summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-10-19 19:02:19 +0200
committerAndrej Shadura <andrewsh@debian.org>2021-10-19 19:02:19 +0200
commit94d2082531bf10c3cdf17b4e8fde9ca1a6c9de40 (patch)
tree8a96d1eb4c266243e10504a968fd49cb780df9d4
parent6b06932344e635f554420698ecd1954e31d0c6ea (diff)
New upstream version 1.45.0
-rwxr-xr-x.ci/scripts/test_synapse_port_db.sh4
-rw-r--r--.github/CODEOWNERS2
-rw-r--r--.github/workflows/tests.yml23
-rw-r--r--CHANGES.md134
-rw-r--r--README.rst9
-rw-r--r--debian/changelog22
-rw-r--r--debian/matrix-synapse-py3.links1
-rw-r--r--docs/MSC1711_certificates_FAQ.md4
-rw-r--r--docs/README.md6
-rw-r--r--docs/development/contributing_guide.md4
-rw-r--r--docs/development/saml.md11
-rw-r--r--docs/modules/spam_checker_callbacks.md50
-rw-r--r--docs/upgrade.md9
-rw-r--r--docs/usage/administration/admin_api/registration_tokens.md3
-rw-r--r--docs/welcome_and_overview.md79
-rw-r--r--mypy.ini114
-rwxr-xr-xscripts-dev/lint.sh2
-rwxr-xr-xscripts-dev/make_full_schema.sh2
-rwxr-xr-xscripts-dev/release.py36
-rwxr-xr-xscripts/synapse_port_db6
-rwxr-xr-xscripts/update_synapse_database (renamed from scripts-dev/update_database)50
-rwxr-xr-xsetup.py32
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/filtering.py117
-rw-r--r--synapse/api/ratelimiting.py86
-rw-r--r--synapse/app/_base.py10
-rw-r--r--synapse/app/admin_cmd.py8
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/app/homeserver.py16
-rw-r--r--synapse/app/phone_stats_home.py8
-rw-r--r--synapse/config/_base.py64
-rw-r--r--synapse/config/account_validity.py2
-rw-r--r--synapse/config/cas.py2
-rw-r--r--synapse/config/emailconfig.py9
-rw-r--r--synapse/config/key.py6
-rw-r--r--synapse/config/oidc.py2
-rw-r--r--synapse/config/registration.py7
-rw-r--r--synapse/config/repository.py2
-rw-r--r--synapse/config/saml2.py2
-rw-r--r--synapse/config/server.py104
-rw-r--r--synapse/config/server_notices.py4
-rw-r--r--synapse/config/sso.py6
-rw-r--r--synapse/config/tls.py9
-rw-r--r--synapse/event_auth.py156
-rw-r--r--synapse/events/builder.py20
-rw-r--r--synapse/events/presence_router.py6
-rw-r--r--synapse/events/spamcheck.py59
-rw-r--r--synapse/events/third_party_rules.py9
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/federation/federation_server.py5
-rw-r--r--synapse/federation/transport/server/__init__.py2
-rw-r--r--synapse/handlers/_base.py120
-rw-r--r--synapse/handlers/account_validity.py8
-rw-r--r--synapse/handlers/admin.py7
-rw-r--r--synapse/handlers/auth.py10
-rw-r--r--synapse/handlers/deactivate_account.py10
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/handlers/directory.py11
-rw-r--r--synapse/handlers/event_auth.py15
-rw-r--r--synapse/handlers/events.py12
-rw-r--r--synapse/handlers/federation.py76
-rw-r--r--synapse/handlers/federation_event.py169
-rw-r--r--synapse/handlers/identity.py22
-rw-r--r--synapse/handlers/initial_sync.py8
-rw-r--r--synapse/handlers/message.py86
-rw-r--r--synapse/handlers/pagination.py22
-rw-r--r--synapse/handlers/profile.py17
-rw-r--r--synapse/handlers/read_marker.py5
-rw-r--r--synapse/handlers/receipts.py6
-rw-r--r--synapse/handlers/register.py24
-rw-r--r--synapse/handlers/room.py25
-rw-r--r--synapse/handlers/room_batch.py423
-rw-r--r--synapse/handlers/room_list.py7
-rw-r--r--synapse/handlers/room_member.py74
-rw-r--r--synapse/handlers/saml.py7
-rw-r--r--synapse/handlers/search.py11
-rw-r--r--synapse/handlers/send_email.py9
-rw-r--r--synapse/handlers/set_password.py6
-rw-r--r--synapse/handlers/ui_auth/checkers.py14
-rw-r--r--synapse/handlers/user_directory.py108
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/http/matrixfederationclient.py10
-rw-r--r--synapse/http/server.py5
-rw-r--r--synapse/logging/_terse_json.py6
-rw-r--r--synapse/logging/context.py16
-rw-r--r--synapse/logging/opentracing.py9
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py20
-rw-r--r--synapse/push/clientformat.py4
-rw-r--r--synapse/push/httppusher.py4
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/replication/http/_base.py154
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py4
-rw-r--r--synapse/replication/slave/storage/pushers.py10
-rw-r--r--synapse/replication/tcp/client.py2
-rw-r--r--synapse/replication/tcp/handler.py15
-rw-r--r--synapse/replication/tcp/redis.py8
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/rest/admin/users.py4
-rw-r--r--synapse/rest/client/account.py40
-rw-r--r--synapse/rest/client/auth.py8
-rw-r--r--synapse/rest/client/capabilities.py10
-rw-r--r--synapse/rest/client/filter.py2
-rw-r--r--synapse/rest/client/login.py6
-rw-r--r--synapse/rest/client/profile.py6
-rw-r--r--synapse/rest/client/push_rule.py4
-rw-r--r--synapse/rest/client/register.py32
-rw-r--r--synapse/rest/client/room.py2
-rw-r--r--synapse/rest/client/room_batch.py337
-rw-r--r--synapse/rest/client/shared_rooms.py2
-rw-r--r--synapse/rest/client/sync.py2
-rw-r--r--synapse/rest/client/voip.py2
-rw-r--r--synapse/rest/media/v1/__init__.py38
-rw-r--r--synapse/rest/media/v1/oembed.py28
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py186
-rw-r--r--synapse/rest/media/v1/thumbnailer.py21
-rw-r--r--synapse/rest/well_known.py4
-rw-r--r--synapse/server.py16
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py8
-rw-r--r--synapse/server_notices/server_notices_manager.py8
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/state/v1.py12
-rw-r--r--synapse/state/v2.py6
-rw-r--r--synapse/storage/databases/main/censor_events.py8
-rw-r--r--synapse/storage/databases/main/client_ips.py15
-rw-r--r--synapse/storage/databases/main/events.py34
-rw-r--r--synapse/storage/databases/main/filtering.py8
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py36
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py10
-rw-r--r--synapse/storage/databases/main/registration.py17
-rw-r--r--synapse/storage/databases/main/room.py8
-rw-r--r--synapse/storage/databases/main/room_batch.py6
-rw-r--r--synapse/storage/databases/main/search.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py101
-rw-r--r--synapse/storage/prepare_database.py6
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/state.py172
-rw-r--r--synapse/storage/util/id_generators.py225
-rw-r--r--synapse/storage/util/sequence.py6
-rw-r--r--synapse/util/__init__.py11
-rw-r--r--synapse/util/async_helpers.py6
-rw-r--r--synapse/util/caches/cached_call.py2
-rw-r--r--synapse/util/caches/deferred_cache.py11
-rw-r--r--synapse/util/caches/lrucache.py57
-rw-r--r--synapse/util/caches/response_cache.py6
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/caches/ttlcache.py12
-rw-r--r--synapse/util/daemonize.py8
-rw-r--r--synapse/util/metrics.py27
-rw-r--r--synapse/util/patch_inline_callbacks.py28
-rw-r--r--synapse/util/threepids.py4
-rw-r--r--synapse/util/versionstring.py25
-rw-r--r--tests/api/test_auth.py14
-rw-r--r--tests/appservice/test_scheduler.py40
-rw-r--r--tests/config/test_base.py21
-rw-r--r--tests/config/test_cache.py54
-rw-r--r--tests/config/test_load.py18
-rw-r--r--tests/config/test_tls.py38
-rw-r--r--tests/events/test_presence_router.py7
-rw-r--r--tests/federation/test_federation_sender.py6
-rw-r--r--tests/federation/test_federation_server.py2
-rw-r--r--tests/handlers/test_profile.py4
-rw-r--r--tests/handlers/test_register.py95
-rw-r--r--tests/handlers/test_stats.py21
-rw-r--r--tests/handlers/test_user_directory.py680
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/logging/test_terse_json.py28
-rw-r--r--tests/module_api/test_api.py7
-rw-r--r--tests/replication/_base.py23
-rw-r--r--tests/rest/admin/test_user.py10
-rw-r--r--tests/rest/client/test_account.py55
-rw-r--r--tests/rest/client/test_capabilities.py2
-rw-r--r--tests/rest/client/test_identity.py2
-rw-r--r--tests/rest/client/test_login.py23
-rw-r--r--tests/rest/client/test_presence.py2
-rw-r--r--tests/rest/client/test_register.py8
-rw-r--r--tests/rest/client/test_rooms.py171
-rw-r--r--tests/rest/client/utils.py13
-rw-r--r--tests/rest/media/v1/test_url_preview.py131
-rw-r--r--tests/server.py8
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py2
-rw-r--r--tests/storage/databases/main/test_room.py7
-rw-r--r--tests/storage/test_appservice.py2
-rw-r--r--tests/storage/test_cleanup_extrems.py7
-rw-r--r--tests/storage/test_client_ips.py64
-rw-r--r--tests/storage/test_event_chain.py14
-rw-r--r--tests/storage/test_monthly_active_users.py14
-rw-r--r--tests/storage/test_roommember.py14
-rw-r--r--tests/storage/test_state.py513
-rw-r--r--tests/storage/test_txn_limit.py2
-rw-r--r--tests/storage/test_user_directory.py406
-rw-r--r--tests/test_event_auth.py108
-rw-r--r--tests/test_federation.py1
-rw-r--r--tests/test_mau.py39
-rw-r--r--tests/test_preview.py40
-rw-r--r--tests/unittest.py50
-rw-r--r--tox.ini2
199 files changed, 5120 insertions, 2238 deletions
diff --git a/.ci/scripts/test_synapse_port_db.sh b/.ci/scripts/test_synapse_port_db.sh
index 2b4e5ec1..50115b30 100755
--- a/.ci/scripts/test_synapse_port_db.sh
+++ b/.ci/scripts/test_synapse_port_db.sh
@@ -25,7 +25,7 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml
echo "--- Prepare test database"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
-scripts-dev/update_database --database-config .ci/sqlite-config.yaml
+scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
# Create the PostgreSQL database.
.ci/scripts/postgres_exec.py "CREATE DATABASE synapse"
@@ -46,7 +46,7 @@ echo "--- Prepare empty SQLite database"
# we do this by deleting the sqlite db, and then doing the same again.
rm .ci/test_db.db
-scripts-dev/update_database --database-config .ci/sqlite-config.yaml
+scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
# re-create the PostgreSQL database.
.ci/scripts/postgres_exec.py \
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 00000000..d6cd75f1
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1,2 @@
+# Automatically request reviews from the synapse-core team when a pull request comes in.
+* @matrix-org/synapse-core \ No newline at end of file
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index fa9c5e03..30a911fd 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -76,22 +76,25 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.6", "3.7", "3.8", "3.9"]
+ python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
database: ["sqlite"]
+ toxenv: ["py"]
include:
# Newest Python without optional deps
- - python-version: "3.9"
- toxenv: "py-noextras,combine"
+ - python-version: "3.10"
+ toxenv: "py-noextras"
# Oldest Python with PostgreSQL
- python-version: "3.6"
database: "postgres"
postgres-version: "9.6"
+ toxenv: "py"
- # Newest Python with PostgreSQL
- - python-version: "3.9"
+ # Newest Python with newest PostgreSQL
+ - python-version: "3.10"
database: "postgres"
- postgres-version: "13"
+ postgres-version: "14"
+ toxenv: "py"
steps:
- uses: actions/checkout@v2
@@ -111,7 +114,7 @@ jobs:
if: ${{ matrix.postgres-version }}
timeout-minutes: 2
run: until pg_isready -h localhost; do sleep 1; done
- - run: tox -e py,combine
+ - run: tox -e ${{ matrix.toxenv }}
env:
TRIAL_FLAGS: "--jobs=2"
SYNAPSE_POSTGRES: ${{ matrix.database == 'postgres' || '' }}
@@ -169,7 +172,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- run: pip install tox
- - run: tox -e py,combine
+ - run: tox -e py
env:
TRIAL_FLAGS: "--jobs=2"
- name: Dump logs
@@ -256,8 +259,8 @@ jobs:
- python-version: "3.6"
postgres-version: "9.6"
- - python-version: "3.9"
- postgres-version: "13"
+ - python-version: "3.10"
+ postgres-version: "14"
services:
postgres:
diff --git a/CHANGES.md b/CHANGES.md
index 3f048ba8..435387d7 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,137 @@
+Synapse 1.45.0 (2021-10-19)
+===========================
+
+No functional changes since Synapse 1.45.0rc2.
+
+Known Issues
+------------
+
+- A suspected [performance regression](https://github.com/matrix-org/synapse/issues/11049) which was first reported after the release of 1.44.0 remains unresolved.
+
+ We have not been able to identify a probable cause. Affected users report that setting up a federation sender worker appears to alleviate symptoms of the regression.
+
+Improved Documentation
+----------------------
+
+- Reword changelog to clarify concerns about a suspected performance regression in 1.44.0. ([\#11117](https://github.com/matrix-org/synapse/issues/11117))
+
+
+Synapse 1.45.0rc2 (2021-10-14)
+==============================
+
+This release candidate [fixes](https://github.com/matrix-org/synapse/issues/11053) a user directory [bug](https://github.com/matrix-org/synapse/issues/11025) present in 1.45.0rc1.
+
+Known Issues
+------------
+
+- A suspected [performance regression](https://github.com/matrix-org/synapse/issues/11049) which was first reported after the release of 1.44.0 remains unresolved.
+
+ We have not been able to identify a probable cause. Affected users report that setting up a federation sender worker appears to alleviate symptoms of the regression.
+
+Bugfixes
+--------
+
+- Fix a long-standing bug when using multiple event persister workers where events were not correctly sent down `/sync` due to a race. ([\#11045](https://github.com/matrix-org/synapse/issues/11045))
+- Fix a bug introduced in Synapse 1.45.0rc1 where the user directory would stop updating if it processed an event from a
+ user not in the `users` table. ([\#11053](https://github.com/matrix-org/synapse/issues/11053))
+- Fix a bug introduced in Synapse 1.44.0 when logging errors during oEmbed processing. ([\#11061](https://github.com/matrix-org/synapse/issues/11061))
+
+
+Internal Changes
+----------------
+
+- Add an 'approximate difference' method to `StateFilter`. ([\#10825](https://github.com/matrix-org/synapse/issues/10825))
+- Fix inconsistent behavior of `get_last_client_by_ip` when reporting data that has not been stored in the database yet. ([\#10970](https://github.com/matrix-org/synapse/issues/10970))
+- Fix a bug introduced in Synapse 1.21.0 that causes opentracing and Prometheus metrics for replication requests to be measured incorrectly. ([\#10996](https://github.com/matrix-org/synapse/issues/10996))
+- Ensure that cache config tests do not share state. ([\#11036](https://github.com/matrix-org/synapse/issues/11036))
+
+
+Synapse 1.45.0rc1 (2021-10-12)
+==============================
+
+**Note:** Media storage providers module that read from Synapse's configuration need changes as of this version, see the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#upgrading-to-v1450) for more information.
+
+Known Issues
+------------
+
+- We are investigating [a performance issue](https://github.com/matrix-org/synapse/issues/11049) which was reported after the release of 1.44.0.
+- We are aware of [a bug](https://github.com/matrix-org/synapse/issues/11025) with the user directory when using application services. A second release candidate is expected which will resolve this.
+
+Features
+--------
+
+- Add [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069) support to `/account/whoami`. ([\#9655](https://github.com/matrix-org/synapse/issues/9655))
+- Support autodiscovery of oEmbed previews. ([\#10822](https://github.com/matrix-org/synapse/issues/10822))
+- Add a `user_may_send_3pid_invite` spam checker callback for modules to allow or deny 3PID invites. ([\#10894](https://github.com/matrix-org/synapse/issues/10894))
+- Add a spam checker callback to allow or deny room joins. ([\#10910](https://github.com/matrix-org/synapse/issues/10910))
+- Include an `update_synapse_database` script in the distribution. Contributed by @Fizzadar at Beeper. ([\#10954](https://github.com/matrix-org/synapse/issues/10954))
+- Include exception information in JSON logging output. Contributed by @Fizzadar at Beeper. ([\#11028](https://github.com/matrix-org/synapse/issues/11028))
+
+
+Bugfixes
+--------
+
+- Fix a minor bug in the response to `/_matrix/client/r0/voip/turnServer`. Contributed by @lukaslihotzki. ([\#10922](https://github.com/matrix-org/synapse/issues/10922))
+- Fix a bug where empty `yyyy-mm-dd/` directories would be left behind in the media store's `url_cache_thumbnails/` directory. ([\#10924](https://github.com/matrix-org/synapse/issues/10924))
+- Fix a bug introduced in Synapse v1.40.0 where the signature checks for room version 8 and 9 could be applied to earlier room versions in some situations. ([\#10927](https://github.com/matrix-org/synapse/issues/10927))
+- Fix a long-standing bug wherein deactivated users still count towards the monthly active users limit. ([\#10947](https://github.com/matrix-org/synapse/issues/10947))
+- Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. ([\#10956](https://github.com/matrix-org/synapse/issues/10956))
+- Fix a long-standing bug where rebuilding the user directory wouldn't exclude support and deactivated users. ([\#10960](https://github.com/matrix-org/synapse/issues/10960))
+- Fix [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint rejecting subsequent batches with unknown batch ID error in existing room versions from the room creator. ([\#10962](https://github.com/matrix-org/synapse/issues/10962))
+- Fix a bug that could leak local users' per-room nicknames and avatars when the user directory is rebuilt. ([\#10981](https://github.com/matrix-org/synapse/issues/10981))
+- Fix a long-standing bug where the remainder of a batch of user directory changes would be silently dropped if the server left a room early in the batch. ([\#10982](https://github.com/matrix-org/synapse/issues/10982))
+- Correct a bugfix introduced in Synapse v1.44.0 that would catch the wrong error if a connection is lost before a response could be written to it. ([\#10995](https://github.com/matrix-org/synapse/issues/10995))
+- Fix a long-standing bug where local users' per-room nicknames/avatars were visible to anyone who could see you in the user directory. ([\#11002](https://github.com/matrix-org/synapse/issues/11002))
+- Fix a long-standing bug where a user's per-room nickname/avatar would overwrite their profile in the user directory when a room was made public. ([\#11003](https://github.com/matrix-org/synapse/issues/11003))
+- Work around a regression, introduced in Synapse v1.39.0, that caused `SynapseError`s raised by the experimental third-party rules module callback `check_event_allowed` to be ignored. ([\#11042](https://github.com/matrix-org/synapse/issues/11042))
+- Fix a bug in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) insertion events in rooms that could cause cross-talk/conflicts between batches. ([\#10877](https://github.com/matrix-org/synapse/issues/10877))
+
+
+Improved Documentation
+----------------------
+
+- Change wording ("reference homeserver") in Synapse repository documentation. Contributed by @maxkratz. ([\#10971](https://github.com/matrix-org/synapse/issues/10971))
+- Fix a dead URL in development documentation (SAML) and change wording from "Riot" to "Element". Contributed by @maxkratz. ([\#10973](https://github.com/matrix-org/synapse/issues/10973))
+- Add additional content to the Welcome and Overview page of the documentation. ([\#10990](https://github.com/matrix-org/synapse/issues/10990))
+- Update links to MSCs in documentation. Contributed by @dklimpel. ([\#10991](https://github.com/matrix-org/synapse/issues/10991))
+
+
+Internal Changes
+----------------
+
+- Improve type hinting in `synapse.util`. ([\#10888](https://github.com/matrix-org/synapse/issues/10888))
+- Add further type hints to `synapse.storage.util`. ([\#10892](https://github.com/matrix-org/synapse/issues/10892))
+- Fix type hints to be compatible with an upcoming change to Twisted. ([\#10895](https://github.com/matrix-org/synapse/issues/10895))
+- Update utility code to handle C implementations of frozendict. ([\#10902](https://github.com/matrix-org/synapse/issues/10902))
+- Drop old functionality which maintained database compatibility with Synapse versions before v1.31. ([\#10903](https://github.com/matrix-org/synapse/issues/10903))
+- Clean-up configuration helper classes for the `ServerConfig` class. ([\#10915](https://github.com/matrix-org/synapse/issues/10915))
+- Use direct references to config flags. ([\#10916](https://github.com/matrix-org/synapse/issues/10916), [\#10959](https://github.com/matrix-org/synapse/issues/10959), [\#10985](https://github.com/matrix-org/synapse/issues/10985))
+- Clean up some of the federation event authentication code for clarity. ([\#10926](https://github.com/matrix-org/synapse/issues/10926), [\#10940](https://github.com/matrix-org/synapse/issues/10940), [\#10986](https://github.com/matrix-org/synapse/issues/10986), [\#10987](https://github.com/matrix-org/synapse/issues/10987), [\#10988](https://github.com/matrix-org/synapse/issues/10988), [\#11010](https://github.com/matrix-org/synapse/issues/11010), [\#11011](https://github.com/matrix-org/synapse/issues/11011))
+- Refactor various parts of the codebase to use `RoomVersion` objects instead of room version identifier strings. ([\#10934](https://github.com/matrix-org/synapse/issues/10934))
+- Refactor user directory tests in preparation for upcoming changes. ([\#10935](https://github.com/matrix-org/synapse/issues/10935))
+- Include the event id in the logcontext when handling PDUs received over federation. ([\#10936](https://github.com/matrix-org/synapse/issues/10936))
+- Fix logged errors in unit tests. ([\#10939](https://github.com/matrix-org/synapse/issues/10939))
+- Fix a broken test to ensure that consent configuration works during registration. ([\#10945](https://github.com/matrix-org/synapse/issues/10945))
+- Add type hints to filtering classes. ([\#10958](https://github.com/matrix-org/synapse/issues/10958))
+- Add type-hint to `HomeserverTestcase.setup_test_homeserver`. ([\#10961](https://github.com/matrix-org/synapse/issues/10961))
+- Fix the test utility function `create_room_as` so that `is_public=True` will explicitly set the `visibility` parameter of room creation requests to `public`. Contributed by @AndrewFerr. ([\#10963](https://github.com/matrix-org/synapse/issues/10963))
+- Make the release script more robust and transparent. ([\#10966](https://github.com/matrix-org/synapse/issues/10966))
+- Refactor [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` mega function into smaller handler functions. ([\#10974](https://github.com/matrix-org/synapse/issues/10974))
+- Log stack traces when a missing opentracing span is detected. ([\#10983](https://github.com/matrix-org/synapse/issues/10983))
+- Update GHA config to run tests against Python 3.10 and PostgreSQL 14. ([\#10992](https://github.com/matrix-org/synapse/issues/10992))
+- Fix a long-standing bug where `ReadWriteLock`s could drop logging contexts on exit. ([\#10993](https://github.com/matrix-org/synapse/issues/10993))
+- Add a `CODEOWNERS` file to automatically request reviews from the `@matrix-org/synapse-core` team on new pull requests. ([\#10994](https://github.com/matrix-org/synapse/issues/10994))
+- Add further type hints to `synapse.state`. ([\#11004](https://github.com/matrix-org/synapse/issues/11004))
+- Remove the deprecated `BaseHandler` object. ([\#11005](https://github.com/matrix-org/synapse/issues/11005))
+- Bump mypy version for CI to 0.910, and pull in new type stubs for dependencies. ([\#11006](https://github.com/matrix-org/synapse/issues/11006))
+- Fix CI to run the unit tests without optional deps. ([\#11017](https://github.com/matrix-org/synapse/issues/11017))
+- Ensure that cache config tests do not share state. ([\#11019](https://github.com/matrix-org/synapse/issues/11019))
+- Add additional type hints to `synapse.server_notices`. ([\#11021](https://github.com/matrix-org/synapse/issues/11021))
+- Add additional type hints for `synapse.push`. ([\#11023](https://github.com/matrix-org/synapse/issues/11023))
+- When installing the optional developer dependencies, also include the dependencies needed for type-checking and unit testing. ([\#11034](https://github.com/matrix-org/synapse/issues/11034))
+- Remove unnecessary list comprehension from `synapse_port_db` to satisfy code style requirements. ([\#11043](https://github.com/matrix-org/synapse/issues/11043))
+
+
Synapse 1.44.0 (2021-10-05)
===========================
diff --git a/README.rst b/README.rst
index 524a3a51..50de3a49 100644
--- a/README.rst
+++ b/README.rst
@@ -55,11 +55,8 @@ solutions. The hope is for Matrix to act as the building blocks for a new
generation of fully open and interoperable messaging and VoIP apps for the
internet.
-Synapse is a reference "homeserver" implementation of Matrix from the core
-development team at matrix.org, written in Python/Twisted. It is intended to
-showcase the concept of Matrix and let folks see the spec in the context of a
-codebase and let you run your own homeserver and generally help bootstrap the
-ecosystem.
+Synapse is a Matrix "homeserver" implementation developed by the matrix.org core
+team, written in Python 3/Twisted.
In Matrix, every user runs one or more Matrix clients, which connect through to
a Matrix homeserver. The homeserver stores all their personal chat history and
@@ -301,7 +298,7 @@ to install using pip and a virtualenv::
python3 -m venv ./env
source ./env/bin/activate
- pip install -e ".[all,test]"
+ pip install -e ".[all,dev]"
This will run a process of downloading and installing all the needed
dependencies into a virtual env. If any dependencies fail to install,
diff --git a/debian/changelog b/debian/changelog
index 9e878fbc..5fefb2f2 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,25 @@
+matrix-synapse-py3 (1.45.0) stable; urgency=medium
+
+ * New synapse release 1.45.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 19 Oct 2021 11:18:53 +0100
+
+matrix-synapse-py3 (1.45.0~rc2) stable; urgency=medium
+
+ * New synapse release 1.45.0~rc2.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 14 Oct 2021 10:58:24 +0100
+
+matrix-synapse-py3 (1.45.0~rc1) stable; urgency=medium
+
+ [ Nick @ Beeper ]
+ * Include an `update_synapse_database` script in the distribution.
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.45.0~rc1.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 12 Oct 2021 10:46:27 +0100
+
matrix-synapse-py3 (1.44.0) stable; urgency=medium
* New synapse release 1.44.0.
diff --git a/debian/matrix-synapse-py3.links b/debian/matrix-synapse-py3.links
index 53e29654..7eeba180 100644
--- a/debian/matrix-synapse-py3.links
+++ b/debian/matrix-synapse-py3.links
@@ -3,3 +3,4 @@ opt/venvs/matrix-synapse/bin/register_new_matrix_user usr/bin/register_new_matri
opt/venvs/matrix-synapse/bin/synapse_port_db usr/bin/synapse_port_db
opt/venvs/matrix-synapse/bin/synapse_review_recent_signups usr/bin/synapse_review_recent_signups
opt/venvs/matrix-synapse/bin/synctl usr/bin/synctl
+opt/venvs/matrix-synapse/bin/update_synapse_database usr/bin/update_synapse_database
diff --git a/docs/MSC1711_certificates_FAQ.md b/docs/MSC1711_certificates_FAQ.md
index 7d71c190..086899a9 100644
--- a/docs/MSC1711_certificates_FAQ.md
+++ b/docs/MSC1711_certificates_FAQ.md
@@ -3,7 +3,7 @@
## Historical Note
This document was originally written to guide server admins through the upgrade
path towards Synapse 1.0. Specifically,
-[MSC1711](https://github.com/matrix-org/matrix-doc/blob/master/proposals/1711-x509-for-federation.md)
+[MSC1711](https://github.com/matrix-org/matrix-doc/blob/main/proposals/1711-x509-for-federation.md)
required that all servers present valid TLS certificates on their federation
API. Admins were encouraged to achieve compliance from version 0.99.0 (released
in February 2019) ahead of version 1.0 (released June 2019) enforcing the
@@ -282,7 +282,7 @@ coffin of the Perspectives project (which was already pretty dead). So, the
Spec Core Team decided that a better approach would be to mandate valid TLS
certificates for federation alongside the rest of the Web. More details can be
found in
-[MSC1711](https://github.com/matrix-org/matrix-doc/blob/master/proposals/1711-x509-for-federation.md#background-the-failure-of-the-perspectives-approach).
+[MSC1711](https://github.com/matrix-org/matrix-doc/blob/main/proposals/1711-x509-for-federation.md#background-the-failure-of-the-perspectives-approach).
This results in a breaking change, which is disruptive, but absolutely critical
for the security model. However, the existence of Let's Encrypt as a trivial
diff --git a/docs/README.md b/docs/README.md
index e113f55d..6d70f5af 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -6,9 +6,9 @@ Please update any links to point to the new website instead.
## About
This directory currently holds a series of markdown files documenting how to install, use
-and develop Synapse, the reference Matrix homeserver. The documentation is readable directly
-from this repository, but it is recommended to instead browse through the
-[website](https://matrix-org.github.io/synapse) for easier discoverability.
+and develop Synapse. The documentation is readable directly from this repository, but it is
+recommended to instead browse through the [website](https://matrix-org.github.io/synapse) for
+easier discoverability.
## Adding to the documentation
diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md
index 71336636..3bf08a72 100644
--- a/docs/development/contributing_guide.md
+++ b/docs/development/contributing_guide.md
@@ -50,7 +50,7 @@ setup a *virtualenv*, as follows:
cd path/where/you/have/cloned/the/repository
python3 -m venv ./env
source ./env/bin/activate
-pip install -e ".[all,lint,mypy,test]"
+pip install -e ".[all,dev]"
pip install tox
```
@@ -63,7 +63,7 @@ TBD
# 5. Get in touch.
-Join our developer community on Matrix: #synapse-dev:matrix.org !
+Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)!
# 6. Pick an issue.
diff --git a/docs/development/saml.md b/docs/development/saml.md
index a9bfd2dc..60a431d6 100644
--- a/docs/development/saml.md
+++ b/docs/development/saml.md
@@ -1,10 +1,9 @@
# How to test SAML as a developer without a server
-https://capriza.github.io/samling/samling.html (https://github.com/capriza/samling) is a great
-resource for being able to tinker with the SAML options within Synapse without needing to
-deploy and configure a complicated software stack.
+https://fujifish.github.io/samling/samling.html (https://github.com/fujifish/samling) is a great resource for being able to tinker with the
+SAML options within Synapse without needing to deploy and configure a complicated software stack.
-To make Synapse (and therefore Riot) use it:
+To make Synapse (and therefore Element) use it:
1. Use the samling.html URL above or deploy your own and visit the IdP Metadata tab.
2. Copy the XML to your clipboard.
@@ -26,9 +25,9 @@ To make Synapse (and therefore Riot) use it:
the dependencies are installed and ready to go.
7. Restart Synapse.
-Then in Riot:
+Then in Element:
-1. Visit the login page with a Riot pointing at your homeserver.
+1. Visit the login page and point Element towards your homeserver using the `public_baseurl` above.
2. Click the Single Sign-On button.
3. On the samling page, enter a Name Identifier and add a SAML Attribute for `uid=your_localpart`.
The response must also be signed.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 7920ac5f..787e9907 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -19,6 +19,21 @@ either a `bool` to indicate whether the event must be rejected because of spam,
to indicate the event must be rejected because of spam and to give a rejection reason to
forward to clients.
+### `user_may_join_room`
+
+```python
+async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool
+```
+
+Called when a user is trying to join a room. The module must return a `bool` to indicate
+whether the user can join the room. The user is represented by their Matrix user ID (e.g.
+`@alice:example.com`) and the room is represented by its Matrix ID (e.g.
+`!room:example.com`). The module is also given a boolean to indicate whether the user
+currently has a pending invite in the room.
+
+This callback isn't called if the join is performed by a server administrator, or in the
+context of a room creation.
+
### `user_may_invite`
```python
@@ -29,6 +44,41 @@ Called when processing an invitation. The module must return a `bool` indicating
the inviter can invite the invitee to the given room. Both inviter and invitee are
represented by their Matrix user ID (e.g. `@alice:example.com`).
+### `user_may_send_3pid_invite`
+
+```python
+async def user_may_send_3pid_invite(
+ inviter: str,
+ medium: str,
+ address: str,
+ room_id: str,
+) -> bool
+```
+
+Called when processing an invitation using a third-party identifier (also called a 3PID,
+e.g. an email address or a phone number). The module must return a `bool` indicating
+whether the inviter can invite the invitee to the given room.
+
+The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the
+invitee is represented by its medium (e.g. "email") and its address
+(e.g. `alice@example.com`). See [the Matrix specification](https://matrix.org/docs/spec/appendices#pid-types)
+for more information regarding third-party identifiers.
+
+For example, a call to this callback to send an invitation to the email address
+`alice@example.com` would look like this:
+
+```python
+await user_may_send_3pid_invite(
+ "@bob:example.com", # The inviter's user ID
+ "email", # The medium of the 3PID to invite
+ "alice@example.com", # The address of the 3PID to invite
+ "!some_room:example.com", # The ID of the room to send the invite into
+)
+```
+
+**Note**: If the third-party identifier is already associated with a matrix user ID,
+[`user_may_invite`](#user_may_invite) will be used instead.
+
### `user_may_create_room`
```python
diff --git a/docs/upgrade.md b/docs/upgrade.md
index a8221372..18ecb267 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -85,6 +85,15 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
+# Upgrading to v1.45.0
+
+## Changes required to media storage provider modules when reading from the Synapse configuration object
+
+Media storage provider modules that read from the Synapse configuration object (i.e. that
+read the value of `hs.config.[...]`) now need to specify the configuration section they're
+reading from. This means that if a module reads the value of e.g. `hs.config.media_store_path`,
+it needs to replace it with `hs.config.media.media_store_path`.
+
# Upgrading to v1.44.0
## The URL preview cache is no longer mirrored to storage providers
diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md
index 828c0277..c48d060d 100644
--- a/docs/usage/administration/admin_api/registration_tokens.md
+++ b/docs/usage/administration/admin_api/registration_tokens.md
@@ -1,7 +1,8 @@
# Registration Tokens
This API allows you to manage tokens which can be used to authenticate
-registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md).
+registration requests, as proposed in
+[MSC3231](https://github.com/matrix-org/matrix-doc/blob/main/proposals/3231-token-authenticated-registration.md).
To use it, you will need to enable the `registration_requires_token` config
option, and authenticate by providing an `access_token` for a server admin:
see [Admin API](../../usage/administration/admin_api).
diff --git a/docs/welcome_and_overview.md b/docs/welcome_and_overview.md
index 30e75984..aab2d6b4 100644
--- a/docs/welcome_and_overview.md
+++ b/docs/welcome_and_overview.md
@@ -1,4 +1,79 @@
# Introduction
-Welcome to the documentation repository for Synapse, the reference
-[Matrix](https://matrix.org) homeserver implementation. \ No newline at end of file
+Welcome to the documentation repository for Synapse, a
+[Matrix](https://matrix.org) homeserver implementation developed by the matrix.org core
+team.
+
+## Installing and using Synapse
+
+This documentation covers topics for **installation**, **configuration** and
+**maintainence** of your Synapse process:
+
+* Learn how to [install](setup/installation.md) and
+ [configure](usage/configuration/index.html) your own instance, perhaps with [Single
+ Sign-On](usage/configuration/user_authentication/index.html).
+
+* See how to [upgrade](upgrade.md) between Synapse versions.
+
+* Administer your instance using the [Admin
+ API](usage/administration/admin_api/index.html), installing [pluggable
+ modules](modules/index.html), or by accessing the [manhole](manhole.md).
+
+* Learn how to [read log lines](usage/administration/request_log.md), configure
+ [logging](usage/configuration/logging_sample_config.md) or set up [structured
+ logging](structured_logging.md).
+
+* Scale Synapse through additional [worker processes](workers.md).
+
+* Set up [monitoring and metrics](metrics-howto.md) to keep an eye on your
+ Synapse instance's performance.
+
+## Developing on Synapse
+
+Contributions are welcome! Synapse is primarily written in
+[Python](https://python.org). As a developer, you may be interested in the
+following documentation:
+
+* Read the [Contributing Guide](development/contributing_guide.md). It is meant
+ to walk new contributors through the process of developing and submitting a
+ change to the Synapse codebase (which is [hosted on
+ GitHub](https://github.com/matrix-org/synapse)).
+
+* Set up your [development
+ environment](development/contributing_guide.md#2-what-do-i-need), then learn
+ how to [lint](development/contributing_guide.md#run-the-linters) and
+ [test](development/contributing_guide.md#8-test-test-test) your code.
+
+* Look at [the issue tracker](https://github.com/matrix-org/synapse/issues) for
+ bugs to fix or features to add. If you're new, it may be best to start with
+ those labeled [good first
+ issue](https://github.com/matrix-org/synapse/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).
+
+* Understand [how Synapse is
+ built](development/internal_documentation/index.html), how to [migrate
+ database schemas](development/database_schema.md), learn about
+ [federation](federate.md) and how to [set up a local
+ federation](federate.md#running-a-demo-federation-of-synapses) for development.
+
+* We like to keep our `git` history clean. [Learn](development/git.md) how to
+ do so!
+
+* And finally, contribute to this documentation! The source for which is
+ [located here](https://github.com/matrix-org/synapse/tree/develop/docs).
+
+## Donating to Synapse development
+
+Want to help keep Synapse going but don't know how to code? Synapse is a
+[Matrix.org Foundation](https://matrix.org) project. Consider becoming a
+supportor on [Liberapay](https://liberapay.com/matrixdotorg),
+[Patreon](https://patreon.com/matrixdotorg) or through
+[PayPal](https://paypal.me/matrixdotorg) via a one-time donation.
+
+If you are an organisation or enterprise and would like to sponsor development,
+reach out to us over email at: support (at) matrix.org
+
+## Reporting a security vulnerability
+
+If you've found a security issue in Synapse or any other Matrix.org Foundation
+project, please report it to us in accordance with our [Security Disclosure
+Policy](https://www.matrix.org/security-disclosure-policy/). Thank you!
diff --git a/mypy.ini b/mypy.ini
index 437d0a46..a7019e2b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -96,15 +96,48 @@ files =
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
+[mypy-synapse.push.*]
+disallow_untyped_defs = True
+
[mypy-synapse.rest.*]
disallow_untyped_defs = True
+[mypy-synapse.server_notices.*]
+disallow_untyped_defs = True
+
+[mypy-synapse.state.*]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.util.*]
+disallow_untyped_defs = True
+
+[mypy-synapse.streams.*]
+disallow_untyped_defs = True
+
[mypy-synapse.util.batching_queue]
disallow_untyped_defs = True
+[mypy-synapse.util.caches.cached_call]
+disallow_untyped_defs = True
+
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
+[mypy-synapse.util.caches.lrucache]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.caches.response_cache]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.caches.stream_change_cache]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.caches.ttl_cache]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.daemonize]
+disallow_untyped_defs = True
+
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
@@ -141,6 +174,9 @@ disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
+[mypy-synapse.util.patch_inline_callbacks]
+disallow_untyped_defs = True
+
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
@@ -162,98 +198,106 @@ disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
-[mypy-pymacaroons.*]
-ignore_missing_imports = True
+[mypy-synapse.util.versionstring]
+disallow_untyped_defs = True
-[mypy-zope]
-ignore_missing_imports = True
+[mypy-tests.handlers.test_user_directory]
+disallow_untyped_defs = True
-[mypy-bcrypt]
-ignore_missing_imports = True
+[mypy-tests.storage.test_user_directory]
+disallow_untyped_defs = True
-[mypy-constantly]
-ignore_missing_imports = True
+;; Dependencies without annotations
+;; Before ignoring a module, check to see if type stubs are available.
+;; The `typeshed` project maintains stubs here:
+;; https://github.com/python/typeshed/tree/master/stubs
+;; and for each package `foo` there's a corresponding `types-foo` package on PyPI,
+;; which we can pull in as a dev dependency by adding to `setup.py`'s
+;; `CONDITIONAL_REQUIREMENTS["mypy"]` list.
-[mypy-twisted.*]
+[mypy-authlib.*]
ignore_missing_imports = True
-[mypy-treq.*]
+[mypy-bcrypt]
ignore_missing_imports = True
-[mypy-hyperlink]
+[mypy-canonicaljson]
ignore_missing_imports = True
-[mypy-h11]
+[mypy-constantly]
ignore_missing_imports = True
-[mypy-msgpack]
+[mypy-daemonize]
ignore_missing_imports = True
-[mypy-opentracing]
+[mypy-h11]
ignore_missing_imports = True
-[mypy-OpenSSL.*]
+[mypy-hiredis]
ignore_missing_imports = True
-[mypy-netaddr]
+[mypy-hyperlink]
ignore_missing_imports = True
-[mypy-saml2.*]
+[mypy-ijson.*]
ignore_missing_imports = True
-[mypy-canonicaljson]
+[mypy-jaeger_client.*]
ignore_missing_imports = True
-[mypy-jaeger_client.*]
+[mypy-josepy.*]
ignore_missing_imports = True
-[mypy-jsonschema]
+[mypy-jwt.*]
ignore_missing_imports = True
-[mypy-signedjson.*]
+[mypy-lxml]
ignore_missing_imports = True
-[mypy-prometheus_client.*]
+[mypy-msgpack]
ignore_missing_imports = True
-[mypy-service_identity.*]
+[mypy-nacl.*]
ignore_missing_imports = True
-[mypy-daemonize]
+[mypy-netaddr]
ignore_missing_imports = True
-[mypy-sentry_sdk]
+[mypy-opentracing]
ignore_missing_imports = True
-[mypy-PIL.*]
+[mypy-phonenumbers.*]
ignore_missing_imports = True
-[mypy-lxml]
+[mypy-prometheus_client.*]
ignore_missing_imports = True
-[mypy-jwt.*]
+[mypy-pymacaroons.*]
ignore_missing_imports = True
-[mypy-authlib.*]
+[mypy-pympler.*]
ignore_missing_imports = True
[mypy-rust_python_jaeger_reporter.*]
ignore_missing_imports = True
-[mypy-nacl.*]
+[mypy-saml2.*]
ignore_missing_imports = True
-[mypy-hiredis]
+[mypy-sentry_sdk]
ignore_missing_imports = True
-[mypy-josepy.*]
+[mypy-service_identity.*]
ignore_missing_imports = True
-[mypy-pympler.*]
+[mypy-signedjson.*]
ignore_missing_imports = True
-[mypy-phonenumbers.*]
+[mypy-treq.*]
ignore_missing_imports = True
-[mypy-ijson.*]
+[mypy-twisted.*]
+ignore_missing_imports = True
+
+[mypy-zope]
ignore_missing_imports = True
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 809eff16..b6554a73 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -90,10 +90,10 @@ else
"scripts/hash_password"
"scripts/register_new_matrix_user"
"scripts/synapse_port_db"
+ "scripts/update_synapse_database"
"scripts-dev"
"scripts-dev/build_debian_packages"
"scripts-dev/sign_json"
- "scripts-dev/update_database"
"contrib" "synctl" "setup.py" "synmark" "stubs" ".ci"
)
fi
diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh
index 39bf30d2..c3c90f4e 100755
--- a/scripts-dev/make_full_schema.sh
+++ b/scripts-dev/make_full_schema.sh
@@ -147,7 +147,7 @@ python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
echo "Running db background jobs..."
-scripts-dev/update_database --database-config "$SQLITE_CONFIG"
+scripts/update_synapse_database --database-config --run-background-updates "$SQLITE_CONFIG"
# Create the PostgreSQL database.
echo "Creating postgres database..."
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index ab2d860a..4e1f99fe 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -35,6 +35,19 @@ from github import Github
from packaging import version
+def run_until_successful(command, *args, **kwargs):
+ while True:
+ completed_process = subprocess.run(command, *args, **kwargs)
+ exit_code = completed_process.returncode
+ if exit_code == 0:
+ # successful, so nothing more to do here.
+ return completed_process
+
+ print(f"The command {command!r} failed with exit code {exit_code}.")
+ print("Please try to correct the failure and then re-run.")
+ click.confirm("Try again?", abort=True)
+
+
@click.group()
def cli():
"""An interactive script to walk through the parts of creating a release.
@@ -197,7 +210,7 @@ def prepare():
f.write(parsed_synapse_ast.dumps())
# Generate changelogs
- subprocess.run("python3 -m towncrier", shell=True)
+ run_until_successful("python3 -m towncrier", shell=True)
# Generate debian changelogs
if parsed_new_version.pre is not None:
@@ -209,11 +222,11 @@ def prepare():
else:
debian_version = new_version
- subprocess.run(
+ run_until_successful(
f'dch -M -v {debian_version} "New synapse release {debian_version}."',
shell=True,
)
- subprocess.run('dch -M -r -D stable ""', shell=True)
+ run_until_successful('dch -M -r -D stable ""', shell=True)
# Show the user the changes and ask if they want to edit the change log.
repo.git.add("-u")
@@ -224,7 +237,7 @@ def prepare():
# Commit the changes.
repo.git.add("-u")
- repo.git.commit(f"-m {new_version}")
+ repo.git.commit("-m", new_version)
# We give the option to bail here in case the user wants to make sure things
# are OK before pushing.
@@ -239,6 +252,8 @@ def prepare():
# Otherwise, push and open the changelog in the browser.
repo.git.push("-u", repo.remote().name, repo.active_branch.name)
+ print("Opening the changelog in your browser...")
+ print("Please ask others to give it a check.")
click.launch(
f"https://github.com/matrix-org/synapse/blob/{repo.active_branch.name}/CHANGES.md"
)
@@ -290,7 +305,19 @@ def tag(gh_token: Optional[str]):
# If no token was given, we bail here
if not gh_token:
+ print("Launching the GitHub release page in your browser.")
+ print("Please correct the title and create a draft.")
+ if current_version.is_prerelease:
+ print("As this is an RC, remember to mark it as a pre-release!")
+ print("(by the way, this step can be automated by passing --gh-token,")
+ print("or one of the GH_TOKEN or GITHUB_TOKEN env vars.)")
click.launch(f"https://github.com/matrix-org/synapse/releases/edit/{tag_name}")
+
+ print("Once done, you need to wait for the release assets to build.")
+ if click.confirm("Launch the release assets actions page?", default=True):
+ click.launch(
+ f"https://github.com/matrix-org/synapse/actions?query=branch%3A{tag_name}"
+ )
return
# Create a new draft release
@@ -305,6 +332,7 @@ def tag(gh_token: Optional[str]):
)
# Open the release and the actions where we are building the assets.
+ print("Launching the release page and the actions page.")
click.launch(release.html_url)
click.launch(
f"https://github.com/matrix-org/synapse/actions?query=branch%3A{tag_name}"
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index fa6ac6d9..349866eb 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -215,7 +215,7 @@ class MockHomeserver:
def __init__(self, config):
self.clock = Clock(reactor)
self.config = config
- self.hostname = config.server_name
+ self.hostname = config.server.server_name
self.version_string = "Synapse/" + get_version_string(synapse)
def get_clock(self):
@@ -583,7 +583,7 @@ class Porter(object):
return
self.postgres_store = self.build_db_store(
- self.hs_config.get_single_database()
+ self.hs_config.database.get_single_database()
)
await self.run_background_updates_on_postgres()
@@ -1069,7 +1069,7 @@ class CursesProgress(Progress):
self.stdscr.addstr(0, 0, status, curses.A_BOLD)
- max_len = max([len(t) for t in self.tables.keys()])
+ max_len = max(len(t) for t in self.tables.keys())
left_margin = 5
middle_space = 1
diff --git a/scripts-dev/update_database b/scripts/update_synapse_database
index 87f709b6..6c088bad 100755
--- a/scripts-dev/update_database
+++ b/scripts/update_synapse_database
@@ -36,16 +36,35 @@ class MockHomeserver(HomeServer):
def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__(
- config.server_name, reactor=reactor, config=config, **kwargs
+ config.server.server_name, reactor=reactor, config=config, **kwargs
)
self.version_string = "Synapse/" + get_version_string(synapse)
-if __name__ == "__main__":
+def run_background_updates(hs):
+ store = hs.get_datastore()
+
+ async def run_background_updates():
+ await store.db_pool.updates.run_background_updates(sleep=False)
+ # Stop the reactor to exit the script once every background update is run.
+ reactor.stop()
+
+ def run():
+ # Apply all background updates on the database.
+ defer.ensureDeferred(
+ run_as_background_process("background_updates", run_background_updates)
+ )
+
+ reactor.callWhenRunning(run)
+
+ reactor.run()
+
+
+def main():
parser = argparse.ArgumentParser(
description=(
- "Updates a synapse database to the latest schema and runs background updates"
+ "Updates a synapse database to the latest schema and optionally runs background updates"
" on it."
)
)
@@ -54,7 +73,13 @@ if __name__ == "__main__":
"--database-config",
type=argparse.FileType("r"),
required=True,
- help="A database config file for either a SQLite3 database or a PostgreSQL one.",
+ help="Synapse configuration file, giving the details of the database to be updated",
+ )
+ parser.add_argument(
+ "--run-background-updates",
+ action="store_true",
+ required=False,
+ help="run background updates after upgrading the database schema",
)
args = parser.parse_args()
@@ -82,19 +107,10 @@ if __name__ == "__main__":
# Setup instantiates the store within the homeserver object and updates the
# DB.
hs.setup()
- store = hs.get_datastore()
- async def run_background_updates():
- await store.db_pool.updates.run_background_updates(sleep=False)
- # Stop the reactor to exit the script once every background update is run.
- reactor.stop()
+ if args.run_background_updates:
+ run_background_updates(hs)
- def run():
- # Apply all background updates on the database.
- defer.ensureDeferred(
- run_as_background_process("background_updates", run_background_updates)
- )
- reactor.callWhenRunning(run)
-
- reactor.run()
+if __name__ == "__main__":
+ main()
diff --git a/setup.py b/setup.py
index c4785635..220084a4 100755
--- a/setup.py
+++ b/setup.py
@@ -103,17 +103,17 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8",
]
-CONDITIONAL_REQUIREMENTS["dev"] = CONDITIONAL_REQUIREMENTS["lint"] + [
- # The following are used by the release script
- "click==7.1.2",
- "redbaron==0.9.2",
- "GitPython==3.1.14",
- "commonmark==0.9.1",
- "pygithub==1.55",
+CONDITIONAL_REQUIREMENTS["mypy"] = [
+ "mypy==0.910",
+ "mypy-zope==0.3.2",
+ "types-bleach>=4.1.0",
+ "types-jsonschema>=3.2.0",
+ "types-Pillow>=8.3.4",
+ "types-pyOpenSSL>=20.0.7",
+ "types-PyYAML>=5.4.10",
+ "types-setuptools>=57.4.0",
]
-CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"]
-
# Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests.
# Tests assume that all optional dependencies are installed.
@@ -121,6 +121,20 @@ CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"]
# parameterized_class decorator was introduced in parameterized 0.7.0
CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
+CONDITIONAL_REQUIREMENTS["dev"] = (
+ CONDITIONAL_REQUIREMENTS["lint"]
+ + CONDITIONAL_REQUIREMENTS["mypy"]
+ + CONDITIONAL_REQUIREMENTS["test"]
+ + [
+ # The following are used by the release script
+ "click==7.1.2",
+ "redbaron==0.9.2",
+ "GitPython==3.1.14",
+ "commonmark==0.9.1",
+ "pygithub==1.55",
+ ]
+)
+
setup(
name="matrix-synapse",
version=version,
diff --git a/synapse/__init__.py b/synapse/__init__.py
index b8979c36..97452f34 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.44.0"
+__version__ = "1.45.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index ad1ff6a9..20e91a11 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,7 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from typing import List
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Container,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ TypeVar,
+ Union,
+)
import jsonschema
from jsonschema import FormatChecker
@@ -23,7 +33,11 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomID, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
FILTER_SCHEMA = {
"additionalProperties": False,
@@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = {
@FormatChecker.cls_checks("matrix_room_id")
-def matrix_room_id_validator(room_id_str):
+def matrix_room_id_validator(room_id_str: str) -> RoomID:
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks("matrix_user_id")
-def matrix_user_id_validator(user_id_str):
+def matrix_user_id_validator(user_id_str: str) -> UserID:
return UserID.from_string(user_id_str)
class Filtering:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
- async def get_user_filter(self, user_localpart, filter_id):
+ async def get_user_filter(
+ self, user_localpart: str, filter_id: Union[int, str]
+ ) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
- def add_user_filter(self, user_localpart, user_filter):
+ def add_user_filter(
+ self, user_localpart: str, user_filter: JsonDict
+ ) -> Awaitable[int]:
self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
@@ -146,13 +164,13 @@ class Filtering:
# replace_user_filter at some point? There's no REST API specified for
# them however
- def check_valid_filter(self, user_filter_json):
+ def check_valid_filter(self, user_filter_json: JsonDict) -> None:
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
- user_filter_json(dict): The filter
+ user_filter_json: The filter
Raises:
SynapseError: If the filter is not valid.
"""
@@ -167,8 +185,12 @@ class Filtering:
raise SynapseError(400, str(e))
+# Filters work across events, presence EDUs, and account data.
+FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
+
+
class FilterCollection:
- def __init__(self, filter_json):
+ def __init__(self, filter_json: JsonDict):
self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {})
@@ -188,25 +210,25 @@ class FilterCollection:
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
- def __repr__(self):
+ def __repr__(self) -> str:
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
- def get_filter_json(self):
+ def get_filter_json(self) -> JsonDict:
return self._filter_json
- def timeline_limit(self):
+ def timeline_limit(self) -> int:
return self._room_timeline_filter.limit()
- def presence_limit(self):
+ def presence_limit(self) -> int:
return self._presence_filter.limit()
- def ephemeral_limit(self):
+ def ephemeral_limit(self) -> int:
return self._room_ephemeral_filter.limit()
- def lazy_load_members(self):
+ def lazy_load_members(self) -> bool:
return self._room_state_filter.lazy_load_members()
- def include_redundant_members(self):
+ def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members()
def filter_presence(self, events):
@@ -218,29 +240,31 @@ class FilterCollection:
def filter_room_state(self, events):
return self._room_state_filter.filter(self._room_filter.filter(events))
- def filter_room_timeline(self, events):
+ def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))
- def filter_room_ephemeral(self, events):
+ def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
- def filter_room_account_data(self, events):
+ def filter_room_account_data(
+ self, events: Iterable[FilterEvent]
+ ) -> List[FilterEvent]:
return self._room_account_data.filter(self._room_filter.filter(events))
- def blocks_all_presence(self):
+ def blocks_all_presence(self) -> bool:
return (
self._presence_filter.filters_all_types()
or self._presence_filter.filters_all_senders()
)
- def blocks_all_room_ephemeral(self):
+ def blocks_all_room_ephemeral(self) -> bool:
return (
self._room_ephemeral_filter.filters_all_types()
or self._room_ephemeral_filter.filters_all_senders()
or self._room_ephemeral_filter.filters_all_rooms()
)
- def blocks_all_room_timeline(self):
+ def blocks_all_room_timeline(self) -> bool:
return (
self._room_timeline_filter.filters_all_types()
or self._room_timeline_filter.filters_all_senders()
@@ -249,7 +273,7 @@ class FilterCollection:
class Filter:
- def __init__(self, filter_json):
+ def __init__(self, filter_json: JsonDict):
self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
@@ -266,20 +290,20 @@ class Filter:
self.labels = self.filter_json.get("org.matrix.labels", None)
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
- def filters_all_types(self):
+ def filters_all_types(self) -> bool:
return "*" in self.not_types
- def filters_all_senders(self):
+ def filters_all_senders(self) -> bool:
return "*" in self.not_senders
- def filters_all_rooms(self):
+ def filters_all_rooms(self) -> bool:
return "*" in self.not_rooms
- def check(self, event):
+ def check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
Returns:
- bool: True if the event matches
+ True if the event matches
"""
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
@@ -305,18 +329,25 @@ class Filter:
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
- content = event.get("content", {})
+ content = event.get("content") or {}
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, [])
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
- def check_fields(self, room_id, sender, event_type, labels, contains_url):
+ def check_fields(
+ self,
+ room_id: Optional[str],
+ sender: Optional[str],
+ event_type: Optional[str],
+ labels: Container[str],
+ contains_url: bool,
+ ) -> bool:
"""Checks whether the filter matches the given event fields.
Returns:
- bool: True if the event fields match
+ True if the event fields match
"""
literal_keys = {
"rooms": lambda v: room_id == v,
@@ -343,14 +374,14 @@ class Filter:
return True
- def filter_rooms(self, room_ids):
+ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
"""Apply the 'rooms' filter to a given list of rooms.
Args:
- room_ids (list): A list of room_ids.
+ room_ids: A list of room_ids.
Returns:
- list: A list of room_ids that match the filter
+ A list of room_ids that match the filter
"""
room_ids = set(room_ids)
@@ -363,23 +394,23 @@ class Filter:
return room_ids
- def filter(self, events):
+ def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events))
- def limit(self):
+ def limit(self) -> int:
return self.filter_json.get("limit", 10)
- def lazy_load_members(self):
+ def lazy_load_members(self) -> bool:
return self.filter_json.get("lazy_load_members", False)
- def include_redundant_members(self):
+ def include_redundant_members(self) -> bool:
return self.filter_json.get("include_redundant_members", False)
- def with_room_ids(self, room_ids):
+ def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
Args:
- room_ids (iterable[unicode]): The room_ids to add
+ room_ids: The room_ids to add
Returns:
filter: A new filter including the given rooms and the old
@@ -390,8 +421,8 @@ class Filter:
return newFilter
-def _matches_wildcard(actual_value, filter_value):
- if filter_value.endswith("*"):
+def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
+ if filter_value.endswith("*") and isinstance(actual_value, str):
type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix)
else:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index cbdd7402..e8964097 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
+from synapse.config.ratelimiting import RateLimitConfig
from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
@@ -233,3 +234,88 @@ class Ratelimiter:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
+
+
+class RequestRatelimiter:
+ def __init__(
+ self,
+ store: DataStore,
+ clock: Clock,
+ rc_message: RateLimitConfig,
+ rc_admin_redaction: Optional[RateLimitConfig],
+ ):
+ self.store = store
+ self.clock = clock
+
+ # The rate_hz and burst_count are overridden on a per-user basis
+ self.request_ratelimiter = Ratelimiter(
+ store=self.store, clock=self.clock, rate_hz=0, burst_count=0
+ )
+ self._rc_message = rc_message
+
+ # Check whether ratelimiting room admin message redaction is enabled
+ # by the presence of rate limits in the config
+ if rc_admin_redaction:
+ self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
+ store=self.store,
+ clock=self.clock,
+ rate_hz=rc_admin_redaction.per_second,
+ burst_count=rc_admin_redaction.burst_count,
+ )
+ else:
+ self.admin_redaction_ratelimiter = None
+
+ async def ratelimit(
+ self,
+ requester: Requester,
+ update: bool = True,
+ is_admin_redaction: bool = False,
+ ) -> None:
+ """Ratelimits requests.
+
+ Args:
+ requester
+ update: Whether to record that a request is being processed.
+ Set to False when doing multiple checks for one request (e.g.
+ to check up front if we would reject the request), and set to
+ True for the last call for a given request.
+ is_admin_redaction: Whether this is a room admin/moderator
+ redacting an event. If so then we may apply different
+ ratelimits depending on config.
+
+ Raises:
+ LimitExceededError if the request should be ratelimited
+ """
+ user_id = requester.user.to_string()
+
+ # The AS user itself is never rate limited.
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service is not None:
+ return # do not ratelimit app service senders
+
+ messages_per_second = self._rc_message.per_second
+ burst_count = self._rc_message.burst_count
+
+ # Check if there is a per user override in the DB.
+ override = await self.store.get_ratelimit_for_user(user_id)
+ if override:
+ # If overridden with a null Hz then ratelimiting has been entirely
+ # disabled for the user
+ if not override.messages_per_second:
+ return
+
+ messages_per_second = override.messages_per_second
+ burst_count = override.burst_count
+
+ if is_admin_redaction and self.admin_redaction_ratelimiter:
+ # If we have separate config for admin redactions, use a separate
+ # ratelimiter as to not have user_ids clash
+ await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
+ else:
+ # Override rate and burst count per-user
+ await self.request_ratelimiter.ratelimit(
+ requester,
+ rate_hz=messages_per_second,
+ burst_count=burst_count,
+ update=update,
+ )
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 548f6dcd..4a204a58 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -86,11 +86,11 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
start_reactor(
appname,
- soft_file_limit=config.soft_file_limit,
- gc_thresholds=config.gc_thresholds,
+ soft_file_limit=config.server.soft_file_limit,
+ gc_thresholds=config.server.gc_thresholds,
pid_file=config.worker.worker_pid_file,
daemonize=config.worker.worker_daemonize,
- print_pidfile=config.print_pidfile,
+ print_pidfile=config.server.print_pidfile,
logger=logger,
run_command=run_command,
)
@@ -298,10 +298,10 @@ def refresh_certificate(hs):
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
"""
- if not hs.config.has_tls_listener():
+ if not hs.config.server.has_tls_listener():
return
- hs.config.read_certificate_from_disk()
+ hs.config.tls.read_certificate_from_disk()
hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config)
if hs._listening_services:
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index f2c5b752..13d20af4 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -195,14 +195,14 @@ def start(config_options):
config.logging.no_redirect_stdio = True
# Explicitly disable background processes
- config.update_user_directory = False
+ config.server.update_user_directory = False
config.worker.run_background_tasks = False
- config.start_pushers = False
+ config.worker.start_pushers = False
config.pusher_shard_config.instances = []
- config.send_federation = False
+ config.worker.send_federation = False
config.federation_shard_config.instances = []
- synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
ss = AdminCmdServer(
config.server.server_name,
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 3036e1b4..7489f31d 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -462,7 +462,7 @@ def start(config_options):
# For other worker types we force this to off.
config.server.update_user_directory = False
- synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
if config.server.gc_seconds:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 205831dc..422f03cc 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer):
)
if name in ["media", "federation", "client"]:
- if self.config.media.enable_media_repo:
+ if self.config.server.enable_media_repo:
media_repo = self.get_media_repository_resource()
resources.update(
{MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
@@ -248,7 +248,7 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "webclient":
- webclient_loc = self.config.web_client_location
+ webclient_loc = self.config.server.web_client_location
if webclient_loc is None:
logger.warning(
@@ -343,7 +343,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
if config.server.gc_seconds:
@@ -439,11 +439,11 @@ def run(hs):
_base.start_reactor(
"synapse-homeserver",
- soft_file_limit=hs.config.soft_file_limit,
- gc_thresholds=hs.config.gc_thresholds,
- pid_file=hs.config.pid_file,
- daemonize=hs.config.daemonize,
- print_pidfile=hs.config.print_pidfile,
+ soft_file_limit=hs.config.server.soft_file_limit,
+ gc_thresholds=hs.config.server.gc_thresholds,
+ pid_file=hs.config.server.pid_file,
+ daemonize=hs.config.server.daemonize,
+ print_pidfile=hs.config.server.print_pidfile,
logger=logger,
)
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 49e7a45e..fcd01e83 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -74,7 +74,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
store = hs.get_datastore()
stats["homeserver"] = hs.config.server.server_name
- stats["server_context"] = hs.config.server_context
+ stats["server_context"] = hs.config.server.server_context
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
version = sys.version_info
@@ -171,7 +171,7 @@ def start_phone_stats_home(hs):
current_mau_count_by_service = {}
reserved_users = ()
store = hs.get_datastore()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
current_mau_count = await store.get_monthly_active_count()
current_mau_count_by_service = (
await store.get_monthly_active_count_by_service()
@@ -183,9 +183,9 @@ def start_phone_stats_home(hs):
current_mau_by_service_gauge.labels(app_service).set(float(count))
registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
- max_mau_gauge.set(float(hs.config.max_mau_value))
+ max_mau_gauge.set(float(hs.config.server.max_mau_value))
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
generate_monthly_active_users()
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index d974a1a2..7c4428a1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -118,21 +118,6 @@ class Config:
"synapse", "res/templates"
)
- def __getattr__(self, item: str) -> Any:
- """
- Try and fetch a configuration option that does not exist on this class.
-
- This is so that existing configs that rely on `self.value`, where value
- is actually from a different config section, continue to work.
- """
- if item in ["generate_config_section", "read_config"]:
- raise AttributeError(item)
-
- if self.root is None:
- raise AttributeError(item)
- else:
- return self.root._get_unclassed_config(self.section, item)
-
@staticmethod
def parse_size(value):
if isinstance(value, int):
@@ -289,7 +274,9 @@ class Config:
env.filters.update(
{
"format_ts": _format_ts_filter,
- "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+ "mxc_to_http": _create_mxc_to_http_filter(
+ self.root.server.public_baseurl
+ ),
}
)
@@ -311,8 +298,6 @@ class RootConfig:
config_classes = []
def __init__(self):
- self._configs = OrderedDict()
-
for config_class in self.config_classes:
if config_class.section is None:
raise ValueError("%r requires a section name" % (config_class,))
@@ -321,42 +306,7 @@ class RootConfig:
conf = config_class(self)
except Exception as e:
raise Exception("Failed making %s: %r" % (config_class.section, e))
- self._configs[config_class.section] = conf
-
- def __getattr__(self, item: str) -> Any:
- """
- Redirect lookups on this object either to config objects, or values on
- config objects, so that `config.tls.blah` works, as well as legacy uses
- of things like `config.server_name`. It will first look up the config
- section name, and then values on those config classes.
- """
- if item in self._configs.keys():
- return self._configs[item]
-
- return self._get_unclassed_config(None, item)
-
- def _get_unclassed_config(self, asking_section: Optional[str], item: str):
- """
- Fetch a config value from one of the instantiated config classes that
- has not been fetched directly.
-
- Args:
- asking_section: If this check is coming from a Config child, which
- one? This section will not be asked if it has the value.
- item: The configuration value key.
-
- Raises:
- AttributeError if no config classes have the config key. The body
- will contain what sections were checked.
- """
- for key, val in self._configs.items():
- if key == asking_section:
- continue
-
- if item in dir(val):
- return getattr(val, item)
-
- raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+ setattr(self, config_class.section, conf)
def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
"""
@@ -373,9 +323,11 @@ class RootConfig:
"""
res = OrderedDict()
- for name, config in self._configs.items():
+ for config_class in self.config_classes:
+ config = getattr(self, config_class.section)
+
if hasattr(config, func_name):
- res[name] = getattr(config, func_name)(*args, **kwargs)
+ res[config_class.section] = getattr(config, func_name)(*args, **kwargs)
return res
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index ffaffc49..b56c2a24 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -76,7 +76,7 @@ class AccountValidityConfig(Config):
)
if self.account_validity_renew_by_email_enabled:
- if not self.public_baseurl:
+ if not self.root.server.public_baseurl:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
# Load account validity templates.
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 901f4123..9b58ecf3 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -37,7 +37,7 @@ class CasConfig(Config):
# The public baseurl is required because it is used by the redirect
# template.
- public_baseurl = self.public_baseurl
+ public_baseurl = self.root.server.public_baseurl
if not public_baseurl:
raise ConfigError("cas_config requires a public_baseurl to be set")
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 936abe61..8ff59aa2 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -19,7 +19,6 @@ import email.utils
import logging
import os
from enum import Enum
-from typing import Optional
import attr
@@ -135,7 +134,7 @@ class EmailConfig(Config):
# msisdn is currently always remote while Synapse does not support any method of
# sending SMS messages
ThreepidBehaviour.REMOTE
- if self.account_threepid_delegate_email
+ if self.root.registration.account_threepid_delegate_email
else ThreepidBehaviour.LOCAL
)
# Prior to Synapse v1.4.0, there was another option that defined whether Synapse would
@@ -144,7 +143,7 @@ class EmailConfig(Config):
# identity server in the process.
self.using_identity_server_from_trusted_list = False
if (
- not self.account_threepid_delegate_email
+ not self.root.registration.account_threepid_delegate_email
and config.get("trust_identity_server_for_password_resets", False) is True
):
# Use the first entry in self.trusted_third_party_id_servers instead
@@ -156,7 +155,7 @@ class EmailConfig(Config):
# trusted_third_party_id_servers does not contain a scheme whereas
# account_threepid_delegate_email is expected to. Presume https
- self.account_threepid_delegate_email: Optional[str] = (
+ self.root.registration.account_threepid_delegate_email = (
"https://" + first_trusted_identity_server
)
self.using_identity_server_from_trusted_list = True
@@ -335,7 +334,7 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
- if self.account_validity_renew_by_email_enabled:
+ if self.root.account_validity.account_validity_renew_by_email_enabled:
expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 94a90630..015dbb8a 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -145,11 +145,13 @@ class KeyConfig(Config):
# list of TrustedKeyServer objects
self.key_servers = list(
- _parse_key_servers(key_servers, self.federation_verify_certificates)
+ _parse_key_servers(
+ key_servers, self.root.tls.federation_verify_certificates
+ )
)
self.macaroon_secret_key = config.get(
- "macaroon_secret_key", self.registration_shared_secret
+ "macaroon_secret_key", self.root.registration.registration_shared_secret
)
if not self.macaroon_secret_key:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 7e67fbad..10f57963 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -58,7 +58,7 @@ class OIDCConfig(Config):
"Multiple OIDC providers have the idp_id %r." % idp_id
)
- public_baseurl = self.public_baseurl
+ public_baseurl = self.root.server.public_baseurl
if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback"
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 7cffdacf..a3d2a38c 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -45,7 +45,10 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
- if self.account_threepid_delegate_msisdn and not self.public_baseurl:
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.root.server.public_baseurl
+ ):
raise ConfigError(
"The configuration option `public_baseurl` is required if "
"`account_threepid_delegate.msisdn` is set, such that "
@@ -85,7 +88,7 @@ class RegistrationConfig(Config):
if mxid_localpart:
# Convert the localpart to a full mxid.
self.auto_join_user_id = UserID(
- mxid_localpart, self.server_name
+ mxid_localpart, self.root.server.server_name
).to_string()
if self.autocreate_auto_join_rooms:
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 7481f3bf..69906a98 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -94,7 +94,7 @@ class ContentRepositoryConfig(Config):
# Only enable the media repo if either the media repo is enabled or the
# current worker app is the media repo.
if (
- self.enable_media_repo is False
+ self.root.server.enable_media_repo is False
and config.get("worker_app") != "synapse.app.media_repository"
):
self.can_load_media_repo = False
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 05e98362..9c51b6a2 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -199,7 +199,7 @@ class SAML2Config(Config):
"""
import saml2
- public_baseurl = self.public_baseurl
+ public_baseurl = self.root.server.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ad8715da..818b8063 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,6 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +17,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
import yaml
@@ -184,49 +182,74 @@ KNOWN_RESOURCES = {
@attr.s(frozen=True)
class HttpResourceConfig:
- names = attr.ib(
- type=List[str],
+ names: List[str] = attr.ib(
factory=list,
validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore
)
- compress = attr.ib(
- type=bool,
+ compress: bool = attr.ib(
default=False,
validator=attr.validators.optional(attr.validators.instance_of(bool)), # type: ignore[arg-type]
)
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class HttpListenerConfig:
"""Object describing the http-specific parts of the config of a listener"""
- x_forwarded = attr.ib(type=bool, default=False)
- resources = attr.ib(type=List[HttpResourceConfig], factory=list)
- additional_resources = attr.ib(type=Dict[str, dict], factory=dict)
- tag = attr.ib(type=str, default=None)
+ x_forwarded: bool = False
+ resources: List[HttpResourceConfig] = attr.ib(factory=list)
+ additional_resources: Dict[str, dict] = attr.ib(factory=dict)
+ tag: Optional[str] = None
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ListenerConfig:
"""Object describing the configuration of a single listener."""
- port = attr.ib(type=int, validator=attr.validators.instance_of(int))
- bind_addresses = attr.ib(type=List[str])
- type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
- tls = attr.ib(type=bool, default=False)
+ port: int = attr.ib(validator=attr.validators.instance_of(int))
+ bind_addresses: List[str]
+ type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
+ tls: bool = False
# http_options is only populated if type=http
- http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
+ http_options: Optional[HttpListenerConfig] = None
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ManholeConfig:
"""Object describing the configuration of the manhole"""
- username = attr.ib(type=str, validator=attr.validators.instance_of(str))
- password = attr.ib(type=str, validator=attr.validators.instance_of(str))
- priv_key = attr.ib(type=Optional[Key])
- pub_key = attr.ib(type=Optional[Key])
+ username: str = attr.ib(validator=attr.validators.instance_of(str))
+ password: str = attr.ib(validator=attr.validators.instance_of(str))
+ priv_key: Optional[Key]
+ pub_key: Optional[Key]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RetentionConfig:
+ """Object describing the configuration of the manhole"""
+
+ interval: int
+ shortest_max_lifetime: Optional[int]
+ longest_max_lifetime: Optional[int]
+
+
+@attr.s(frozen=True)
+class LimitRemoteRoomsConfig:
+ enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
+ complexity: Union[float, int] = attr.ib(
+ validator=attr.validators.instance_of(
+ (float, int) # type: ignore[arg-type] # noqa
+ ),
+ default=1.0,
+ )
+ complexity_error: str = attr.ib(
+ validator=attr.validators.instance_of(str),
+ default=ROOM_COMPLEXITY_TOO_GREAT,
+ )
+ admins_can_join: bool = attr.ib(
+ validator=attr.validators.instance_of(bool), default=False
+ )
class ServerConfig(Config):
@@ -519,7 +542,7 @@ class ServerConfig(Config):
" greater than 'allowed_lifetime_max'"
)
- self.retention_purge_jobs: List[Dict[str, Optional[int]]] = []
+ self.retention_purge_jobs: List[RetentionConfig] = []
for purge_job_config in retention_config.get("purge_jobs", []):
interval_config = purge_job_config.get("interval")
@@ -553,20 +576,12 @@ class ServerConfig(Config):
)
self.retention_purge_jobs.append(
- {
- "interval": interval,
- "shortest_max_lifetime": shortest_max_lifetime,
- "longest_max_lifetime": longest_max_lifetime,
- }
+ RetentionConfig(interval, shortest_max_lifetime, longest_max_lifetime)
)
if not self.retention_purge_jobs:
self.retention_purge_jobs = [
- {
- "interval": self.parse_duration("1d"),
- "shortest_max_lifetime": None,
- "longest_max_lifetime": None,
- }
+ RetentionConfig(self.parse_duration("1d"), None, None)
]
self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
@@ -591,25 +606,6 @@ class ServerConfig(Config):
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
- @attr.s
- class LimitRemoteRoomsConfig:
- enabled = attr.ib(
- validator=attr.validators.instance_of(bool), default=False
- )
- complexity = attr.ib(
- validator=attr.validators.instance_of(
- (float, int) # type: ignore[arg-type] # noqa
- ),
- default=1.0,
- )
- complexity_error = attr.ib(
- validator=attr.validators.instance_of(str),
- default=ROOM_COMPLEXITY_TOO_GREAT,
- )
- admins_can_join = attr.ib(
- validator=attr.validators.instance_of(bool), default=False
- )
-
self.limit_remote_rooms = LimitRemoteRoomsConfig(
**(config.get("limit_remote_rooms") or {})
)
diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py
index 48bf3241..bde4e879 100644
--- a/synapse/config/server_notices.py
+++ b/synapse/config/server_notices.py
@@ -73,7 +73,9 @@ class ServerNoticesConfig(Config):
return
mxid_localpart = c["system_mxid_localpart"]
- self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string()
+ self.server_notices_mxid = UserID(
+ mxid_localpart, self.root.server.server_name
+ ).to_string()
self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None)
self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None)
# todo: i18n
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 524a7ff3..11a9b76a 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -103,8 +103,10 @@ class SSOConfig(Config):
# the client's.
# public_baseurl is an optional setting, so we only add the fallback's URL to the
# list if it's provided (because we can't figure out what that URL is otherwise).
- if self.public_baseurl:
- login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
+ if self.root.server.public_baseurl:
+ login_fallback_url = (
+ self.root.server.public_baseurl + "_matrix/static/client/login"
+ )
self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs):
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 5679f05e..6227434b 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -172,9 +172,12 @@ class TlsConfig(Config):
)
# YYYYMMDDhhmmssZ -- in UTC
- expires_on = datetime.strptime(
- tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ"
- )
+ expiry_data = tls_certificate.get_notAfter()
+ if expiry_data is None:
+ raise ValueError(
+ "TLS Certificate has no expiry date, and this is not permitted"
+ )
+ expires_on = datetime.strptime(expiry_data.decode("ascii"), "%Y%m%d%H%M%SZ")
now = datetime.utcnow()
days_remaining = (expires_on - now).days
return days_remaining
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 65040283..ca0293a3 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -41,42 +41,112 @@ from synapse.types import StateMap, UserID, get_domain_from_id
logger = logging.getLogger(__name__)
-def check(
- room_version_obj: RoomVersion,
- event: EventBase,
- auth_events: StateMap[EventBase],
- do_sig_check: bool = True,
- do_size_check: bool = True,
+def validate_event_for_room_version(
+ room_version_obj: RoomVersion, event: EventBase
) -> None:
- """Checks if this event is correctly authed.
+ """Ensure that the event complies with the limits, and has the right signatures
+
+ NB: does not *validate* the signatures - it assumes that any signatures present
+ have already been checked.
+
+ NB: it does not check that the event satisfies the auth rules (that is done in
+ check_auth_rules_for_event) - these tests are independent of the rest of the state
+ in the room.
+
+ NB: This is used to check events that have been received over federation. As such,
+ it can only enforce the checks specified in the relevant room version, to avoid
+ a split-brain situation where some servers accept such events, and others reject
+ them.
+
+ TODO: consider moving this into EventValidator
Args:
- room_version_obj: the version of the room
- event: the event being checked.
- auth_events: the existing room state.
- do_sig_check: True if it should be verified that the sending server
- signed the event.
- do_size_check: True if the size of the event fields should be verified.
+ room_version_obj: the version of the room which contains this event
+ event: the event to be checked
Raises:
- AuthError if the checks fail
-
- Returns:
- if the auth checks pass.
+ SynapseError if there is a problem with the event
"""
- assert isinstance(auth_events, dict)
-
- if do_size_check:
- _check_size_limits(event)
+ _check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
- room_id = event.room_id
+ # check that the event has the correct signatures
+ sender_domain = get_domain_from_id(event.sender)
+
+ is_invite_via_3pid = (
+ event.type == EventTypes.Member
+ and event.membership == Membership.INVITE
+ and "third_party_invite" in event.content
+ )
+
+ # Check the sender's domain has signed the event
+ if not event.signatures.get(sender_domain):
+ # We allow invites via 3pid to have a sender from a different
+ # HS, as the sender must match the sender of the original
+ # 3pid invite. This is checked further down with the
+ # other dedicated membership checks.
+ if not is_invite_via_3pid:
+ raise AuthError(403, "Event not signed by sender's server")
+
+ if event.format_version in (EventFormatVersions.V1,):
+ # Only older room versions have event IDs to check.
+ event_id_domain = get_domain_from_id(event.event_id)
+
+ # Check the origin domain has signed the event
+ if not event.signatures.get(event_id_domain):
+ raise AuthError(403, "Event not signed by sending server")
+
+ is_invite_via_allow_rule = (
+ room_version_obj.msc3083_join_rules
+ and event.type == EventTypes.Member
+ and event.membership == Membership.JOIN
+ and EventContentFields.AUTHORISING_USER in event.content
+ )
+ if is_invite_via_allow_rule:
+ authoriser_domain = get_domain_from_id(
+ event.content[EventContentFields.AUTHORISING_USER]
+ )
+ if not event.signatures.get(authoriser_domain):
+ raise AuthError(403, "Event not signed by authorising server")
+
+
+def check_auth_rules_for_event(
+ room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
+ """Check that an event complies with the auth rules
+
+ Checks whether an event passes the auth rules with a given set of state events
+
+ Assumes that we have already checked that the event is the right shape (it has
+ enough signatures, has a room ID, etc). In other words:
+
+ - it's fine for use in state resolution, when we have already decided whether to
+ accept the event or not, and are now trying to decide whether it should make it
+ into the room state
+
+ - when we're doing the initial event auth, it is only suitable in combination with
+ a bunch of other tests.
+
+ Args:
+ room_version_obj: the version of the room
+ event: the event being checked.
+ auth_events: the room state to check the events against.
+
+ Raises:
+ AuthError if the checks fail
+ """
+ assert isinstance(auth_events, dict)
# We need to ensure that the auth events are actually for the same room, to
# stop people from using powers they've been granted in other rooms for
# example.
+ #
+ # Arguably we don't need to do this when we're just doing state res, as presumably
+ # the state res algorithm isn't silly enough to give us events from different rooms.
+ # Still, it's easier to do it anyway.
+ room_id = event.room_id
for auth_event in auth_events.values():
if auth_event.room_id != room_id:
raise AuthError(
@@ -85,44 +155,12 @@ def check(
"which is in room %s"
% (event.event_id, room_id, auth_event.event_id, auth_event.room_id),
)
-
- if do_sig_check:
- sender_domain = get_domain_from_id(event.sender)
-
- is_invite_via_3pid = (
- event.type == EventTypes.Member
- and event.membership == Membership.INVITE
- and "third_party_invite" in event.content
- )
-
- # Check the sender's domain has signed the event
- if not event.signatures.get(sender_domain):
- # We allow invites via 3pid to have a sender from a different
- # HS, as the sender must match the sender of the original
- # 3pid invite. This is checked further down with the
- # other dedicated membership checks.
- if not is_invite_via_3pid:
- raise AuthError(403, "Event not signed by sender's server")
-
- if event.format_version in (EventFormatVersions.V1,):
- # Only older room versions have event IDs to check.
- event_id_domain = get_domain_from_id(event.event_id)
-
- # Check the origin domain has signed the event
- if not event.signatures.get(event_id_domain):
- raise AuthError(403, "Event not signed by sending server")
-
- is_invite_via_allow_rule = (
- event.type == EventTypes.Member
- and event.membership == Membership.JOIN
- and EventContentFields.AUTHORISING_USER in event.content
- )
- if is_invite_via_allow_rule:
- authoriser_domain = get_domain_from_id(
- event.content[EventContentFields.AUTHORISING_USER]
+ if auth_event.rejected_reason:
+ raise AuthError(
+ 403,
+ "During auth for event %s: found rejected event %s in the state"
+ % (event.event_id, auth_event.event_id),
)
- if not event.signatures.get(authoriser_domain):
- raise AuthError(403, "Event not signed by authorising server")
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
#
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 87e2bb12..50f2a4c1 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -18,10 +18,8 @@ import attr
from nacl.signing import SigningKey
from synapse.api.constants import MAX_DEPTH
-from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
KNOWN_EVENT_FORMAT_VERSIONS,
- KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
@@ -197,24 +195,6 @@ class EventBuilderFactory:
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
- def new(self, room_version: str, key_values: dict) -> EventBuilder:
- """Generate an event builder appropriate for the given room version
-
- Deprecated: use for_room_version with a RoomVersion object instead
-
- Args:
- room_version: Version of the room that we're creating an event builder for
- key_values: Fields used as the basis of the new event
-
- Returns:
- EventBuilder
- """
- v = KNOWN_ROOM_VERSIONS.get(room_version)
- if not v:
- # this can happen if support is withdrawn for a room version
- raise UnsupportedRoomVersionError()
- return self.for_room_version(v, key_values)
-
def for_room_version(
self, room_version: RoomVersion, key_values: dict
) -> EventBuilder:
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index eb4556cd..68b8b190 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -45,11 +45,11 @@ def load_legacy_presence_router(hs: "HomeServer"):
configuration, and registers the hooks they implement.
"""
- if hs.config.presence_router_module_class is None:
+ if hs.config.server.presence_router_module_class is None:
return
- module = hs.config.presence_router_module_class
- config = hs.config.presence_router_config
+ module = hs.config.server.presence_router_module_class
+ config = hs.config.server.presence_router_config
api = hs.get_module_api()
presence_router = module(config=config, module_api=api)
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index c389f70b..ae4c8ab2 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -44,7 +44,9 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
["synapse.events.EventBase"],
Awaitable[Union[bool, str]],
]
+USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
+USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[
[str, List[str], List[Dict[str, str]]], Awaitable[bool]
@@ -165,7 +167,11 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
class SpamChecker:
def __init__(self):
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
+ self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
+ self._user_may_send_3pid_invite_callbacks: List[
+ USER_MAY_SEND_3PID_INVITE_CALLBACK
+ ] = []
self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
self._user_may_create_room_with_invites_callbacks: List[
USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
@@ -187,7 +193,9 @@ class SpamChecker:
def register_callbacks(
self,
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+ user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
+ user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
user_may_create_room_with_invites: Optional[
USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
@@ -206,9 +214,17 @@ class SpamChecker:
if check_event_for_spam is not None:
self._check_event_for_spam_callbacks.append(check_event_for_spam)
+ if user_may_join_room is not None:
+ self._user_may_join_room_callbacks.append(user_may_join_room)
+
if user_may_invite is not None:
self._user_may_invite_callbacks.append(user_may_invite)
+ if user_may_send_3pid_invite is not None:
+ self._user_may_send_3pid_invite_callbacks.append(
+ user_may_send_3pid_invite,
+ )
+
if user_may_create_room is not None:
self._user_may_create_room_callbacks.append(user_may_create_room)
@@ -259,6 +275,24 @@ class SpamChecker:
return False
+ async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
+ """Checks if a given users is allowed to join a room.
+ Not called when a user creates a room.
+
+ Args:
+ userid: The ID of the user wanting to join the room
+ room_id: The ID of the room the user wants to join
+ is_invited: Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ for callback in self._user_may_join_room_callbacks:
+ if await callback(user_id, room_id, is_invited) is False:
+ return False
+
+ return True
+
async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool:
@@ -280,6 +314,31 @@ class SpamChecker:
return True
+ async def user_may_send_3pid_invite(
+ self, inviter_userid: str, medium: str, address: str, room_id: str
+ ) -> bool:
+ """Checks if a given user may invite a given threepid into the room
+
+ If this method returns false, the threepid invite will be rejected.
+
+ Note that if the threepid is already associated with a Matrix user ID, Synapse
+ will call user_may_invite with said user ID instead.
+
+ Args:
+ inviter_userid: The user ID of the sender of the invitation
+ medium: The 3PID's medium (e.g. "email")
+ address: The 3PID's address (e.g. "alice@example.com")
+ room_id: The room ID
+
+ Returns:
+ True if the user may send the invite, otherwise False
+ """
+ for callback in self._user_may_send_3pid_invite_callbacks:
+ if await callback(inviter_userid, medium, address, room_id) is False:
+ return False
+
+ return True
+
async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index d94b1bb4..976d9fa4 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -217,6 +217,15 @@ class ThirdPartyEventRules:
for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await callback(event, state_events)
+ except SynapseError as e:
+ # FIXME: Being able to throw SynapseErrors is relied upon by
+ # some modules. PR #10386 accidentally broke this ability.
+ # That said, we aren't keen on exposing this implementation detail
+ # to modules and we should one day have a proper way to do what
+ # is wanted.
+ # This module callback needs a rework so that hacks such as
+ # this one are not necessary.
+ raise e
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 38fccd1e..520edbbf 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -372,7 +372,7 @@ class EventClientSerializer:
def __init__(self, hs):
self.store = hs.get_datastore()
self.experimental_msc1849_support_enabled = (
- hs.config.experimental_msc1849_support_enabled
+ hs.config.server.experimental_msc1849_support_enabled
)
async def serialize_event(
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5f4383ee..d8c0b86f 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1008,7 +1008,10 @@ class FederationServer(FederationBase):
async with lock:
logger.info("handling received PDU: %s", event)
try:
- await self._federation_event_handler.on_receive_pdu(origin, event)
+ with nested_logging_context(event.event_id):
+ await self._federation_event_handler.on_receive_pdu(
+ origin, event
+ )
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 95176ba6..c32539bf 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -117,7 +117,7 @@ class PublicRoomList(BaseFederationServlet):
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_room_list_handler()
- self.allow_access = hs.config.allow_public_rooms_over_federation
+ self.allow_access = hs.config.server.allow_public_rooms_over_federation
async def on_GET(
self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]]
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
deleted file mode 100644
index 0ccef884..00000000
--- a/synapse/handlers/_base.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Copyright 2014 - 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Optional
-
-from synapse.api.ratelimiting import Ratelimiter
-from synapse.types import Requester
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class BaseHandler:
- """
- Common base class for the event handlers.
-
- Deprecated: new code should not use this. Instead, Handler classes should define the
- fields they actually need. The utility methods should either be factored out to
- standalone helper functions, or to different Handler classes.
- """
-
- def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastore()
- self.auth = hs.get_auth()
- self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler()
- self.distributor = hs.get_distributor()
- self.clock = hs.get_clock()
- self.hs = hs
-
- # The rate_hz and burst_count are overridden on a per-user basis
- self.request_ratelimiter = Ratelimiter(
- store=self.store, clock=self.clock, rate_hz=0, burst_count=0
- )
- self._rc_message = self.hs.config.ratelimiting.rc_message
-
- # Check whether ratelimiting room admin message redaction is enabled
- # by the presence of rate limits in the config
- if self.hs.config.ratelimiting.rc_admin_redaction:
- self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
- store=self.store,
- clock=self.clock,
- rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second,
- burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count,
- )
- else:
- self.admin_redaction_ratelimiter = None
-
- self.server_name = hs.hostname
-
- self.event_builder_factory = hs.get_event_builder_factory()
-
- async def ratelimit(
- self,
- requester: Requester,
- update: bool = True,
- is_admin_redaction: bool = False,
- ) -> None:
- """Ratelimits requests.
-
- Args:
- requester
- update: Whether to record that a request is being processed.
- Set to False when doing multiple checks for one request (e.g.
- to check up front if we would reject the request), and set to
- True for the last call for a given request.
- is_admin_redaction: Whether this is a room admin/moderator
- redacting an event. If so then we may apply different
- ratelimits depending on config.
-
- Raises:
- LimitExceededError if the request should be ratelimited
- """
- user_id = requester.user.to_string()
-
- # The AS user itself is never rate limited.
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service is not None:
- return # do not ratelimit app service senders
-
- messages_per_second = self._rc_message.per_second
- burst_count = self._rc_message.burst_count
-
- # Check if there is a per user override in the DB.
- override = await self.store.get_ratelimit_for_user(user_id)
- if override:
- # If overridden with a null Hz then ratelimiting has been entirely
- # disabled for the user
- if not override.messages_per_second:
- return
-
- messages_per_second = override.messages_per_second
- burst_count = override.burst_count
-
- if is_admin_redaction and self.admin_redaction_ratelimiter:
- # If we have separate config for admin redactions, use a separate
- # ratelimiter as to not have user_ids clash
- await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
- else:
- # Override rate and burst count per-user
- await self.request_ratelimiter.ratelimit(
- requester,
- rate_hz=messages_per_second,
- burst_count=burst_count,
- update=update,
- )
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 5a5f124d..87e415df 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -67,12 +67,8 @@ class AccountValidityHandler:
and self._account_validity_renew_by_email_enabled
):
# Don't do email-specific configuration if renewal by email is disabled.
- self._template_html = (
- hs.config.account_validity.account_validity_template_html
- )
- self._template_text = (
- hs.config.account_validity.account_validity_template_text
- )
+ self._template_html = hs.config.email.account_validity_template_html
+ self._template_text = hs.config.email.account_validity_template_text
self._renew_email_subject = (
hs.config.account_validity.account_validity_renew_email_subject
)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index bfa7f2c5..a53cd62d 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -21,18 +21,15 @@ from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class AdminHandler(BaseHandler):
+class AdminHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a8c717ef..f4612a5b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -52,7 +52,6 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers._base import BaseHandler
from synapse.handlers.ui_auth import (
INTERACTIVE_AUTH_CHECKERS,
UIAuthSessionDataConstants,
@@ -186,19 +185,20 @@ class LoginTokenAttributes:
auth_provider_id = attr.ib(type=str)
-class AuthHandler(BaseHandler):
+class AuthHandler:
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
self.checkers[inst.AUTH_TYPE] = inst # type: ignore
- self.bcrypt_rounds = hs.config.bcrypt_rounds
+ self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
# we can't use hs.get_module_api() here, because to do so will create an
# import loop.
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 9ae5b775..e88c3c27 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -19,19 +19,17 @@ from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Requester, UserID, create_requester
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class DeactivateAccountHandler(BaseHandler):
+class DeactivateAccountHandler:
"""Handler which deals with deactivating user accounts."""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
self.hs = hs
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@@ -133,6 +131,10 @@ class DeactivateAccountHandler(BaseHandler):
# delete from user directory
await self.user_directory_handler.handle_local_user_deactivated(user_id)
+ # If the user is present in the monthly active users table
+ # remove them
+ await self.store.remove_deactivated_user_from_mau_table(user_id)
+
# Mark the user as erased, if they asked for that
if erase_data:
user = UserID.from_string(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 35334725..75e60197 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -40,8 +40,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -50,14 +48,16 @@ logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100
-class DeviceWorkerHandler(BaseHandler):
+class DeviceWorkerHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.clock = hs.get_clock()
self.hs = hs
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler()
+ self.server_name = hs.hostname
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 5cfba3c8..14ed7d98 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -31,26 +31,25 @@ from synapse.appservice import ApplicationService
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class DirectoryHandler(BaseHandler):
+class DirectoryHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.auth = hs.get_auth()
+ self.hs = hs
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.config = hs.config
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
- self.require_membership = hs.config.require_membership_for_aliases
+ self.require_membership = hs.config.server.require_membership_for_aliases
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self.server_name = hs.hostname
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index cb81fa09..d089c562 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -22,7 +22,8 @@ from synapse.api.constants import (
RestrictedJoinRuleTypes,
)
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.api.room_versions import RoomVersion
+from synapse.event_auth import check_auth_rules_for_event
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
@@ -45,21 +46,17 @@ class EventAuthHandler:
self._store = hs.get_datastore()
self._server_name = hs.hostname
- async def check_from_context(
+ async def check_auth_rules_from_context(
self,
- room_version: str,
+ room_version_obj: RoomVersion,
event: EventBase,
context: EventContext,
- do_sig_check: bool = True,
) -> None:
+ """Check an event passes the auth rules at its own auth events"""
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
-
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- event_auth.check(
- room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
- )
+ check_auth_rules_for_event(room_version_obj, event, auth_events)
def compute_auth_events(
self,
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 4b3f0370..1f64534a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -25,8 +25,6 @@ from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -34,11 +32,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class EventStreamHandler(BaseHandler):
+class EventStreamHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ self.hs = hs
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
@@ -138,9 +136,9 @@ class EventStreamHandler(BaseHandler):
return chunk
-class EventHandler(BaseHandler):
+class EventHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
self.storage = hs.get_storage()
async def get_event(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index adbd150e..3e341bd2 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,11 +45,14 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
+from synapse.event_auth import (
+ check_auth_rules_for_event,
+ validate_event_for_room_version,
+)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@@ -74,15 +77,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class FederationHandler(BaseHandler):
+class FederationHandler:
"""Handles general incoming federation requests
Incoming events are *not* handled here, for which see FederationEventHandler.
"""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
self.hs = hs
self.store = hs.get_datastore()
@@ -95,6 +96,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.event_builder_factory = hs.get_event_builder_factory()
self._event_auth_handler = hs.get_event_auth_handler()
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.config = hs.config
@@ -723,8 +725,8 @@ class FederationHandler(BaseHandler):
state_ids,
)
- builder = self.event_builder_factory.new(
- room_version.identifier,
+ builder = self.event_builder_factory.for_room_version(
+ room_version,
{
"type": EventTypes.Member,
"content": event_content,
@@ -747,10 +749,9 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- await self._event_auth_handler.check_from_context(
- room_version.identifier, event, context, do_sig_check=False
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version, event, context
)
-
return event
async def on_invite_request(
@@ -767,7 +768,7 @@ class FederationHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- if self.hs.config.block_non_admin_invites:
+ if self.hs.config.server.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
if not await self.spam_checker.user_may_invite(
@@ -902,9 +903,9 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- room_version = await self.store.get_room_version_id(room_id)
- builder = self.event_builder_factory.new(
- room_version,
+ room_version_obj = await self.store.get_room_version(room_id)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj,
{
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
@@ -921,8 +922,8 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- await self._event_auth_handler.check_from_context(
- room_version, event, context, do_sig_check=False
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Failed to create new leave %r because %s", event, e)
@@ -954,10 +955,10 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- room_version = await self.store.get_room_version_id(room_id)
+ room_version_obj = await self.store.get_room_version(room_id)
- builder = self.event_builder_factory.new(
- room_version,
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj,
{
"type": EventTypes.Member,
"content": {"membership": Membership.KNOCK},
@@ -983,8 +984,8 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_knock_request`
- await self._event_auth_handler.check_from_context(
- room_version, event, context, do_sig_check=False
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Failed to create new knock %r because %s", event, e)
@@ -1173,7 +1174,8 @@ class FederationHandler(BaseHandler):
auth_for_e[(EventTypes.Create, "")] = create_event
try:
- event_auth.check(room_version, e, auth_events=auth_for_e)
+ validate_event_for_room_version(room_version, e)
+ check_auth_rules_for_event(room_version, e, auth_for_e)
except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
@@ -1250,8 +1252,10 @@ class FederationHandler(BaseHandler):
}
if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname):
- room_version = await self.store.get_room_version_id(room_id)
- builder = self.event_builder_factory.new(room_version, event_dict)
+ room_version_obj = await self.store.get_room_version(room_id)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
EventValidator().validate_builder(builder)
event, context = await self.event_creation_handler.create_new_client_event(
@@ -1259,7 +1263,7 @@ class FederationHandler(BaseHandler):
)
event, context = await self.add_display_name_to_third_party_invite(
- room_version, event_dict, event, context
+ room_version_obj, event_dict, event, context
)
EventValidator().validate_new(event, self.config)
@@ -1269,8 +1273,9 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
- await self._event_auth_handler.check_from_context(
- room_version, event, context
+ validate_event_for_room_version(room_version_obj, event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
@@ -1304,22 +1309,25 @@ class FederationHandler(BaseHandler):
"""
assert_params_in_dict(event_dict, ["room_id"])
- room_version = await self.store.get_room_version_id(event_dict["room_id"])
+ room_version_obj = await self.store.get_room_version(event_dict["room_id"])
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
- builder = self.event_builder_factory.new(room_version, event_dict)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event, context = await self.add_display_name_to_third_party_invite(
- room_version, event_dict, event, context
+ room_version_obj, event_dict, event, context
)
try:
- await self._event_auth_handler.check_from_context(
- room_version, event, context
+ validate_event_for_room_version(room_version_obj, event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Denying third party invite %r because %s", event, e)
@@ -1336,7 +1344,7 @@ class FederationHandler(BaseHandler):
async def add_display_name_to_third_party_invite(
self,
- room_version: str,
+ room_version_obj: RoomVersion,
event_dict: JsonDict,
event: EventBase,
context: EventContext,
@@ -1368,7 +1376,9 @@ class FederationHandler(BaseHandler):
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
- builder = self.event_builder_factory.new(room_version, event_dict)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
EventValidator().validate_builder(builder)
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 01fd8411..f640b417 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -29,7 +29,6 @@ from typing import (
from prometheus_client import Counter
-from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -47,7 +46,11 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.event_auth import auth_types_for_event
+from synapse.event_auth import (
+ auth_types_for_event,
+ check_auth_rules_for_event,
+ validate_event_for_room_version,
+)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
@@ -68,11 +71,7 @@ from synapse.types import (
UserID,
get_domain_from_id,
)
-from synapse.util.async_helpers import (
- Linearizer,
- concurrently_execute,
- yieldable_gather_results,
-)
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
@@ -357,6 +356,11 @@ class FederationEventHandler:
)
# all looks good, we can persist the event.
+
+ # First, precalculate the joined hosts so that the federation sender doesn't
+ # need to.
+ await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
+
await self._run_push_actions_and_persist_event(event, context)
return event, context
@@ -890,6 +894,9 @@ class FederationEventHandler:
backfilled=backfilled,
)
except AuthError as e:
+ # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it
+ # for now
+ logger.exception("Unexpected AuthError from _check_event_auth")
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@@ -1011,9 +1018,8 @@ class FederationEventHandler:
room_version = await self._store.get_room_version(marker_event.room_id)
create_event = await self._store.get_create_event_for_room(marker_event.room_id)
room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
- if (
- not room_version.msc2716_historical
- or not self._config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self._config.experimental.msc2716_enabled
or marker_event.sender != room_creator
):
return
@@ -1155,7 +1161,10 @@ class FederationEventHandler:
return
logger.info(
- "Persisting %i of %i remaining events", len(roots), len(event_map)
+ "Persisting %i of %i remaining outliers: %s",
+ len(roots),
+ len(event_map),
+ shortstr(e.event_id for e in roots),
)
await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
@@ -1189,7 +1198,10 @@ class FederationEventHandler:
allow_rejected=True,
)
- async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
+ room_version = await self._store.get_room_version_id(room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
with nested_logging_context(suffix=event.event_id):
auth = {}
for auth_event_id in event.auth_event_ids():
@@ -1207,17 +1219,16 @@ class FederationEventHandler:
auth[(ae.type, ae.state_key)] = ae
context = EventContext.for_outlier()
- context = await self._check_event_auth(
- origin,
- event,
- context,
- claimed_auth_event_map=auth,
- )
+ try:
+ validate_event_for_room_version(room_version_obj, event)
+ check_auth_rules_for_event(room_version_obj, event, auth)
+ except AuthError as e:
+ logger.warning("Rejecting %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
return event, context
- events_to_persist = (
- x for x in await yieldable_gather_results(prep, fetched_events) if x
- )
+ events_to_persist = (x for x in (prep(event) for event in fetched_events) if x)
await self.persist_events_and_notify(room_id, tuple(events_to_persist))
async def _check_event_auth(
@@ -1226,7 +1237,6 @@ class FederationEventHandler:
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
"""
@@ -1242,42 +1252,45 @@ class FederationEventHandler:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly including events that were rejected, or are in the wrong room.
-
- Only populated when populating outliers.
-
backfilled: True if the event was backfilled.
Returns:
The updated context object.
"""
- # claimed_auth_event_map should be given iff the event is an outlier
- assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
+ # This method should only be used for non-outliers
+ assert not event.internal_metadata.outlier
+ # first of all, check that the event itself is valid.
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- if claimed_auth_event_map:
- # if we have a copy of the auth events from the event, use that as the
- # basis for auth.
- auth_events = claimed_auth_event_map
- else:
- # otherwise, we calculate what the auth events *should* be, and use that
- prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self._event_auth_handler.compute_auth_events(
- event, prev_state_ids, for_verification=True
- )
- auth_events_x = await self._store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+ try:
+ validate_event_for_room_version(room_version_obj, event)
+ except AuthError as e:
+ logger.warning("While validating received event %r: %s", event, e)
+ # TODO: use a different rejected reason here?
+ context.rejected = RejectedReason.AUTH_ERROR
+ return context
+
+ # calculate what the auth events *should* be, to use as a basis for auth.
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
+ event, prev_state_ids, for_verification=True
+ )
+ auth_events_x = await self._store.get_events(auth_events_ids)
+ calculated_auth_event_map = {
+ (e.type, e.state_key): e for e in auth_events_x.values()
+ }
try:
(
context,
auth_events_for_auth,
) = await self._update_auth_events_and_context_for_auth(
- origin, event, context, auth_events
+ origin,
+ event,
+ context,
+ calculated_auth_event_map=calculated_auth_event_map,
)
except Exception:
# We don't really mind if the above fails, so lets not fail
@@ -1289,24 +1302,17 @@ class FederationEventHandler:
"Ignoring failure and continuing processing of event.",
event.event_id,
)
- auth_events_for_auth = auth_events
+ auth_events_for_auth = calculated_auth_event_map
try:
- event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
+ check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth)
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
+ return context
- if not context.rejected:
- await self._check_for_soft_fail(event, state, backfilled, origin=origin)
- await self._maybe_kick_guest_users(event)
-
- # If we are going to send this event over federation we precaclculate
- # the joined hosts.
- if event.internal_metadata.get_send_on_behalf_of():
- await self._event_creation_handler.cache_joined_hosts_for_event(
- event, context
- )
+ await self._check_for_soft_fail(event, state, backfilled, origin=origin)
+ await self._maybe_kick_guest_users(event)
return context
@@ -1404,7 +1410,7 @@ class FederationEventHandler:
}
try:
- event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ check_auth_rules_for_event(room_version_obj, event, current_auth_events)
except AuthError as e:
logger.warning(
"Soft-failing %r (from %s) because %s",
@@ -1425,7 +1431,7 @@ class FederationEventHandler:
origin: str,
event: EventBase,
context: EventContext,
- input_auth_events: StateMap[EventBase],
+ calculated_auth_event_map: StateMap[EventBase],
) -> Tuple[EventContext, StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs.
@@ -1443,19 +1449,17 @@ class FederationEventHandler:
event:
context:
- input_auth_events:
- Map from (event_type, state_key) to event
-
- Normally, our calculated auth_events based on the state of the room
- at the event's position in the DAG, though occasionally (eg if the
- event is an outlier), may be the auth events claimed by the remote
- server.
+ calculated_auth_event_map:
+ Our calculated auth_events based on the state of the room
+ at the event's position in the DAG.
Returns:
updated context, updated auth event map
"""
- # take a copy of input_auth_events before we modify it.
- auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
+ assert not event.internal_metadata.outlier
+
+ # take a copy of calculated_auth_event_map before we modify it.
+ auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map)
event_auth_events = set(event.auth_event_ids())
@@ -1475,6 +1479,11 @@ class FederationEventHandler:
logger.debug("Events %s are in the store", have_events)
missing_auth.difference_update(have_events)
+ # missing_auth is now the set of event_ids which:
+ # a. are listed in event.auth_events, *and*
+ # b. are *not* part of our calculated auth events based on room state, *and*
+ # c. are *not* yet in our database.
+
if missing_auth:
# If we don't have all the auth events, we need to get them.
logger.info("auth_events contains unknown events: %s", missing_auth)
@@ -1496,19 +1505,31 @@ class FederationEventHandler:
}
)
- if event.internal_metadata.is_outlier():
- # XXX: given that, for an outlier, we'll be working with the
- # event's *claimed* auth events rather than those we calculated:
- # (a) is there any point in this test, since different_auth below will
- # obviously be empty
- # (b) alternatively, why don't we do it earlier?
- logger.info("Skipping auth_event fetch for outlier")
- return context, auth_events
+ # auth_events now contains
+ # 1. our *calculated* auth events based on the room state, plus:
+ # 2. any events which:
+ # a. are listed in `event.auth_events`, *and*
+ # b. are not part of our calculated auth events, *and*
+ # c. were not in our database before the call to /event_auth
+ # d. have since been added to our database (most likely by /event_auth).
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
+ # different_auth is the set of events which *are* in `event.auth_events`, but
+ # which are *not* in `auth_events`. Comparing with (2.) above, this means
+ # exclusively the set of `event.auth_events` which we already had in our
+ # database before any call to /event_auth.
+ #
+ # I'm reasonably sure that the fact that events returned by /event_auth are
+ # blindly added to auth_events (and hence excluded from different_auth) is a bug
+ # - though it's a very long-standing one (see
+ # https://github.com/matrix-org/synapse/commit/78015948a7febb18e000651f72f8f58830a55b93#diff-0bc92da3d703202f5b9be2d3f845e375f5b1a6bc6ba61705a8af9be1121f5e42R786
+ # from Jan 2015 which seems to add it, though it actually just moves it from
+ # elsewhere (before that, it gets lost in a mess of huge "various bug fixes"
+ # PRs).
+
if not different_auth:
return context, auth_events
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index fe8a9958..9c319b53 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -39,8 +39,6 @@ from synapse.util.stringutils import (
valid_id_server_location,
)
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -49,15 +47,14 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://"
-class IdentityHandler(BaseHandler):
+class IdentityHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
# An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
- hs, ip_blacklist=hs.config.federation_ip_range_blacklist
+ hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist
)
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
@@ -573,9 +570,15 @@ class IdentityHandler(BaseHandler):
# Try to validate as email
if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ # Remote emails will only be used if a valid identity server is provided.
+ assert (
+ self.hs.config.registration.account_threepid_delegate_email is not None
+ )
+
# Ask our delegated email identity server
validation_session = await self.threepid_from_creds(
- self.hs.config.account_threepid_delegate_email, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_email,
+ threepid_creds,
)
elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
@@ -587,10 +590,11 @@ class IdentityHandler(BaseHandler):
return validation_session
# Try to validate as msisdn
- if self.hs.config.account_threepid_delegate_msisdn:
+ if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
- self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_msisdn,
+ threepid_creds,
)
return validation_session
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9ad39a65..d4e45561 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -31,8 +31,6 @@ from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -40,9 +38,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class InitialSyncHandler(BaseHandler):
+class InitialSyncHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.state_handler = hs.get_state_handler()
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index fd861e94..4de9f4b8 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import random
+from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -39,9 +40,11 @@ from synapse.api.errors import (
NotFoundError,
ShadowBanError,
SynapseError,
+ UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
+from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
@@ -59,8 +62,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.server import HomeServer
@@ -79,7 +80,7 @@ class MessageHandler:
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
- self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
@@ -413,7 +414,9 @@ class EventCreationHandler:
self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.config = hs.config
- self.require_membership_for_aliases = hs.config.require_membership_for_aliases
+ self.require_membership_for_aliases = (
+ hs.config.server.require_membership_for_aliases
+ )
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
@@ -423,13 +426,12 @@ class EventCreationHandler:
Membership.JOIN,
Membership.KNOCK,
}
- if self.hs.config.include_profile_data_on_invite:
+ if self.hs.config.server.include_profile_data_on_invite:
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
- # This is only used to get at ratelimit function
- self.base_handler = BaseHandler(hs)
+ self.request_ratelimiter = hs.get_request_ratelimiter()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
@@ -459,11 +461,11 @@ class EventCreationHandler:
#
self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
# The number of forward extremeities before a dummy event is sent.
- self._dummy_events_threshold = hs.config.dummy_events_threshold
+ self._dummy_events_threshold = hs.config.server.dummy_events_threshold
if (
self.config.worker.run_background_tasks
- and self.config.cleanup_extremities_with_dummy_events
+ and self.config.server.cleanup_extremities_with_dummy_events
):
self.clock.looping_call(
lambda: run_as_background_process(
@@ -475,7 +477,7 @@ class EventCreationHandler:
self._message_handler = hs.get_message_handler()
- self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
self._external_cache = hs.get_external_cache()
@@ -549,16 +551,22 @@ class EventCreationHandler:
await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
- room_version = event_dict["content"]["room_version"]
+ room_version_id = event_dict["content"]["room_version"]
+ room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version_obj:
+ # this can happen if support is withdrawn for a room version
+ raise UnsupportedRoomVersionError(room_version_id)
else:
try:
- room_version = await self.store.get_room_version_id(
+ room_version_obj = await self.store.get_room_version(
event_dict["room_id"]
)
except NotFoundError:
raise AuthError(403, "Unknown room")
- builder = self.event_builder_factory.new(room_version, event_dict)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
self.validator.validate_builder(builder)
@@ -1064,9 +1072,17 @@ class EventCreationHandler:
EventTypes.Create,
"",
):
- room_version = event.content.get("room_version", RoomVersions.V1.identifier)
+ room_version_id = event.content.get(
+ "room_version", RoomVersions.V1.identifier
+ )
+ room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version_obj:
+ raise UnsupportedRoomVersionError(
+ "Attempt to create a room with unsupported room version %s"
+ % (room_version_id,)
+ )
else:
- room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = await self.store.get_room_version(event.room_id)
if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here are
@@ -1075,8 +1091,9 @@ class EventCreationHandler:
assert event.content["membership"] == Membership.LEAVE
else:
try:
- await self._event_auth_handler.check_from_context(
- room_version, event, context
+ validate_event_for_room_version(room_version_obj, event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
@@ -1302,7 +1319,7 @@ class EventCreationHandler:
original_event and event.sender != original_event.sender
)
- await self.base_handler.ratelimit(
+ await self.request_ratelimiter.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
@@ -1456,6 +1473,39 @@ class EventCreationHandler:
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
+ if event.type == EventTypes.MSC2716_INSERTION:
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ create_event = await self.store.get_create_event_for_room(event.room_id)
+ room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+
+ # Only check an insertion event if the room version
+ # supports it or the event is from the room creator.
+ if room_version_obj.msc2716_historical or (
+ self.config.experimental.msc2716_enabled
+ and event.sender == room_creator
+ ):
+ next_batch_id = event.content.get(
+ EventContentFields.MSC2716_NEXT_BATCH_ID
+ )
+ conflicting_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(
+ event.room_id, next_batch_id
+ )
+ )
+ if conflicting_insertion_event_id is not None:
+ # The current insertion event that we're processing is invalid
+ # because an insertion event already exists in the room with the
+ # same next_batch_id. We can't allow multiple because the batch
+ # pointing will get weird, e.g. we can't determine which insertion
+ # event the batch event is pointing to.
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Another insertion event already exists with the same next_batch_id",
+ errcode=Codes.INVALID_PARAM,
+ )
+
# Mark any `m.historical` messages as backfilled so they don't appear
# in `/sync` and have the proper decrementing `stream_ordering` as we import
backfilled = False
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 08b93b3e..176e4dfd 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -85,23 +85,29 @@ class PaginationHandler:
self._purges_by_id: Dict[str, PurgeStatus] = {}
self._event_serializer = hs.get_event_client_serializer()
- self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+ self._retention_default_max_lifetime = (
+ hs.config.server.retention_default_max_lifetime
+ )
- self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
- self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
+ self._retention_allowed_lifetime_min = (
+ hs.config.server.retention_allowed_lifetime_min
+ )
+ self._retention_allowed_lifetime_max = (
+ hs.config.server.retention_allowed_lifetime_max
+ )
- if hs.config.worker.run_background_tasks and hs.config.retention_enabled:
+ if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled:
# Run the purge jobs described in the configuration file.
- for job in hs.config.retention_purge_jobs:
+ for job in hs.config.server.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job)
self.clock.looping_call(
run_as_background_process,
- job["interval"],
+ job.interval,
"purge_history_for_rooms_in_range",
self.purge_history_for_rooms_in_range,
- job["shortest_max_lifetime"],
- job["longest_max_lifetime"],
+ job.shortest_max_lifetime,
+ job.longest_max_lifetime,
)
async def purge_history_for_rooms_in_range(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index b23a1541..e6c3cf58 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -32,8 +32,6 @@ from synapse.types import (
get_domain_from_id,
)
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -43,7 +41,7 @@ MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
-class ProfileHandler(BaseHandler):
+class ProfileHandler:
"""Handles fetching and updating user profile information.
ProfileHandler can be instantiated directly on workers and will
@@ -54,7 +52,9 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.hs = hs
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -62,6 +62,7 @@ class ProfileHandler(BaseHandler):
)
self.user_directory_handler = hs.get_user_directory_handler()
+ self.request_ratelimiter = hs.get_request_ratelimiter()
if hs.config.worker.run_background_tasks:
self.clock.looping_call(
@@ -178,7 +179,7 @@ class ProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
- if not by_admin and not self.hs.config.enable_set_displayname:
+ if not by_admin and not self.hs.config.registration.enable_set_displayname:
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
@@ -268,7 +269,7 @@ class ProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
- if not by_admin and not self.hs.config.enable_set_avatar_url:
+ if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
@@ -346,7 +347,7 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
return
- await self.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
# Do not actually update the room state for shadow-banned users.
if requester.shadow_banned:
@@ -397,7 +398,7 @@ class ProfileHandler(BaseHandler):
# when building a membership event. In this case, we must allow the
# lookup.
if (
- not self.hs.config.limit_profile_requests_to_users_who_share_rooms
+ not self.hs.config.server.limit_profile_requests_to_users_who_share_rooms
or not requester
):
return
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index bd8160e7..58593e57 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -17,17 +17,14 @@ from typing import TYPE_CHECKING
from synapse.util.async_helpers import Linearizer
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class ReadMarkerHandler(BaseHandler):
+class ReadMarkerHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
self.server_name = hs.config.server.server_name
self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler()
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f21f33ad..374e961e 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
-from synapse.handlers._base import BaseHandler
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -26,10 +25,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class ReceiptsHandler(BaseHandler):
+class ReceiptsHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.notifier = hs.get_notifier()
self.server_name = hs.config.server.server_name
self.store = hs.get_datastore()
self.event_auth_handler = hs.get_event_auth_handler()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 4f99f137..a0e6a017 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -41,8 +41,6 @@ from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -85,9 +83,10 @@ class LoginDict(TypedDict):
refresh_token: Optional[str]
-class RegistrationHandler(BaseHandler):
+class RegistrationHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
self.hs = hs
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -116,8 +115,8 @@ class RegistrationHandler(BaseHandler):
self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
- self.session_lifetime = hs.config.session_lifetime
- self.access_token_lifetime = hs.config.access_token_lifetime
+ self.session_lifetime = hs.config.registration.session_lifetime
+ self.access_token_lifetime = hs.config.registration.access_token_lifetime
init_counters_for_auth_provider("")
@@ -340,8 +339,13 @@ class RegistrationHandler(BaseHandler):
auth_provider=(auth_provider_id or ""),
).inc()
+ # If the user does not need to consent at registration, auto-join any
+ # configured rooms.
if not self.hs.config.consent.user_consent_at_registration:
- if not self.hs.config.auto_join_rooms_for_guests and make_guest:
+ if (
+ not self.hs.config.registration.auto_join_rooms_for_guests
+ and make_guest
+ ):
logger.info(
"Skipping auto-join for %s because auto-join for guests is disabled",
user_id,
@@ -387,7 +391,7 @@ class RegistrationHandler(BaseHandler):
"preset": self.hs.config.registration.autocreate_auto_join_room_preset,
}
- # If the configuration providers a user ID to create rooms with, use
+ # If the configuration provides a user ID to create rooms with, use
# that instead of the first user registered.
requires_join = False
if self.hs.config.registration.auto_join_user_id:
@@ -510,7 +514,7 @@ class RegistrationHandler(BaseHandler):
# we don't have a local user in the room to craft up an invite with.
requires_invite = await self.store.is_host_joined(
room_id,
- self.server_name,
+ self._server_name,
)
if requires_invite:
@@ -854,7 +858,7 @@ class RegistrationHandler(BaseHandler):
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
- self.hs.config.mau_limits_reserved_threepids, threepid
+ self.hs.config.server.mau_limits_reserved_threepids, threepid
):
await self.store.upsert_monthly_active_user(user_id)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 8fede5e9..7072bca1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -52,6 +52,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin
@@ -75,8 +76,6 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -87,15 +86,18 @@ id_server_scheme = "https://"
FIVE_MINUTES_IN_MS = 5 * 60 * 1000
-class RoomCreationHandler(BaseHandler):
+class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.hs = hs
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self.config = hs.config
+ self.request_ratelimiter = hs.get_request_ratelimiter()
# Room state based off defined presets
self._presets_dict: Dict[str, Dict[str, Any]] = {
@@ -161,7 +163,7 @@ class RoomCreationHandler(BaseHandler):
Raises:
ShadowBanError if the requester is shadow-banned.
"""
- await self.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
user_id = requester.user.to_string()
@@ -237,8 +239,9 @@ class RoomCreationHandler(BaseHandler):
},
},
)
- old_room_version = await self.store.get_room_version_id(old_room_id)
- await self._event_auth_handler.check_from_context(
+ old_room_version = await self.store.get_room_version(old_room_id)
+ validate_event_for_room_version(old_room_version, tombstone_event)
+ await self._event_auth_handler.check_auth_rules_from_context(
old_room_version, tombstone_event, tombstone_context
)
@@ -663,10 +666,10 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
- await self.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
room_version_id = config.get(
- "room_version", self.config.default_room_version.identifier
+ "room_version", self.config.server.default_room_version.identifier
)
if not isinstance(room_version_id, str):
@@ -858,6 +861,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -960,6 +964,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=ratelimit,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
new file mode 100644
index 00000000..51dd4e75
--- /dev/null
+++ b/synapse/handlers/room_batch.py
@@ -0,0 +1,423 @@
+import logging
+from typing import TYPE_CHECKING, List, Tuple
+
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.appservice import ApplicationService
+from synapse.http.servlet import assert_params_in_dict
+from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RoomBatchHandler:
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state_store = hs.get_storage().state
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
+ """Finds the depth which would sort it after the most-recent
+ prev_event_id but before the successors of those events. If no
+ successors are found, we assume it's an historical extremity part of the
+ current batch and use the same depth of the prev_event_ids.
+
+ Args:
+ prev_event_ids: List of prev event IDs
+
+ Returns:
+ Inherited depth
+ """
+ (
+ most_recent_prev_event_id,
+ most_recent_prev_event_depth,
+ ) = await self.store.get_max_depth_of(prev_event_ids)
+
+ # We want to insert the historical event after the `prev_event` but before the successor event
+ #
+ # We inherit depth from the successor event instead of the `prev_event`
+ # because events returned from `/messages` are first sorted by `topological_ordering`
+ # which is just the `depth` and then tie-break with `stream_ordering`.
+ #
+ # We mark these inserted historical events as "backfilled" which gives them a
+ # negative `stream_ordering`. If we use the same depth as the `prev_event`,
+ # then our historical event will tie-break and be sorted before the `prev_event`
+ # when it should come after.
+ #
+ # We want to use the successor event depth so they appear after `prev_event` because
+ # it has a larger `depth` but before the successor event because the `stream_ordering`
+ # is negative before the successor event.
+ successor_event_ids = await self.store.get_successor_events(
+ [most_recent_prev_event_id]
+ )
+
+ # If we can't find any successor events, then it's a forward extremity of
+ # historical messages and we can just inherit from the previous historical
+ # event which we can already assume has the correct depth where we want
+ # to insert into.
+ if not successor_event_ids:
+ depth = most_recent_prev_event_depth
+ else:
+ (
+ _,
+ oldest_successor_depth,
+ ) = await self.store.get_min_depth_of(successor_event_ids)
+
+ depth = oldest_successor_depth
+
+ return depth
+
+ def create_insertion_event_dict(
+ self, sender: str, room_id: str, origin_server_ts: int
+ ) -> JsonDict:
+ """Creates an event dict for an "insertion" event with the proper fields
+ and a random batch ID.
+
+ Args:
+ sender: The event author MXID
+ room_id: The room ID that the event belongs to
+ origin_server_ts: Timestamp when the event was sent
+
+ Returns:
+ The new event dictionary to insert.
+ """
+
+ next_batch_id = random_string(8)
+ insertion_event = {
+ "type": EventTypes.MSC2716_INSERTION,
+ "sender": sender,
+ "room_id": room_id,
+ "content": {
+ EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
+ EventContentFields.MSC2716_HISTORICAL: True,
+ },
+ "origin_server_ts": origin_server_ts,
+ }
+
+ return insertion_event
+
+ async def create_requester_for_user_id_from_app_service(
+ self, user_id: str, app_service: ApplicationService
+ ) -> Requester:
+ """Creates a new requester for the given user_id
+ and validates that the app service is allowed to control
+ the given user.
+
+ Args:
+ user_id: The author MXID that the app service is controlling
+ app_service: The app service that controls the user
+
+ Returns:
+ Requester object
+ """
+
+ await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
+
+ return create_requester(user_id, app_service=app_service)
+
+ async def get_most_recent_auth_event_ids_from_event_id_list(
+ self, event_ids: List[str]
+ ) -> List[str]:
+ """Find the most recent auth event ids (derived from state events) that
+ allowed that message to be sent. We will use this as a base
+ to auth our historical messages against.
+
+ Args:
+ event_ids: List of event ID's to look at
+
+ Returns:
+ List of event ID's
+ """
+
+ (
+ most_recent_prev_event_id,
+ _,
+ ) = await self.store.get_max_depth_of(event_ids)
+ # mapping from (type, state_key) -> state_event_id
+ prev_state_map = await self.state_store.get_state_ids_for_event(
+ most_recent_prev_event_id
+ )
+ # List of state event ID's
+ prev_state_ids = list(prev_state_map.values())
+ auth_event_ids = prev_state_ids
+
+ return auth_event_ids
+
+ async def persist_state_events_at_start(
+ self,
+ state_events_at_start: List[JsonDict],
+ room_id: str,
+ initial_auth_event_ids: List[str],
+ app_service_requester: Requester,
+ ) -> List[str]:
+ """Takes all `state_events_at_start` event dictionaries and creates/persists
+ them as floating state events which don't resolve into the current room state.
+ They are floating because they reference a fake prev_event which doesn't connect
+ to the normal DAG at all.
+
+ Args:
+ state_events_at_start:
+ room_id: Room where you want the events persisted in.
+ initial_auth_event_ids: These will be the auth_events for the first
+ state event created. Each event created afterwards will be
+ added to the list of auth events for the next state event
+ created.
+ app_service_requester: The requester of an application service.
+
+ Returns:
+ List of state event ID's we just persisted
+ """
+ assert app_service_requester.app_service
+
+ state_event_ids_at_start = []
+ auth_event_ids = initial_auth_event_ids.copy()
+ for state_event in state_events_at_start:
+ assert_params_in_dict(
+ state_event, ["type", "origin_server_ts", "content", "sender"]
+ )
+
+ logger.debug(
+ "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
+ state_event,
+ auth_event_ids,
+ )
+
+ event_dict = {
+ "type": state_event["type"],
+ "origin_server_ts": state_event["origin_server_ts"],
+ "content": state_event["content"],
+ "room_id": room_id,
+ "sender": state_event["sender"],
+ "state_key": state_event["state_key"],
+ }
+
+ # Mark all events as historical
+ event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
+
+ # Make the state events float off on their own so we don't have a
+ # bunch of `@mxid joined the room` noise between each batch
+ fake_prev_event_id = "$" + random_string(43)
+
+ # TODO: This is pretty much the same as some other code to handle inserting state in this file
+ if event_dict["type"] == EventTypes.Member:
+ membership = event_dict["content"].get("membership", None)
+ event_id, _ = await self.room_member_handler.update_membership(
+ await self.create_requester_for_user_id_from_app_service(
+ state_event["sender"], app_service_requester.app_service
+ ),
+ target=UserID.from_string(event_dict["state_key"]),
+ room_id=room_id,
+ action=membership,
+ content=event_dict["content"],
+ outlier=True,
+ prev_event_ids=[fake_prev_event_id],
+ # Make sure to use a copy of this list because we modify it
+ # later in the loop here. Otherwise it will be the same
+ # reference and also update in the event when we append later.
+ auth_event_ids=auth_event_ids.copy(),
+ )
+ else:
+ # TODO: Add some complement tests that adds state that is not member joins
+ # and will use this code path. Maybe we only want to support join state events
+ # and can get rid of this `else`?
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ await self.create_requester_for_user_id_from_app_service(
+ state_event["sender"], app_service_requester.app_service
+ ),
+ event_dict,
+ outlier=True,
+ prev_event_ids=[fake_prev_event_id],
+ # Make sure to use a copy of this list because we modify it
+ # later in the loop here. Otherwise it will be the same
+ # reference and also update in the event when we append later.
+ auth_event_ids=auth_event_ids.copy(),
+ )
+ event_id = event.event_id
+
+ state_event_ids_at_start.append(event_id)
+ auth_event_ids.append(event_id)
+
+ return state_event_ids_at_start
+
+ async def persist_historical_events(
+ self,
+ events_to_create: List[JsonDict],
+ room_id: str,
+ initial_prev_event_ids: List[str],
+ inherited_depth: int,
+ auth_event_ids: List[str],
+ app_service_requester: Requester,
+ ) -> List[str]:
+ """Create and persists all events provided sequentially. Handles the
+ complexity of creating events in chronological order so they can
+ reference each other by prev_event but still persists in
+ reverse-chronoloical order so they have the correct
+ (topological_ordering, stream_ordering) and sort correctly from
+ /messages.
+
+ Args:
+ events_to_create: List of historical events to create in JSON
+ dictionary format.
+ room_id: Room where you want the events persisted in.
+ initial_prev_event_ids: These will be the prev_events for the first
+ event created. Each event created afterwards will point to the
+ previous event created.
+ inherited_depth: The depth to create the events at (you will
+ probably by calling inherit_depth_from_prev_ids(...)).
+ auth_event_ids: Define which events allow you to create the given
+ event in the room.
+ app_service_requester: The requester of an application service.
+
+ Returns:
+ List of persisted event IDs
+ """
+ assert app_service_requester.app_service
+
+ prev_event_ids = initial_prev_event_ids.copy()
+
+ event_ids = []
+ events_to_persist = []
+ for ev in events_to_create:
+ assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
+
+ event_dict = {
+ "type": ev["type"],
+ "origin_server_ts": ev["origin_server_ts"],
+ "content": ev["content"],
+ "room_id": room_id,
+ "sender": ev["sender"], # requester.user.to_string(),
+ "prev_events": prev_event_ids.copy(),
+ }
+
+ # Mark all events as historical
+ event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
+
+ event, context = await self.event_creation_handler.create_event(
+ await self.create_requester_for_user_id_from_app_service(
+ ev["sender"], app_service_requester.app_service
+ ),
+ event_dict,
+ prev_event_ids=event_dict.get("prev_events"),
+ auth_event_ids=auth_event_ids,
+ historical=True,
+ depth=inherited_depth,
+ )
+ logger.debug(
+ "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
+ event,
+ prev_event_ids,
+ auth_event_ids,
+ )
+
+ assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
+ event.sender,
+ )
+
+ events_to_persist.append((event, context))
+ event_id = event.event_id
+
+ event_ids.append(event_id)
+ prev_event_ids = [event_id]
+
+ # Persist events in reverse-chronological order so they have the
+ # correct stream_ordering as they are backfilled (which decrements).
+ # Events are sorted by (topological_ordering, stream_ordering)
+ # where topological_ordering is just depth.
+ for (event, context) in reversed(events_to_persist):
+ await self.event_creation_handler.handle_new_client_event(
+ await self.create_requester_for_user_id_from_app_service(
+ event["sender"], app_service_requester.app_service
+ ),
+ event=event,
+ context=context,
+ )
+
+ return event_ids
+
+ async def handle_batch_of_events(
+ self,
+ events_to_create: List[JsonDict],
+ room_id: str,
+ batch_id_to_connect_to: str,
+ initial_prev_event_ids: List[str],
+ inherited_depth: int,
+ auth_event_ids: List[str],
+ app_service_requester: Requester,
+ ) -> Tuple[List[str], str]:
+ """
+ Handles creating and persisting all of the historical events as well
+ as insertion and batch meta events to make the batch navigable in the DAG.
+
+ Args:
+ events_to_create: List of historical events to create in JSON
+ dictionary format.
+ room_id: Room where you want the events created in.
+ batch_id_to_connect_to: The batch_id from the insertion event you
+ want this batch to connect to.
+ initial_prev_event_ids: These will be the prev_events for the first
+ event created. Each event created afterwards will point to the
+ previous event created.
+ inherited_depth: The depth to create the events at (you will
+ probably by calling inherit_depth_from_prev_ids(...)).
+ auth_event_ids: Define which events allow you to create the given
+ event in the room.
+ app_service_requester: The requester of an application service.
+
+ Returns:
+ Tuple containing a list of created events and the next_batch_id
+ """
+
+ # Connect this current batch to the insertion event from the previous batch
+ last_event_in_batch = events_to_create[-1]
+ batch_event = {
+ "type": EventTypes.MSC2716_BATCH,
+ "sender": app_service_requester.user.to_string(),
+ "room_id": room_id,
+ "content": {
+ EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
+ EventContentFields.MSC2716_HISTORICAL: True,
+ },
+ # Since the batch event is put at the end of the batch,
+ # where the newest-in-time event is, copy the origin_server_ts from
+ # the last event we're inserting
+ "origin_server_ts": last_event_in_batch["origin_server_ts"],
+ }
+ # Add the batch event to the end of the batch (newest-in-time)
+ events_to_create.append(batch_event)
+
+ # Add an "insertion" event to the start of each batch (next to the oldest-in-time
+ # event in the batch) so the next batch can be connected to this one.
+ insertion_event = self.create_insertion_event_dict(
+ sender=app_service_requester.user.to_string(),
+ room_id=room_id,
+ # Since the insertion event is put at the start of the batch,
+ # where the oldest-in-time event is, copy the origin_server_ts from
+ # the first event we're inserting
+ origin_server_ts=events_to_create[0]["origin_server_ts"],
+ )
+ next_batch_id = insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
+ ]
+ # Prepend the insertion event to the start of the batch (oldest-in-time)
+ events_to_create = [insertion_event] + events_to_create
+
+ # Create and persist all of the historical events
+ event_ids = await self.persist_historical_events(
+ events_to_create=events_to_create,
+ room_id=room_id,
+ initial_prev_event_ids=initial_prev_event_ids,
+ inherited_depth=inherited_depth,
+ auth_event_ids=auth_event_ids,
+ app_service_requester=app_service_requester,
+ )
+
+ return event_ids, next_batch_id
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index c3d4199e..ba7a14d6 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -36,8 +36,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -49,9 +47,10 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
-class RoomListHandler(BaseHandler):
+class RoomListHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.hs = hs
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index afa7e472..74e6c7ec 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -51,8 +51,6 @@ from synapse.types import (
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -89,8 +87,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.servernotices.server_notices_mxid
- self._enable_lookup = hs.config.enable_3pid_lookup
- self.allow_per_room_profiles = self.config.allow_per_room_profiles
+ self._enable_lookup = hs.config.registration.enable_3pid_lookup
+ self.allow_per_room_profiles = self.config.server.allow_per_room_profiles
self._join_rate_limiter_local = Ratelimiter(
store=self.store,
@@ -118,9 +116,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
)
- # This is only used to get at the ratelimit function. It's fine there are
- # multiple of these as it doesn't store state.
- self.base_handler = BaseHandler(hs)
+ self.request_ratelimiter = hs.get_request_ratelimiter()
@abc.abstractmethod
async def _remote_join(
@@ -434,6 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
outlier: bool = False,
prev_event_ids: Optional[List[str]] = None,
@@ -451,6 +448,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Information from a 3PID invite.
ratelimit: Whether to rate limit the request.
content: The content of the created event.
+ new_room: Whether the membership update is happening in the context of a room
+ creation.
require_consent: Whether consent is required.
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
@@ -485,6 +484,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
outlier=outlier,
prev_event_ids=prev_event_ids,
@@ -504,6 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
outlier: bool = False,
prev_event_ids: Optional[List[str]] = None,
@@ -523,6 +524,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed:
ratelimit:
content:
+ new_room: Whether the membership update is happening in the context of a room
+ creation.
require_consent:
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
@@ -625,7 +628,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
- if self.config.block_non_admin_invites:
+ if self.config.server.block_non_admin_invites:
logger.info(
"Blocking invite: user is not admin and non-admin "
"invites disabled"
@@ -726,6 +729,30 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ # Figure out whether the user is a server admin to determine whether they
+ # should be able to bypass the spam checker.
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ bypass_spam_checker = True
+
+ else:
+ bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if (
+ not bypass_spam_checker
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ and not new_room
+ and not await self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ )
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
# Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join(
target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
@@ -1230,7 +1257,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
- if self.config.block_non_admin_invites:
+ if self.config.server.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
@@ -1244,7 +1271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- await self.base_handler.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
@@ -1268,10 +1295,22 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if invitee:
# Note that update_membership with an action of "invite" can raise
# a ShadowBanError, but this was done above already.
+ # We don't check the invite against the spamchecker(s) here (through
+ # user_may_invite) because we'll do it further down the line anyway (in
+ # update_membership_locked).
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
+ # Check if the spamchecker(s) allow this invite to go through.
+ if not await self.spam_checker.user_may_send_3pid_invite(
+ inviter_userid=requester.user.to_string(),
+ medium=medium,
+ address=address,
+ room_id=room_id,
+ ):
+ raise SynapseError(403, "Cannot send threepid invite")
+
stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
@@ -1428,7 +1467,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
Returns: bool of whether the complexity is too great, or None
if unable to be fetched
"""
- max_complexity = self.hs.config.limit_remote_rooms.complexity
+ max_complexity = self.hs.config.server.limit_remote_rooms.complexity
complexity = await self.federation_handler.get_room_complexity(
remote_room_hosts, room_id
)
@@ -1444,7 +1483,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
Args:
room_id: The room ID to check for complexity.
"""
- max_complexity = self.hs.config.limit_remote_rooms.complexity
+ max_complexity = self.hs.config.server.limit_remote_rooms.complexity
complexity = await self.store.get_room_complexity(room_id)
return complexity["v1"] > max_complexity
@@ -1468,8 +1507,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
- check_complexity = self.hs.config.limit_remote_rooms.enabled
- if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join:
+ check_complexity = self.hs.config.server.limit_remote_rooms.enabled
+ if (
+ check_complexity
+ and self.hs.config.server.limit_remote_rooms.admins_can_join
+ ):
check_complexity = not await self.auth.is_server_admin(user)
if check_complexity:
@@ -1480,7 +1522,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if too_complex is True:
raise SynapseError(
code=400,
- msg=self.hs.config.limit_remote_rooms.complexity_error,
+ msg=self.hs.config.server.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
@@ -1515,7 +1557,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
raise SynapseError(
code=400,
- msg=self.hs.config.limit_remote_rooms.complexity_error,
+ msg=self.hs.config.server.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 2fed9f37..727d75a5 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -22,7 +22,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
@@ -51,9 +50,11 @@ class Saml2SessionData:
ui_auth_session_id: Optional[str] = None
-class SamlHandler(BaseHandler):
+class SamlHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.server_name = hs.hostname
self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 8226d6f5..a3ffa26b 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -26,17 +26,18 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class SearchHandler(BaseHandler):
+class SearchHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.state_handler = hs.get_state_handler()
+ self.clock = hs.get_clock()
+ self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -105,7 +106,7 @@ class SearchHandler(BaseHandler):
dict to be returned to the client with results of search
"""
- if not self.hs.config.enable_search:
+ if not self.hs.config.server.enable_search:
raise SynapseError(400, "Search is disabled on this homeserver")
batch_group = None
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index 25e6b012..1a062a78 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -105,8 +105,13 @@ async def _sendmail(
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
- # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong
- reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type]
+ reactor.connectTCP(
+ smtphost, # type: ignore[arg-type]
+ smtpport,
+ factory,
+ timeout=30,
+ bindAddress=None,
+ )
await make_deferred_yieldable(d)
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a63fac82..706ad727 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -17,19 +17,17 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class SetPasswordHandler(BaseHandler):
+class SetPasswordHandler:
"""Handler which deals with changing user account passwords"""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8f5d465f..184730eb 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker:
# msisdns are currently always ThreepidBehaviour.REMOTE
if medium == "msisdn":
- if not self.hs.config.account_threepid_delegate_msisdn:
+ if not self.hs.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
threepid = await identity_handler.threepid_from_creds(
- self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_msisdn,
+ threepid_creds,
)
elif medium == "email":
if (
self.hs.config.email.threepid_behaviour_email
== ThreepidBehaviour.REMOTE
):
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
threepid = await identity_handler.threepid_from_creds(
- self.hs.config.account_threepid_delegate_email, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_email,
+ threepid_creds,
)
elif (
self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
@@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
- return bool(self.hs.config.account_threepid_delegate_msisdn)
+ return bool(self.hs.config.registration.account_threepid_delegate_msisdn)
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("msisdn", authdict)
@@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
- self._enabled = bool(hs.config.registration_requires_token)
+ self._enabled = bool(hs.config.registration.registration_requires_token)
self.store = hs.get_datastore()
def is_enabled(self) -> bool:
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index b91e7cb5..8810f048 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -60,7 +60,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.update_user_directory = hs.config.update_user_directory
+ self.update_user_directory = hs.config.server.update_user_directory
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
@@ -132,12 +132,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- # Support users are for diagnostics and should not appear in the user directory.
- is_support = await self.store.is_support_user(user_id)
- # When change profile information of deactivated user it should not appear in the user directory.
- is_deactivated = await self.store.get_user_deactivated_status(user_id)
-
- if not (is_support or is_deactivated):
+ if await self.store.should_include_local_user_in_dir(user_id):
await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@@ -208,6 +203,7 @@ class UserDirectoryHandler(StateDeltasHandler):
public_value=Membership.JOIN,
)
+ is_remote = not self.is_mine_id(state_key)
if change is MatchChange.now_false:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
@@ -225,32 +221,36 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id in user_ids:
await self._handle_remove_user(room_id, user_id)
- return
+ continue
else:
logger.debug("Server is still in room: %r", room_id)
- is_support = await self.store.is_support_user(state_key)
- if not is_support:
+ include_in_dir = (
+ is_remote
+ or await self.store.should_include_local_user_in_dir(state_key)
+ )
+ if include_in_dir:
if change is MatchChange.no_change:
- # Handle any profile changes
- await self._handle_profile_change(
- state_key, room_id, prev_event_id, event_id
- )
+ # Handle any profile changes for remote users.
+ # (For local users we are not forced to scan membership
+ # events; instead the rest of the application calls
+ # `handle_local_profile_change`.)
+ if is_remote:
+ await self._handle_profile_change(
+ state_key, room_id, prev_event_id, event_id
+ )
continue
if change is MatchChange.now_true: # The user joined
- event = await self.store.get_event(event_id, allow_none=True)
- # It isn't expected for this event to not exist, but we
- # don't want the entire background process to break.
- if event is None:
- continue
-
- profile = ProfileInfo(
- avatar_url=event.content.get("avatar_url"),
- display_name=event.content.get("displayname"),
- )
-
- await self._handle_new_user(room_id, state_key, profile)
+ # This may be the first time we've seen a remote user. If
+ # so, ensure we have a directory entry for them. (We don't
+ # need to do this for local users: their directory entry
+ # is created at the point of registration.
+ if is_remote:
+ await self._upsert_directory_entry_for_remote_user(
+ state_key, event_id
+ )
+ await self._track_user_joined_room(room_id, state_key)
else: # The user left
await self._handle_remove_user(room_id, state_key)
else:
@@ -300,7 +300,7 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id
)
- logger.debug("Change: %r, publicness: %r", publicness, is_public)
+ logger.debug("Publicness change: %r, is_public: %r", publicness, is_public)
if publicness is MatchChange.now_true and not is_public:
# If we became world readable but room isn't currently public then
@@ -311,42 +311,50 @@ class UserDirectoryHandler(StateDeltasHandler):
# ignore the change
return
- other_users_in_room_with_profiles = (
- await self.store.get_users_in_room_with_profiles(room_id)
- )
+ users_in_room = await self.store.get_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
- for user_id in other_users_in_room_with_profiles.keys():
+ for user_id in users_in_room:
await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables.
- # NOTE: this is not the most efficient method, as handle_new_user sets
+ # NOTE: this is not the most efficient method, as _track_user_joined_room sets
# up local_user -> other_user and other_user_whos_local -> local_user,
# which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
- for user_id, profile in other_users_in_room_with_profiles.items():
- await self._handle_new_user(room_id, user_id, profile)
+ for user_id in users_in_room:
+ await self._track_user_joined_room(room_id, user_id)
- async def _handle_new_user(
- self, room_id: str, user_id: str, profile: ProfileInfo
+ async def _upsert_directory_entry_for_remote_user(
+ self, user_id: str, event_id: str
) -> None:
- """Called when we might need to add user to directory
-
- Args:
- room_id: The room ID that user joined or started being public
- user_id
+ """A remote user has just joined a room. Ensure they have an entry in
+ the user directory. The caller is responsible for making sure they're
+ remote.
"""
+ event = await self.store.get_event(event_id, allow_none=True)
+ # It isn't expected for this event to not exist, but we
+ # don't want the entire background process to break.
+ if event is None:
+ return
+
logger.debug("Adding new user to dir, %r", user_id)
await self.store.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
+ user_id, event.content.get("displayname"), event.content.get("avatar_url")
)
+ async def _track_user_joined_room(self, room_id: str, user_id: str) -> None:
+ """Someone's just joined a room. Update `users_in_public_rooms` or
+ `users_who_share_private_rooms` as appropriate.
+
+ The caller is responsible for ensuring that the given user is not excluded
+ from the user directory.
+ """
is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
- # Now we update users who share rooms with users.
other_users_in_room = await self.store.get_users_in_room(room_id)
if is_public:
@@ -356,13 +364,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# First, if they're our user then we need to update for every user
if self.is_mine_id(user_id):
-
- is_appservice = self.store.get_if_app_services_interested_in_user(
- user_id
- )
-
- # We don't care about appservice users.
- if not is_appservice:
+ if await self.store.should_include_local_user_in_dir(user_id):
for other_user_id in other_users_in_room:
if user_id == other_user_id:
continue
@@ -374,10 +376,10 @@ class UserDirectoryHandler(StateDeltasHandler):
if user_id == other_user_id:
continue
- is_appservice = self.store.get_if_app_services_interested_in_user(
+ include_other_user = self.is_mine_id(
other_user_id
- )
- if self.is_mine_id(other_user_id) and not is_appservice:
+ ) and await self.store.should_include_local_user_in_dir(other_user_id)
+ if include_other_user:
to_insert.add((other_user_id, user_id))
if to_insert:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5204c3d0..b5a2d333 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -912,7 +912,7 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def __init__(self):
self._context = SSL.Context(SSL.SSLv23_METHOD)
- self._context.set_verify(VERIFY_NONE, lambda *_: None)
+ self._context.set_verify(VERIFY_NONE, lambda *_: False)
def getContext(self, hostname=None, port=None):
return self._context
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cdc36b8d..4f592246 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -327,23 +327,23 @@ class MatrixFederationHttpClient:
self.reactor = hs.get_reactor()
user_agent = hs.version_string
- if hs.config.user_agent_suffix:
- user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
+ if hs.config.server.user_agent_suffix:
+ user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
user_agent = user_agent.encode("ascii")
federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
- hs.config.federation_ip_range_whitelist,
- hs.config.federation_ip_range_blacklist,
+ hs.config.server.federation_ip_range_whitelist,
+ hs.config.server.federation_ip_range_blacklist,
)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
federation_agent,
- ip_blacklist=hs.config.federation_ip_range_blacklist,
+ ip_blacklist=hs.config.server.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 0df1bfbe..897ba5e4 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -563,7 +563,10 @@ class _ByteProducer:
try:
self._request.registerProducer(self, True)
- except RuntimeError as e:
+ except AttributeError as e:
+ # Calling self._request.registerProducer might raise an AttributeError since
+ # the underlying Twisted code calls self._request.channel.registerProducer,
+ # however self._request.channel will be None if the connection was lost.
logger.info("Connection disconnected before response was written: %r", e)
# We drop our references to data we'll not use.
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 6e82f7c7..b78d6e17 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -65,6 +65,12 @@ class JsonFormatter(logging.Formatter):
if key not in _IGNORED_LOG_RECORD_ATTRIBUTES:
event[key] = value
+ if record.exc_info:
+ exc_type, exc_value, _ = record.exc_info
+ if exc_type:
+ event["exc_type"] = f"{exc_type.__name__}"
+ event["exc_value"] = f"{exc_value}"
+
return _encoder.encode(event)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 02e5ddd2..bdc01877 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -52,7 +52,7 @@ try:
is_thread_resource_usage_supported = True
- def get_thread_resource_usage() -> "Optional[resource._RUsage]":
+ def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
return resource.getrusage(RUSAGE_THREAD)
@@ -61,7 +61,7 @@ except Exception:
# won't track resource usage.
is_thread_resource_usage_supported = False
- def get_thread_resource_usage() -> "Optional[resource._RUsage]":
+ def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
return None
@@ -226,10 +226,10 @@ class _Sentinel:
def copy_to(self, record):
pass
- def start(self, rusage: "Optional[resource._RUsage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]"):
pass
- def stop(self, rusage: "Optional[resource._RUsage]"):
+ def stop(self, rusage: "Optional[resource.struct_rusage]"):
pass
def add_database_transaction(self, duration_sec):
@@ -289,7 +289,7 @@ class LoggingContext:
# The thread resource usage when the logcontext became active. None
# if the context is not currently active.
- self.usage_start: Optional[resource._RUsage] = None
+ self.usage_start: Optional[resource.struct_rusage] = None
self.main_thread = get_thread_id()
self.request = None
@@ -410,7 +410,7 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def start(self, rusage: "Optional[resource._RUsage]") -> None:
+ def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is currently running.
@@ -435,7 +435,7 @@ class LoggingContext:
else:
self.usage_start = rusage
- def stop(self, rusage: "Optional[resource._RUsage]") -> None:
+ def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is no longer running.
@@ -490,7 +490,7 @@ class LoggingContext:
return res
- def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]:
+ def _get_cputime(self, current: "resource.struct_rusage") -> Tuple[float, float]:
"""Get the cpu usage time between start() and the given rusage
Args:
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 03d2dd94..20d23a42 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -339,6 +339,7 @@ def ensure_active_span(message, ret=None):
"There was no active span when trying to %s."
" Did you forget to start one or did a context slip?",
message,
+ stack_info=True,
)
return ret
@@ -806,6 +807,14 @@ def trace(func=None, opname=None):
result.addCallbacks(call_back, err_back)
else:
+ if inspect.isawaitable(result):
+ logger.error(
+ "@trace may not have wrapped %s correctly! "
+ "The function is not async but returned a %s.",
+ func.__qualname__,
+ type(result).__name__,
+ )
+
scope.__exit__(None, None, None)
return result
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 3a142607..2ab599a3 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -265,7 +265,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self)
- def start(self, rusage: "Optional[resource._RUsage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]"):
"""Log context has started running (again)."""
super().start(rusage)
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 2c23afe8..820f6f3f 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -94,7 +94,7 @@ class Pusher(metaclass=abc.ABCMeta):
self._start_processing()
@abc.abstractmethod
- def _start_processing(self):
+ def _start_processing(self) -> None:
"""Start processing push notifications."""
raise NotImplementedError()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c337e530..0622a37a 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -290,6 +290,12 @@ def _condition_checker(
return True
+MemberMap = Dict[str, Tuple[str, str]]
+Rule = Dict[str, dict]
+RulesByUser = Dict[str, List[Rule]]
+StateGroup = Union[object, int]
+
+
@attr.s(slots=True)
class RulesForRoomData:
"""The data stored in the cache by `RulesForRoom`.
@@ -299,16 +305,16 @@ class RulesForRoomData:
"""
# event_id -> (user_id, state)
- member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict)
+ member_map = attr.ib(type=MemberMap, factory=dict)
# user_id -> rules
- rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict)
+ rules_by_user = attr.ib(type=RulesByUser, factory=dict)
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
- state_group = attr.ib(type=Union[object, int], factory=object)
+ state_group = attr.ib(type=StateGroup, factory=object)
# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
@@ -532,7 +538,13 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
- def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
+ def update_cache(
+ self,
+ sequence: int,
+ members: MemberMap,
+ rules_by_user: RulesByUser,
+ state_group: StateGroup,
+ ) -> None:
if sequence == self.data.sequence:
self.data.member_map.update(members)
self.data.rules_by_user = rules_by_user
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 1fc9716a..c5708cd8 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -19,7 +19,9 @@ from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MA
from synapse.types import UserID
-def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
+def format_push_rules_for_user(
+ user: UserID, ruleslist: List
+) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eac65572..dbf4ad7f 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -403,10 +403,10 @@ class HttpPusher(Pusher):
rejected = resp["rejected"]
return rejected
- async def _send_badge(self, badge):
+ async def _send_badge(self, badge: int) -> None:
"""
Args:
- badge (int): number of unread messages
+ badge: number of unread messages
"""
logger.debug("Sending updated badge count %d to %s", badge, self.name)
d = {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index e38e3c5d..ce299ba3 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -892,7 +892,7 @@ def safe_text(raw_text: str) -> jinja2.Markup:
A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup(
- bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
+ bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False))
)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f1b78d09..e047ec74 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -182,85 +182,87 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
)
@trace(opname="outgoing_replication_request")
- @outgoing_gauge.track_inprogress()
async def send_request(*, instance_name="master", **kwargs):
- if instance_name == local_instance_name:
- raise Exception("Trying to send HTTP request to self")
- if instance_name == "master":
- host = master_host
- port = master_port
- elif instance_name in instance_map:
- host = instance_map[instance_name].host
- port = instance_map[instance_name].port
- else:
- raise Exception(
- "Instance %r not in 'instance_map' config" % (instance_name,)
+ with outgoing_gauge.track_inprogress():
+ if instance_name == local_instance_name:
+ raise Exception("Trying to send HTTP request to self")
+ if instance_name == "master":
+ host = master_host
+ port = master_port
+ elif instance_name in instance_map:
+ host = instance_map[instance_name].host
+ port = instance_map[instance_name].port
+ else:
+ raise Exception(
+ "Instance %r not in 'instance_map' config" % (instance_name,)
+ )
+
+ data = await cls._serialize_payload(**kwargs)
+
+ url_args = [
+ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
+ ]
+
+ if cls.CACHE:
+ txn_id = random_string(10)
+ url_args.append(txn_id)
+
+ if cls.METHOD == "POST":
+ request_func = client.post_json_get_json
+ elif cls.METHOD == "PUT":
+ request_func = client.put_json
+ elif cls.METHOD == "GET":
+ request_func = client.get_json
+ else:
+ # We have already asserted in the constructor that a
+ # compatible was picked, but lets be paranoid.
+ raise Exception(
+ "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
+ )
+
+ uri = "http://%s:%s/_synapse/replication/%s/%s" % (
+ host,
+ port,
+ cls.NAME,
+ "/".join(url_args),
)
- data = await cls._serialize_payload(**kwargs)
-
- url_args = [
- urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
- ]
-
- if cls.CACHE:
- txn_id = random_string(10)
- url_args.append(txn_id)
-
- if cls.METHOD == "POST":
- request_func = client.post_json_get_json
- elif cls.METHOD == "PUT":
- request_func = client.put_json
- elif cls.METHOD == "GET":
- request_func = client.get_json
- else:
- # We have already asserted in the constructor that a
- # compatible was picked, but lets be paranoid.
- raise Exception(
- "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
- )
-
- uri = "http://%s:%s/_synapse/replication/%s/%s" % (
- host,
- port,
- cls.NAME,
- "/".join(url_args),
- )
-
- try:
- # We keep retrying the same request for timeouts. This is so that we
- # have a good idea that the request has either succeeded or failed on
- # the master, and so whether we should clean up or not.
- while True:
- headers: Dict[bytes, List[bytes]] = {}
- # Add an authorization header, if configured.
- if replication_secret:
- headers[b"Authorization"] = [b"Bearer " + replication_secret]
- opentracing.inject_header_dict(headers, check_destination=False)
- try:
- result = await request_func(uri, data, headers=headers)
- break
- except RequestTimedOutError:
- if not cls.RETRY_ON_TIMEOUT:
- raise
-
- logger.warning("%s request timed out; retrying", cls.NAME)
-
- # If we timed out we probably don't need to worry about backing
- # off too much, but lets just wait a little anyway.
- await clock.sleep(1)
- except HttpResponseException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the main process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- _outgoing_request_counter.labels(cls.NAME, e.code).inc()
- raise e.to_synapse_error()
- except Exception as e:
- _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
- raise SynapseError(502, "Failed to talk to main process") from e
-
- _outgoing_request_counter.labels(cls.NAME, 200).inc()
- return result
+ try:
+ # We keep retrying the same request for timeouts. This is so that we
+ # have a good idea that the request has either succeeded or failed
+ # on the master, and so whether we should clean up or not.
+ while True:
+ headers: Dict[bytes, List[bytes]] = {}
+ # Add an authorization header, if configured.
+ if replication_secret:
+ headers[b"Authorization"] = [
+ b"Bearer " + replication_secret
+ ]
+ opentracing.inject_header_dict(headers, check_destination=False)
+ try:
+ result = await request_func(uri, data, headers=headers)
+ break
+ except RequestTimedOutError:
+ if not cls.RETRY_ON_TIMEOUT:
+ raise
+
+ logger.warning("%s request timed out; retrying", cls.NAME)
+
+ # If we timed out we probably don't need to worry about backing
+ # off too much, but lets just wait a little anyway.
+ await clock.sleep(1)
+ except HttpResponseException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the main process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ _outgoing_request_counter.labels(cls.NAME, e.code).inc()
+ raise e.to_synapse_error()
+ except Exception as e:
+ _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
+ raise SynapseError(502, "Failed to talk to main process") from e
+
+ _outgoing_request_counter.labels(cls.NAME, 200).inc()
+ return result
return send_request
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 2cb74890..8c1bf922 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -13,14 +13,14 @@
# limitations under the License.
from typing import List, Optional, Tuple
-from synapse.storage.types import Connection
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker:
def __init__(
self,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
table: str,
column: str,
extra_tables: Optional[List[Tuple[str, str]]] = None,
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 2672a2c9..cea90c0f 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,9 +15,8 @@
from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.pusher import PusherWorkerStore
-from synapse.storage.types import Connection
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -27,7 +26,12 @@ if TYPE_CHECKING:
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 37769ace..961c1776 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -117,7 +117,7 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
- self._notify_pushers = hs.config.start_pushers
+ self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool()
self._presence_handler = hs.get_presence_handler()
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1438a82b..6aa93180 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -171,7 +171,10 @@ class ReplicationCommandHandler:
if hs.config.worker.worker_app is not None:
continue
- if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ if (
+ stream.NAME == FederationStream.NAME
+ and hs.config.worker.send_federation
+ ):
# We only support federation stream if federation sending
# has been disabled on the master.
continue
@@ -225,7 +228,7 @@ class ReplicationCommandHandler:
self._is_master = hs.config.worker.worker_app is None
self._federation_sender = None
- if self._is_master and not hs.config.send_federation:
+ if self._is_master and not hs.config.worker.send_federation:
self._federation_sender = hs.get_federation_sender()
self._server_notices_sender = None
@@ -315,7 +318,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
- hs.config.redis.redis_host.encode(),
+ hs.config.redis.redis_host, # type: ignore[arg-type]
hs.config.redis.redis_port,
self._factory,
)
@@ -324,7 +327,11 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker.worker_replication_host
port = hs.config.worker.worker_replication_port
- hs.get_reactor().connectTCP(host.encode(), port, self._factory)
+ hs.get_reactor().connectTCP(
+ host, # type: ignore[arg-type]
+ port,
+ self._factory,
+ )
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 8c0df627..062fe2f3 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -364,6 +364,12 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
+ reactor.connectTCP(
+ host, # type: ignore[arg-type]
+ port,
+ factory,
+ timeout=30,
+ bindAddress=None,
+ )
return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 030852cb..80f9b23b 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -71,7 +71,7 @@ class ReplicationStreamer:
self.notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
- self._replication_torture_level = hs.config.replication_torture_level
+ self._replication_torture_level = hs.config.server.replication_torture_level
self.notifier.add_replication_callback(self.on_notifier_poke)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 46bfec46..f20aa653 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()
- if not self.hs.config.registration_shared_secret:
+ if not self.hs.config.registration.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
body = parse_json_object_from_request(request)
@@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet):
got_mac = body["mac"]
want_mac_builder = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
+ key=self.hs.config.registration.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac_builder.update(nonce.encode("utf8"))
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 6a7608d6..6b272658 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -119,7 +119,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
if existing_user_id is None:
- if self.config.request_token_inhibit_3pid_errors:
+ if self.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
# Have the configured identity server handle the request
ret = await self.identity_handler.requestEmailToken(
- self.hs.config.account_threepid_delegate_email,
+ self.hs.config.registration.account_threepid_delegate_email,
email,
client_secret,
send_attempt,
@@ -403,7 +403,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
- if self.config.request_token_inhibit_3pid_errors:
+ if self.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
# Have the configured identity server handle the request
ret = await self.identity_handler.requestEmailToken(
- self.hs.config.account_threepid_delegate_email,
+ self.hs.config.registration.account_threepid_delegate_email,
email,
client_secret,
send_attempt,
@@ -486,7 +486,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
- if self.hs.config.request_token_inhibit_3pid_errors:
+ if self.hs.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
- if not self.hs.config.account_threepid_delegate_msisdn:
+ if not self.hs.config.registration.account_threepid_delegate_msisdn:
logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
@@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
)
ret = await self.identity_handler.requestMsisdnToken(
- self.hs.config.account_threepid_delegate_msisdn,
+ self.hs.config.registration.account_threepid_delegate_msisdn,
country,
phone_number,
client_secret,
@@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
- if not self.config.account_threepid_delegate_msisdn:
+ if not self.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
400,
"This homeserver is not validating phone numbers. Use an identity server "
@@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
# Proxy submit_token request to msisdn threepid delegate
response = await self.identity_handler.proxy_msisdn_submit_token(
- self.config.account_threepid_delegate_msisdn,
+ self.config.registration.account_threepid_delegate_msisdn,
body["client_secret"],
body["sid"],
body["token"],
@@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_3pid_changes:
+ if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
@@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet):
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_3pid_changes:
+ if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
@@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_3pid_changes:
+ if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
@@ -857,8 +857,8 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
# If the domain whitelist is set, the domain must be in it
if (
valid
- and hs.config.next_link_domain_whitelist is not None
- and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+ and hs.config.server.next_link_domain_whitelist is not None
+ and next_link_parsed.hostname not in hs.config.server.next_link_domain_whitelist
):
valid = False
@@ -878,9 +878,13 @@ class WhoamiRestServlet(RestServlet):
self.auth = hs.get_auth()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- response = {"user_id": requester.user.to_string()}
+ response = {
+ "user_id": requester.user.to_string(),
+ # MSC: https://github.com/matrix-org/matrix-doc/pull/3069
+ "org.matrix.msc3069.is_guest": bool(requester.is_guest),
+ }
# Appservices and similar accounts do not have device IDs
# that we can report on, so exclude them for compliance.
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 282861fa..9c15a043 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -48,9 +48,11 @@ class AuthRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.captcha.recaptcha_template
- self.terms_template = hs.config.terms_template
- self.registration_token_template = hs.config.registration_token_template
- self.success_template = hs.config.fallback_success_template
+ self.terms_template = hs.config.consent.terms_template
+ self.registration_token_template = (
+ hs.config.registration.registration_token_template
+ )
+ self.success_template = hs.config.registration.fallback_success_template
async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 65b3b5ce..2a3e24ae 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -44,10 +44,10 @@ class CapabilitiesRestServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
change_password = self.auth_handler.can_change_password()
- response = {
+ response: JsonDict = {
"capabilities": {
"m.room_versions": {
- "default": self.config.default_room_version.identifier,
+ "default": self.config.server.default_room_version.identifier,
"available": {
v.identifier: v.disposition
for v in KNOWN_ROOM_VERSIONS.values()
@@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet):
if self.config.experimental.msc3283_enabled:
response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
- "enabled": self.config.enable_set_displayname
+ "enabled": self.config.registration.enable_set_displayname
}
response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
- "enabled": self.config.enable_set_avatar_url
+ "enabled": self.config.registration.enable_set_avatar_url
}
response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
- "enabled": self.config.enable_3pid_changes
+ "enabled": self.config.registration.enable_3pid_changes
}
return 200, response
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 6ed60c74..cc1c2f97 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -90,7 +90,7 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request)
- set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
+ set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index fa5c173f..d49a647b 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2.saml2_enabled
self.cas_enabled = hs.config.cas.cas_enabled
self.oidc_enabled = hs.config.oidc.oidc_enabled
- self._msc2918_enabled = hs.config.access_token_lifetime is not None
+ self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
self.auth = hs.get_auth()
@@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
self._clock = hs.get_clock()
- self.access_token_lifetime = hs.config.access_token_lifetime
+ self.access_token_lifetime = hs.config.registration.access_token_lifetime
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
@@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
- if hs.config.access_token_lifetime is not None:
+ if hs.config.registration.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas.cas_enabled:
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index d0f20de5..c684636c 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -41,7 +41,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester_user = None
- if self.hs.config.require_auth_for_profile_requests:
+ if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
@@ -94,7 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester_user = None
- if self.hs.config.require_auth_for_profile_requests:
+ if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
@@ -146,7 +146,7 @@ class ProfileRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester_user = None
- if self.hs.config.require_auth_for_profile_requests:
+ if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index ecebc46e..6f796d5e 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -61,7 +61,9 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
- self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+ self._users_new_default_push_rules = (
+ hs.config.server.users_new_default_push_rules
+ )
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 48b0062c..bf3cb341 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -129,7 +129,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
- if self.hs.config.request_token_inhibit_3pid_errors:
+ if self.hs.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
# Have the configured identity server handle the request
ret = await self.identity_handler.requestEmailToken(
- self.hs.config.account_threepid_delegate_email,
+ self.hs.config.registration.account_threepid_delegate_email,
email,
client_secret,
send_attempt,
@@ -209,7 +209,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
- if self.hs.config.request_token_inhibit_3pid_errors:
+ if self.hs.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
- if not self.hs.config.account_threepid_delegate_msisdn:
+ if not self.hs.config.registration.account_threepid_delegate_msisdn:
logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
@@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
ret = await self.identity_handler.requestMsisdnToken(
- self.hs.config.account_threepid_delegate_msisdn,
+ self.hs.config.registration.account_threepid_delegate_msisdn,
country,
phone_number,
client_secret,
@@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
)
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_registration:
+ if not self.hs.config.registration.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
@@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
- if not self.hs.config.enable_registration:
+ if not self.hs.config.registration.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
@@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
- self._registration_enabled = self.hs.config.enable_registration
- self._msc2918_enabled = hs.config.access_token_lifetime is not None
+ self._registration_enabled = self.hs.config.registration.enable_registration
+ self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -682,7 +682,7 @@ class RegisterRestServlet(RestServlet):
# written to the db
if threepid:
if is_threepid_reserved(
- self.hs.config.mau_limits_reserved_threepids, threepid
+ self.hs.config.server.mau_limits_reserved_threepids, threepid
):
await self.store.upsert_monthly_active_user(registered_user_id)
@@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet):
async def _do_guest_registration(
self, params: JsonDict, address: Optional[str] = None
) -> Tuple[int, JsonDict]:
- if not self.hs.config.allow_guest_access:
+ if not self.hs.config.registration.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user(
make_guest=True, address=address
@@ -849,13 +849,13 @@ def _calculate_registration_flows(
"""
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
- require_email = "email" in config.registrations_require_3pid
- require_msisdn = "msisdn" in config.registrations_require_3pid
+ require_email = "email" in config.registration.registrations_require_3pid
+ require_msisdn = "msisdn" in config.registration.registrations_require_3pid
show_msisdn = True
show_email = True
- if config.disable_msisdn_registration:
+ if config.registration.disable_msisdn_registration:
show_msisdn = False
require_msisdn = False
@@ -909,7 +909,7 @@ def _calculate_registration_flows(
flow.insert(0, LoginType.RECAPTCHA)
# Prepend registration token to all flows if we're requiring a token
- if config.registration_requires_token:
+ if config.registration.registration_requires_token:
for flow in flows:
flow.insert(0, LoginType.REGISTRATION_TOKEN)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index bf46dc60..ed95189b 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -369,7 +369,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
# federations.
- if not self.hs.config.allow_public_rooms_without_auth:
+ if not self.hs.config.server.allow_public_rooms_without_auth:
raise
# We allow people to not be authed if they're just looking at our
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index bf14ec38..38ad4c24 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -15,13 +15,12 @@
import logging
import re
from http import HTTPStatus
-from typing import TYPE_CHECKING, Awaitable, List, Tuple
+from typing import TYPE_CHECKING, Awaitable, Tuple
from twisted.web.server import Request
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.appservice import ApplicationService
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -32,7 +31,7 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
@@ -77,102 +76,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
- self.hs = hs
self.store = hs.get_datastore()
- self.state_store = hs.get_storage().state
self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
+ self.room_batch_handler = hs.get_room_batch_handler()
self.txns = HttpTransactionCache(hs)
- async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
- (
- most_recent_prev_event_id,
- most_recent_prev_event_depth,
- ) = await self.store.get_max_depth_of(prev_event_ids)
-
- # We want to insert the historical event after the `prev_event` but before the successor event
- #
- # We inherit depth from the successor event instead of the `prev_event`
- # because events returned from `/messages` are first sorted by `topological_ordering`
- # which is just the `depth` and then tie-break with `stream_ordering`.
- #
- # We mark these inserted historical events as "backfilled" which gives them a
- # negative `stream_ordering`. If we use the same depth as the `prev_event`,
- # then our historical event will tie-break and be sorted before the `prev_event`
- # when it should come after.
- #
- # We want to use the successor event depth so they appear after `prev_event` because
- # it has a larger `depth` but before the successor event because the `stream_ordering`
- # is negative before the successor event.
- successor_event_ids = await self.store.get_successor_events(
- [most_recent_prev_event_id]
- )
-
- # If we can't find any successor events, then it's a forward extremity of
- # historical messages and we can just inherit from the previous historical
- # event which we can already assume has the correct depth where we want
- # to insert into.
- if not successor_event_ids:
- depth = most_recent_prev_event_depth
- else:
- (
- _,
- oldest_successor_depth,
- ) = await self.store.get_min_depth_of(successor_event_ids)
-
- depth = oldest_successor_depth
-
- return depth
-
- def _create_insertion_event_dict(
- self, sender: str, room_id: str, origin_server_ts: int
- ) -> JsonDict:
- """Creates an event dict for an "insertion" event with the proper fields
- and a random batch ID.
-
- Args:
- sender: The event author MXID
- room_id: The room ID that the event belongs to
- origin_server_ts: Timestamp when the event was sent
-
- Returns:
- The new event dictionary to insert.
- """
-
- next_batch_id = random_string(8)
- insertion_event = {
- "type": EventTypes.MSC2716_INSERTION,
- "sender": sender,
- "room_id": room_id,
- "content": {
- EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
- EventContentFields.MSC2716_HISTORICAL: True,
- },
- "origin_server_ts": origin_server_ts,
- }
-
- return insertion_event
-
- async def _create_requester_for_user_id_from_app_service(
- self, user_id: str, app_service: ApplicationService
- ) -> Requester:
- """Creates a new requester for the given user_id
- and validates that the app service is allowed to control
- the given user.
-
- Args:
- user_id: The author MXID that the app service is controlling
- app_service: The app service that controls the user
-
- Returns:
- Requester object
- """
-
- await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
-
- return create_requester(user_id, app_service=app_service)
-
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
@@ -200,121 +109,62 @@ class RoomBatchSendEventRestServlet(RestServlet):
errcode=Codes.MISSING_PARAM,
)
+ # Verify the batch_id_from_query corresponds to an actual insertion event
+ # and have the batch connected.
+ if batch_id_from_query:
+ corresponding_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(
+ room_id, batch_id_from_query
+ )
+ )
+ if corresponding_insertion_event_id is None:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "No insertion event corresponds to the given ?batch_id",
+ errcode=Codes.INVALID_PARAM,
+ )
+
# For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base
# to auth our historical messages against.
- (
- most_recent_prev_event_id,
- _,
- ) = await self.store.get_max_depth_of(prev_event_ids_from_query)
- # mapping from (type, state_key) -> state_event_id
- prev_state_map = await self.state_store.get_state_ids_for_event(
- most_recent_prev_event_id
+ auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list(
+ prev_event_ids_from_query
)
- # List of state event ID's
- prev_state_ids = list(prev_state_map.values())
- auth_event_ids = prev_state_ids
-
- state_event_ids_at_start = []
- for state_event in body["state_events_at_start"]:
- assert_params_in_dict(
- state_event, ["type", "origin_server_ts", "content", "sender"]
- )
- logger.debug(
- "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
- state_event,
- auth_event_ids,
+ # Create and persist all of the state events that float off on their own
+ # before the batch. These will most likely be all of the invite/member
+ # state events used to auth the upcoming historical messages.
+ state_event_ids_at_start = (
+ await self.room_batch_handler.persist_state_events_at_start(
+ state_events_at_start=body["state_events_at_start"],
+ room_id=room_id,
+ initial_auth_event_ids=auth_event_ids,
+ app_service_requester=requester,
)
+ )
+ # Update our ongoing auth event ID list with all of the new state we
+ # just created
+ auth_event_ids.extend(state_event_ids_at_start)
- event_dict = {
- "type": state_event["type"],
- "origin_server_ts": state_event["origin_server_ts"],
- "content": state_event["content"],
- "room_id": room_id,
- "sender": state_event["sender"],
- "state_key": state_event["state_key"],
- }
-
- # Mark all events as historical
- event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
-
- # Make the state events float off on their own
- fake_prev_event_id = "$" + random_string(43)
-
- # TODO: This is pretty much the same as some other code to handle inserting state in this file
- if event_dict["type"] == EventTypes.Member:
- membership = event_dict["content"].get("membership", None)
- event_id, _ = await self.room_member_handler.update_membership(
- await self._create_requester_for_user_id_from_app_service(
- state_event["sender"], requester.app_service
- ),
- target=UserID.from_string(event_dict["state_key"]),
- room_id=room_id,
- action=membership,
- content=event_dict["content"],
- outlier=True,
- prev_event_ids=[fake_prev_event_id],
- # Make sure to use a copy of this list because we modify it
- # later in the loop here. Otherwise it will be the same
- # reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
- )
- else:
- # TODO: Add some complement tests that adds state that is not member joins
- # and will use this code path. Maybe we only want to support join state events
- # and can get rid of this `else`?
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- await self._create_requester_for_user_id_from_app_service(
- state_event["sender"], requester.app_service
- ),
- event_dict,
- outlier=True,
- prev_event_ids=[fake_prev_event_id],
- # Make sure to use a copy of this list because we modify it
- # later in the loop here. Otherwise it will be the same
- # reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
- )
- event_id = event.event_id
-
- state_event_ids_at_start.append(event_id)
- auth_event_ids.append(event_id)
-
- events_to_create = body["events"]
-
- inherited_depth = await self._inherit_depth_from_prev_ids(
+ inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query
)
+ events_to_create = body["events"]
+
# Figure out which batch to connect to. If they passed in
# batch_id_from_query let's use it. The batch ID passed in comes
# from the batch_id in the "insertion" event from the previous batch.
last_event_in_batch = events_to_create[-1]
- batch_id_to_connect_to = batch_id_from_query
base_insertion_event = None
if batch_id_from_query:
+ batch_id_to_connect_to = batch_id_from_query
# All but the first base insertion event should point at a fake
# event, which causes the HS to ask for the state at the start of
# the batch later.
+ fake_prev_event_id = "$" + random_string(43)
prev_event_ids = [fake_prev_event_id]
-
- # Verify the batch_id_from_query corresponds to an actual insertion event
- # and have the batch connected.
- corresponding_insertion_event_id = (
- await self.store.get_insertion_event_by_batch_id(batch_id_from_query)
- )
- if corresponding_insertion_event_id is None:
- raise SynapseError(
- 400,
- "No insertion event corresponds to the given ?batch_id",
- errcode=Codes.INVALID_PARAM,
- )
- pass
# Otherwise, create an insertion event to act as a starting point.
#
# We don't always have an insertion event to start hanging more history
@@ -325,10 +175,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
else:
prev_event_ids = prev_event_ids_from_query
- base_insertion_event_dict = self._create_insertion_event_dict(
- sender=requester.user.to_string(),
- room_id=room_id,
- origin_server_ts=last_event_in_batch["origin_server_ts"],
+ base_insertion_event_dict = (
+ self.room_batch_handler.create_insertion_event_dict(
+ sender=requester.user.to_string(),
+ room_id=room_id,
+ origin_server_ts=last_event_in_batch["origin_server_ts"],
+ )
)
base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
@@ -336,7 +188,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
base_insertion_event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- await self._create_requester_for_user_id_from_app_service(
+ await self.room_batch_handler.create_requester_for_user_id_from_app_service(
base_insertion_event_dict["sender"],
requester.app_service,
),
@@ -351,92 +203,17 @@ class RoomBatchSendEventRestServlet(RestServlet):
EventContentFields.MSC2716_NEXT_BATCH_ID
]
- # Connect this current batch to the insertion event from the previous batch
- batch_event = {
- "type": EventTypes.MSC2716_BATCH,
- "sender": requester.user.to_string(),
- "room_id": room_id,
- "content": {
- EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
- EventContentFields.MSC2716_HISTORICAL: True,
- },
- # Since the batch event is put at the end of the batch,
- # where the newest-in-time event is, copy the origin_server_ts from
- # the last event we're inserting
- "origin_server_ts": last_event_in_batch["origin_server_ts"],
- }
- # Add the batch event to the end of the batch (newest-in-time)
- events_to_create.append(batch_event)
-
- # Add an "insertion" event to the start of each batch (next to the oldest-in-time
- # event in the batch) so the next batch can be connected to this one.
- insertion_event = self._create_insertion_event_dict(
- sender=requester.user.to_string(),
+ # Create and persist all of the historical events as well as insertion
+ # and batch meta events to make the batch navigable in the DAG.
+ event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events(
+ events_to_create=events_to_create,
room_id=room_id,
- # Since the insertion event is put at the start of the batch,
- # where the oldest-in-time event is, copy the origin_server_ts from
- # the first event we're inserting
- origin_server_ts=events_to_create[0]["origin_server_ts"],
+ batch_id_to_connect_to=batch_id_to_connect_to,
+ initial_prev_event_ids=prev_event_ids,
+ inherited_depth=inherited_depth,
+ auth_event_ids=auth_event_ids,
+ app_service_requester=requester,
)
- # Prepend the insertion event to the start of the batch (oldest-in-time)
- events_to_create = [insertion_event] + events_to_create
-
- event_ids = []
- events_to_persist = []
- for ev in events_to_create:
- assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
-
- event_dict = {
- "type": ev["type"],
- "origin_server_ts": ev["origin_server_ts"],
- "content": ev["content"],
- "room_id": room_id,
- "sender": ev["sender"], # requester.user.to_string(),
- "prev_events": prev_event_ids.copy(),
- }
-
- # Mark all events as historical
- event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
-
- event, context = await self.event_creation_handler.create_event(
- await self._create_requester_for_user_id_from_app_service(
- ev["sender"], requester.app_service
- ),
- event_dict,
- prev_event_ids=event_dict.get("prev_events"),
- auth_event_ids=auth_event_ids,
- historical=True,
- depth=inherited_depth,
- )
- logger.debug(
- "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
- event,
- prev_event_ids,
- auth_event_ids,
- )
-
- assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
- event.sender,
- )
-
- events_to_persist.append((event, context))
- event_id = event.event_id
-
- event_ids.append(event_id)
- prev_event_ids = [event_id]
-
- # Persist events in reverse-chronological order so they have the
- # correct stream_ordering as they are backfilled (which decrements).
- # Events are sorted by (topological_ordering, stream_ordering)
- # where topological_ordering is just depth.
- for (event, context) in reversed(events_to_persist):
- ev = await self.event_creation_handler.handle_new_client_event(
- await self._create_requester_for_user_id_from_app_service(
- event["sender"], requester.app_service
- ),
- event=event,
- context=context,
- )
insertion_event_id = event_ids[0]
batch_event_id = event_ids[-1]
@@ -445,9 +222,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
response_dict = {
"state_event_ids": state_event_ids_at_start,
"event_ids": historical_event_ids,
- "next_batch_id": insertion_event["content"][
- EventContentFields.MSC2716_NEXT_BATCH_ID
- ],
+ "next_batch_id": next_batch_id,
"insertion_event_id": insertion_event_id,
"batch_event_id": batch_event_id,
}
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index 1d90493e..09a46737 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -42,7 +42,7 @@ class UserSharedRoomsServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.user_directory_active = hs.config.update_user_directory
+ self.user_directory_active = hs.config.server.update_user_directory
async def on_GET(
self, request: SynapseRequest, user_id: str
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 1259058b..913216a7 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -155,7 +155,7 @@ class SyncRestServlet(RestServlet):
try:
filter_object = json_decoder.decode(filter_id)
set_timeline_upper_limit(
- filter_object, self.hs.config.filter_timeline_limit
+ filter_object, self.hs.config.server.filter_timeline_limit
)
except Exception:
raise SynapseError(400, "Invalid filter JSON")
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index ea2b8aa4..ea7e0251 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -70,7 +70,7 @@ class VoipRestServlet(RestServlet):
{
"username": username,
"password": password,
- "ttl": userLifetime / 1000,
+ "ttl": userLifetime // 1000,
"uris": turnUris,
},
)
diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py
index 3dd16d4b..d5b74cdd 100644
--- a/synapse/rest/media/v1/__init__.py
+++ b/synapse/rest/media/v1/__init__.py
@@ -12,33 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import PIL.Image
+from PIL.features import check_codec
# check for JPEG support.
-try:
- PIL.Image._getdecoder("rgb", "jpeg", None)
-except OSError as e:
- if str(e).startswith("decoder jpeg not available"):
- raise Exception(
- "FATAL: jpeg codec not supported. Install pillow correctly! "
- " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
- " pip install pillow --user'"
- )
-except Exception:
- # any other exception is fine
- pass
+if not check_codec("jpg"):
+ raise Exception(
+ "FATAL: jpeg codec not supported. Install pillow correctly! "
+ " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
+ " pip install pillow --user'"
+ )
# check for PNG support.
-try:
- PIL.Image._getdecoder("rgb", "zip", None)
-except OSError as e:
- if str(e).startswith("decoder zip not available"):
- raise Exception(
- "FATAL: zip codec not supported. Install pillow correctly! "
- " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
- " pip install pillow --user'"
- )
-except Exception:
- # any other exception is fine
- pass
+if not check_codec("zlib"):
+ raise Exception(
+ "FATAL: zip codec not supported. Install pillow correctly! "
+ " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
+ " pip install pillow --user'"
+ )
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index e04671fb..78b1603f 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -96,6 +96,32 @@ class OEmbedProvider:
# No match.
return None
+ def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]:
+ """
+ Search an HTML document for oEmbed autodiscovery information.
+
+ Args:
+ tree: The parsed HTML body.
+
+ Returns:
+ The URL to use for oEmbed information, or None if no URL was found.
+ """
+ # Search for link elements with the proper rel and type attributes.
+ for tag in tree.xpath(
+ "//link[@rel='alternate'][@type='application/json+oembed']"
+ ):
+ if "href" in tag.attrib:
+ return tag.attrib["href"]
+
+ # Some providers (e.g. Flickr) use alternative instead of alternate.
+ for tag in tree.xpath(
+ "//link[@rel='alternative'][@type='application/json+oembed']"
+ ):
+ if "href" in tag.attrib:
+ return tag.attrib["href"]
+
+ return None
+
def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
"""
Parse the oEmbed response into an Open Graph response.
@@ -165,7 +191,7 @@ class OEmbedProvider:
except Exception as e:
# Trap any exception and let the code follow as usual.
- logger.warning(f"Error parsing oEmbed metadata from {url}: {e:r}")
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
open_graph_response = {}
cache_age = None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 79a42b24..1fe0fc8a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import re
import shutil
import sys
import traceback
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Tuple, Union
from urllib import parse as urlparse
import attr
@@ -73,6 +73,7 @@ OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
ONE_DAY = 24 * ONE_HOUR
+IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -295,22 +296,32 @@ class PreviewUrlResource(DirectServeJsonResource):
body = file.read()
encoding = get_html_media_encoding(body, media_info.media_type)
- og = decode_and_calc_og(body, media_info.uri, encoding)
-
- await self._precache_image_url(user, media_info, og)
-
- elif oembed_url and _is_json(media_info.media_type):
- # Handle an oEmbed response.
- with open(media_info.filename, "rb") as file:
- body = file.read()
-
- oembed_response = self._oembed.parse_oembed_response(url, body)
- og = oembed_response.open_graph_result
-
- # Use the cache age from the oEmbed result, instead of the HTTP response.
- if oembed_response.cache_age is not None:
- expiration_ms = oembed_response.cache_age
+ tree = decode_body(body, encoding)
+ if tree is not None:
+ # Check if this HTML document points to oEmbed information and
+ # defer to that.
+ oembed_url = self._oembed.autodiscover_from_html(tree)
+ og = {}
+ if oembed_url:
+ oembed_info = await self._download_url(oembed_url, user)
+ og, expiration_ms = await self._handle_oembed_response(
+ url, oembed_info, expiration_ms
+ )
+
+ # If there was no oEmbed URL (or oEmbed parsing failed), attempt
+ # to generate the Open Graph information from the HTML.
+ if not oembed_url or not og:
+ og = _calc_og(tree, media_info.uri)
+
+ await self._precache_image_url(user, media_info, og)
+ else:
+ og = {}
+ elif oembed_url:
+ # Handle the oEmbed information.
+ og, expiration_ms = await self._handle_oembed_response(
+ url, media_info, expiration_ms
+ )
await self._precache_image_url(user, media_info, og)
else:
@@ -478,6 +489,39 @@ class PreviewUrlResource(DirectServeJsonResource):
else:
del og["og:image"]
+ async def _handle_oembed_response(
+ self, url: str, media_info: MediaInfo, expiration_ms: int
+ ) -> Tuple[JsonDict, int]:
+ """
+ Parse the downloaded oEmbed info.
+
+ Args:
+ url: The URL which is being previewed (not the one which was
+ requested).
+ media_info: The media being previewed.
+ expiration_ms: The length of time, in milliseconds, the media is valid for.
+
+ Returns:
+ A tuple of:
+ The Open Graph dictionary, if the oEmbed info can be parsed.
+ The (possibly updated) length of time, in milliseconds, the media is valid for.
+ """
+ # If JSON was not returned, there's nothing to do.
+ if not _is_json(media_info.media_type):
+ return {}, expiration_ms
+
+ with open(media_info.filename, "rb") as file:
+ body = file.read()
+
+ oembed_response = self._oembed.parse_oembed_response(url, body)
+ open_graph_result = oembed_response.open_graph_result
+
+ # Use the cache age from the oEmbed result, if one was given.
+ if open_graph_result and oembed_response.cache_age is not None:
+ expiration_ms = oembed_response.cache_age
+
+ return open_graph_result, expiration_ms
+
def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
@@ -496,6 +540,27 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.info("Still running DB updates; skipping expiry")
return
+ def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
+ """Attempt to remove the given chain of parent directories
+
+ Args:
+ dirs: The list of directory paths to delete, with children appearing
+ before their parents.
+ """
+ for dir in dirs:
+ try:
+ os.rmdir(dir)
+ except FileNotFoundError:
+ # Already deleted, continue with deleting the rest
+ pass
+ except OSError as e:
+ # Failed, skip deleting the rest of the parent dirs
+ if e.errno != errno.ENOTEMPTY:
+ logger.warning(
+ "Failed to remove media directory: %r: %s", dir, e
+ )
+ break
+
# First we delete expired url cache entries
media_ids = await self.store.get_expired_url_cache(now)
@@ -504,20 +569,16 @@ class PreviewUrlResource(DirectServeJsonResource):
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
+ except FileNotFoundError:
+ pass # If the path doesn't exist, meh
except OSError as e:
- # If the path doesn't exist, meh
- if e.errno != errno.ENOENT:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
- continue
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
+ continue
removed_media.append(media_id)
- try:
- dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
- for dir in dirs:
- os.rmdir(dir)
- except Exception:
- pass
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ try_remove_parent_dirs(dirs)
await self.store.delete_url_cache(removed_media)
@@ -530,7 +591,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * ONE_DAY
+ expire_before = now - IMAGE_CACHE_EXPIRY_MS
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
@@ -538,36 +599,30 @@ class PreviewUrlResource(DirectServeJsonResource):
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
+ except FileNotFoundError:
+ pass # If the path doesn't exist, meh
except OSError as e:
- # If the path doesn't exist, meh
- if e.errno != errno.ENOENT:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
- continue
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
+ continue
- try:
- dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
- for dir in dirs:
- os.rmdir(dir)
- except Exception:
- pass
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ try_remove_parent_dirs(dirs)
thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
try:
shutil.rmtree(thumbnail_dir)
+ except FileNotFoundError:
+ pass # If the path doesn't exist, meh
except OSError as e:
- # If the path doesn't exist, meh
- if e.errno != errno.ENOENT:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
- continue
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
+ continue
removed_media.append(media_id)
- try:
- dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
- for dir in dirs:
- os.rmdir(dir)
- except Exception:
- pass
+ dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
+ # Note that one of the directories to be deleted has already been
+ # removed by the `rmtree` above.
+ try_remove_parent_dirs(dirs)
await self.store.delete_url_cache_media(removed_media)
@@ -619,26 +674,22 @@ def get_html_media_encoding(body: bytes, content_type: str) -> str:
return "utf-8"
-def decode_and_calc_og(
- body: bytes, media_uri: str, request_encoding: Optional[str] = None
-) -> JsonDict:
+def decode_body(
+ body: bytes, request_encoding: Optional[str] = None
+) -> Optional["etree.Element"]:
"""
- Calculate metadata for an HTML document.
-
- This uses lxml to parse the HTML document into the OG response. If errors
- occur during processing of the document, an empty response is returned.
+ This uses lxml to parse the HTML document.
Args:
body: The HTML document, as bytes.
- media_url: The URI used to download the body.
request_encoding: The character encoding of the body, as a string.
Returns:
- The OG response as a dictionary.
+ The parsed HTML body, or None if an error occurred during processed.
"""
# If there's no body, nothing useful is going to be found.
if not body:
- return {}
+ return None
from lxml import etree
@@ -650,25 +701,22 @@ def decode_and_calc_og(
parser = etree.HTMLParser(recover=True, encoding="utf-8")
except Exception as e:
logger.warning("Unable to create HTML parser: %s" % (e,))
- return {}
-
- def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
- # Attempt to parse the body. If this fails, log and return no metadata.
- tree = etree.fromstring(body_attempt, parser)
-
- # The data was successfully parsed, but no tree was found.
- if tree is None:
- return {}
+ return None
- return _calc_og(tree, media_uri)
+ def _attempt_decode_body(
+ body_attempt: Union[bytes, str]
+ ) -> Optional["etree.Element"]:
+ # Attempt to parse the body. Returns None if the body was successfully
+ # parsed, but no tree was found.
+ return etree.fromstring(body_attempt, parser)
# Attempt to parse the body. If this fails, log and return no metadata.
try:
- return _attempt_calc_og(body)
+ return _attempt_decode_body(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- return _attempt_calc_og(body.decode("utf-8", "ignore"))
+ return _attempt_decode_body(body.decode("utf-8", "ignore"))
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index df54a406..46701a8b 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -61,9 +61,19 @@ class Thumbnailer:
self.transpose_method = None
try:
# We don't use ImageOps.exif_transpose since it crashes with big EXIF
- image_exif = self.image._getexif()
+ #
+ # Ignore safety: Pillow seems to acknowledge that this method is
+ # "private, experimental, but generally widely used". Pillow 6
+ # includes a public getexif() method (no underscore) that we might
+ # consider using instead when we can bump that dependency.
+ #
+ # At the time of writing, Debian buster (currently oldstable)
+ # provides version 5.4.1. It's expected to EOL in mid-2022, see
+ # https://wiki.debian.org/DebianReleases#Production_Releases
+ image_exif = self.image._getexif() # type: ignore
if image_exif is not None:
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
+ assert isinstance(image_orientation, int)
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
except Exception as e:
# A lot of parsing errors can happen when parsing EXIF
@@ -76,7 +86,10 @@ class Thumbnailer:
A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
- self.image = self.image.transpose(self.transpose_method)
+ # Safety: `transpose` takes an int rather than e.g. an IntEnum.
+ # self.transpose_method is set above to be a value in
+ # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values.
+ self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
self.width, self.height = self.image.size
self.transpose_method = None
# We don't need EXIF any more
@@ -101,7 +114,7 @@ class Thumbnailer:
else:
return (max_height * self.width) // self.height, max_height
- def _resize(self, width: int, height: int) -> Image:
+ def _resize(self, width: int, height: int) -> Image.Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful.
@@ -151,7 +164,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
- def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
+ def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index c80a3a99..7ac01faa 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -39,9 +39,9 @@ class WellKnownBuilder:
result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}}
- if self._config.default_identity_server:
+ if self._config.registration.default_identity_server:
result["m.identity_server"] = {
- "base_url": self._config.default_identity_server
+ "base_url": self._config.registration.default_identity_server
}
return result
diff --git a/synapse/server.py b/synapse/server.py
index 637eb15b..5bc045d6 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -39,7 +39,7 @@ from twisted.web.resource import IResource
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
-from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.config.homeserver import HomeServerConfig
@@ -97,6 +97,7 @@ from synapse.handlers.room import (
RoomCreationHandler,
RoomShutdownHandler,
)
+from synapse.handlers.room_batch import RoomBatchHandler
from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
@@ -438,6 +439,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return RoomCreationHandler(self)
@cache_in_self
+ def get_room_batch_handler(self) -> RoomBatchHandler:
+ return RoomBatchHandler(self)
+
+ @cache_in_self
def get_room_shutdown_handler(self) -> RoomShutdownHandler:
return RoomShutdownHandler(self)
@@ -816,3 +821,12 @@ class HomeServer(metaclass=abc.ABCMeta):
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
return self.config.worker.send_federation
+
+ @cache_in_self
+ def get_request_ratelimiter(self) -> RequestRatelimiter:
+ return RequestRatelimiter(
+ self.get_datastore(),
+ self.get_clock(),
+ self.config.ratelimiting.rc_message,
+ self.config.ratelimiting.rc_admin_redaction,
+ )
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 073b0d75..8522930b 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -47,9 +47,9 @@ class ResourceLimitsServerNotices:
self._notifier = hs.get_notifier()
self._enabled = (
- hs.config.limit_usage_by_mau
+ hs.config.server.limit_usage_by_mau
and self._server_notices_manager.is_enabled()
- and not hs.config.hs_disabled
+ and not hs.config.server.hs_disabled
)
async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
@@ -98,7 +98,7 @@ class ResourceLimitsServerNotices:
try:
if (
limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER
- and not self._config.mau_limit_alerting
+ and not self._config.server.mau_limit_alerting
):
# We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return
@@ -149,7 +149,7 @@ class ResourceLimitsServerNotices:
"body": event_body,
"msgtype": ServerNoticeMsgType,
"server_notice_type": ServerNoticeLimitReached,
- "admin_contact": self._config.admin_contact,
+ "admin_contact": self._config.server.admin_contact,
"limit_type": event_limit_type,
}
event = await self._server_notices_manager.send_notice(
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index cd1c5ff6..0cf60236 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -41,12 +41,8 @@ class ServerNoticesManager:
self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.servernotices.server_notices_mxid
- def is_enabled(self):
- """Checks if server notices are enabled on this server.
-
- Returns:
- bool
- """
+ def is_enabled(self) -> bool:
+ """Checks if server notices are enabled on this server."""
return self.server_notices_mxid is not None
async def send_notice(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c981df3f..5cf2e125 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -118,7 +118,7 @@ class _StateCacheEntry:
else:
self.state_id = _gen_state_id()
- def __len__(self):
+ def __len__(self) -> int:
return len(self.state)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 92336d7c..ffe6207a 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -225,7 +225,7 @@ def _resolve_with_state(
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase],
-):
+) -> MutableStateMap[str]:
conflicted_state = {}
for key, event_ids in conflicted_state_ids.items():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
@@ -329,12 +329,10 @@ def _resolve_auth_events(
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
event,
auth_events,
- do_sig_check=False,
- do_size_check=False,
)
prev_event = event
except AuthError:
@@ -349,12 +347,10 @@ def _resolve_normal_events(
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
event,
auth_events,
- do_sig_check=False,
- do_size_check=False,
)
return event
except AuthError:
@@ -366,7 +362,7 @@ def _resolve_normal_events(
def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
- def key_func(e):
+ def key_func(e: EventBase) -> Tuple[int, str]:
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest()
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 7b1e8361..bd18eefd 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -481,7 +481,7 @@ async def _reverse_topological_power_sort(
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
- def _get_power_order(event_id):
+ def _get_power_order(event_id: str) -> Tuple[int, int, str]:
ev = event_map[event_id]
pl = event_to_pl[event_id]
@@ -546,12 +546,10 @@ async def _iterative_auth_checks(
auth_events[key] = event_map[ev_id]
try:
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
room_version,
event,
auth_events,
- do_sig_check=False,
- do_size_check=False,
)
resolved_state[(event.type, event.state_key)] = event_id
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 6305414e..eee07227 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -36,7 +36,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
if (
hs.config.worker.run_background_tasks
- and self.hs.config.redaction_retention_period is not None
+ and self.hs.config.server.redaction_retention_period is not None
):
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
@@ -48,7 +48,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
By censor we mean update the event_json table with the redacted event.
"""
- if self.hs.config.redaction_retention_period is None:
+ if self.hs.config.server.redaction_retention_period is None:
return
if not (
@@ -60,7 +60,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# created.
return
- before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+ before_ts = (
+ self._clock.time_msec() - self.hs.config.server.redaction_retention_period
+ )
# We fetch all redactions that:
# 1. point to an event we have,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index cc192f5c..6c1ef090 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -353,7 +353,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self.user_ips_max_age = hs.config.user_ips_max_age
+ self.user_ips_max_age = hs.config.server.user_ips_max_age
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore):
"""
ret = await super().get_last_client_ip_by_device(user_id, device_id)
- # Update what is retrieved from the database with data which is pending insertion.
+ # Update what is retrieved from the database with data which is pending
+ # insertion, as if it has already been stored in the database.
for key in self._batch_row_update:
- uid, access_token, ip = key
+ uid, _access_token, ip = key
if uid == user_id:
user_agent, did, last_seen = self._batch_row_update[key]
+
+ if did is None:
+ # These updates don't make it to the `devices` table
+ continue
+
if not device_id or did == device_id:
- ret[(user_id, device_id)] = {
+ ret[(user_id, did)] = {
"user_id": user_id,
- "access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": did,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 584f818f..19f55c19 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -104,7 +104,7 @@ class PersistEventsStore:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
- self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+ self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
@@ -1276,13 +1276,6 @@ class PersistEventsStore:
logger.exception("")
raise
- # update the stored internal_metadata to update the "outlier" flag.
- # TODO: This is unused as of Synapse 1.31. Remove it once we are happy
- # to drop backwards-compatibility with 1.30.
- metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
- sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
- txn.execute(sql, (metadata_json, event.event_id))
-
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
@@ -1327,19 +1320,6 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- def get_internal_metadata(event):
- im = event.internal_metadata.get_dict()
-
- # temporary hack for database compatibility with Synapse 1.30 and earlier:
- # store the `outlier` flag inside the internal_metadata json as well as in
- # the `events` table, so that if anyone rolls back to an older Synapse,
- # things keep working. This can be removed once we are happy to drop support
- # for that
- if event.internal_metadata.is_outlier():
- im["outlier"] = True
-
- return im
-
self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
@@ -1348,7 +1328,7 @@ class PersistEventsStore:
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": json_encoder.encode(
- get_internal_metadata(event)
+ event.internal_metadata.get_dict()
),
"json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
@@ -1783,9 +1763,8 @@ class PersistEventsStore:
retcol="creator",
allow_none=True,
)
- if (
- not room_version.msc2716_historical
- or not self.hs.config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self.hs.config.experimental.msc2716_enabled
or event.sender != room_creator
):
return
@@ -1845,9 +1824,8 @@ class PersistEventsStore:
retcol="creator",
allow_none=True,
)
- if (
- not room_version.msc2716_historical
- or not self.hs.config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self.hs.config.experimental.msc2716_enabled
or event.sender != room_creator
):
return
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index bb244a03..434986fa 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Union
+
from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
@@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
@cached(num_args=2)
- async def get_user_filter(self, user_localpart, filter_id):
+ async def get_user_filter(
+ self, user_localpart: str, filter_id: Union[int, str]
+ ) -> JsonDict:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b76ee51a..ec4d47a5 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -32,8 +32,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._max_mau_value = hs.config.max_mau_value
+ self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
+ self._max_mau_value = hs.config.server.max_mau_value
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
@@ -96,8 +96,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
users = []
- for tp in self.hs.config.mau_limits_reserved_threepids[
- : self.hs.config.max_mau_value
+ for tp in self.hs.config.server.mau_limits_reserved_threepids[
+ : self.hs.config.server.max_mau_value
]:
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
@@ -212,7 +212,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._mau_stats_only = hs.config.mau_stats_only
+ self._mau_stats_only = hs.config.server.mau_stats_only
# Do not add more reserved users than the total allowable number
self.db_pool.new_transaction(
@@ -221,7 +221,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
[],
[],
self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+ hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
@@ -354,3 +354,27 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
await self.upsert_monthly_active_user(user_id)
+
+ async def remove_deactivated_user_from_mau_table(self, user_id: str) -> None:
+ """
+ Removes a deactivated user from the monthly active user
+ table and resets affected caches.
+
+ Args:
+ user_id(str): the user_id to remove
+ """
+
+ rows_deleted = await self.db_pool.simple_delete(
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ desc="simple_delete",
+ )
+
+ if rows_deleted != 0:
+ await self.invalidate_cache_and_stream(
+ "user_last_seen_monthly_active", (user_id,)
+ )
+ await self.invalidate_cache_and_stream("get_monthly_active_count", ())
+ await self.invalidate_cache_and_stream(
+ "get_monthly_active_count_by_service", ()
+ )
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a7fb8cd8..fc720f59 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import logging
-from typing import List, Tuple, Union
+from typing import Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
@@ -101,7 +101,9 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+ self._users_new_default_push_rules = (
+ hs.config.server.users_new_default_push_rules
+ )
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
@@ -137,7 +139,7 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, use_new_defaults)
@cached(max_entries=5000)
- async def get_push_rules_enabled_for_user(self, user_id):
+ async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index a93caae8..b73ce53c 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional,
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -32,7 +31,12 @@ logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c83089ee..181841ee 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID, UserInfo
@@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return False
now = self._clock.time_msec()
- trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
+ trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
@@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
- id_servers = set(self.config.trusted_third_party_id_servers)
+ id_servers = set(self.config.registration.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
@@ -1775,10 +1775,17 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
- self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+ self._ignore_unknown_session_error = (
+ hs.config.server.request_token_inhibit_3pid_errors
+ )
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 118b390e..d69eaf80 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore):
# policy.
if not ret:
return {
- "min_lifetime": self.config.retention_default_min_lifetime,
- "max_lifetime": self.config.retention_default_max_lifetime,
+ "min_lifetime": self.config.server.retention_default_min_lifetime,
+ "max_lifetime": self.config.server.retention_default_max_lifetime,
}
row = ret[0]
@@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore):
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention_default_min_lifetime
+ row["min_lifetime"] = self.config.server.retention_default_min_lifetime
if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention_default_max_lifetime
+ row["max_lifetime"] = self.config.server.retention_default_max_lifetime
return row
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index a3833887..300a563c 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -18,7 +18,9 @@ from synapse.storage._base import SQLBaseStore
class RoomBatchStore(SQLBaseStore):
- async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]:
+ async def get_insertion_event_by_batch_id(
+ self, room_id: str, batch_id: str
+ ) -> Optional[str]:
"""Retrieve a insertion event ID.
Args:
@@ -30,7 +32,7 @@ class RoomBatchStore(SQLBaseStore):
"""
return await self.db_pool.simple_select_one_onecol(
table="insertion_events",
- keyvalues={"next_batch_id": batch_id},
+ keyvalues={"room_id": room_id, "next_batch_id": batch_id},
retcol="event_id",
allow_none=True,
)
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2a1e99e1..c85383c9 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -51,7 +51,7 @@ class SearchWorkerStore(SQLBaseStore):
txn:
entries: entries to be added to the table
"""
- if not self.hs.config.enable_search:
+ if not self.hs.config.server.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
@@ -105,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- if not hs.config.enable_search:
+ if not hs.config.server.enable_search:
return
self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 90d65edc..e98a45b6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
cast,
)
+from synapse.api.errors import StoreError
+
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -40,12 +42,10 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-
TEMP_TABLE = "_temp_populate_user_directory"
class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
-
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
@@ -230,38 +230,49 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
is_in_room = await self.is_host_joined(room_id, self.server_name)
if is_in_room:
- is_public = await self.is_room_world_readable_or_publicly_joinable(
- room_id
- )
-
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+ # Throw away users excluded from the directory.
+ users_with_profile = {
+ user_id: profile
+ for user_id, profile in users_with_profile.items()
+ if not self.hs.is_mine_id(user_id)
+ or await self.should_include_local_user_in_dir(user_id)
+ }
- # Update each user in the user directory.
+ # Upsert a user_directory record for each remote user we see.
for user_id, profile in users_with_profile.items():
+ # Local users are processed separately in
+ # `_populate_user_directory_users`; there we can read from
+ # the `profiles` table to ensure we don't leak their per-room
+ # profiles. It also means we write local users to this table
+ # exactly once, rather than once for every room they're in.
+ if self.hs.is_mine_id(user_id):
+ continue
+ # TODO `users_with_profile` above reads from the `user_directory`
+ # table, meaning that `profile` is bespoke to this room.
+ # and this leaks remote users' per-room profiles to the user directory.
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- to_insert = set()
-
+ # Now update the room sharing tables to include this room.
+ is_public = await self.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
if is_public:
- for user_id in users_with_profile:
- if self.get_if_app_services_interested_in_user(user_id):
- continue
-
- to_insert.add(user_id)
-
- if to_insert:
- await self.add_users_in_public_rooms(room_id, to_insert)
- to_insert.clear()
+ if users_with_profile:
+ await self.add_users_in_public_rooms(
+ room_id, users_with_profile.keys()
+ )
else:
+ to_insert = set()
for user_id in users_with_profile:
+ # We want the set of pairs (L, M) where L and M are
+ # in `users_with_profile` and L is local.
+ # Do so by looking for the local user L first.
if not self.hs.is_mine_id(user_id):
continue
- if self.get_if_app_services_interested_in_user(user_id):
- continue
-
for other_user_id in users_with_profile:
if user_id == other_user_id:
continue
@@ -349,10 +360,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
for user_id in users_to_work_on:
- profile = await self.get_profileinfo(get_localpart_from_id(user_id))
- await self.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
- )
+ if await self.should_include_local_user_in_dir(user_id):
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ await self.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url
+ )
# We've finished processing a user. Delete it from the table.
await self.db_pool.simple_delete_one(
@@ -369,6 +381,42 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
+ async def should_include_local_user_in_dir(self, user: str) -> bool:
+ """Certain classes of local user are omitted from the user directory.
+ Is this user one of them?
+ """
+ # We're opting to exclude the appservice sender (user defined by the
+ # `sender_localpart` in the appservice registration) even though
+ # technically it could be DM-able. In the future, this could potentially
+ # be configurable per-appservice whether the appservice sender can be
+ # contacted.
+ if self.get_app_service_by_user_id(user) is not None:
+ return False
+
+ # We're opting to exclude appservice users (anyone matching the user
+ # namespace regex in the appservice registration) even though technically
+ # they could be DM-able. In the future, this could potentially
+ # be configurable per-appservice whether the appservice users can be
+ # contacted.
+ if self.get_if_app_services_interested_in_user(user):
+ # TODO we might want to make this configurable for each app service
+ return False
+
+ # Support users are for diagnostics and should not appear in the user directory.
+ if await self.is_support_user(user):
+ return False
+
+ # Deactivated users aren't contactable, so should not appear in the user directory.
+ try:
+ if await self.get_user_deactivated_status(user):
+ return False
+ except StoreError:
+ # No such user in the users table. No need to do this when calling
+ # is_support_user---that returns False if the user is missing.
+ return False
+
+ return True
+
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
@@ -527,7 +575,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- async def update_user_directory_stream_pos(self, stream_id: int) -> None:
+ async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
@@ -537,7 +585,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
-
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f31880b8..11ca47ea 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -366,7 +366,7 @@ def _upgrade_existing_database(
+ "new for the server to understand"
)
- # some of the deltas assume that config.server_name is set correctly, so now
+ # some of the deltas assume that server_name is set correctly, so now
# is a good time to run the sanity check.
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
@@ -487,6 +487,10 @@ def _upgrade_existing_database(
spec = importlib.util.spec_from_file_location(
module_name, absolute_path
)
+ if spec is None:
+ raise RuntimeError(
+ f"Could not build a module spec for {module_name} at {absolute_path}"
+ )
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 573e05a4..1aee741a 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# When updating these values, please leave a short summary of the changes below.
-
-SCHEMA_VERSION = 64
+SCHEMA_VERSION = 64 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -46,7 +44,7 @@ Changes in SCHEMA_VERSION = 64:
"""
-SCHEMA_COMPAT_VERSION = 59
+SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata.
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
This value is stored in the database, and checked on startup. If the value in the
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5e86befd..b5ba1560 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,9 +15,11 @@ import logging
from typing import (
TYPE_CHECKING,
Awaitable,
+ Collection,
Dict,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -29,7 +31,7 @@ from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
@@ -134,6 +136,23 @@ class StateFilter:
include_others=True,
)
+ @staticmethod
+ def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+ """
+ Returns a (frozen) StateFilter with the same contents as the parameters
+ specified here, which can be made of mutable types.
+ """
+ types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
+ for state_types, state_keys in types.items():
+ if state_keys is not None:
+ types_with_frozen_values[state_types] = frozenset(state_keys)
+ else:
+ types_with_frozen_values[state_types] = None
+
+ return StateFilter(
+ frozendict(types_with_frozen_values), include_others=include_others
+ )
+
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the
@@ -356,6 +375,157 @@ class StateFilter:
return member_filter, non_member_filter
+ def _decompose_into_four_parts(
+ self,
+ ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
+ """
+ Decomposes this state filter into 4 constituent parts, which can be
+ thought of as this:
+ all? - minus_wildcards + plus_wildcards + plus_state_keys
+
+ where
+ * all represents ALL state
+ * minus_wildcards represents entire state types to remove
+ * plus_wildcards represents entire state types to add
+ * plus_state_keys represents individual state keys to add
+
+ See `recompose_from_four_parts` for the other direction of this
+ correspondence.
+ """
+ is_all = self.include_others
+ excluded_types: Set[str] = {t for t in self.types if is_all}
+ wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
+ concrete_keys: Set[StateKey] = set(self.concrete_types())
+
+ return (is_all, excluded_types), (wildcard_types, concrete_keys)
+
+ @staticmethod
+ def _recompose_from_four_parts(
+ all_part: bool,
+ minus_wildcards: Set[str],
+ plus_wildcards: Set[str],
+ plus_state_keys: Set[StateKey],
+ ) -> "StateFilter":
+ """
+ Recomposes a state filter from 4 parts.
+
+ See `decompose_into_four_parts` (the other direction of this
+ correspondence) for descriptions on each of the parts.
+ """
+
+ # {state type -> set of state keys OR None for wildcard}
+ # (The same structure as that of a StateFilter.)
+ new_types: Dict[str, Optional[Set[str]]] = {}
+
+ # if we start with all, insert the excluded statetypes as empty sets
+ # to prevent them from being included
+ if all_part:
+ new_types.update({state_type: set() for state_type in minus_wildcards})
+
+ # insert the plus wildcards
+ new_types.update({state_type: None for state_type in plus_wildcards})
+
+ # insert the specific state keys
+ for state_type, state_key in plus_state_keys:
+ if state_type in new_types:
+ entry = new_types[state_type]
+ if entry is not None:
+ entry.add(state_key)
+ elif not all_part:
+ # don't insert if the entire type is already included by
+ # include_others as this would actually shrink the state allowed
+ # by this filter.
+ new_types[state_type] = {state_key}
+
+ return StateFilter.freeze(new_types, include_others=all_part)
+
+ def approx_difference(self, other: "StateFilter") -> "StateFilter":
+ """
+ Returns a state filter which represents `self - other`.
+
+ This is useful for determining what state remains to be pulled out of the
+ database if we want the state included by `self` but already have the state
+ included by `other`.
+
+ The returned state filter
+ - MUST include all state events that are included by this filter (`self`)
+ unless they are included by `other`;
+ - MUST NOT include state events not included by this filter (`self`); and
+ - MAY be an over-approximation: the returned state filter
+ MAY additionally include some state events from `other`.
+
+ This implementation attempts to return the narrowest such state filter.
+ In the case that `self` contains wildcards for state types where
+ `other` contains specific state keys, an approximation must be made:
+ the returned state filter keeps the wildcard, as state filters are not
+ able to express 'all state keys except some given examples'.
+ e.g.
+ StateFilter(m.room.member -> None (wildcard))
+ minus
+ StateFilter(m.room.member -> {'@wombat:example.org'})
+ is approximated as
+ StateFilter(m.room.member -> None (wildcard))
+ """
+
+ # We first transform self and other into an alternative representation:
+ # - whether or not they include all events to begin with ('all')
+ # - if so, which event types are excluded? ('excludes')
+ # - which entire event types to include ('wildcards')
+ # - which concrete state keys to include ('concrete state keys')
+ (self_all, self_excludes), (
+ self_wildcards,
+ self_concrete_keys,
+ ) = self._decompose_into_four_parts()
+ (other_all, other_excludes), (
+ other_wildcards,
+ other_concrete_keys,
+ ) = other._decompose_into_four_parts()
+
+ # Start with an estimate of the difference based on self
+ new_all = self_all
+ # Wildcards from the other can be added to the exclusion filter
+ new_excludes = self_excludes | other_wildcards
+ # We remove wildcards that appeared as wildcards in the other
+ new_wildcards = self_wildcards - other_wildcards
+ # We filter out the concrete state keys that appear in the other
+ # as wildcards or concrete state keys.
+ new_concrete_keys = {
+ (state_type, state_key)
+ for (state_type, state_key) in self_concrete_keys
+ if state_type not in other_wildcards
+ } - other_concrete_keys
+
+ if other_all:
+ if self_all:
+ # If self starts with all, then we add as wildcards any
+ # types which appear in the other's exclusion filter (but
+ # aren't in the self exclusion filter). This is as the other
+ # filter will return everything BUT the types in its exclusion, so
+ # we need to add those excluded types that also match the self
+ # filter as wildcard types in the new filter.
+ new_wildcards |= other_excludes.difference(self_excludes)
+
+ # If other is an `include_others` then the difference isn't.
+ new_all = False
+ # (We have no need for excludes when we don't start with all, as there
+ # is nothing to exclude.)
+ new_excludes = set()
+
+ # We also filter out all state types that aren't in the exclusion
+ # list of the other.
+ new_wildcards &= other_excludes
+ new_concrete_keys = {
+ (state_type, state_key)
+ for (state_type, state_key) in new_concrete_keys
+ if state_type in other_excludes
+ }
+
+ # Transform our newly-constructed state filter from the alternative
+ # representation back into the normal StateFilter representation.
+ return StateFilter._recompose_from_four_parts(
+ new_all, new_excludes, new_wildcards, new_concrete_keys
+ )
+
class StateGroupStorage:
"""High level interface to fetching state for event."""
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 6f7cbe40..67081161 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,42 +16,62 @@ import logging
import threading
from collections import OrderedDict
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from types import TracebackType
+from typing import (
+ AsyncContextManager,
+ ContextManager,
+ Dict,
+ Generator,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
-from sortedcontainers import SortedSet
+from sortedcontainers import SortedList, SortedSet
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
logger = logging.getLogger(__name__)
+T = TypeVar("T")
+
+
class IdGenerator:
- def __init__(self, db_conn, table, column):
+ def __init__(
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ column: str,
+ ):
self._lock = threading.Lock()
self._next_id = _load_current_id(db_conn, table, column)
- def get_next(self):
+ def get_next(self) -> int:
with self._lock:
self._next_id += 1
return self._next_id
-def _load_current_id(db_conn, table, column, step=1):
- """
-
- Args:
- db_conn (object):
- table (str):
- column (str):
- step (int):
-
- Returns:
- int
- """
+def _load_current_id(
+ db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
+) -> int:
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor(txn_name="_load_current_id")
@@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- (val,) = cur.fetchone()
+ result = cur.fetchone()
+ assert result is not None
+ (val,) = result
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
@@ -93,16 +115,16 @@ class StreamIdGenerator:
def __init__(
self,
- db_conn,
- table,
- column,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
- step=1,
- ):
+ step: int = 1,
+ ) -> None:
assert step != 0
self._lock = threading.Lock()
- self._step = step
- self._current = _load_current_id(db_conn, table, column, step)
+ self._step: int = step
+ self._current: int = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
@@ -115,7 +137,7 @@ class StreamIdGenerator:
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
- def get_next(self):
+ def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -128,7 +150,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id
@contextmanager
- def manager():
+ def manager() -> Generator[int, None, None]:
try:
yield next_id
finally:
@@ -137,7 +159,7 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager())
- def get_next_mult(self, n):
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
@@ -155,7 +177,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id
@contextmanager
- def manager():
+ def manager() -> Generator[Sequence[int], None, None]:
try:
yield next_ids
finally:
@@ -215,7 +237,7 @@ class MultiWriterIdGenerator:
def __init__(
self,
- db_conn,
+ db_conn: LoggingDatabaseConnection,
db: DatabasePool,
stream_name: str,
instance_name: str,
@@ -223,7 +245,7 @@ class MultiWriterIdGenerator:
sequence_name: str,
writers: List[str],
positive: bool = True,
- ):
+ ) -> None:
self._db = db
self._stream_name = stream_name
self._instance_name = instance_name
@@ -243,6 +265,15 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids: SortedSet[int] = SortedSet()
+ # We also need to track when we've requested some new stream IDs but
+ # they haven't yet been added to the `_unfinished_ids` set. Every time
+ # we request a new stream ID we add the current max stream ID to the
+ # list, and remove it once we've added the newly allocated IDs to the
+ # `_unfinished_ids` set. This means that we *may* be allocated stream
+ # IDs above those in the list, and so we can't advance the local current
+ # position beyond the minimum stream ID in this list.
+ self._in_flight_fetches: SortedList[int] = SortedList()
+
# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
self._finished_ids: Set[int] = set()
@@ -268,6 +299,9 @@ class MultiWriterIdGenerator:
)
self._known_persisted_positions: List[int] = []
+ # The maximum stream ID that we have seen been allocated across any writer.
+ self._max_seen_allocated_stream_id = 1
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
@@ -283,11 +317,15 @@ class MultiWriterIdGenerator:
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
+ self._max_seen_allocated_stream_id = max(
+ self._current_positions.values(), default=1
+ )
+
def _load_current_ids(
self,
- db_conn,
+ db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]],
- ):
+ ) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream.
@@ -335,7 +373,9 @@ class MultiWriterIdGenerator:
"agg": "MAX" if self._positive else "-MIN",
}
cur.execute(sql)
- (stream_id,) = cur.fetchone()
+ result = cur.fetchone()
+ assert result is not None
+ (stream_id,) = result
max_stream_id = max(max_stream_id, stream_id)
@@ -354,7 +394,7 @@ class MultiWriterIdGenerator:
self._persisted_upto_position = min_stream_id
- rows = []
+ rows: List[Tuple[str, int]] = []
for table, instance_column, id_column in tables:
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
@@ -367,7 +407,8 @@ class MultiWriterIdGenerator:
}
cur.execute(sql, (min_stream_id * self._return_factor,))
- rows.extend(cur)
+ # Cast safety: this corresponds to the types returned by the query above.
+ rows.extend(cast(Iterable[Tuple[str, int]], cur))
# Sort so that we handle rows in order for each instance.
rows.sort()
@@ -385,13 +426,35 @@ class MultiWriterIdGenerator:
cur.close()
- def _load_next_id_txn(self, txn) -> int:
- return self._sequence_gen.get_next_id_txn(txn)
+ def _load_next_id_txn(self, txn: Cursor) -> int:
+ stream_ids = self._load_next_mult_id_txn(txn, 1)
+ return stream_ids[0]
- def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
- return self._sequence_gen.get_next_mult_txn(txn, n)
+ def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
+ # We need to track that we've requested some more stream IDs, and what
+ # the current max allocated stream ID is. This is to prevent a race
+ # where we've been allocated stream IDs but they have not yet been added
+ # to the `_unfinished_ids` set, allowing the current position to advance
+ # past them.
+ with self._lock:
+ current_max = self._max_seen_allocated_stream_id
+ self._in_flight_fetches.add(current_max)
+
+ try:
+ stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)
+
+ with self._lock:
+ self._unfinished_ids.update(stream_ids)
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
+ )
+ finally:
+ with self._lock:
+ self._in_flight_fetches.remove(current_max)
+
+ return stream_ids
- def get_next(self):
+ def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -403,9 +466,12 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
- return _MultiWriterCtxManager(self)
+ # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
+ # controls the return type. If `None` or omitted, the context manager yields
+ # a single integer stream_id; otherwise it yields a list of stream_ids.
+ return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
- def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
"""
Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids:
@@ -417,9 +483,10 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
- return _MultiWriterCtxManager(self, n)
+ # Cast safety: see get_next.
+ return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
- def get_next_txn(self, txn: LoggingTransaction):
+ def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Usage:
@@ -434,9 +501,6 @@ class MultiWriterIdGenerator:
next_id = self._load_next_id_txn(txn)
- with self._lock:
- self._unfinished_ids.add(next_id)
-
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
@@ -457,7 +521,7 @@ class MultiWriterIdGenerator:
return self._return_factor * next_id
- def _mark_id_as_finished(self, next_id: int):
+ def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the
current position if possible.
"""
@@ -468,15 +532,27 @@ class MultiWriterIdGenerator:
new_cur: Optional[int] = None
- if self._unfinished_ids:
+ if self._unfinished_ids or self._in_flight_fetches:
# If there are unfinished IDs then the new position will be the
- # largest finished ID less than the minimum unfinished ID.
+ # largest finished ID strictly less than the minimum unfinished
+ # ID.
+
+ # The minimum unfinished ID needs to take account of both
+ # `_unfinished_ids` and `_in_flight_fetches`.
+ if self._unfinished_ids and self._in_flight_fetches:
+ # `_in_flight_fetches` stores the maximum safe stream ID, so
+ # we add one to make it equivalent to the minimum unsafe ID.
+ min_unfinished = min(
+ self._unfinished_ids[0], self._in_flight_fetches[0] + 1
+ )
+ elif self._in_flight_fetches:
+ min_unfinished = self._in_flight_fetches[0] + 1
+ else:
+ min_unfinished = self._unfinished_ids[0]
finished = set()
-
- min_unfinshed = self._unfinished_ids[0]
for s in self._finished_ids:
- if s < min_unfinshed:
+ if s < min_unfinished:
if new_cur is None or new_cur < s:
new_cur = s
else:
@@ -534,7 +610,7 @@ class MultiWriterIdGenerator:
for name, i in self._current_positions.items()
}
- def advance(self, instance_name: str, new_id: int):
+ def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
@@ -546,6 +622,10 @@ class MultiWriterIdGenerator:
new_id, self._current_positions.get(instance_name, 0)
)
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, new_id
+ )
+
self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int:
@@ -560,7 +640,7 @@ class MultiWriterIdGenerator:
with self._lock:
return self._return_factor * self._persisted_upto_position
- def _add_persisted_position(self, new_id: int):
+ def _add_persisted_position(self, new_id: int) -> None:
"""Record that we have persisted a position.
This is used to keep the `_current_positions` up to date.
@@ -576,7 +656,11 @@ class MultiWriterIdGenerator:
# to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name)
- if our_current_position and not self._unfinished_ids:
+ if (
+ our_current_position
+ and not self._unfinished_ids
+ and not self._in_flight_fetches
+ ):
self._current_positions[self._instance_name] = max(
our_current_position, new_id
)
@@ -606,7 +690,7 @@ class MultiWriterIdGenerator:
# do.
break
- def _update_stream_positions_table_txn(self, txn: Cursor):
+ def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
"""Update the `stream_positions` table with newly persisted position."""
if not self._writers:
@@ -628,20 +712,25 @@ class MultiWriterIdGenerator:
txn.execute(sql, (self._stream_name, self._instance_name, pos))
-@attr.s(slots=True)
-class _AsyncCtxManagerWrapper:
+@attr.s(frozen=True, auto_attribs=True)
+class _AsyncCtxManagerWrapper(Generic[T]):
"""Helper class to convert a plain context manager to an async one.
This is mainly useful if you have a plain context manager but the interface
requires an async one.
"""
- inner = attr.ib()
+ inner: ContextManager[T]
- async def __aenter__(self):
+ async def __aenter__(self) -> T:
return self.inner.__enter__()
- async def __aexit__(self, exc_type, exc, tb):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> Optional[bool]:
return self.inner.__exit__(exc_type, exc, tb)
@@ -663,15 +752,17 @@ class _MultiWriterCtxManager:
db_autocommit=True,
)
- with self.id_gen._lock:
- self.id_gen._unfinished_ids.update(self.stream_ids)
-
if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
return [i * self.id_gen._return_factor for i in self.stream_ids]
- async def __aexit__(self, exc_type, exc, tb):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> bool:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index bb33e04f..75268cbe 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
@@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
"""See SequenceGenerator.check_consistency for docstring."""
txn = db_conn.cursor(txn_name="sequence.check_consistency")
@@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
# There is nothing to do for in memory sequences
pass
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index bd234549..abf53d14 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -50,7 +50,16 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
if type(obj) is frozendict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
- return obj._dict
+ try:
+ # Safety: we catch the AttributeError immediately below.
+ # See https://github.com/matrix-org/python-canonicaljson/issues/36#issuecomment-927816293
+ # for discussion on how frozendict's internals have changed over time.
+ return obj._dict # type: ignore[attr-defined]
+ except AttributeError:
+ # When the C implementation of frozendict is used,
+ # there isn't a `_dict` attribute with a dict
+ # so we resort to making a copy of the frozendict
+ return dict(obj)
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 82d918a0..5df80ea8 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -438,7 +438,8 @@ class ReadWriteLock:
try:
yield
finally:
- new_defer.callback(None)
+ with PreserveLoggingContext():
+ new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
return _ctx_manager()
@@ -466,7 +467,8 @@ class ReadWriteLock:
try:
yield
finally:
- new_defer.callback(None)
+ with PreserveLoggingContext():
+ new_defer.callback(None)
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index e58dd91e..470f4f91 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -85,7 +85,7 @@ class CachedCall(Generic[TV]):
# result in the deferred, since `awaiting` a deferred destroys its result.
# (Also, if it's a Failure, GCing the deferred would log a critical error
# about unhandled Failures)
- def got_result(r):
+ def got_result(r: Union[TV, Failure]) -> None:
self._result = r
self._deferred.addBoth(got_result)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 6262efe0..da502aec 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -31,6 +31,7 @@ from prometheus_client import Gauge
from twisted.internet import defer
from twisted.python import failure
+from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache
@@ -112,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
self.thread: Optional[threading.Thread] = None
@property
- def max_entries(self):
+ def max_entries(self) -> int:
return self.cache.max_size
def check_thread(self) -> None:
@@ -258,7 +259,7 @@ class DeferredCache(Generic[KT, VT]):
return False
- def cb(result) -> None:
+ def cb(result: VT) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
@@ -270,7 +271,7 @@ class DeferredCache(Generic[KT, VT]):
# not have been. Either way, let's double-check now.
entry.invalidate()
- def eb(_fail) -> None:
+ def eb(_fail: Failure) -> None:
compare_and_pop()
entry.invalidate()
@@ -284,11 +285,11 @@ class DeferredCache(Generic[KT, VT]):
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
- ):
+ ) -> None:
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
- def invalidate(self, key):
+ def invalidate(self, key) -> None:
"""Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4ff62b40..a0a7a9de 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
try:
from pympler.asizeof import Asizer
- def _get_size_of(val: Any, *, recurse=True) -> int:
+ def _get_size_of(val: Any, *, recurse: bool = True) -> int:
"""Get an estimate of the size in bytes of the object.
Args:
@@ -71,7 +71,7 @@ try:
except ImportError:
- def _get_size_of(val: Any, *, recurse=True) -> int:
+ def _get_size_of(val: Any, *, recurse: bool = True) -> int:
return 0
@@ -85,15 +85,6 @@ VT = TypeVar("VT")
# a general type var, distinct from either KT or VT
T = TypeVar("T")
-
-def enumerate_leaves(node, depth):
- if depth == 0:
- yield node
- else:
- for n in node.values():
- yield from enumerate_leaves(n, depth - 1)
-
-
P = TypeVar("P")
@@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]):
__slots__ = ["last_access_ts_secs"]
- def update_last_access(self, clock: Clock):
+ def update_last_access(self, clock: Clock) -> None:
self.last_access_ts_secs = int(clock.time())
@@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
@wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
"""Walks the global cache list to find cache entries that haven't been
accessed in the given number of seconds.
"""
@@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int):
logger.info("Dropped %d items from caches", i)
-def setup_expire_lru_cache_entries(hs: "HomeServer"):
+def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
"""Start a background job that expires all cache entries if they have not
been accessed for the given number of seconds.
"""
@@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"):
)
-class _Node:
+class _Node(Generic[KT, VT]):
__slots__ = [
"_list_node",
"_global_list_node",
@@ -197,8 +188,8 @@ class _Node:
def __init__(
self,
root: "ListNode[_Node]",
- key,
- value,
+ key: KT,
+ value: VT,
cache: "weakref.ReferenceType[LruCache]",
clock: Clock,
callbacks: Collection[Callable[[], None]] = (),
@@ -409,7 +400,7 @@ class LruCache(Generic[KT, VT]):
def synchronized(f: FT) -> FT:
@wraps(f)
- def inner(*args, **kwargs):
+ def inner(*args: Any, **kwargs: Any) -> Any:
with lock:
return f(*args, **kwargs)
@@ -418,17 +409,19 @@ class LruCache(Generic[KT, VT]):
cached_cache_len = [0]
if size_callback is not None:
- def cache_len():
+ def cache_len() -> int:
return cached_cache_len[0]
else:
- def cache_len():
+ def cache_len() -> int:
return len(cache)
self.len = synchronized(cache_len)
- def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
+ def add_node(
+ key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
+ ) -> None:
node = _Node(
list_root,
key,
@@ -446,7 +439,7 @@ class LruCache(Generic[KT, VT]):
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
- def move_node_to_front(node: _Node):
+ def move_node_to_front(node: _Node) -> None:
node.move_to_front(real_clock, list_root)
def delete_node(node: _Node) -> int:
@@ -488,7 +481,7 @@ class LruCache(Generic[KT, VT]):
default: Optional[T] = None,
callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
- ):
+ ) -> Union[None, T, VT]:
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
@@ -502,7 +495,9 @@ class LruCache(Generic[KT, VT]):
return default
@synchronized
- def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
+ def cache_set(
+ key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+ ) -> None:
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@@ -547,7 +542,7 @@ class LruCache(Generic[KT, VT]):
...
@synchronized
- def cache_pop(key: KT, default: Optional[T] = None):
+ def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
node = cache.get(key, None)
if node:
delete_node(node)
@@ -612,25 +607,25 @@ class LruCache(Generic[KT, VT]):
self.contains = cache_contains
self.clear = cache_clear
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
result = self.get(key, self.sentinel)
if result is self.sentinel:
raise KeyError()
else:
- return result
+ return cast(VT, result)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: KT, value: VT) -> None:
self.set(key, value)
- def __delitem__(self, key, value):
+ def __delitem__(self, key: KT, value: VT) -> None:
result = self.pop(key, self.sentinel)
if result is self.sentinel:
raise KeyError()
- def __len__(self):
+ def __len__(self) -> int:
return self.len()
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return self.contains(key)
def set_cache_factor(self, factor: float) -> bool:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index ed720433..88ccf443 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -104,8 +104,8 @@ class ResponseCache(Generic[KV]):
return None
def _set(
- self, context: ResponseCacheContext[KV], deferred: defer.Deferred
- ) -> defer.Deferred:
+ self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
+ ) -> "defer.Deferred[RV]":
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
@@ -126,7 +126,7 @@ class ResponseCache(Generic[KV]):
key = context.cache_key
self.pending_result_cache[key] = result
- def on_complete(r):
+ def on_complete(r: RV) -> RV:
# if this cache has a non-zero timeout, and the callback has not cleared
# the should_cache bit, we leave it in the cache for now and schedule
# its removal later.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 27b1da23..330709b8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -40,10 +40,10 @@ class StreamChangeCache:
self,
name: str,
current_stream_pos: int,
- max_size=10000,
+ max_size: int = 10000,
prefilled_cache: Optional[Mapping[EntityType, int]] = None,
- ):
- self._original_max_size = max_size
+ ) -> None:
+ self._original_max_size: int = max_size
self._max_size = math.floor(max_size)
self._entity_to_key: Dict[EntityType, int] = {}
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 46afe3f9..0b9ac26b 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -159,12 +159,12 @@ class TTLCache(Generic[KT, VT]):
del self._expiry_list[0]
-@attr.s(frozen=True, slots=True)
-class _CacheEntry:
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
- expiry_time = attr.ib(type=float)
- ttl = attr.ib(type=float)
- key = attr.ib()
- value = attr.ib()
+ expiry_time: float
+ ttl: float
+ key: Any # should be KT
+ value: Any # should be VT
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index f1a351cf..de04f34e 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -19,6 +19,8 @@ import logging
import os
import signal
import sys
+from types import FrameType, TracebackType
+from typing import NoReturn, Type
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
@@ -97,7 +99,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# (we don't normally expect reactor.run to raise any exceptions, but this will
# also catch any other uncaught exceptions before we get that far.)
- def excepthook(type_, value, traceback):
+ def excepthook(
+ type_: Type[BaseException], value: BaseException, traceback: TracebackType
+ ) -> None:
logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
sys.excepthook = excepthook
@@ -119,7 +123,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
sys.exit(1)
# write a log line on SIGTERM.
- def sigterm(signum, frame):
+ def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn:
logger.warning("Caught signal %s. Stopping daemon." % signum)
sys.exit(0)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1b82dca8..1e784b3f 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -14,9 +14,11 @@
import logging
from functools import wraps
-from typing import Any, Callable, Optional, TypeVar, cast
+from types import TracebackType
+from typing import Any, Callable, Optional, Type, TypeVar, cast
from prometheus_client import Counter
+from typing_extensions import Protocol
from synapse.logging.context import (
ContextResourceUsage,
@@ -24,6 +26,7 @@ from synapse.logging.context import (
current_context,
)
from synapse.metrics import InFlightGauge
+from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -64,6 +67,10 @@ in_flight = InFlightGauge(
T = TypeVar("T", bound=Callable[..., Any])
+class HasClock(Protocol):
+ clock: Clock
+
+
def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
"""
Used to decorate an async function with a `Measure` context manager.
@@ -86,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
block_name = func.__name__ if name is None else name
@wraps(func)
- async def measured_func(self, *args, **kwargs):
+ async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
@@ -104,10 +111,10 @@ class Measure:
"start",
]
- def __init__(self, clock, name: str):
+ def __init__(self, clock: Clock, name: str) -> None:
"""
Args:
- clock: A n object with a "time()" method, which returns the current
+ clock: An object with a "time()" method, which returns the current
time in seconds.
name: The name of the metric to report.
"""
@@ -124,7 +131,7 @@ class Measure:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
self._logging_context = LoggingContext(str(curr_context), parent_context)
- self.start: Optional[int] = None
+ self.start: Optional[float] = None
def __enter__(self) -> "Measure":
if self.start is not None:
@@ -138,7 +145,12 @@ class Measure:
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
@@ -168,8 +180,9 @@ class Measure:
"""
return self._logging_context.get_resource_usage()
- def _update_in_flight(self, metrics):
+ def _update_in_flight(self, metrics) -> None:
"""Gets called when processing in flight metrics"""
+ assert self.start is not None
duration = self.clock.time() - self.start
metrics.real_time_max = max(metrics.real_time_max, duration)
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 9dd010af..1f18654d 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
import functools
import sys
-from typing import Any, Callable, List
+from typing import Any, Callable, Generator, List, TypeVar
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -24,6 +24,9 @@ from twisted.python.failure import Failure
_already_patched = False
+T = TypeVar("T")
+
+
def do_patch() -> None:
"""
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@@ -37,15 +40,19 @@ def do_patch() -> None:
if _already_patched:
return
- def new_inline_callbacks(f):
+ def new_inline_callbacks(
+ f: Callable[..., Generator["Deferred[object]", object, T]]
+ ) -> Callable[..., "Deferred[T]"]:
@functools.wraps(f)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
start_context = current_context()
changes: List[str] = []
- orig = orig_inline_callbacks(_check_yield_points(f, changes))
+ orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+ _check_yield_points(f, changes)
+ )
try:
- res = orig(*args, **kwargs)
+ res: "Deferred[T]" = orig(*args, **kwargs)
except Exception:
if current_context() != start_context:
for err in changes:
@@ -84,7 +91,7 @@ def do_patch() -> None:
print(err, file=sys.stderr)
raise Exception(err)
- def check_ctx(r):
+ def check_ctx(r: T) -> T:
if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
@@ -107,7 +114,10 @@ def do_patch() -> None:
_already_patched = True
-def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
+def _check_yield_points(
+ f: Callable[..., Generator["Deferred[object]", object, T]],
+ changes: List[str],
+) -> Callable:
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct.
@@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
from synapse.logging.context import current_context
@functools.wraps(f)
- def check_yield_points_inner(*args, **kwargs):
+ def check_yield_points_inner(
+ *args: Any, **kwargs: Any
+ ) -> Generator["Deferred[object]", object, T]:
gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index baa9190a..389adf00 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
bool: whether the 3PID medium/address is allowed to be added to this HS
"""
- if hs.config.allowed_local_3pids:
- for constraint in hs.config.allowed_local_3pids:
+ if hs.config.registration.allowed_local_3pids:
+ for constraint in hs.config.registration.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
address,
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 1c20b24b..899ee0ad 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -15,14 +15,18 @@
import logging
import os
import subprocess
+from types import ModuleType
+from typing import Dict
logger = logging.getLogger(__name__)
+version_cache: Dict[ModuleType, str] = {}
-def get_version_string(module) -> str:
+
+def get_version_string(module: ModuleType) -> str:
"""Given a module calculate a git-aware version string for it.
- If called on a module not in a git checkout will return `__verison__`.
+ If called on a module not in a git checkout will return `__version__`.
Args:
module (module)
@@ -31,11 +35,13 @@ def get_version_string(module) -> str:
str
"""
- cached_version = getattr(module, "_synapse_version_string_cache", None)
- if cached_version:
+ cached_version = version_cache.get(module)
+ if cached_version is not None:
return cached_version
- version_string = module.__version__
+ # We want this to fail loudly with an AttributeError. Type-ignore this so
+ # mypy only considers the happy path.
+ version_string = module.__version__ # type: ignore[attr-defined]
try:
null = open(os.devnull, "w")
@@ -97,10 +103,15 @@ def get_version_string(module) -> str:
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- version_string = "%s (%s)" % (module.__version__, git_version)
+ version_string = "%s (%s)" % (
+ # If the __version__ attribute doesn't exist, we'll have failed
+ # loudly above.
+ module.__version__, # type: ignore[attr-defined]
+ git_version,
+ )
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- module._synapse_version_string_cache = version_string
+ version_cache[module] = version_string
return version_string
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cccff7af..3aa9ba3c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
- location=self.hs.config.server_name,
+ location=self.hs.config.server.server_name,
identifier="key",
key=self.hs.config.key.macaroon_secret_key,
)
@@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
- location=self.hs.config.server_name,
+ location=self.hs.config.server.server_name,
identifier="key",
key=self.hs.config.key.macaroon_secret_key,
)
@@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
@@ -303,7 +303,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
@@ -332,7 +332,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
@@ -372,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
@@ -387,7 +387,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
- self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.value.code, 403)
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index a2b5ed20..55f0899b 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -24,7 +24,7 @@ from synapse.appservice.scheduler import (
from synapse.logging.context import make_deferred_yieldable
from tests import unittest
-from tests.test_utils import make_awaitable
+from tests.test_utils import simple_async_mock
from ..utils import MockClock
@@ -49,11 +49,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
- self.store.get_appservice_state = Mock(
- return_value=defer.succeed(ApplicationServiceState.UP)
- )
- txn.send = Mock(return_value=make_awaitable(True))
- self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
+ self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
+ txn.send = simple_async_mock(True)
+ txn.complete = simple_async_mock(True)
+ self.store.create_appservice_txn = simple_async_mock(txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -71,10 +70,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events = [Mock(), Mock()]
txn = Mock(id="idhere", service=service, events=events)
- self.store.get_appservice_state = Mock(
- return_value=defer.succeed(ApplicationServiceState.DOWN)
+ self.store.get_appservice_state = simple_async_mock(
+ ApplicationServiceState.DOWN
)
- self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
+ self.store.create_appservice_txn = simple_async_mock(txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -94,12 +93,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
- self.store.get_appservice_state = Mock(
- return_value=defer.succeed(ApplicationServiceState.UP)
- )
- self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
- txn.send = Mock(return_value=make_awaitable(False)) # fails to send
- self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
+ self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
+ self.store.set_appservice_state = simple_async_mock(True)
+ txn.send = simple_async_mock(False) # fails to send
+ self.store.create_appservice_txn = simple_async_mock(txn)
# actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -122,7 +119,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.as_api = Mock()
self.store = Mock()
self.service = Mock()
- self.callback = Mock()
+ self.callback = simple_async_mock()
self.recoverer = _Recoverer(
clock=self.clock,
as_api=self.as_api,
@@ -144,8 +141,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=make_awaitable(True))
- txn.complete.return_value = make_awaitable(None)
+ txn.send = simple_async_mock(True)
+ txn.complete = simple_async_mock(None)
# wait for exp backoff
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
@@ -170,8 +167,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=make_awaitable(False))
- txn.complete.return_value = make_awaitable(None)
+ txn.send = simple_async_mock(False)
+ txn.complete = simple_async_mock(None)
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
@@ -184,7 +181,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
- txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
+ txn.send = simple_async_mock(True) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count
@@ -195,6 +192,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
self.txn_ctrl = Mock()
+ self.txn_ctrl.send = simple_async_mock()
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self):
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index baa5313f..6a52f862 100644
--- a/tests/config/test_base.py
+++ b/tests/config/test_base.py
@@ -14,23 +14,28 @@
import os.path
import tempfile
+from unittest.mock import Mock
from synapse.config import ConfigError
+from synapse.config._base import Config
from synapse.util.stringutils import random_string
from tests import unittest
-class BaseConfigTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
- self.hs = hs
+class BaseConfigTestCase(unittest.TestCase):
+ def setUp(self):
+ # The root object needs a server property with a public_baseurl.
+ root = Mock()
+ root.server.public_baseurl = "http://test"
+ self.config = Config(root)
def test_loading_missing_templates(self):
# Use a temporary directory that exists on the system, but that isn't likely to
# contain template files
with tempfile.TemporaryDirectory() as tmp_dir:
# Attempt to load an HTML template from our custom template directory
- template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0]
+ template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0]
# If no errors, we should've gotten the default template instead
@@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# Attempt to load the template from our custom template directory
template = (
- self.hs.config.read_templates([template_filename], (tmp_dir,))
+ self.config.read_templates([template_filename], (tmp_dir,))
)[0]
# Render the template
@@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# Retrieve the template.
template = (
- self.hs.config.read_templates(
+ self.config.read_templates(
[template_filename],
(td.name for td in tempdirs),
)
@@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# Retrieve the template.
template = (
- self.hs.config.read_templates(
+ self.config.read_templates(
[other_template_name],
(td.name for td in tempdirs),
)
@@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
def test_loading_template_from_nonexistent_custom_directory(self):
with self.assertRaises(ConfigError):
- self.hs.config.read_templates(
+ self.config.read_templates(
["some_filename.html"], ("a_nonexistent_directory",)
)
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 857d9cd0..4bb82e81 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -12,39 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.config._base import Config, RootConfig
from synapse.config.cache import CacheConfig, add_resizable_cache
from synapse.util.caches.lrucache import LruCache
from tests.unittest import TestCase
-class FakeServer(Config):
- section = "server"
-
-
-class TestConfig(RootConfig):
- config_classes = [FakeServer, CacheConfig]
-
-
class CacheConfigTests(TestCase):
def setUp(self):
- # Reset caches before each test
- TestConfig().caches.reset()
+ # Reset caches before each test since there's global state involved.
+ self.config = CacheConfig()
+ self.config.reset()
+
+ def tearDown(self):
+ # Also reset the caches after each test to leave state pristine.
+ self.config.reset()
def test_individual_caches_from_environ(self):
"""
Individual cache factors will be loaded from the environment.
"""
config = {}
- t = TestConfig()
- t.caches._environ = {
+ self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
"SYNAPSE_NOT_CACHE": "BLAH",
}
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+ self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
def test_config_overrides_environ(self):
"""
@@ -52,15 +47,14 @@ class CacheConfigTests(TestCase):
over those in the config.
"""
config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
- t = TestConfig()
- t.caches._environ = {
+ self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
"SYNAPSE_CACHE_FACTOR_FOO": 1,
}
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(
- dict(t.caches.cache_factors),
+ dict(self.config.cache_factors),
{"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
)
@@ -76,8 +70,7 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 50)
config = {"caches": {"per_cache_factors": {"foo": 3}}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config)
self.assertEqual(cache.max_size, 300)
@@ -88,8 +81,7 @@ class CacheConfigTests(TestCase):
there is one.
"""
config = {"caches": {"per_cache_factors": {"foo": 2}}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -106,8 +98,7 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 50)
config = {"caches": {"global_factor": 4}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(cache.max_size, 400)
@@ -118,8 +109,7 @@ class CacheConfigTests(TestCase):
is no per-cache factor.
"""
config = {"caches": {"global_factor": 1.5}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -133,12 +123,11 @@ class CacheConfigTests(TestCase):
"per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
}
}
- t = TestConfig()
- t.caches._environ = {
+ self.config._environ = {
"SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
"SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
}
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache_a = LruCache(100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
@@ -158,11 +147,10 @@ class CacheConfigTests(TestCase):
"""
config = {"caches": {"event_cache_size": "10k"}}
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.read_config(config, config_dir_path="", data_dir_path="")
cache = LruCache(
- max_size=t.caches.event_cache_size,
+ max_size=self.config.event_cache_size,
apply_cache_factor_from_config=False,
)
add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index ef6c2bee..59635de2 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -49,7 +49,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertTrue(
- hasattr(config, "macaroon_secret_key"),
+ hasattr(config.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
if len(config.key.macaroon_secret_key) < 5:
@@ -60,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
self.assertTrue(
- hasattr(config, "macaroon_secret_key"),
+ hasattr(config.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
if len(config.key.macaroon_secret_key) < 5:
@@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase):
config1 = HomeServerConfig.load_config("", ["-c", self.file])
config2 = HomeServerConfig.load_config("", ["-c", self.file])
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
- self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
- self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key)
+ self.assertEqual(
+ config1.key.macaroon_secret_key, config2.key.macaroon_secret_key
+ )
+ self.assertEqual(
+ config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
+ )
def test_disable_registration(self):
self.generate_config()
@@ -84,16 +88,16 @@ class ConfigLoadingTestCase(unittest.TestCase):
)
# Check that disable_registration clobbers enable_registration.
config = HomeServerConfig.load_config("", ["-c", self.file])
- self.assertFalse(config.enable_registration)
+ self.assertFalse(config.registration.enable_registration)
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
- self.assertFalse(config.enable_registration)
+ self.assertFalse(config.registration.enable_registration)
# Check that either config value is clobbered by the command line.
config = HomeServerConfig.load_or_generate_config(
"", ["-c", self.file, "--enable-registration"]
)
- self.assertTrue(config.enable_registration)
+ self.assertTrue(config.registration.enable_registration)
def test_stats_enabled(self):
self.generate_config_and_remove_lines_containing("enable_metrics")
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b6bc1876..9ba57815 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -42,9 +42,9 @@ class TLSConfigTests(TestCase):
"""
config = {}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
def test_tls_client_minimum_set(self):
"""
@@ -52,29 +52,29 @@ class TLSConfigTests(TestCase):
"""
config = {"federation_client_minimum_tls_version": 1}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
config = {"federation_client_minimum_tls_version": 1.1}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.1")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1")
config = {"federation_client_minimum_tls_version": 1.2}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
# Also test a string version
config = {"federation_client_minimum_tls_version": "1"}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
config = {"federation_client_minimum_tls_version": "1.2"}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
def test_tls_client_minimum_1_point_3_missing(self):
"""
@@ -91,7 +91,7 @@ class TLSConfigTests(TestCase):
config = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig()
with self.assertRaises(ConfigError) as e:
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(
e.exception.args[0],
(
@@ -112,8 +112,8 @@ class TLSConfigTests(TestCase):
config = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- self.assertEqual(t.federation_client_minimum_tls_version, "1.3")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
+ self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
def test_tls_client_minimum_set_passed_through_1_2(self):
"""
@@ -121,7 +121,7 @@ class TLSConfigTests(TestCase):
"""
config = {"federation_client_minimum_tls_version": 1.2}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
@@ -137,7 +137,7 @@ class TLSConfigTests(TestCase):
"""
config = {"federation_client_minimum_tls_version": 1}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
@@ -159,7 +159,7 @@ class TLSConfigTests(TestCase):
}
t = TestConfig()
e = self.assertRaises(
- ConfigError, t.read_config, config, config_dir_path="", data_dir_path=""
+ ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path=""
)
self.assertIn("IDNA domain names", str(e))
@@ -174,7 +174,7 @@ class TLSConfigTests(TestCase):
]
}
t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
+ t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 3b3866bf..3deb14c3 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -26,6 +26,7 @@ from synapse.rest.client import login, presence, room
from synapse.types import JsonDict, StreamToken, create_requester
from tests.handlers.test_sync import generate_sync_config
+from tests.test_utils import simple_async_mock
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
@@ -133,8 +134,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
+ # Mock out the calls over federation.
+ fed_transport_client = Mock(spec=["send_transaction"])
+ fed_transport_client.send_transaction = simple_async_mock({})
+
hs = self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=fed_transport_client,
)
# Load the modules into the homeserver
module_api = hs.get_module_api()
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 65b18fbd..b457dad6 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -336,7 +336,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
recovery
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+ mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
u1 = self.register_user("user", "pass")
@@ -376,7 +376,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+ mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
u1 = self.register_user("user", "pass")
@@ -429,7 +429,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# now the server goes offline
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+ mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 0b60cc42..03e1e11f 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(
channel.json_body["room_version"],
- self.hs.config.default_room_version.identifier,
+ self.hs.config.server.default_room_version.identifier,
)
members = set(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 57cc3e26..c153018f 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
def test_set_my_name_if_disabled(self):
- self.hs.config.enable_set_displayname = False
+ self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed
self.get_success(
@@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
def test_set_my_avatar_if_disabled(self):
- self.hs.config.enable_set_avatar_url = False
+ self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
self.get_success(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index d3efb67e..db691c4c 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -16,7 +16,12 @@ from unittest.mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, ResourceLimitError, SynapseError
+from synapse.api.errors import (
+ CodeMessageException,
+ Codes,
+ ResourceLimitError,
+ SynapseError,
+)
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -120,14 +125,24 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config = self.default_config()
# some of the tests rely on us having a user consent version
- hs_config["user_consent"] = {
- "version": "test_consent_version",
- "template_dir": ".",
- }
+ hs_config.setdefault("user_consent", {}).update(
+ {
+ "version": "test_consent_version",
+ "template_dir": ".",
+ }
+ )
hs_config["max_mau_value"] = 50
hs_config["limit_usage_by_mau"] = True
- hs = self.setup_test_homeserver(config=hs_config)
+ # Don't attempt to reach out over federation.
+ self.mock_federation_client = Mock()
+ self.mock_federation_client.make_query.side_effect = CodeMessageException(
+ 500, ""
+ )
+
+ hs = self.setup_test_homeserver(
+ config=hs_config, federation_client=self.mock_federation_client
+ )
load_legacy_spam_checkers(hs)
@@ -138,9 +153,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
- self.mock_distributor = Mock()
- self.mock_distributor.declare("registered_user")
- self.mock_captcha_client = Mock()
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastore()
self.lots_of_users = 100
@@ -174,21 +186,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
+ @override_config({"limit_usage_by_mau": False})
def test_mau_limits_when_disabled(self):
- self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
+ @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
- self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value - 1)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
+ @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
- self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -198,15 +210,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
+ @override_config({"limit_usage_by_mau": True})
def test_register_mau_blocked(self):
- self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -215,16 +227,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
+ @override_config(
+ {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
+ )
def test_auto_join_rooms_for_guests(self):
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
- self.hs.config.auto_join_rooms_for_guests = False
user_id = self.get_success(
self.handler.register_user(localpart="jeff", make_guest=True),
)
@@ -243,34 +255,33 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_id["room_id"] in rooms)
self.assertEqual(len(rooms), 1)
+ @override_config({"auto_join_rooms": []})
def test_auto_create_auto_join_rooms_with_no_rooms(self):
- self.hs.config.auto_join_rooms = []
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config({"auto_join_rooms": ["#room:another"]})
def test_auto_create_auto_join_where_room_is_another_domain(self):
- self.hs.config.auto_join_rooms = ["#room:another"]
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config(
+ {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
+ )
def test_auto_create_auto_join_where_auto_create_is_false(self):
- self.hs.config.autocreate_auto_join_rooms = False
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
-
self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -294,10 +305,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_id["room_id"] in rooms)
self.assertEqual(len(rooms), 1)
+ @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
-
self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
@@ -510,6 +519,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(rooms, set())
self.assertEqual(invited_rooms, [])
+ @override_config(
+ {
+ "user_consent": {
+ "block_events_error": "Error",
+ "require_at_registration": True,
+ },
+ "form_secret": "53cr3t",
+ "public_baseurl": "http://test",
+ "auto_join_rooms": ["#room:test"],
+ },
+ )
def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
@@ -521,25 +541,20 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# * The server is configured to auto-join to a room
# (and autocreate if necessary)
- event_creation_handler = self.hs.get_event_creation_handler()
- # (Messing with the internals of event_creation_handler is fragile
- # but can't see a better way to do this. One option could be to subclass
- # the test with custom config.)
- event_creation_handler._block_events_without_consent_error = "Error"
- event_creation_handler._consent_uri_builder = Mock()
- room_alias_str = "#room:test"
- self.hs.config.auto_join_rooms = [room_alias_str]
-
# When:-
- # * the user is registered and post consent actions are called
+ # * the user is registered
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
- self.get_success(self.handler.post_consent_actions(user_id))
# Then:-
# * Ensure that they have not been joined to the room
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
+ # The user provides consent; ensure they are now in the rooms.
+ self.get_success(self.handler.post_consent_actions(user_id))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 1)
+
def test_register_support_user(self):
user_id = self.get_success(
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 24b7ef6e..56207f4d 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -103,12 +103,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the stats via the background update
self._add_background_updates()
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
def test_initial_room(self):
"""
@@ -140,12 +135,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
r = self.get_success(self.get_all_room_state())
@@ -568,12 +558,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
r1stats_complete = self._get_current_stats("room", r1)
u1stats_complete = self._get_current_stats("user", u1)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 266333c5..0120b468 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,47 +11,238 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Tuple
-from unittest.mock import Mock
+from typing import Tuple
+from unittest.mock import Mock, patch
from urllib.parse import quote
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.rest.client import login, room, user_directory
+from synapse.appservice import ApplicationService
+from synapse.rest.client import login, register, room, user_directory
+from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
from synapse.types import create_requester
+from synapse.util import Clock
from tests import unittest
+from tests.storage.test_user_directory import GetUserDirectoryTables
+from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
class UserDirectoryTestCase(unittest.HomeserverTestCase):
- """
- Tests the UserDirectoryHandler.
+ """Tests the UserDirectoryHandler.
+
+ We're broadly testing two kinds of things here.
+
+ 1. Check that we correctly update the user directory in response
+ to events (e.g. join a room, leave a room, change name, make public)
+ 2. Check that the search logic behaves as expected.
+
+ The background process that rebuilds the user directory is tested in
+ tests/storage/test_user_directory.py.
"""
servlets = [
login.register_servlets,
synapse.rest.admin.register_servlets,
+ register.register_servlets,
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["update_user_directory"] = True
- return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, hs):
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ hostname="test",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ # Note: this user does not match the regex above, so that tests
+ # can distinguish the sender from the AS user.
+ sender="@as_main:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.event_creation_handler = self.hs.get_event_creation_handler()
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+ def test_normal_user_pair(self) -> None:
+ """Sanity check that the room-sharing tables are updated correctly."""
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ public = self.helper.create_room_as(
+ alice,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=alice_token,
+ )
+ private = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+ self.helper.invite(private, alice, bob, tok=alice_token)
+ self.helper.join(public, bob, tok=bob_token)
+ self.helper.join(private, bob, tok=bob_token)
+
+ # Alice also makes a second public room but no-one else joins
+ public2 = self.helper.create_room_as(
+ alice,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=alice_token,
+ )
+
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ in_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(
+ set(in_public), {(alice, public), (bob, public), (alice, public2)}
+ )
+ self.assertEqual(
+ self.user_dir_helper._compress_shared(in_private),
+ {(alice, bob, private), (bob, alice, private)},
+ )
+
+ # The next four tests (test_excludes_*) all setup
+ # - A normal user included in the user dir
+ # - A public and private room created by that user
+ # - A user excluded from the room dir, belonging to both rooms
+
+ # They match similar logic in storage/test_user_directory. But that tests
+ # rebuilding the directory; this tests updating it incrementally.
+
+ def test_excludes_support_user(self) -> None:
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ support = "@support1:test"
+ self.get_success(
+ self.store.register_user(
+ user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+ )
+ )
+
+ public, private = self._create_rooms_and_inject_memberships(
+ alice, alice_token, support
+ )
+ self._check_only_one_user_in_directory(alice, public)
+
+ def test_excludes_deactivated_user(self) -> None:
+ admin = self.register_user("admin", "pass", admin=True)
+ admin_token = self.login(admin, "pass")
+ user = self.register_user("naughty", "pass")
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{user}",
+ access_token=admin_token,
+ content={"deactivated": True},
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["deactivated"], True)
+
+ # Join the deactivated user to rooms owned by the admin.
+ # Is this something that could actually happen outside of a test?
+ public, private = self._create_rooms_and_inject_memberships(
+ admin, admin_token, user
+ )
+ self._check_only_one_user_in_directory(admin, public)
+
+ def test_excludes_appservices_user(self) -> None:
+ # Register an AS user.
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+
+ # Join the AS user to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, as_user
+ )
+ self._check_only_one_user_in_directory(user, public)
+
+ def test_excludes_appservice_sender(self) -> None:
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ room = self.helper.create_room_as(user, is_public=True, tok=token)
+ self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
+ self._check_only_one_user_in_directory(user, room)
+
+ def test_user_not_in_users_table(self) -> None:
+ """Unclear how it happens, but on matrix.org we've seen join events
+ for users who aren't in the users table. Test that we don't fall over
+ when processing such a user.
+ """
+ user1 = self.register_user("user1", "pass")
+ token1 = self.login(user1, "pass")
+ room = self.helper.create_room_as(user1, is_public=True, tok=token1)
+
+ # Inject a join event for a user who doesn't exist
+ self.get_success(inject_member_event(self.hs, room, "@not-a-user:test", "join"))
+
+ # Another new user registers and joins the room
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login(user2, "pass")
+ self.helper.join(room, user2, tok=token2)
+
+ # The dodgy event should not have stopped us from processing user2's join.
+ in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(in_public), {(user1, room), (user2, room)})
+
+ def _create_rooms_and_inject_memberships(
+ self, creator: str, token: str, joiner: str
+ ) -> Tuple[str, str]:
+ """Create a public and private room as a normal user.
+ Then get the `joiner` into those rooms.
+ """
+ # TODO: Duplicates the same-named method in UserDirectoryInitialPopulationTest.
+ public_room = self.helper.create_room_as(
+ creator,
+ is_public=True,
+ # See https://github.com/matrix-org/synapse/issues/10951
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ private_room = self.helper.create_room_as(creator, is_public=False, tok=token)
- def test_handle_local_profile_change_with_support_user(self):
+ # HACK: get the user into these rooms
+ self.get_success(inject_member_event(self.hs, public_room, joiner, "join"))
+ self.get_success(inject_member_event(self.hs, private_room, joiner, "join"))
+
+ return public_room, private_room
+
+ def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ in_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+
+ self.assertEqual(users, {user})
+ self.assertEqual(set(in_public), {(user, public)})
+ self.assertEqual(in_private, [])
+
+ def test_handle_local_profile_change_with_support_user(self) -> None:
support_user_id = "@support:test"
self.get_success(
self.store.register_user(
@@ -64,10 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.handler.handle_local_profile_change(support_user_id, None)
+ self.handler.handle_local_profile_change(
+ support_user_id, ProfileInfo("I love support me", None)
+ )
)
profile = self.get_success(self.store.get_user_in_directory(support_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
@@ -77,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
self.assertTrue(profile["display_name"] == display_name)
- def test_handle_local_profile_change_with_deactivated_user(self):
+ def test_handle_local_profile_change_with_deactivated_user(self) -> None:
# create user
r_user_id = "@regular:test"
self.get_success(
@@ -101,7 +294,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
# update profile after deactivation
self.get_success(
@@ -110,9 +303,50 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is furthermore not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
+
+ def test_handle_local_profile_change_with_appservice_user(self) -> None:
+ # create user
+ as_user_id = self.register_appservice_user(
+ "as_user_alice", self.appservice.token
+ )
+
+ # profile is not in directory
+ profile = self.get_success(self.store.get_user_in_directory(as_user_id))
+ self.assertIsNone(profile)
- def test_handle_user_deactivated_support_user(self):
+ # update profile
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
+ self.get_success(
+ self.handler.handle_local_profile_change(as_user_id, profile_info)
+ )
+
+ # profile is still not in directory
+ profile = self.get_success(self.store.get_user_in_directory(as_user_id))
+ self.assertIsNone(profile)
+
+ def test_handle_local_profile_change_with_appservice_sender(self) -> None:
+ # profile is not in directory
+ profile = self.get_success(
+ self.store.get_user_in_directory(self.appservice.sender)
+ )
+ self.assertIsNone(profile)
+
+ # update profile
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
+ self.get_success(
+ self.handler.handle_local_profile_change(
+ self.appservice.sender, profile_info
+ )
+ )
+
+ # profile is still not in directory
+ profile = self.get_success(
+ self.store.get_user_in_directory(self.appservice.sender)
+ )
+ self.assertIsNone(profile)
+
+ def test_handle_user_deactivated_support_user(self) -> None:
s_user_id = "@support:test"
self.get_success(
self.store.register_user(
@@ -120,20 +354,29 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
- self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
- self.store.remove_from_user_dir.not_called()
+ mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ with patch.object(
+ self.store, "remove_from_user_dir", mock_remove_from_user_dir
+ ):
+ self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
+ # BUG: the correct spelling is assert_not_called, but that makes the test fail
+ # and it's not clear that this is actually the behaviour we want.
+ mock_remove_from_user_dir.not_called()
- def test_handle_user_deactivated_regular_user(self):
+ def test_handle_user_deactivated_regular_user(self) -> None:
r_user_id = "@regular:test"
self.get_success(
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
- self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
- self.store.remove_from_user_dir.called_once_with(r_user_id)
- def test_reactivation_makes_regular_user_searchable(self):
+ mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ with patch.object(
+ self.store, "remove_from_user_dir", mock_remove_from_user_dir
+ ):
+ self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
+ mock_remove_from_user_dir.assert_called_once_with(r_user_id)
+
+ def test_reactivation_makes_regular_user_searchable(self) -> None:
user = self.register_user("regular", "pass")
user_token = self.login(user, "pass")
admin_user = self.register_user("admin", "pass", admin=True)
@@ -171,7 +414,147 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
self.assertEqual(s["results"][0]["user_id"], user)
- def test_private_room(self):
+ def test_process_join_after_server_leaves_room(self) -> None:
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ # Alice makes two rooms. Bob joins one of them.
+ room1 = self.helper.create_room_as(alice, tok=alice_token)
+ room2 = self.helper.create_room_as(alice, tok=alice_token)
+ self.helper.join(room1, bob, tok=bob_token)
+
+ # The user sharing tables should have been updated.
+ public1 = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(public1), {(alice, room1), (alice, room2), (bob, room1)})
+
+ # Alice leaves room1. The user sharing tables should be updated.
+ self.helper.leave(room1, alice, tok=alice_token)
+ public2 = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(public2), {(alice, room2), (bob, room1)})
+
+ # Pause the processing of new events.
+ dir_handler = self.hs.get_user_directory_handler()
+ dir_handler.update_user_directory = False
+
+ # Bob leaves one room and joins the other.
+ self.helper.leave(room1, bob, tok=bob_token)
+ self.helper.join(room2, bob, tok=bob_token)
+
+ # Process the leave and join in one go.
+ dir_handler.update_user_directory = True
+ dir_handler.notify_new_event()
+ self.wait_for_background_updates()
+
+ # The user sharing tables should have been updated.
+ public3 = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(public3), {(alice, room2), (bob, room2)})
+
+ def test_per_room_profile_doesnt_alter_directory_entry(self) -> None:
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+
+ # Alice should have a user directory entry created at registration.
+ users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory())
+ self.assertEqual(
+ users[alice], ProfileInfo(display_name="alice", avatar_url=None)
+ )
+
+ # Alice makes a room for herself.
+ room = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+
+ # Alice sets a nickname unique to that room.
+ self.helper.send_state(
+ room,
+ "m.room.member",
+ {
+ "displayname": "Freddy Mercury",
+ "membership": "join",
+ },
+ alice_token,
+ state_key=alice,
+ )
+
+ # Alice's display name remains the same in the user directory.
+ search_result = self.get_success(self.handler.search_users(bob, alice, 10))
+ self.assertEqual(
+ search_result["results"],
+ [{"display_name": "alice", "avatar_url": None, "user_id": alice}],
+ 0,
+ )
+
+ def test_making_room_public_doesnt_alter_directory_entry(self) -> None:
+ """Per-room names shouldn't go to the directory when the room becomes public.
+
+ This isn't about preventing a leak (the room is now public, so the nickname
+ is too). It's about preserving the invariant that we only show a user's public
+ profile in the user directory results.
+
+ I made this a Synapse test case rather than a Complement one because
+ I think this is (strictly speaking) an implementation choice. Synapse
+ has chosen to only ever use the public profile when responding to a user
+ directory search. There's no privacy leak here, because making the room
+ public discloses the per-room name.
+
+ The spec doesn't mandate anything about _how_ a user
+ should appear in a /user_directory/search result. Hypothetical example:
+ suppose Bob searches for Alice. When representing Alice in a search
+ result, it's reasonable to use any of Alice's nicknames that Bob is
+ aware of. Heck, maybe we even want to use lots of them in a combined
+ displayname like `Alice (aka "ali", "ally", "41iC3")`.
+ """
+
+ # TODO the same should apply when Alice is a remote user.
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ # Alice and Bob are in a private room.
+ room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+ self.helper.invite(room, src=alice, targ=bob, tok=alice_token)
+ self.helper.join(room, user=bob, tok=bob_token)
+
+ # Alice has a nickname unique to that room.
+
+ self.helper.send_state(
+ room,
+ "m.room.member",
+ {
+ "displayname": "Freddy Mercury",
+ "membership": "join",
+ },
+ alice_token,
+ state_key=alice,
+ )
+
+ # Check Alice isn't recorded as being in a public room.
+ public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertNotIn((alice, room), public)
+
+ # One of them makes the room public.
+ self.helper.send_state(
+ room,
+ "m.room.join_rules",
+ {"join_rule": "public"},
+ alice_token,
+ )
+
+ # Check that Alice is now recorded as being in a public room
+ public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertIn((alice, room), public)
+
+ # Alice's display name remains the same in the user directory.
+ search_result = self.get_success(self.handler.search_users(bob, alice, 10))
+ self.assertEqual(
+ search_result["results"],
+ [{"display_name": "alice", "avatar_url": None, "user_id": alice}],
+ 0,
+ )
+
+ def test_private_room(self) -> None:
"""
A user can be searched for only by people that are either in a public
room, or that share a private chat.
@@ -191,11 +574,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Check we have populated the database correctly.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
self.assertEqual(
- self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
@@ -215,10 +603,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.leave(room, user=u2, tok=u2_token)
# Check we have removed the values.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
- self.assertEqual(self._compress_shared(shares_private), set())
+ self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
self.assertEqual(public_users, [])
# User1 now gets no search results for any of the other users.
@@ -228,7 +620,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
- def test_spam_checker(self):
+ def test_spam_checker(self) -> None:
"""
A user which fails the spam checks will not appear in search results.
"""
@@ -246,11 +638,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Check we have populated the database correctly.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
self.assertEqual(
- self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
@@ -258,7 +655,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- async def allow_all(user_profile):
+ async def allow_all(user_profile: ProfileInfo) -> bool:
# Allow all users.
return False
@@ -272,7 +669,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- async def block_all(user_profile):
+ async def block_all(user_profile: ProfileInfo) -> bool:
# All users are spammy.
return True
@@ -282,7 +679,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
- def test_legacy_spam_checker(self):
+ def test_legacy_spam_checker(self) -> None:
"""
A spam checker without the expected method should be ignored.
"""
@@ -300,11 +697,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Check we have populated the database correctly.
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
self.assertEqual(
- self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
@@ -317,134 +719,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- def _compress_shared(self, shared):
- """
- Compress a list of users who share rooms dicts to a list of tuples.
- """
- r = set()
- for i in shared:
- r.add((i["user_id"], i["other_user_id"], i["room_id"]))
- return r
-
- def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
- r = self.get_success(
- self.store.db_pool.simple_select_list(
- "users_in_public_rooms", None, ("user_id", "room_id")
- )
- )
- retval = []
- for i in r:
- retval.append((i["user_id"], i["room_id"]))
- return retval
-
- def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
- return self.get_success(
- self.store.db_pool.simple_select_list(
- "users_who_share_private_rooms",
- None,
- ["user_id", "other_user_id", "room_id"],
- )
- )
-
- def _add_background_updates(self):
- """
- Add the background updates we need to run.
- """
- # Ugh, have to reset this flag
- self.store.db_pool.updates._all_done = False
-
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_createtables",
- "progress_json": "{}",
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_process_rooms",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_createtables",
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_process_users",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_rooms",
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_user_directory_cleanup",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_users",
- },
- )
- )
-
- def test_initial(self):
- """
- The user directory's initial handler correctly updates the search tables.
- """
- u1 = self.register_user("user1", "pass")
- u1_token = self.login(u1, "pass")
- u2 = self.register_user("user2", "pass")
- u2_token = self.login(u2, "pass")
- u3 = self.register_user("user3", "pass")
- u3_token = self.login(u3, "pass")
-
- room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
- self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
- self.helper.join(room, user=u2, tok=u2_token)
-
- private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
- self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token)
- self.helper.join(private_room, user=u3, tok=u3_token)
-
- self.get_success(self.store.update_user_directory_stream_pos(None))
- self.get_success(self.store.delete_all_from_user_dir())
-
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
-
- # Nothing updated yet
- self.assertEqual(shares_private, [])
- self.assertEqual(public_users, [])
-
- # Do the initial population of the user directory via the background update
- self._add_background_updates()
-
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
-
- # User 1 and User 2 are in the same public room
- self.assertEqual(set(public_users), {(u1, room), (u2, room)})
-
- # User 1 and User 3 share private rooms
- self.assertEqual(
- self._compress_shared(shares_private),
- {(u1, u3, private_room), (u3, u1, private_room)},
- )
-
- def test_initial_share_all_users(self):
+ def test_initial_share_all_users(self) -> None:
"""
Search all users = True means that a user does not have to share a
private room with the searching user or be in a public room to be search
@@ -457,26 +732,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.register_user("user2", "pass")
u3 = self.register_user("user3", "pass")
- # Wipe the user dir
- self.get_success(self.store.update_user_directory_stream_pos(None))
- self.get_success(self.store.delete_all_from_user_dir())
-
- # Do the initial population of the user directory via the background update
- self._add_background_updates()
-
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
- shares_private = self.get_users_who_share_private_rooms()
- public_users = self.get_users_in_public_rooms()
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
# No users share rooms
self.assertEqual(public_users, [])
- self.assertEqual(self._compress_shared(shares_private), set())
+ self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
# Despite not sharing a room, search_all_users means we get a search
# result.
@@ -501,7 +766,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
}
}
)
- def test_prefer_local_users(self):
+ def test_prefer_local_users(self) -> None:
"""Tests that local users are shown higher in search results when
user_directory.prefer_local_users is True.
"""
@@ -535,15 +800,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
local_users = [local_user_1, local_user_2, local_user_3]
remote_users = [remote_user_1, remote_user_2, remote_user_3]
- # Populate the user directory via background update
- self._add_background_updates()
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
# The local searching user searches for the term "user", which other users have
# in their user id
results = self.get_success(
@@ -565,7 +821,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
room_id: str,
room_version: RoomVersion,
user_id: str,
- ):
+ ) -> None:
# Add a user to the room.
builder = self.event_builder_factory.for_room_version(
room_version,
@@ -588,8 +844,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
- user_id = "@test:test"
-
servlets = [
user_directory.register_servlets,
room.register_servlets,
@@ -597,7 +851,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["update_user_directory"] = True
hs = self.setup_test_homeserver(config=config)
@@ -606,19 +860,24 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
return hs
- def test_disabling_room_list(self):
+ def test_disabling_room_list(self) -> None:
self.config.userdirectory.user_directory_search_enabled = True
- # First we create a room with another user so that user dir is non-empty
- # for our user
- self.helper.create_room_as(self.user_id)
+ # Create two users and put them in the same room.
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass")
- room = self.helper.create_room_as(self.user_id)
- self.helper.join(room, user=u2)
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
- # Assert user directory is not empty
+ # Each should see the other when searching the user directory.
channel = self.make_request(
- "POST", b"user_directory/search", b'{"search_term":"user2"}'
+ "POST",
+ b"user_directory/search",
+ b'{"search_term":"user2"}',
+ access_token=u1_token,
)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) > 0)
@@ -626,7 +885,10 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
# Disable user directory and check search returns nothing
self.config.userdirectory.user_directory_search_enabled = False
channel = self.make_request(
- "POST", b"user_directory/search", b'{"search_term":"user2"}'
+ "POST",
+ b"user_directory/search",
+ b'{"search_term":"user2"}',
+ access_token=u1_token,
)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index d9a8b077..638babae 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -226,7 +226,7 @@ class FederationClientTests(HomeserverTestCase):
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
# Set up the ip_range blacklist
- self.hs.config.federation_ip_range_blacklist = IPSet(
+ self.hs.config.server.federation_ip_range_blacklist = IPSet(
["127.0.0.0/8", "fe80::/64"]
)
self.reactor.lookups["internal"] = "127.0.0.1"
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index f73fcd68..96f399b7 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -198,3 +198,31 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
self.assertEqual(log["url"], "/_matrix/client/versions")
self.assertEqual(log["protocol"], "1.1")
self.assertEqual(log["user_agent"], "")
+
+ def test_with_exception(self):
+ """
+ The logging exception type & value should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ logger = self.get_logger(handler)
+
+ try:
+ raise ValueError("That's wrong, you wally!")
+ except ValueError:
+ logger.exception("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
+
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ "exc_type",
+ "exc_value",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertEqual(log["exc_type"], "ValueError")
+ self.assertEqual(log["exc_value"], "That's wrong, you wally!")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9d38974f..e915dd5c 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -25,6 +25,7 @@ from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -46,8 +47,12 @@ class ModuleApiTestCase(HomeserverTestCase):
self.auth_handler = homeserver.get_auth_handler()
def make_homeserver(self, reactor, clock):
+ # Mock out the calls over federation.
+ fed_transport_client = Mock(spec=["send_transaction"])
+ fed_transport_client.send_transaction = simple_async_mock({})
+
return self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=fed_transport_client,
)
def test_can_register_user(self):
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c7555c26..eac4664b 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -70,8 +70,16 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# databases objects are the same.
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
+ # Normally we'd pass in the handler to `setup_test_homeserver`, which would
+ # eventually hit "Install @cache_in_self attributes" in tests/utils.py.
+ # Unfortunately our handler wants a reference to the homeserver. That leaves
+ # us with a chicken-and-egg problem.
+ # We can workaround this: create the homeserver first, create the handler
+ # and bodge it in after the fact. The bodging requires us to know the
+ # dirty details of how `cache_in_self` works. We politely ask mypy to
+ # ignore our dirty dealings.
self.test_handler = self._build_replication_data_handler()
- self.worker_hs._replication_data_handler = self.test_handler
+ self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined]
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
@@ -240,7 +248,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- b"localhost",
+ "localhost",
6379,
self.connect_any_redis_attempts,
)
@@ -315,12 +323,15 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
)
+ # Copy the port into a new, non-Optional variable so mypy knows we're
+ # not going to reset `instance_loc` to `None` under its feet. See
+ # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
+ port = instance_loc.port
+
self.reactor.add_tcp_client_callback(
self.reactor.lookups[instance_loc.host],
instance_loc.port,
- lambda: self._handle_http_replication_attempt(
- worker_hs, instance_loc.port
- ),
+ lambda: self._handle_http_replication_attempt(worker_hs, port),
)
store = worker_hs.get_datastore()
@@ -424,7 +435,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, b"localhost")
+ self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ee3ae9cc..6ed9e421 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver()
- self.hs.config.registration_shared_secret = "shared"
+ self.hs.config.registration.registration_shared_secret = "shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
@@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
If there is no shared secret, registration through this method will be
prevented.
"""
- self.hs.config.registration_shared_secret = None
+ self.hs.config.registration.registration_shared_secret = None
channel = self.make_request("POST", self.url, b"{}")
@@ -422,7 +422,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -1485,7 +1485,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -1522,7 +1522,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 9e9e953c..89d85b0a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -470,13 +470,45 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
+ def default_config(self):
+ config = super().default_config()
+ config["allow_guest_access"] = True
+ return config
+
def test_GET_whoami(self):
device_id = "wouldgohere"
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test", device_id=device_id)
- whoami = self.whoami(tok)
- self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id})
+ whoami = self._whoami(tok)
+ self.assertEqual(
+ whoami,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ # Unstable until MSC3069 enters spec
+ "org.matrix.msc3069.is_guest": False,
+ },
+ )
+
+ def test_GET_whoami_guests(self):
+ channel = self.make_request(
+ b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
+ )
+ tok = channel.json_body["access_token"]
+ user_id = channel.json_body["user_id"]
+ device_id = channel.json_body["device_id"]
+
+ whoami = self._whoami(tok)
+ self.assertEqual(
+ whoami,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ # Unstable until MSC3069 enters spec
+ "org.matrix.msc3069.is_guest": True,
+ },
+ )
def test_GET_whoami_appservices(self):
user_id = "@as:test"
@@ -484,18 +516,25 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": user_id, "exclusive": True}]},
sender=user_id,
)
self.hs.get_datastore().services_cache.append(appservice)
- whoami = self.whoami(as_token)
- self.assertEqual(whoami, {"user_id": user_id})
+ whoami = self._whoami(as_token)
+ self.assertEqual(
+ whoami,
+ {
+ "user_id": user_id,
+ # Unstable until MSC3069 enters spec
+ "org.matrix.msc3069.is_guest": False,
+ },
+ )
self.assertFalse(hasattr(whoami, "device_id"))
- def whoami(self, tok):
+ def _whoami(self, tok):
channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
self.assertEqual(channel.code, 200)
return channel.json_body
@@ -625,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed"""
- self.hs.config.enable_3pid_changes = False
+ self.hs.config.registration.enable_3pid_changes = False
client_secret = "foobar"
session_id = self._request_token(self.email, client_secret)
@@ -695,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_delete_email_if_disabled(self):
"""Test deleting an email from profile when disallowed"""
- self.hs.config.enable_3pid_changes = False
+ self.hs.config.registration.enable_3pid_changes = False
# Add a threepid
self.get_success(
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 422361b6..b9e36025 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -55,7 +55,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
self.assertEqual(
- self.config.default_room_version.identifier,
+ self.config.server.default_room_version.identifier,
capabilities["m.room_versions"]["default"],
)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index ca2e8ff8..becb4e8d 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
return self.hs
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ self.hs.config.registration.enable_3pid_lookup = False
self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 371615a0..a63f04bd 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -94,9 +94,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.enable_registration = True
- self.hs.config.registrations_require_3pid = []
- self.hs.config.auto_join_rooms = []
+ self.hs.config.registration.enable_registration = True
+ self.hs.config.registration.registrations_require_3pid = []
+ self.hs.config.registration.auto_join_rooms = []
self.hs.config.captcha.enable_registration_captcha = False
return self.hs
@@ -1064,13 +1064,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
- def register_as_user(self, username):
- self.make_request(
- b"POST",
- "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
- {"username": username},
- )
-
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
@@ -1107,7 +1100,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_user(self):
"""Test that an appservice user can use /login"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1121,7 +1114,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_user_bot(self):
"""Test that the appservice bot can use /login"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1135,7 +1128,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_wrong_user(self):
"""Test that non-as users cannot login with the as token"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1149,7 +1142,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_wrong_as(self):
"""Test that as users cannot login with wrong as token"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
@@ -1165,7 +1158,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"""Test that users must provide a token when using the appservice
login method
"""
- self.register_as_user(AS_USER)
+ self.register_appservice_user(AS_USER, self.service.token)
params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 1d152352..56fe1a3d 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -50,7 +50,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
PUT to the status endpoint with use_presence enabled will call
set_state on the presence handler.
"""
- self.hs.config.use_presence = True
+ self.hs.config.server.use_presence = True
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 72a5a11b..66dcfc9f 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -50,7 +50,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
@@ -74,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server_name,
+ self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
@@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_guest_registration(self):
self.hs.config.key.macaroon_secret_key = "test"
- self.hs.config.allow_guest_access = True
+ self.hs.config.registration.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self):
- self.hs.config.allow_guest_access = False
+ self.hs.config.registration.allow_guest_access = False
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 30bdaa9c..376853fd 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase):
# Check that do_3pid_invite wasn't called this time.
self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+ """
+
+ async def user_may_join_room(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> bool:
+ return False
+
+ join_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ self.assertEquals(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase):
self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+class RoomJoinTestCase(RoomBase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user1 = self.register_user("thomas", "hackme")
+ self.tok1 = self.login("thomas", "hackme")
+
+ self.user2 = self.register_user("teresa", "hackme")
+ self.tok2 = self.login("teresa", "hackme")
+
+ self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value = True
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ return_value = False
+ self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -2430,3 +2531,73 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+class ThreepidInviteTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("thomas", "hackme")
+ self.tok = self.login("thomas", "hackme")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_threepid_invite_spamcheck(self):
+ # Mock a few functions to prevent the test from failing due to failing to talk to
+ # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+ # can check its call_count later on during the test.
+ make_invite_mock = Mock(return_value=make_awaitable(0))
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+ self.hs.get_identity_handler().lookup_3pid = Mock(
+ return_value=make_awaitable(None),
+ )
+
+ # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+ # allow everything for now.
+ mock = Mock(return_value=make_awaitable(True))
+ self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+ # Send a 3PID invite into the room and check that it succeeded.
+ email_to_invite = "teresa@example.com"
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.code, 200)
+
+ # Check that the callback was called with the right params.
+ mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+ # Check that the call to send the invite was made.
+ make_invite_mock.assert_called_once()
+
+ # Now change the return value of the callback to deny any invite and test that
+ # we can't send the invite.
+ mock.return_value = make_awaitable(False)
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.code, 403)
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 3075d3f2..71fa87ce 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -48,7 +48,7 @@ class RestHelper:
def create_room_as(
self,
room_creator: Optional[str] = None,
- is_public: bool = True,
+ is_public: Optional[bool] = None,
room_version: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = 200,
@@ -62,9 +62,10 @@ class RestHelper:
Args:
room_creator: The user ID to create the room with.
- is_public: If True, the `visibility` parameter will be set to the
- default (public). Otherwise, the `visibility` parameter will be set
- to "private".
+ is_public: If True, the `visibility` parameter will be set to
+ "public". If False, it will be set to "private". If left
+ unspecified, the server will set it to an appropriate default
+ (which should be "private" as per the CS spec).
room_version: The room version to create the room as. Defaults to Synapse's
default room version.
tok: The access token to use in the request.
@@ -77,8 +78,8 @@ class RestHelper:
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
content = extra_content or {}
- if not is_public:
- content["visibility"] = "private"
+ if is_public is not None:
+ content["visibility"] = "public" if is_public else "private"
if room_version:
content["room_version"] = room_version
if tok:
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 4d09b5d0..8698135a 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -21,11 +21,13 @@ from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
from tests.server import FakeTransport
from tests.test_utils import SMALL_PNG
+from tests.utils import MockClock
try:
import lxml
@@ -723,9 +725,107 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
+ def test_oembed_autodiscovery(self):
+ """
+ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
+ 1. Request a preview of a URL which is not known to the oEmbed code.
+ 2. It returns HTML including a link to an oEmbed preview.
+ 3. The oEmbed preview is requested and returns a URL for an image.
+ 4. The image is requested for thumbnailing.
+ """
+ # This is a little cheesy in that we use the www subdomain (which isn't the
+ # list of oEmbed patterns) to get "raw" HTML response.
+ self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = b"""
+ <link rel="alternate" type="application/json+oembed"
+ href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+
+ # The oEmbed response.
+ result2 = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result2).encode("utf-8")
+
+ # Ensure a second request is made to the oEmbed URL.
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+
+ self.pump()
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/oembed?", server.data)
+
+ # Ensure a third request is made to the photo URL.
+ client = self.reactor.tcpClients[2][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: image/png\r\n\r\n"
+ )
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+
+ self.pump()
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/matrixdotorg", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(
+ body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345"
+ )
+ self.assertTrue(body["og:image"].startswith("mxc://"))
+ self.assertEqual(body["og:image:height"], 1)
+ self.assertEqual(body["og:image:width"], 1)
+ self.assertEqual(body["og:image:type"], "image/png")
+
def _download_image(self):
"""Downloads an image into the URL cache.
-
Returns:
A (host, media_id) tuple representing the MXC URI of the image.
"""
@@ -851,3 +951,32 @@ class URLPreviewTests(unittest.HomeserverTestCase):
404,
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
+
+ def test_cache_expiry(self):
+ """Test that URL cache files and thumbnails are cleaned up properly on expiry."""
+ self.preview_url.clock = MockClock()
+
+ _host, media_id = self._download_image()
+
+ file_path = self.preview_url.filepaths.url_cache_filepath(media_id)
+ file_dirs = self.preview_url.filepaths.url_cache_filepath_dirs_to_delete(
+ media_id
+ )
+ thumbnail_dir = self.preview_url.filepaths.url_cache_thumbnail_directory(
+ media_id
+ )
+ thumbnail_dirs = self.preview_url.filepaths.url_cache_thumbnail_dirs_to_delete(
+ media_id
+ )
+
+ self.assertTrue(os.path.isfile(file_path))
+ self.assertTrue(os.path.isdir(thumbnail_dir))
+
+ self.preview_url.clock.advance_time_msec(IMAGE_CACHE_EXPIRY_MS + 1)
+ self.get_success(self.preview_url._expire_url_cache_data())
+
+ for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs:
+ self.assertFalse(
+ os.path.exists(path),
+ f"{os.path.relpath(path, self.media_store_path)} was not deleted",
+ )
diff --git a/tests/server.py b/tests/server.py
index 88dfa805..64645651 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -317,7 +317,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self.threadpool = ThreadPool(self)
- self._tcp_callbacks = {}
+ self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
self._udp = []
self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque()
@@ -355,7 +355,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self):
return self.threadpool
- def add_tcp_client_callback(self, host, port, callback):
+ def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -364,7 +364,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
- def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+ def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -475,7 +475,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
return server
-def get_clock():
+def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
return clock, hs_clock
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 7f25200a..36c49595 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -346,7 +346,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
invites = []
# Register as many users as the MAU limit allows.
- for i in range(self.hs.config.max_mau_value):
+ for i in range(self.hs.config.server.max_mau_value):
localpart = "user%d" % i
user_id = self.register_user(localpart, "password")
tok = self.login(localpart, "password")
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index ffee7071..7496974d 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -79,12 +79,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Make sure the background update filled in the room creator
room_creator_after = self.get_success(
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index cf9748f2..f26d5acf 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -126,7 +126,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.db_pool = database._db_pool
self.engine = database.engine
- db_config = hs.config.get_single_database()
+ db_config = hs.config.database.get_single_database()
self.store = TestTransactionStore(
database, make_conn(db_config, self.engine, "test"), hs
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7cc5e621..a59c28f8 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -66,12 +66,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
def test_soft_failed_extremities_handled_correctly(self):
"""Test that extremities are correctly calculated in the presence of
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3cc8038f..0e4013eb 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -147,6 +147,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
+ def test_get_last_client_ip_by_device(self, after_persisting: bool):
+ """Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", device_id
+ )
+ )
+
+ if after_persisting:
+ # Trigger the storage loop
+ self.reactor.advance(10)
+
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, device_id)
+ )
+
+ self.assertEqual(
+ result,
+ {
+ (user_id, device_id): {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 12345678000,
+ },
+ },
+ )
+
+ @parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool):
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@@ -242,12 +285,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
user_id = "@user:id"
device_id = "MY_DEVICE"
@@ -311,12 +349,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# We should now get the correct result again
result = self.get_success(
@@ -337,12 +370,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
user_id = "@user:id"
device_id = "MY_DEVICE"
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 93136f07..b31c5eb5 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -578,12 +578,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
@@ -619,12 +614,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 944dbc34..d6b4cdd7 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -51,7 +51,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self):
- threepids = self.hs.config.mau_limits_reserved_threepids
+ threepids = self.hs.config.server.mau_limits_reserved_threepids
# register three users, of which two have reserved 3pids, and a third
# which is a support user.
@@ -101,9 +101,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# XXX some of this is redundant. poking things into the config shouldn't
# work, and in any case it's not obvious what we expect to happen when
# we advance the reactor.
- self.hs.config.max_mau_value = 0
+ self.hs.config.server.max_mau_value = 0
self.reactor.advance(FORTY_DAYS)
- self.hs.config.max_mau_value = 5
+ self.hs.config.server.max_mau_value = 5
self.get_success(self.store.reap_monthly_active_users())
@@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(d)
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, self.hs.config.max_mau_value)
+ self.assertEqual(count, self.hs.config.server.max_mau_value)
self.reactor.advance(FORTY_DAYS)
@@ -199,7 +199,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_reap_monthly_active_users_reserved_users(self):
"""Tests that reaping correctly handles reaping where reserved users are
present"""
- threepids = self.hs.config.mau_limits_reserved_threepids
+ threepids = self.hs.config.server.mau_limits_reserved_threepids
initial_users = len(threepids)
reserved_user_number = initial_users - 1
for i in range(initial_users):
@@ -234,7 +234,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(d)
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, self.hs.config.max_mau_value)
+ self.assertEqual(count, self.hs.config.server.max_mau_value)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
@@ -294,7 +294,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user2_email},
]
- self.hs.config.mau_limits_reserved_threepids = threepids
+ self.hs.config.server.mau_limits_reserved_threepids = threepids
d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c72dc405..2873e22c 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -169,12 +169,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
@@ -197,9 +192,4 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
+ self.wait_for_background_updates()
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 32060f2a..70d52b08 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__)
@@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
def test_get_state_for_event(self):
-
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+
+class StateFilterDifferenceTestCase(TestCase):
+ def assert_difference(
+ self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+ ):
+ self.assertEqual(
+ minuend.approx_difference(subtrahend),
+ expected,
+ f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+ )
+
+ def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b do not have the
+ include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.CanonicalAlias: {""}},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_no_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only a has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Create: None,
+ EventTypes.Member: set(),
+ EventTypes.CanonicalAlias: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ # This also shows that the resultant state filter is normalised.
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter(types=frozendict(), include_others=True),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b have the include_others
+ flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Create: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_no_include_other_minus_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only b has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=True,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_simple_cases(self):
+ """
+ Tests some very simple cases of the StateFilter approx_difference,
+ that are not explicitly tested by the more in-depth tests.
+ """
+
+ self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+ self.assert_difference(
+ StateFilter.all(),
+ StateFilter.none(),
+ StateFilter.all(),
+ )
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 6ff3ebb1..ace82cbf 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -22,7 +22,7 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
return self.setup_test_homeserver(db_txn_limit=1000)
def test_config(self):
- db_config = self.hs.config.get_single_database()
+ db_config = self.hs.config.database.get_single_database()
self.assertEqual(db_config.config["txn_limit"], 1000)
def test_select(self):
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 222e5d12..be3ed64f 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,7 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Set, Tuple
+from unittest import mock
+from unittest.mock import Mock, patch
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes, Membership, UserTypes
+from synapse.appservice import ApplicationService
+from synapse.rest import admin
+from synapse.rest.client import login, register, room
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.roommember import ProfileInfo
+from synapse.util import Clock
+
+from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
@@ -21,8 +36,391 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
+class GetUserDirectoryTables:
+ """Helper functions that we want to reuse in tests/handlers/test_user_directory.py"""
+
+ def __init__(self, store: DataStore):
+ self.store = store
+
+ def _compress_shared(
+ self, shared: List[Dict[str, str]]
+ ) -> Set[Tuple[str, str, str]]:
+ """
+ Compress a list of users who share rooms dicts to a list of tuples.
+ """
+ r = set()
+ for i in shared:
+ r.add((i["user_id"], i["other_user_id"], i["room_id"]))
+ return r
+
+ async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
+ """Fetch the entire `users_in_public_rooms` table.
+
+ Returns a list of tuples (user_id, room_id) where room_id is public and
+ contains the user with the given id.
+ """
+ r = await self.store.db_pool.simple_select_list(
+ "users_in_public_rooms", None, ("user_id", "room_id")
+ )
+
+ retval = []
+ for i in r:
+ retval.append((i["user_id"], i["room_id"]))
+ return retval
+
+ async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
+ """Fetch the entire `users_who_share_private_rooms` table.
+
+ Returns a dict containing "user_id", "other_user_id" and "room_id" keys.
+ The dicts can be flattened to Tuples with the `_compress_shared` method.
+ (This seems a little awkward---maybe we could clean this up.)
+ """
+
+ return await self.store.db_pool.simple_select_list(
+ "users_who_share_private_rooms",
+ None,
+ ["user_id", "other_user_id", "room_id"],
+ )
+
+ async def get_users_in_user_directory(self) -> Set[str]:
+ """Fetch the set of users in the `user_directory` table.
+
+ This is useful when checking we've correctly excluded users from the directory.
+ """
+ result = await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ["user_id"],
+ )
+ return {row["user_id"] for row in result}
+
+ async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
+ """Fetch users and their profiles from the `user_directory` table.
+
+ This is useful when we want to inspect display names and avatars.
+ It's almost the entire contents of the `user_directory` table: the only
+ thing missing is an unused room_id column.
+ """
+ rows = await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ("user_id", "display_name", "avatar_url"),
+ )
+ return {
+ row["user_id"]: ProfileInfo(
+ display_name=row["display_name"], avatar_url=row["avatar_url"]
+ )
+ for row in rows
+ }
+
+
+class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
+ """Ensure that rebuilding the directory writes the correct data to the DB.
+
+ See also tests/handlers/test_user_directory.py for similar checks. They
+ test the incremental updates, rather than the big rebuild.
+ """
+
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ register.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ hostname="test",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = super().make_homeserver(reactor, clock)
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastore()
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+ def _purge_and_rebuild_user_dir(self) -> None:
+ """Nuke the user directory tables, start the background process to
+ repopulate them, and wait for the process to complete. This allows us
+ to inspect the outcome of the background process alone, without any of
+ the other incremental updates.
+ """
+ self.get_success(self.store.update_user_directory_stream_pos(None))
+ self.get_success(self.store.delete_all_from_user_dir())
+
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
+
+ # Nothing updated yet
+ self.assertEqual(shares_private, [])
+ self.assertEqual(public_users, [])
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_createtables",
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_createtables",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_process_users",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_process_rooms",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_user_directory_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_user_directory_process_users",
+ },
+ )
+ )
+
+ self.wait_for_background_updates()
+
+ def test_initial(self) -> None:
+ """
+ The user directory's initial handler correctly updates the search tables.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+ u3 = self.register_user("user3", "pass")
+ u3_token = self.login(u3, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token)
+ self.helper.join(private_room, user=u3, tok=u3_token)
+
+ # Do the initial population of the user directory via the background update
+ self._purge_and_rebuild_user_dir()
+
+ shares_private = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ public_users = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
+
+ # User 1 and User 2 are in the same public room
+ self.assertEqual(set(public_users), {(u1, room), (u2, room)})
+
+ # User 1 and User 3 share private rooms
+ self.assertEqual(
+ self.user_dir_helper._compress_shared(shares_private),
+ {(u1, u3, private_room), (u3, u1, private_room)},
+ )
+
+ # All three should have entries in the directory
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ self.assertEqual(users, {u1, u2, u3})
+
+ # The next four tests (test_population_excludes_*) all set up
+ # - A normal user included in the user dir
+ # - A public and private room created by that user
+ # - A user excluded from the room dir, belonging to both rooms
+
+ # They match similar logic in handlers/test_user_directory.py But that tests
+ # updating the directory; this tests rebuilding it from scratch.
+
+ def _create_rooms_and_inject_memberships(
+ self, creator: str, token: str, joiner: str
+ ) -> Tuple[str, str]:
+ """Create a public and private room as a normal user.
+ Then get the `joiner` into those rooms.
+ """
+ public_room = self.helper.create_room_as(
+ creator,
+ is_public=True,
+ # See https://github.com/matrix-org/synapse/issues/10951
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ private_room = self.helper.create_room_as(creator, is_public=False, tok=token)
+
+ # HACK: get the user into these rooms
+ self.get_success(inject_member_event(self.hs, public_room, joiner, "join"))
+ self.get_success(inject_member_event(self.hs, private_room, joiner, "join"))
+
+ return public_room, private_room
+
+ def _check_room_sharing_tables(
+ self, normal_user: str, public_room: str, private_room: str
+ ) -> None:
+ # After rebuilding the directory, we should only see the normal user.
+ users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
+ self.assertEqual(users, {normal_user})
+ in_public_rooms = self.get_success(
+ self.user_dir_helper.get_users_in_public_rooms()
+ )
+ self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
+ in_private_rooms = self.get_success(
+ self.user_dir_helper.get_users_who_share_private_rooms()
+ )
+ self.assertEqual(in_private_rooms, [])
+
+ def test_population_excludes_support_user(self) -> None:
+ # Create a normal and support user.
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ support = "@support1:test"
+ self.get_success(
+ self.store.register_user(
+ user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+ )
+ )
+
+ # Join the support user to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, support
+ )
+
+ # Rebuild the directory.
+ self._purge_and_rebuild_user_dir()
+
+ # Check the support user is not in the directory.
+ self._check_room_sharing_tables(user, public, private)
+
+ def test_population_excludes_deactivated_user(self) -> None:
+ user = self.register_user("naughty", "pass")
+ admin = self.register_user("admin", "pass", admin=True)
+ admin_token = self.login(admin, "pass")
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{user}",
+ access_token=admin_token,
+ content={"deactivated": True},
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["deactivated"], True)
+
+ # Join the deactivated user to rooms owned by the admin.
+ # Is this something that could actually happen outside of a test?
+ public, private = self._create_rooms_and_inject_memberships(
+ admin, admin_token, user
+ )
+
+ # Rebuild the user dir. The deactivated user should be missing.
+ self._purge_and_rebuild_user_dir()
+ self._check_room_sharing_tables(admin, public, private)
+
+ def test_population_excludes_appservice_user(self) -> None:
+ # Register an AS user.
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+
+ # Join the AS user to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, as_user
+ )
+
+ # Rebuild the directory.
+ self._purge_and_rebuild_user_dir()
+
+ # Check the AS user is not in the directory.
+ self._check_room_sharing_tables(user, public, private)
+
+ def test_population_excludes_appservice_sender(self) -> None:
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+
+ # Join the AS sender to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, self.appservice.sender
+ )
+
+ # Rebuild the directory.
+ self._purge_and_rebuild_user_dir()
+
+ # Check the AS sender is not in the directory.
+ self._check_room_sharing_tables(user, public, private)
+
+ def test_population_conceals_private_nickname(self) -> None:
+ # Make a private room, and set a nickname within
+ user = self.register_user("aaaa", "pass")
+ user_token = self.login(user, "pass")
+ private_room = self.helper.create_room_as(user, is_public=False, tok=user_token)
+ self.helper.send_state(
+ private_room,
+ EventTypes.Member,
+ state_key=user,
+ body={"membership": Membership.JOIN, "displayname": "BBBB"},
+ tok=user_token,
+ )
+
+ # Rebuild the user directory. Make the rescan of the `users` table a no-op
+ # so we only see the effect of scanning the `room_memberships` table.
+ async def mocked_process_users(*args: Any, **kwargs: Any) -> int:
+ await self.store.db_pool.updates._end_background_update(
+ "populate_user_directory_process_users"
+ )
+ return 1
+
+ with mock.patch.dict(
+ self.store.db_pool.updates._background_update_handlers,
+ populate_user_directory_process_users=mocked_process_users,
+ ):
+ self._purge_and_rebuild_user_dir()
+
+ # Local users are ignored by the scan over rooms
+ users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory())
+ self.assertEqual(users, {})
+
+ # Do a full rebuild including the scan over the `users` table. The local
+ # user should appear with their profile name.
+ self._purge_and_rebuild_user_dir()
+ users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory())
+ self.assertEqual(
+ users, {user: ProfileInfo(display_name="aaaa", avatar_url=None)}
+ )
+
+
class UserDirectoryStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
@@ -33,7 +431,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
- def test_search_user_dir(self):
+ def test_search_user_dir(self) -> None:
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
@@ -44,7 +442,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
)
@override_config({"user_directory": {"search_all_users": True}})
- def test_search_user_dir_all_users(self):
+ def test_search_user_dir_all_users(self) -> None:
r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
@@ -58,7 +456,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
)
@override_config({"user_directory": {"search_all_users": True}})
- def test_search_user_dir_stop_words(self):
+ def test_search_user_dir_stop_words(self) -> None:
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 1a4d0787..cf407c51 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -38,21 +38,19 @@ class EventAuthTestCase(unittest.TestCase):
}
# creator should be able to send state
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_random_state_event(creator),
auth_events,
- do_sig_check=False,
)
# joiner should not be able to send state
self.assertRaises(
AuthError,
- event_auth.check,
+ event_auth.check_auth_rules_for_event,
RoomVersions.V1,
_random_state_event(joiner),
auth_events,
- do_sig_check=False,
)
def test_state_default_level(self):
@@ -77,19 +75,17 @@ class EventAuthTestCase(unittest.TestCase):
# pleb should not be able to send state
self.assertRaises(
AuthError,
- event_auth.check,
+ event_auth.check_auth_rules_for_event,
RoomVersions.V1,
_random_state_event(pleb),
auth_events,
- do_sig_check=False,
),
# king should be able to send state
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_random_state_event(king),
auth_events,
- do_sig_check=False,
)
def test_alias_event(self):
@@ -102,37 +98,33 @@ class EventAuthTestCase(unittest.TestCase):
}
# creator should be able to send aliases
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(creator),
auth_events,
- do_sig_check=False,
)
# Reject an event with no state key.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(creator, state_key=""),
auth_events,
- do_sig_check=False,
)
# If the domain of the sender does not match the state key, reject.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(creator, state_key="test.com"),
auth_events,
- do_sig_check=False,
)
# Note that the member does *not* need to be in the room.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_alias_event(other),
auth_events,
- do_sig_check=False,
)
def test_msc2432_alias_event(self):
@@ -145,34 +137,30 @@ class EventAuthTestCase(unittest.TestCase):
}
# creator should be able to send aliases
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(creator),
auth_events,
- do_sig_check=False,
)
# No particular checks are done on the state key.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(creator, state_key=""),
auth_events,
- do_sig_check=False,
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(creator, state_key="test.com"),
auth_events,
- do_sig_check=False,
)
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_alias_event(other),
auth_events,
- do_sig_check=False,
)
def test_msc2209(self):
@@ -192,20 +180,18 @@ class EventAuthTestCase(unittest.TestCase):
}
# pleb should be able to modify the notifications power level.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
_power_levels_event(pleb, {"notifications": {"room": 100}}),
auth_events,
- do_sig_check=False,
)
# But an MSC2209 room rejects this change.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_power_levels_event(pleb, {"notifications": {"room": 100}}),
auth_events,
- do_sig_check=False,
)
def test_join_rules_public(self):
@@ -222,59 +208,53 @@ class EventAuthTestCase(unittest.TestCase):
}
# Check join.
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user cannot be force-joined to a room.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_member_event(pleb, "join", sender=creator),
auth_events,
- do_sig_check=False,
)
# Banned should be rejected.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user who left can re-join.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can send a join if they're in the room.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can accept an invite.
auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
def test_join_rules_invite(self):
@@ -292,60 +272,54 @@ class EventAuthTestCase(unittest.TestCase):
# A join without an invite is rejected.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user cannot be force-joined to a room.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_member_event(pleb, "join", sender=creator),
auth_events,
- do_sig_check=False,
)
# Banned should be rejected.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user who left cannot re-join.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can send a join if they're in the room.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can accept an invite.
auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
def test_join_rules_msc3083_restricted(self):
@@ -370,11 +344,10 @@ class EventAuthTestCase(unittest.TestCase):
# Older room versions don't understand this join rule
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A properly formatted join event should work.
@@ -384,11 +357,10 @@ class EventAuthTestCase(unittest.TestCase):
EventContentFields.AUTHORISING_USER: "@creator:example.com"
},
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
auth_events,
- do_sig_check=False,
)
# A join issued by a specific user works (i.e. the power level checks
@@ -400,7 +372,7 @@ class EventAuthTestCase(unittest.TestCase):
pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
"@inviter:foo.test"
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(
pleb,
@@ -409,16 +381,14 @@ class EventAuthTestCase(unittest.TestCase):
},
),
pl_auth_events,
- do_sig_check=False,
)
# A join which is missing an authorised server is rejected.
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# An join authorised by a user who is not in the room is rejected.
@@ -427,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
creator, {"invite": 100, "users": {"@other:example.com": 150}}
)
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(
pleb,
@@ -436,13 +406,12 @@ class EventAuthTestCase(unittest.TestCase):
},
),
auth_events,
- do_sig_check=False,
)
# A user cannot be force-joined to a room. (This uses an event which
# *would* be valid, but is sent be a different user.)
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_member_event(
pleb,
@@ -453,36 +422,32 @@ class EventAuthTestCase(unittest.TestCase):
},
),
auth_events,
- do_sig_check=False,
)
# Banned should be rejected.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
with self.assertRaises(AuthError):
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
auth_events,
- do_sig_check=False,
)
# A user who left can re-join.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
auth_events,
- do_sig_check=False,
)
# A user can send a join if they're in the room. (This doesn't need to
# be authorised since the user is already joined.)
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
# A user can accept an invite. (This doesn't need to be authorised since
@@ -490,11 +455,10 @@ class EventAuthTestCase(unittest.TestCase):
auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator
)
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
auth_events,
- do_sig_check=False,
)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c51e018d..24fc77d7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
event,
context,
state=None,
- claimed_auth_event_map=None,
backfilled=False,
):
return context
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 66111eb3..c683c893 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -13,11 +13,11 @@
# limitations under the License.
"""Tests REST events for /rooms paths."""
-
+import synapse.rest.admin
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
-from synapse.rest.client import register, sync
+from synapse.rest.client import login, profile, register, sync
from tests import unittest
from tests.unittest import override_config
@@ -26,7 +26,13 @@ from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
- servlets = [register.register_servlets, sync.register_servlets]
+ servlets = [
+ register.register_servlets,
+ sync.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ profile.register_servlets,
+ login.register_servlets,
+ ]
def default_config(self):
config = default_config("test")
@@ -165,7 +171,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
@override_config({"mau_trial_days": 1})
def test_trial_users_cant_come_back(self):
- self.hs.config.mau_trial_days = 1
+ self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
@@ -229,6 +235,31 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
+ def test_deactivated_users_dont_count_towards_mau(self):
+ user1 = self.register_user("madonna", "password")
+ self.register_user("prince", "password2")
+ self.register_user("frodo", "onering", True)
+
+ token1 = self.login("madonna", "password")
+ token2 = self.login("prince", "password2")
+ admin_token = self.login("frodo", "onering")
+
+ self.do_sync_for_user(token1)
+ self.do_sync_for_user(token2)
+
+ # Check that mau count is what we expect
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 2)
+
+ # Deactivate user1
+ url = "/_synapse/admin/v1/deactivate/%s" % user1
+ channel = self.make_request("POST", url, access_token=admin_token)
+ self.assertIn("success", channel.json_body["id_server_unbind_result"])
+
+ # Check that deactivated user is no longer counted
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 1)
+
def create_user(self, localpart, token=None, appservice=False):
request_data = {
"username": localpart,
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 48e792b5..09e017b4 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -13,7 +13,8 @@
# limitations under the License.
from synapse.rest.media.v1.preview_url_resource import (
- decode_and_calc_og,
+ _calc_og,
+ decode_body,
get_html_media_encoding,
summarize_paragraphs,
)
@@ -158,7 +159,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -173,7 +175,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -191,7 +194,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(
og,
@@ -212,7 +216,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -225,7 +230,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -239,7 +245,8 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -253,21 +260,22 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_empty(self):
"""Test a body with no data in it."""
html = b""
- og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEqual(og, {})
+ tree = decode_body(html)
+ self.assertIsNone(tree)
def test_no_tree(self):
"""A valid body with no tree in it."""
html = b"\x00"
- og = decode_and_calc_og(html, "http://example.com/test.html")
- self.assertEqual(og, {})
+ tree = decode_body(html)
+ self.assertIsNone(tree)
def test_invalid_encoding(self):
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
@@ -279,9 +287,8 @@ class CalcOgTestCase(unittest.TestCase):
</body>
</html>
"""
- og = decode_and_calc_og(
- html, "http://example.com/test.html", "invalid-encoding"
- )
+ tree = decode_body(html, "invalid-encoding")
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self):
@@ -295,7 +302,8 @@ class CalcOgTestCase(unittest.TestCase):
</body>
</html>
"""
- og = decode_and_calc_og(html, "http://example.com/test.html")
+ tree = decode_body(html)
+ og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
diff --git a/tests/unittest.py b/tests/unittest.py
index 7a6f5954..81c1a9e9 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import inspect
import logging
import secrets
import time
-from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from unittest.mock import Mock, patch
from canonicaljson import json
@@ -28,6 +28,7 @@ from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
+from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
@@ -46,6 +47,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -232,7 +234,7 @@ class HomeserverTestCase(TestCase):
# Honour the `use_frozen_dicts` config option. We have to do this
# manually because this is taken care of in the app `start` code, which
# we don't run. Plus we want to reset it on tearDown.
- events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+ events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts
if self.hs is None:
raise Exception("No homeserver returned from make_homeserver.")
@@ -315,6 +317,15 @@ class HomeserverTestCase(TestCase):
self.reactor.advance(0.01)
time.sleep(0.01)
+ def wait_for_background_updates(self) -> None:
+ """Block until all background database updates have completed."""
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
def make_homeserver(self, reactor, clock):
"""
Make and return a homeserver.
@@ -371,7 +382,7 @@ class HomeserverTestCase(TestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
"""
Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test
@@ -447,7 +458,7 @@ class HomeserverTestCase(TestCase):
client_ip,
)
- def setup_test_homeserver(self, *args, **kwargs):
+ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
"""
Set up the test homeserver, meant to be called by the overridable
make_homeserver. It automatically passes through the test class's
@@ -558,7 +569,7 @@ class HomeserverTestCase(TestCase):
Returns:
The MXID of the new user.
"""
- self.hs.config.registration_shared_secret = "shared"
+ self.hs.config.registration.registration_shared_secret = "shared"
# Create the user
channel = self.make_request("GET", "/_synapse/admin/v1/register")
@@ -594,6 +605,35 @@ class HomeserverTestCase(TestCase):
user_id = channel.json_body["user_id"]
return user_id
+ def register_appservice_user(
+ self,
+ username: str,
+ appservice_token: str,
+ ) -> str:
+ """Register an appservice user as an application service.
+ Requires the client-facing registration API be registered.
+
+ Args:
+ username: the user to be registered by an application service.
+ Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart"
+ appservice_token: the acccess token for that application service.
+
+ Raises: if the request to '/register' does not return 200 OK.
+
+ Returns: the MXID of the new user.
+ """
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/register",
+ {
+ "username": username,
+ "type": "m.login.application_service",
+ },
+ access_token=appservice_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ return channel.json_body["user_id"]
+
def login(
self,
username,
diff --git a/tox.ini b/tox.ini
index 5a62ec76..cfe6a069 100644
--- a/tox.ini
+++ b/tox.ini
@@ -41,10 +41,10 @@ lint_targets =
scripts/hash_password
scripts/register_new_matrix_user
scripts/synapse_port_db
+ scripts/update_synapse_database
scripts-dev
scripts-dev/build_debian_packages
scripts-dev/sign_json
- scripts-dev/update_database
stubs
contrib
synctl