From 6dc64c92c6991f09910f3e6db368e6eeb4b1981e Mon Sep 17 00:00:00 2001 From: Andrej Shadura Date: Sun, 19 Jun 2022 15:20:00 +0200 Subject: New upstream version 1.61.0 --- tests/api/test_auth.py | 2 - tests/api/test_filtering.py | 6 +- tests/api/test_ratelimiting.py | 2 - tests/appservice/test_api.py | 101 +++++++ tests/appservice/test_appservice.py | 3 +- tests/config/test_cache.py | 8 + tests/crypto/test_event_signing.py | 17 +- tests/crypto/test_keyring.py | 2 +- tests/events/test_presence_router.py | 2 +- tests/events/test_snapshot.py | 4 +- tests/federation/test_federation_sender.py | 40 +-- tests/federation/test_federation_server.py | 6 +- tests/federation/transport/server/__init__.py | 13 + tests/federation/transport/server/test__base.py | 141 +++++++++ tests/federation/transport/test_server.py | 4 +- tests/handlers/test_appservice.py | 16 +- tests/handlers/test_directory.py | 3 +- tests/handlers/test_federation.py | 19 +- tests/handlers/test_federation_event.py | 15 +- tests/handlers/test_message.py | 14 +- tests/handlers/test_receipts.py | 94 +++--- tests/handlers/test_room_summary.py | 20 +- tests/handlers/test_sync.py | 1 + tests/handlers/test_typing.py | 41 ++- tests/handlers/test_user_directory.py | 3 +- tests/http/server/__init__.py | 13 + tests/http/server/_base.py | 100 +++++++ tests/http/test_fedclient.py | 6 +- tests/http/test_servlet.py | 74 ++++- tests/http/test_site.py | 2 +- tests/module_api/test_api.py | 2 +- tests/push/test_push_rule_evaluator.py | 84 +++++- tests/replication/_base.py | 54 +++- tests/replication/http/__init__.py | 13 + tests/replication/http/test__base.py | 106 +++++++ tests/replication/slave/storage/_base.py | 2 +- tests/replication/slave/storage/test_events.py | 10 +- tests/replication/slave/storage/test_receipts.py | 12 +- tests/replication/tcp/test_handler.py | 73 +++++ tests/replication/test_sharded_event_persister.py | 14 +- tests/rest/admin/test_admin.py | 90 +----- tests/rest/admin/test_room.py | 3 +- tests/rest/admin/test_user.py | 4 +- tests/rest/client/test_account.py | 1 - tests/rest/client/test_auth.py | 41 +++ tests/rest/client/test_device_lists.py | 159 ---------- tests/rest/client/test_devices.py | 202 +++++++++++++ tests/rest/client/test_events.py | 3 +- tests/rest/client/test_groups.py | 56 ---- tests/rest/client/test_login.py | 2 - tests/rest/client/test_mutual_rooms.py | 2 - tests/rest/client/test_notifications.py | 91 ++++++ tests/rest/client/test_register.py | 2 - tests/rest/client/test_relations.py | 89 ++++-- tests/rest/client/test_retention.py | 39 ++- tests/rest/client/test_room_batch.py | 7 +- tests/rest/client/test_rooms.py | 267 ++++++++++++++++- tests/rest/client/test_sendtodevice.py | 5 +- tests/rest/client/test_shadow_banned.py | 4 +- tests/rest/client/test_sync.py | 41 +-- tests/rest/client/test_typing.py | 3 +- tests/rest/client/test_upgrade_room.py | 38 ++- tests/rest/media/test_media_retention.py | 321 +++++++++++++++++++++ tests/rest/media/v1/test_html_preview.py | 37 ++- tests/rest/media/v1/test_url_preview.py | 35 +++ tests/scripts/test_new_matrix_user.py | 13 +- tests/server.py | 14 + .../test_resource_limits_server_notices.py | 11 +- tests/storage/databases/main/test_events_worker.py | 25 ++ tests/storage/databases/main/test_lock.py | 54 ++++ tests/storage/test_appservice.py | 27 +- tests/storage/test_base.py | 2 +- tests/storage/test_devices.py | 7 +- tests/storage/test_event_chain.py | 3 +- tests/storage/test_event_federation.py | 9 - tests/storage/test_events.py | 58 ++-- tests/storage/test_monthly_active_users.py | 83 ++++++ tests/storage/test_purge.py | 19 +- tests/storage/test_redaction.py | 14 +- tests/storage/test_room.py | 12 +- tests/storage/test_room_search.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 2 +- tests/storage/test_user_directory.py | 1 - .../util/test_partial_state_events_tracker.py | 59 +++- tests/test_mau.py | 3 - tests/test_server.py | 111 ++++++- tests/test_state.py | 36 ++- tests/test_types.py | 21 +- tests/test_utils/event_injection.py | 2 +- tests/test_visibility.py | 46 ++- tests/unittest.py | 2 +- tests/util/test_lrucache.py | 58 +++- tests/utils.py | 2 +- 94 files changed, 2659 insertions(+), 725 deletions(-) create mode 100644 tests/appservice/test_api.py create mode 100644 tests/federation/transport/server/__init__.py create mode 100644 tests/federation/transport/server/test__base.py create mode 100644 tests/http/server/__init__.py create mode 100644 tests/http/server/_base.py create mode 100644 tests/replication/http/__init__.py create mode 100644 tests/replication/http/test__base.py create mode 100644 tests/replication/tcp/test_handler.py delete mode 100644 tests/rest/client/test_device_lists.py create mode 100644 tests/rest/client/test_devices.py delete mode 100644 tests/rest/client/test_groups.py create mode 100644 tests/rest/client/test_notifications.py create mode 100644 tests/rest/media/test_media_retention.py (limited to 'tests') diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d547df8a..bc75ddd3 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 985d6e39..a269c477 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -20,7 +20,7 @@ from unittest.mock import patch import jsonschema from frozendict import frozendict -from synapse.api.constants import EventContentFields +from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import make_event_from_dict @@ -85,13 +85,13 @@ class FilteringTestCase(unittest.HomeserverTestCase): "org.matrix.not_labels": ["#work"], }, "ephemeral": { - "types": ["m.receipt", "m.typing"], + "types": [EduTypes.RECEIPT, EduTypes.TYPING], "not_rooms": ["!726s6s6q:example.com"], "not_senders": ["@spam:example.com"], }, }, "presence": { - "types": ["m.presence"], + "types": [EduTypes.PRESENCE], "not_senders": ["@alice:example.com"], }, "event_format": "client", diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483d5463..f661a9ff 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=True, sender="@as:example.com", @@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=False, sender="@as:example.com", diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py new file mode 100644 index 00000000..532b6763 --- /dev/null +++ b/tests/appservice/test_api.py @@ -0,0 +1,101 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Mapping +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.appservice import ApplicationService +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +PROTOCOL = "myproto" +TOKEN = "myastoken" +URL = "http://mytestservice" + + +class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.api = hs.get_application_service_api() + self.service = ApplicationService( + id="unique_identifier", + sender="@as:test", + url=URL, + token="unused", + hs_token=TOKEN, + ) + + def test_query_3pe_authenticates_token(self): + """ + Tests that 3pe queries to the appservice are authenticated + with the appservice's token. + """ + + SUCCESS_RESULT_USER = [ + { + "protocol": PROTOCOL, + "userid": "@a:user", + "fields": { + "more": "fields", + }, + } + ] + SUCCESS_RESULT_LOCATION = [ + { + "protocol": PROTOCOL, + "alias": "#a:room", + "fields": { + "more": "fields", + }, + } + ] + + URL_USER = f"{URL}/_matrix/app/unstable/thirdparty/user/{PROTOCOL}" + URL_LOCATION = f"{URL}/_matrix/app/unstable/thirdparty/location/{PROTOCOL}" + + self.request_url = None + + async def get_json(url: str, args: Mapping[Any, Any]) -> List[JsonDict]: + if not args.get(b"access_token"): + raise RuntimeError("Access token not provided") + + self.assertEqual(args.get(b"access_token"), TOKEN) + self.request_url = url + if url == URL_USER: + return SUCCESS_RESULT_USER + elif url == URL_LOCATION: + return SUCCESS_RESULT_LOCATION + else: + raise RuntimeError( + "URL provided was invalid. This should never be seen." + ) + + # We assign to a method, which mypy doesn't like. + self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + + result = self.get_success( + self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) + ) + self.assertEqual(self.request_url, URL_USER) + self.assertEqual(result, SUCCESS_RESULT_USER) + result = self.get_success( + self.api.query_3pe( + self.service, "location", PROTOCOL, {b"some": [b"field"]} + ) + ) + self.assertEqual(self.request_url, URL_LOCATION) + self.assertEqual(result, SUCCESS_RESULT_LOCATION) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index edc584d0..3018d3fc 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -23,7 +23,7 @@ from tests.test_utils import simple_async_mock def _regex(regex: str, exclusive: bool = True) -> Namespace: - return Namespace(exclusive, None, re.compile(regex)) + return Namespace(exclusive, re.compile(regex)) class ApplicationServiceTestCase(unittest.TestCase): @@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase): sender="@as:test", url="some_url", token="some_token", - hostname="matrix.org", # only used by get_groups_for_user ) self.event = Mock( event_id="$abc:xyz", diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 4bb82e81..d2b3c299 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -38,6 +38,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_NOT_CACHE": "BLAH", } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) @@ -52,6 +53,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_CACHE_FACTOR_FOO": 1, } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual( dict(self.config.cache_factors), @@ -71,6 +73,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"per_cache_factors": {"foo": 3}}} self.config.read_config(config) + self.config.resize_all_caches() self.assertEqual(cache.max_size, 300) @@ -82,6 +85,7 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"per_cache_factors": {"foo": 2}}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -99,6 +103,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"global_factor": 4}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual(cache.max_size, 400) @@ -110,6 +115,7 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"global_factor": 1.5}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -128,6 +134,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -148,6 +155,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"event_cache_size": "10k"}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache( max_size=self.config.event_cache_size, diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 06e0545a..8fa710c9 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import nacl.signing -import signedjson.types -from unpaddedbase64 import decode_base64 +from signedjson.key import decode_signing_key_base64 +from signedjson.types import SigningKey from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures @@ -25,7 +23,7 @@ from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. -SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") +SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1" KEY_ALG = "ed25519" KEY_VER = "1" @@ -36,14 +34,9 @@ HOSTNAME = "domain" class EventSigningTestCase(unittest.TestCase): def setUp(self): - # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been - # monkeypatched to include new `alg` and `version` attributes. This is captured - # by the `signedjson.types.SigningKey` protocol. - self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment] - SIGNING_KEY_SEED + self.signing_key: SigningKey = decode_signing_key_base64( + KEY_ALG, KEY_VER, SIGNING_KEY_SEED ) - self.signing_key.alg = KEY_ALG - self.signing_key.version = KEY_VER def test_sign_minimal(self): event_dict = { diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d00ef24c..820a1a54 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,8 +19,8 @@ import attr import canonicaljson import signedjson.key import signedjson.sign -from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key +from signedjson.types import SigningKey from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 3deb14c3..ffc3012a 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -439,7 +439,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): for edu in edus: # Make sure we're only checking presence-type EDUs - if edu["edu_type"] != EduTypes.Presence: + if edu["edu_type"] != EduTypes.PRESENCE: continue # EDUs can contain multiple presence updates diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index defbc68c..8ddce83b 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") @@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase): def _check_serialize_deserialize(self, event, context): serialized = self.get_success(context.serialize(event, self.store)) - d_context = EventContext.deserialize(self.storage, serialized) + d_context = EventContext.deserialize(self._storage_controllers, serialized) self.assertEqual(context.state_group, d_context.state_group) self.assertEqual(context.rejected, d_context.rejected) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 6b26353d..01a1db61 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -19,7 +19,7 @@ from signedjson.types import BaseKey, SigningKey from twisted.internet import defer -from synapse.api.constants import RoomEncryptionAlgorithms +from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms from synapse.rest import admin from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt @@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) - # Ensure a new Awaitable is created for each call. - mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable( - ["test", "host2"] - ) - return self.setup_test_homeserver( - state_handler=mock_state_handler, + hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( + return_value=make_awaitable({"test", "host2"}) + ) + + return hs + @override_config({"send_federation": True}) def test_send_receipts(self): mock_send_transaction = ( @@ -63,7 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): data["edus"], [ { - "edu_type": "m.receipt", + "edu_type": EduTypes.RECEIPT, "content": { "room_id": { "m.read": { @@ -103,7 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): data["edus"], [ { - "edu_type": "m.receipt", + "edu_type": EduTypes.RECEIPT, "content": { "room_id": { "m.read": { @@ -138,7 +138,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): data["edus"], [ { - "edu_type": "m.receipt", + "edu_type": EduTypes.RECEIPT, "content": { "room_id": { "m.read": { @@ -322,8 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # expect signing key update edu self.assertEqual(len(self.edus), 2) - self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") - self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") + self.assertEqual(self.edus.pop(0)["edu_type"], EduTypes.SIGNING_KEY_UPDATE) + self.assertEqual( + self.edus.pop(0)["edu_type"], EduTypes.UNSTABLE_SIGNING_KEY_UPDATE + ) # sign the devices d1_json = build_device_dict(u1, "D1", device1_signing_key) @@ -348,7 +350,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(len(self.edus), 2) stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 for edu in self.edus: - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] if stream_id is not None: self.assertEqual(c["prev_id"], [stream_id]) @@ -388,7 +390,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # expect three edus, in an unknown order self.assertEqual(len(self.edus), 3) for edu in self.edus: - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] self.assertGreaterEqual( c.items(), @@ -435,7 +437,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(len(self.edus), 3) stream_id = None for edu in self.edus: - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) if stream_id is not None: @@ -487,7 +489,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # there should be a single update for this user. self.assertEqual(len(self.edus), 1) edu = self.edus.pop(0) - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] # synapse uses an empty prev_id list to indicate "needs a full resync". @@ -544,7 +546,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # ... and we should get a single update for this user. self.assertEqual(len(self.edus), 1) edu = self.edus.pop(0) - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] # synapse uses an empty prev_id list to indicate "needs a full resync". @@ -560,7 +562,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): """Check that the given EDU is an update for the given device Returns the stream_id. """ - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) content = edu["content"] expected = { diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b19365b8..413b3c94 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): super().prepare(reactor, clock, hs) + self._storage_controllers = hs.get_storage_controllers() + # create the room creator_user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") @@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py new file mode 100644 index 00000000..3a5f22c0 --- /dev/null +++ b/tests/federation/transport/server/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py new file mode 100644 index 00000000..e63885c1 --- /dev/null +++ b/tests/federation/transport/server/test__base.py @@ -0,0 +1,141 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Dict, List, Tuple + +from synapse.api.errors import Codes +from synapse.federation.transport.server import BaseFederationServlet +from synapse.federation.transport.server._base import Authenticator, _parse_auth_header +from synapse.http.server import JsonResource, cancellable +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin + + +class CancellableFederationServlet(BaseFederationServlet): + PATH = "/sleep" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.clock = hs.get_clock() + + @cancellable + async def on_GET( + self, origin: str, content: None, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class BaseFederationServletCancellationTests( + unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `BaseFederationServlet` cancellation.""" + + skip = "`BaseFederationServlet` does not support cancellation yet." + + path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}" + + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`.""" + resource = JsonResource(self.hs) + + CancellableFederationServlet( + hs=self.hs, + authenticator=Authenticator(self.hs), + ratelimiter=self.hs.get_federation_ratelimiter(), + server_name=self.hs.hostname, + ).register(resource) + + return resource + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = self.make_signed_federation_request( + "GET", self.path, await_result=False + ) + + # Advance past all the rate limiting logic. If we disconnect too early, the + # request won't be processed. + self.pump() + + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = self.make_signed_federation_request( + "POST", + self.path, + content={}, + await_result=False, + ) + + # Advance past all the rate limiting logic. If we disconnect too early, the + # request won't be processed. + self.pump() + + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) + + +class BaseFederationAuthorizationTests(unittest.TestCase): + def test_authorization_header(self) -> None: + """Tests that the Authorization header is parsed correctly.""" + + # test a "normal" Authorization header + self.assertEqual( + _parse_auth_header( + b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar"' + ), + ("foo", "ed25519:1", "sig", "bar"), + ) + # test an Authorization with extra spaces, upper-case names, and escaped + # characters + self.assertEqual( + _parse_auth_header( + b'X-Matrix ORIGIN=foo,KEY="ed25\\519:1",SIG="sig",destination="bar"' + ), + ("foo", "ed25519:1", "sig", "bar"), + ) + self.assertEqual( + _parse_auth_header( + b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar",extra_field=ignored' + ), + ("foo", "ed25519:1", "sig", "bar"), + ) diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index 5f001c33..cfd550a0 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import EduTypes + from tests import unittest from tests.unittest import DEBUG, override_config @@ -50,7 +52,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/send/txn_id_1234/", content={ "edus": [ - {"edu_type": "m.device_list_update", "content": {"foo": "bar"}} + {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}} ], "pdus": [], }, diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 5b0cd1ab..d96d5aa1 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage +from synapse.api.constants import EduTypes from synapse.appservice import ( ApplicationService, TransactionOneTimeKeyCounts, @@ -434,16 +435,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): }, ) - # "Complete" a transaction. - # All this really does for us is make an entry in the application_services_state - # database table, which tracks the current stream_token per stream ID per AS. - self.get_success( - self.hs.get_datastores().main.complete_appservice_txn( - 0, - interested_appservice, - ) - ) - # Now, pretend that we receive a large burst of read receipts (300 total) that # all come in at once. for i in range(300): @@ -486,7 +477,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Check that the ephemeral event is a read receipt with the expected structure latest_read_receipt = all_ephemeral_events[-1] - self.assertEqual(latest_read_receipt["type"], "m.receipt") + self.assertEqual(latest_read_receipt["type"], EduTypes.RECEIPT) event_id = list(latest_read_receipt["content"].keys())[0] self.assertEqual( @@ -706,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create an application service appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -785,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Create an appservice that is interested in "local_user" appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -852,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): self._service_token = "VERYSECRET" self._service = ApplicationService( self._service_token, - "as1.invalid", "as1", "@as.sender:test", namespaces={ diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 11ad4422..53d49ca8 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() # Create user self.admin_user = self.register_user("admin", "pass", admin=True) @@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): def _get_canonical_alias(self): """Get the canonical alias state of the room.""" return self.get_success( - self.state_handler.get_current_state( + self._storage_controllers.state.get_current_state_event( self.room_id, EventTypes.CanonicalAlias, "" ) ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 060ba5f5..e0eda545 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_store = hs.get_storage().state + self.state_storage_controller = hs.get_storage_controllers().state self._event_auth_handler = hs.get_event_auth_handler() return hs @@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) current_state = self.get_success( self.store.get_events_as_list( - (self.get_success(self.store.get_current_state_ids(room_id))).values() + ( + self.get_success(self.store.get_partial_current_state_ids(room_id)) + ).values() ) ) @@ -276,7 +278,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # federation handler wanting to backfill the fake event. self.get_success( federation_event_handler._process_received_pdu( - self.OTHER_SERVER_NAME, event, state=current_state + self.OTHER_SERVER_NAME, + event, + state_ids={ + (e.type, e.state_key): e.event_id for e in current_state + }, ) ) @@ -332,8 +338,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): most_recent_prev_event_depth, ) = self.get_success(self.store.get_max_depth_of(prev_event_ids)) # mapping from (type, state_key) -> state_event_id + assert most_recent_prev_event_id is not None prev_state_map = self.get_success( - self.state_store.get_state_ids_for_event(most_recent_prev_event_id) + self.state_storage_controller.get_state_ids_for_event( + most_recent_prev_event_id + ) ) # List of state event ID's prev_state_ids = list(prev_state_map.values()) @@ -505,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): self.get_success(d) # sanity-check: the room should show that the new user is a member - r = self.get_success(self.store.get_current_state_ids(room_id)) + r = self.get_success(self.store.get_partial_current_state_ids(room_id)) self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) return join_event diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 489ba577..1a36c25c 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) -> None: OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main - state_storage = self.hs.get_storage().state + state_storage_controller = self.hs.get_storage_controllers().state # create the room user_id = self.register_user("kermit", "test") @@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") ) - initial_state_map = self.get_success(main_store.get_current_state_ids(room_id)) + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) auth_event_ids = [ initial_state_map[("m.room.create", "")], @@ -146,9 +148,12 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) if prev_exists_as_outlier: prev_event.internal_metadata.outlier = True - persistence = self.hs.get_storage().persistence + persistence = self.hs.get_storage_controllers().persistence self.get_success( - persistence.persist_event(prev_event, EventContext.for_outlier()) + persistence.persist_event( + prev_event, + EventContext.for_outlier(self.hs.get_storage_controllers()), + ) ) else: @@ -214,7 +219,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # check that the state at that event is as expected state = self.get_success( - state_storage.get_state_ids_for_event(pulled_event.event_id) + state_storage_controller.get_state_ids_for_event(pulled_event.event_id) ) expected_state = { (e.type, e.state_key): e.event_id for e in state_at_prev_event diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index f4f7ab48..44da96c7 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_event_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self._persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") @@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self._persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) return memberEvent, memberEventContext @@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( - self.persist_event_storage.persist_event(event3, context) + self._persist_event_storage_controller.persist_event(event3, context) ) # Assert that the returned values match those from the initial event @@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events([(event3, context)]) + self._persist_event_storage_controller.persist_events([(event3, context)]) ) ret_event4 = events[0] @@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event2.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events( + self._persist_event_storage_controller.persist_events( [(event1, context1), (event2, context2)] ) ) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 0482a1ea..a95868b5 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from copy import deepcopy from typing import List -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict from tests import unittest @@ -39,7 +39,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [], @@ -64,7 +64,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -79,7 +79,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -105,7 +105,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -120,43 +120,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - ) - - def test_handles_missing_content_of_m_read(self): - self._test_filters_private( - [ - { - "content": { - "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, - "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ: { - "@user:jki.re": { - "ts": 1436451550453, - } - } - }, - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - [ - { - "content": { - "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, - "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ: { - "@user:jki.re": { - "ts": 1436451550453, - } - } - }, - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -176,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -191,7 +155,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -210,7 +174,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, { "content": { @@ -223,7 +187,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, ], [ @@ -238,7 +202,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -260,7 +224,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, ], [ @@ -273,7 +237,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, ], ) @@ -302,7 +266,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -327,14 +291,38 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) + def test_we_do_not_mutate(self): + """Ensure the input values are not modified.""" + events = [ + { + "content": { + "$1435641916114394fHBLK:matrix.org": { + ReceiptTypes.READ_PRIVATE: { + "@rikj:jki.re": { + "ts": 1436451550453, + } + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": EduTypes.RECEIPT, + } + ] + original_events = deepcopy(events) + self._test_filters_private(events, []) + # Since the events are fed in from a cache they should not be modified. + self.assertEqual(events, original_events) + def _test_filters_private( self, events: List[JsonDict], expected_output: List[JsonDict] ): """Tests that the _filter_out_private returns the expected output""" - filtered_events = self.event_source.filter_out_private(events, "@me:server.org") + filtered_events = self.event_source.filter_out_private_receipts( + events, "@me:server.org" + ) self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index e74eb717..05466556 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -179,7 +179,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result_children_ids.append( [ (cs["room_id"], cs["state_key"]) - for cs in result_room.get("children_state") + for cs in result_room["children_state"] ] ) @@ -772,7 +772,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": public_room, "world_readable": False, - "join_rules": JoinRules.PUBLIC, + "join_rule": JoinRules.PUBLIC, }, ), ( @@ -780,7 +780,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": knock_room, "world_readable": False, - "join_rules": JoinRules.KNOCK, + "join_rule": JoinRules.KNOCK, }, ), ( @@ -788,7 +788,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": not_invited_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ( @@ -796,7 +796,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": invited_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ( @@ -804,7 +804,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": restricted_room, "world_readable": False, - "join_rules": JoinRules.RESTRICTED, + "join_rule": JoinRules.RESTRICTED, "allowed_room_ids": [], }, ), @@ -813,7 +813,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": restricted_accessible_room, "world_readable": False, - "join_rules": JoinRules.RESTRICTED, + "join_rule": JoinRules.RESTRICTED, "allowed_room_ids": [self.room], }, ), @@ -822,7 +822,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": world_readable_room, "world_readable": True, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ( @@ -830,7 +830,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": joined_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ) @@ -911,7 +911,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": fed_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 865b8b7e..db3302a4 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -160,6 +160,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() self.store._get_event_cache.clear() + self.store._event_ref.clear() # The rooms should be excluded from the sync response. # Get a new request key. diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5f2e26a5..7af13331 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -21,6 +21,7 @@ from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource +from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer from synapse.server import HomeServer @@ -128,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_event_auth_handler().check_host_in_room = check_host_in_room - def get_joined_hosts_for_room(room_id: str): + async def get_current_hosts_in_room(room_id: str): return {member.domain for member in self.room_members} - self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room + hs.get_storage_controllers().state.get_current_hosts_in_room = ( + get_current_hosts_in_room + ) async def get_users_in_room(room_id: str): return {str(u) for u in self.room_members} @@ -145,7 +148,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = ( @@ -184,7 +187,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_APPLE.to_string()]}, } @@ -209,7 +212,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( - "m.typing", + EduTypes.TYPING, content={ "room_id": ROOM_ID, "user_id": U_APPLE.to_string(), @@ -231,7 +234,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "PUT", "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( - "m.typing", + EduTypes.TYPING, content={ "room_id": ROOM_ID, "user_id": U_ONION.to_string(), @@ -254,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_ONION.to_string()]}, } @@ -270,7 +273,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "PUT", "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( - "m.typing", + EduTypes.TYPING, content={ "room_id": OTHER_ROOM_ID, "user_id": U_ONION.to_string(), @@ -324,7 +327,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( - "m.typing", + EduTypes.TYPING, content={ "room_id": ROOM_ID, "user_id": U_APPLE.to_string(), @@ -345,7 +348,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual( events[0], - [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], + [ + { + "type": EduTypes.TYPING, + "room_id": ROOM_ID, + "content": {"user_ids": []}, + } + ], ) def test_typing_timeout(self) -> None: @@ -379,7 +388,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_APPLE.to_string()]}, } @@ -402,7 +411,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual( events[0], - [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], + [ + { + "type": EduTypes.TYPING, + "room_id": ROOM_ID, + "content": {"user_ids": []}, + } + ], ) # SYN-230 - see if we can still set after timeout @@ -433,7 +448,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_APPLE.to_string()]}, } diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 4d658d29..9e39cd97 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not match the regex above, so that tests @@ -954,7 +953,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) + self.hs.get_storage_controllers().persistence.persist_event(event, context) ) def test_local_user_leaving_room_remains_in_user_directory(self) -> None: diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py new file mode 100644 index 00000000..3a5f22c0 --- /dev/null +++ b/tests/http/server/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py new file mode 100644 index 00000000..b9f1a381 --- /dev/null +++ b/tests/http/server/_base.py @@ -0,0 +1,100 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unles4s required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Any, Callable, Optional, Union +from unittest import mock + +from twisted.internet.error import ConnectionDone + +from synapse.http.server import ( + HTTP_STATUS_REQUEST_CANCELLED, + respond_with_html_bytes, + respond_with_json, +) +from synapse.types import JsonDict + +from tests import unittest +from tests.server import FakeChannel, ThreadedMemoryReactorClock + + +class EndpointCancellationTestHelperMixin(unittest.TestCase): + """Provides helper methods for testing cancellation of endpoints.""" + + def _test_disconnect( + self, + reactor: ThreadedMemoryReactorClock, + channel: FakeChannel, + expect_cancellation: bool, + expected_body: Union[bytes, JsonDict], + expected_code: Optional[int] = None, + ) -> None: + """Disconnects an in-flight request and checks the response. + + Args: + reactor: The twisted reactor running the request handler. + channel: The `FakeChannel` for the request. + expect_cancellation: `True` if request processing is expected to be + cancelled, `False` if the request should run to completion. + expected_body: The expected response for the request. + expected_code: The expected status code for the request. Defaults to `200` + or `499` depending on `expect_cancellation`. + """ + # Determine the expected status code. + if expected_code is None: + if expect_cancellation: + expected_code = HTTP_STATUS_REQUEST_CANCELLED + else: + expected_code = HTTPStatus.OK + + request = channel.request + self.assertFalse( + channel.is_finished(), + "Request finished before we could disconnect - " + "was `await_result=False` passed to `make_request`?", + ) + + # We're about to disconnect the request. This also disconnects the channel, so + # we have to rely on mocks to extract the response. + respond_method: Callable[..., Any] + if isinstance(expected_body, bytes): + respond_method = respond_with_html_bytes + else: + respond_method = respond_with_json + + with mock.patch( + f"synapse.http.server.{respond_method.__name__}", wraps=respond_method + ) as respond_mock: + # Disconnect the request. + request.connectionLost(reason=ConnectionDone()) + + if expect_cancellation: + # An immediate cancellation is expected. + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + self.assertEqual(code, expected_code) + self.assertEqual(request.code, expected_code) + self.assertEqual(body, expected_body) + else: + respond_mock.assert_not_called() + + # The handler is expected to run to completion. + reactor.pump([1.0]) + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + self.assertEqual(code, expected_code) + self.assertEqual(request.code, expected_code) + self.assertEqual(body, expected_body) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 638babae..006dbab0 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel from synapse.api.errors import RequestSendFailed from synapse.http.matrixfederationclient import ( - MAX_RESPONSE_SIZE, + JsonParser, MatrixFederationHttpClient, MatrixFederationRequest, ) @@ -609,9 +609,9 @@ class FederationClientTests(HomeserverTestCase): while not test_d.called: protocol.dataReceived(b"a" * chunk_size) sent += chunk_size - self.assertLessEqual(sent, MAX_RESPONSE_SIZE) + self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE) - self.assertEqual(sent, MAX_RESPONSE_SIZE) + self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE) f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index a80bfb9f..b3655d7b 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -12,16 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from http import HTTPStatus from io import BytesIO +from typing import Tuple from unittest.mock import Mock -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import cancellable from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_json_value_from_request, ) +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.server import HomeServer +from synapse.types import JsonDict from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin def make_request(content): @@ -40,19 +49,21 @@ class TestServletUtils(unittest.TestCase): """Basic tests for parse_json_value_from_request.""" # Test round-tripping. obj = {"foo": 1} - result = parse_json_value_from_request(make_request(obj)) - self.assertEqual(result, obj) + result1 = parse_json_value_from_request(make_request(obj)) + self.assertEqual(result1, obj) # Results don't have to be objects. - result = parse_json_value_from_request(make_request(b'["foo"]')) - self.assertEqual(result, ["foo"]) + result2 = parse_json_value_from_request(make_request(b'["foo"]')) + self.assertEqual(result2, ["foo"]) # Test empty. with self.assertRaises(SynapseError): parse_json_value_from_request(make_request(b"")) - result = parse_json_value_from_request(make_request(b""), allow_empty_body=True) - self.assertIsNone(result) + result3 = parse_json_value_from_request( + make_request(b""), allow_empty_body=True + ) + self.assertIsNone(result3) # Invalid UTF-8. with self.assertRaises(SynapseError): @@ -76,3 +87,52 @@ class TestServletUtils(unittest.TestCase): # Test not an object with self.assertRaises(SynapseError): parse_json_object_from_request(make_request(b'["foo"]')) + + +class CancellableRestServlet(RestServlet): + """A `RestServlet` with a mix of cancellable and uncancellable handlers.""" + + PATTERNS = client_patterns("/sleep$") + + def __init__(self, hs: HomeServer): + super().__init__() + self.clock = hs.get_clock() + + @cancellable + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class TestRestServletCancellation( + unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `RestServlet` cancellation.""" + + servlets = [ + lambda hs, http_server: CancellableRestServlet(hs).register(http_server) + ] + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = self.make_request("GET", "/sleep", await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = self.make_request("POST", "/sleep", await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/http/test_site.py b/tests/http/test_site.py index 8c13b4f6..b2dbf76d 100644 --- a/tests/http/test_site.py +++ b/tests/http/test_site.py @@ -36,7 +36,7 @@ class SynapseRequestTestCase(HomeserverTestCase): # as a control case, first send a regular request. # complete the connection and wire it up to a fake transport - client_address = IPv6Address("TCP", "::1", "2345") + client_address = IPv6Address("TCP", "::1", 2345) protocol = factory.buildProtocol(client_address) transport = StringTransport() protocol.makeConnection(transport) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 8bc84aaa..169e29b5 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -399,7 +399,7 @@ class ModuleApiTestCase(HomeserverTestCase): for edu in edus: # Make sure we're only checking presence-type EDUs - if edu["edu_type"] != EduTypes.Presence: + if edu["edu_type"] != EduTypes.PRESENCE: continue # EDUs can contain multiple presence updates diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 5dba1870..9b623d00 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set, Tuple, Union import frozendict @@ -26,7 +26,12 @@ from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent: + def _get_evaluator( + self, + content: JsonDict, + relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, + relations_match_enabled: bool = False, + ) -> PushRuleEvaluatorForEvent: event = FrozenEvent( { "event_id": "$event_id", @@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): sender_power_level = 0 power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluatorForEvent( - event, room_member_count, sender_power_level, power_levels + event, + room_member_count, + sender_power_level, + power_levels, + relations or set(), + relations_match_enabled, ) def test_display_name(self) -> None: @@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): push_rule_evaluator.tweaks_for_actions(actions), {"sound": "default", "highlight": True}, ) + + def test_relation_match(self) -> None: + """Test the relation_match push rule kind.""" + + # Check if the experimental feature is disabled. + evaluator = self._get_evaluator( + {}, {"m.annotation": {("@user:test", "m.reaction")}} + ) + condition = {"kind": "relation_match"} + # Oddly, an unknown condition always matches. + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # A push rule evaluator with the experimental rule enabled. + evaluator = self._get_evaluator( + {}, {"m.annotation": {("@user:test", "m.reaction")}}, True + ) + + # Check just relation type. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check relation type and sender. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@user:test", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@other:test", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + # Check relation type and event type. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "type": "m.reaction", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check just sender, this fails since rel_type is required. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "sender": "@user:test", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + # Check sender glob. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@*:test", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check event type glob. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "event_type": "*.reaction", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a7602b4c..970d5e53 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol @@ -32,6 +33,7 @@ from synapse.server import HomeServer from tests import unittest from tests.server import FakeTransport +from tests.utils import USE_POSTGRES_FOR_TESTS try: import hiredis @@ -475,22 +477,25 @@ class FakeRedisPubSubServer: """A fake Redis server for pub/sub.""" def __init__(self): - self._subscribers = set() + self._subscribers_by_channel: Dict[ + bytes, Set["FakeRedisPubSubProtocol"] + ] = defaultdict(set) - def add_subscriber(self, conn): + def add_subscriber(self, conn, channel: bytes): """A connection has called SUBSCRIBE""" - self._subscribers.add(conn) + self._subscribers_by_channel[channel].add(conn) def remove_subscriber(self, conn): - """A connection has called UNSUBSCRIBE""" - self._subscribers.discard(conn) + """A connection has lost connection""" + for subscribers in self._subscribers_by_channel.values(): + subscribers.discard(conn) - def publish(self, conn, channel, msg) -> int: + def publish(self, conn, channel: bytes, msg) -> int: """A connection want to publish a message to subscribers.""" - for sub in self._subscribers: + for sub in self._subscribers_by_channel[channel]: sub.send(["message", channel, msg]) - return len(self._subscribers) + return len(self._subscribers_by_channel) def buildProtocol(self, addr): return FakeRedisPubSubProtocol(self) @@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol): num_subscribers = self._server.publish(self, channel, message) self.send(num_subscribers) elif command == b"SUBSCRIBE": - (channel,) = args - self._server.add_subscriber(self) - self.send(["subscribe", channel, 1]) + for idx, channel in enumerate(args): + num_channels = idx + 1 + self._server.add_subscriber(self, channel) + self.send(["subscribe", channel, num_channels]) # Since we use SET/GET to cache things we can safely no-op them. elif command == b"SET": @@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol): def connectionLost(self, reason): self._server.remove_subscriber(self) + + +class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase): + """ + A test case that enables Redis, providing a fake Redis server. + """ + + if not hiredis: + skip = "Requires hiredis" + + if not USE_POSTGRES_FOR_TESTS: + # Redis replication only takes place on Postgres + skip = "Requires Postgres" + + def default_config(self) -> Dict[str, Any]: + """ + Overrides the default config to enable Redis. + Even if the test only uses make_worker_hs, the main process needs Redis + enabled otherwise it won't create a Fake Redis server to listen on the + Redis port and accept fake TCP connections. + """ + base = super().default_config() + base["redis"] = {"enabled": True} + return base diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py new file mode 100644 index 00000000..3a5f22c0 --- /dev/null +++ b/tests/replication/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py new file mode 100644 index 00000000..a5ab093a --- /dev/null +++ b/tests/replication/http/test__base.py @@ -0,0 +1,106 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Tuple + +from twisted.web.server import Request + +from synapse.api.errors import Codes +from synapse.http.server import JsonResource, cancellable +from synapse.replication.http import REPLICATION_PREFIX +from synapse.replication.http._base import ReplicationEndpoint +from synapse.server import HomeServer +from synapse.types import JsonDict + +from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin + + +class CancellableReplicationEndpoint(ReplicationEndpoint): + NAME = "cancellable_sleep" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: HomeServer): + super().__init__(hs) + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload() -> JsonDict: + return {} + + @cancellable + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class UncancellableReplicationEndpoint(ReplicationEndpoint): + NAME = "uncancellable_sleep" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: HomeServer): + super().__init__(hs) + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload() -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class ReplicationEndpointCancellationTestCase( + unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `ReplicationEndpoint` cancellation.""" + + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`.""" + resource = JsonResource(self.hs) + + CancellableReplicationEndpoint(self.hs).register(resource) + UncancellableReplicationEndpoint(self.hs).register(resource) + + return resource + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" + channel = self.make_request("POST", path, await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" + channel = self.make_request("POST", path, await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 85be79d1..c5705256 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.master_store = hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 297a9e77..6d3d4afe 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) msg, msgctx = self.build_event() self.get_success( - self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + self._storage_controllers.persistence.persist_events( + [(j2, j2ctx), (msg, msgctx)] + ) ) self.replicate() @@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.storage.persistence.persist_events( + self._storage_controllers.persistence.persist_events( [(event, context)], backfilled=True ) ) else: - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 5bbbd5fb..19f57115 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): def prepare(self, reactor, clock, homeserver): super().prepare(reactor, clock, homeserver) self.room_creator = homeserver.get_room_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self.persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) # Create a test user self.ourUser = UserID.from_string(OUR_USER_ID) @@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) # Join the second user to the second room @@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) def test_return_empty_with_no_data(self): diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py new file mode 100644 index 00000000..e6a19eaf --- /dev/null +++ b/tests/replication/tcp/test_handler.py @@ -0,0 +1,73 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.replication._base import RedisMultiWorkerStreamTestCase + + +class ChannelsTestCase(RedisMultiWorkerStreamTestCase): + def test_subscribed_to_enough_redis_channels(self) -> None: + # The default main process is subscribed to the USER_IP channel. + self.assertCountEqual( + self.hs.get_replication_command_handler()._channels_to_subscribe_to, + ["USER_IP"], + ) + + def test_background_worker_subscribed_to_user_ip(self) -> None: + # The default main process is subscribed to the USER_IP channel. + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + self.assertIn( + "USER_IP", + worker1.get_replication_command_handler()._channels_to_subscribe_to, + ) + + # Advance so the Redis subscription gets processed + self.pump(0.1) + + # The counts are 2 because both the main process and the worker are subscribed. + self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2) + self.assertEqual( + len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2 + ) + + def test_non_background_worker_not_subscribed_to_user_ip(self) -> None: + # The default main process is subscribed to the USER_IP channel. + worker2 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker2", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + self.assertNotIn( + "USER_IP", + worker2.get_replication_command_handler()._channels_to_subscribe_to, + ) + + # Advance so the Redis subscription gets processed + self.pump(0.1) + + # The count is 2 because both the main process and the worker are subscribed. + self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2) + # For USER_IP, the count is 1 because only the main process is subscribed. + self.assertEqual( + len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1 + ) diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 5f142e84..a7ca6806 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -14,7 +14,6 @@ import logging from unittest.mock import patch -from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -64,21 +63,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # We control the room ID generation by patching out the # `_generate_room_id` method - async def generate_room( - creator_id: str, is_public: bool, room_version: RoomVersion - ): - await self.store.store_room( - room_id=room_id, - room_creator_user_id=creator_id, - is_public=is_public, - room_version=room_version, - ) - return room_id - with patch( "synapse.handlers.room.RoomCreationHandler._generate_room_id" ) as mock: - mock.side_effect = generate_room + mock.side_effect = lambda: room_id self.helper.create_room_as(user_id, tok=tok) def test_basic(self): diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 40571b75..82ac5991 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -14,7 +14,6 @@ import urllib.parse from http import HTTPStatus -from typing import List from parameterized import parameterized @@ -23,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.http.server import JsonResource from synapse.rest.admin import VersionServlet -from synapse.rest.client import groups, login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.util import Clock @@ -49,93 +48,6 @@ class VersionTestCase(unittest.HomeserverTestCase): ) -class DeleteGroupTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - groups.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - @unittest.override_config({"experimental_features": {"groups_enabled": True}}) - def test_delete_group(self) -> None: - # Create a new group - channel = self.make_request( - "POST", - b"/create_group", - access_token=self.admin_user_tok, - content={"localpart": "test"}, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - group_id = channel.json_body["group_id"] - - self._check_group(group_id, expect_code=HTTPStatus.OK) - - # Invite/join another user - - url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) - channel = self.make_request( - "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - url = "/groups/%s/self/accept_invite" % (group_id,) - channel = self.make_request( - "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - # Check other user knows they're in the group - self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) - self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - - # Now delete the group - url = "/_synapse/admin/v1/delete_group/" + group_id - channel = self.make_request( - "POST", - url.encode("ascii"), - access_token=self.admin_user_tok, - content={"localpart": "test"}, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - # Check group returns HTTPStatus.NOT_FOUND - self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND) - - # Check users don't think they're in the group - self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) - self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - - def _check_group(self, group_id: str, expect_code: int) -> None: - """Assert that trying to fetch the given group results in the given - HTTP status code - """ - - url = "/groups/%s/profile" % (group_id,) - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - - self.assertEqual(expect_code, channel.code, msg=channel.json_body) - - def _get_groups_user_is_in(self, access_token: str) -> List[str]: - """Returns the list of groups the user is in (given their access token)""" - channel = self.make_request("GET", b"/joined_groups", access_token=access_token) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - return channel.json_body["groups"] - - class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Test /quarantine_media admin API.""" diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 95282f07..ca6af941 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2467,7 +2467,6 @@ PURGE_TABLES = [ "event_push_actions", "event_search", "events", - "group_rooms", "receipts_graph", "receipts_linearized", "room_aliases", @@ -2484,9 +2483,9 @@ PURGE_TABLES = [ "e2e_room_keys", "event_push_summary", "pusher_throttle", - "group_summary_rooms", "room_account_data", "room_tags", # "state_groups", # Current impl leaves orphaned state groups around. "state_groups_state", + "federation_inbound_events_staging", ] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 0cdf1dec..0d441022 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage.persistence.persist_event(event, context)) + self.get_success(storage_controllers.persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e0a11da9..a43a1372 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 9653f458..05355c7f 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -195,8 +195,17 @@ class UIAuthTests(unittest.HomeserverTestCase): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.device_id = "dev1" + + # Force-enable password login for just long enough to log in. + auth_handler = self.hs.get_auth_handler() + allow_auth_for_login = auth_handler._password_enabled_for_login + auth_handler._password_enabled_for_login = True + self.user_tok = self.login("test", self.user_pass, self.device_id) + # Restore password login to however it was. + auth_handler._password_enabled_for_login = allow_auth_for_login + def delete_device( self, access_token: str, @@ -263,6 +272,38 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + @override_config({"password_config": {"enabled": "only_for_reauth"}}) + def test_ui_auth_with_passwords_for_reauth_only(self) -> None: + """ + Test user interactive authentication outside of registration. + """ + + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + HTTPStatus.OK, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + def test_grandfathered_identifier(self) -> None: """Check behaviour without "identifier" dict diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py deleted file mode 100644 index a8af4e24..00000000 --- a/tests/rest/client/test_device_lists.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from http import HTTPStatus - -from synapse.rest import admin, devices, room, sync -from synapse.rest.client import account, login, register - -from tests import unittest - - -class DeviceListsTestCase(unittest.HomeserverTestCase): - """Tests regarding device list changes.""" - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - register.register_servlets, - account.register_servlets, - room.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] - - def test_receiving_local_device_list_changes(self) -> None: - """Tests that a local users that share a room receive each other's device list - changes. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") - - # Create a room for them to coexist peacefully in - new_room_id = self.helper.create_room_as( - alice_user_id, is_public=True, tok=alice_access_token - ) - self.assertIsNotNone(new_room_id) - - # Have Bob join the room - self.helper.invite( - new_room_id, alice_user_id, bob_user_id, tok=alice_access_token - ) - self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) - - # Now have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - "/sync", - access_token=bob_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch_token = channel.json_body["next_batch"] - - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"/sync?since={next_batch_token}&timeout=30000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) - - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - - # Check that bob's incremental sync contains the updated device list. - # If not, the client would only receive the device list update on the - # *next* sync. - bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) - - changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( - "changed", [] - ) - self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) - - def test_not_receiving_local_device_list_changes(self) -> None: - """Tests a local users DO NOT receive device updates from each other if they do not - share a room. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") - - # These users do not share a room. They are lonely. - - # Have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - "/sync", - access_token=bob_access_token, - ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - next_batch_token = channel.json_body["next_batch"] - - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"/sync?since={next_batch_token}&timeout=1000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) - - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - - # Check that bob's incremental sync does not contain the updated device list. - bob_sync_channel.await_result() - self.assertEqual( - bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body - ) - - changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( - "changed", [] - ) - self.assertNotIn( - alice_user_id, changed_device_lists, bob_sync_channel.json_body - ) diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py new file mode 100644 index 00000000..aa982224 --- /dev/null +++ b/tests/rest/client/test_devices.py @@ -0,0 +1,202 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.errors import NotFoundError +from synapse.rest import admin, devices, room, sync +from synapse.rest.client import account, login, register +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class DeviceListsTestCase(unittest.HomeserverTestCase): + """Tests regarding device list changes.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + account.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_receiving_local_device_list_changes(self) -> None: + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) + + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + + def test_not_receiving_local_device_list_changes(self) -> None: + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # These users do not share a room. They are lonely. + + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual( + bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body + ) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) + + +class DevicesTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_device_handler() + + @unittest.override_config({"delete_stale_devices_after": 72000000}) + def test_delete_stale_devices(self) -> None: + """Tests that stale devices are automatically removed after a set time of + inactivity. + The configuration is set to delete devices that haven't been used in the past 20h. + """ + # Register a user and creates 2 devices for them. + user_id = self.register_user("user", "password") + tok1 = self.login("user", "password", device_id="abc") + tok2 = self.login("user", "password", device_id="def") + + # Sync them so they have a last_seen value. + self.make_request("GET", "/sync", access_token=tok1) + self.make_request("GET", "/sync", access_token=tok2) + + # Advance half a day and sync again with one of the devices, so that the next + # time the background job runs we don't delete this device (since it will look + # for devices that haven't been used for over an hour). + self.reactor.advance(43200) + self.make_request("GET", "/sync", access_token=tok1) + + # Advance another half a day, and check that the device that has synced still + # exists but the one that hasn't has been removed. + self.reactor.advance(43200) + self.get_success(self.handler.get_device(user_id, "abc")) + self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 1b1392fa..a9b7db9d 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -19,6 +19,7 @@ from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin +from synapse.api.constants import EduTypes from synapse.rest.client import events, login, room from synapse.server import HomeServer from synapse.util import Clock @@ -103,7 +104,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): c for c in channel.json_body["chunk"] if not ( - c.get("type") == "m.presence" + c.get("type") == EduTypes.PRESENCE and c["content"].get("user_id") == self.user_id ) ] diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py deleted file mode 100644 index e067cf82..00000000 --- a/tests/rest/client/test_groups.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest.client import groups, room - -from tests import unittest -from tests.unittest import override_config - - -class GroupsTestCase(unittest.HomeserverTestCase): - user_id = "@alice:test" - room_creator_user_id = "@bob:test" - - servlets = [room.register_servlets, groups.register_servlets] - - @override_config({"enable_group_creation": True}) - def test_rooms_limited_by_visibility(self) -> None: - group_id = "+spqr:test" - - # Alice creates a group - channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) - self.assertEqual(channel.code, 200, msg=channel.text_body) - self.assertEqual(channel.json_body, {"group_id": group_id}) - - # Bob creates a private room - room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) - self.helper.auth_user_id = self.room_creator_user_id - self.helper.send_state( - room_id, "m.room.name", {"name": "bob's secret room"}, tok=None - ) - self.helper.auth_user_id = self.user_id - - # Alice adds the room to her group. - channel = self.make_request( - "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} - ) - self.assertEqual(channel.code, 200, msg=channel.text_body) - self.assertEqual(channel.json_body, {}) - - # Alice now tries to retrieve the room list of the space. - channel = self.make_request("GET", f"/groups/{group_id}/rooms") - self.assertEqual(channel.code, 200, msg=channel.text_body) - self.assertEqual( - channel.json_body, {"chunk": [], "total_room_count_estimate": 0} - ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 4920468f..f4ea1209 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.service = ApplicationService( id="unique_identifier", token="some_token", - hostname="example.com", sender="@asbot:example.com", namespaces={ ApplicationService.NS_USERS: [ @@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.another_service = ApplicationService( id="another__identifier", token="another_token", - hostname="example.com", sender="@as2bot:example.com", namespaces={ ApplicationService.NS_USERS: [ diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py index 7b7d283b..a4327f7a 100644 --- a/tests/rest/client/test_mutual_rooms.py +++ b/tests/rest/client/test_mutual_rooms.py @@ -36,12 +36,10 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["update_user_directory"] = True return self.setup_test_homeserver(config=config) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self.handler = hs.get_user_directory_handler() def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request( diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py new file mode 100644 index 00000000..700f6587 --- /dev/null +++ b/tests/rest/client/test_notifications.py @@ -0,0 +1,91 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, notifications, receipts, room +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase + + +class HTTPPusherTests(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + receipts.register_servlets, + notifications.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.module_api = homeserver.get_module_api() + self.event_creation_handler = homeserver.get_event_creation_handler() + self.sync_handler = homeserver.get_sync_handler() + self.auth_handler = homeserver.get_auth_handler() + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # Mock out the calls over federation. + fed_transport_client = Mock(spec=["send_transaction"]) + fed_transport_client.send_transaction = simple_async_mock({}) + + return self.setup_test_homeserver( + federation_transport_client=fed_transport_client, + ) + + def test_notify_for_local_invites(self) -> None: + """ + Local users will get notified for invites + """ + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Check we start with no pushes + channel = self.make_request( + "GET", + "/notifications", + access_token=other_access_token, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body) + + # Send an invite + self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token) + + # We should have a notification now + channel = self.make_request( + "GET", + "/notifications", + access_token=other_access_token, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + self.assertEqual( + channel.json_body["notifications"][0]["event"]["content"]["membership"], + "invite", + channel.json_body, + ) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 9aebf173..afb08b27 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 27dee8f6..62e4db23 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): relation_type: str, assertion_callable: Callable[[JsonDict], None], expected_db_txn_for_event: int, + access_token: Optional[str] = None, ) -> None: """ Makes requests to various endpoints which should include bundled aggregations @@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): for relation-specific assertions. expected_db_txn_for_event: The number of database transactions which are expected for a call to /event/. + access_token: The access token to user, defaults to self.user_token. """ + access_token = access_token or self.user_token def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" @@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) @@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) @@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) @@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Request sync. filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token + "GET", f"/sync?filter={filter}", access_token=access_token ) self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] @@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "/search", # Search term matches the parent message. content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) chunk = [ @@ -995,7 +998,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6) def test_annotation_to_annotation(self) -> None: """Any relation to an annotation should be ignored.""" @@ -1031,36 +1034,66 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) def test_thread(self) -> None: """ Test that threads get correctly bundled. """ - self._send_relation(RelationTypes.THREAD, "m.room.test") - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + # The root message is from "user", send replies as "user2". + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) thread_2 = channel.json_body["event_id"] - def assert_thread(bundled_aggregations: JsonDict) -> None: - self.assertEqual(2, bundled_aggregations.get("count")) - self.assertTrue(bundled_aggregations.get("current_user_participated")) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } + # This needs two assertion functions which are identical except for whether + # the current_user_participated flag is True, create a factory for the + # two versions. + def _gen_assert(participated: bool) -> Callable[[JsonDict], None]: + def assert_thread(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertEqual( + participated, bundled_aggregations.get("current_user_participated") + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user2_id, + "type": "m.room.test", }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - bundled_aggregations.get("latest_event"), - ) + bundled_aggregations.get("latest_event"), + ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10) + return assert_thread + + # The "user" sent the root event and is making queries for the bundled + # aggregations: they have participated. + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + # The "user2" sent replies in the thread and is making queries for the + # bundled aggregations: they have participated. + # + # Note that this re-uses some cached values, so the total number of + # queries is much smaller. + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + ) + + # A user with no interactions with the thread: they have not participated. + user3_id, user3_token = self._create_user("charlie") + self.helper.join(self.room, user=user3_id, tok=user3_token) + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: """ @@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 7b8fe6d0..ac9c1133 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -129,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Send a first event, which should be filtered out at the end of the test. @@ -154,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client(storage, self.user_id, events) + filter_events_for_client(storage_controllers, self.user_id, events) ) # We should only get one event back. @@ -252,16 +253,24 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config["retention"] = { + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + retention_config = { "enabled": True, } + # Update this config with what's in the default config so that + # override_config works as expected. + retention_config.update(config.get("retention", {})) + config["retention"] = retention_config + + return config + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: mock_federation_client = Mock(spec=["backfill"]) self.hs = self.setup_test_homeserver( - config=config, federation_client=mock_federation_client, ) return self.hs @@ -295,6 +304,24 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self._test_retention(room_id, expected_code_for_first_event=404) + @unittest.override_config({"retention": {"enabled": False}}) + def test_visibility_when_disabled(self) -> None: + """Retention policies should be ignored when the retention feature is disabled.""" + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms}, + tok=self.token, + ) + + resp = self.helper.send(room_id=room_id, body="test", tok=self.token) + + self.reactor.advance(one_day_ms * 2 / 1000) + + self.get_event(room_id, resp["event_id"]) + def _test_retention( self, room_id: str, expected_code_for_first_event: int = 200 ) -> None: diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 41a1bf6d..9d5cb60d 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not have to match the regex above @@ -88,7 +87,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.clock = clock - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token @@ -168,7 +167,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): # Fetch the state_groups state_group_map = self.get_success( - self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + self._storage_controllers.state.get_state_groups_ids( + room_id, historical_event_ids + ) ) # We expect all of the historical events to be using the same state_group diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 9443daa0..f523d89b 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -26,6 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( + EduTypes, EventContentFields, EventTypes, Membership, @@ -925,7 +926,7 @@ class RoomJoinTestCase(RoomBase): ) -> bool: return return_value - callback_mock = Mock(side_effect=user_may_join_room) + callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) # Join a first room, without being invited to it. @@ -1116,6 +1117,264 @@ class RoomMessagesTestCase(RoomBase): self.assertEqual(200, channel.code, msg=channel.result["body"]) +class RoomPowerLevelOverridesTestCase(RoomBase): + """Tests that the power levels can be overridden with server config.""" + + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user_id = self.register_user("admin", "pass") + self.admin_access_token = self.login("admin", "pass") + + def power_levels(self, room_id: str) -> Dict[str, Any]: + return self.helper.get_state( + room_id, "m.room.power_levels", self.admin_access_token + ) + + def test_default_power_levels_with_room_override(self) -> None: + """ + Create a room, providing power level overrides. + Confirm that the room's power levels reflect the overrides. + + See https://github.com/matrix-org/matrix-spec/issues/492 + - currently we overwrite each key of power_level_content_override + completely. + """ + + room_id = self.helper.create_room_as( + self.user_id, + extra_content={ + "power_level_content_override": {"events": {"custom.event": 0}} + }, + ) + self.assertEqual( + { + "custom.event": 0, + }, + self.power_levels(room_id)["events"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_power_levels_with_server_override(self) -> None: + """ + With a server configured to modify the room-level defaults, + Create a room, without providing any extra power level overrides. + Confirm that the room's power levels reflect the server-level overrides. + + Similar to https://github.com/matrix-org/matrix-spec/issues/492, + we overwrite each key of power_level_content_override completely. + """ + + room_id = self.helper.create_room_as(self.user_id) + self.assertEqual( + { + "custom.event": 0, + }, + self.power_levels(room_id)["events"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": { + "events": {"server.event": 0}, + "ban": 13, + }, + } + }, + ) + def test_power_levels_with_server_and_room_overrides(self) -> None: + """ + With a server configured to modify the room-level defaults, + create a room, providing different overrides. + Confirm that the room's power levels reflect both overrides, and + choose the room overrides where they clash. + """ + + room_id = self.helper.create_room_as( + self.user_id, + extra_content={ + "power_level_content_override": {"events": {"room.event": 0}} + }, + ) + + # Room override wins over server config + self.assertEqual( + {"room.event": 0}, + self.power_levels(room_id)["events"], + ) + + # But where there is no room override, server config wins + self.assertEqual(13, self.power_levels(room_id)["ban"]) + + +class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): + """ + Tests that we can really do various otherwise-prohibited actions + based on overriding the power levels in config. + """ + + user_id = "@sid1:red" + + def test_creator_can_post_state_event(self) -> None: + # Given I am the creator of a room + room_id = self.helper.create_room_as(self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am allowed + self.assertEqual(200, channel.code, msg=channel.result["body"]) + + def test_normal_user_can_not_post_state_event(self) -> None: + # Given I am a normal member of a room + room_id = self.helper.create_room_as("@some_other_guy:red") + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed because state events require PL>=50 + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", + channel.json_body["error"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_with_config_override_normal_user_can_post_state_event(self) -> None: + # Given the server has config allowing normal users to post my event type, + # and I am a normal member of a room + room_id = self.helper.create_room_as("@some_other_guy:red") + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am allowed + self.assertEqual(200, channel.code, msg=channel.result["body"]) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_any_room_override_defeats_config_override(self) -> None: + # Given the server has config allowing normal users to post my event type + # And I am a normal member of a room + # But the room was created with special permissions + extra_content: Dict[str, Any] = { + "power_level_content_override": {"events": {}}, + } + room_id = self.helper.create_room_as( + "@some_other_guy:red", extra_content=extra_content + ) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed + self.assertEqual(403, channel.code, msg=channel.result["body"]) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_specific_room_override_defeats_config_override(self) -> None: + # Given the server has config allowing normal users to post my event type, + # and I am a normal member of a room, + # but the room was created with special permissions for this event type + extra_content = { + "power_level_content_override": {"events": {"custom.event": 1}}, + } + room_id = self.helper.create_room_as( + "@some_other_guy:red", extra_content=extra_content + ) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + + "user_level (0) < send_level (1)", + channel.json_body["error"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + "private_chat": None, + "trusted_private_chat": None, + } + }, + ) + def test_config_override_applies_only_to_specific_preset(self) -> None: + # Given the server has config for public_chats, + # and I am a normal member of a private_chat room + room_id = self.helper.create_room_as("@some_other_guy:red", is_public=False) + self.helper.invite(room=room_id, src="@some_other_guy:red", targ=self.user_id) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed because the public_chat config does not + # affect this room, because this room is a private_chat + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + + "user_level (0) < send_level (50)", + channel.json_body["error"], + ) + + class RoomInitialSyncTestCase(RoomBase): """Tests /rooms/$room_id/initialSync.""" @@ -1154,7 +1413,7 @@ class RoomInitialSyncTestCase(RoomBase): e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) - self.assertEqual("m.presence", presence_by_user[self.user_id]["type"]) + self.assertEqual(EduTypes.PRESENCE, presence_by_user[self.user_id]["type"]) class RoomMessageListTestCase(RoomBase): @@ -2598,7 +2857,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. - mock = Mock(return_value=make_awaitable(True)) + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) # Send a 3PID invite into the room and check that it succeeded. diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index c3942889..6435800f 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import EduTypes from synapse.rest import admin from synapse.rest.client import login, sendtodevice, sync @@ -139,7 +140,7 @@ class SendToDeviceTestCase(HomeserverTestCase): for i in range(3): self.get_success( federation_registry.on_edu( - "m.direct_to_device", + EduTypes.DIRECT_TO_DEVICE, "remote_server", { "sender": "@user:remote_server", @@ -172,7 +173,7 @@ class SendToDeviceTestCase(HomeserverTestCase): # and we can send more messages self.get_success( federation_registry.on_edu( - "m.direct_to_device", + EduTypes.DIRECT_TO_DEVICE, "remote_server", { "sender": "@user:remote_server", diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index ae5ada3b..d9bd8c4a 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -17,7 +17,7 @@ from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.rest.client import ( directory, login, @@ -226,7 +226,7 @@ class RoomTestCase(_ShadowBannedBase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": room_id, "content": {"user_ids": [self.other_user_id]}, } diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 01083376..e3efd1f1 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from http import HTTPStatus from typing import List, Optional from parameterized import parameterized @@ -21,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( + EduTypes, EventContentFields, EventTypes, ReceiptTypes, @@ -485,30 +487,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we didn't override the public read receipt self.assertIsNone(self._get_read_receipt()) - @parameterized.expand( - [ - # Old Element version, expected to send an empty body - ( - "agent1", - "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)", - 200, - ), - # Old SchildiChat version, expected to send an empty body - ("agent2", "SchildiChat/1.2.1 (Android 10)", 200), - # Expected 400: Denies empty body starting at version 1.3+ - ("agent3", "Element/1.3.6 (Android 10)", 400), - ("agent4", "SchildiChat/1.3.6 (Android 11)", 400), - # Contains "Riot": Receipts with empty bodies expected - ("agent5", "Element (Riot.im) (Android 9)", 200), - # Expected 400: Does not contain "Android" - ("agent6", "Element/1.2.1", 400), - # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage - ("agent7", "Element dbg/1.1.8-dev (Android)", 400), - ] - ) - def test_read_receipt_with_empty_body( - self, name: str, user_agent: str, expected_status_code: int - ) -> None: + def test_read_receipt_with_empty_body_is_rejected(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -517,16 +496,16 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): "POST", f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", access_token=self.tok2, - custom_headers=[("User-Agent", user_agent)], ) - self.assertEqual(channel.code, expected_status_code) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) def _get_read_receipt(self) -> Optional[JsonDict]: """Syncs and returns the read receipt.""" # Checks if event is a read receipt def is_read_receipt(event: JsonDict) -> bool: - return event["type"] == "m.receipt" + return event["type"] == EduTypes.RECEIPT # Sync channel = self.make_request( @@ -678,12 +657,13 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(3) # Check that custom events with a body increase the unread counter. - self.helper.send_event( + result = self.helper.send_event( self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, ) + event_id = result["event_id"] self._check_unread_count(4) # Check that edits don't increase the unread counter. @@ -693,7 +673,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): content={ "body": "hello", "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": event_id, + }, }, tok=self.tok2, ) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index d6da5107..61b66d76 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -17,6 +17,7 @@ from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EduTypes from synapse.rest.client import room from synapse.server import HomeServer from synapse.types import UserID @@ -67,7 +68,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": self.room_id, "content": {"user_ids": [self.user_id]}, } diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index c86fc5df..98c1039d 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -76,7 +76,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): """ Upgrading a room should work fine. """ - # THe user isn't in the room. + # The user isn't in the room. roomless = self.register_user("roomless", "pass") roomless_token = self.login(roomless, "pass") @@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_space_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_space_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_space_id) + ) # Ensure the new room is still a space. create_event = self.get_success( @@ -263,3 +265,35 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertIn((EventTypes.SpaceChild, self.room_id), state_ids) # The child that was removed should not be copied over. self.assertNotIn((EventTypes.SpaceChild, old_room_id), state_ids) + + def test_custom_room_type(self) -> None: + """Test upgrading a room that has a custom room type set.""" + test_room_type = "com.example.my_custom_room_type" + + # Create a room with a custom room type. + room_id = self.helper.create_room_as( + self.creator, + tok=self.creator_token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: test_room_type} + }, + ) + + # Upgrade the room! + channel = self._upgrade_room(room_id=room_id) + self.assertEqual(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + new_room_id = channel.json_body["replacement_room"] + + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_room_id) + ) + + # Ensure the new room is the same type as the old room. + create_event = self.get_success( + self.store.get_event(state_ids[(EventTypes.Create, "")]) + ) + self.assertEqual( + create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type + ) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py new file mode 100644 index 00000000..14af07c5 --- /dev/null +++ b/tests/rest/media/test_media_retention.py @@ -0,0 +1,321 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from typing import Iterable, Optional, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.types import UserID +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config +from tests.utils import MockClock + + +class MediaRetentionTestCase(unittest.HomeserverTestCase): + + ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 + THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS + + servlets = [ + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # We need to be able to test advancing time in the homeserver, so we + # replace the test homeserver's default clock with a MockClock, which + # supports advancing time. + return self.setup_test_homeserver(clock=MockClock()) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.remote_server_name = "remote.homeserver" + self.store = hs.get_datastores().main + + # Create a user to upload media with + test_user_id = self.register_user("alice", "password") + + # Inject media (recently accessed, old access, never accessed, old access + # quarantined media) into both the local store and the remote cache, plus + # one additional local media that is marked as protected from quarantine. + media_repository = hs.get_media_repository() + test_media_content = b"example string" + + def _create_media_and_set_attributes( + last_accessed_ms: Optional[int], + is_quarantined: Optional[bool] = False, + is_protected: Optional[bool] = False, + ) -> str: + # "Upload" some media to the local media store + mxc_uri = self.get_success( + media_repository.create_content( + media_type="text/plain", + upload_name=None, + content=io.BytesIO(test_media_content), + content_length=len(test_media_content), + auth_user=UserID.from_string(test_user_id), + ) + ) + + media_id = mxc_uri.split("/")[-1] + + # Set the last recently accessed time for this media + if last_accessed_ms is not None: + self.get_success( + self.store.update_cached_last_access_time( + local_media=(media_id,), + remote_media=(), + time_ms=last_accessed_ms, + ) + ) + + if is_quarantined: + # Mark this media as quarantined + self.get_success( + self.store.quarantine_media_by_id( + server_name=self.hs.config.server.server_name, + media_id=media_id, + quarantined_by="@theadmin:test", + ) + ) + + if is_protected: + # Mark this media as protected from quarantine + self.get_success( + self.store.mark_local_media_as_safe( + media_id=media_id, + safe=True, + ) + ) + + return media_id + + def _cache_remote_media_and_set_attributes( + media_id: str, + last_accessed_ms: Optional[int], + is_quarantined: Optional[bool] = False, + ) -> str: + # Pretend to cache some remote media + self.get_success( + self.store.store_cached_remote_media( + origin=self.remote_server_name, + media_id=media_id, + media_type="text/plain", + media_length=1, + time_now_ms=clock.time_msec(), + upload_name="testfile.txt", + filesystem_id="abcdefg12345", + ) + ) + + # Set the last recently accessed time for this media + if last_accessed_ms is not None: + self.get_success( + hs.get_datastores().main.update_cached_last_access_time( + local_media=(), + remote_media=((self.remote_server_name, media_id),), + time_ms=last_accessed_ms, + ) + ) + + if is_quarantined: + # Mark this media as quarantined + self.get_success( + self.store.quarantine_media_by_id( + server_name=self.remote_server_name, + media_id=media_id, + quarantined_by="@theadmin:test", + ) + ) + + return media_id + + # Start with the local media store + self.local_recently_accessed_media = _create_media_and_set_attributes( + last_accessed_ms=self.THIRTY_DAYS_IN_MS, + ) + self.local_not_recently_accessed_media = _create_media_and_set_attributes( + last_accessed_ms=self.ONE_DAY_IN_MS, + ) + self.local_not_recently_accessed_quarantined_media = ( + _create_media_and_set_attributes( + last_accessed_ms=self.ONE_DAY_IN_MS, + is_quarantined=True, + ) + ) + self.local_not_recently_accessed_protected_media = ( + _create_media_and_set_attributes( + last_accessed_ms=self.ONE_DAY_IN_MS, + is_protected=True, + ) + ) + self.local_never_accessed_media = _create_media_and_set_attributes( + last_accessed_ms=None, + ) + + # And now the remote media store + self.remote_recently_accessed_media = _cache_remote_media_and_set_attributes( + media_id="a", + last_accessed_ms=self.THIRTY_DAYS_IN_MS, + ) + self.remote_not_recently_accessed_media = ( + _cache_remote_media_and_set_attributes( + media_id="b", + last_accessed_ms=self.ONE_DAY_IN_MS, + ) + ) + self.remote_not_recently_accessed_quarantined_media = ( + _cache_remote_media_and_set_attributes( + media_id="c", + last_accessed_ms=self.ONE_DAY_IN_MS, + is_quarantined=True, + ) + ) + # Remote media will always have a "last accessed" attribute, as it would not + # be fetched from the remote homeserver unless instigated by a user. + + @override_config( + { + "media_retention": { + # Enable retention for local media + "local_media_lifetime": "30d" + # Cached remote media should not be purged + } + } + ) + def test_local_media_retention(self) -> None: + """ + Tests that local media that have not been accessed recently is purged, while + cached remote media is unaffected. + """ + # Advance 31 days (in seconds) + self.reactor.advance(31 * 24 * 60 * 60) + + # Check that media has been correctly purged. + # Local media accessed <30 days ago should still exist. + # Remote media should be unaffected. + self._assert_if_mxc_uris_purged( + purged=[ + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_media, + ), + (self.hs.config.server.server_name, self.local_never_accessed_media), + ], + not_purged=[ + (self.hs.config.server.server_name, self.local_recently_accessed_media), + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_quarantined_media, + ), + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_protected_media, + ), + (self.remote_server_name, self.remote_recently_accessed_media), + (self.remote_server_name, self.remote_not_recently_accessed_media), + ( + self.remote_server_name, + self.remote_not_recently_accessed_quarantined_media, + ), + ], + ) + + @override_config( + { + "media_retention": { + # Enable retention for cached remote media + "remote_media_lifetime": "30d" + # Local media should not be purged + } + } + ) + def test_remote_media_cache_retention(self) -> None: + """ + Tests that entries from the remote media cache that have not been accessed + recently is purged, while local media is unaffected. + """ + # Advance 31 days (in seconds) + self.reactor.advance(31 * 24 * 60 * 60) + + # Check that media has been correctly purged. + # Local media should be unaffected. + # Remote media accessed <30 days ago should still exist. + self._assert_if_mxc_uris_purged( + purged=[ + (self.remote_server_name, self.remote_not_recently_accessed_media), + ], + not_purged=[ + (self.remote_server_name, self.remote_recently_accessed_media), + (self.hs.config.server.server_name, self.local_recently_accessed_media), + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_media, + ), + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_quarantined_media, + ), + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_protected_media, + ), + ( + self.remote_server_name, + self.remote_not_recently_accessed_quarantined_media, + ), + (self.hs.config.server.server_name, self.local_never_accessed_media), + ], + ) + + def _assert_if_mxc_uris_purged( + self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]] + ) -> None: + def _assert_mxc_uri_purge_state( + server_name: str, media_id: str, expect_purged: bool + ) -> None: + """Given an MXC URI, assert whether it has been purged or not.""" + if server_name == self.hs.config.server.server_name: + found_media_dict = self.get_success( + self.store.get_local_media(media_id) + ) + else: + found_media_dict = self.get_success( + self.store.get_cached_remote_media(server_name, media_id) + ) + + mxc_uri = f"mxc://{server_name}/{media_id}" + + if expect_purged: + self.assertIsNone( + found_media_dict, msg=f"{mxc_uri} unexpectedly not purged" + ) + else: + self.assertIsNotNone( + found_media_dict, + msg=f"{mxc_uri} unexpectedly purged", + ) + + # Assert that the given MXC URIs have either been correctly purged or not. + for server_name, media_id in purged: + _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True) + for server_name, media_id in not_purged: + _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False) diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 62e30881..ea9e5889 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase): ) -class CalcOgTestCase(unittest.TestCase): +class OpenGraphFromHtmlTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" @@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + # Another variant is a title with no content. + html = b""" + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Title"}) + def test_h1_as_title(self) -> None: html = b""" @@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) + def test_empty_description(self) -> None: + """Description tags with empty content should be ignored.""" + html = b""" + + + + + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"}) + def test_missing_title_and_broken_h1(self) -> None: html = b""" diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 3b24d0ac..2c321f8d 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -656,6 +656,41 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_nonexistent_image(self) -> None: + """If the preview image doesn't exist, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + end_content = ( + b"""""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertNotIn("og:image", channel.json_body) + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py index 19a145ee..22f99c6a 100644 --- a/tests/scripts/test_new_matrix_user.py +++ b/tests/scripts/test_new_matrix_user.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List from unittest.mock import Mock, patch from synapse._scripts.register_new_matrix_user import request_registration @@ -49,8 +50,8 @@ class RegisterTestCase(TestCase): requests.post = post # The fake stdout will be written here - out = [] - err_code = [] + out: List[str] = [] + err_code: List[int] = [] with patch("synapse._scripts.register_new_matrix_user.requests", requests): request_registration( @@ -85,8 +86,8 @@ class RegisterTestCase(TestCase): requests.get = get # The fake stdout will be written here - out = [] - err_code = [] + out: List[str] = [] + err_code: List[int] = [] with patch("synapse._scripts.register_new_matrix_user.requests", requests): request_registration( @@ -137,8 +138,8 @@ class RegisterTestCase(TestCase): requests.post = post # The fake stdout will be written here - out = [] - err_code = [] + out: List[str] = [] + err_code: List[int] = [] with patch("synapse._scripts.register_new_matrix_user.requests", requests): request_registration( diff --git a/tests/server.py b/tests/server.py index 8f30e250..b9f46597 100644 --- a/tests/server.py +++ b/tests/server.py @@ -109,6 +109,17 @@ class FakeChannel: _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None resource_usage: Optional[ContextResourceUsage] = None + _request: Optional[Request] = None + + @property + def request(self) -> Request: + assert self._request is not None + return self._request + + @request.setter + def request(self, request: Request) -> None: + assert self._request is None + self._request = request @property def json_body(self): @@ -322,6 +333,8 @@ def make_request( channel = FakeChannel(site, reactor, ip=client_ip) req = request(channel, site) + channel.request = req + req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) @@ -736,6 +749,7 @@ def setup_test_homeserver( if config is None: config = default_config(name, parse=True) + config.caches.resize_all_caches() config.ldap_enabled = False if "clock" not in kwargs: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 9ee9509d..07e29788 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -75,6 +75,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( return_value=make_awaitable("!something:localhost") ) + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( + return_value=make_awaitable("!something:localhost") + ) self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @@ -102,6 +105,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() self._send_notice.assert_called_once() def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): @@ -300,7 +304,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): hasn't been reached (since it's the only user and the limit is 5), so users shouldn't receive a server notice. """ - self.register_user("user", "password") + m = Mock(return_value=make_awaitable(None)) + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m + + user_id = self.register_user("user", "password") tok = self.login("user", "password") channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) @@ -309,6 +316,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): "rooms", channel.json_body, "Got invites without server notice" ) + m.assert_called_once_with(user_id) + def test_invite_with_notice(self): """Tests that, if the MAU limit is hit, the server notices user invites each user to a room in which it has sent a notice. diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index c237a8c7..38963ce4 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -154,6 +154,31 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + def test_event_ref(self): + """Test that we reuse events that are still in memory but have fallen + out of the cache, rather than requesting them from the DB. + """ + + # Reset the event cache + self.store._get_event_cache.clear() + + with LoggingContext("test") as ctx: + # We keep hold of the event event though we never use it. + event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841 + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + # Reset the event cache + self.store._get_event_cache.clear() + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # Since the event is still in memory we shouldn't have fetched it + # from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0) + def test_dedupe(self): """Test that if we request the same event multiple times we only pull it out once. diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 74c6224e..3cc2a58d 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer, reactor +from twisted.internet.base import ReactorBase +from twisted.internet.defer import Deferred + from synapse.server import HomeServer from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS @@ -22,6 +26,56 @@ class LockTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): self.store = hs.get_datastores().main + def test_acquire_contention(self): + # Track the number of tasks holding the lock. + # Should be at most 1. + in_lock = 0 + max_in_lock = 0 + + release_lock: "Deferred[None]" = Deferred() + + async def task(): + nonlocal in_lock + nonlocal max_in_lock + + lock = await self.store.try_acquire_lock("name", "key") + if not lock: + return + + async with lock: + in_lock += 1 + max_in_lock = max(max_in_lock, in_lock) + + # Block to allow other tasks to attempt to take the lock. + await release_lock + + in_lock -= 1 + + # Start 3 tasks. + task1 = defer.ensureDeferred(task()) + task2 = defer.ensureDeferred(task()) + task3 = defer.ensureDeferred(task()) + + # Give the reactor a kick so that the database transaction returns. + self.pump() + + release_lock.callback(None) + + # Run the tasks to completion. + # To work around `Linearizer`s using a different reactor to sleep when + # contended (#12841), we call `runUntilCurrent` on + # `twisted.internet.reactor`, which is a different reactor to that used + # by the homeserver. + assert isinstance(reactor, ReactorBase) + self.get_success(task1) + reactor.runUntilCurrent() + self.get_success(task2) + reactor.runUntilCurrent() + self.get_success(task3) + + # At most one task should have held the lock at a time. + self.assertEqual(max_in_lock, 1) + def test_simple_lock(self): """Test that we can take out a lock and that while we hold it nobody else can take it out. diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1bf93e79..1047ed09 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -14,7 +14,7 @@ import json import os import tempfile -from typing import List, Optional, cast +from typing import List, cast from unittest.mock import Mock import yaml @@ -149,15 +149,12 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def _set_state( - self, id: str, state: ApplicationServiceState, txn: Optional[int] = None - ): + def _set_state(self, id: str, state: ApplicationServiceState): return self.db_pool.runOperation( self.engine.convert_param_style( - "INSERT INTO application_services_state(as_id, state, last_txn) " - "VALUES(?,?,?)" + "INSERT INTO application_services_state(as_id, state) VALUES(?,?)" ), - (id, state.value, txn), + (id, state.value), ) def _insert_txn(self, as_id, txn_id, events): @@ -280,17 +277,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn FROM application_services_state WHERE as_id=?" - ), - (service.id,), - ) - ) - self.assertEqual(1, len(res)) - self.assertEqual(txn_id, res[0][0]) - res = self.get_success( self.db_pool.runQuery( self.engine.convert_param_style( @@ -316,14 +302,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): res = self.get_success( self.db_pool.runQuery( self.engine.convert_param_style( - "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + "SELECT state FROM application_services_state WHERE as_id=?" ), (service.id,), ) ) self.assertEqual(1, len(res)) - self.assertEqual(txn_id, res[0][0]) - self.assertEqual(ApplicationServiceState.UP.value, res[0][1]) + self.assertEqual(ApplicationServiceState.UP.value, res[0][0]) res = self.get_success( self.db_pool.runQuery( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index a8ffb52c..cce8e75c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -60,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db._db_pool = self.db_pool - self.datastore = SQLBaseStore(db, None, hs) + self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index bbf079b2..f37505b6 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -13,6 +13,7 @@ # limitations under the License. import synapse.api.errors +from synapse.api.constants import EduTypes from tests.unittest import HomeserverTestCase @@ -266,10 +267,12 @@ class DeviceStoreTestCase(HomeserverTestCase): # (This is a temporary arrangement for backwards compatibility!) self.assertEqual(len(device_updates), 2, device_updates) self.assertEqual( - device_updates[0][0], "m.signing_key_update", device_updates[0] + device_updates[0][0], EduTypes.SIGNING_KEY_UPDATE, device_updates[0] ) self.assertEqual( - device_updates[1][0], "org.matrix.signing_key_update", device_updates[1] + device_updates[1][0], + EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, + device_updates[1], ) # Check there are no more device updates left. diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 401020fd..a0ce077a 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext()) for e in events] + txn, + [(e, EventContext(self.hs.get_storage_controllers())) for e in events], ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 645d564d..d92a9ac5 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): (room_id, event_id), ) - txn.execute( - ( - "INSERT INTO event_reference_hashes " - "(event_id, algorithm, hash) " - "VALUES (?, 'sha256', ?)" - ), - (event_id, bytearray(b"ffff")), - ) - for i in range(0, 20): self.get_success( self.store.db_pool.runInteraction("insert", insert_event, i) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index ef5e2587..2ff88e64 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -31,7 +31,8 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence + self._state_storage_controller = self.hs.get_storage_controllers().state self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -69,9 +70,9 @@ class ExtremPruneTestCase(HomeserverTestCase): def persist_event(self, event, state=None): """Persist the event, with optional state""" context = self.get_success( - self.state.compute_event_context(event, old_state=state) + self.state.compute_event_context(event, state_ids_before_event=state) ) - self.get_success(self.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) def assert_extremities(self, expected_extremities): """Assert the current extremities for the room""" @@ -103,9 +104,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -135,17 +138,20 @@ class ExtremPruneTestCase(HomeserverTestCase): # setting. The state resolution across the old and new event will then # include it, and so the resolved state won't match the new state. state_before_gap = dict( - self.get_success(self.state.get_current_state(self.room_id)) + self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) ) state_before_gap.pop(("m.room.history_visibility", "")) context = self.get_success( self.state.compute_event_context( - remote_event_2, old_state=state_before_gap.values() + remote_event_2, + state_ids_before_event=state_before_gap, ) ) - self.get_success(self.persistence.persist_event(remote_event_2, context)) + self.get_success(self._persistence.persist_event(remote_event_2, context)) # Check that we haven't dropped the old extremity. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -177,9 +183,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -207,9 +215,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -247,9 +257,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -289,9 +301,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id, local_message_event_id]) @@ -323,9 +337,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) @@ -340,7 +356,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): @@ -377,7 +393,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_rooms_for_user` to add the remote user to the cache rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) @@ -424,7 +440,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_users_in_room` to add the remote user to the cache users = self.get_success(self.store.get_users_in_room(room_id)) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 4c29ad79..e8b4a564 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -407,3 +407,86 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEqual(result[service1], 2) self.assertEqual(result[service2], 1) self.assertEqual(result[native], 1) + + def test_get_monthly_active_users_by_service(self): + # (No users, no filtering) -> empty result + result = self.get_success(self.store.get_monthly_active_users_by_service()) + + self.assertEqual(len(result), 0) + + # (Some users, no filtering) -> non-empty result + appservice1_user1 = "@appservice1_user1:example.com" + appservice2_user1 = "@appservice2_user1:example.com" + service1 = "service1" + service2 = "service2" + self.get_success( + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + ) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user1)) + self.get_success( + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + ) + self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) + + result = self.get_success(self.store.get_monthly_active_users_by_service()) + + self.assertEqual(len(result), 2) + self.assertIn((service1, appservice1_user1), result) + self.assertIn((service2, appservice2_user1), result) + + # (Some users, end-timestamp filtering) -> non-empty result + appservice1_user2 = "@appservice1_user2:example.com" + timestamp1 = self.reactor.seconds() + self.reactor.advance(5) + timestamp2 = self.reactor.seconds() + self.get_success( + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + ) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user2)) + + result = self.get_success( + self.store.get_monthly_active_users_by_service( + end_timestamp=round(timestamp1 * 1000) + ) + ) + + self.assertEqual(len(result), 2) + self.assertNotIn((service1, appservice1_user2), result) + + # (Some users, start-timestamp filtering) -> non-empty result + result = self.get_success( + self.store.get_monthly_active_users_by_service( + start_timestamp=round(timestamp2 * 1000) + ) + ) + + self.assertEqual(len(result), 1) + self.assertIn((service1, appservice1_user2), result) + + # (Some users, full-timestamp filtering) -> non-empty result + native_user1 = "@native_user1:example.com" + native = "native" + timestamp3 = self.reactor.seconds() + self.reactor.advance(100) + self.get_success( + self.store.register_user( + user_id=native_user1, password_hash=None, appservice_id=native + ) + ) + self.get_success(self.store.upsert_monthly_active_user(native_user1)) + + result = self.get_success( + self.store.get_monthly_active_users_by_service( + start_timestamp=round(timestamp2 * 1000), + end_timestamp=round(timestamp3 * 1000), + ) + ) + + self.assertEqual(len(result), 1) + self.assertIn((service1, appservice1_user2), result) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 08cc6023..8dfaa055 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() def test_purge_history(self): """ @@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token self.get_success( - self.storage.purge_events.purge_history(self.room_id, token_str, True) + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, True + ) ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted @@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token f = self.get_failure( - self.storage.purge_events.purge_history(self.room_id, event, True), + self._storage_controllers.purge_events.purge_history( + self.room_id, event, True + ), SynapseError, ) self.assertIn("greater than forward", f.value.args[0]) @@ -98,14 +102,17 @@ class PurgeTests(HomeserverTestCase): first = self.helper.send(self.room_id, body="test1") # Get the current room state. - state_handler = self.hs.get_state_handler() create_event = self.get_success( - state_handler.get_current_state(self.room_id, "m.room.create", "") + self._storage_controllers.state.get_current_state_event( + self.room_id, "m.room.create", "" + ) ) self.assertIsNotNone(create_event) # Purge everything before this topological token - self.get_success(self.storage.purge_events.purge_room(self.room_id)) + self.get_success( + self._storage_controllers.purge_events.purge_room(self.room_id) + ) # The events aren't found. self.store._invalidate_get_event_cache(create_event.event_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d8d17ef3..6c4e63b7 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage = hs.get_storage_controllers() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.storage.persistence.persist_event(event_1, context_1)) + self.get_success(self._storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.storage.persistence.persist_event(event_2, context_2)) + self.get_success(self._storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.storage.persistence.persist_event(redaction_event, context) + self._storage.persistence.persist_event(redaction_event, context) ) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 5b011e18..3c79dabc 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def inject_room_event(self, **kwargs): self.get_success( - self.storage.persistence.persist_event( + self._storage_controllers.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) @@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) @@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8dfc1e1d..e747c6b5 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase): prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) prev_event = self.get_success(store.get_event(prev_event_ids[0])) prev_state_map = self.get_success( - self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + self.hs.get_storage_controllers().state.get_state_ids_for_event( + prev_event_ids[0] + ) ) event_dict = { diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index a2a9c05f..1218786d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -34,7 +34,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override] # We can't test the RoomMemberStore on its own without the other event # storage logic diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f88f1c55..8043bdbd 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f1964eb..5b60cf52 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -134,7 +134,6 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 303e190b..cae14151 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -17,8 +17,12 @@ from unittest import mock from twisted.internet.defer import CancelledError, ensureDeferred -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) +from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d2) + + +class PartialCurrentStateTrackerTestCase(TestCase): + def setUp(self) -> None: + self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) + + self.tracker = PartialCurrentStateTracker(self.mock_store) + + def test_does_not_block_for_full_state_rooms(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(False) + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_blocks_for_partial_room_state(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d = ensureDeferred(self.tracker.await_full_state("room_id")) + + # there should be no result yet + self.assertNoResult(d) + + # notifying that the room has been de-partial-stated should unblock + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d) + + def test_un_partial_state_race(self): + # We should correctly handle race between awaiting the state and us + # un-partialling the state + async def is_partial_state_room(events): + self.tracker.notify_un_partial_stated("room_id") + return True + + self.mock_store.is_partial_state_room.side_effect = is_partial_state_room + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_cancellation(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d1 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d1) + + d2 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d2) + + d1.cancel() + self.assertFailure(d1, CancelledError) + + # d2 should still be waiting! + self.assertNoResult(d2) + + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d2) diff --git a/tests/test_mau.py b/tests/test_mau.py index 5bbc361a..f14fcb7d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -105,7 +105,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender:test", namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, @@ -251,7 +250,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_1, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender_1:test", namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]}, @@ -262,7 +260,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_2, - hostname=self.hs.hostname, id="AnotherASID", sender="@as_sender_2:test", namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]}, diff --git a/tests/test_server.py b/tests/test_server.py index f2ffbc89..0f1eb43c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,18 +13,28 @@ # limitations under the License. import re +from http import HTTPStatus +from typing import Tuple from twisted.internet.defer import Deferred from twisted.web.resource import Resource from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def -from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource -from synapse.http.site import SynapseSite +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + JsonResource, + OptionsResource, + cancellable, +) +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import make_deferred_yieldable +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.server import ( FakeSite, ThreadedMemoryReactorClock, @@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) + + +class CancellableDirectServeJsonResource(DirectServeJsonResource): + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class CancellableDirectServeHtmlResource(DirectServeHtmlResource): + ERROR_TEMPLATE = "{code} {msg}" + + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + +class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeJsonResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeJsonResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) + + +class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeHtmlResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeHtmlResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body=b"499 Request cancelled", + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, channel, expect_cancellation=False, expected_body=b"ok" + ) diff --git a/tests/test_state.py b/tests/test_state.py index e4baa691..95f81beb 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -88,6 +88,9 @@ class _DummyStore: return groups + async def get_state_ids_for_group(self, state_group, state_filter=None): + return self._group_to_state[state_group] + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids ): @@ -126,6 +129,19 @@ class _DummyStore: async def get_room_version_id(self, room_id): return RoomVersions.V1.identifier + async def get_state_group_for_events(self, event_ids): + res = {} + for event in event_ids: + res[event] = self._event_to_state_group[event] + return res + + async def get_state_for_groups(self, groups): + res = {} + for group in groups: + state = self._group_to_state[group] + res[group] = state + return res + class DictObj(dict): def __init__(self, **kwargs): @@ -163,12 +179,12 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): self.dummy_store = _DummyStore() - storage = Mock(main=self.dummy_store, state=self.dummy_store) + storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", "get_datastores", - "get_storage", + "get_storage_controllers", "get_auth", "get_state_handler", "get_clock", @@ -183,7 +199,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) - hs.get_storage.return_value = storage + hs.get_storage_controllers.return_value = storage_controllers self.state = StateHandler(hs) self.event_id = 0 @@ -426,7 +442,12 @@ class StateTestCase(unittest.TestCase): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) @@ -451,7 +472,12 @@ class StateTestCase(unittest.TestCase): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) diff --git a/tests/test_types.py b/tests/test_types.py index 80888a74..0b10dae8 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart +from synapse.types import RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest @@ -62,25 +62,6 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): self.assertFalse(RoomAlias.is_valid(id_string)) -class GroupIDTestCase(unittest.TestCase): - def test_parse(self): - group_id = GroupID.from_string("+group/=_-.123:my.domain") - self.assertEqual("group/=_-.123", group_id.localpart) - self.assertEqual("my.domain", group_id.domain) - - def test_validate(self): - bad_ids = ["$badsigil:domain", "+:empty"] + [ - "+group" + c + ":domain" for c in "A%?æ£" - ] - for id_string in bad_ids: - try: - GroupID.from_string(id_string) - self.fail("Parsing '%s' should raise exception" % id_string) - except SynapseError as exc: - self.assertEqual(400, exc.code) - self.assertEqual("M_INVALID_PARAM", exc.errcode) - - class MapUsernameTestCase(unittest.TestCase): def testPassThrough(self): self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index c654e36e..8027c7a8 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -70,7 +70,7 @@ async def inject_event( """ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - persistence = hs.get_storage().persistence + persistence = hs.get_storage_controllers().persistence assert persistence is not None await persistence.persist_event(event, context) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index d0230f9e..f338af6c 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): outlier = self._inject_outlier() self.assertEqual( self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier] + ) ), [outlier], ) @@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): evt = self._inject_message("@unerased:local_hs") filtered = self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier, evt] + ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") self.assertEqual(filtered[0], outlier) @@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... but other servers should only be able to see the outlier (the other should # be redacted) filtered = self.get_success( - filter_events_for_server(self.storage, "other_server", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "other_server", [outlier, evt] + ) ) self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[1].event_id, evt.event_id) @@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... and the filtering happens. filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) for i in range(0, len(events_to_filter)): @@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_room_member( @@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_message( @@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_outlier(self) -> EventBase: @@ -234,7 +250,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event(event, EventContext.for_outlier()) + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) + ) ) return event @@ -291,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@user:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], ) ), [invite_event, reject_event], @@ -301,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@other:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@other:test", + [invite_event, reject_event], ) ), [], diff --git a/tests/unittest.py b/tests/unittest.py index 9afa68c1..e7f255b4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): self.site, method=method, path=path, - content=content or "", + content=content if content is not None else "", shorthand=False, await_result=await_result, custom_headers=custom_headers, diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 321fc177..67173a4f 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -14,8 +14,9 @@ from typing import List -from unittest.mock import Mock +from unittest.mock import Mock, patch +from synapse.metrics.jemalloc import JemallocStats from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.treecache import TreeCache @@ -316,3 +317,58 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key1"), None) self.assertEqual(cache.get("key2"), 3) + + +class MemoryEvictionTestCase(unittest.HomeserverTestCase): + @override_config( + { + "caches": { + "cache_autotuning": { + "max_cache_memory_usage": "700M", + "target_cache_memory_usage": "500M", + "min_cache_ttl": "5m", + } + } + } + ) + @patch("synapse.util.caches.lrucache.get_jemalloc_stats") + def test_evict_memory(self, jemalloc_interface) -> None: + mock_jemalloc_class = Mock(spec=JemallocStats) + jemalloc_interface.return_value = mock_jemalloc_class + + # set the return value of get_stat() to be greater than max_cache_memory_usage + mock_jemalloc_class.get_stat.return_value = 924288000 + + setup_expire_lru_cache_entries(self.hs) + cache = LruCache(4, clock=self.hs.get_clock()) + + cache["key1"] = 1 + cache["key2"] = 2 + + # advance the reactor less than the min_cache_ttl + self.reactor.advance(60 * 2) + + # our items should still be in the cache + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), 2) + + # advance the reactor past the min_cache_ttl + self.reactor.advance(60 * 6) + + # the items should be cleared from cache + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), None) + + # add more stuff to caches + cache["key1"] = 1 + cache["key2"] = 2 + + # set the return value of get_stat() to be lower than target_cache_memory_usage + mock_jemalloc_class.get_stat.return_value = 10000 + + # advance the reactor past the min_cache_ttl + self.reactor.advance(60 * 6) + + # the items should still be in the cache + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), 2) diff --git a/tests/utils.py b/tests/utils.py index d4ba3a9b..3059c453 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -264,7 +264,7 @@ class MockClock: async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" - persistence_store = hs.get_storage().persistence + persistence_store = hs.get_storage_controllers().persistence store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() -- cgit v1.2.3