summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.git-blame-ignore-revs8
-rw-r--r--CHANGES.md84
-rw-r--r--MANIFEST.in7
-rw-r--r--README.rst5
-rw-r--r--UPGRADE.rst7
-rw-r--r--debian/changelog6
-rw-r--r--docker/Dockerfile1
-rw-r--r--docker/README.md5
-rwxr-xr-xdocker/start.py12
-rw-r--r--docs/admin_api/media_admin_api.md16
-rw-r--r--docs/openid.md50
-rw-r--r--docs/reverse_proxy.md51
-rw-r--r--docs/sample_config.yaml51
-rw-r--r--docs/spam_checker.md18
-rw-r--r--mypy.ini4
-rwxr-xr-xscripts-dev/config-lint.sh9
-rw-r--r--setup.cfg1
-rw-r--r--stubs/txredisapi.pyi6
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py43
-rw-r--r--synapse/appservice/api.py2
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/logger.py5
-rw-r--r--synapse/config/oidc_config.py102
-rw-r--r--synapse/config/server.py3
-rw-r--r--synapse/config/stats.py45
-rw-r--r--synapse/events/spamcheck.py29
-rw-r--r--synapse/federation/federation_server.py175
-rw-r--r--synapse/federation/sender/per_destination_queue.py287
-rw-r--r--synapse/federation/sender/transaction_manager.py59
-rw-r--r--synapse/handlers/acme.py4
-rw-r--r--synapse/handlers/auth.py73
-rw-r--r--synapse/handlers/cas_handler.py1
-rw-r--r--synapse/handlers/federation.py220
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/oidc_handler.py179
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/handlers/register.py71
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_list.py4
-rw-r--r--synapse/handlers/saml_handler.py1
-rw-r--r--synapse/handlers/sso.py8
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/http/client.py52
-rw-r--r--synapse/http/federation/matrix_federation_agent.py3
-rw-r--r--synapse/http/federation/well_known_resolver.py3
-rw-r--r--synapse/http/matrixfederationclient.py15
-rw-r--r--synapse/logging/_remote.py23
-rw-r--r--synapse/logging/context.py6
-rw-r--r--synapse/module_api/__init__.py31
-rw-r--r--synapse/push/emailpusher.py4
-rw-r--r--synapse/replication/http/_base.py9
-rw-r--r--synapse/replication/http/login.py4
-rw-r--r--synapse/replication/tcp/handler.py48
-rw-r--r--synapse/replication/tcp/protocol.py33
-rw-r--r--synapse/replication/tcp/redis.py45
-rw-r--r--synapse/rest/admin/_base.py15
-rw-r--r--synapse/rest/admin/media.py29
-rw-r--r--synapse/rest/admin/purge_room_servlet.py15
-rw-r--r--synapse/rest/admin/rooms.py5
-rw-r--r--synapse/rest/admin/server_notice_servlet.py23
-rw-r--r--synapse/rest/admin/users.py5
-rw-r--r--synapse/rest/client/v1/login.py53
-rw-r--r--synapse/rest/client/v1/room.py5
-rw-r--r--synapse/rest/client/v2_alpha/groups.py105
-rw-r--r--synapse/rest/media/v1/config_resource.py3
-rw-r--r--synapse/rest/media/v1/media_repository.py3
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py3
-rw-r--r--synapse/rest/media/v1/thumbnailer.py11
-rw-r--r--synapse/rest/media/v1/upload_resource.py3
-rw-r--r--synapse/rest/synapse/client/saml2/response_resource.py10
-rw-r--r--synapse/server.py13
-rw-r--r--synapse/storage/databases/main/event_federation.py148
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py79
-rw-r--r--synapse/storage/databases/main/events_worker.py12
-rw-r--r--synapse/storage/databases/main/purge_events.py8
-rw-r--r--synapse/storage/databases/main/registration.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql17
-rw-r--r--synapse/storage/databases/main/transactions.py10
-rw-r--r--synapse/types.py71
-rw-r--r--synapse/util/async_helpers.py26
-rw-r--r--synapse/util/caches/response_cache.py10
-rw-r--r--synapse/util/macaroons.py89
-rw-r--r--tests/federation/test_federation_catch_up.py3
-rw-r--r--tests/handlers/oidc_test_key.p85
-rw-r--r--tests/handlers/oidc_test_key.pub.pem4
-rw-r--r--tests/handlers/test_auth.py49
-rw-r--r--tests/handlers/test_cas.py10
-rw-r--r--tests/handlers/test_oidc.py217
-rw-r--r--tests/handlers/test_register.py31
-rw-r--r--tests/handlers/test_saml.py10
-rw-r--r--tests/http/test_client.py126
-rw-r--r--tests/replication/_base.py65
-rw-r--r--tests/replication/test_federation_ack.py8
-rw-r--r--tests/rest/client/v1/test_login.py43
-rw-r--r--tests/rest/media/v1/test_media_storage.py29
-rw-r--r--tests/server.py4
-rw-r--r--tests/storage/test_event_federation.py76
-rw-r--r--tests/storage/test_purge.py71
-rw-r--r--tests/test_utils/logging_setup.py2
-rw-r--r--tests/util/caches/test_responsecache.py131
-rw-r--r--tox.ini2
103 files changed, 2707 insertions, 934 deletions
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 00000000..83ddd568
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,8 @@
+# Black reformatting (#5482).
+32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
+
+# Target Python 3.5 with black (#8664).
+aff1eb7c671b0a3813407321d2702ec46c71fa56
+
+# Update black to 20.8b1 (#9381).
+0a00b7ff14890987f09112a2ae696c61001e6cf1
diff --git a/CHANGES.md b/CHANGES.md
index 99e314c6..1bf9514a 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,87 @@
+Synapse 1.30.0 (2021-03-22)
+===========================
+
+Note that this release deprecates the ability for appservices to
+call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice
+developers should use a `type` value of `m.login.application_service` as
+per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions).
+In future releases, calling this endpoint with an access token - but without a `m.login.application_service`
+type - will fail.
+
+
+No significant changes.
+
+
+Synapse 1.30.0rc1 (2021-03-16)
+==============================
+
+Features
+--------
+
+- Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573))
+- Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
+- Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549))
+- Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601))
+- Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617))
+- Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626))
+
+
+Bugfixes
+--------
+
+- Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473))
+- Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583))
+- Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567))
+- Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587))
+- Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597))
+- Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620))
+- Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623))
+
+
+Updates to the Docker image
+---------------------------
+
+- Make use of an improved malloc implementation (`jemalloc`) in the docker image. ([\#8553](https://github.com/matrix-org/synapse/issues/8553))
+
+
+Improved Documentation
+----------------------
+
+- Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508))
+- Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550))
+- Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571))
+- Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580))
+- Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
+
+
+Deprecations and Removals
+-------------------------
+
+- The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
+- Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559))
+
+
+Internal Changes
+----------------
+
+- Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458))
+- Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520))
+- Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523))
+- Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618))
+- Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541))
+- Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560))
+- Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561))
+- Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562))
+- Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563))
+- Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568))
+- Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576))
+- Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586))
+- Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590))
+- Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596))
+- Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
+- Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619))
+
+
Synapse 1.29.0 (2021-03-08)
===========================
diff --git a/MANIFEST.in b/MANIFEST.in
index 120ce5b7..25d1cb75 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -20,9 +20,10 @@ recursive-include scripts *
recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py
-include tests/http/ca.crt
-include tests/http/ca.key
-include tests/http/server.key
+recursive-include tests *.pem
+recursive-include tests *.p8
+recursive-include tests *.crt
+recursive-include tests *.key
recursive-include synapse/res *
recursive-include synapse/static *.css
diff --git a/README.rst b/README.rst
index d872b11f..6a1e7135 100644
--- a/README.rst
+++ b/README.rst
@@ -183,8 +183,9 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
-`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
-`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
+`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
+`HAProxy <https://www.haproxy.org/>`_ or
+`relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges.
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 031e02bd..8bc2ff91 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -124,6 +124,13 @@ This version changes the URI used for callbacks from OAuth2 and SAML2 identity p
need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
"ACS location" (also known as "allowed callback URLs") at the identity provider.
+ The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
+ ``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
+ provider uses this property to validate or otherwise identify Synapse, its configuration
+ will need to be updated to use the new URL. Alternatively you could create a new, separate
+ "EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
+ the existing "EntityDescriptor" as they were.
+
Changes to HTML templates
-------------------------
diff --git a/debian/changelog b/debian/changelog
index 3feefd8a..e6b2122d 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+matrix-synapse-py3 (1.30.0) stable; urgency=medium
+
+ * New synapse release 1.30.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Mon, 22 Mar 2021 13:15:34 +0000
+
matrix-synapse-py3 (1.29.0) stable; urgency=medium
[ Jonathan de Jong ]
diff --git a/docker/Dockerfile b/docker/Dockerfile
index d619ee08..def45015 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -69,6 +69,7 @@ RUN apt-get update && apt-get install -y \
libpq5 \
libwebp6 \
xmlsec1 \
+ libjemalloc2 \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
diff --git a/docker/README.md b/docker/README.md
index 7b138df4..3a7dc585 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -204,3 +204,8 @@ healthcheck:
timeout: 10s
retries: 3
```
+
+## Using jemalloc
+
+Jemalloc is embedded in the image and will be used instead of the default allocator.
+You can read about jemalloc by reading the Synapse [README](../README.md) \ No newline at end of file
diff --git a/docker/start.py b/docker/start.py
index 0d2c590b..16d6a820 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -3,6 +3,7 @@
import codecs
import glob
import os
+import platform
import subprocess
import sys
@@ -213,6 +214,13 @@ def main(args, environ):
if "-m" not in args:
args = ["-m", synapse_worker] + args
+ jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),)
+
+ if os.path.isfile(jemallocpath):
+ environ["LD_PRELOAD"] = jemallocpath
+ else:
+ log("Could not find %s, will not use" % (jemallocpath,))
+
# if there are no config files passed to synapse, try adding the default file
if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
@@ -248,9 +256,9 @@ running with 'migrate_config'. See the README for more details.
args = ["python"] + args
if ownership is not None:
args = ["gosu", ownership] + args
- os.execv("/usr/sbin/gosu", args)
+ os.execve("/usr/sbin/gosu", args, environ)
else:
- os.execv("/usr/local/bin/python", args)
+ os.execve("/usr/local/bin/python", args, environ)
if __name__ == "__main__":
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 90faeaae..9dbec68c 100644
--- a/docs/admin_api/media_admin_api.md
+++ b/docs/admin_api/media_admin_api.md
@@ -1,5 +1,7 @@
# Contents
-- [List all media in a room](#list-all-media-in-a-room)
+- [Querying media](#querying-media)
+ * [List all media in a room](#list-all-media-in-a-room)
+ * [List all media uploaded by a user](#list-all-media-uploaded-by-a-user)
- [Quarantine media](#quarantine-media)
* [Quarantining media by ID](#quarantining-media-by-id)
* [Quarantining media in a room](#quarantining-media-in-a-room)
@@ -10,7 +12,11 @@
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
- [Purge Remote Media API](#purge-remote-media-api)
-# List all media in a room
+# Querying media
+
+These APIs allow extracting media information from the homeserver.
+
+## List all media in a room
This API gets a list of known media in a room.
However, it only shows media from unencrypted events or rooms.
@@ -36,6 +42,12 @@ The API returns a JSON body like the following:
}
```
+## List all media uploaded by a user
+
+Listing all media that has been uploaded by a local user can be achieved through
+the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user)
+Admin API.
+
# Quarantine media
Quarantining media means that it is marked as inaccessible by users. It applies
diff --git a/docs/openid.md b/docs/openid.md
index 263bc9f6..cfaafc50 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -226,7 +226,7 @@ Synapse config:
oidc_providers:
- idp_id: github
idp_name: Github
- idp_brand: "org.matrix.github" # optional: styling hint for clients
+ idp_brand: "github" # optional: styling hint for clients
discover: false
issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED
@@ -252,7 +252,7 @@ oidc_providers:
oidc_providers:
- idp_id: google
idp_name: Google
- idp_brand: "org.matrix.google" # optional: styling hint for clients
+ idp_brand: "google" # optional: styling hint for clients
issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@@ -299,7 +299,7 @@ Synapse config:
oidc_providers:
- idp_id: gitlab
idp_name: Gitlab
- idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
+ idp_brand: "gitlab" # optional: styling hint for clients
issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
@@ -334,7 +334,7 @@ Synapse config:
```yaml
- idp_id: facebook
idp_name: Facebook
- idp_brand: "org.matrix.facebook" # optional: styling hint for clients
+ idp_brand: "facebook" # optional: styling hint for clients
discover: false
issuer: "https://facebook.com"
client_id: "your-client-id" # TO BE FILLED
@@ -386,7 +386,7 @@ oidc_providers:
config:
subject_claim: "id"
localpart_template: "{{ user.login }}"
- display_name_template: "{{ user.full_name }}"
+ display_name_template: "{{ user.full_name }}"
```
### XWiki
@@ -401,8 +401,7 @@ oidc_providers:
idp_name: "XWiki"
issuer: "https://myxwikihost/xwiki/oidc/"
client_id: "your-client-id" # TO BE FILLED
- # Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed
- client_secret: "dontcare"
+ client_auth_method: none
scopes: ["openid", "profile"]
user_profile_method: "userinfo_endpoint"
user_mapping_provider:
@@ -410,3 +409,40 @@ oidc_providers:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
```
+
+## Apple
+
+Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
+
+You will need to create a new "Services ID" for SiWA, and create and download a
+private key with "SiWA" enabled.
+
+As well as the private key file, you will need:
+ * Client ID: the "identifier" you gave the "Services ID"
+ * Team ID: a 10-character ID associated with your developer account.
+ * Key ID: the 10-character identifier for the key.
+
+https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
+documentation on setting up SiWA.
+
+The synapse config will look like this:
+
+```yaml
+ - idp_id: apple
+ idp_name: Apple
+ issuer: "https://appleid.apple.com"
+ client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
+ client_auth_method: "client_secret_post"
+ client_secret_jwt_key:
+ key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file
+ jwt_header:
+ alg: ES256
+ kid: "KEYIDCODE" # Set to the 10-char Key ID
+ jwt_payload:
+ iss: TEAMIDCODE # Set to the 10-char Team ID
+ scopes: ["name", "email", "openid"]
+ authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
+ user_mapping_provider:
+ config:
+ email_template: "{{ user.email }}"
+```
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 81e5a68a..860afd5a 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -3,8 +3,9 @@
It is recommended to put a reverse proxy such as
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
-[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
-[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
+[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
+[HAProxy](https://www.haproxy.org/) or
+[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root
privileges.
@@ -162,6 +163,52 @@ backend matrix
server matrix 127.0.0.1:8008
```
+### Relayd
+
+```
+table <webserver> { 127.0.0.1 }
+table <matrixserver> { 127.0.0.1 }
+
+http protocol "https" {
+ tls { no tlsv1.0, ciphers "HIGH" }
+ tls keypair "example.com"
+ match header set "X-Forwarded-For" value "$REMOTE_ADDR"
+ match header set "X-Forwarded-Proto" value "https"
+
+ # set CORS header for .well-known/matrix/server, .well-known/matrix/client
+ # httpd does not support setting headers, so do it here
+ match request path "/.well-known/matrix/*" tag "matrix-cors"
+ match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
+
+ pass quick path "/_matrix/*" forward to <matrixserver>
+ pass quick path "/_synapse/client/*" forward to <matrixserver>
+
+ # pass on non-matrix traffic to webserver
+ pass forward to <webserver>
+}
+
+relay "https_traffic" {
+ listen on egress port 443 tls
+ protocol "https"
+ forward to <matrixserver> port 8008 check tcp
+ forward to <webserver> port 8080 check tcp
+}
+
+http protocol "matrix" {
+ tls { no tlsv1.0, ciphers "HIGH" }
+ tls keypair "example.com"
+ block
+ pass quick path "/_matrix/*" forward to <matrixserver>
+ pass quick path "/_synapse/client/*" forward to <matrixserver>
+}
+
+relay "matrix_federation" {
+ listen on egress port 8448 tls
+ protocol "matrix"
+ forward to <matrixserver> port 8008 check tcp
+}
+```
+
## Homeserver Configuration
You will also want to set `bind_addresses: ['127.0.0.1']` and
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 4dbef41b..7de000f4 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -89,8 +89,7 @@ pid_file: DATADIR/homeserver.pid
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
-# API, so this setting is of limited value if federation is enabled on
-# the server.
+# API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
@@ -1780,7 +1779,26 @@ saml2_config:
#
# client_id: Required. oauth2 client id to use.
#
-# client_secret: Required. oauth2 client secret to use.
+# client_secret: oauth2 client secret to use. May be omitted if
+# client_secret_jwt_key is given, or if client_auth_method is 'none'.
+#
+# client_secret_jwt_key: Alternative to client_secret: details of a key used
+# to create a JSON Web Token to be used as an OAuth2 client secret. If
+# given, must be a dictionary with the following properties:
+#
+# key: a pem-encoded signing key. Must be a suitable key for the
+# algorithm specified. Required unless 'key_file' is given.
+#
+# key_file: the path to file containing a pem-encoded signing key file.
+# Required unless 'key' is given.
+#
+# jwt_header: a dictionary giving properties to include in the JWT
+# header. Must include the key 'alg', giving the algorithm used to
+# sign the JWT, such as "ES256", using the JWA identifiers in
+# RFC7518.
+#
+# jwt_payload: an optional dictionary giving properties to include in
+# the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
@@ -1901,7 +1919,7 @@ oidc_providers:
#
#- idp_id: github
# idp_name: Github
- # idp_brand: org.matrix.github
+ # idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
@@ -2627,19 +2645,20 @@ user_directory:
-# Local statistics collection. Used in populating the room directory.
-#
-# 'bucket_size' controls how large each statistics timeslice is. It can
-# be defined in a human readable short form -- e.g. "1d", "1y".
+# Settings for local room and user statistics collection. See
+# docs/room_and_user_statistics.md.
#
-# 'retention' controls how long historical statistics will be kept for.
-# It can be defined in a human readable short form -- e.g. "1d", "1y".
-#
-#
-#stats:
-# enabled: true
-# bucket_size: 1d
-# retention: 1y
+stats:
+ # Uncomment the following to disable room and user statistics. Note that doing
+ # so may cause certain features (such as the room directory) not to work
+ # correctly.
+ #
+ #enabled: false
+
+ # The size of each timeslice in the room_stats_historical and
+ # user_stats_historical tables, as a time period. Defaults to "1d".
+ #
+ #bucket_size: 1h
# Server Notices room configuration
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index e615ac99..52947f60 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -14,6 +14,7 @@ The Python class is instantiated with two objects:
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
+All the methods must be defined.
There's a generic method for checking every event (`check_event_for_spam`), as
well as some specific methods:
@@ -24,6 +25,7 @@ well as some specific methods:
* `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
+* `check_media_file_for_spam`
The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
@@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to
call back into the homeserver internals.
+Additionally, a `parse_config` method is mandatory and receives the plugin config
+dictionary. After parsing, It must return an object which will be
+passed to `__init__` later.
+
### Example
```python
@@ -41,6 +47,10 @@ class ExampleSpamChecker:
self.config = config
self.api = api
+ @staticmethod
+ def parse_config(config):
+ return config
+
async def check_event_for_spam(self, foo):
return False # allow all events
@@ -59,7 +69,13 @@ class ExampleSpamChecker:
async def check_username_for_spam(self, user_profile):
return False # allow all usernames
- async def check_registration_for_spam(self, email_threepid, username, request_info):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
return RegistrationBehaviour.ALLOW # allow all registrations
async def check_media_file_for_spam(self, file_wrapper, file_info):
diff --git a/mypy.ini b/mypy.ini
index 64ed45da..e0685e09 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -69,6 +69,7 @@ files =
synapse/util/async_helpers.py,
synapse/util/caches,
synapse/util/metrics.py,
+ synapse/util/macaroons.py,
synapse/util/stringutils.py,
tests/replication,
tests/test_utils,
@@ -116,9 +117,6 @@ ignore_missing_imports = True
[mypy-saml2.*]
ignore_missing_imports = True
-[mypy-unpaddedbase64]
-ignore_missing_imports = True
-
[mypy-canonicaljson]
ignore_missing_imports = True
diff --git a/scripts-dev/config-lint.sh b/scripts-dev/config-lint.sh
index 189ca665..91321604 100755
--- a/scripts-dev/config-lint.sh
+++ b/scripts-dev/config-lint.sh
@@ -2,9 +2,14 @@
# Find linting errors in Synapse's default config file.
# Exits with 0 if there are no problems, or another code otherwise.
+# cd to the root of the repository
+cd `dirname $0`/..
+
+# Restore backup of sample config upon script exit
+trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT
+
# Fix non-lowercase true/false values
sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
-rm docs/sample_config.yaml.bak
# Check if anything changed
-git diff --exit-code docs/sample_config.yaml
+diff docs/sample_config.yaml docs/sample_config.yaml.bak
diff --git a/setup.cfg b/setup.cfg
index f46e43fa..5e301c2c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -3,6 +3,7 @@ test_suite = tests
[check-manifest]
ignore =
+ .git-blame-ignore-revs
contrib
contrib/*
docs/*
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 618548a3..080ca402 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -17,7 +17,9 @@
"""
from typing import Any, List, Optional, Type, Union
-class RedisProtocol:
+from twisted.internet import protocol
+
+class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
@@ -52,7 +54,7 @@ def lazyConnection(
class ConnectionHandler: ...
-class RedisFactory:
+class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 56ca8888..8e57739c 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.29.0"
+__version__ = "1.30.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 89e62b0e..e10e33fd 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -163,7 +164,7 @@ class Auth:
async def get_user_by_req(
self,
- request: Request,
+ request: SynapseRequest,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
@@ -408,7 +409,7 @@ class Auth:
raise _InvalidMacaroonException()
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
guest = False
for caveat in macaroon.caveats:
@@ -416,7 +417,12 @@ class Auth:
guest = True
self.validate_macaroon(macaroon, rights, user_id=user_id)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ except (
+ pymacaroons.exceptions.MacaroonException,
+ KeyError,
+ TypeError,
+ ValueError,
+ ):
raise InvalidClientTokenError("Invalid macaroon passed.")
if rights == "access":
@@ -424,27 +430,6 @@ class Auth:
return user_id, guest
- def get_user_id_from_macaroon(self, macaroon):
- """Retrieve the user_id given by the caveats on the macaroon.
-
- Does *not* validate the macaroon.
-
- Args:
- macaroon (pymacaroons.Macaroon): The macaroon to validate
-
- Returns:
- (str) user id
-
- Raises:
- InvalidClientCredentialsError if there is no user_id caveat in the
- macaroon
- """
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix) :]
- raise InvalidClientTokenError("No user caveat in macaroon")
-
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -465,21 +450,13 @@ class Auth:
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self.clock.time_msec)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self._macaroon_secret_key)
- def _verify_expiry(self, caveat):
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self.hs.get_clock().time_msec()
- return now < expiry
-
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 93c2aabc..9d3bbe3b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(
- hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
+ hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 40269667..ba9cd63c 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -212,9 +212,8 @@ class Config:
@classmethod
def read_file(cls, file_path, config_name):
- cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
- return file_stream.read()
+ """Deprecated: call read_file directly"""
+ return read_file(file_path, (config_name,))
def read_template(self, filename: str) -> jinja2.Template:
"""Load a template file from disk.
@@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
return self._get_instance(key)
-__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+def read_file(file_path: Any, config_path: Iterable[str]) -> str:
+ """Check the given file exists, and read it into a string
+
+ If it does not, emit an error indicating the problem
+
+ Args:
+ file_path: the file to be read
+ config_path: where in the configuration file_path came from, so that a useful
+ error can be emitted if it does not exist.
+ Returns:
+ content of the file.
+ Raises:
+ ConfigError if there is a problem reading the file.
+ """
+ if not isinstance(file_path, str):
+ raise ConfigError("%r is not a string", config_path)
+
+ try:
+ os.stat(file_path)
+ with open(file_path) as file_stream:
+ return file_stream.read()
+ except OSError as e:
+ raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
+
+
+__all__ = [
+ "Config",
+ "RootConfig",
+ "ShardedWorkerHandlingConfig",
+ "RoutableShardedWorkerHandlingConfig",
+ "read_file",
+]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index db16c86f..e896fd34 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ...
+
+def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index e56cf846..999aecce 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -21,8 +21,10 @@ import threading
from string import Template
import yaml
+from zope.interface import implementer
from twisted.logger import (
+ ILogObserver,
LogBeginner,
STDLibLogObserver,
eventAsText,
@@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
threadlocal = threading.local()
- def _log(event):
+ @implementer(ILogObserver)
+ def _log(event: dict) -> None:
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index a27594be..2bfb537c 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import Counter
-from typing import Iterable, Optional, Tuple, Type
+from typing import Iterable, Mapping, Optional, Tuple, Type
import attr
@@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri
-from ._base import Config, ConfigError
+from ._base import Config, ConfigError, read_file
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
@@ -97,7 +97,26 @@ class OIDCConfig(Config):
#
# client_id: Required. oauth2 client id to use.
#
- # client_secret: Required. oauth2 client secret to use.
+ # client_secret: oauth2 client secret to use. May be omitted if
+ # client_secret_jwt_key is given, or if client_auth_method is 'none'.
+ #
+ # client_secret_jwt_key: Alternative to client_secret: details of a key used
+ # to create a JSON Web Token to be used as an OAuth2 client secret. If
+ # given, must be a dictionary with the following properties:
+ #
+ # key: a pem-encoded signing key. Must be a suitable key for the
+ # algorithm specified. Required unless 'key_file' is given.
+ #
+ # key_file: the path to file containing a pem-encoded signing key file.
+ # Required unless 'key' is given.
+ #
+ # jwt_header: a dictionary giving properties to include in the JWT
+ # header. Must include the key 'alg', giving the algorithm used to
+ # sign the JWT, such as "ES256", using the JWA identifiers in
+ # RFC7518.
+ #
+ # jwt_payload: an optional dictionary giving properties to include in
+ # the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
@@ -218,7 +237,7 @@ class OIDCConfig(Config):
#
#- idp_id: github
# idp_name: Github
- # idp_brand: org.matrix.github
+ # idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
@@ -240,7 +259,7 @@ class OIDCConfig(Config):
# jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
- "required": ["issuer", "client_id", "client_secret"],
+ "required": ["issuer", "client_id"],
"properties": {
"idp_id": {
"type": "string",
@@ -253,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"idp_icon": {"type": "string"},
"idp_brand": {
"type": "string",
- # MSC2758-style namespaced identifier
+ "minLength": 1,
+ "maxLength": 255,
+ "pattern": "^[a-z][a-z0-9_.-]*$",
+ },
+ "idp_unstable_brand": {
+ "type": "string",
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
@@ -262,6 +286,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"issuer": {"type": "string"},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
+ "client_secret_jwt_key": {
+ "type": "object",
+ "required": ["jwt_header"],
+ "oneOf": [
+ {"required": ["key"]},
+ {"required": ["key_file"]},
+ ],
+ "properties": {
+ "key": {"type": "string"},
+ "key_file": {"type": "string"},
+ "jwt_header": {
+ "type": "object",
+ "required": ["alg"],
+ "properties": {
+ "alg": {"type": "string"},
+ },
+ "additionalProperties": {"type": "string"},
+ },
+ "jwt_payload": {
+ "type": "object",
+ "additionalProperties": {"type": "string"},
+ },
+ },
+ },
"client_auth_method": {
"type": "string",
# the following list is the same as the keys of
@@ -404,15 +452,31 @@ def _parse_oidc_config_dict(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e
+ client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
+ client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
+ if client_secret_jwt_key_config is not None:
+ keyfile = client_secret_jwt_key_config.get("key_file")
+ if keyfile:
+ key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
+ else:
+ key = client_secret_jwt_key_config["key"]
+ client_secret_jwt_key = OidcProviderClientSecretJwtKey(
+ key=key,
+ jwt_header=client_secret_jwt_key_config["jwt_header"],
+ jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
+ )
+
return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon,
idp_brand=oidc_config.get("idp_brand"),
+ unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
- client_secret=oidc_config["client_secret"],
+ client_secret=oidc_config.get("client_secret"),
+ client_secret_jwt_key=client_secret_jwt_key,
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
@@ -428,6 +492,18 @@ def _parse_oidc_config_dict(
@attr.s(slots=True, frozen=True)
+class OidcProviderClientSecretJwtKey:
+ # a pem-encoded signing key
+ key = attr.ib(type=str)
+
+ # properties to include in the JWT header
+ jwt_header = attr.ib(type=Mapping[str, str])
+
+ # properties to include in the JWT payload.
+ jwt_payload = attr.ib(type=Mapping[str, str])
+
+
+@attr.s(slots=True, frozen=True)
class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol.
@@ -442,6 +518,9 @@ class OidcProviderConfig:
# Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str])
+ # Optional brand identifier for the unstable API (see MSC2858).
+ unstable_idp_brand = attr.ib(type=Optional[str])
+
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
@@ -452,8 +531,13 @@ class OidcProviderConfig:
# oauth2 client id to use
client_id = attr.ib(type=str)
- # oauth2 client secret to use
- client_secret = attr.ib(type=str)
+ # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+ # a secret.
+ client_secret = attr.ib(type=Optional[str])
+
+ # key to use to construct a JWT to use as a client secret. May be `None` if
+ # `client_secret` is set.
+ client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 2afca36e..5f8910b6 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -841,8 +841,7 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
- # API, so this setting is of limited value if federation is enabled on
- # the server.
+ # API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index b559bfa4..2258329a 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -13,10 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
+import logging
from ._base import Config
+ROOM_STATS_DISABLED_WARN = """\
+WARNING: room/user statistics have been disabled via the stats.enabled
+configuration setting. This means that certain features (such as the room
+directory) will not operate correctly. Future versions of Synapse may ignore
+this setting.
+
+To fix this warning, remove the stats.enabled setting from your configuration
+file.
+--------------------------------------------------------------------------------"""
+
+logger = logging.getLogger(__name__)
+
class StatsConfig(Config):
"""Stats Configuration
@@ -28,30 +40,29 @@ class StatsConfig(Config):
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000
- self.stats_retention = sys.maxsize
stats_config = config.get("stats", None)
if stats_config:
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
self.stats_bucket_size = self.parse_duration(
stats_config.get("bucket_size", "1d")
)
- self.stats_retention = self.parse_duration(
- stats_config.get("retention", "%ds" % (sys.maxsize,))
- )
+ if not self.stats_enabled:
+ logger.warning(ROOM_STATS_DISABLED_WARN)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
- # Local statistics collection. Used in populating the room directory.
+ # Settings for local room and user statistics collection. See
+ # docs/room_and_user_statistics.md.
#
- # 'bucket_size' controls how large each statistics timeslice is. It can
- # be defined in a human readable short form -- e.g. "1d", "1y".
- #
- # 'retention' controls how long historical statistics will be kept for.
- # It can be defined in a human readable short form -- e.g. "1d", "1y".
- #
- #
- #stats:
- # enabled: true
- # bucket_size: 1d
- # retention: 1y
+ stats:
+ # Uncomment the following to disable room and user statistics. Note that doing
+ # so may cause certain features (such as the room directory) not to work
+ # correctly.
+ #
+ #enabled: false
+
+ # The size of each timeslice in the room_stats_historical and
+ # user_stats_historical tables, as a time period. Defaults to "1d".
+ #
+ #bucket_size: 1h
"""
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 8cfc0bb3..a9185987 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,6 +15,7 @@
# limitations under the License.
import inspect
+import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.rest.media.v1._base import FileInfo
@@ -27,6 +28,8 @@ if TYPE_CHECKING:
import synapse.events
import synapse.server
+logger = logging.getLogger(__name__)
+
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
@@ -190,6 +193,7 @@ class SpamChecker:
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str] = None,
) -> RegistrationBehaviour:
"""Checks if we should allow the given registration request.
@@ -198,6 +202,9 @@ class SpamChecker:
username: The request user name, if any
request_info: List of tuples of user agent and IP that
were used during the registration process.
+ auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+ "cas". If any. Note this does not include users registered
+ via a password provider.
Returns:
Enum for how the request should be handled
@@ -208,9 +215,25 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
- behaviour = await maybe_awaitable(
- checker(email_threepid, username, request_info)
- )
+ # Provide auth_provider_id if the function supports it
+ checker_args = inspect.signature(checker)
+ if len(checker_args.parameters) == 4:
+ d = checker(
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ )
+ elif len(checker_args.parameters) == 3:
+ d = checker(email_threepid, username, request_info)
+ else:
+ logger.error(
+ "Invalid signature for %s.check_registration_for_spam. Denying registration",
+ spam_checker.__module__,
+ )
+ return RegistrationBehaviour.DENY
+
+ behaviour = await maybe_awaitable(d)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2f832b47..9839d3d0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,6 +22,7 @@ from typing import (
Awaitable,
Callable,
Dict,
+ Iterable,
List,
Optional,
Tuple,
@@ -90,16 +91,15 @@ pdu_process_time = Histogram(
"Time taken to process an event",
)
-
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_received_pdu_age",
- "The age (in seconds) of the last PDU successfully received from the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_received_pdu_time",
+ "The timestamp of the last PDU which was successfully received from the given domain",
labelnames=("server_name",),
)
class FederationServer(FederationBase):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
@@ -112,14 +112,15 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
- self._federation_ratelimiter = hs.get_federation_ratelimiter()
-
self._server_linearizer = Linearizer("fed_server")
- self._transaction_linearizer = Linearizer("fed_txn_handler")
+
+ # origins that we are currently processing a transaction from.
+ # a dict from origin to txn id.
+ self._active_transactions = {} # type: Dict[str, str]
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
- hs, "fed_txn_handler", timeout_ms=30000
+ hs.get_clock(), "fed_txn_handler", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store)
@@ -129,10 +130,10 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(
- hs, "state_resp", timeout_ms=30000
+ hs.get_clock(), "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
- hs, "state_ids_resp", timeout_ms=30000
+ hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = (
@@ -169,6 +170,33 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
+ # Reject malformed transactions early: reject if too many PDUs/EDUs
+ if len(transaction.pdus) > 50 or ( # type: ignore
+ hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
+ ):
+ logger.info("Transaction PDU or EDU count too large. Returning 400")
+ return 400, {}
+
+ # we only process one transaction from each origin at a time. We need to do
+ # this check here, rather than in _on_incoming_transaction_inner so that we
+ # don't cache the rejection in _transaction_resp_cache (so that if the txn
+ # arrives again later, we can process it).
+ current_transaction = self._active_transactions.get(origin)
+ if current_transaction and current_transaction != transaction_id:
+ logger.warning(
+ "Received another txn %s from %s while still processing %s",
+ transaction_id,
+ origin,
+ current_transaction,
+ )
+ return 429, {
+ "errcode": Codes.UNKNOWN,
+ "error": "Too many concurrent transactions",
+ }
+
+ # CRITICAL SECTION: we must now not await until we populate _active_transactions
+ # in _on_incoming_transaction_inner.
+
# We wrap in a ResponseCache so that we de-duplicate retried
# transactions.
return await self._transaction_resp_cache.wrap(
@@ -182,26 +210,18 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
- # Use a linearizer to ensure that transactions from a remote are
- # processed in order.
- with await self._transaction_linearizer.queue(origin):
- # We rate limit here *after* we've queued up the incoming requests,
- # so that we don't fill up the ratelimiter with blocked requests.
- #
- # This is important as the ratelimiter allows N concurrent requests
- # at a time, and only starts ratelimiting if there are more requests
- # than that being processed at a time. If we queued up requests in
- # the linearizer/response cache *after* the ratelimiting then those
- # queued up requests would count as part of the allowed limit of N
- # concurrent requests.
- with self._federation_ratelimiter.ratelimit(origin) as d:
- await d
-
- result = await self._handle_incoming_transaction(
- origin, transaction, request_time
- )
+ # CRITICAL SECTION: the first thing we must do (before awaiting) is
+ # add an entry to _active_transactions.
+ assert origin not in self._active_transactions
+ self._active_transactions[origin] = transaction.transaction_id # type: ignore
- return result
+ try:
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
+ )
+ return result
+ finally:
+ del self._active_transactions[origin]
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
@@ -227,19 +247,6 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
- # Reject if PDU count > 50 or EDU count > 100
- if len(transaction.pdus) > 50 or ( # type: ignore
- hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
- ):
-
- logger.info("Transaction PDU or EDU count too large. Returning 400")
-
- response = {}
- await self.transaction_actions.set_response(
- origin, transaction, 400, response
- )
- return 400, response
-
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
@@ -335,42 +342,48 @@ class FederationServer(FederationBase):
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
- logger.debug("Processing PDUs for %s", room_id)
- try:
- await self.check_server_matches_acl(origin_host, room_id)
- except AuthError as e:
- logger.warning("Ignoring PDUs for room %s from banned server", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- pdu_results[event_id] = e.error_dict()
- return
+ with nested_logging_context(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- with pdu_process_time.time():
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()),
- )
+ try:
+ await self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warning(
+ "Ignoring PDUs for room %s from banned server", room_id
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
+ for pdu in pdus_by_room[room_id]:
+ pdu_results[pdu.event_id] = await process_pdu(pdu)
+
+ async def process_pdu(pdu: EventBase) -> JsonDict:
+ event_id = pdu.event_id
+ with pdu_process_time.time():
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ return {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ return {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
+ )
+ return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if newest_pdu_ts and origin in self._federation_metrics_domains:
- newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
- last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+ last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
return pdu_results
@@ -448,18 +461,22 @@ class FederationServer(FederationBase):
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
+ auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
) -> Dict[str, list]:
if event_id:
- pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ pdus = await self.handler.get_state_for_pdu(
+ room_id, event_id
+ ) # type: Iterable[EventBase]
else:
pdus = (await self.state.get_current_state(room_id)).values()
- auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+ auth_chain = await self.store.get_auth_chain(
+ room_id, [pdu.event_id for pdu in pdus]
+ )
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
@@ -863,7 +880,9 @@ class FederationHandlerRegistry:
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
- self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
+ self.query_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
@@ -897,7 +916,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler
def register_query_handler(
- self, query_type: str, handler: Callable[[dict], defer.Deferred]
+ self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -970,7 +989,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
- async def on_query(self, query_type: str, args: dict):
+ async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index deb519f3..cc0d765e 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -17,6 +17,7 @@ import datetime
import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
+import attr
from prometheus_client import Counter
from synapse.api.errors import (
@@ -93,6 +94,10 @@ class PerDestinationQueue:
self._destination = destination
self.transmission_loop_running = False
+ # Flag to signal to any running transmission loop that there is new data
+ # queued up to be sent.
+ self._new_data_to_send = False
+
# True whilst we are sending events that the remote homeserver missed
# because it was unreachable. We start in this state so we can perform
# catch-up at startup.
@@ -108,7 +113,7 @@ class PerDestinationQueue:
# destination (we are the only updater so this is safe)
self._last_successful_stream_ordering = None # type: Optional[int]
- # a list of pending PDUs
+ # a queue of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
# XXX this is never actually used: see
@@ -208,6 +213,10 @@ class PerDestinationQueue:
transaction in the background.
"""
+ # Mark that we (may) have new things to send, so that any running
+ # transmission loop will recheck whether there is stuff to send.
+ self._new_data_to_send = True
+
if self.transmission_loop_running:
# XXX: this can get stuck on by a never-ending
# request at which point pending_pdus just keeps growing.
@@ -250,125 +259,41 @@ class PerDestinationQueue:
pending_pdus = []
while True:
- # We have to keep 2 free slots for presence and rr_edus
- limit = MAX_EDUS_PER_TRANSACTION - 2
-
- device_update_edus, dev_list_id = await self._get_device_update_edus(
- limit
- )
-
- limit -= len(device_update_edus)
-
- (
- to_device_edus,
- device_stream_id,
- ) = await self._get_to_device_message_edus(limit)
-
- pending_edus = device_update_edus + to_device_edus
-
- # BEGIN CRITICAL SECTION
- #
- # In order to avoid a race condition, we need to make sure that
- # the following code (from popping the queues up to the point
- # where we decide if we actually have any pending messages) is
- # atomic - otherwise new PDUs or EDUs might arrive in the
- # meantime, but not get sent because we hold the
- # transmission_loop_running flag.
-
- pending_pdus = self._pending_pdus
+ self._new_data_to_send = False
- # We can only include at most 50 PDUs per transactions
- pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
+ async with _TransactionQueueManager(self) as (
+ pending_pdus,
+ pending_edus,
+ ):
+ if not pending_pdus and not pending_edus:
+ logger.debug("TX [%s] Nothing to send", self._destination)
+
+ # If we've gotten told about new things to send during
+ # checking for things to send, we try looking again.
+ # Otherwise new PDUs or EDUs might arrive in the meantime,
+ # but not get sent because we hold the
+ # `transmission_loop_running` flag.
+ if self._new_data_to_send:
+ continue
+ else:
+ return
- pending_edus.extend(self._get_rr_edus(force_flush=False))
- pending_presence = self._pending_presence
- self._pending_presence = {}
- if pending_presence:
- pending_edus.append(
- Edu(
- origin=self._server_name,
- destination=self._destination,
- edu_type="m.presence",
- content={
- "push": [
- format_user_presence_state(
- presence, self._clock.time_msec()
- )
- for presence in pending_presence.values()
- ]
- },
+ if pending_pdus:
+ logger.debug(
+ "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ self._destination,
+ len(pending_pdus),
)
- )
- pending_edus.extend(
- self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
- )
- while (
- len(pending_edus) < MAX_EDUS_PER_TRANSACTION
- and self._pending_edus_keyed
- ):
- _, val = self._pending_edus_keyed.popitem()
- pending_edus.append(val)
-
- if pending_pdus:
- logger.debug(
- "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- self._destination,
- len(pending_pdus),
+ await self._transaction_manager.send_new_transaction(
+ self._destination, pending_pdus, pending_edus
)
- if not pending_pdus and not pending_edus:
- logger.debug("TX [%s] Nothing to send", self._destination)
- self._last_device_stream_id = device_stream_id
- return
-
- # if we've decided to send a transaction anyway, and we have room, we
- # may as well send any pending RRs
- if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
- pending_edus.extend(self._get_rr_edus(force_flush=True))
-
- # END CRITICAL SECTION
-
- success = await self._transaction_manager.send_new_transaction(
- self._destination, pending_pdus, pending_edus
- )
- if success:
sent_transactions_counter.inc()
sent_edus_counter.inc(len(pending_edus))
for edu in pending_edus:
sent_edus_by_type.labels(edu.edu_type).inc()
- # Remove the acknowledged device messages from the database
- # Only bother if we actually sent some device messages
- if to_device_edus:
- await self._store.delete_device_msgs_for_remote(
- self._destination, device_stream_id
- )
- # also mark the device updates as sent
- if device_update_edus:
- logger.info(
- "Marking as sent %r %r", self._destination, dev_list_id
- )
- await self._store.mark_as_sent_devices_by_remote(
- self._destination, dev_list_id
- )
-
- self._last_device_stream_id = device_stream_id
- self._last_device_list_stream_id = dev_list_id
-
- if pending_pdus:
- # we sent some PDUs and it was successful, so update our
- # last_successful_stream_ordering in the destinations table.
- final_pdu = pending_pdus[-1]
- last_successful_stream_ordering = (
- final_pdu.internal_metadata.stream_ordering
- )
- assert last_successful_stream_ordering
- await self._store.set_destination_last_successful_stream_ordering(
- self._destination, last_successful_stream_ordering
- )
- else:
- break
except NotRetryingDestination as e:
logger.debug(
"TX [%s] not ready for retry yet (next retry at %s) - "
@@ -401,7 +326,7 @@ class PerDestinationQueue:
self._pending_presence = {}
self._pending_rrs = {}
- self._start_catching_up()
+ self._start_catching_up()
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
@@ -412,7 +337,6 @@ class PerDestinationQueue:
e,
)
- self._start_catching_up()
except RequestSendFailed as e:
logger.warning(
"TX [%s] Failed to send transaction: %s", self._destination, e
@@ -422,16 +346,12 @@ class PerDestinationQueue:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
@@ -499,13 +419,10 @@ class PerDestinationQueue:
rooms = [p.room_id for p in catchup_pdus]
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
- success = await self._transaction_manager.send_new_transaction(
+ await self._transaction_manager.send_new_transaction(
self._destination, catchup_pdus, []
)
- if not success:
- return
-
sent_transactions_counter.inc()
final_pdu = catchup_pdus[-1]
self._last_successful_stream_ordering = cast(
@@ -584,3 +501,135 @@ class PerDestinationQueue:
"""
self._catching_up = True
self._pending_pdus = []
+
+
+@attr.s(slots=True)
+class _TransactionQueueManager:
+ """A helper async context manager for pulling stuff off the queues and
+ tracking what was last successfully sent, etc.
+ """
+
+ queue = attr.ib(type=PerDestinationQueue)
+
+ _device_stream_id = attr.ib(type=Optional[int], default=None)
+ _device_list_id = attr.ib(type=Optional[int], default=None)
+ _last_stream_ordering = attr.ib(type=Optional[int], default=None)
+ _pdus = attr.ib(type=List[EventBase], factory=list)
+
+ async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
+ # First we calculate the EDUs we want to send, if any.
+
+ # We start by fetching device related EDUs, i.e device updates and to
+ # device messages. We have to keep 2 free slots for presence and rr_edus.
+ limit = MAX_EDUS_PER_TRANSACTION - 2
+
+ device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
+ limit
+ )
+
+ if device_update_edus:
+ self._device_list_id = dev_list_id
+ else:
+ self.queue._last_device_list_stream_id = dev_list_id
+
+ limit -= len(device_update_edus)
+
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = await self.queue._get_to_device_message_edus(limit)
+
+ if to_device_edus:
+ self._device_stream_id = device_stream_id
+ else:
+ self.queue._last_device_stream_id = device_stream_id
+
+ pending_edus = device_update_edus + to_device_edus
+
+ # Now add the read receipt EDU.
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
+
+ # And presence EDU.
+ if self.queue._pending_presence:
+ pending_edus.append(
+ Edu(
+ origin=self.queue._server_name,
+ destination=self.queue._destination,
+ edu_type="m.presence",
+ content={
+ "push": [
+ format_user_presence_state(
+ presence, self.queue._clock.time_msec()
+ )
+ for presence in self.queue._pending_presence.values()
+ ]
+ },
+ )
+ )
+ self.queue._pending_presence = {}
+
+ # Finally add any other types of EDUs if there is room.
+ pending_edus.extend(
+ self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
+ )
+ while (
+ len(pending_edus) < MAX_EDUS_PER_TRANSACTION
+ and self.queue._pending_edus_keyed
+ ):
+ _, val = self.queue._pending_edus_keyed.popitem()
+ pending_edus.append(val)
+
+ # Now we look for any PDUs to send, by getting up to 50 PDUs from the
+ # queue
+ self._pdus = self.queue._pending_pdus[:50]
+
+ if not self._pdus and not pending_edus:
+ return [], []
+
+ # if we've decided to send a transaction anyway, and we have room, we
+ # may as well send any pending RRs
+ if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
+
+ if self._pdus:
+ self._last_stream_ordering = self._pdus[
+ -1
+ ].internal_metadata.stream_ordering
+ assert self._last_stream_ordering
+
+ return self._pdus, pending_edus
+
+ async def __aexit__(self, exc_type, exc, tb):
+ if exc_type is not None:
+ # Failed to send transaction, so we bail out.
+ return
+
+ # Successfully sent transactions, so we remove pending PDUs from the queue
+ if self._pdus:
+ self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
+
+ # Succeeded to send the transaction so we record where we have sent up
+ # to in the various streams
+
+ if self._device_stream_id:
+ await self.queue._store.delete_device_msgs_for_remote(
+ self.queue._destination, self._device_stream_id
+ )
+ self.queue._last_device_stream_id = self._device_stream_id
+
+ # also mark the device updates as sent
+ if self._device_list_id:
+ logger.info(
+ "Marking as sent %r %r", self.queue._destination, self._device_list_id
+ )
+ await self.queue._store.mark_as_sent_devices_by_remote(
+ self.queue._destination, self._device_list_id
+ )
+ self.queue._last_device_list_stream_id = self._device_list_id
+
+ if self._last_stream_ordering:
+ # we sent some PDUs and it was successful, so update our
+ # last_successful_stream_ordering in the destinations table.
+ await self.queue._store.set_destination_last_successful_stream_ordering(
+ self.queue._destination, self._last_stream_ordering
+ )
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 763aff29..07b740c2 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_sent_pdu_age",
- "The age (in seconds) of the last PDU successfully sent to the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_sent_pdu_time",
+ "The timestamp of the last PDU which was successfully sent to the given domain",
labelnames=("server_name",),
)
@@ -69,15 +69,12 @@ class TransactionManager:
destination: str,
pdus: List[EventBase],
edus: List[Edu],
- ) -> bool:
+ ) -> None:
"""
Args:
destination: The destination to send to (e.g. 'example.org')
pdus: In-order list of PDUs to send
edus: List of EDUs to send
-
- Returns:
- True iff the transaction was successful
"""
# Make a transaction-sending opentracing span. This span follows on from
@@ -96,8 +93,6 @@ class TransactionManager:
edu.strip_context()
with start_active_span_follows_from("send_transaction", span_contexts):
- success = True
-
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
@@ -152,45 +147,29 @@ class TransactionManager:
response = await self._transport_layer.send_transaction(
transaction, json_data_cb
)
- code = 200
except HttpResponseException as e:
code = e.code
response = e.response
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response", destination, txn_id, code
- )
- raise e
-
- logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
-
- if code == 200:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warning(
- "TX [%s] {%s} Remote returned error for %s: %s",
- destination,
- txn_id,
- e_id,
- r,
- )
- else:
- for p in pdus:
+ set_tag(tags.ERROR, True)
+
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+ raise
+
+ logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
+
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
logger.warning(
- "TX [%s] {%s} Failed to send event %s",
+ "TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
- p.event_id,
+ e_id,
+ r,
)
- success = False
- if success and pdus and destination in self._federation_metrics_domains:
+ if pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
- last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
- last_pdu_age_metric.labels(server_name=destination).set(
- last_pdu_age / 1000
+ last_pdu_ts_metric.labels(server_name=destination).set(
+ last_pdu.origin_server_ts / 1000
)
-
- set_tag(tags.ERROR, not success)
- return success
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 5ecb2da1..132be238 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -73,7 +73,9 @@ class AcmeHandler:
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
- self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
+ self.reactor.listenTCP(
+ self.hs.config.acme_port, srv, backlog=50, interface=host
+ )
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e415..fb5f8118 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
extra_attributes = attr.ib(type=JsonDict)
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+ """Data we store in a short-term login token"""
+
+ user_id = attr.ib(type=str)
+
+ # the SSO Identity Provider that the user authenticated with, to get this token
+ auth_provider_id = attr.ib(type=str)
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -326,7 +337,8 @@ class AuthHandler(BaseHandler):
user is too high to proceed
"""
-
+ if not requester.access_token_id:
+ raise ValueError("Cannot validate a user without an access token")
if self._ui_auth_session_timeout:
last_validated = await self.store.get_access_token_last_validated(
requester.access_token_id
@@ -1164,18 +1176,16 @@ class AuthHandler(BaseHandler):
return None
return user_id
- async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
- auth_api = self.hs.get_auth()
- user_id = None
+ async def validate_short_term_login_token(
+ self, login_token: str
+ ) -> LoginTokenAttributes:
try:
- macaroon = pymacaroons.Macaroon.deserialize(login_token)
- user_id = auth_api.get_user_id_from_macaroon(macaroon)
- auth_api.validate_macaroon(macaroon, "login", user_id)
+ res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- await self.auth.check_auth_blocking(user_id)
- return user_id
+ await self.auth.check_auth_blocking(res.user_id)
+ return res
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
@@ -1204,7 +1214,7 @@ class AuthHandler(BaseHandler):
async def delete_access_tokens_for_user(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
@@ -1397,6 +1407,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1417,9 @@ class AuthHandler(BaseHandler):
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ auth_provider_id: The id of the SSO Identity provider that was used for
+ login. This will be stored in the login token for future tracking in
+ prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
@@ -1427,6 +1441,7 @@ class AuthHandler(BaseHandler):
self._complete_sso_login(
registered_user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_attributes,
@@ -1437,6 +1452,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1479,7 @@ class AuthHandler(BaseHandler):
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1585,48 @@ class MacaroonGenerator:
return macaroon.serialize()
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ auth_provider_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
return macaroon.serialize()
+ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+ """Verify a short-term-login macaroon
+
+ Checks that the given token is a valid, unexpired short-term-login token
+ minted by this server.
+
+ Args:
+ token: the login token to verify
+
+ Returns:
+ the user_id that this token is valid for
+
+ Raises:
+ MacaroonVerificationFailedException if the verification failed
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
+ auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = login")
+ v.satisfy_general(lambda c: c.startswith("user_id = "))
+ v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+ satisfy_expiry(v, self.hs.get_clock().time_msec)
+ v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+ return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 04972f9c..cb67589f 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -83,6 +83,7 @@ class CasHandler:
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
+ self.unstable_idp_brand = None
self._sso_handler = hs.get_sso_handler()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 2ead626a..598a66f7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -201,7 +201,7 @@ class FederationHandler(BaseHandler):
or pdu.internal_metadata.is_outlier()
)
if already_seen:
- logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
+ logger.debug("Already seen pdu")
return
# do some initial sanity-checking of the event. In particular, make
@@ -210,18 +210,14 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warning(
- "[%s %s] Received event failed sanity checks", room_id, event_id
- )
+ logger.warning("Received event failed sanity checks")
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
- "[%s %s] Queuing PDU from %s for now: join in progress",
- room_id,
- event_id,
+ "Queuing PDU from %s for now: join in progress",
origin,
)
self.room_queues[room_id].append((pdu, origin))
@@ -236,9 +232,7 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
- "[%s %s] Ignoring PDU from %s as we're not in the room",
- room_id,
- event_id,
+ "Ignoring PDU from %s as we're not in the room",
origin,
)
return None
@@ -250,7 +244,7 @@ class FederationHandler(BaseHandler):
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
+ logger.debug("min_depth: %d", min_depth)
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
@@ -267,17 +261,13 @@ class FederationHandler(BaseHandler):
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
- "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
- room_id,
- event_id,
+ "Acquiring room lock to fetch %d missing prev_events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
- "[%s %s] Acquired room lock to fetch %d missing prev_events",
- room_id,
- event_id,
+ "Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
)
@@ -297,9 +287,7 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
- "[%s %s] Found all missing prev_events",
- room_id,
- event_id,
+ "Found all missing prev_events",
)
if prevs - seen:
@@ -329,9 +317,7 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warning(
- "[%s %s] Rejecting: failed to fetch %d prev events: %s",
- room_id,
- event_id,
+ "Rejecting: failed to fetch %d prev events: %s",
len(prevs - seen),
shortstr(prevs - seen),
)
@@ -367,17 +353,16 @@ class FederationHandler(BaseHandler):
# Ask the remote server for the states we don't
# know about
for p in prevs - seen:
- logger.info(
- "Requesting state at missing prev_event %s",
- event_id,
- )
+ logger.info("Requesting state after missing prev_event %s", p)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (remote_state, _,) = await self._get_state_for_room(
- origin, room_id, p, include_event_in_state=True
+ remote_state = (
+ await self._get_state_after_missing_prev_event(
+ origin, room_id, p
+ )
)
remote_state_map = {
@@ -414,10 +399,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
- "[%s %s] Error attempting to resolve state at missing "
- "prev_events",
- room_id,
- event_id,
+ "Error attempting to resolve state at missing " "prev_events",
exc_info=True,
)
raise FederationError(
@@ -454,9 +436,7 @@ class FederationHandler(BaseHandler):
latest |= seen
logger.info(
- "[%s %s]: Requesting missing events between %s and %s",
- room_id,
- event_id,
+ "Requesting missing events between %s and %s",
shortstr(latest),
event_id,
)
@@ -523,15 +503,11 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
- logger.warning(
- "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
- )
+ logger.warning("Failed to get prev_events: %s", e)
return
logger.info(
- "[%s %s]: Got %d prev_events: %s",
- room_id,
- event_id,
+ "Got %d prev_events: %s",
len(missing_events),
shortstr(missing_events),
)
@@ -542,9 +518,7 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
- "[%s %s] Handling received prev_event %s",
- room_id,
- event_id,
+ "Handling received prev_event %s",
ev.event_id,
)
with nested_logging_context(ev.event_id):
@@ -553,9 +527,7 @@ class FederationHandler(BaseHandler):
except FederationError as e:
if e.code == 403:
logger.warning(
- "[%s %s] Received prev_event %s failed history check.",
- room_id,
- event_id,
+ "Received prev_event %s failed history check.",
ev.event_id,
)
else:
@@ -566,7 +538,6 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
- include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
@@ -574,11 +545,9 @@ class FederationHandler(BaseHandler):
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
- include_event_in_state: if true, the event itself will be included in the
- returned state event list.
Returns:
- A list of events in the state, possibly including the event itself, and
+ A list of events in the state, not including the event itself, and
a list of events in the auth chain for the given event.
"""
(
@@ -590,9 +559,6 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
- if include_event_in_state:
- desired_events.add(event_id)
-
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
@@ -609,13 +575,6 @@ class FederationHandler(BaseHandler):
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
- if include_event_in_state:
- remote_event = event_map.get(event_id)
- if not remote_event:
- raise Exception("Unable to get missing prev_event %s" % (event_id,))
- if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
-
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
@@ -689,6 +648,131 @@ class FederationHandler(BaseHandler):
return fetched_events
+ async def _get_state_after_missing_prev_event(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ ) -> List[EventBase]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+
+ Returns:
+ A list of events in the state, including the event itself
+ """
+ # TODO: This function is basically the same as _get_state_for_room. Can
+ # we make backfill() use it, rather than having two code paths? I think the
+ # only difference is that backfill() persists the prev events separately.
+
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ logger.debug(
+ "state_ids returned %i state events, %i auth events",
+ len(state_event_ids),
+ len(auth_event_ids),
+ )
+
+ # start by just trying to fetch the events from the store
+ desired_events = set(state_event_ids)
+ desired_events.add(event_id)
+ logger.debug("Fetching %i events from cache/store", len(desired_events))
+ fetched_events = await self.store.get_events(
+ desired_events, allow_rejected=True
+ )
+
+ missing_desired_events = desired_events - fetched_events.keys()
+ logger.debug(
+ "We are missing %i events (got %i)",
+ len(missing_desired_events),
+ len(fetched_events),
+ )
+
+ # We probably won't need most of the auth events, so let's just check which
+ # we have for now, rather than thrashing the event cache with them all
+ # unnecessarily.
+
+ # TODO: we probably won't actually need all of the auth events, since we
+ # already have a bunch of the state events. It would be nice if the
+ # federation api gave us a way of finding out which we actually need.
+
+ missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events.difference_update(
+ await self.store.have_seen_events(missing_auth_events)
+ )
+ logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+ missing_events = missing_desired_events | missing_auth_events
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ (await self.store.get_events(missing_desired_events, allow_rejected=True))
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ # if we couldn't get the prev event in question, that's a problem.
+ remote_event = fetched_events.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+ # missing state at that event is a warning, not a blocker
+ # XXX: this doesn't sound right? it means that we'll end up with incomplete
+ # state.
+ failed_to_fetch = desired_events - fetched_events.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+ ]
+
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
+
+ return remote_state
+
async def _process_received_pdu(
self,
origin: str,
@@ -707,10 +791,7 @@ class FederationHandler(BaseHandler):
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
- room_id = event.room_id
- event_id = event.event_id
-
- logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+ logger.debug("Processing event: %s", event)
try:
await self._handle_new_event(origin, event, state=state)
@@ -871,7 +952,6 @@ class FederationHandler(BaseHandler):
destination=dest,
room_id=room_id,
event_id=e_id,
- include_event_in_state=False,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
@@ -1317,7 +1397,7 @@ class FederationHandler(BaseHandler):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ event.room_id, list(event.auth_event_ids()), include_given=True
)
return list(auth)
@@ -1580,7 +1660,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
- auth_chain = await self.store.get_auth_chain(state_ids)
+ auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
state = await self.store.get_events(list(prev_state_ids.values()))
@@ -2219,7 +2299,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 71a50766..13f81522 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
- hs, "initial_sync_cache"
+ hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31..6d8551a6 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,13 +15,13 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode
import attr
import pymacaroons
from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken
+from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@@ -28,20 +29,26 @@ from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
MacaroonDeserializationException,
+ MacaroonInitException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
+from twisted.web.http_headers import Headers
from synapse.config import ConfigError
-from synapse.config.oidc_config import OidcProviderConfig
+from synapse.config.oidc_config import (
+ OidcProviderClientSecretJwtKey,
+ OidcProviderConfig,
+)
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
-from synapse.util import json_decoder
+from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -211,7 +218,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
- except (MacaroonDeserializationException, ValueError) as e:
+ except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -275,9 +282,21 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
+
+ client_secret = None # type: Union[None, str, JwtClientSecret]
+ if provider.client_secret:
+ client_secret = provider.client_secret
+ elif provider.client_secret_jwt_key:
+ client_secret = JwtClientSecret(
+ provider.client_secret_jwt_key,
+ provider.client_id,
+ provider.issuer,
+ hs.get_clock(),
+ )
+
self._client_auth = ClientAuth(
provider.client_id,
- provider.client_secret,
+ client_secret,
provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
@@ -312,6 +331,9 @@ class OidcProvider:
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
+ # Optional brand identifier for the unstable API (see MSC2858).
+ self.unstable_idp_brand = provider.unstable_idp_brand
+
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@@ -521,7 +543,7 @@ class OidcProvider:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
- headers = {
+ raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
@@ -535,10 +557,10 @@ class OidcProvider:
body = urlencode(args, True)
# Fill the body/headers with credentials
- uri, headers, body = self._client_auth.prepare(
- method="POST", uri=token_endpoint, headers=headers, body=body
+ uri, raw_headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
- headers = {k: [v] for (k, v) in headers.items()}
+ headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
@@ -745,7 +767,7 @@ class OidcProvider:
idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
- ui_auth_session_id=ui_auth_session_id,
+ ui_auth_session_id=ui_auth_session_id or "",
),
)
@@ -976,6 +998,81 @@ class OidcProvider:
return str(remote_user_id)
+# number of seconds a newly-generated client secret should be valid for
+CLIENT_SECRET_VALIDITY_SECONDS = 3600
+
+# minimum remaining validity on a client secret before we should generate a new one
+CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
+
+
+class JwtClientSecret:
+ """A class which generates a new client secret on demand, based on a JWK
+
+ This implementation is designed to comply with the requirements for Apple Sign in:
+ https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
+
+ It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
+ but it's worth noting that we still put the generated secret in the "client_secret"
+ field (or rather, whereever client_auth_method puts it) rather than in a
+ client_assertion field in the body as that RFC seems to require.
+ """
+
+ def __init__(
+ self,
+ key: OidcProviderClientSecretJwtKey,
+ oauth_client_id: str,
+ oauth_issuer: str,
+ clock: Clock,
+ ):
+ self._key = key
+ self._oauth_client_id = oauth_client_id
+ self._oauth_issuer = oauth_issuer
+ self._clock = clock
+ self._cached_secret = b""
+ self._cached_secret_replacement_time = 0
+
+ def __str__(self):
+ # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
+ # encode_client_secret_basic, which calls "{}".format(secret), which ends up
+ # here.
+ return self._get_secret().decode("ascii")
+
+ def __bytes__(self):
+ # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
+ # encode_client_secret_post, which ends up here.
+ return self._get_secret()
+
+ def _get_secret(self) -> bytes:
+ now = self._clock.time()
+
+ # if we have enough validity on our existing secret, use it
+ if now < self._cached_secret_replacement_time:
+ return self._cached_secret
+
+ issued_at = int(now)
+ expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
+
+ # we copy the configured header because jwt.encode modifies it.
+ header = dict(self._key.jwt_header)
+
+ # see https://tools.ietf.org/html/rfc7523#section-3
+ payload = {
+ "sub": self._oauth_client_id,
+ "aud": self._oauth_issuer,
+ "iat": issued_at,
+ "exp": expires_at,
+ **self._key.jwt_payload,
+ }
+ logger.info(
+ "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
+ )
+ self._cached_secret = jwt.encode(header, payload, self._key.key)
+ self._cached_secret_replacement_time = (
+ expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
+ )
+ return self._cached_secret
+
+
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
@@ -1020,10 +1117,9 @@ class OidcSessionTokenGenerator:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
- if session_data.ui_auth_session_id:
- macaroon.add_first_party_caveat(
- "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
- )
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+ )
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1142,7 @@ class OidcSessionTokenGenerator:
The data extracted from the session cookie
Raises:
- ValueError if an expected caveat is missing from the macaroon.
+ KeyError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -1057,26 +1153,16 @@ class OidcSessionTokenGenerator:
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
- # Sometimes there's a UI auth session ID, it seems to be OK to attempt
- # to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the session data from the token.
- nonce = self._get_value_from_macaroon(macaroon, "nonce")
- idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
- client_redirect_url = self._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
- try:
- ui_auth_session_id = self._get_value_from_macaroon(
- macaroon, "ui_auth_session_id"
- ) # type: Optional[str]
- except ValueError:
- ui_auth_session_id = None
-
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ idp_id = get_value_from_macaroon(macaroon, "idp_id")
+ client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+ ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
@@ -1084,33 +1170,6 @@ class OidcSessionTokenGenerator:
ui_auth_session_id=ui_auth_session_id,
)
- def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
- """Extracts a caveat value from a macaroon token.
-
- Args:
- macaroon: the token
- key: the key of the caveat to extract
-
- Returns:
- The extracted value
-
- Raises:
- ValueError: if the caveat was not in the macaroon
- """
- prefix = key + " = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(prefix):
- return caveat.caveat_id[len(prefix) :]
- raise ValueError("No %s caveat in macaroon" % (key,))
-
- def _verify_expiry(self, caveat: str) -> bool:
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
- return now < expiry
-
@attr.s(frozen=True, slots=True)
class OidcSessionData:
@@ -1125,8 +1184,8 @@ class OidcSessionData:
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
- # The session ID of the ongoing UI Auth (None if this is a login)
- ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+ # The session ID of the ongoing UI Auth ("" if this is a login)
+ ui_auth_session_id = attr.ib(type=str)
UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 059064a4..66dc886c 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -285,7 +285,7 @@ class PaginationHandler:
except Exception:
f = Failure()
logger.error(
- "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
+ "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3cda8965..1abc8875 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,7 +16,9 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+
+from prometheus_client import Counter
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -41,6 +43,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+registration_counter = Counter(
+ "synapse_user_registrations_total",
+ "Number of new users registered (since restart)",
+ ["guest", "shadow_banned", "auth_provider"],
+)
+
+login_counter = Counter(
+ "synapse_user_logins_total",
+ "Number of user logins (since restart)",
+ ["guest", "auth_provider"],
+)
+
+
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -67,6 +82,7 @@ class RegistrationHandler(BaseHandler):
)
else:
self.device_handler = hs.get_device_handler()
+ self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@@ -156,6 +172,7 @@ class RegistrationHandler(BaseHandler):
bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ auth_provider_id: Optional[str] = None,
) -> str:
"""Registers a new client on the server.
@@ -181,8 +198,9 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
+ auth_provider_id: The SSO IdP the user used, if any.
Returns:
- The registere user_id.
+ The registered user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -192,6 +210,7 @@ class RegistrationHandler(BaseHandler):
threepid,
localpart,
user_agent_ips or [],
+ auth_provider_id=auth_provider_id,
)
if result == RegistrationBehaviour.DENY:
@@ -280,6 +299,12 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
fail_count += 1
+ registration_counter.labels(
+ guest=make_guest,
+ shadow_banned=shadow_banned,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
if not self.hs.config.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
@@ -638,6 +663,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
+ auth_provider_id: Optional[str] = None,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -648,21 +674,40 @@ class RegistrationHandler(BaseHandler):
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
-
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
Tuple of device ID and access token
"""
+ res = await self._register_device_client(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
+ )
- if self.hs.config.worker_app:
- r = await self._register_device_client(
- user_id=user_id,
- device_id=device_id,
- initial_display_name=initial_display_name,
- is_guest=is_guest,
- is_appservice_ghost=is_appservice_ghost,
- )
- return r["device_id"], r["access_token"]
+ login_counter.labels(
+ guest=is_guest,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
+ return res["device_id"], res["access_token"]
+ async def register_device_inner(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ is_appservice_ghost: bool = False,
+ ) -> Dict[str, str]:
+ """Helper for register_device
+
+ Does the bits that need doing on the main process. Not for use outside this
+ class and RegisterDeviceReplicationServlet.
+ """
+ assert not self.hs.config.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@@ -687,7 +732,7 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
- return (registered_device_id, access_token)
+ return {"device_id": registered_device_id, "access_token": access_token}
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a488df10..4b3d0d72 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler):
# succession, only process the first attempt and return its result to
# subsequent requests
self._upgrade_response_cache = ResponseCache(
- hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
+ hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 14f14db4..8bfc46c6 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -44,10 +44,10 @@ class RoomListHandler(BaseHandler):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
- hs, "room_list"
+ hs.get_clock(), "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
- hs, "remote_room_list", timeout_ms=30 * 1000
+ hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a9645b77..ec2ba11c 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
+ self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdc..415b1c2d 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
"""Optional branding identifier"""
return None
+ @property
+ def unstable_idp_brand(self) -> Optional[str]:
+ """Optional brand identifier for the unstable API (see MSC2858)."""
+ return None
+
@abc.abstractmethod
async def handle_redirect_request(
self,
@@ -456,6 +461,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
@@ -605,6 +611,7 @@ class SsoHandler:
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
+ auth_provider_id=auth_provider_id,
)
await self._store.record_user_external_id(
@@ -886,6 +893,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4e8ed7b3..f50257cd 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -244,7 +244,7 @@ class SyncHandler:
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(
- hs, "sync"
+ hs.get_clock(), "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 72901e3f..1e01e0a9 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -39,12 +39,15 @@ from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
+ ITCPTransport,
)
+from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -56,13 +59,20 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import (
+ UNKNOWN_LENGTH,
+ IAgent,
+ IBodyProducer,
+ IPolicyForHTTPS,
+ IResponse,
+)
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.types import ISynapseReactor
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -150,16 +160,17 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
-
- r = recv()
addresses = [] # type: List[IAddress]
def _callback() -> None:
- r.resolutionBegan(None)
-
has_bad_ip = False
- for i in addresses:
- ip_address = IPAddress(i.host)
+ for address in addresses:
+ # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+ # should go through this path.
+ if not isinstance(address, (IPv4Address, IPv6Address)):
+ continue
+
+ ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
@@ -174,15 +185,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
- for i in addresses:
- r.addressResolved(i)
- r.resolutionComplete()
+ for address in addresses:
+ recv.addressResolved(address)
+ recv.resolutionComplete()
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
- pass
+ recv.resolutionBegan(resolutionInProgress)
@staticmethod
def addressResolved(address: IAddress) -> None:
@@ -196,10 +207,10 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber
)
- return r
+ return recv
-@implementer(IReactorPluggableNameResolver)
+@implementer(ISynapseReactor)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@@ -324,7 +335,7 @@ class SimpleHttpClient:
# filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
- )
+ ) # type: ISynapseReactor
else:
self.reactor = hs.get_reactor()
@@ -345,7 +356,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
- )
+ ) # type: IAgent
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@@ -751,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
@@ -762,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -796,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
@@ -867,6 +884,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")
+@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b07aa59c..5935a125 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
def __init__(
self,
- reactor: IReactorCore,
+ reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
ip_blacklist: IPSet,
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 4def7d76..ecd63e65 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -322,7 +322,8 @@ def _cache_period_from_headers(
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
- for hdr in headers.getRawHeaders(b"cache-control", []):
+ cache_control_headers = headers.getRawHeaders(b"cache-control") or []
+ for hdr in cache_control_headers:
for directive in hdr.split(b","):
splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 0f107714..5f01ebd3 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
- )
+ ) # type: ISynapseReactor
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
- self.agent = MatrixFederationAgent(
+ federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
@@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
+ federation_agent,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
@@ -534,9 +534,10 @@ class MatrixFederationHttpClient:
response.code, response_phrase, body
)
- # Retry if the error is a 429 (Too Many Requests),
- # otherwise just raise a standard HttpResponseException
- if response.code == 429:
+ # Retry if the error is a 5xx or a 429 (Too Many
+ # Requests), otherwise just raise a standard
+ # `HttpResponseException`
+ if 500 <= response.code < 600 or response.code == 429:
raise RequestSendFailed(exc, can_retry=True) from exc
else:
raise exc
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 174ca7be..643492ce 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""
- transport = attr.ib(type=ITransport)
+ # This is essentially ITCPTransport, but that is missing certain fields
+ # (connected and registerProducer) which are part of the implementation.
+ transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter:
return
- self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
@@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect()
def writer(result: Protocol) -> None:
+ # Force recognising transport as a Connection and not the more
+ # generic ITransport.
+ transport = result.transport # type: Connection # type: ignore
+
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
- if self._producer and result.transport is self._producer.transport:
+ if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
- transport=result.transport,
+ transport=transport,
format=self.format,
)
- result.transport.registerProducer(self._producer, True)
+ transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
- self._connection_waiter.addCallbacks(writer, fail)
+ deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
+ deferred.addCallbacks(writer, fail)
+ self._connection_waiter = deferred
def _handle_pressure(self) -> None:
"""
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 78e27bfb..1a7ea4fa 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -669,7 +669,7 @@ def preserve_fn(f):
return g
-def run_in_background(f, *args, **kwargs):
+def run_in_background(f, *args, **kwargs) -> defer.Deferred:
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
if isinstance(res, types.CoroutineType):
res = defer.ensureDeferred(res)
+ # At this point we should have a Deferred, if not then f was a synchronous
+ # function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
- return res
+ return defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b..781e02fb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi:
)
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
+ auth_provider_id: str = "",
) -> str:
- """Generate a login token suitable for m.login.token authentication"""
+ """Generate a login token suitable for m.login.token authentication
+
+ Args:
+ user_id: gives the ID of the user that the token is for
+
+ duration_in_ms: the time that the token will be valid for
+
+ auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+ to get this token, if any. This is encoded in the token so that
+ /login can report stats on number of successful logins by IdP.
+ """
return self._hs.get_macaroon_generator().generate_short_term_login_token(
- user_id, duration_in_ms
+ user_id,
+ auth_provider_id,
+ duration_in_ms,
)
@defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
"""
self._auth_handler._complete_sso_login(
registered_user_id,
+ "<unknown>",
request,
client_redirect_url,
)
@@ -286,6 +302,7 @@ class ModuleApi:
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
+ auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
+ auth_provider_id: the ID of the SSO IdP which was used to log in. This
+ is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url, new_user=new_user
+ registered_user_id,
+ auth_provider_id,
+ request,
+ client_redirect_url,
+ new_user=new_user,
)
@defer.inlineCallbacks
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 5fec2aaf..3dc06a79 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
-from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, ThrottleParams
@@ -66,7 +66,7 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey
- self.timed_call = None # type: Optional[DelayedCall]
+ self.timed_call = None # type: Optional[IDelayedCall]
self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 8a3f113e..b7aa0c28 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -18,7 +18,7 @@ import logging
import re
import urllib
from inspect import signature
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
_pending_outgoing_requests = Gauge(
@@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
CACHE = True
RETRY_ON_TIMEOUT = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache = ResponseCache(
- hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
+ hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 36071feb..4ec1bfa6 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
- device_id, access_token = await self.registration_handler.register_device(
+ res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
@@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_appservice_ghost=is_appservice_ghost,
)
- return 200, {"device_id": device_id, "access_token": access_token}
+ return 200, res
def register_servlets(hs, http_server):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da1..a8894bea 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
- Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+ Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
- self._connections = [] # type: List[AbstractConnection]
+ self._connections = [] # type: List[IReplicationConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+ self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue(
- self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
- conn: AbstractConnection,
+ conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
@@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
- hs.config.redis.redis_host,
+ hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
@@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self._factory)
+ hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate
- def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
- def send_positions_to_connection(self, conn: AbstractConnection):
+ def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
)
def on_USER_SYNC(
- self, conn: AbstractConnection, cmd: UserSyncCommand
+ self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
return None
def on_CLEAR_USER_SYNC(
- self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
- def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+ def on_FEDERATION_ACK(
+ self, conn: IReplicationConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_USER_IP(
- self, conn: AbstractConnection, cmd: UserIpCommand
+ self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
- def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
- self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
- self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
- def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+ def on_REMOTE_SERVER_UP(
+ self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
- def new_connection(self, connection: AbstractConnection):
+ def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
- def lost_connection(self, connection: AbstractConnection):
+ def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections)
def send_command(
- self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad31..825900f6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
@@ -54,8 +53,10 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from zope.interface import Interface, implementer
from twisted.internet import task
+from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -121,6 +122,14 @@ class ConnectionStates:
CLOSED = "closed"
+class IReplicationConnection(Interface):
+ """An interface for replication connections."""
+
+ def send_command(cmd: Command):
+ """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
+ # The transport is going to be an ITCPTransport, but that doesn't have the
+ # (un)registerProducer methods, those are only on the implementation.
+ transport = None # type: Connection
+
delimiter = b"\n"
# Valid commands we expect to receive
@@ -181,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
+ assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@@ -205,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
+ assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@@ -294,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
+ assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
@@ -391,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
+ assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
@@ -495,20 +512,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
-class AbstractConnection(abc.ABC):
- """An interface for replication connections."""
-
- @abc.abstractmethod
- def send_command(self, cmd: Command):
- """Send the command down the connection"""
- pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 0e6155cf..2f4d407f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
import attr
import txredisapi
+from zope.interface import implementer
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import IAddress, IConnector
+from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
@@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
- AbstractConnection,
+ IReplicationConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
@@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
pass
-class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+@implementer(IReplicationConnection)
+class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
This class fulfils two functions:
@@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
- (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+ (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
@@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
except Exception:
logger.warning("Failed to send ping to a redis connection")
+ # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
+ # it's rubbish. We add our own here.
+
+ def startedConnecting(self, connector: IConnector):
+ logger.info(
+ "Connecting to redis server %s", format_address(connector.getDestination())
+ )
+ super().startedConnecting(connector)
+
+ def clientConnectionFailed(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s failed: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s lost: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionLost(connector, reason)
+
+
+def format_address(address: IAddress) -> str:
+ if isinstance(address, (IPv4Address, IPv6Address)):
+ return "%s:%i" % (address.host, address.port)
+ return str(address)
+
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
@@ -328,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host, port, factory, 30)
+ reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index e09234c6..7681e55b 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,10 +15,9 @@
import re
-import twisted.web.server
-
-import synapse.api.auth
+from synapse.api.auth import Auth
from synapse.api.errors import AuthError
+from synapse.http.site import SynapseRequest
from synapse.types import UserID
@@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
return patterns
-async def assert_requester_is_admin(
- auth: synapse.api.auth.Auth, request: twisted.web.server.Request
-) -> None:
+async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
"""Verify that the requester is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
request: incoming request
Raises:
@@ -53,11 +50,11 @@ async def assert_requester_is_admin(
await assert_user_is_admin(auth, requester.user)
-async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
user_id: user to check
Raises:
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 511c859f..7fcc48a9 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,10 +17,9 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from twisted.web.server import Request
-
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
@@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
self.auth = hs.get_auth()
async def on_POST(
- self, request: Request, server_name: str, media_id: str
+ self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
self.media_repository = hs.get_media_repository()
async def on_DELETE(
- self, request: Request, server_name: str, media_id: str
+ self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, server_name: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index 8b7bb6d4..49966ee3 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,13 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Tuple
+
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
@@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
PATTERNS = admin_patterns("/purge_room$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f2c42a0f..263d8ec0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -685,7 +685,10 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
- results["state"], time_now
+ results["state"],
+ time_now,
+ # No need to bundle aggregations for state events
+ bundle_aggregations=False,
)
return 200, results
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 375d0554..f495666f 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,17 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Optional, Tuple
+
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SendServerNoticeServlet(RestServlet):
@@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
}
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
self.snm = hs.get_server_notices_manager()
- def register(self, json_resource):
+ def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
@@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
- async def on_POST(self, request, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
@@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
- def on_PUT(self, request, txn_id):
+ def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 267a9934..2c89b62e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet):
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
- if "password" not in body:
+ if (
+ "password" not in body
+ and self.hs.config.password_localdb_enabled
+ ):
raise SynapseError(
400, "Must provide a password to re-activate an account."
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc4..e4c352f5 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
@@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+ sso_flow = {
+ "type": LoginRestServlet.SSO_TYPE,
+ "identity_providers": [
+ _get_auth_flow_dict_for_idp(
+ idp,
+ )
+ for idp in self._sso_handler.get_identity_providers().values()
+ ],
+ } # type: JsonDict
if self._msc2858_enabled:
+ # backwards-compatibility support for clients which don't
+ # support the stable API yet
sso_flow["org.matrix.msc2858.identity_providers"] = [
- _get_auth_flow_dict_for_idp(idp)
+ _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
for idp in self._sso_handler.get_identity_providers().values()
]
@@ -219,6 +231,7 @@ class LoginRestServlet(RestServlet):
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
+ auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -234,6 +247,8 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
result: Dictionary of account information after successful login.
@@ -256,7 +271,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name
+ user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
)
result = {
@@ -283,12 +298,13 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
- token
- )
+ res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
- user_id, login_submission, self.auth_handler._sso_login_callback
+ res.user_id,
+ login_submission,
+ self.auth_handler._sso_login_callback,
+ auth_provider_id=res.auth_provider_id,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
@@ -327,22 +343,38 @@ class LoginRestServlet(RestServlet):
return result
-def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+def _get_auth_flow_dict_for_idp(
+ idp: SsoIdentityProvider, use_unstable_brands: bool = False
+) -> JsonDict:
"""Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login
+
+ Args:
+ idp: the identity provider to describe
+ use_unstable_brands: whether we should use brand identifiers suitable
+ for the unstable API
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
e["brand"] = idp.idp_brand
+ # use the stable brand identifier if the unstable identifier isn't defined.
+ if use_unstable_brands and idp.unstable_idp_brand:
+ e["brand"] = idp.unstable_idp_brand
return e
class SsoRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+ PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
+ re.compile(
+ "^"
+ + CLIENT_API_PREFIX
+ + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+ )
+ ]
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@@ -360,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
def register(self, http_server: HttpServer) -> None:
super().register(http_server)
if self._msc2858_enabled:
- # expose additional endpoint for MSC2858 support
+ # expose additional endpoint for MSC2858 support: backwards-compat support
+ # for clients which don't yet support the stable endpoints.
http_server.register_paths(
"GET",
client_patterns(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 9a1df30c..5884daea 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -671,7 +671,10 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
- results["state"], time_now
+ results["state"],
+ time_now,
+ # No need to bundle aggregations for state events
+ bundle_aggregations=False,
)
return 200, results
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 7aea4ceb..5901432f 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -32,6 +32,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.types import GroupID, JsonDict
from ._base import client_patterns
@@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
- async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, category_id: Optional[str], room_id: str
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ category_id: Optional[str],
+ room_id: str,
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, category_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_GET(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_GET(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, role_id: Optional[str], user_id: str
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ role_id: Optional[str],
+ user_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, role_id: str, user_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, room_id: str, config_key: str
+ self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
- async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.store = hs.get_datastore()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9039662f..1eff98ef 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0641924f..8b4841ed 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -35,6 +35,7 @@ from synapse.api.errors import (
from synapse.config._base import ConfigError
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@@ -145,7 +146,7 @@ class MediaRepository:
upload_name: Optional[str],
content: IO,
content_length: int,
- auth_user: str,
+ auth_user: UserID,
) -> str:
"""Store uploaded content for a local user and return the mxc URL
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a074e807..b8895aea 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -39,6 +39,7 @@ from synapse.http.server import (
respond_with_json_bytes,
)
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
@@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 07903e40..988f52c7 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -96,9 +96,14 @@ class Thumbnailer:
def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
- # looks awful
- if self.image.mode in ["1", "P"]:
- self.image = self.image.convert("RGB")
+ # looks awful.
+ #
+ # If the image has transparency, use RGBA instead.
+ if self.image.mode in ["1", "L", "P"]:
+ mode = "RGB"
+ if self.image.info.get("transparency", None) is not None:
+ mode = "RGBA"
+ self.image = self.image.convert(mode)
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 5e104fac..ae5aef2f 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
@@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_POST(self, request: Request) -> None:
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5..4dfadf1b 100644
--- a/synapse/rest/synapse/client/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -14,24 +14,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
from synapse.http.server import DirectServeHtmlResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._saml_handler = hs.get_saml_handler()
+ self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
- self._saml_handler._render_error(
+ self._sso_handler.render_error(
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
diff --git a/synapse/server.py b/synapse/server.py
index afd7cd72..48ac87a1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -36,7 +36,6 @@ from typing import (
cast,
)
-import twisted.internet.base
import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
@@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources
-from synapse.types import DomainSpecificString
+from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
- def get_reactor(self) -> twisted.internet.base.ReactorBase:
+ def get_reactor(self) -> ISynapseReactor:
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
@@ -352,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
- return (
- InsecureInterceptableContextFactory()
- if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
- else RegularPolicyForHTTPS()
- )
+ if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+ return InsecureInterceptableContextFactory()
+ return RegularPolicyForHTTPS()
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 18ddb92f..332193ad 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
- self, event_ids: Collection[str], include_given: bool = False
+ self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
- event_ids, include_given=include_given
+ room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self,
+ room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
- An awaitable which resolve to a list of event_ids
+ list of event_ids
"""
+
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_ids_chains",
+ self._get_auth_chain_ids_using_cover_index_txn,
+ room_id,
+ event_ids,
+ include_given,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
+ def _get_auth_chain_ids_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
+ """Calculates the auth chain IDs using the chain index."""
+
+ # First we look up the chain ID/sequence numbers for the given events.
+
+ initial_events = set(event_ids)
+
+ # All the events that we've found that are reachable from the events.
+ seen_events = set() # type: Set[str]
+
+ # A map from chain ID to max sequence number of the given events.
+ event_chains = {} # type: Dict[int, int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ seen_events.add(event_id)
+ event_chains[chain_id] = max(
+ sequence_number, event_chains.get(chain_id, 0)
+ )
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(seen_events)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Now we look up all links for the chains we have, adding chains that
+ # are reachable from any event.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # A map from chain ID to max sequence number *reachable* from any event ID.
+ chains = {} # type: Dict[int, int]
+
+ # Add all linked chains reachable from initial set of chains.
+ for batch in batch_iter(event_chains, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number,
+ chains.get(target_chain_id, 0),
+ )
+
+ # Add the initial set of chains, excluding the sequence corresponding to
+ # initial event.
+ for chain_id, seq_no in event_chains.items():
+ chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* event ID. Events with a sequence less than that are in the
+ # auth chain.
+ if include_given:
+ results = initial_events
+ else:
+ results = set()
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND sequence_number <= max_seq
+ """
+
+ rows = txn.execute_values(sql, chains.items())
+ results.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND sequence_number <= ?
+ """
+ for chain_id, max_no in chains.items():
+ txn.execute(sql, (chain_id, max_no))
+ results.update(r for r, in txn)
+
+ return list(results)
+
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
+ """Calculates the auth chain IDs.
+
+ This is used when we don't have a cover index for the room.
+ """
if include_given:
results = set(event_ids)
else:
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index cb6b1f8a..78367ea5 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -135,6 +135,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._chain_cover_index,
)
+ self.db_pool.updates.register_background_update_handler(
+ "purged_chain_cover",
+ self._purged_chain_cover_index,
+ )
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -932,3 +937,77 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
processed_count=count,
finished_room_map=finished_rooms,
)
+
+ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
+ """
+ A background updates that iterates over the chain cover and deletes the
+ chain cover for events that have been purged.
+
+ This may be due to fully purging a room or via setting a retention policy.
+ """
+ current_event_id = progress.get("current_event_id", "")
+
+ def purged_chain_cover_txn(txn) -> int:
+ # The event ID from events will be null if the chain ID / sequence
+ # number points to a purged event.
+ sql = """
+ SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
+ FROM event_auth_chains
+ LEFT JOIN events AS e USING (event_id)
+ WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
+ """
+ txn.execute(sql, (current_event_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # The event IDs and chain IDs / sequence numbers where the event has
+ # been purged.
+ unreferenced_event_ids = []
+ unreferenced_chain_id_tuples = []
+ event_id = ""
+ for event_id, chain_id, sequence_number, has_event in rows:
+ if not has_event:
+ unreferenced_event_ids.append((event_id,))
+ unreferenced_chain_id_tuples.append((chain_id, sequence_number))
+
+ # Delete the unreferenced auth chains from event_auth_chain_links and
+ # event_auth_chains.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chains WHERE event_id = ?
+ """,
+ unreferenced_event_ids,
+ )
+ # We should also delete matching target_*, but there is no index on
+ # target_chain_id. Hopefully any purged events are due to a room
+ # being fully purged and they will be removed from the origin_*
+ # searches.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chain_links WHERE
+ origin_chain_id = ? AND origin_sequence_number = ?
+ """,
+ unreferenced_chain_id_tuples,
+ )
+
+ progress = {
+ "current_event_id": event_id,
+ }
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "purged_chain_cover", progress
+ )
+
+ return len(rows)
+
+ result = await self.db_pool.runInteraction(
+ "_purged_chain_cover_index",
+ purged_chain_cover_txn,
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("purged_chain_cover")
+
+ return result
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index edbe42f2..c04e162c 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
+
import logging
import threading
from collections import namedtuple
@@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
set[str]: The events we have already seen.
"""
- results = set()
+ # if the event cache contains the event, obviously we've seen it.
+ results = {x for x in event_ids if self._get_event_cache.contains(x)}
def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE "
@@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore):
txn.database_engine, "e.event_id", chunk
)
txn.execute(sql + clause, args)
- for (event_id,) in txn:
- results.add(event_id)
+ results.update(row[0] for row in txn)
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+ for chunk in batch_iter((x for x in event_ids if x not in results), 100):
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 0836e4af..41f4fe7f 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -331,13 +331,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
- (origin_chain_id = ? AND origin_sequence_number = ?) OR
- (target_chain_id = ? AND target_sequence_number = ?)
+ origin_chain_id = ? AND origin_sequence_number = ?
""",
- (
- (chain_id, seq_num, chain_id, seq_num)
- for (chain_id, seq_num) in referenced_chain_id_tuples
- ),
+ referenced_chain_id_tuples,
)
# Now we delete tables which lack an index on room_id but have one on event_id
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 61a7556e..eba66ff3 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
@@ -1510,7 +1510,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
async def user_delete_access_tokens(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
@@ -1533,7 +1533,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
- values = [v for _, v in items]
+ values = [v for _, v in items] # type: List[Union[str, int]]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
diff --git a/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
new file mode 100644
index 00000000..87cb1f3c
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5910, 'purged_chain_cover', '{}');
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index b921d63d..03096618 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore):
self.db_pool.simple_upsert_many_txn(
txn,
- "destination_rooms",
- ["destination", "room_id"],
- rows,
- ["stream_ordering"],
- [(stream_ordering,)] * len(rows),
+ table="destination_rooms",
+ key_names=("destination", "room_id"),
+ key_values=rows,
+ value_names=["stream_ordering"],
+ value_values=[(stream_ordering,)] * len(rows),
)
async def get_destination_last_successful_stream_ordering(
diff --git a/synapse/types.py b/synapse/types.py
index 721343f0..b08ce901 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -35,6 +35,14 @@ from typing import (
import attr
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
+from zope.interface import Interface
+
+from twisted.internet.interfaces import (
+ IReactorCore,
+ IReactorPluggableNameResolver,
+ IReactorTCP,
+ IReactorTime,
+)
from synapse.api.errors import Codes, SynapseError
from synapse.util.stringutils import parse_and_validate_server_name
@@ -67,33 +75,40 @@ MutableStateMap = MutableMapping[StateKey, T]
JsonDict = Dict[str, Any]
-class Requester(
- namedtuple(
- "Requester",
- [
- "user",
- "access_token_id",
- "is_guest",
- "shadow_banned",
- "device_id",
- "app_service",
- "authenticated_entity",
- ],
- )
+# Note that this seems to require inheriting *directly* from Interface in order
+# for mypy-zope to realize it is an interface.
+class ISynapseReactor(
+ IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
):
+ """The interfaces necessary for Synapse to function."""
+
+
+@attr.s(frozen=True, slots=True)
+class Requester:
"""
Represents the user making a request
Attributes:
- user (UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request has been shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request has been shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
"""
+ user = attr.ib(type="UserID")
+ access_token_id = attr.ib(type=Optional[int])
+ is_guest = attr.ib(type=bool)
+ shadow_banned = attr.ib(type=bool)
+ device_id = attr.ib(type=Optional[str])
+ app_service = attr.ib(type=Optional["ApplicationService"])
+ authenticated_entity = attr.ib(type=str)
+
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -141,23 +156,23 @@ class Requester(
def create_requester(
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
- is_guest: Optional[bool] = False,
- shadow_banned: Optional[bool] = False,
+ is_guest: bool = False,
+ shadow_banned: bool = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
-):
+) -> Requester:
"""
Create a new ``Requester`` object
Args:
- user_id (str|UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user_id: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request is shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request is shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 719e35b7..f33c1158 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -76,11 +76,16 @@ class ObservableDeferred:
def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().callback(r)
- except Exception:
- pass
+ observer.callback(r)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .callback(%r), ignoring...",
+ observer,
+ r,
+ exc_info=e,
+ )
return r
def errback(f):
@@ -90,11 +95,16 @@ class ObservableDeferred:
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().errback(f)
- except Exception:
- pass
+ observer.errback(f)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .errback(%r), ignoring...",
+ observer,
+ f,
+ exc_info=e,
+ )
if consumeErrors:
return None
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 32228f42..46ea8e09 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
used rather than trying to compute a new response.
"""
- def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
- self.clock = hs.get_clock()
+ self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
self._name = name
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 00000000..12cdd533
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+ and returns the extracted value.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ MacaroonVerificationFailedException: if there are conflicting values for the
+ caveat in the macaroon, or if the caveat was not found in the macaroon.
+ """
+ prefix = key + " = "
+ result = None # type: Optional[str]
+ for caveat in macaroon.caveats:
+ if not caveat.caveat_id.startswith(prefix):
+ continue
+
+ val = caveat.caveat_id[len(prefix) :]
+
+ if result is None:
+ # first time we found this caveat: record the value
+ result = val
+ elif val != result:
+ # on subsequent occurrences, raise if the value is different.
+ raise MacaroonVerificationFailedException(
+ "Conflicting values for caveat " + key
+ )
+
+ if result is not None:
+ return result
+
+ # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+ # Note that it is insecure to generate a macaroon without all the caveats you
+ # might need (because there is nothing stopping people from adding extra caveats),
+ # so if the caveat isn't there, something odd must be going on.
+ raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+ """Make a macaroon verifier which accepts 'time' caveats
+
+ Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+ the given macaroon verifier.
+
+ Args:
+ v: the macaroon verifier
+ get_time_ms: a callable which will return the timestamp after which the caveat
+ should be considered expired. Normally the current time.
+ """
+
+ def verify_expiry_caveat(caveat: str):
+ time_msec = get_time_ms()
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ return time_msec < expiry
+
+ v.satisfy_general(verify_expiry_caveat)
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 1a3ccb26..6f96cd79 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
+from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config
@@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
else:
data = json_cb()
self.failed_pdus.extend(data["pdus"])
- raise IOError("Failed to connect because this is a test!")
+ raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination)
def get_destination_room(self, room: str, destination: str = "host2") -> dict:
"""
diff --git a/tests/handlers/oidc_test_key.p8 b/tests/handlers/oidc_test_key.p8
new file mode 100644
index 00000000..bb929763
--- /dev/null
+++ b/tests/handlers/oidc_test_key.p8
@@ -0,0 +1,5 @@
+-----BEGIN PRIVATE KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
+Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
+P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
+-----END PRIVATE KEY-----
diff --git a/tests/handlers/oidc_test_key.pub.pem b/tests/handlers/oidc_test_key.pub.pem
new file mode 100644
index 00000000..176d4a4b
--- /dev/null
+++ b/tests/handlers/oidc_test_key.pub.pem
@@ -0,0 +1,4 @@
+-----BEGIN PUBLIC KEY-----
+MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
+QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
+-----END PUBLIC KEY-----
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0e42013b..c9f889b5 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
)
- self.assertEqual("a_user", user_id)
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected
self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+ self.auth_handler.validate_short_term_login_token(token),
AuthError,
)
+ def test_short_term_login_token_gives_auth_provider(self):
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", auth_provider_id="my_idp"
+ )
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+
def test_short_term_login_token_cannot_replace_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
+ )
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
+ res = self.get_success(
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
- self.assertEqual("a_user", user_id)
+ self.assertEqual("a_user", res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- ),
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
AuthError,
)
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.large_number_of_users)
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.small_number_of_users)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
- token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "user_a", "", 5000
+ )
return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 6f992291..7975af24 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=False
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=False
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+ "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
)
@override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cf1de28f..5e9c9c2e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from typing import Optional
+import os
from urllib.parse import parse_qs, urlparse
from mock import ANY, Mock, patch
@@ -23,6 +23,7 @@ import pymacaroons
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
JWKS_URI = ISSUER + ".well-known/jwks.json"
# config for common cases
-COMMON_CONFIG = {
+DEFAULT_CONFIG = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+}
+
+# extends the default config with explicit OAuth2 endpoints instead of using discovery
+EXPLICIT_ENDPOINT_CONFIG = {
+ **DEFAULT_CONFIG,
"discover": False,
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
"token_endpoint": TOKEN_ENDPOINT,
@@ -107,6 +119,32 @@ async def get_json(url):
return {"keys": []}
+def _key_file_path() -> str:
+ """path to a file containing the private half of a test key"""
+
+ # this key was generated with:
+ # openssl ecparam -name prime256v1 -genkey -noout |
+ # openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
+ #
+ # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
+ # that's what Apple use, and we want to be sure that we work with Apple's keys.
+ #
+ # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
+ # keys using ASN.1. Both are then typically formatted using PEM, which says: use the
+ # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
+ # really need to care about any of that.)
+ return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
+
+
+def _public_key_file_path() -> str:
+ """path to a file containing the public half of a test key"""
+ # this was generated with:
+ # openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
+ #
+ # See above about where oidc_test_key.p8 came from
+ return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
+
+
class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
@@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = {
- "enabled": True,
- "client_id": CLIENT_ID,
- "client_secret": CLIENT_SECRET,
- "issuer": ISSUER,
- "scopes": SCOPES,
- "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
- }
-
- # Update this config with what's in the default config so that
- # override_config works as expected.
- oidc_config.update(config.get("oidc_config", {}))
- config["oidc_config"] = oidc_config
-
return config
def make_homeserver(self, reactor, clock):
@@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.reset_mock()
return args
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
- @override_config({"oidc_config": {"discover": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
@@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
- @override_config({"oidc_config": COMMON_CONFIG})
+ @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
- @override_config({"oidc_config": COMMON_CONFIG})
+ @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
@@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Shouldn't raise with a valid userinfo, even without jwks
force_load_metadata()
- @override_config({"oidc_config": {"skip_verification": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
@@ -360,20 +387,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
- state = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "state"
- )
- nonce = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "nonce"
- )
- redirect = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
+ state = get_value_from_macaroon(macaroon, "state")
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
@@ -385,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@@ -434,7 +457,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None, new_user=True
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None, new_user=False
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
@@ -486,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -528,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
- @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @override_config(
+ {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
+ )
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@@ -613,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "issuer": ISSUER,
+ "client_auth_method": "client_secret_post",
+ "client_secret_jwt_key": {
+ "key_file": _key_file_path(),
+ "jwt_header": {"alg": "ES256", "kid": "ABC789"},
+ "jwt_payload": {"iss": "DEFGHI"},
+ },
+ }
+ }
+ )
+ def test_exchange_code_jwt_key(self):
+ """Test that code exchange works with a JWK client secret."""
+ from authlib.jose import jwt
+
+ token = {"type": "bearer"}
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+ )
+ )
+ code = "code"
+
+ # advance the clock a bit before we start, so we aren't working with zero
+ # timestamps.
+ self.reactor.advance(1000)
+ start_time = self.reactor.seconds()
+ ret = self.get_success(self.provider._exchange_code(code))
+
+ self.assertEqual(ret, token)
+
+ # the request should have hit the token endpoint
+ kwargs = self.http_client.request.call_args[1]
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ # the client secret provided to the should be a jwt which can be checked with
+ # the public key
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ secret = args["client_secret"][0]
+ with open(_public_key_file_path()) as f:
+ key = f.read()
+ claims = jwt.decode(secret, key)
+ self.assertEqual(claims.header["kid"], "ABC789")
+ self.assertEqual(claims["aud"], ISSUER)
+ self.assertEqual(claims["iss"], "DEFGHI")
+ self.assertEqual(claims["sub"], CLIENT_ID)
+ self.assertEqual(claims["iat"], start_time)
+ self.assertGreater(claims["exp"], start_time)
+
+ # check the rest of the POSTed data
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ @override_config(
+ {
+ "oidc_config": {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "issuer": ISSUER,
+ "client_auth_method": "none",
+ }
+ }
+ )
+ def test_exchange_code_no_auth(self):
+ """Test that code exchange works with no client secret."""
+ token = {"type": "bearer"}
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+ )
+ )
+ code = "code"
+ ret = self.get_success(self.provider._exchange_code(code))
+
+ self.assertEqual(ret, token)
+
+ # the request should have hit the token endpoint
+ kwargs = self.http_client.request.call_args[1]
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ # check the POSTed data
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
- }
+ },
}
}
)
@@ -651,12 +773,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with(
"@foo:test",
+ "oidc",
request,
client_redirect_url,
{"phone": "1234567"},
new_user=True,
)
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
@@ -668,7 +792,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", ANY, ANY, None, new_user=True
+ "@test_user:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -679,7 +803,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user_2:test", ANY, ANY, None, new_user=True
+ "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -697,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"Mapping provider does not support de-duplicating Matrix IDs",
)
- @override_config({"oidc_config": {"allow_existing_users": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
@@ -716,14 +840,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -738,7 +862,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -774,9 +898,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+ "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
)
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
@@ -787,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderFailures"
- }
+ },
}
}
)
@@ -810,7 +936,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", ANY, ANY, None, new_user=True
+ "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -834,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_empty_localpart(self):
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
@@ -846,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"config": {"localpart_template": "{{ user.username }}"}
- }
+ },
}
}
)
@@ -866,7 +994,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
state: str,
nonce: str,
client_redirect_url: str,
- ui_auth_session_id: Optional[str] = None,
+ ui_auth_session_id: str = "",
) -> str:
from synapse.handlers.oidc_handler import OidcSessionData
@@ -909,6 +1037,7 @@ async def _make_callback_with_userinfo(
idp_id="oidc",
nonce="nonce",
client_redirect_url=client_redirect_url,
+ ui_auth_session_id="",
),
)
request = _build_callback_request("code", state, session)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8..94b69035 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_spam_checker_receives_sso_type(self):
+ """Test rejecting registration based on SSO type"""
+
+ class BanBadIdPUser:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+ # Configure a spam checker that denies a certain user on a specific IdP
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanBadIdPUser()]
+
+ f = self.get_failure(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
+ SynapseError,
+ )
+ exception = f.value
+
+ # We return 429 from the spam checker for denied registrations
+ self.assertIsInstance(exception, SynapseError)
+ self.assertEqual(exception.code, 429)
+
+ # Check the same username can register using SAML
+ self.get_success(
+ self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
+ )
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 029af285..30efd43b 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None, new_user=False
+ "@test_user:test", "saml", request, "", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None, new_user=False
+ "@test_user:test", "saml", request, "", None, new_user=False
)
def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", request, "", None, new_user=True
+ "@test_user1:test", "saml", request, "", None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 21ecb81c..0ce181a5 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -16,12 +16,23 @@ from io import BytesIO
from mock import Mock
+from netaddr import IPSet
+
+from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
-from twisted.web.client import ResponseDone
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH
-from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+from synapse.api.errors import SynapseError
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
+ read_body_with_max_size,
+)
+from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")
+
+
+class BlacklistingAgentTest(TestCase):
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+ self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+ self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+ # Configure the reactor's DNS resolver.
+ for (domain, ip) in (
+ (self.safe_domain, self.safe_ip),
+ (self.unsafe_domain, self.unsafe_ip),
+ (self.allowed_domain, self.allowed_ip),
+ ):
+ self.reactor.lookups[domain.decode()] = ip.decode()
+ self.reactor.lookups[ip.decode()] = ip.decode()
+
+ self.ip_whitelist = IPSet([self.allowed_ip.decode()])
+ self.ip_blacklist = IPSet(["5.0.0.0/8"])
+
+ def test_reactor(self):
+ """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
+ agent = Agent(
+ BlacklistingReactorWrapper(
+ self.reactor,
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ ),
+ )
+
+ # The unsafe domains and IPs should be rejected.
+ for domain in (self.unsafe_domain, self.unsafe_ip):
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + domain), DNSLookupError
+ )
+
+ # The safe domains IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
+
+ def test_agent(self):
+ """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
+ agent = BlacklistingAgentWrapper(
+ Agent(self.reactor),
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ )
+
+ # The unsafe IPs should be rejected.
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
+ )
+
+ # The safe and unsafe domains and safe IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.unsafe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f6a6aed3..67b79136 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-import attr
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -32,7 +31,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+ ReplicationStreamProtocolFactory,
+ ServerReplicationStreamProtocol,
+)
from synapse.server import HomeServer
from synapse.util import Clock
@@ -59,7 +61,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
- self.server = server_factory.buildProtocol(None)
+ self.server = server_factory.buildProtocol(
+ None
+ ) # type: ServerReplicationStreamProtocol
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -152,12 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self.site
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -179,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return request_factory.request
+ return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -188,8 +188,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream.
"""
+ path = request.path # type: bytes # type: ignore
self.assertRegex(
- request.path,
+ path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
@@ -232,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- "localhost",
+ b"localhost",
6379,
self.connect_any_redis_attempts,
)
@@ -387,12 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
- request_factory = OneShotRequestFactory()
-
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self._hs_to_site[hs]
+ channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -418,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
+ self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
@@ -450,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
-@attr.s()
-class OneShotRequestFactory:
- """A simple request factory that generates a single `SynapseRequest` and
- stores it for future use. Can only be used once.
- """
-
- request = attr.ib(default=None)
-
- def __call__(self, *args, **kwargs):
- assert self.request is None
-
- self.request = SynapseRequest(*args, **kwargs)
- return self.request
-
-
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@@ -475,9 +457,13 @@ class _PushHTTPChannel(HTTPChannel):
makes it very hard to test.
"""
- def __init__(self, reactor: IReactorTime):
+ def __init__(
+ self, reactor: IReactorTime, request_factory: Type[Request], site: Site
+ ):
super().__init__()
self.reactor = reactor
+ self.requestFactory = request_factory
+ self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
@@ -503,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
+ def requestDone(self, request):
+ # Store the request for inspection.
+ self.request = request
+ super().requestDone(request)
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@@ -590,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
+ transport = None # type: Optional[FakeTransport]
+
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@@ -634,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg):
"""Send a message back to the client."""
+ assert self.transport is not None
+
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index f235f1bd..0d9e3bb1 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
"""
rch = self.hs.get_tcp_replication()
- # wire up the ReplicationCommandHandler to a mock connection
- mock_connection = mock.Mock(spec=AbstractConnection)
+ # wire up the ReplicationCommandHandler to a mock connection, which needs
+ # to implement IReplicationConnection. (Note that Mock doesn't understand
+ # interfaces, but casing an interface to a list gives the attributes.)
+ mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection)
# tell it it received an RDATA row
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 20af3285..988821b1 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
- expected_flows = [
- {"type": "m.login.cas"},
- {"type": "m.login.sso"},
- {"type": "m.login.token"},
- {"type": "m.login.password"},
- ] + ADDITIONAL_LOGIN_FLOWS
+ expected_flow_types = [
+ "m.login.cas",
+ "m.login.sso",
+ "m.login.token",
+ "m.login.password",
+ ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
- self.assertCountEqual(channel.json_body["flows"], expected_flows)
+ self.assertCountEqual(
+ [f["type"] for f in channel.json_body["flows"]], expected_flow_types
+ )
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_get_msc2858_login_flows(self):
@@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
- def test_client_idp_redirect_msc2858_disabled(self):
- """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
- channel = self._make_sso_redirect_request(True, "oidc")
- self.assertEqual(channel.code, 400, channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
-
- @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
- channel = self._make_sso_redirect_request(True, "xxx")
+ channel = self._make_sso_redirect_request(False, "xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
- @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
+ channel = self._make_sso_redirect_request(False, "oidc")
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_msc2858_redirect_to_oidc(self):
+ """Test the unstable API"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
@@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
+ channel = self._make_sso_redirect_request(True, "oidc")
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 36d1e6bc..9f77125f 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -105,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
-@attr.s
+@attr.s(slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
@@ -117,13 +117,15 @@ class _TestImage:
test should just check for success.
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
+ expected_found: True if the file should exist on the server, or False if
+ a 404 is expected.
"""
data = attr.ib(type=bytes)
content_type = attr.ib(type=bytes)
extension = attr.ib(type=bytes)
- expected_cropped = attr.ib(type=Optional[bytes])
- expected_scaled = attr.ib(type=Optional[bytes])
+ expected_cropped = attr.ib(type=Optional[bytes], default=None)
+ expected_scaled = attr.ib(type=Optional[bytes], default=None)
expected_found = attr.ib(default=True, type=bool)
@@ -153,6 +155,21 @@ class _TestImage:
),
),
),
+ # small png with transparency.
+ (
+ _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d49484452000000010000000101000"
+ b"00000376ef9240000000274524e5300010194fdae0000000a4944"
+ b"4154789c636800000082008177cd72b60000000049454e44ae426"
+ b"082"
+ ),
+ b"image/png",
+ b".png",
+ # Note that we don't check the output since it varies across
+ # different versions of Pillow.
+ ),
+ ),
# small lossless webp
(
_TestImage(
@@ -162,8 +179,6 @@ class _TestImage:
),
b"image/webp",
b".webp",
- None,
- None,
),
),
# an empty file
@@ -172,9 +187,7 @@ class _TestImage:
b"",
b"image/gif",
b".gif",
- None,
- None,
- False,
+ expected_found=False,
),
),
],
diff --git a/tests/server.py b/tests/server.py
index 939a0008..2287d200 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
+ ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@@ -188,7 +189,7 @@ class FakeSite:
def make_request(
reactor,
- site: Site,
+ site: Union[Site, FakeSite],
method,
path,
content=b"",
@@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock
+@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81..d597d712 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- @parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
- # Mark the room as not having a cover index
+ # Mark the room as maybe having a cover index.
def store_room(txn):
self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
)
+ return room_id
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
+ # a and b have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["a", "b"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ # d and e have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+ self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+ self.assertEqual(auth_chain_ids, ["k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+ self.assertEqual(auth_chain_ids, ["j"])
+
+ # j and k have no parents.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+ self.assertEqual(auth_chain_ids, [])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+ self.assertEqual(auth_chain_ids, [])
+
+ # More complex input sequences.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["h", "i"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+ # e gets returned even though include_given is false, but it is in the
+ # auth chain of b.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "e"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ # Test include_given.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+ )
+ self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index a06ad2c0..41af8c48 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
- def test_purge(self):
+ self.store = hs.get_datastore()
+ self.storage = self.hs.get_storage()
+
+ def test_purge_history(self):
"""
- Purging a room will delete everything before the topological point.
+ Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- store = self.hs.get_datastore()
- storage = self.hs.get_storage()
-
# Get the topological token
token = self.get_success(
- store.get_topological_token_for_event(last["event_id"])
+ self.store.get_topological_token_for_event(last["event_id"])
)
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
self.get_success(
- storage.purge_events.purge_history(self.room_id, token_str, True)
+ self.storage.purge_events.purge_history(self.room_id, token_str, True)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.get_failure(store.get_event(first["event_id"]), NotFoundError)
- self.get_failure(store.get_event(second["event_id"]), NotFoundError)
- self.get_failure(store.get_event(third["event_id"]), NotFoundError)
- self.get_success(store.get_event(last["event_id"]))
+ self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_wont_delete_extrems(self):
+ def test_purge_history_wont_delete_extrems(self):
"""
- Purging a room will delete everything before the topological point.
+ Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- storage = self.hs.get_datastore()
-
# Set the topological token higher than it should be
token = self.get_success(
- storage.get_topological_token_for_event(last["event_id"])
+ self.store.get_topological_token_for_event(last["event_id"])
)
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
- purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
- self.pump()
- f = self.failureResultOf(purge)
+ f = self.get_failure(
+ self.storage.purge_events.purge_history(self.room_id, event, True),
+ SynapseError,
+ )
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- self.get_success(storage.get_event(first["event_id"]))
- self.get_success(storage.get_event(second["event_id"]))
- self.get_success(storage.get_event(third["event_id"]))
- self.get_success(storage.get_event(last["event_id"]))
+ self.get_success(self.store.get_event(first["event_id"]))
+ self.get_success(self.store.get_event(second["event_id"]))
+ self.get_success(self.store.get_event(third["event_id"]))
+ self.get_success(self.store.get_event(last["event_id"]))
+
+ def test_purge_room(self):
+ """
+ Purging a room will delete everything about it.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+
+ # Get the current room state.
+ state_handler = self.hs.get_state_handler()
+ create_event = self.get_success(
+ state_handler.get_current_state(self.room_id, "m.room.create", "")
+ )
+ self.assertIsNotNone(create_event)
+
+ # Purge everything before this topological token
+ self.get_success(self.storage.purge_events.purge_room(self.room_id))
+
+ # The events aren't found.
+ self.store._invalidate_get_event_cache(create_event.event_id)
+ self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
+ self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 52ae5c57..74568b34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
- self.tx_log.emit(
+ self.tx_log.emit( # type: ignore
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_responsecache.py
new file mode 100644
index 00000000..f9a187b8
--- /dev/null
+++ b/tests/util/caches/test_responsecache.py
@@ -0,0 +1,131 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches.response_cache import ResponseCache
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class DeferredCacheTestCase(TestCase):
+ """
+ A TestCase class for ResponseCache.
+
+ The test-case function naming has some logic to it in it's parts, here's some notes about it:
+ wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available
+ (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.)
+ expire: Denotes tests that test expiry after assured existence.
+ (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
+ """
+
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
+ return ResponseCache(self.clock, name, timeout_ms=ms)
+
+ @staticmethod
+ async def instant_return(o: str) -> str:
+ return o
+
+ async def delayed_return(self, o: str) -> str:
+ await self.clock.sleep(1)
+ return o
+
+ def test_cache_hit(self):
+ cache = self.with_cache("keeping_cache", ms=9001)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(wrap_d),
+ "initial wrap result should be the same",
+ )
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should have the result",
+ )
+
+ def test_cache_miss(self):
+ cache = self.with_cache("trashing_cache", ms=0)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(wrap_d),
+ "initial wrap result should be the same",
+ )
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+ def test_cache_expire(self):
+ cache = self.with_cache("short_cache", ms=1000)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should still have the result",
+ )
+
+ # cache eviction timer is handled
+ self.reactor.pump((2,))
+
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+ def test_cache_wait_hit(self):
+ cache = self.with_cache("neutral_cache")
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+ self.assertNoResult(wrap_d)
+
+ # function wakes up, returns result
+ self.reactor.pump((2,))
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+
+ def test_cache_wait_expire(self):
+ cache = self.with_cache("medium_cache", ms=3000)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+ self.assertNoResult(wrap_d)
+
+ # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
+ self.reactor.pump((1, 1))
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should still have the result",
+ )
+
+ # (1 + 1 + 2) > 3.0, cache eviction timer is handled
+ self.reactor.pump((2,))
+
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
diff --git a/tox.ini b/tox.ini
index a6d10537..9ff70fe3 100644
--- a/tox.ini
+++ b/tox.ini
@@ -189,7 +189,5 @@ commands=
[testenv:mypy]
deps =
{[base]deps}
- # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513
- twisted==20.3.0
extras = all,mypy
commands = mypy