summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-06-19 15:20:00 +0200
committerAndrej Shadura <andrewsh@debian.org>2022-06-19 15:20:00 +0200
commit6dc64c92c6991f09910f3e6db368e6eeb4b1981e (patch)
treed8bab73ee460e0a96bbda9c5988d8025dbbe2eb3 /tests
parentc2d3cd76c24f663449bfa209ac920305f0501d3a (diff)
New upstream version 1.61.0
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py2
-rw-r--r--tests/api/test_filtering.py6
-rw-r--r--tests/api/test_ratelimiting.py2
-rw-r--r--tests/appservice/test_api.py101
-rw-r--r--tests/appservice/test_appservice.py3
-rw-r--r--tests/config/test_cache.py8
-rw-r--r--tests/crypto/test_event_signing.py17
-rw-r--r--tests/crypto/test_keyring.py2
-rw-r--r--tests/events/test_presence_router.py2
-rw-r--r--tests/events/test_snapshot.py4
-rw-r--r--tests/federation/test_federation_sender.py40
-rw-r--r--tests/federation/test_federation_server.py6
-rw-r--r--tests/federation/transport/server/__init__.py13
-rw-r--r--tests/federation/transport/server/test__base.py141
-rw-r--r--tests/federation/transport/test_server.py4
-rw-r--r--tests/handlers/test_appservice.py16
-rw-r--r--tests/handlers/test_directory.py3
-rw-r--r--tests/handlers/test_federation.py19
-rw-r--r--tests/handlers/test_federation_event.py15
-rw-r--r--tests/handlers/test_message.py14
-rw-r--r--tests/handlers/test_receipts.py94
-rw-r--r--tests/handlers/test_room_summary.py20
-rw-r--r--tests/handlers/test_sync.py1
-rw-r--r--tests/handlers/test_typing.py41
-rw-r--r--tests/handlers/test_user_directory.py3
-rw-r--r--tests/http/server/__init__.py13
-rw-r--r--tests/http/server/_base.py100
-rw-r--r--tests/http/test_fedclient.py6
-rw-r--r--tests/http/test_servlet.py74
-rw-r--r--tests/http/test_site.py2
-rw-r--r--tests/module_api/test_api.py2
-rw-r--r--tests/push/test_push_rule_evaluator.py84
-rw-r--r--tests/replication/_base.py54
-rw-r--r--tests/replication/http/__init__.py13
-rw-r--r--tests/replication/http/test__base.py106
-rw-r--r--tests/replication/slave/storage/_base.py2
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/replication/slave/storage/test_receipts.py12
-rw-r--r--tests/replication/tcp/test_handler.py73
-rw-r--r--tests/replication/test_sharded_event_persister.py14
-rw-r--r--tests/rest/admin/test_admin.py90
-rw-r--r--tests/rest/admin/test_room.py3
-rw-r--r--tests/rest/admin/test_user.py4
-rw-r--r--tests/rest/client/test_account.py1
-rw-r--r--tests/rest/client/test_auth.py41
-rw-r--r--tests/rest/client/test_devices.py (renamed from tests/rest/client/test_device_lists.py)43
-rw-r--r--tests/rest/client/test_events.py3
-rw-r--r--tests/rest/client/test_groups.py56
-rw-r--r--tests/rest/client/test_login.py2
-rw-r--r--tests/rest/client/test_mutual_rooms.py2
-rw-r--r--tests/rest/client/test_notifications.py91
-rw-r--r--tests/rest/client/test_register.py2
-rw-r--r--tests/rest/client/test_relations.py89
-rw-r--r--tests/rest/client/test_retention.py39
-rw-r--r--tests/rest/client/test_room_batch.py7
-rw-r--r--tests/rest/client/test_rooms.py267
-rw-r--r--tests/rest/client/test_sendtodevice.py5
-rw-r--r--tests/rest/client/test_shadow_banned.py4
-rw-r--r--tests/rest/client/test_sync.py41
-rw-r--r--tests/rest/client/test_typing.py3
-rw-r--r--tests/rest/client/test_upgrade_room.py38
-rw-r--r--tests/rest/media/test_media_retention.py321
-rw-r--r--tests/rest/media/v1/test_html_preview.py37
-rw-r--r--tests/rest/media/v1/test_url_preview.py35
-rw-r--r--tests/scripts/test_new_matrix_user.py13
-rw-r--r--tests/server.py14
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py11
-rw-r--r--tests/storage/databases/main/test_events_worker.py25
-rw-r--r--tests/storage/databases/main/test_lock.py54
-rw-r--r--tests/storage/test_appservice.py27
-rw-r--r--tests/storage/test_base.py2
-rw-r--r--tests/storage/test_devices.py7
-rw-r--r--tests/storage/test_event_chain.py3
-rw-r--r--tests/storage/test_event_federation.py9
-rw-r--r--tests/storage/test_events.py58
-rw-r--r--tests/storage/test_monthly_active_users.py83
-rw-r--r--tests/storage/test_purge.py19
-rw-r--r--tests/storage/test_redaction.py14
-rw-r--r--tests/storage/test_room.py12
-rw-r--r--tests/storage/test_room_search.py4
-rw-r--r--tests/storage/test_roommember.py2
-rw-r--r--tests/storage/test_state.py2
-rw-r--r--tests/storage/test_user_directory.py1
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py59
-rw-r--r--tests/test_mau.py3
-rw-r--r--tests/test_server.py111
-rw-r--r--tests/test_state.py36
-rw-r--r--tests/test_types.py21
-rw-r--r--tests/test_utils/event_injection.py2
-rw-r--r--tests/test_visibility.py46
-rw-r--r--tests/unittest.py2
-rw-r--r--tests/util/test_lrucache.py58
-rw-r--r--tests/utils.py2
93 files changed, 2500 insertions, 566 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index d547df8a..bc75ddd3 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
@@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
- self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 985d6e39..a269c477 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -20,7 +20,7 @@ from unittest.mock import patch
import jsonschema
from frozendict import frozendict
-from synapse.api.constants import EventContentFields
+from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
@@ -85,13 +85,13 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"org.matrix.not_labels": ["#work"],
},
"ephemeral": {
- "types": ["m.receipt", "m.typing"],
+ "types": [EduTypes.RECEIPT, EduTypes.TYPING],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"],
},
},
"presence": {
- "types": ["m.presence"],
+ "types": [EduTypes.PRESENCE],
"not_senders": ["@alice:example.com"],
},
"event_format": "client",
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483d5463..f661a9ff 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None,
- "example.com",
id="foo",
rate_limited=True,
sender="@as:example.com",
@@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None,
- "example.com",
id="foo",
rate_limited=False,
sender="@as:example.com",
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
new file mode 100644
index 00000000..532b6763
--- /dev/null
+++ b/tests/appservice/test_api.py
@@ -0,0 +1,101 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, List, Mapping
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.appservice import ApplicationService
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+
+PROTOCOL = "myproto"
+TOKEN = "myastoken"
+URL = "http://mytestservice"
+
+
+class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.api = hs.get_application_service_api()
+ self.service = ApplicationService(
+ id="unique_identifier",
+ sender="@as:test",
+ url=URL,
+ token="unused",
+ hs_token=TOKEN,
+ )
+
+ def test_query_3pe_authenticates_token(self):
+ """
+ Tests that 3pe queries to the appservice are authenticated
+ with the appservice's token.
+ """
+
+ SUCCESS_RESULT_USER = [
+ {
+ "protocol": PROTOCOL,
+ "userid": "@a:user",
+ "fields": {
+ "more": "fields",
+ },
+ }
+ ]
+ SUCCESS_RESULT_LOCATION = [
+ {
+ "protocol": PROTOCOL,
+ "alias": "#a:room",
+ "fields": {
+ "more": "fields",
+ },
+ }
+ ]
+
+ URL_USER = f"{URL}/_matrix/app/unstable/thirdparty/user/{PROTOCOL}"
+ URL_LOCATION = f"{URL}/_matrix/app/unstable/thirdparty/location/{PROTOCOL}"
+
+ self.request_url = None
+
+ async def get_json(url: str, args: Mapping[Any, Any]) -> List[JsonDict]:
+ if not args.get(b"access_token"):
+ raise RuntimeError("Access token not provided")
+
+ self.assertEqual(args.get(b"access_token"), TOKEN)
+ self.request_url = url
+ if url == URL_USER:
+ return SUCCESS_RESULT_USER
+ elif url == URL_LOCATION:
+ return SUCCESS_RESULT_LOCATION
+ else:
+ raise RuntimeError(
+ "URL provided was invalid. This should never be seen."
+ )
+
+ # We assign to a method, which mypy doesn't like.
+ self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment]
+
+ result = self.get_success(
+ self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
+ )
+ self.assertEqual(self.request_url, URL_USER)
+ self.assertEqual(result, SUCCESS_RESULT_USER)
+ result = self.get_success(
+ self.api.query_3pe(
+ self.service, "location", PROTOCOL, {b"some": [b"field"]}
+ )
+ )
+ self.assertEqual(self.request_url, URL_LOCATION)
+ self.assertEqual(result, SUCCESS_RESULT_LOCATION)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index edc584d0..3018d3fc 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -23,7 +23,7 @@ from tests.test_utils import simple_async_mock
def _regex(regex: str, exclusive: bool = True) -> Namespace:
- return Namespace(exclusive, None, re.compile(regex))
+ return Namespace(exclusive, re.compile(regex))
class ApplicationServiceTestCase(unittest.TestCase):
@@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
sender="@as:test",
url="some_url",
token="some_token",
- hostname="matrix.org", # only used by get_groups_for_user
)
self.event = Mock(
event_id="$abc:xyz",
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 4bb82e81..d2b3c299 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -38,6 +38,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_NOT_CACHE": "BLAH",
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
@@ -52,6 +53,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_CACHE_FACTOR_FOO": 1,
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(
dict(self.config.cache_factors),
@@ -71,6 +73,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"per_cache_factors": {"foo": 3}}}
self.config.read_config(config)
+ self.config.resize_all_caches()
self.assertEqual(cache.max_size, 300)
@@ -82,6 +85,7 @@ class CacheConfigTests(TestCase):
"""
config = {"caches": {"per_cache_factors": {"foo": 2}}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -99,6 +103,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"global_factor": 4}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(cache.max_size, 400)
@@ -110,6 +115,7 @@ class CacheConfigTests(TestCase):
"""
config = {"caches": {"global_factor": 1.5}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -128,6 +134,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache_a = LruCache(100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
@@ -148,6 +155,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"event_cache_size": "10k"}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(
max_size=self.config.event_cache_size,
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 06e0545a..8fa710c9 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import nacl.signing
-import signedjson.types
-from unpaddedbase64 import decode_base64
+from signedjson.key import decode_signing_key_base64
+from signedjson.types import SigningKey
from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import add_hashes_and_signatures
@@ -25,7 +23,7 @@ from tests import unittest
# Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against.
-SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
+SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
KEY_ALG = "ed25519"
KEY_VER = "1"
@@ -36,14 +34,9 @@ HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase):
def setUp(self):
- # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been
- # monkeypatched to include new `alg` and `version` attributes. This is captured
- # by the `signedjson.types.SigningKey` protocol.
- self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment]
- SIGNING_KEY_SEED
+ self.signing_key: SigningKey = decode_signing_key_base64(
+ KEY_ALG, KEY_VER, SIGNING_KEY_SEED
)
- self.signing_key.alg = KEY_ALG
- self.signing_key.version = KEY_VER
def test_sign_minimal(self):
event_dict = {
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index d00ef24c..820a1a54 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,8 +19,8 @@ import attr
import canonicaljson
import signedjson.key
import signedjson.sign
-from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
+from signedjson.types import SigningKey
from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 3deb14c3..ffc3012a 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -439,7 +439,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
for edu in edus:
# Make sure we're only checking presence-type EDUs
- if edu["edu_type"] != EduTypes.Presence:
+ if edu["edu_type"] != EduTypes.PRESENCE:
continue
# EDUs can contain multiple presence updates
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index defbc68c..8ddce83b 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
- d_context = EventContext.deserialize(self.storage, serialized)
+ d_context = EventContext.deserialize(self._storage_controllers, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 6b26353d..01a1db61 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -19,7 +19,7 @@ from signedjson.types import BaseKey, SigningKey
from twisted.internet import defer
-from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.rest import admin
from synapse.rest.client import login
from synapse.types import JsonDict, ReadReceipt
@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
- # Ensure a new Awaitable is created for each call.
- mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
- ["test", "host2"]
- )
- return self.setup_test_homeserver(
- state_handler=mock_state_handler,
+ hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
+ hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
+ return_value=make_awaitable({"test", "host2"})
+ )
+
+ return hs
+
@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (
@@ -63,7 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
data["edus"],
[
{
- "edu_type": "m.receipt",
+ "edu_type": EduTypes.RECEIPT,
"content": {
"room_id": {
"m.read": {
@@ -103,7 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
data["edus"],
[
{
- "edu_type": "m.receipt",
+ "edu_type": EduTypes.RECEIPT,
"content": {
"room_id": {
"m.read": {
@@ -138,7 +138,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
data["edus"],
[
{
- "edu_type": "m.receipt",
+ "edu_type": EduTypes.RECEIPT,
"content": {
"room_id": {
"m.read": {
@@ -322,8 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# expect signing key update edu
self.assertEqual(len(self.edus), 2)
- self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
- self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
+ self.assertEqual(self.edus.pop(0)["edu_type"], EduTypes.SIGNING_KEY_UPDATE)
+ self.assertEqual(
+ self.edus.pop(0)["edu_type"], EduTypes.UNSTABLE_SIGNING_KEY_UPDATE
+ )
# sign the devices
d1_json = build_device_dict(u1, "D1", device1_signing_key)
@@ -348,7 +350,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 2)
stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142
for edu in self.edus:
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
if stream_id is not None:
self.assertEqual(c["prev_id"], [stream_id])
@@ -388,7 +390,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# expect three edus, in an unknown order
self.assertEqual(len(self.edus), 3)
for edu in self.edus:
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
self.assertGreaterEqual(
c.items(),
@@ -435,7 +437,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 3)
stream_id = None
for edu in self.edus:
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
if stream_id is not None:
@@ -487,7 +489,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# there should be a single update for this user.
self.assertEqual(len(self.edus), 1)
edu = self.edus.pop(0)
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
# synapse uses an empty prev_id list to indicate "needs a full resync".
@@ -544,7 +546,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# ... and we should get a single update for this user.
self.assertEqual(len(self.edus), 1)
edu = self.edus.pop(0)
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
c = edu["content"]
# synapse uses an empty prev_id list to indicate "needs a full resync".
@@ -560,7 +562,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"""Check that the given EDU is an update for the given device
Returns the stream_id.
"""
- self.assertEqual(edu["edu_type"], "m.device_list_update")
+ self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
content = edu["content"]
expected = {
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b19365b8..413b3c94 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)
+ self._storage_controllers = hs.get_storage_controllers()
+
# create the room
creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
- self.hs.get_state_handler().get_current_state(self._room_id)
+ self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
- self.hs.get_state_handler().get_current_state(self._room_id)
+ self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py
new file mode 100644
index 00000000..3a5f22c0
--- /dev/null
+++ b/tests/federation/transport/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
new file mode 100644
index 00000000..e63885c1
--- /dev/null
+++ b/tests/federation/transport/server/test__base.py
@@ -0,0 +1,141 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Dict, List, Tuple
+
+from synapse.api.errors import Codes
+from synapse.federation.transport.server import BaseFederationServlet
+from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
+from synapse.http.server import JsonResource, cancellable
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
+
+
+class CancellableFederationServlet(BaseFederationServlet):
+ PATH = "/sleep"
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.clock = hs.get_clock()
+
+ @cancellable
+ async def on_GET(
+ self, origin: str, content: None, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class BaseFederationServletCancellationTests(
+ unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `BaseFederationServlet` cancellation."""
+
+ skip = "`BaseFederationServlet` does not support cancellation yet."
+
+ path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
+
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`."""
+ resource = JsonResource(self.hs)
+
+ CancellableFederationServlet(
+ hs=self.hs,
+ authenticator=Authenticator(self.hs),
+ ratelimiter=self.hs.get_federation_ratelimiter(),
+ server_name=self.hs.hostname,
+ ).register(resource)
+
+ return resource
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = self.make_signed_federation_request(
+ "GET", self.path, await_result=False
+ )
+
+ # Advance past all the rate limiting logic. If we disconnect too early, the
+ # request won't be processed.
+ self.pump()
+
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = self.make_signed_federation_request(
+ "POST",
+ self.path,
+ content={},
+ await_result=False,
+ )
+
+ # Advance past all the rate limiting logic. If we disconnect too early, the
+ # request won't be processed.
+ self.pump()
+
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
+
+
+class BaseFederationAuthorizationTests(unittest.TestCase):
+ def test_authorization_header(self) -> None:
+ """Tests that the Authorization header is parsed correctly."""
+
+ # test a "normal" Authorization header
+ self.assertEqual(
+ _parse_auth_header(
+ b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar"'
+ ),
+ ("foo", "ed25519:1", "sig", "bar"),
+ )
+ # test an Authorization with extra spaces, upper-case names, and escaped
+ # characters
+ self.assertEqual(
+ _parse_auth_header(
+ b'X-Matrix ORIGIN=foo,KEY="ed25\\519:1",SIG="sig",destination="bar"'
+ ),
+ ("foo", "ed25519:1", "sig", "bar"),
+ )
+ self.assertEqual(
+ _parse_auth_header(
+ b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar",extra_field=ignored'
+ ),
+ ("foo", "ed25519:1", "sig", "bar"),
+ )
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 5f001c33..cfd550a0 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.constants import EduTypes
+
from tests import unittest
from tests.unittest import DEBUG, override_config
@@ -50,7 +52,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/send/txn_id_1234/",
content={
"edus": [
- {"edu_type": "m.device_list_update", "content": {"foo": "bar"}}
+ {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}}
],
"pdus": [],
},
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 5b0cd1ab..d96d5aa1 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
+from synapse.api.constants import EduTypes
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
@@ -434,16 +435,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
},
)
- # "Complete" a transaction.
- # All this really does for us is make an entry in the application_services_state
- # database table, which tracks the current stream_token per stream ID per AS.
- self.get_success(
- self.hs.get_datastores().main.complete_appservice_txn(
- 0,
- interested_appservice,
- )
- )
-
# Now, pretend that we receive a large burst of read receipts (300 total) that
# all come in at once.
for i in range(300):
@@ -486,7 +477,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
# Check that the ephemeral event is a read receipt with the expected structure
latest_read_receipt = all_ephemeral_events[-1]
- self.assertEqual(latest_read_receipt["type"], "m.receipt")
+ self.assertEqual(latest_read_receipt["type"], EduTypes.RECEIPT)
event_id = list(latest_read_receipt["content"].keys())[0]
self.assertEqual(
@@ -706,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
# Create an application service
appservice = ApplicationService(
token=random_string(10),
- hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
@@ -785,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Create an appservice that is interested in "local_user"
appservice = ApplicationService(
token=random_string(10),
- hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
@@ -852,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
self._service_token = "VERYSECRET"
self._service = ApplicationService(
self._service_token,
- "as1.invalid",
"as1",
"@as.sender:test",
namespaces={
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 11ad4422..53d49ca8 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
+ self._storage_controllers = hs.get_storage_controllers()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
- self.state_handler.get_current_state(
+ self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 060ba5f5..e0eda545 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
- self.state_store = hs.get_storage().state
+ self.state_storage_controller = hs.get_storage_controllers().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs
@@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
current_state = self.get_success(
self.store.get_events_as_list(
- (self.get_success(self.store.get_current_state_ids(room_id))).values()
+ (
+ self.get_success(self.store.get_partial_current_state_ids(room_id))
+ ).values()
)
)
@@ -276,7 +278,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
- self.OTHER_SERVER_NAME, event, state=current_state
+ self.OTHER_SERVER_NAME,
+ event,
+ state_ids={
+ (e.type, e.state_key): e.event_id for e in current_state
+ },
)
)
@@ -332,8 +338,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
most_recent_prev_event_depth,
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
# mapping from (type, state_key) -> state_event_id
+ assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
- self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
+ self.state_storage_controller.get_state_ids_for_event(
+ most_recent_prev_event_id
+ )
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())
@@ -505,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
self.get_success(d)
# sanity-check: the room should show that the new user is a member
- r = self.get_success(self.store.get_current_state_ids(room_id))
+ r = self.get_success(self.store.get_partial_current_state_ids(room_id))
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 489ba577..1a36c25c 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) -> None:
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
- state_storage = self.hs.get_storage().state
+ state_storage_controller = self.hs.get_storage_controllers().state
# create the room
user_id = self.register_user("kermit", "test")
@@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
)
- initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
auth_event_ids = [
initial_state_map[("m.room.create", "")],
@@ -146,9 +148,12 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
- persistence = self.hs.get_storage().persistence
+ persistence = self.hs.get_storage_controllers().persistence
self.get_success(
- persistence.persist_event(prev_event, EventContext.for_outlier())
+ persistence.persist_event(
+ prev_event,
+ EventContext.for_outlier(self.hs.get_storage_controllers()),
+ )
)
else:
@@ -214,7 +219,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# check that the state at that event is as expected
state = self.get_success(
- state_storage.get_state_ids_for_event(pulled_event.event_id)
+ state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
)
expected_state = {
(e.type, e.state_key): e.event_id for e in state_at_prev_event
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f4f7ab48..44da96c7 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.handler = self.hs.get_event_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self._persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
@@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self._persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
return memberEvent, memberEventContext
@@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
- self.persist_event_storage.persist_event(event3, context)
+ self._persist_event_storage_controller.persist_event(event3, context)
)
# Assert that the returned values match those from the initial event
@@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events([(event3, context)])
+ self._persist_event_storage_controller.persist_events([(event3, context)])
)
ret_event4 = events[0]
@@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events(
+ self._persist_event_storage_controller.persist_events(
[(event1, context1), (event2, context2)]
)
)
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 0482a1ea..a95868b5 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from copy import deepcopy
from typing import List
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.types import JsonDict
from tests import unittest
@@ -39,7 +39,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[],
@@ -64,7 +64,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -79,7 +79,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -105,7 +105,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -120,43 +120,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
- }
- ],
- )
-
- def test_handles_missing_content_of_m_read(self):
- self._test_filters_private(
- [
- {
- "content": {
- "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
- "$1435641916114394fHBLK:matrix.org": {
- ReceiptTypes.READ: {
- "@user:jki.re": {
- "ts": 1436451550453,
- }
- }
- },
- },
- "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
- }
- ],
- [
- {
- "content": {
- "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
- "$1435641916114394fHBLK:matrix.org": {
- ReceiptTypes.READ: {
- "@user:jki.re": {
- "ts": 1436451550453,
- }
- }
- },
- },
- "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -176,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -191,7 +155,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -210,7 +174,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
{
"content": {
@@ -223,7 +187,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
],
[
@@ -238,7 +202,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
@@ -260,7 +224,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
],
[
@@ -273,7 +237,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
},
],
)
@@ -302,7 +266,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
[
@@ -327,14 +291,38 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
+ "type": EduTypes.RECEIPT,
}
],
)
+ def test_we_do_not_mutate(self):
+ """Ensure the input values are not modified."""
+ events = [
+ {
+ "content": {
+ "$1435641916114394fHBLK:matrix.org": {
+ ReceiptTypes.READ_PRIVATE: {
+ "@rikj:jki.re": {
+ "ts": 1436451550453,
+ }
+ }
+ }
+ },
+ "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
+ "type": EduTypes.RECEIPT,
+ }
+ ]
+ original_events = deepcopy(events)
+ self._test_filters_private(events, [])
+ # Since the events are fed in from a cache they should not be modified.
+ self.assertEqual(events, original_events)
+
def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict]
):
"""Tests that the _filter_out_private returns the expected output"""
- filtered_events = self.event_source.filter_out_private(events, "@me:server.org")
+ filtered_events = self.event_source.filter_out_private_receipts(
+ events, "@me:server.org"
+ )
self.assertEqual(filtered_events, expected_output)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index e74eb717..05466556 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -179,7 +179,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result_children_ids.append(
[
(cs["room_id"], cs["state_key"])
- for cs in result_room.get("children_state")
+ for cs in result_room["children_state"]
]
)
@@ -772,7 +772,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": public_room,
"world_readable": False,
- "join_rules": JoinRules.PUBLIC,
+ "join_rule": JoinRules.PUBLIC,
},
),
(
@@ -780,7 +780,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": knock_room,
"world_readable": False,
- "join_rules": JoinRules.KNOCK,
+ "join_rule": JoinRules.KNOCK,
},
),
(
@@ -788,7 +788,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": not_invited_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
(
@@ -796,7 +796,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": invited_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
(
@@ -804,7 +804,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": restricted_room,
"world_readable": False,
- "join_rules": JoinRules.RESTRICTED,
+ "join_rule": JoinRules.RESTRICTED,
"allowed_room_ids": [],
},
),
@@ -813,7 +813,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": restricted_accessible_room,
"world_readable": False,
- "join_rules": JoinRules.RESTRICTED,
+ "join_rule": JoinRules.RESTRICTED,
"allowed_room_ids": [self.room],
},
),
@@ -822,7 +822,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": world_readable_room,
"world_readable": True,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
(
@@ -830,7 +830,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": joined_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
),
)
@@ -911,7 +911,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
{
"room_id": fed_room,
"world_readable": False,
- "join_rules": JoinRules.INVITE,
+ "join_rule": JoinRules.INVITE,
},
)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 865b8b7e..db3302a4 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -160,6 +160,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
self.store._get_event_cache.clear()
+ self.store._event_ref.clear()
# The rooms should be excluded from the sync response.
# Get a new request key.
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5f2e26a5..7af13331 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
+from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer
@@ -128,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
- def get_joined_hosts_for_room(room_id: str):
+ async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}
- self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
+ hs.get_storage_controllers().state.get_current_hosts_in_room = (
+ get_current_hosts_in_room
+ )
async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
@@ -145,7 +148,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
+ self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = (
@@ -184,7 +187,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
@@ -209,7 +212,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": ROOM_ID,
"user_id": U_APPLE.to_string(),
@@ -231,7 +234,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": ROOM_ID,
"user_id": U_ONION.to_string(),
@@ -254,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_ONION.to_string()]},
}
@@ -270,7 +273,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": OTHER_ROOM_ID,
"user_id": U_ONION.to_string(),
@@ -324,7 +327,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
- "m.typing",
+ EduTypes.TYPING,
content={
"room_id": ROOM_ID,
"user_id": U_APPLE.to_string(),
@@ -345,7 +348,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
events[0],
- [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
+ [
+ {
+ "type": EduTypes.TYPING,
+ "room_id": ROOM_ID,
+ "content": {"user_ids": []},
+ }
+ ],
)
def test_typing_timeout(self) -> None:
@@ -379,7 +388,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
@@ -402,7 +411,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
events[0],
- [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
+ [
+ {
+ "type": EduTypes.TYPING,
+ "room_id": ROOM_ID,
+ "content": {"user_ids": []},
+ }
+ ],
)
# SYN-230 - see if we can still set after timeout
@@ -433,7 +448,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 4d658d29..9e39cd97 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.appservice = ApplicationService(
token="i_am_an_app_service",
- hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not match the regex above, so that tests
@@ -954,7 +953,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.hs.get_storage().persistence.persist_event(event, context)
+ self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
new file mode 100644
index 00000000..3a5f22c0
--- /dev/null
+++ b/tests/http/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
new file mode 100644
index 00000000..b9f1a381
--- /dev/null
+++ b/tests/http/server/_base.py
@@ -0,0 +1,100 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unles4s required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+
+from twisted.internet.error import ConnectionDone
+
+from synapse.http.server import (
+ HTTP_STATUS_REQUEST_CANCELLED,
+ respond_with_html_bytes,
+ respond_with_json,
+)
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel, ThreadedMemoryReactorClock
+
+
+class EndpointCancellationTestHelperMixin(unittest.TestCase):
+ """Provides helper methods for testing cancellation of endpoints."""
+
+ def _test_disconnect(
+ self,
+ reactor: ThreadedMemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+ ) -> None:
+ """Disconnects an in-flight request and checks the response.
+
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be
+ cancelled, `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200`
+ or `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
+
+ request = channel.request
+ self.assertFalse(
+ channel.is_finished(),
+ "Request finished before we could disconnect - "
+ "was `await_result=False` passed to `make_request`?",
+ )
+
+ # We're about to disconnect the request. This also disconnects the channel, so
+ # we have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
+ else:
+ respond_mock.assert_not_called()
+
+ # The handler is expected to run to completion.
+ reactor.pump([1.0])
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 638babae..006dbab0 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel
from synapse.api.errors import RequestSendFailed
from synapse.http.matrixfederationclient import (
- MAX_RESPONSE_SIZE,
+ JsonParser,
MatrixFederationHttpClient,
MatrixFederationRequest,
)
@@ -609,9 +609,9 @@ class FederationClientTests(HomeserverTestCase):
while not test_d.called:
protocol.dataReceived(b"a" * chunk_size)
sent += chunk_size
- self.assertLessEqual(sent, MAX_RESPONSE_SIZE)
+ self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
- self.assertEqual(sent, MAX_RESPONSE_SIZE)
+ self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index a80bfb9f..b3655d7b 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -12,16 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from http import HTTPStatus
from io import BytesIO
+from typing import Tuple
from unittest.mock import Mock
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import cancellable
from synapse.http.servlet import (
+ RestServlet,
parse_json_object_from_request,
parse_json_value_from_request,
)
+from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns
+from synapse.server import HomeServer
+from synapse.types import JsonDict
from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
def make_request(content):
@@ -40,19 +49,21 @@ class TestServletUtils(unittest.TestCase):
"""Basic tests for parse_json_value_from_request."""
# Test round-tripping.
obj = {"foo": 1}
- result = parse_json_value_from_request(make_request(obj))
- self.assertEqual(result, obj)
+ result1 = parse_json_value_from_request(make_request(obj))
+ self.assertEqual(result1, obj)
# Results don't have to be objects.
- result = parse_json_value_from_request(make_request(b'["foo"]'))
- self.assertEqual(result, ["foo"])
+ result2 = parse_json_value_from_request(make_request(b'["foo"]'))
+ self.assertEqual(result2, ["foo"])
# Test empty.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b""))
- result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
- self.assertIsNone(result)
+ result3 = parse_json_value_from_request(
+ make_request(b""), allow_empty_body=True
+ )
+ self.assertIsNone(result3)
# Invalid UTF-8.
with self.assertRaises(SynapseError):
@@ -76,3 +87,52 @@ class TestServletUtils(unittest.TestCase):
# Test not an object
with self.assertRaises(SynapseError):
parse_json_object_from_request(make_request(b'["foo"]'))
+
+
+class CancellableRestServlet(RestServlet):
+ """A `RestServlet` with a mix of cancellable and uncancellable handlers."""
+
+ PATTERNS = client_patterns("/sleep$")
+
+ def __init__(self, hs: HomeServer):
+ super().__init__()
+ self.clock = hs.get_clock()
+
+ @cancellable
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class TestRestServletCancellation(
+ unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `RestServlet` cancellation."""
+
+ servlets = [
+ lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
+ ]
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = self.make_request("GET", "/sleep", await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = self.make_request("POST", "/sleep", await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
index 8c13b4f6..b2dbf76d 100644
--- a/tests/http/test_site.py
+++ b/tests/http/test_site.py
@@ -36,7 +36,7 @@ class SynapseRequestTestCase(HomeserverTestCase):
# as a control case, first send a regular request.
# complete the connection and wire it up to a fake transport
- client_address = IPv6Address("TCP", "::1", "2345")
+ client_address = IPv6Address("TCP", "::1", 2345)
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8bc84aaa..169e29b5 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -399,7 +399,7 @@ class ModuleApiTestCase(HomeserverTestCase):
for edu in edus:
# Make sure we're only checking presence-type EDUs
- if edu["edu_type"] != EduTypes.Presence:
+ if edu["edu_type"] != EduTypes.PRESENCE:
continue
# EDUs can contain multiple presence updates
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 5dba1870..9b623d00 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Set, Tuple, Union
import frozendict
@@ -26,7 +26,12 @@ from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
+ def _get_evaluator(
+ self,
+ content: JsonDict,
+ relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
+ relations_match_enabled: bool = False,
+ ) -> PushRuleEvaluatorForEvent:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
sender_power_level = 0
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
- event, room_member_count, sender_power_level, power_levels
+ event,
+ room_member_count,
+ sender_power_level,
+ power_levels,
+ relations or set(),
+ relations_match_enabled,
)
def test_display_name(self) -> None:
@@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
push_rule_evaluator.tweaks_for_actions(actions),
{"sound": "default", "highlight": True},
)
+
+ def test_relation_match(self) -> None:
+ """Test the relation_match push rule kind."""
+
+ # Check if the experimental feature is disabled.
+ evaluator = self._get_evaluator(
+ {}, {"m.annotation": {("@user:test", "m.reaction")}}
+ )
+ condition = {"kind": "relation_match"}
+ # Oddly, an unknown condition always matches.
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # A push rule evaluator with the experimental rule enabled.
+ evaluator = self._get_evaluator(
+ {}, {"m.annotation": {("@user:test", "m.reaction")}}, True
+ )
+
+ # Check just relation type.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check relation type and sender.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "sender": "@user:test",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "sender": "@other:test",
+ }
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check relation type and event type.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "type": "m.reaction",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check just sender, this fails since rel_type is required.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "sender": "@user:test",
+ }
+ self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check sender glob.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "sender": "@*:test",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+ # Check event type glob.
+ condition = {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.annotation",
+ "event_type": "*.reaction",
+ }
+ self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a7602b4c..970d5e53 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
@@ -32,6 +33,7 @@ from synapse.server import HomeServer
from tests import unittest
from tests.server import FakeTransport
+from tests.utils import USE_POSTGRES_FOR_TESTS
try:
import hiredis
@@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self):
- self._subscribers = set()
+ self._subscribers_by_channel: Dict[
+ bytes, Set["FakeRedisPubSubProtocol"]
+ ] = defaultdict(set)
- def add_subscriber(self, conn):
+ def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE"""
- self._subscribers.add(conn)
+ self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn):
- """A connection has called UNSUBSCRIBE"""
- self._subscribers.discard(conn)
+ """A connection has lost connection"""
+ for subscribers in self._subscribers_by_channel.values():
+ subscribers.discard(conn)
- def publish(self, conn, channel, msg) -> int:
+ def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers."""
- for sub in self._subscribers:
+ for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
- return len(self._subscribers)
+ return len(self._subscribers_by_channel)
def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
@@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
- (channel,) = args
- self._server.add_subscriber(self)
- self.send(["subscribe", channel, 1])
+ for idx, channel in enumerate(args):
+ num_channels = idx + 1
+ self._server.add_subscriber(self, channel)
+ self.send(["subscribe", channel, num_channels])
# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
@@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
+
+
+class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
+ """
+ A test case that enables Redis, providing a fake Redis server.
+ """
+
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py
new file mode 100644
index 00000000..3a5f22c0
--- /dev/null
+++ b/tests/replication/http/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
new file mode 100644
index 00000000..a5ab093a
--- /dev/null
+++ b/tests/replication/http/test__base.py
@@ -0,0 +1,106 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.errors import Codes
+from synapse.http.server import JsonResource, cancellable
+from synapse.replication.http import REPLICATION_PREFIX
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
+
+
+class CancellableReplicationEndpoint(ReplicationEndpoint):
+ NAME = "cancellable_sleep"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload() -> JsonDict:
+ return {}
+
+ @cancellable
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class UncancellableReplicationEndpoint(ReplicationEndpoint):
+ NAME = "uncancellable_sleep"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload() -> JsonDict:
+ return {}
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class ReplicationEndpointCancellationTestCase(
+ unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `ReplicationEndpoint` cancellation."""
+
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`."""
+ resource = JsonResource(self.hs)
+
+ CancellableReplicationEndpoint(self.hs).register(resource)
+ UncancellableReplicationEndpoint(self.hs).register(resource)
+
+ return resource
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
+ channel = self.make_request("POST", path, await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
+ channel = self.make_request("POST", path, await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 85be79d1..c5705256 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 297a9e77..6d3d4afe 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
msg, msgctx = self.build_event()
self.get_success(
- self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ self._storage_controllers.persistence.persist_events(
+ [(j2, j2ctx), (msg, msgctx)]
+ )
)
self.replicate()
@@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.storage.persistence.persist_events(
+ self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True
)
)
else:
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index 5bbbd5fb..19f57115 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self.persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
@@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
# Join the second user to the second room
@@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
def test_return_empty_with_no_data(self):
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
new file mode 100644
index 00000000..e6a19eaf
--- /dev/null
+++ b/tests/replication/tcp/test_handler.py
@@ -0,0 +1,73 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests.replication._base import RedisMultiWorkerStreamTestCase
+
+
+class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+ def test_subscribed_to_enough_redis_channels(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ self.assertCountEqual(
+ self.hs.get_replication_command_handler()._channels_to_subscribe_to,
+ ["USER_IP"],
+ )
+
+ def test_background_worker_subscribed_to_user_ip(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ worker1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+ self.assertIn(
+ "USER_IP",
+ worker1.get_replication_command_handler()._channels_to_subscribe_to,
+ )
+
+ # Advance so the Redis subscription gets processed
+ self.pump(0.1)
+
+ # The counts are 2 because both the main process and the worker are subscribed.
+ self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
+ self.assertEqual(
+ len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
+ )
+
+ def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ worker2 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker2",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+ self.assertNotIn(
+ "USER_IP",
+ worker2.get_replication_command_handler()._channels_to_subscribe_to,
+ )
+
+ # Advance so the Redis subscription gets processed
+ self.pump(0.1)
+
+ # The count is 2 because both the main process and the worker are subscribed.
+ self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
+ # For USER_IP, the count is 1 because only the main process is subscribed.
+ self.assertEqual(
+ len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
+ )
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 5f142e84..a7ca6806 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,7 +14,6 @@
import logging
from unittest.mock import patch
-from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -64,21 +63,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# We control the room ID generation by patching out the
# `_generate_room_id` method
- async def generate_room(
- creator_id: str, is_public: bool, room_version: RoomVersion
- ):
- await self.store.store_room(
- room_id=room_id,
- room_creator_user_id=creator_id,
- is_public=is_public,
- room_version=room_version,
- )
- return room_id
-
with patch(
"synapse.handlers.room.RoomCreationHandler._generate_room_id"
) as mock:
- mock.side_effect = generate_room
+ mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok)
def test_basic(self):
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 40571b75..82ac5991 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -14,7 +14,6 @@
import urllib.parse
from http import HTTPStatus
-from typing import List
from parameterized import parameterized
@@ -23,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.rest.admin import VersionServlet
-from synapse.rest.client import groups, login, room
+from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -49,93 +48,6 @@ class VersionTestCase(unittest.HomeserverTestCase):
)
-class DeleteGroupTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- groups.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- @unittest.override_config({"experimental_features": {"groups_enabled": True}})
- def test_delete_group(self) -> None:
- # Create a new group
- channel = self.make_request(
- "POST",
- b"/create_group",
- access_token=self.admin_user_tok,
- content={"localpart": "test"},
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- group_id = channel.json_body["group_id"]
-
- self._check_group(group_id, expect_code=HTTPStatus.OK)
-
- # Invite/join another user
-
- url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
- channel = self.make_request(
- "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- url = "/groups/%s/self/accept_invite" % (group_id,)
- channel = self.make_request(
- "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- # Check other user knows they're in the group
- self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
- self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
-
- # Now delete the group
- url = "/_synapse/admin/v1/delete_group/" + group_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- access_token=self.admin_user_tok,
- content={"localpart": "test"},
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- # Check group returns HTTPStatus.NOT_FOUND
- self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND)
-
- # Check users don't think they're in the group
- self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
- self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
-
- def _check_group(self, group_id: str, expect_code: int) -> None:
- """Assert that trying to fetch the given group results in the given
- HTTP status code
- """
-
- url = "/groups/%s/profile" % (group_id,)
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
-
- self.assertEqual(expect_code, channel.code, msg=channel.json_body)
-
- def _get_groups_user_is_in(self, access_token: str) -> List[str]:
- """Returns the list of groups the user is in (given their access token)"""
- channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- return channel.json_body["groups"]
-
-
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Test /quarantine_media admin API."""
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 95282f07..ca6af941 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -2467,7 +2467,6 @@ PURGE_TABLES = [
"event_push_actions",
"event_search",
"events",
- "group_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
@@ -2484,9 +2483,9 @@ PURGE_TABLES = [
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
- "group_summary_rooms",
"room_account_data",
"room_tags",
# "state_groups", # Current impl leaves orphaned state groups around.
"state_groups_state",
+ "federation_inbound_events_staging",
]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0cdf1dec..0d441022 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage.persistence.persist_event(event, context))
+ self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index e0a11da9..a43a1372 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": user_id, "exclusive": True}]},
sender=user_id,
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 9653f458..05355c7f 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -195,8 +195,17 @@ class UIAuthTests(unittest.HomeserverTestCase):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.device_id = "dev1"
+
+ # Force-enable password login for just long enough to log in.
+ auth_handler = self.hs.get_auth_handler()
+ allow_auth_for_login = auth_handler._password_enabled_for_login
+ auth_handler._password_enabled_for_login = True
+
self.user_tok = self.login("test", self.user_pass, self.device_id)
+ # Restore password login to however it was.
+ auth_handler._password_enabled_for_login = allow_auth_for_login
+
def delete_device(
self,
access_token: str,
@@ -263,6 +272,38 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
+ @override_config({"password_config": {"enabled": "only_for_reauth"}})
+ def test_ui_auth_with_passwords_for_reauth_only(self) -> None:
+ """
+ Test user interactive authentication outside of registration.
+ """
+
+ # Attempt to delete this device.
+ # Returns a 401 as per the spec
+ channel = self.delete_device(
+ self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+ )
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ self.user_tok,
+ self.device_id,
+ HTTPStatus.OK,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
def test_grandfathered_identifier(self) -> None:
"""Check behaviour without "identifier" dict
diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_devices.py
index a8af4e24..aa982224 100644
--- a/tests/rest/client/test_device_lists.py
+++ b/tests/rest/client/test_devices.py
@@ -13,8 +13,13 @@
# limitations under the License.
from http import HTTPStatus
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, login, register
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -157,3 +162,41 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
+
+
+class DevicesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_device_handler()
+
+ @unittest.override_config({"delete_stale_devices_after": 72000000})
+ def test_delete_stale_devices(self) -> None:
+ """Tests that stale devices are automatically removed after a set time of
+ inactivity.
+ The configuration is set to delete devices that haven't been used in the past 20h.
+ """
+ # Register a user and creates 2 devices for them.
+ user_id = self.register_user("user", "password")
+ tok1 = self.login("user", "password", device_id="abc")
+ tok2 = self.login("user", "password", device_id="def")
+
+ # Sync them so they have a last_seen value.
+ self.make_request("GET", "/sync", access_token=tok1)
+ self.make_request("GET", "/sync", access_token=tok2)
+
+ # Advance half a day and sync again with one of the devices, so that the next
+ # time the background job runs we don't delete this device (since it will look
+ # for devices that haven't been used for over an hour).
+ self.reactor.advance(43200)
+ self.make_request("GET", "/sync", access_token=tok1)
+
+ # Advance another half a day, and check that the device that has synced still
+ # exists but the one that hasn't has been removed.
+ self.reactor.advance(43200)
+ self.get_success(self.handler.get_device(user_id, "abc"))
+ self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError)
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 1b1392fa..a9b7db9d 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -19,6 +19,7 @@ from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
+from synapse.api.constants import EduTypes
from synapse.rest.client import events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -103,7 +104,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
c
for c in channel.json_body["chunk"]
if not (
- c.get("type") == "m.presence"
+ c.get("type") == EduTypes.PRESENCE
and c["content"].get("user_id") == self.user_id
)
]
diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py
deleted file mode 100644
index e067cf82..00000000
--- a/tests/rest/client/test_groups.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest.client import groups, room
-
-from tests import unittest
-from tests.unittest import override_config
-
-
-class GroupsTestCase(unittest.HomeserverTestCase):
- user_id = "@alice:test"
- room_creator_user_id = "@bob:test"
-
- servlets = [room.register_servlets, groups.register_servlets]
-
- @override_config({"enable_group_creation": True})
- def test_rooms_limited_by_visibility(self) -> None:
- group_id = "+spqr:test"
-
- # Alice creates a group
- channel = self.make_request("POST", "/create_group", {"localpart": "spqr"})
- self.assertEqual(channel.code, 200, msg=channel.text_body)
- self.assertEqual(channel.json_body, {"group_id": group_id})
-
- # Bob creates a private room
- room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False)
- self.helper.auth_user_id = self.room_creator_user_id
- self.helper.send_state(
- room_id, "m.room.name", {"name": "bob's secret room"}, tok=None
- )
- self.helper.auth_user_id = self.user_id
-
- # Alice adds the room to her group.
- channel = self.make_request(
- "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {}
- )
- self.assertEqual(channel.code, 200, msg=channel.text_body)
- self.assertEqual(channel.json_body, {})
-
- # Alice now tries to retrieve the room list of the space.
- channel = self.make_request("GET", f"/groups/{group_id}/rooms")
- self.assertEqual(channel.code, 200, msg=channel.text_body)
- self.assertEqual(
- channel.json_body, {"chunk": [], "total_room_count_estimate": 0}
- )
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 4920468f..f4ea1209 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.service = ApplicationService(
id="unique_identifier",
token="some_token",
- hostname="example.com",
sender="@asbot:example.com",
namespaces={
ApplicationService.NS_USERS: [
@@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.another_service = ApplicationService(
id="another__identifier",
token="another_token",
- hostname="example.com",
sender="@as2bot:example.com",
namespaces={
ApplicationService.NS_USERS: [
diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 7b7d283b..a4327f7a 100644
--- a/tests/rest/client/test_mutual_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -36,12 +36,10 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self.handler = hs.get_user_directory_handler()
def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py
new file mode 100644
index 00000000..700f6587
--- /dev/null
+++ b/tests/rest/client/test_notifications.py
@@ -0,0 +1,91 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.rest.client import login, notifications, receipts, room
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+
+class HTTPPusherTests(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ receipts.register_servlets,
+ notifications.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.store = homeserver.get_datastores().main
+ self.module_api = homeserver.get_module_api()
+ self.event_creation_handler = homeserver.get_event_creation_handler()
+ self.sync_handler = homeserver.get_sync_handler()
+ self.auth_handler = homeserver.get_auth_handler()
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ # Mock out the calls over federation.
+ fed_transport_client = Mock(spec=["send_transaction"])
+ fed_transport_client.send_transaction = simple_async_mock({})
+
+ return self.setup_test_homeserver(
+ federation_transport_client=fed_transport_client,
+ )
+
+ def test_notify_for_local_invites(self) -> None:
+ """
+ Local users will get notified for invites
+ """
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Check we start with no pushes
+ channel = self.make_request(
+ "GET",
+ "/notifications",
+ access_token=other_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body)
+
+ # Send an invite
+ self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token)
+
+ # We should have a notification now
+ channel = self.make_request(
+ "GET",
+ "/notifications",
+ access_token=other_access_token,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
+ self.assertEqual(
+ channel.json_body["notifications"][0]["event"]["content"]["membership"],
+ "invite",
+ channel.json_body,
+ )
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 9aebf173..afb08b27 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
@@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
- self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 27dee8f6..62e4db23 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
relation_type: str,
assertion_callable: Callable[[JsonDict], None],
expected_db_txn_for_event: int,
+ access_token: Optional[str] = None,
) -> None:
"""
Makes requests to various endpoints which should include bundled aggregations
@@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
for relation-specific assertions.
expected_db_txn_for_event: The number of database transactions which
are expected for a call to /event/.
+ access_token: The access token to user, defaults to self.user_token.
"""
+ access_token = access_token or self.user_token
def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
@@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body)
@@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
@@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"])
@@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# Request sync.
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
+ "GET", f"/sync?filter={filter}", access_token=access_token
)
self.assertEqual(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
@@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
chunk = [
@@ -995,7 +998,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
@@ -1031,36 +1034,66 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
def test_thread(self) -> None:
"""
Test that threads get correctly bundled.
"""
- self._send_relation(RelationTypes.THREAD, "m.room.test")
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ # The root message is from "user", send replies as "user2".
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
thread_2 = channel.json_body["event_id"]
- def assert_thread(bundled_aggregations: JsonDict) -> None:
- self.assertEqual(2, bundled_aggregations.get("count"))
- self.assertTrue(bundled_aggregations.get("current_user_participated"))
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
+ # This needs two assertion functions which are identical except for whether
+ # the current_user_participated flag is True, create a factory for the
+ # two versions.
+ def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
+ def assert_thread(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertEqual(
+ participated, bundled_aggregations.get("current_user_participated")
+ )
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user2_id,
+ "type": "m.room.test",
},
- "event_id": thread_2,
- "sender": self.user_id,
- "type": "m.room.test",
- },
- bundled_aggregations.get("latest_event"),
- )
+ bundled_aggregations.get("latest_event"),
+ )
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10)
+ return assert_thread
+
+ # The "user" sent the root event and is making queries for the bundled
+ # aggregations: they have participated.
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ # The "user2" sent replies in the thread and is making queries for the
+ # bundled aggregations: they have participated.
+ #
+ # Note that this re-uses some cached values, so the total number of
+ # queries is much smaller.
+ self._test_bundled_aggregations(
+ RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ )
+
+ # A user with no interactions with the thread: they have not participated.
+ user3_id, user3_token = self._create_user("charlie")
+ self.helper.join(self.room, user=user3_id, tok=user3_token)
+ self._test_bundled_aggregations(
+ RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ )
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
"""
@@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
def test_nested_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 7b8fe6d0..ac9c1133 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -129,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Send a first event, which should be filtered out at the end of the test.
@@ -154,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
- filter_events_for_client(storage, self.user_id, events)
+ filter_events_for_client(storage_controllers, self.user_id, events)
)
# We should only get one event back.
@@ -252,16 +253,24 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
- config["retention"] = {
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+
+ retention_config = {
"enabled": True,
}
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ retention_config.update(config.get("retention", {}))
+ config["retention"] = retention_config
+
+ return config
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
mock_federation_client = Mock(spec=["backfill"])
self.hs = self.setup_test_homeserver(
- config=config,
federation_client=mock_federation_client,
)
return self.hs
@@ -295,6 +304,24 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
self._test_retention(room_id, expected_code_for_first_event=404)
+ @unittest.override_config({"retention": {"enabled": False}})
+ def test_visibility_when_disabled(self) -> None:
+ """Retention policies should be ignored when the retention feature is disabled."""
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms},
+ tok=self.token,
+ )
+
+ resp = self.helper.send(room_id=room_id, body="test", tok=self.token)
+
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ self.get_event(room_id, resp["event_id"])
+
def _test_retention(
self, room_id: str, expected_code_for_first_event: int = 200
) -> None:
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 41a1bf6d..9d5cb60d 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
self.appservice = ApplicationService(
token="i_am_an_app_service",
- hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not have to match the regex above
@@ -88,7 +87,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.clock = clock
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
@@ -168,7 +167,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
# Fetch the state_groups
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+ self._storage_controllers.state.get_state_groups_ids(
+ room_id, historical_event_ids
+ )
)
# We expect all of the historical events to be using the same state_group
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9443daa0..f523d89b 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -26,6 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
+ EduTypes,
EventContentFields,
EventTypes,
Membership,
@@ -925,7 +926,7 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
- callback_mock = Mock(side_effect=user_may_join_room)
+ callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
# Join a first room, without being invited to it.
@@ -1116,6 +1117,264 @@ class RoomMessagesTestCase(RoomBase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
+class RoomPowerLevelOverridesTestCase(RoomBase):
+ """Tests that the power levels can be overridden with server config."""
+
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin_user_id = self.register_user("admin", "pass")
+ self.admin_access_token = self.login("admin", "pass")
+
+ def power_levels(self, room_id: str) -> Dict[str, Any]:
+ return self.helper.get_state(
+ room_id, "m.room.power_levels", self.admin_access_token
+ )
+
+ def test_default_power_levels_with_room_override(self) -> None:
+ """
+ Create a room, providing power level overrides.
+ Confirm that the room's power levels reflect the overrides.
+
+ See https://github.com/matrix-org/matrix-spec/issues/492
+ - currently we overwrite each key of power_level_content_override
+ completely.
+ """
+
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={
+ "power_level_content_override": {"events": {"custom.event": 0}}
+ },
+ )
+ self.assertEqual(
+ {
+ "custom.event": 0,
+ },
+ self.power_levels(room_id)["events"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_power_levels_with_server_override(self) -> None:
+ """
+ With a server configured to modify the room-level defaults,
+ Create a room, without providing any extra power level overrides.
+ Confirm that the room's power levels reflect the server-level overrides.
+
+ Similar to https://github.com/matrix-org/matrix-spec/issues/492,
+ we overwrite each key of power_level_content_override completely.
+ """
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.assertEqual(
+ {
+ "custom.event": 0,
+ },
+ self.power_levels(room_id)["events"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {
+ "events": {"server.event": 0},
+ "ban": 13,
+ },
+ }
+ },
+ )
+ def test_power_levels_with_server_and_room_overrides(self) -> None:
+ """
+ With a server configured to modify the room-level defaults,
+ create a room, providing different overrides.
+ Confirm that the room's power levels reflect both overrides, and
+ choose the room overrides where they clash.
+ """
+
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={
+ "power_level_content_override": {"events": {"room.event": 0}}
+ },
+ )
+
+ # Room override wins over server config
+ self.assertEqual(
+ {"room.event": 0},
+ self.power_levels(room_id)["events"],
+ )
+
+ # But where there is no room override, server config wins
+ self.assertEqual(13, self.power_levels(room_id)["ban"])
+
+
+class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
+ """
+ Tests that we can really do various otherwise-prohibited actions
+ based on overriding the power levels in config.
+ """
+
+ user_id = "@sid1:red"
+
+ def test_creator_can_post_state_event(self) -> None:
+ # Given I am the creator of a room
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am allowed
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+
+ def test_normal_user_can_not_post_state_event(self) -> None:
+ # Given I am a normal member of a room
+ room_id = self.helper.create_room_as("@some_other_guy:red")
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed because state events require PL>=50
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ "user_level (0) < send_level (50)",
+ channel.json_body["error"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_with_config_override_normal_user_can_post_state_event(self) -> None:
+ # Given the server has config allowing normal users to post my event type,
+ # and I am a normal member of a room
+ room_id = self.helper.create_room_as("@some_other_guy:red")
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am allowed
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_any_room_override_defeats_config_override(self) -> None:
+ # Given the server has config allowing normal users to post my event type
+ # And I am a normal member of a room
+ # But the room was created with special permissions
+ extra_content: Dict[str, Any] = {
+ "power_level_content_override": {"events": {}},
+ }
+ room_id = self.helper.create_room_as(
+ "@some_other_guy:red", extra_content=extra_content
+ )
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_specific_room_override_defeats_config_override(self) -> None:
+ # Given the server has config allowing normal users to post my event type,
+ # and I am a normal member of a room,
+ # but the room was created with special permissions for this event type
+ extra_content = {
+ "power_level_content_override": {"events": {"custom.event": 1}},
+ }
+ room_id = self.helper.create_room_as(
+ "@some_other_guy:red", extra_content=extra_content
+ )
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ + "user_level (0) < send_level (1)",
+ channel.json_body["error"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ "private_chat": None,
+ "trusted_private_chat": None,
+ }
+ },
+ )
+ def test_config_override_applies_only_to_specific_preset(self) -> None:
+ # Given the server has config for public_chats,
+ # and I am a normal member of a private_chat room
+ room_id = self.helper.create_room_as("@some_other_guy:red", is_public=False)
+ self.helper.invite(room=room_id, src="@some_other_guy:red", targ=self.user_id)
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed because the public_chat config does not
+ # affect this room, because this room is a private_chat
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ + "user_level (0) < send_level (50)",
+ channel.json_body["error"],
+ )
+
+
class RoomInitialSyncTestCase(RoomBase):
"""Tests /rooms/$room_id/initialSync."""
@@ -1154,7 +1413,7 @@ class RoomInitialSyncTestCase(RoomBase):
e["content"]["user_id"]: e for e in channel.json_body["presence"]
}
self.assertTrue(self.user_id in presence_by_user)
- self.assertEqual("m.presence", presence_by_user[self.user_id]["type"])
+ self.assertEqual(EduTypes.PRESENCE, presence_by_user[self.user_id]["type"])
class RoomMessageListTestCase(RoomBase):
@@ -2598,7 +2857,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
- mock = Mock(return_value=make_awaitable(True))
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
# Send a 3PID invite into the room and check that it succeeded.
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index c3942889..6435800f 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
@@ -139,7 +140,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
for i in range(3):
self.get_success(
federation_registry.on_edu(
- "m.direct_to_device",
+ EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
@@ -172,7 +173,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# and we can send more messages
self.get_success(
federation_registry.on_edu(
- "m.direct_to_device",
+ EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index ae5ada3b..d9bd8c4a 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.rest.client import (
directory,
login,
@@ -226,7 +226,7 @@ class RoomTestCase(_ShadowBannedBase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": room_id,
"content": {"user_ids": [self.other_user_id]},
}
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 01083376..e3efd1f1 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from http import HTTPStatus
from typing import List, Optional
from parameterized import parameterized
@@ -21,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
+ EduTypes,
EventContentFields,
EventTypes,
ReceiptTypes,
@@ -485,30 +487,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we didn't override the public read receipt
self.assertIsNone(self._get_read_receipt())
- @parameterized.expand(
- [
- # Old Element version, expected to send an empty body
- (
- "agent1",
- "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
- 200,
- ),
- # Old SchildiChat version, expected to send an empty body
- ("agent2", "SchildiChat/1.2.1 (Android 10)", 200),
- # Expected 400: Denies empty body starting at version 1.3+
- ("agent3", "Element/1.3.6 (Android 10)", 400),
- ("agent4", "SchildiChat/1.3.6 (Android 11)", 400),
- # Contains "Riot": Receipts with empty bodies expected
- ("agent5", "Element (Riot.im) (Android 9)", 200),
- # Expected 400: Does not contain "Android"
- ("agent6", "Element/1.2.1", 400),
- # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage
- ("agent7", "Element dbg/1.1.8-dev (Android)", 400),
- ]
- )
- def test_read_receipt_with_empty_body(
- self, name: str, user_agent: str, expected_status_code: int
- ) -> None:
+ def test_read_receipt_with_empty_body_is_rejected(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -517,16 +496,16 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
"POST",
f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
access_token=self.tok2,
- custom_headers=[("User-Agent", user_agent)],
)
- self.assertEqual(channel.code, expected_status_code)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body)
def _get_read_receipt(self) -> Optional[JsonDict]:
"""Syncs and returns the read receipt."""
# Checks if event is a read receipt
def is_read_receipt(event: JsonDict) -> bool:
- return event["type"] == "m.receipt"
+ return event["type"] == EduTypes.RECEIPT
# Sync
channel = self.make_request(
@@ -678,12 +657,13 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(3)
# Check that custom events with a body increase the unread counter.
- self.helper.send_event(
+ result = self.helper.send_event(
self.room_id,
"org.matrix.custom_type",
{"body": "hello"},
tok=self.tok2,
)
+ event_id = result["event_id"]
self._check_unread_count(4)
# Check that edits don't increase the unread counter.
@@ -693,7 +673,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
content={
"body": "hello",
"msgtype": "m.text",
- "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ "m.relates_to": {
+ "rel_type": RelationTypes.REPLACE,
+ "event_id": event_id,
+ },
},
tok=self.tok2,
)
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index d6da5107..61b66d76 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -17,6 +17,7 @@
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import EduTypes
from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.types import UserID
@@ -67,7 +68,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
events[0],
[
{
- "type": "m.typing",
+ "type": EduTypes.TYPING,
"room_id": self.room_id,
"content": {"user_ids": [self.user_id]},
}
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index c86fc5df..98c1039d 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -76,7 +76,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
"""
Upgrading a room should work fine.
"""
- # THe user isn't in the room.
+ # The user isn't in the room.
roomless = self.register_user("roomless", "pass")
roomless_token = self.login(roomless, "pass")
@@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
new_space_id = channel.json_body["replacement_room"]
- state_ids = self.get_success(self.store.get_current_state_ids(new_space_id))
+ state_ids = self.get_success(
+ self.store.get_partial_current_state_ids(new_space_id)
+ )
# Ensure the new room is still a space.
create_event = self.get_success(
@@ -263,3 +265,35 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertIn((EventTypes.SpaceChild, self.room_id), state_ids)
# The child that was removed should not be copied over.
self.assertNotIn((EventTypes.SpaceChild, old_room_id), state_ids)
+
+ def test_custom_room_type(self) -> None:
+ """Test upgrading a room that has a custom room type set."""
+ test_room_type = "com.example.my_custom_room_type"
+
+ # Create a room with a custom room type.
+ room_id = self.helper.create_room_as(
+ self.creator,
+ tok=self.creator_token,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: test_room_type}
+ },
+ )
+
+ # Upgrade the room!
+ channel = self._upgrade_room(room_id=room_id)
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ state_ids = self.get_success(
+ self.store.get_partial_current_state_ids(new_room_id)
+ )
+
+ # Ensure the new room is the same type as the old room.
+ create_event = self.get_success(
+ self.store.get_event(state_ids[(EventTypes.Create, "")])
+ )
+ self.assertEqual(
+ create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type
+ )
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
new file mode 100644
index 00000000..14af07c5
--- /dev/null
+++ b/tests/rest/media/test_media_retention.py
@@ -0,0 +1,321 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import io
+from typing import Iterable, Optional, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, register, room
+from synapse.server import HomeServer
+from synapse.types import UserID
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+from tests.utils import MockClock
+
+
+class MediaRetentionTestCase(unittest.HomeserverTestCase):
+
+ ONE_DAY_IN_MS = 24 * 60 * 60 * 1000
+ THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS
+
+ servlets = [
+ room.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ # We need to be able to test advancing time in the homeserver, so we
+ # replace the test homeserver's default clock with a MockClock, which
+ # supports advancing time.
+ return self.setup_test_homeserver(clock=MockClock())
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.remote_server_name = "remote.homeserver"
+ self.store = hs.get_datastores().main
+
+ # Create a user to upload media with
+ test_user_id = self.register_user("alice", "password")
+
+ # Inject media (recently accessed, old access, never accessed, old access
+ # quarantined media) into both the local store and the remote cache, plus
+ # one additional local media that is marked as protected from quarantine.
+ media_repository = hs.get_media_repository()
+ test_media_content = b"example string"
+
+ def _create_media_and_set_attributes(
+ last_accessed_ms: Optional[int],
+ is_quarantined: Optional[bool] = False,
+ is_protected: Optional[bool] = False,
+ ) -> str:
+ # "Upload" some media to the local media store
+ mxc_uri = self.get_success(
+ media_repository.create_content(
+ media_type="text/plain",
+ upload_name=None,
+ content=io.BytesIO(test_media_content),
+ content_length=len(test_media_content),
+ auth_user=UserID.from_string(test_user_id),
+ )
+ )
+
+ media_id = mxc_uri.split("/")[-1]
+
+ # Set the last recently accessed time for this media
+ if last_accessed_ms is not None:
+ self.get_success(
+ self.store.update_cached_last_access_time(
+ local_media=(media_id,),
+ remote_media=(),
+ time_ms=last_accessed_ms,
+ )
+ )
+
+ if is_quarantined:
+ # Mark this media as quarantined
+ self.get_success(
+ self.store.quarantine_media_by_id(
+ server_name=self.hs.config.server.server_name,
+ media_id=media_id,
+ quarantined_by="@theadmin:test",
+ )
+ )
+
+ if is_protected:
+ # Mark this media as protected from quarantine
+ self.get_success(
+ self.store.mark_local_media_as_safe(
+ media_id=media_id,
+ safe=True,
+ )
+ )
+
+ return media_id
+
+ def _cache_remote_media_and_set_attributes(
+ media_id: str,
+ last_accessed_ms: Optional[int],
+ is_quarantined: Optional[bool] = False,
+ ) -> str:
+ # Pretend to cache some remote media
+ self.get_success(
+ self.store.store_cached_remote_media(
+ origin=self.remote_server_name,
+ media_id=media_id,
+ media_type="text/plain",
+ media_length=1,
+ time_now_ms=clock.time_msec(),
+ upload_name="testfile.txt",
+ filesystem_id="abcdefg12345",
+ )
+ )
+
+ # Set the last recently accessed time for this media
+ if last_accessed_ms is not None:
+ self.get_success(
+ hs.get_datastores().main.update_cached_last_access_time(
+ local_media=(),
+ remote_media=((self.remote_server_name, media_id),),
+ time_ms=last_accessed_ms,
+ )
+ )
+
+ if is_quarantined:
+ # Mark this media as quarantined
+ self.get_success(
+ self.store.quarantine_media_by_id(
+ server_name=self.remote_server_name,
+ media_id=media_id,
+ quarantined_by="@theadmin:test",
+ )
+ )
+
+ return media_id
+
+ # Start with the local media store
+ self.local_recently_accessed_media = _create_media_and_set_attributes(
+ last_accessed_ms=self.THIRTY_DAYS_IN_MS,
+ )
+ self.local_not_recently_accessed_media = _create_media_and_set_attributes(
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ )
+ self.local_not_recently_accessed_quarantined_media = (
+ _create_media_and_set_attributes(
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ is_quarantined=True,
+ )
+ )
+ self.local_not_recently_accessed_protected_media = (
+ _create_media_and_set_attributes(
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ is_protected=True,
+ )
+ )
+ self.local_never_accessed_media = _create_media_and_set_attributes(
+ last_accessed_ms=None,
+ )
+
+ # And now the remote media store
+ self.remote_recently_accessed_media = _cache_remote_media_and_set_attributes(
+ media_id="a",
+ last_accessed_ms=self.THIRTY_DAYS_IN_MS,
+ )
+ self.remote_not_recently_accessed_media = (
+ _cache_remote_media_and_set_attributes(
+ media_id="b",
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ )
+ )
+ self.remote_not_recently_accessed_quarantined_media = (
+ _cache_remote_media_and_set_attributes(
+ media_id="c",
+ last_accessed_ms=self.ONE_DAY_IN_MS,
+ is_quarantined=True,
+ )
+ )
+ # Remote media will always have a "last accessed" attribute, as it would not
+ # be fetched from the remote homeserver unless instigated by a user.
+
+ @override_config(
+ {
+ "media_retention": {
+ # Enable retention for local media
+ "local_media_lifetime": "30d"
+ # Cached remote media should not be purged
+ }
+ }
+ )
+ def test_local_media_retention(self) -> None:
+ """
+ Tests that local media that have not been accessed recently is purged, while
+ cached remote media is unaffected.
+ """
+ # Advance 31 days (in seconds)
+ self.reactor.advance(31 * 24 * 60 * 60)
+
+ # Check that media has been correctly purged.
+ # Local media accessed <30 days ago should still exist.
+ # Remote media should be unaffected.
+ self._assert_if_mxc_uris_purged(
+ purged=[
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_media,
+ ),
+ (self.hs.config.server.server_name, self.local_never_accessed_media),
+ ],
+ not_purged=[
+ (self.hs.config.server.server_name, self.local_recently_accessed_media),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_quarantined_media,
+ ),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_protected_media,
+ ),
+ (self.remote_server_name, self.remote_recently_accessed_media),
+ (self.remote_server_name, self.remote_not_recently_accessed_media),
+ (
+ self.remote_server_name,
+ self.remote_not_recently_accessed_quarantined_media,
+ ),
+ ],
+ )
+
+ @override_config(
+ {
+ "media_retention": {
+ # Enable retention for cached remote media
+ "remote_media_lifetime": "30d"
+ # Local media should not be purged
+ }
+ }
+ )
+ def test_remote_media_cache_retention(self) -> None:
+ """
+ Tests that entries from the remote media cache that have not been accessed
+ recently is purged, while local media is unaffected.
+ """
+ # Advance 31 days (in seconds)
+ self.reactor.advance(31 * 24 * 60 * 60)
+
+ # Check that media has been correctly purged.
+ # Local media should be unaffected.
+ # Remote media accessed <30 days ago should still exist.
+ self._assert_if_mxc_uris_purged(
+ purged=[
+ (self.remote_server_name, self.remote_not_recently_accessed_media),
+ ],
+ not_purged=[
+ (self.remote_server_name, self.remote_recently_accessed_media),
+ (self.hs.config.server.server_name, self.local_recently_accessed_media),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_media,
+ ),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_quarantined_media,
+ ),
+ (
+ self.hs.config.server.server_name,
+ self.local_not_recently_accessed_protected_media,
+ ),
+ (
+ self.remote_server_name,
+ self.remote_not_recently_accessed_quarantined_media,
+ ),
+ (self.hs.config.server.server_name, self.local_never_accessed_media),
+ ],
+ )
+
+ def _assert_if_mxc_uris_purged(
+ self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]]
+ ) -> None:
+ def _assert_mxc_uri_purge_state(
+ server_name: str, media_id: str, expect_purged: bool
+ ) -> None:
+ """Given an MXC URI, assert whether it has been purged or not."""
+ if server_name == self.hs.config.server.server_name:
+ found_media_dict = self.get_success(
+ self.store.get_local_media(media_id)
+ )
+ else:
+ found_media_dict = self.get_success(
+ self.store.get_cached_remote_media(server_name, media_id)
+ )
+
+ mxc_uri = f"mxc://{server_name}/{media_id}"
+
+ if expect_purged:
+ self.assertIsNone(
+ found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
+ )
+ else:
+ self.assertIsNotNone(
+ found_media_dict,
+ msg=f"{mxc_uri} unexpectedly purged",
+ )
+
+ # Assert that the given MXC URIs have either been correctly purged or not.
+ for server_name, media_id in purged:
+ _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True)
+ for server_name, media_id in not_purged:
+ _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False)
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index 62e30881..ea9e5889 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase):
)
-class CalcOgTestCase(unittest.TestCase):
+class OpenGraphFromHtmlTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
@@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+ # Another variant is a title with no content.
+ html = b"""
+ <html>
+ <head><title></title></head>
+ <body>
+ <h1>Title</h1>
+ </body>
+ </html>
+ """
+
+ tree = decode_body(html, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
+
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
+
def test_h1_as_title(self) -> None:
html = b"""
<html>
@@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
+ def test_empty_description(self) -> None:
+ """Description tags with empty content should be ignored."""
+ html = b"""
+ <html>
+ <meta property="og:description" content=""/>
+ <meta property="og:description"/>
+ <meta name="description" content=""/>
+ <meta name="description"/>
+ <meta name="description" content="Finally!"/>
+ <body>
+ <h1>Title</h1>
+ </body>
+ </html>
+ """
+
+ tree = decode_body(html, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
+
+ self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
+
def test_missing_title_and_broken_h1(self) -> None:
html = b"""
<html>
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 3b24d0ac..2c321f8d 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -656,6 +656,41 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
+ def test_nonexistent_image(self) -> None:
+ """If the preview image doesn't exist, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = (
+ b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
+ )
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # The image should not be in the result.
+ self.assertNotIn("og:image", channel.json_body)
+
def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 19a145ee..22f99c6a 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration
@@ -49,8 +50,8 @@ class RegisterTestCase(TestCase):
requests.post = post
# The fake stdout will be written here
- out = []
- err_code = []
+ out: List[str] = []
+ err_code: List[int] = []
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
request_registration(
@@ -85,8 +86,8 @@ class RegisterTestCase(TestCase):
requests.get = get
# The fake stdout will be written here
- out = []
- err_code = []
+ out: List[str] = []
+ err_code: List[int] = []
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
request_registration(
@@ -137,8 +138,8 @@ class RegisterTestCase(TestCase):
requests.post = post
# The fake stdout will be written here
- out = []
- err_code = []
+ out: List[str] = []
+ err_code: List[int] = []
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
request_registration(
diff --git a/tests/server.py b/tests/server.py
index 8f30e250..b9f46597 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -109,6 +109,17 @@ class FakeChannel:
_ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
resource_usage: Optional[ContextResourceUsage] = None
+ _request: Optional[Request] = None
+
+ @property
+ def request(self) -> Request:
+ assert self._request is not None
+ return self._request
+
+ @request.setter
+ def request(self, request: Request) -> None:
+ assert self._request is None
+ self._request = request
@property
def json_body(self):
@@ -322,6 +333,8 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel, site)
+ channel.request = req
+
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(0, SEEK_END)
@@ -736,6 +749,7 @@ def setup_test_homeserver(
if config is None:
config = default_config(name, parse=True)
+ config.caches.resize_all_caches()
config.ldap_enabled = False
if "clock" not in kwargs:
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 9ee9509d..07e29788 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -75,6 +75,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
return_value=make_awaitable("!something:localhost")
)
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock(
+ return_value=make_awaitable("!something:localhost")
+ )
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@@ -102,6 +105,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once()
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
@@ -300,7 +304,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
hasn't been reached (since it's the only user and the limit is 5), so users
shouldn't receive a server notice.
"""
- self.register_user("user", "password")
+ m = Mock(return_value=make_awaitable(None))
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m
+
+ user_id = self.register_user("user", "password")
tok = self.login("user", "password")
channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
@@ -309,6 +316,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
"rooms", channel.json_body, "Got invites without server notice"
)
+ m.assert_called_once_with(user_id)
+
def test_invite_with_notice(self):
"""Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice.
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index c237a8c7..38963ce4 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -154,6 +154,31 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+ def test_event_ref(self):
+ """Test that we reuse events that are still in memory but have fallen
+ out of the cache, rather than requesting them from the DB.
+ """
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ # We keep hold of the event event though we never use it.
+ event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ self.get_success(self.store.get_event(self.event_id))
+
+ # Since the event is still in memory we shouldn't have fetched it
+ # from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
+
def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 74c6224e..3cc2a58d 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer, reactor
+from twisted.internet.base import ReactorBase
+from twisted.internet.defer import Deferred
+
from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
@@ -22,6 +26,56 @@ class LockTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs: HomeServer):
self.store = hs.get_datastores().main
+ def test_acquire_contention(self):
+ # Track the number of tasks holding the lock.
+ # Should be at most 1.
+ in_lock = 0
+ max_in_lock = 0
+
+ release_lock: "Deferred[None]" = Deferred()
+
+ async def task():
+ nonlocal in_lock
+ nonlocal max_in_lock
+
+ lock = await self.store.try_acquire_lock("name", "key")
+ if not lock:
+ return
+
+ async with lock:
+ in_lock += 1
+ max_in_lock = max(max_in_lock, in_lock)
+
+ # Block to allow other tasks to attempt to take the lock.
+ await release_lock
+
+ in_lock -= 1
+
+ # Start 3 tasks.
+ task1 = defer.ensureDeferred(task())
+ task2 = defer.ensureDeferred(task())
+ task3 = defer.ensureDeferred(task())
+
+ # Give the reactor a kick so that the database transaction returns.
+ self.pump()
+
+ release_lock.callback(None)
+
+ # Run the tasks to completion.
+ # To work around `Linearizer`s using a different reactor to sleep when
+ # contended (#12841), we call `runUntilCurrent` on
+ # `twisted.internet.reactor`, which is a different reactor to that used
+ # by the homeserver.
+ assert isinstance(reactor, ReactorBase)
+ self.get_success(task1)
+ reactor.runUntilCurrent()
+ self.get_success(task2)
+ reactor.runUntilCurrent()
+ self.get_success(task3)
+
+ # At most one task should have held the lock at a time.
+ self.assertEqual(max_in_lock, 1)
+
def test_simple_lock(self):
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1bf93e79..1047ed09 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -14,7 +14,7 @@
import json
import os
import tempfile
-from typing import List, Optional, cast
+from typing import List, cast
from unittest.mock import Mock
import yaml
@@ -149,15 +149,12 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def _set_state(
- self, id: str, state: ApplicationServiceState, txn: Optional[int] = None
- ):
+ def _set_state(self, id: str, state: ApplicationServiceState):
return self.db_pool.runOperation(
self.engine.convert_param_style(
- "INSERT INTO application_services_state(as_id, state, last_txn) "
- "VALUES(?,?,?)"
+ "INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
),
- (id, state.value, txn),
+ (id, state.value),
)
def _insert_txn(self, as_id, txn_id, events):
@@ -283,17 +280,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
- "SELECT last_txn FROM application_services_state WHERE as_id=?"
- ),
- (service.id,),
- )
- )
- self.assertEqual(1, len(res))
- self.assertEqual(txn_id, res[0][0])
-
- res = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
@@ -316,14 +302,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
- "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+ "SELECT state FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
)
self.assertEqual(1, len(res))
- self.assertEqual(txn_id, res[0][0])
- self.assertEqual(ApplicationServiceState.UP.value, res[0][1])
+ self.assertEqual(ApplicationServiceState.UP.value, res[0][0])
res = self.get_success(
self.db_pool.runQuery(
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index a8ffb52c..cce8e75c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -60,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
- self.datastore = SQLBaseStore(db, None, hs)
+ self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index bbf079b2..f37505b6 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -13,6 +13,7 @@
# limitations under the License.
import synapse.api.errors
+from synapse.api.constants import EduTypes
from tests.unittest import HomeserverTestCase
@@ -266,10 +267,12 @@ class DeviceStoreTestCase(HomeserverTestCase):
# (This is a temporary arrangement for backwards compatibility!)
self.assertEqual(len(device_updates), 2, device_updates)
self.assertEqual(
- device_updates[0][0], "m.signing_key_update", device_updates[0]
+ device_updates[0][0], EduTypes.SIGNING_KEY_UPDATE, device_updates[0]
)
self.assertEqual(
- device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
+ device_updates[1][0],
+ EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
+ device_updates[1],
)
# Check there are no more device updates left.
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 401020fd..a0ce077a 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
- txn, [(e, EventContext()) for e in events]
+ txn,
+ [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 645d564d..d92a9ac5 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- txn.execute(
- (
- "INSERT INTO event_reference_hashes "
- "(event_id, algorithm, hash) "
- "VALUES (?, 'sha256', ?)"
- ),
- (event_id, bytearray(b"ffff")),
- )
-
for i in range(0, 20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index ef5e2587..2ff88e64 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -31,7 +31,8 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
+ self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
@@ -69,9 +70,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
- self.state.compute_event_context(event, old_state=state)
+ self.state.compute_event_context(event, state_ids_before_event=state)
)
- self.get_success(self.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities):
"""Assert the current extremities for the room"""
@@ -103,9 +104,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -135,17 +138,20 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
- self.get_success(self.state.get_current_state(self.room_id))
+ self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
self.state.compute_event_context(
- remote_event_2, old_state=state_before_gap.values()
+ remote_event_2,
+ state_ids_before_event=state_before_gap,
)
)
- self.get_success(self.persistence.persist_event(remote_event_2, context))
+ self.get_success(self._persistence.persist_event(remote_event_2, context))
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -177,9 +183,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +215,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +257,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +301,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +337,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
@@ -340,7 +356,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self):
@@ -377,7 +393,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_rooms_for_user` to add the remote user to the cache
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@@ -424,7 +440,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_users_in_room` to add the remote user to the cache
users = self.get_success(self.store.get_users_in_room(room_id))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 4c29ad79..e8b4a564 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -407,3 +407,86 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service1], 2)
self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1)
+
+ def test_get_monthly_active_users_by_service(self):
+ # (No users, no filtering) -> empty result
+ result = self.get_success(self.store.get_monthly_active_users_by_service())
+
+ self.assertEqual(len(result), 0)
+
+ # (Some users, no filtering) -> non-empty result
+ appservice1_user1 = "@appservice1_user1:example.com"
+ appservice2_user1 = "@appservice2_user1:example.com"
+ service1 = "service1"
+ service2 = "service2"
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
+
+ result = self.get_success(self.store.get_monthly_active_users_by_service())
+
+ self.assertEqual(len(result), 2)
+ self.assertIn((service1, appservice1_user1), result)
+ self.assertIn((service2, appservice2_user1), result)
+
+ # (Some users, end-timestamp filtering) -> non-empty result
+ appservice1_user2 = "@appservice1_user2:example.com"
+ timestamp1 = self.reactor.seconds()
+ self.reactor.advance(5)
+ timestamp2 = self.reactor.seconds()
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
+
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ end_timestamp=round(timestamp1 * 1000)
+ )
+ )
+
+ self.assertEqual(len(result), 2)
+ self.assertNotIn((service1, appservice1_user2), result)
+
+ # (Some users, start-timestamp filtering) -> non-empty result
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ start_timestamp=round(timestamp2 * 1000)
+ )
+ )
+
+ self.assertEqual(len(result), 1)
+ self.assertIn((service1, appservice1_user2), result)
+
+ # (Some users, full-timestamp filtering) -> non-empty result
+ native_user1 = "@native_user1:example.com"
+ native = "native"
+ timestamp3 = self.reactor.seconds()
+ self.reactor.advance(100)
+ self.get_success(
+ self.store.register_user(
+ user_id=native_user1, password_hash=None, appservice_id=native
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(native_user1))
+
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ start_timestamp=round(timestamp2 * 1000),
+ end_timestamp=round(timestamp3 * 1000),
+ )
+ )
+
+ self.assertEqual(len(result), 1)
+ self.assertIn((service1, appservice1_user2), result)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 08cc6023..8dfaa055 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self):
"""
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
self.get_success(
- self.storage.purge_events.purge_history(self.room_id, token_str, True)
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
f = self.get_failure(
- self.storage.purge_events.purge_history(self.room_id, event, True),
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, event, True
+ ),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
@@ -98,14 +102,17 @@ class PurgeTests(HomeserverTestCase):
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
- state_handler = self.hs.get_state_handler()
create_event = self.get_success(
- state_handler.get_current_state(self.room_id, "m.room.create", "")
+ self._storage_controllers.state.get_current_state_event(
+ self.room_id, "m.room.create", ""
+ )
)
self.assertIsNotNone(create_event)
# Purge everything before this topological token
- self.get_success(self.storage.purge_events.purge_room(self.room_id))
+ self.get_success(
+ self._storage_controllers.purge_events.purge_room(self.room_id)
+ )
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d8d17ef3..6c4e63b7 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage = hs.get_storage_controllers()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.storage.persistence.persist_event(redaction_event, context)
+ self._storage.persistence.persist_event(redaction_event, context)
)
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 5b011e18..3c79dabc 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs):
self.get_success(
- self.storage.persistence.persist_event(
+ self._storage_controllers.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
@@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
- self.store.get_current_state(room_id=self.room.to_string())
+ self._storage_controllers.state.get_current_state(
+ room_id=self.room.to_string()
+ )
)
self.assertEqual(1, len(state))
@@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
- self.store.get_current_state(room_id=self.room.to_string())
+ self._storage_controllers.state.get_current_state(
+ room_id=self.room.to_string()
+ )
)
self.assertEqual(1, len(state))
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 8dfc1e1d..e747c6b5 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
prev_state_map = self.get_success(
- self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
+ self.hs.get_storage_controllers().state.get_state_ids_for_event(
+ prev_event_ids[0]
+ )
)
event_dict = {
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index a2a9c05f..1218786d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -34,7 +34,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override]
# We can't test the RoomMemberStore on its own without the other event
# storage logic
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index f88f1c55..8043bdbd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7f1964eb..5b60cf52 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -134,7 +134,6 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice = ApplicationService(
token="i_am_an_app_service",
- hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 303e190b..cae14151 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -17,8 +17,12 @@ from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+ PartialCurrentStateTracker,
+ PartialStateEventsTracker,
+)
+from tests.test_utils import make_awaitable
from tests.unittest import TestCase
@@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d2)
+
+
+class PartialCurrentStateTrackerTestCase(TestCase):
+ def setUp(self) -> None:
+ self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+
+ self.tracker = PartialCurrentStateTracker(self.mock_store)
+
+ def test_does_not_block_for_full_state_rooms(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_blocks_for_partial_room_state(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d = ensureDeferred(self.tracker.await_full_state("room_id"))
+
+ # there should be no result yet
+ self.assertNoResult(d)
+
+ # notifying that the room has been de-partial-stated should unblock
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d)
+
+ def test_un_partial_state_race(self):
+ # We should correctly handle race between awaiting the state and us
+ # un-partialling the state
+ async def is_partial_state_room(events):
+ self.tracker.notify_un_partial_stated("room_id")
+ return True
+
+ self.mock_store.is_partial_state_room.side_effect = is_partial_state_room
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_cancellation(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d1)
+
+ d2 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d2)
+
+ d1.cancel()
+ self.assertFailure(d1, CancelledError)
+
+ # d2 should still be waiting!
+ self.assertNoResult(d2)
+
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d2)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 5bbc361a..f14fcb7d 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -105,7 +105,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.store.services_cache.append(
ApplicationService(
token=as_token,
- hostname=self.hs.hostname,
id="SomeASID",
sender="@as_sender:test",
namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
@@ -251,7 +250,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.store.services_cache.append(
ApplicationService(
token=as_token_1,
- hostname=self.hs.hostname,
id="SomeASID",
sender="@as_sender_1:test",
namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]},
@@ -262,7 +260,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.store.services_cache.append(
ApplicationService(
token=as_token_2,
- hostname=self.hs.hostname,
id="AnotherASID",
sender="@as_sender_2:test",
namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]},
diff --git a/tests/test_server.py b/tests/test_server.py
index f2ffbc89..0f1eb43c 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -13,18 +13,28 @@
# limitations under the License.
import re
+from http import HTTPStatus
+from typing import Tuple
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
-from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ JsonResource,
+ OptionsResource,
+ cancellable,
+)
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
@@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
+
+
+class CancellableDirectServeJsonResource(DirectServeJsonResource):
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
+ ERROR_TEMPLATE = "{code} {msg}"
+
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+
+class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeJsonResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeJsonResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
+
+
+class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeHtmlResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeHtmlResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body=b"499 Request cancelled",
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
+ )
diff --git a/tests/test_state.py b/tests/test_state.py
index e4baa691..95f81beb 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -88,6 +88,9 @@ class _DummyStore:
return groups
+ async def get_state_ids_for_group(self, state_group, state_filter=None):
+ return self._group_to_state[state_group]
+
async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
@@ -126,6 +129,19 @@ class _DummyStore:
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
+ async def get_state_group_for_events(self, event_ids):
+ res = {}
+ for event in event_ids:
+ res[event] = self._event_to_state_group[event]
+ return res
+
+ async def get_state_for_groups(self, groups):
+ res = {}
+ for group in groups:
+ state = self._group_to_state[group]
+ res[group] = state
+ return res
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -163,12 +179,12 @@ class Graph:
class StateTestCase(unittest.TestCase):
def setUp(self):
self.dummy_store = _DummyStore()
- storage = Mock(main=self.dummy_store, state=self.dummy_store)
+ storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock(
spec_set=[
"config",
"get_datastores",
- "get_storage",
+ "get_storage_controllers",
"get_auth",
"get_state_handler",
"get_clock",
@@ -183,7 +199,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
- hs.get_storage.return_value = storage
+ hs.get_storage_controllers.return_value = storage_controllers
self.state = StateHandler(hs)
self.event_id = 0
@@ -426,7 +442,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
- self.state.compute_event_context(event, old_state=old_state)
+ self.state.compute_event_context(
+ event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in old_state
+ },
+ )
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -451,7 +472,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
- self.state.compute_event_context(event, old_state=old_state)
+ self.state.compute_event_context(
+ event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in old_state
+ },
+ )
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
diff --git a/tests/test_types.py b/tests/test_types.py
index 80888a74..0b10dae8 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -13,7 +13,7 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
@@ -62,25 +62,6 @@ class RoomAliasTestCase(unittest.HomeserverTestCase):
self.assertFalse(RoomAlias.is_valid(id_string))
-class GroupIDTestCase(unittest.TestCase):
- def test_parse(self):
- group_id = GroupID.from_string("+group/=_-.123:my.domain")
- self.assertEqual("group/=_-.123", group_id.localpart)
- self.assertEqual("my.domain", group_id.domain)
-
- def test_validate(self):
- bad_ids = ["$badsigil:domain", "+:empty"] + [
- "+group" + c + ":domain" for c in "A%?æ£"
- ]
- for id_string in bad_ids:
- try:
- GroupID.from_string(id_string)
- self.fail("Parsing '%s' should raise exception" % id_string)
- except SynapseError as exc:
- self.assertEqual(400, exc.code)
- self.assertEqual("M_INVALID_PARAM", exc.errcode)
-
-
class MapUsernameTestCase(unittest.TestCase):
def testPassThrough(self):
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c654e36e..8027c7a8 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -70,7 +70,7 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- persistence = hs.get_storage().persistence
+ persistence = hs.get_storage_controllers().persistence
assert persistence is not None
await persistence.persist_event(event, context)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0230f9e..f338af6c 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
events_to_filter.append(evt)
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
outlier = self._inject_outlier()
self.assertEqual(
self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier]
+ )
),
[outlier],
)
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier, evt]
+ )
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
self.assertEqual(filtered[0], outlier)
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... but other servers should only be able to see the outlier (the other should
# be redacted)
filtered = self.get_success(
- filter_events_for_server(self.storage, "other_server", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "other_server", [outlier, evt]
+ )
)
self.assertEqual(filtered[0], outlier)
self.assertEqual(filtered[1].event_id, evt.event_id)
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
for i in range(0, len(events_to_filter)):
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_room_member(
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_message(
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_outlier(self) -> EventBase:
@@ -234,7 +250,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self.storage.persistence.persist_event(event, EventContext.for_outlier())
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
+ )
)
return event
@@ -291,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
)
),
[invite_event, reject_event],
@@ -301,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@other:test",
+ [invite_event, reject_event],
)
),
[],
diff --git a/tests/unittest.py b/tests/unittest.py
index 9afa68c1..e7f255b4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site,
method=method,
path=path,
- content=content or "",
+ content=content if content is not None else "",
shorthand=False,
await_result=await_result,
custom_headers=custom_headers,
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 321fc177..67173a4f 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,8 +14,9 @@
from typing import List
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
+from synapse.metrics.jemalloc import JemallocStats
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache
@@ -316,3 +317,58 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key1"), None)
self.assertEqual(cache.get("key2"), 3)
+
+
+class MemoryEvictionTestCase(unittest.HomeserverTestCase):
+ @override_config(
+ {
+ "caches": {
+ "cache_autotuning": {
+ "max_cache_memory_usage": "700M",
+ "target_cache_memory_usage": "500M",
+ "min_cache_ttl": "5m",
+ }
+ }
+ }
+ )
+ @patch("synapse.util.caches.lrucache.get_jemalloc_stats")
+ def test_evict_memory(self, jemalloc_interface) -> None:
+ mock_jemalloc_class = Mock(spec=JemallocStats)
+ jemalloc_interface.return_value = mock_jemalloc_class
+
+ # set the return value of get_stat() to be greater than max_cache_memory_usage
+ mock_jemalloc_class.get_stat.return_value = 924288000
+
+ setup_expire_lru_cache_entries(self.hs)
+ cache = LruCache(4, clock=self.hs.get_clock())
+
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ # advance the reactor less than the min_cache_ttl
+ self.reactor.advance(60 * 2)
+
+ # our items should still be in the cache
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
+
+ # advance the reactor past the min_cache_ttl
+ self.reactor.advance(60 * 6)
+
+ # the items should be cleared from cache
+ self.assertEqual(cache.get("key1"), None)
+ self.assertEqual(cache.get("key2"), None)
+
+ # add more stuff to caches
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ # set the return value of get_stat() to be lower than target_cache_memory_usage
+ mock_jemalloc_class.get_stat.return_value = 10000
+
+ # advance the reactor past the min_cache_ttl
+ self.reactor.advance(60 * 6)
+
+ # the items should still be in the cache
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
diff --git a/tests/utils.py b/tests/utils.py
index d4ba3a9b..3059c453 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -264,7 +264,7 @@ class MockClock:
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room"""
- persistence_store = hs.get_storage().persistence
+ persistence_store = hs.get_storage_controllers().persistence
store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()