summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py12
-rw-r--r--tests/appservice/test_api.py18
-rw-r--r--tests/crypto/test_keyring.py53
-rw-r--r--tests/handlers/test_appservice.py125
-rw-r--r--tests/handlers/test_cas.py17
-rw-r--r--tests/handlers/test_device.py57
-rw-r--r--tests/handlers/test_federation.py2
-rw-r--r--tests/handlers/test_presence.py600
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py323
-rw-r--r--tests/logging/test_remote_handler.py12
-rw-r--r--tests/push/test_email.py55
-rw-r--r--tests/replication/storage/_base.py17
-rw-r--r--tests/replication/storage/test_events.py53
-rw-r--r--tests/replication/tcp/streams/test_events.py10
-rw-r--r--tests/replication/tcp/streams/test_to_device.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/rest/admin/test_federation.py6
-rw-r--r--tests/rest/admin/test_room.py159
-rw-r--r--tests/rest/admin/test_server_notice.py20
-rw-r--r--tests/rest/admin/test_user.py26
-rw-r--r--tests/rest/client/test_account.py15
-rw-r--r--tests/rest/client/test_login.py8
-rw-r--r--tests/rest/client/test_receipts.py221
-rw-r--r--tests/rest/client/test_register.py6
-rw-r--r--tests/rest/client/test_rooms.py6
-rw-r--r--tests/rest/client/test_sync.py154
-rw-r--r--tests/storage/databases/main/test_lock.py2
-rw-r--r--tests/storage/test_cleanup_extrems.py14
-rw-r--r--tests/storage/test_event_chain.py6
-rw-r--r--tests/storage/test_event_federation.py6
-rw-r--r--tests/storage/test_keys.py137
-rw-r--r--tests/storage/test_profile.py4
-rw-r--r--tests/storage/test_registration.py48
-rw-r--r--tests/storage/test_txn_limit.py2
-rw-r--r--tests/storage/test_user_filters.py4
-rw-r--r--tests/test_federation.py26
-rw-r--r--tests/test_visibility.py8
-rw-r--r--tests/unittest.py26
-rw-r--r--tests/util/caches/test_descriptors.py4
39 files changed, 1721 insertions, 545 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index dcd01d56..e00d7215 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- # This just needs to return a truth-y value.
- self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+
+ class FakeUserInfo:
+ is_guest = False
+
+ self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
@@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
def test_get_guest_user_from_macaroon(self) -> None:
- self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
+ class FakeUserInfo:
+ is_guest = True
+
+ self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org"
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 75fb5fae..366b6fd5 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -76,7 +76,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> List[JsonDict]:
# Ensure the access token is passed as a header.
- if not headers or not headers.get("Authorization"):
+ if not headers or not headers.get(b"Authorization"):
raise RuntimeError("Access token not provided")
# ... and not as a query param
if b"access_token" in args:
@@ -84,7 +84,9 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
"Access token should not be passed as a query param."
)
- self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+ self.assertEqual(
+ headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+ )
self.request_url = url
if url == URL_USER:
return SUCCESS_RESULT_USER
@@ -152,11 +154,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
# Ensure the access token is passed as a both a query param and in the headers.
if not args.get(b"access_token"):
raise RuntimeError("Access token should be provided in query params.")
- if not headers or not headers.get("Authorization"):
+ if not headers or not headers.get(b"Authorization"):
raise RuntimeError("Access token should be provided in auth headers.")
self.assertEqual(args.get(b"access_token"), TOKEN)
- self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+ self.assertEqual(
+ headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+ )
self.request_url = url
if url == URL_USER:
return SUCCESS_RESULT_USER
@@ -208,10 +212,12 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> JsonDict:
# Ensure the access token is passed as both a header and query arg.
- if not headers.get("Authorization"):
+ if not headers.get(b"Authorization"):
raise RuntimeError("Access token not provided")
- self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+ self.assertEqual(
+ headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+ )
return RESPONSE
# We assign to a method, which mypy doesn't like.
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f93ba5d4..c5700771 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -13,7 +13,7 @@
# limitations under the License.
import time
from typing import Any, Dict, List, Optional, cast
-from unittest.mock import AsyncMock, Mock
+from unittest.mock import Mock
import attr
import canonicaljson
@@ -189,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key("1")
- r = self.hs.get_datastores().main.store_server_keys_json(
+ r = self.hs.get_datastores().main.store_server_keys_response(
"server9",
- get_key_id(key1),
from_server="test",
- ts_now_ms=int(time.time() * 1000),
- ts_expires_ms=1000,
+ ts_added_ms=int(time.time() * 1000),
+ verify_keys={
+ get_key_id(key1): FetchKeyResult(
+ verify_key=get_verify_key(key1), valid_until_ts=1000
+ )
+ },
# The entire response gets signed & stored, just include the bits we
# care about.
- key_json_bytes=canonicaljson.encode_canonical_json(
- {
- "verify_keys": {
- get_key_id(key1): {
- "key": encode_verify_key_base64(get_verify_key(key1))
- }
+ response_json={
+ "verify_keys": {
+ get_key_id(key1): {
+ "key": encode_verify_key_base64(get_verify_key(key1))
}
}
- ),
+ },
)
self.get_success(r)
@@ -285,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)
- def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
- """Tests that we correctly handle key requests for keys we've stored
- with a null `ts_valid_until_ms`
- """
- mock_fetcher = Mock()
- mock_fetcher.get_keys = AsyncMock(return_value={})
-
- key1 = signedjson.key.generate_signing_key("1")
- r = self.hs.get_datastores().main.store_server_signature_keys(
- "server9",
- int(time.time() * 1000),
- # None is not a valid value in FetchKeyResult, but we're abusing this
- # API to insert null values into the database. The nulls get converted
- # to 0 when fetched in KeyStore.get_server_signature_keys.
- {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
- )
- self.get_success(r)
-
- json1: JsonDict = {}
- signedjson.sign.sign_json(json1, "server9", key1)
-
- # should succeed on a signed object with a 0 minimum_valid_until_ms
- d = self.hs.get_datastores().main.get_server_signature_keys(
- [("server9", get_key_id(key1))]
- )
- result = self.get_success(d)
- self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
-
def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key("1")
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 46d02209..a7e6cdd6 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -422,6 +422,18 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ self.exclusive_as_user_2_device_id = "exclusive_as_device_2"
+ self.exclusive_as_user_2 = self.register_user("exclusive_as_user_2", "password")
+ self.exclusive_as_user_2_token = self.login(
+ "exclusive_as_user_2", "password", self.exclusive_as_user_2_device_id
+ )
+
+ self.exclusive_as_user_3_device_id = "exclusive_as_device_3"
+ self.exclusive_as_user_3 = self.register_user("exclusive_as_user_3", "password")
+ self.exclusive_as_user_3_token = self.login(
+ "exclusive_as_user_3", "password", self.exclusive_as_user_3_device_id
+ )
+
def _notify_interested_services(self) -> None:
# This is normally set in `notify_interested_services` but we need to call the
# internal async version so the reactor gets pushed to completion.
@@ -849,6 +861,119 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
for count in service_id_to_message_count.values():
self.assertEqual(count, number_of_messages)
+ @unittest.override_config(
+ {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
+ )
+ def test_application_services_receive_local_to_device_for_many_users(self) -> None:
+ """
+ Test that when a user sends a to-device message to many users
+ in an application service's user namespace, the
+ application service will receive all of them.
+ """
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ },
+ {
+ "regex": "@exclusive_as_user_2:.+",
+ "exclusive": True,
+ },
+ {
+ "regex": "@exclusive_as_user_3:.+",
+ "exclusive": True,
+ },
+ ],
+ },
+ )
+
+ # Have local_user send a to-device message to exclusive_as_users
+ message_content = {"some_key": "some really interesting value"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/3",
+ content={
+ "messages": {
+ self.exclusive_as_user: {
+ self.exclusive_as_user_device_id: message_content
+ },
+ self.exclusive_as_user_2: {
+ self.exclusive_as_user_2_device_id: message_content
+ },
+ self.exclusive_as_user_3: {
+ self.exclusive_as_user_3_device_id: message_content
+ },
+ }
+ },
+ access_token=self.local_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Have exclusive_as_user send a to-device message to local_user
+ for user_token in [
+ self.exclusive_as_user_token,
+ self.exclusive_as_user_2_token,
+ self.exclusive_as_user_3_token,
+ ]:
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/4",
+ content={
+ "messages": {
+ self.local_user: {self.local_user_device_id: message_content}
+ }
+ },
+ access_token=user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Check if our application service - that is interested in exclusive_as_user - received
+ # the to-device message as part of an AS transaction.
+ # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
+ #
+ # The uninterested application service should not have been notified at all.
+ self.send_mock.assert_called_once()
+ (
+ service,
+ _events,
+ _ephemeral,
+ to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Assert that this was the same to-device message that local_user sent
+ self.assertEqual(service, interested_appservice)
+
+ # Assert expected number of messages
+ self.assertEqual(len(to_device_messages), 3)
+
+ for device_msg in to_device_messages:
+ self.assertEqual(device_msg["type"], "m.room_key_request")
+ self.assertEqual(device_msg["sender"], self.local_user)
+ self.assertEqual(device_msg["content"], message_content)
+
+ self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
+ self.assertEqual(
+ to_device_messages[0]["to_device_id"],
+ self.exclusive_as_user_device_id,
+ )
+
+ self.assertEqual(to_device_messages[1]["to_user_id"], self.exclusive_as_user_2)
+ self.assertEqual(
+ to_device_messages[1]["to_device_id"],
+ self.exclusive_as_user_2_device_id,
+ )
+
+ self.assertEqual(to_device_messages[2]["to_user_id"], self.exclusive_as_user_3)
+ self.assertEqual(
+ to_device_messages[2]["to_device_id"],
+ self.exclusive_as_user_3_device_id,
+ )
+
def _register_application_service(
self,
namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 8582b1cd..13e2cd15 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -197,6 +197,23 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
+ @override_config({"cas_config": {"enable_registration": False}})
+ def test_map_cas_user_does_not_register_new_user(self) -> None:
+ """Ensures new users are not registered if the enabled registration flag is disabled."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
+
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler was not called as expected
+ auth_handler.complete_sso_login.assert_not_called()
+
def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 55a4f95e..d4ed0683 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -30,6 +30,7 @@ from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
from tests import unittest
from tests.unittest import override_config
@@ -49,6 +50,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.store = hs.get_datastores().main
+ self.device_message_handler = hs.get_device_message_handler()
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -211,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(res)
+ def test_delete_device_and_big_device_inbox(self) -> None:
+ """Check that deleting a big device inbox is staged and batched asynchronously."""
+ DEVICE_ID = "abc"
+ sender = "@sender:" + self.hs.hostname
+ receiver = "@receiver:" + self.hs.hostname
+ self._record_user(sender, DEVICE_ID, DEVICE_ID)
+ self._record_user(receiver, DEVICE_ID, DEVICE_ID)
+
+ # queue a bunch of messages in the inbox
+ requester = create_requester(sender, device_id=DEVICE_ID)
+ for i in range(DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
+ self.get_success(
+ self.device_message_handler.send_device_message(
+ requester, "message_type", {receiver: {"*": {"val": i}}}
+ )
+ )
+
+ # delete the device
+ self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID]))
+
+ # messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="device_inbox",
+ keyvalues={"user_id": receiver},
+ retcols=("user_id", "device_id", "stream_id"),
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(10, len(res))
+
+ # wait for the task scheduler to do a second delete pass
+ self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+ # remaining messages should now be deleted
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="device_inbox",
+ keyvalues={"user_id": receiver},
+ retcols=("user_id", "device_id", "stream_id"),
+ desc="get_device_id_from_device_inbox",
+ )
+ )
+ self.assertEqual(0, len(res))
+
def test_update_device(self) -> None:
self._record_users()
@@ -414,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
return hs
@@ -440,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
# Create a new login for the user and dehydrated the device
- device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+ device_id, access_token, _expiration_time, refresh_token = self.get_success(
self.registration.register_device(
user_id=user_id,
device_id=None,
initial_display_name="new device",
+ should_issue_refresh_token=True,
)
)
@@ -475,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(user_info.device_id, retrieved_device_id)
+ # make sure the user device has the refresh token
+ assert refresh_token is not None
+ self.get_success(
+ self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
+ )
+
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 21d63ab1..4fc07424 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -262,7 +262,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
if (ev.type, ev.state_key)
in {("m.room.create", ""), ("m.room.member", remote_server_user_id)}
]
- for _ in range(0, 8):
+ for _ in range(8):
event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 88a16193..41c8c44e 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -21,11 +21,12 @@ from signedjson.key import generate_signing_key
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership, PresenceState
-from synapse.api.presence import UserPresenceState
+from synapse.api.presence import UserDevicePresenceState, UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.federation.sender import FederationSender
from synapse.handlers.presence import (
+ BUSY_ONLINE_TIMEOUT,
EXTERNAL_PROCESS_EXPIRY,
FEDERATION_PING_INTERVAL,
FEDERATION_TIMEOUT,
@@ -352,6 +353,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_idle_timer(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -362,8 +364,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
assert new_state is not None
@@ -376,6 +391,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
presence state into unavailable.
"""
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -386,8 +402,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
assert new_state is not None
@@ -396,6 +425,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_sync_timeout(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -406,8 +436,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
assert new_state is not None
@@ -416,6 +459,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_sync_online(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -426,9 +470,20 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
new_state = handle_timeout(
- state, is_mine=True, syncing_user_ids={user_id}, now=now
+ state,
+ is_mine=True,
+ syncing_device_ids={(user_id, device_id)},
+ user_devices={device_id: device_state},
+ now=now,
)
self.assertIsNotNone(new_state)
@@ -438,6 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_federation_ping(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -449,14 +505,28 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)
def test_no_timeout(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -466,8 +536,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now,
last_federation_update_ts=now,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNone(new_state)
@@ -485,8 +568,9 @@ class PresenceTimeoutTestCase(unittest.TestCase):
status_msg=status_msg,
)
+ # Note that this is a remote user so we do not have their device information.
new_state = handle_timeout(
- state, is_mine=False, syncing_user_ids=set(), now=now
+ state, is_mine=False, syncing_device_ids=set(), user_devices={}, now=now
)
self.assertIsNotNone(new_state)
@@ -496,6 +580,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_last_active(self) -> None:
user_id = "@foo:bar"
+ device_id = "dev-1"
status_msg = "I'm here!"
now = 5000000
@@ -507,8 +592,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now,
status_msg=status_msg,
)
+ device_state = UserDevicePresenceState(
+ user_id=user_id,
+ device_id=device_id,
+ state=state.state,
+ last_active_ts=state.last_active_ts,
+ last_sync_ts=state.last_user_sync_ts,
+ )
- new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+ new_state = handle_timeout(
+ state,
+ is_mine=True,
+ syncing_device_ids=set(),
+ user_devices={device_id: device_state},
+ now=now,
+ )
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)
@@ -579,7 +677,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
[
(PresenceState.BUSY, PresenceState.BUSY),
(PresenceState.ONLINE, PresenceState.ONLINE),
- (PresenceState.UNAVAILABLE, PresenceState.UNAVAILABLE),
+ (PresenceState.UNAVAILABLE, PresenceState.ONLINE),
# Offline syncs don't update the state.
(PresenceState.OFFLINE, PresenceState.ONLINE),
]
@@ -800,6 +898,486 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should now be online
self.assertEqual(state.state, PresenceState.ONLINE)
+ @parameterized.expand(
+ # A list of tuples of 4 strings:
+ #
+ # * The presence state of device 1.
+ # * The presence state of device 2.
+ # * The expected user presence state after both devices have synced.
+ # * The expected user presence state after device 1 has idled.
+ # * The expected user presence state after device 2 has idled.
+ # * True to use workers, False a monolith.
+ [
+ (*cases, workers)
+ for workers in (False, True)
+ for cases in [
+ # If both devices have the same state, online should eventually idle.
+ # Otherwise, the state doesn't change.
+ (
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ ),
+ # If the second device has a "lower" state it should fallback to it,
+ # except for "busy" which overrides.
+ (
+ PresenceState.BUSY,
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ # If the second device has a "higher" state it should override.
+ (
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ ]
+ ],
+ name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
+ )
+ @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
+ def test_set_presence_from_syncing_multi_device(
+ self,
+ dev_1_state: str,
+ dev_2_state: str,
+ expected_state_1: str,
+ expected_state_2: str,
+ expected_state_3: str,
+ test_with_workers: bool,
+ ) -> None:
+ """
+ Test the behaviour of multiple devices syncing at the same time.
+
+ Roughly the user's presence state should be set to the "highest" priority
+ of all the devices. When a device then goes offline its state should be
+ discarded and the next highest should win.
+
+ Note that these tests use the idle timer (and don't close the syncs), it
+ is unlikely that a *single* sync would last this long, but is close enough
+ to continually syncing with that current state.
+ """
+ user_id = f"@test:{self.hs.config.server.server_name}"
+
+ # By default, we call /sync against the main process.
+ worker_presence_handler = self.presence_handler
+ if test_with_workers:
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+ )
+ worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
+ # 1. Sync with the first device.
+ self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-1",
+ affect_presence=dev_1_state != PresenceState.OFFLINE,
+ presence_state=dev_1_state,
+ ),
+ by=0.01,
+ )
+
+ # 2. Wait half the idle timer.
+ self.reactor.advance(IDLE_TIMER / 1000 / 2)
+ self.reactor.pump([0.1])
+
+ # 3. Sync with the second device.
+ self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-2",
+ affect_presence=dev_2_state != PresenceState.OFFLINE,
+ presence_state=dev_2_state,
+ ),
+ by=0.01,
+ )
+
+ # 4. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+
+ # When testing with workers, make another random sync (with any *different*
+ # user) to keep the process information from expiring.
+ #
+ # This is due to EXTERNAL_PROCESS_EXPIRY being equivalent to IDLE_TIMER.
+ if test_with_workers:
+ with self.get_success(
+ worker_presence_handler.user_syncing(
+ f"@other-user:{self.hs.config.server.server_name}",
+ "dev-3",
+ affect_presence=True,
+ presence_state=PresenceState.ONLINE,
+ ),
+ by=0.01,
+ ):
+ pass
+
+ # 5. Advance such that the first device should be discarded (the idle timer),
+ # then pump so _handle_timeouts function to called.
+ self.reactor.advance(IDLE_TIMER / 1000 / 2)
+ self.reactor.pump([0.01])
+
+ # 6. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+
+ # 7. Advance such that the second device should be discarded (half the idle timer),
+ # then pump so _handle_timeouts function to called.
+ self.reactor.advance(IDLE_TIMER / 1000 / 2)
+ self.reactor.pump([0.1])
+
+ # 8. The devices are still "syncing" (the sync context managers were never
+ # closed), so might idle.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_3)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_3)
+
+ @parameterized.expand(
+ # A list of tuples of 4 strings:
+ #
+ # * The presence state of device 1.
+ # * The presence state of device 2.
+ # * The expected user presence state after both devices have synced.
+ # * The expected user presence state after device 1 has stopped syncing.
+ # * True to use workers, False a monolith.
+ [
+ (*cases, workers)
+ for workers in (False, True)
+ for cases in [
+ # If both devices have the same state, nothing exciting should happen.
+ (
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ PresenceState.OFFLINE,
+ ),
+ # If the second device has a "lower" state it should fallback to it,
+ # except for "busy" which overrides.
+ (
+ PresenceState.BUSY,
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.BUSY,
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ ),
+ (
+ PresenceState.ONLINE,
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.OFFLINE,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ ),
+ # If the second device has a "higher" state it should override.
+ (
+ PresenceState.ONLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ PresenceState.BUSY,
+ ),
+ (
+ PresenceState.UNAVAILABLE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ PresenceState.ONLINE,
+ ),
+ (
+ PresenceState.OFFLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.UNAVAILABLE,
+ ),
+ ]
+ ],
+ name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
+ )
+ @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
+ def test_set_presence_from_non_syncing_multi_device(
+ self,
+ dev_1_state: str,
+ dev_2_state: str,
+ expected_state_1: str,
+ expected_state_2: str,
+ test_with_workers: bool,
+ ) -> None:
+ """
+ Test the behaviour of multiple devices syncing at the same time.
+
+ Roughly the user's presence state should be set to the "highest" priority
+ of all the devices. When a device then goes offline its state should be
+ discarded and the next highest should win.
+
+ Note that these tests use the idle timer (and don't close the syncs), it
+ is unlikely that a *single* sync would last this long, but is close enough
+ to continually syncing with that current state.
+ """
+ user_id = f"@test:{self.hs.config.server.server_name}"
+
+ # By default, we call /sync against the main process.
+ worker_presence_handler = self.presence_handler
+ if test_with_workers:
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+ )
+ worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
+ # 1. Sync with the first device.
+ sync_1 = self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-1",
+ affect_presence=dev_1_state != PresenceState.OFFLINE,
+ presence_state=dev_1_state,
+ ),
+ by=0.1,
+ )
+
+ # 2. Sync with the second device.
+ sync_2 = self.get_success(
+ worker_presence_handler.user_syncing(
+ user_id,
+ "dev-2",
+ affect_presence=dev_2_state != PresenceState.OFFLINE,
+ presence_state=dev_2_state,
+ ),
+ by=0.1,
+ )
+
+ # 3. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_1)
+
+ # 4. Disconnect the first device.
+ with sync_1:
+ pass
+
+ # 5. Advance such that the first device should be discarded (the sync timeout),
+ # then pump so _handle_timeouts function to called.
+ self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000)
+ self.reactor.pump([5])
+
+ # 6. Assert the expected presence state.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, expected_state_2)
+
+ # 7. Disconnect the second device.
+ with sync_2:
+ pass
+
+ # 8. Advance such that the second device should be discarded (the sync timeout),
+ # then pump so _handle_timeouts function to called.
+ if dev_1_state == PresenceState.BUSY or dev_2_state == PresenceState.BUSY:
+ timeout = BUSY_ONLINE_TIMEOUT
+ else:
+ timeout = SYNC_ONLINE_TIMEOUT
+ self.reactor.advance(timeout / 1000)
+ self.reactor.pump([5])
+
+ # 9. There are no more devices, should be offline.
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+ if test_with_workers:
+ state = self.get_success(
+ worker_presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+
def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message"""
status_msg = "I'm here!"
@@ -1280,7 +1858,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
event = self.get_success(
- builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+ builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
)
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 0d17f2fe..9f63fa6f 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -15,7 +15,7 @@ import base64
import logging
import os
from typing import Generator, List, Optional, cast
-from unittest.mock import AsyncMock, patch
+from unittest.mock import AsyncMock, call, patch
import treq
from netaddr import IPSet
@@ -651,9 +651,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# .well-known request fails.
self.reactor.pump((0.4,))
- # now there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv1"
+ # now there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv1"), call(b"_matrix._tcp.testserv1")]
)
# we should fall back to a direct connection
@@ -737,9 +737,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# .well-known request fails.
self.reactor.pump((0.4,))
- # now there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ # now there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# we should fall back to a direct connection
@@ -788,9 +788,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
content=b'{ "m.server": "target-server" }',
)
- # there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server"
+ # there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.target-server"),
+ call(b"_matrix._tcp.target-server"),
+ ]
)
# now we should get a connection to the target server
@@ -878,9 +881,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
- # there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server"
+ # there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.target-server"),
+ call(b"_matrix._tcp.target-server"),
+ ]
)
# now we should get a connection to the target server
@@ -942,9 +948,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
client_factory, expected_sni=b"testserv", content=b"NOT JSON"
)
- # now there should be a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ # now there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# we should fall back to a direct connection
@@ -1016,14 +1022,14 @@ class MatrixFederationAgentTests(unittest.TestCase):
# there should be no requests
self.assertEqual(len(http_proto.requests), 0)
- # and there should be a SRV lookup instead
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ # and there should be two SRV lookups instead
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
def test_get_hostname_srv(self) -> None:
"""
- Test the behaviour when there is a single SRV record
+ Test the behaviour when there is a single SRV record for _matrix-fed.
"""
self.agent = self._make_agent()
@@ -1039,7 +1045,51 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the request for a .well-known will have failed with a DNS lookup error.
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ b"_matrix-fed._tcp.testserv"
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_hostname_srv_legacy(self) -> None:
+ """
+ Test the behaviour when there is a single SRV record for _matrix.
+ """
+ self.agent = self._make_agent()
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [Server(host=b"srvtarget", port=8443)],
+ ]
+ self.reactor.lookups["srvtarget"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # the request for a .well-known will have failed with a DNS lookup error.
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# Make sure treq is trying to connect
@@ -1065,7 +1115,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_get_well_known_srv(self) -> None:
"""Test the behaviour when the .well-known redirects to a place where there
- is a SRV.
+ is a _matrix-fed SRV record.
"""
self.agent = self._make_agent()
@@ -1096,7 +1146,72 @@ class MatrixFederationAgentTests(unittest.TestCase):
# there should be a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server"
+ b"_matrix-fed._tcp.target-server"
+ )
+
+ # now we should get a connection to the target of the SRV record
+ self.assertEqual(len(clients), 2)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, "5.6.7.8")
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, expected_sni=b"target-server"
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_well_known_srv_legacy(self) -> None:
+ """Test the behaviour when the .well-known redirects to a place where there
+ is a _matrix SRV record.
+ """
+ self.agent = self._make_agent()
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["srvtarget"] = "5.6.7.8"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [Server(host=b"srvtarget", port=8443)],
+ ]
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ content=b'{ "m.server": "target-server" }',
+ )
+
+ # there should be two SRV lookups
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.target-server"),
+ call(b"_matrix._tcp.target-server"),
+ ]
)
# now we should get a connection to the target of the SRV record
@@ -1158,8 +1273,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.4,))
# now there should have been a SRV lookup
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.xn--bcher-kva.com"
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.xn--bcher-kva.com"),
+ call(b"_matrix._tcp.xn--bcher-kva.com"),
+ ]
)
# We should fall back to port 8448
@@ -1188,7 +1306,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.successResultOf(test_d)
def test_idna_srv_target(self) -> None:
- """test the behaviour when the target of a SRV record has idna chars"""
+ """test the behaviour when the target of a _matrix-fed SRV record has idna chars"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.return_value = [
@@ -1204,7 +1322,57 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.xn--bcher-kva.com"
+ b"_matrix-fed._tcp.xn--bcher-kva.com"
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, expected_sni=b"xn--bcher-kva.com"
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"]
+ )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_idna_srv_target_legacy(self) -> None:
+ """test the behaviour when the target of a _matrix SRV record has idna chars"""
+ self.agent = self._make_agent()
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [Server(host=b"xn--trget-3qa.com", port=8443)],
+ ] # târget.com
+ self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(
+ b"matrix-federation://xn--bcher-kva.com/foo/bar"
+ )
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [
+ call(b"_matrix-fed._tcp.xn--bcher-kva.com"),
+ call(b"_matrix._tcp.xn--bcher-kva.com"),
+ ]
)
# Make sure treq is trying to connect
@@ -1394,7 +1562,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertIsNone(r.delegated_server)
def test_srv_fallbacks(self) -> None:
- """Test that other SRV results are tried if the first one fails."""
+ """Test that other SRV results are tried if the first one fails for _matrix-fed SRV."""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.return_value = [
@@ -1409,7 +1577,67 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv"
+ b"_matrix-fed._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Hasn't failed yet
+ self.assertNoResult(test_d)
+
+ # We shouldnow see an attempt to connect to the second server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8444)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_srv_fallbacks_legacy(self) -> None:
+ """Test that other SRV results are tried if the first one fails for _matrix SRV."""
+ self.agent = self._make_agent()
+
+ # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+ self.mock_resolver.resolve_service.side_effect = [
+ [],
+ [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ],
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_has_calls(
+ [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
)
# We should see an attempt to connect to the first server
@@ -1449,6 +1677,43 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ def test_srv_no_fallback_to_legacy(self) -> None:
+ """Test that _matrix SRV results are not tried if the _matrix-fed one fails."""
+ self.agent = self._make_agent()
+
+ # Return a failing entry for _matrix-fed.
+ self.mock_resolver.resolve_service.side_effect = [
+ [Server(host=b"target.com", port=8443)],
+ [],
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Only the _matrix-fed is checked, _matrix is ignored.
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix-fed._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Failed to resolve a server.
+ self.assertFailure(test_d, Exception)
+
class TestCachePeriodFromHeaders(unittest.TestCase):
def test_cache_control(self) -> None:
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index 5191e31a..45eac100 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -78,11 +78,11 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
logger = self.get_logger(handler)
# Send some debug messages
- for i in range(0, 3):
+ for i in range(3):
logger.debug("debug %s" % (i,))
# Send a bunch of useful messages
- for i in range(0, 7):
+ for i in range(7):
logger.info("info %s" % (i,))
# The last debug message pushes it past the maximum buffer
@@ -108,15 +108,15 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
logger = self.get_logger(handler)
# Send some debug messages
- for i in range(0, 3):
+ for i in range(3):
logger.debug("debug %s" % (i,))
# Send a bunch of useful messages
- for i in range(0, 10):
+ for i in range(10):
logger.warning("warn %s" % (i,))
# Send a bunch of info messages
- for i in range(0, 3):
+ for i in range(3):
logger.info("info %s" % (i,))
# The last debug message pushes it past the maximum buffer
@@ -144,7 +144,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
logger = self.get_logger(handler)
# Send a bunch of useful messages
- for i in range(0, 20):
+ for i in range(20):
logger.warning("warn %s" % (i,))
# Allow the reconnection
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 4b5c96ae..73a430dd 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -13,10 +13,12 @@
# limitations under the License.
import email.message
import os
+from http import HTTPStatus
from typing import Any, Dict, List, Sequence, Tuple
import attr
import pkg_resources
+from parameterized import parameterized
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
@@ -25,9 +27,11 @@ import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room
+from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
from synapse.server import HomeServer
from synapse.util import Clock
+from tests.server import FakeSite, make_request
from tests.unittest import HomeserverTestCase
@@ -175,6 +179,57 @@ class EmailPusherTests(HomeserverTestCase):
self._check_for_mail()
+ @parameterized.expand([(False,), (True,)])
+ def test_unsubscribe(self, use_post: bool) -> None:
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+ # We should get emailed about that message
+ args, kwargs = self._check_for_mail()
+
+ # That email should contain an unsubscribe link in the body and header.
+ msg: bytes = args[5]
+
+ # Multipart: plain text, base 64 encoded; html, base 64 encoded
+ multipart_msg = email.message_from_bytes(msg)
+ txt = multipart_msg.get_payload()[0].get_payload(decode=True).decode()
+ html = multipart_msg.get_payload()[1].get_payload(decode=True).decode()
+ self.assertIn("/_synapse/client/unsubscribe", txt)
+ self.assertIn("/_synapse/client/unsubscribe", html)
+
+ # The unsubscribe headers should exist.
+ assert multipart_msg.get("List-Unsubscribe") is not None
+ self.assertIsNotNone(multipart_msg.get("List-Unsubscribe-Post"))
+
+ # Open the unsubscribe link.
+ unsubscribe_link = multipart_msg["List-Unsubscribe"].strip("<>")
+ unsubscribe_resource = UnsubscribeResource(self.hs)
+ channel = make_request(
+ self.reactor,
+ FakeSite(unsubscribe_resource, self.reactor),
+ "POST" if use_post else "GET",
+ unsubscribe_link,
+ shorthand=False,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ # Ensure the pusher was removed.
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
+ )
+ self.assertEqual(pushers, [])
+
def test_invite_sends_email(self) -> None:
# Create a room and invite the user to it
room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py
index de26a62a..afcc80a8 100644
--- a/tests/replication/storage/_base.py
+++ b/tests/replication/storage/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Optional
+from typing import Any, Callable, Iterable, Optional
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -47,24 +47,31 @@ class BaseWorkerStoreTestCase(BaseStreamTestCase):
self.pump(0.1)
def check(
- self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
+ self,
+ method: str,
+ args: Iterable[Any],
+ expected_result: Optional[Any] = None,
+ asserter: Optional[Callable[[Any, Any, Optional[Any]], None]] = None,
) -> None:
+ if asserter is None:
+ asserter = self.assertEqual
+
master_result = self.get_success(getattr(self.master_store, method)(*args))
worker_result = self.get_success(getattr(self.worker_store, method)(*args))
if expected_result is not None:
- self.assertEqual(
+ asserter(
master_result,
expected_result,
"Expected master result to be %r but was %r"
% (expected_result, master_result),
)
- self.assertEqual(
+ asserter(
worker_result,
expected_result,
"Expected worker result to be %r but was %r"
% (expected_result, worker_result),
)
- self.assertEqual(
+ asserter(
master_result,
worker_result,
"Worker result %r does not match master result %r"
diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py
index af25815f..17716253 100644
--- a/tests/replication/storage/test_events.py
+++ b/tests/replication/storage/test_events.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from parameterized import parameterized
@@ -21,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.server import HomeServer
@@ -46,32 +46,9 @@ ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
-def dict_equals(self: EventBase, other: EventBase) -> bool:
- me = encode_canonical_json(self.get_pdu_json())
- them = encode_canonical_json(other.get_pdu_json())
- return me == them
-
-
-def patch__eq__(cls: object) -> Callable[[], None]:
- eq = getattr(cls, "__eq__", None)
- cls.__eq__ = dict_equals # type: ignore[assignment]
-
- def unpatch() -> None:
- if eq is not None:
- cls.__eq__ = eq # type: ignore[method-assign]
-
- return unpatch
-
-
class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
STORE_TYPE = EventsWorkerStore
- def setUp(self) -> None:
- # Patch up the equality operator for events so that we can check
- # whether lists of events match using assertEqual
- self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
- super().setUp()
-
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
@@ -84,13 +61,19 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
)
)
- def tearDown(self) -> None:
- [unpatch() for unpatch in self.unpatches]
+ def assertEventsEqual(
+ self, first: EventBase, second: EventBase, msg: Optional[Any] = None
+ ) -> None:
+ self.assertEqual(
+ encode_canonical_json(first.get_pdu_json()),
+ encode_canonical_json(second.get_pdu_json()),
+ msg,
+ )
def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
- self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id})
join = self.persist(
type="m.room.member",
@@ -99,7 +82,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
prev_events=[(create.event_id, {})],
)
self.replicate()
- self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id})
def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -107,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
- self.check("get_event", [msg.event_id], msg)
+ self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual)
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
self.replicate()
@@ -119,7 +102,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
redacted = make_event_from_dict(
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
)
- self.check("get_event", [msg.event_id], redacted)
+ self.check(
+ "get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual
+ )
def test_backfilled_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -127,7 +112,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
- self.check("get_event", [msg.event_id], msg)
+ self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual)
redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
@@ -141,7 +126,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
redacted = make_event_from_dict(
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
)
- self.check("get_event", [msg.event_id], redacted)
+ self.check(
+ "get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual
+ )
def test_invites(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 65ef4bb1..128fc3e0 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Optional, Sequence
+from typing import Any, List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: Sequence[str] = self.get_success(
+ fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: Sequence[str] = self.get_success(
+ fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -316,14 +316,14 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.test_handler.received_rdata_rows.clear()
# now roll back all that state by de-modding the users
- prev_events = fork_point
+ prev_events = list(fork_point)
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = self.get_success(
inject_event(
self.hs,
- prev_event_ids=list(prev_events),
+ prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py
index fb9eac66..ab379e8c 100644
--- a/tests/replication/tcp/streams/test_to_device.py
+++ b/tests/replication/tcp/streams/test_to_device.py
@@ -49,7 +49,7 @@ class ToDeviceStreamTestCase(BaseStreamTestCase):
# add messages to the device inbox for user1 up until the
# limit defined for a stream update batch
- for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT):
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT):
msg["content"] = {"device": {}}
messages = {user1: {"device": msg}}
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 9b28cd47..59f4fdc7 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -261,7 +261,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(
- builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+ builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
)
self.get_success(federation.on_send_membership_event(remote_server, join_event))
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 4c7864c6..0e2824d1 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -510,7 +510,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
Args:
number_destinations: Number of destinations to be created
"""
- for i in range(0, number_destinations):
+ for i in range(number_destinations):
dest = f"sub{i}.example.com"
self._create_destination(dest, 50, 50, 50, 100)
@@ -690,7 +690,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
self._check_fields(channel_desc.json_body["rooms"])
# test that both lists have different directions
- for i in range(0, number_rooms):
+ for i in range(number_rooms):
self.assertEqual(
channel_asc.json_body["rooms"][i]["room_id"],
channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"],
@@ -777,7 +777,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
Args:
number_rooms: Number of rooms to be created
"""
- for _ in range(0, number_rooms):
+ for _ in range(number_rooms):
room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index eb50086c..6ed451d7 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -15,26 +15,34 @@ import json
import time
import urllib.parse
from typing import List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
+from twisted.internet.task import deferLater
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import Codes
-from synapse.handlers.pagination import PaginationHandler, PurgeStatus
+from synapse.handlers.pagination import (
+ PURGE_ROOM_ACTION_NAME,
+ SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME,
+)
from synapse.rest.client import directory, events, login, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
-from synapse.util.stringutils import random_string
+from synapse.util.task_scheduler import TaskScheduler
from tests import unittest
"""Tests admin REST events for /rooms paths."""
+ONE_HOUR_IN_S = 3600
+
+
class DeleteRoomTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -46,6 +54,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_creation_handler = hs.get_event_creation_handler()
+ self.task_scheduler = hs.get_task_scheduler()
hs.config.consent.user_consent_version = "1"
consent_uri_builder = Mock()
@@ -476,6 +485,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_creation_handler = hs.get_event_creation_handler()
+ self.task_scheduler = hs.get_task_scheduler()
hs.config.consent.user_consent_version = "1"
consent_uri_builder = Mock()
@@ -502,6 +512,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/"
+ self.room_member_handler = hs.get_room_member_handler()
+ self.pagination_handler = hs.get_pagination_handler()
+
@parameterized.expand(
[
("DELETE", "/_synapse/admin/v2/rooms/%s"),
@@ -661,7 +674,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
delete_id1 = channel.json_body["delete_id"]
# go ahead
- self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+ self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
# second task
channel = self.make_request(
@@ -686,12 +699,14 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(2, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual("complete", channel.json_body["results"][1]["status"])
- self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"])
- self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"])
+ delete_ids = {delete_id1, delete_id2}
+ self.assertTrue(channel.json_body["results"][0]["delete_id"] in delete_ids)
+ delete_ids.remove(channel.json_body["results"][0]["delete_id"])
+ self.assertTrue(channel.json_body["results"][1]["delete_id"] in delete_ids)
# get status after more than clearing time for first task
# second task is not cleared
- self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+ self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
channel = self.make_request(
"GET",
@@ -705,7 +720,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
# get status after more than clearing time for all tasks
- self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+ self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
channel = self.make_request(
"GET",
@@ -721,6 +736,13 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
body = {"new_room_user_id": self.admin_user}
+ # Mock PaginationHandler.purge_room to sleep for 100s, so we have time to do a second call
+ # before the purge is over. Note that it doesn't purge anymore, but we don't care.
+ async def purge_room(room_id: str, force: bool) -> None:
+ await deferLater(self.hs.get_reactor(), 100, lambda: None)
+
+ self.pagination_handler.purge_room = AsyncMock(side_effect=purge_room) # type: ignore[method-assign]
+
# first call to delete room
# and do not wait for finish the task
first_channel = self.make_request(
@@ -728,7 +750,6 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url.encode("ascii"),
content=body,
access_token=self.admin_user_tok,
- await_result=False,
)
# second call to delete room
@@ -742,7 +763,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(400, second_channel.code, msg=second_channel.json_body)
self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
self.assertEqual(
- f"History purge already in progress for {self.room_id}",
+ f"Purge already in progress for {self.room_id}",
second_channel.json_body["error"],
)
@@ -751,6 +772,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(200, first_channel.code, msg=first_channel.json_body)
self.assertIn("delete_id", first_channel.json_body)
+ # wait for purge_room to finish
+ self.pump(1)
+
# check status after finish the task
self._test_result(
first_channel.json_body["delete_id"],
@@ -972,6 +996,115 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
# Assert we can no longer peek into the room
self._assert_peek(self.room_id, expect_code=403)
+ @unittest.override_config({"forgotten_room_retention_period": "1d"})
+ def test_purge_forgotten_room(self) -> None:
+ # Create a test room
+ room_id = self.helper.create_room_as(
+ self.admin_user,
+ tok=self.admin_user_tok,
+ )
+
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+ self.get_success(
+ self.room_member_handler.forget(
+ UserID.from_string(self.admin_user), room_id
+ )
+ )
+
+ # Test that room is not yet purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(room_id)
+
+ # Advance 24 hours in the future, past the `forgotten_room_retention_period`
+ self.reactor.advance(24 * ONE_HOUR_IN_S)
+
+ self._is_purged(room_id)
+
+ def test_scheduled_purge_room(self) -> None:
+ # Create a test room
+ room_id = self.helper.create_room_as(
+ self.admin_user,
+ tok=self.admin_user_tok,
+ )
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ # Schedule a purge 10 seconds in the future
+ self.get_success(
+ self.task_scheduler.schedule_task(
+ PURGE_ROOM_ACTION_NAME,
+ resource_id=room_id,
+ timestamp=self.clock.time_msec() + 10 * 1000,
+ )
+ )
+
+ # Test that room is not yet purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(room_id)
+
+ # Wait for next scheduler run
+ self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS)
+
+ self._is_purged(room_id)
+
+ def test_schedule_shutdown_room(self) -> None:
+ # Create a test room
+ room_id = self.helper.create_room_as(
+ self.other_user,
+ tok=self.other_user_tok,
+ )
+
+ # Schedule a shutdown 10 seconds in the future
+ delete_id = self.get_success(
+ self.task_scheduler.schedule_task(
+ SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME,
+ resource_id=room_id,
+ params={
+ "requester_user_id": self.admin_user,
+ "new_room_user_id": self.admin_user,
+ "new_room_name": None,
+ "message": None,
+ "block": False,
+ "purge": True,
+ "force_purge": True,
+ },
+ timestamp=self.clock.time_msec() + 10 * 1000,
+ )
+ )
+
+ # Test that room is not yet shutdown
+ self._is_member(room_id, self.other_user)
+
+ # Test that room is not yet purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(room_id)
+
+ # Wait for next scheduler run
+ self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS)
+
+ # Test that all users has been kicked (room is shutdown)
+ self._has_no_members(room_id)
+
+ self._is_purged(room_id)
+
+ # Retrieve delete results
+ result = self.make_request(
+ "GET",
+ self.url_status_by_delete_id + delete_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, result.code, msg=result.json_body)
+
+ # Check that the user is in kicked_users
+ self.assertIn(
+ self.other_user, result.json_body["shutdown_room"]["kicked_users"]
+ )
+
+ new_room_id = result.json_body["shutdown_room"]["new_room_id"]
+ self.assertTrue(new_room_id)
+
+ # Check that the user is actually in the new room
+ self._is_member(new_room_id, self.other_user)
+
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
d = self.store.is_room_blocked(room_id)
@@ -1034,7 +1167,6 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
kicked_user: a user_id which is kicked from the room
expect_new_room: if we expect that a new room was created
"""
-
# get information by room_id
channel_room_id = self.make_request(
"GET",
@@ -1957,11 +2089,8 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
# Purge every event before the second event.
- purge_id = random_string(16)
- pagination_handler._purges_by_id[purge_id] = PurgeStatus()
self.get_success(
- pagination_handler._purge_history(
- purge_id=purge_id,
+ pagination_handler.purge_history(
room_id=self.room_id,
token=second_token_str,
delete_local_events=True,
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 28b99957..dfd14f57 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -22,6 +22,7 @@ from synapse.server import HomeServer
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict
from synapse.util import Clock
+from synapse.util.stringutils import random_string
from tests import unittest
from tests.unittest import override_config
@@ -413,11 +414,24 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(messages[0]["content"]["body"], "test msg one")
self.assertEqual(messages[0]["sender"], "@notices:test")
+ random_string(16)
+
# shut down and purge room
self.get_success(
- self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user)
- )
- self.get_success(self.pagination_handler.purge_room(first_room_id))
+ self.room_shutdown_handler.shutdown_room(
+ first_room_id,
+ {
+ "requester_user_id": self.admin_user,
+ "new_room_user_id": None,
+ "new_room_name": None,
+ "message": None,
+ "block": False,
+ "purge": True,
+ "force_purge": False,
+ },
+ )
+ )
+ self.get_success(self.pagination_handler.purge_room(first_room_id, force=False))
# user is not member anymore
self._check_invite_and_join_status(self.other_user, 0, 0)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 761871b9..b326ad2c 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1146,6 +1146,32 @@ class UsersListTestCase(unittest.HomeserverTestCase):
users = {user["name"]: user for user in channel.json_body["users"]}
self.assertIs(users[user_id]["erased"], True)
+ def test_filter_locked(self) -> None:
+ # Create a new user.
+ user_id = self.register_user("lockme", "lockme")
+
+ # Lock them
+ self.get_success(self.store.set_user_locked_status(user_id, True))
+
+ # Locked user should appear in list users API
+ channel = self.make_request(
+ "GET",
+ self.url + "?locked=true",
+ access_token=self.admin_user_tok,
+ )
+ users = {user["name"]: user for user in channel.json_body["users"]}
+ self.assertIn(user_id, users)
+ self.assertTrue(users[user_id]["locked"])
+
+ # Locked user should not appear in list users API
+ channel = self.make_request(
+ "GET",
+ self.url + "?locked=false",
+ access_token=self.admin_user_tok,
+ )
+ users = {user["name"]: user for user in channel.json_body["users"]}
+ self.assertNotIn(user_id, users)
+
def _order_test(
self,
expected_user_list: List[str],
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index e9f495e2..cffbda9a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,6 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
+from synapse.storage._base import db_to_json
from synapse.types import JsonDict, UserID
from synapse.util import Clock
@@ -134,6 +135,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ # Check that the UI Auth information doesn't store the password in the database.
+ #
+ # Note that we don't have the UI Auth session ID, so just pull out the single
+ # row.
+ ui_auth_data = self.get_success(
+ self.store.db_pool.simple_select_one(
+ "ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
+ )
+ )
+ client_dict = db_to_json(ui_auth_data["clientdict"])
+ self.assertNotIn("new_password", client_dict)
+
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self) -> None:
"""Test that we ratelimit /requestToken for the same email."""
@@ -562,7 +575,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
# create a bunch of users and add keys for them
users = []
- for i in range(0, 20):
+ for i in range(20):
user_id = self.register_user("missPiggy" + str(i), "test")
users.append((user_id,))
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a2a65895..768d7ad4 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -176,10 +176,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_per_address(self) -> None:
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
- for i in range(0, 6):
+ for i in range(6):
self.register_user("kermit" + str(i), "monkey")
- for i in range(0, 6):
+ for i in range(6):
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
@@ -228,7 +228,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_per_account(self) -> None:
self.register_user("kermit", "monkey")
- for i in range(0, 6):
+ for i in range(6):
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
@@ -277,7 +277,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
self.register_user("kermit", "monkey")
- for i in range(0, 6):
+ for i in range(6):
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py
index 2a7fcea3..ec638c89 100644
--- a/tests/rest/client/test_receipts.py
+++ b/tests/rest/client/test_receipts.py
@@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+from typing import Optional
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.rest.client import login, receipts, register
+from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, ReceiptTypes
+from synapse.rest.client import login, receipts, room, sync
from synapse.server import HomeServer
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -24,30 +29,113 @@ from tests import unittest
class ReceiptsTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
- register.register_servlets,
receipts.register_servlets,
synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.owner = self.register_user("owner", "pass")
- self.owner_tok = self.login("owner", "pass")
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Join the second user
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
def test_send_receipt(self) -> None:
+ # Send a message.
+ res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+ # Send a read receipt
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+ {},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertNotEqual(self._get_read_receipt(), None)
+
+ def test_send_receipt_unknown_event(self) -> None:
+ """Receipts sent for unknown events are ignored to not break message retention."""
+ # Attempt to send a receipt to an unknown room.
channel = self.make_request(
"POST",
"/rooms/!abc:beep/receipt/m.read/$def",
content={},
- access_token=self.owner_tok,
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertIsNone(self._get_read_receipt())
+
+ # Attempt to send a receipt to an unknown event.
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/m.read/$def",
+ content={},
+ access_token=self.tok2,
)
self.assertEqual(channel.code, 200, channel.result)
+ self.assertIsNone(self._get_read_receipt())
+
+ def test_send_receipt_unviewable_event(self) -> None:
+ """Receipts sent for unviewable events are errors."""
+ # Create a room where new users can't see events from before their join
+ # & send events into it.
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ tok=self.tok,
+ extra_content={
+ "preset": "private_chat",
+ "initial_state": [
+ {
+ "content": {"history_visibility": HistoryVisibility.JOINED},
+ "state_key": "",
+ "type": EventTypes.RoomHistoryVisibility,
+ }
+ ],
+ },
+ )
+ res = self.helper.send(room_id, body="hello", tok=self.tok)
+
+ # Attempt to send a receipt from the wrong user.
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+ content={},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+
+ # Join the user to the room, but they still can't see the event.
+ self.helper.invite(room_id, self.user_id, self.user2, tok=self.tok)
+ self.helper.join(room=room_id, user=self.user2, tok=self.tok2)
+
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+ content={},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
def test_send_receipt_invalid_room_id(self) -> None:
channel = self.make_request(
"POST",
"/rooms/not-a-room-id/receipt/m.read/$def",
content={},
- access_token=self.owner_tok,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -59,7 +147,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
"POST",
"/rooms/!abc:beep/receipt/m.read/not-an-event-id",
content={},
- access_token=self.owner_tok,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -71,6 +159,123 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
"POST",
"/rooms/!abc:beep/receipt/invalid-receipt-type/$def",
content={},
- access_token=self.owner_tok,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 400, channel.result)
+
+ def test_private_read_receipts(self) -> None:
+ # Send a message as the first user
+ res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+ # Send a private read receipt to tell the server the first user's message was read
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+ {},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Test that the first user can't see the other user's private read receipt
+ self.assertIsNone(self._get_read_receipt())
+
+ def test_public_receipt_can_override_private(self) -> None:
+ """
+ Sending a public read receipt to the same event which has a private read
+ receipt should cause that receipt to become public.
+ """
+ # Send a message as the first user
+ res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+ # Send a private read receipt
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+ {},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertIsNone(self._get_read_receipt())
+
+ # Send a public read receipt
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+ {},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Test that we did override the private read receipt
+ self.assertNotEqual(self._get_read_receipt(), None)
+
+ def test_private_receipt_cannot_override_public(self) -> None:
+ """
+ Sending a private read receipt to the same event which has a public read
+ receipt should cause no change.
+ """
+ # Send a message as the first user
+ res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+ # Send a public read receipt
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+ {},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertNotEqual(self._get_read_receipt(), None)
+
+ # Send a private read receipt
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+ {},
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Test that we didn't override the public read receipt
+ self.assertIsNone(self._get_read_receipt())
+
+ def test_read_receipt_with_empty_body_is_rejected(self) -> None:
+ # Send a message as the first user
+ res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+ # Send a read receipt for this message with an empty body
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body)
+
+ def _get_read_receipt(self) -> Optional[JsonDict]:
+ """Syncs and returns the read receipt."""
+
+ # Checks if event is a read receipt
+ def is_read_receipt(event: JsonDict) -> bool:
+ return event["type"] == EduTypes.RECEIPT
+
+ # Sync
+ channel = self.make_request(
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
+
+ if channel.json_body.get("rooms", None) is None:
+ return None
+
+ # Return the read receipt
+ ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
+ "ephemeral"
+ ]["events"]
+ receipt_event = filter(is_read_receipt, ephemeral_events)
+ return next(receipt_event, None)
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index c33393dc..ba4e017a 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -169,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self) -> None:
- for i in range(0, 6):
+ for i in range(6):
url = self.url + b"?kind=guest"
channel = self.make_request(b"POST", url, b"{}")
@@ -187,7 +187,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
- for i in range(0, 6):
+ for i in range(6):
request_data = {
"username": "kermit" + str(i),
"password": "monkey",
@@ -1223,7 +1223,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
def test_GET_ratelimiting(self) -> None:
token = "1234"
- for i in range(0, 6):
+ for i in range(6):
channel = self.make_request(
b"GET",
f"{self.url}?token={token}",
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 47c1d38a..7627823d 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -41,7 +41,6 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, register, room, sync
from synapse.server import HomeServer
@@ -2086,11 +2085,8 @@ class RoomMessageListTestCase(RoomBase):
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
# Purge every event before the second event.
- purge_id = random_string(16)
- pagination_handler._purges_by_id[purge_id] = PurgeStatus()
self.get_success(
- pagination_handler._purge_history(
- purge_id=purge_id,
+ pagination_handler.purge_history(
room_id=self.room_id,
token=second_token_str,
delete_local_events=True,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 9c876c7a..d6066525 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from http import HTTPStatus
-from typing import List, Optional
+from typing import List
from parameterized import parameterized
@@ -22,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
- EduTypes,
EventContentFields,
EventTypes,
ReceiptTypes,
@@ -376,156 +374,6 @@ class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin):
)
-class ReadReceiptsTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- receipts.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- return self.setup_test_homeserver(config=config)
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.url = "/sync?since=%s"
- self.next_batch = "s0"
-
- # Register the first user
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey")
-
- # Create the room
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- # Register the second user
- self.user2 = self.register_user("kermit2", "monkey")
- self.tok2 = self.login("kermit2", "monkey")
-
- # Join the second user
- self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
-
- def test_private_read_receipts(self) -> None:
- # Send a message as the first user
- res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
- # Send a private read receipt to tell the server the first user's message was read
- channel = self.make_request(
- "POST",
- f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
- {},
- access_token=self.tok2,
- )
- self.assertEqual(channel.code, 200)
-
- # Test that the first user can't see the other user's private read receipt
- self.assertIsNone(self._get_read_receipt())
-
- def test_public_receipt_can_override_private(self) -> None:
- """
- Sending a public read receipt to the same event which has a private read
- receipt should cause that receipt to become public.
- """
- # Send a message as the first user
- res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
- # Send a private read receipt
- channel = self.make_request(
- "POST",
- f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
- {},
- access_token=self.tok2,
- )
- self.assertEqual(channel.code, 200)
- self.assertIsNone(self._get_read_receipt())
-
- # Send a public read receipt
- channel = self.make_request(
- "POST",
- f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
- {},
- access_token=self.tok2,
- )
- self.assertEqual(channel.code, 200)
-
- # Test that we did override the private read receipt
- self.assertNotEqual(self._get_read_receipt(), None)
-
- def test_private_receipt_cannot_override_public(self) -> None:
- """
- Sending a private read receipt to the same event which has a public read
- receipt should cause no change.
- """
- # Send a message as the first user
- res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
- # Send a public read receipt
- channel = self.make_request(
- "POST",
- f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
- {},
- access_token=self.tok2,
- )
- self.assertEqual(channel.code, 200)
- self.assertNotEqual(self._get_read_receipt(), None)
-
- # Send a private read receipt
- channel = self.make_request(
- "POST",
- f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
- {},
- access_token=self.tok2,
- )
- self.assertEqual(channel.code, 200)
-
- # Test that we didn't override the public read receipt
- self.assertIsNone(self._get_read_receipt())
-
- def test_read_receipt_with_empty_body_is_rejected(self) -> None:
- # Send a message as the first user
- res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
- # Send a read receipt for this message with an empty body
- channel = self.make_request(
- "POST",
- f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
- access_token=self.tok2,
- )
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
- self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body)
-
- def _get_read_receipt(self) -> Optional[JsonDict]:
- """Syncs and returns the read receipt."""
-
- # Checks if event is a read receipt
- def is_read_receipt(event: JsonDict) -> bool:
- return event["type"] == EduTypes.RECEIPT
-
- # Sync
- channel = self.make_request(
- "GET",
- self.url % self.next_batch,
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 200)
-
- # Store the next batch for the next request.
- self.next_batch = channel.json_body["next_batch"]
-
- if channel.json_body.get("rooms", None) is None:
- return None
-
- # Return the read receipt
- ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
- "ephemeral"
- ]["events"]
- receipt_event = filter(is_read_receipt, ephemeral_events)
- return next(receipt_event, None)
-
-
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 650b4941..35f77052 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -382,7 +382,7 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aenter__())
# Wait for ages with the lock, we should not be able to get the lock.
- for _ in range(0, 10):
+ for _ in range(10):
self.reactor.advance((_RENEWAL_INTERVAL_MS / 1000))
lock2 = self.get_success(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7de10996..ceb9597d 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -120,7 +120,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(latest_event_ids, [event_id_4])
+ self.assertEqual(latest_event_ids, {event_id_4})
def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
@@ -147,7 +147,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
+ self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -155,7 +155,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(latest_event_ids, [event_id_b])
+ self.assertEqual(latest_event_ids, {event_id_b})
def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
@@ -185,7 +185,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
+ self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -193,7 +193,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(latest_event_ids, [event_id_b])
+ self.assertEqual(latest_event_ids, {event_id_b})
def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of
@@ -240,7 +240,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
+ self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -248,7 +248,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
+ self.assertEqual(latest_event_ids, {event_id_b, event_id_c})
class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 48ebfada..b55dd07f 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
- for i in range(0, 150):
+ for i in range(150):
self.helper.send_state(
room_id, event_type="m.test", body={"index": i}, tok=self.token
)
@@ -718,12 +718,12 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
- for i in range(0, 150):
+ for i in range(150):
self.helper.send_state(
room_id1, event_type="m.test", body={"index": i}, tok=self.token
)
- for i in range(0, 150):
+ for i in range(150):
self.helper.send_state(
room_id2, event_type="m.test", body={"index": i}, tok=self.token
)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7a4ecab2..d3e20f44 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -227,7 +227,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- for i in range(0, 20):
+ for i in range(20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i)
)
@@ -235,7 +235,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# this should get the last ten
r = self.get_success(self.store.get_prev_events_for_room(room_id))
self.assertEqual(10, len(r))
- for i in range(0, 10):
+ for i in range(10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])
def test_get_rooms_with_many_extremities(self) -> None:
@@ -277,7 +277,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- for i in range(0, 20):
+ for i in range(20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i, room1)
)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
deleted file mode 100644
index 5d7c13e6..00000000
--- a/tests/storage/test_keys.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import signedjson.key
-import signedjson.types
-import unpaddedbase64
-
-from synapse.storage.keys import FetchKeyResult
-
-import tests.unittest
-
-
-def decode_verify_key_base64(
- key_id: str, key_base64: str
-) -> signedjson.types.VerifyKey:
- key_bytes = unpaddedbase64.decode_base64(key_base64)
- return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
-
-
-KEY_1 = decode_verify_key_base64(
- "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
-)
-KEY_2 = decode_verify_key_base64(
- "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
-)
-
-
-class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
- def test_get_server_signature_keys(self) -> None:
- store = self.hs.get_datastores().main
-
- key_id_1 = "ed25519:key1"
- key_id_2 = "ed25519:KEY_ID_2"
- self.get_success(
- store.store_server_signature_keys(
- "from_server",
- 10,
- {
- ("server1", key_id_1): FetchKeyResult(KEY_1, 100),
- ("server1", key_id_2): FetchKeyResult(KEY_2, 200),
- },
- )
- )
-
- res = self.get_success(
- store.get_server_signature_keys(
- [
- ("server1", key_id_1),
- ("server1", key_id_2),
- ("server1", "ed25519:key3"),
- ]
- )
- )
-
- self.assertEqual(len(res.keys()), 3)
- res1 = res[("server1", key_id_1)]
- self.assertEqual(res1.verify_key, KEY_1)
- self.assertEqual(res1.verify_key.version, "key1")
- self.assertEqual(res1.valid_until_ts, 100)
-
- res2 = res[("server1", key_id_2)]
- self.assertEqual(res2.verify_key, KEY_2)
- # version comes from the ID it was stored with
- self.assertEqual(res2.verify_key.version, "KEY_ID_2")
- self.assertEqual(res2.valid_until_ts, 200)
-
- # non-existent result gives None
- self.assertIsNone(res[("server1", "ed25519:key3")])
-
- def test_cache(self) -> None:
- """Check that updates correctly invalidate the cache."""
-
- store = self.hs.get_datastores().main
-
- key_id_1 = "ed25519:key1"
- key_id_2 = "ed25519:key2"
-
- self.get_success(
- store.store_server_signature_keys(
- "from_server",
- 0,
- {
- ("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
- ("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
- },
- )
- )
-
- res = self.get_success(
- store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- )
- self.assertEqual(len(res.keys()), 2)
-
- res1 = res[("srv1", key_id_1)]
- self.assertEqual(res1.verify_key, KEY_1)
- self.assertEqual(res1.valid_until_ts, 100)
-
- res2 = res[("srv1", key_id_2)]
- self.assertEqual(res2.verify_key, KEY_2)
- self.assertEqual(res2.valid_until_ts, 200)
-
- # we should be able to look up the same thing again without a db hit
- res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)]))
- self.assertEqual(len(res.keys()), 1)
- self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
-
- new_key_2 = signedjson.key.get_verify_key(
- signedjson.key.generate_signing_key("key2")
- )
- d = store.store_server_signature_keys(
- "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
- )
- self.get_success(d)
-
- res = self.get_success(
- store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- )
- self.assertEqual(len(res.keys()), 2)
-
- res1 = res[("srv1", key_id_1)]
- self.assertEqual(res1.verify_key, KEY_1)
- self.assertEqual(res1.valid_until_ts, 100)
-
- res2 = res[("srv1", key_id_2)]
- self.assertEqual(res2.verify_key, new_key_2)
- self.assertEqual(res2.valid_until_ts, 300)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index fe5bb779..95f99f41 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -82,7 +82,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.db_pool.runInteraction("", f))
- for i in range(0, 70):
+ for i in range(70):
self.get_success(
self.store.db_pool.simple_insert(
"profiles",
@@ -115,7 +115,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
)
expected_values = []
- for i in range(0, 70):
+ for i in range(70):
expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
res = self.get_success(
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 95c9792d..0cca34d3 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual(
- {
+ UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name'
- "name": self.user_id,
- "password_hash": self.pwhash,
- "admin": 0,
- "is_guest": 0,
- "consent_version": None,
- "consent_ts": None,
- "consent_server_notice_sent": None,
- "appservice_id": None,
- "creation_ts": 0,
- "user_type": None,
- "deactivated": 0,
- "locked": 0,
- "shadow_banned": 0,
- "approved": 1,
- "last_seen_ts": None,
- },
+ user_id=UserID.from_string(self.user_id),
+ is_admin=False,
+ is_guest=False,
+ consent_server_notice_sent=None,
+ consent_ts=None,
+ consent_version=None,
+ appservice_id=None,
+ creation_ts=0,
+ user_type=None,
+ is_deactivated=False,
+ locked=False,
+ is_shadow_banned=False,
+ approved=True,
+ ),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
@@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user
- self.assertEqual(user["consent_version"], "1")
- self.assertGreater(user["consent_ts"], before_consent)
- self.assertLess(user["consent_ts"], self.clock.time_msec())
+ self.assertEqual(user.consent_version, "1")
+ self.assertIsNotNone(user.consent_ts)
+ assert user.consent_ts is not None
+ self.assertGreater(user.consent_ts, before_consent)
+ self.assertLess(user.consent_ts, self.clock.time_msec())
def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
@@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
- self.assertTrue(user["approved"])
+ self.assertTrue(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
@@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
- self.assertFalse(user["approved"])
+ self.assertFalse(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
@@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user)
assert user is not None
- self.assertEqual(user["approved"], 1)
+ self.assertEqual(user.approved, 1)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 15ea4770..22f07498 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -38,5 +38,5 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
db_pool = self.hs.get_datastores().databases[0]
# force txn limit to roll over at least once
- for _ in range(0, 1001):
+ for _ in range(1001):
self.get_success_or_raise(db_pool.runInteraction("test_select", do_select))
diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
index bab802f5..d4637d9d 100644
--- a/tests/storage/test_user_filters.py
+++ b/tests/storage/test_user_filters.py
@@ -45,7 +45,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.db_pool.runInteraction("", f))
- for i in range(0, 70):
+ for i in range(70):
self.get_success(
self.store.db_pool.simple_insert(
"user_filters",
@@ -82,7 +82,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
)
expected_values = []
- for i in range(0, 70):
+ for i in range(70):
expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
res = self.get_success(
diff --git a/tests/test_federation.py b/tests/test_federation.py
index f8ade6da..1b050470 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -51,9 +51,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.store = self.hs.get_datastores().main
# Figure out what the most recent event is
- most_recent = self.get_success(
- self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
- )[0]
+ most_recent = next(
+ iter(
+ self.get_success(
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(
+ self.room_id
+ )
+ )
+ )
+ )
join_event = make_event_from_dict(
{
@@ -100,8 +106,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Make sure we actually joined the room
self.assertEqual(
- self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
- "$join:test.serv",
+ self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
+ {"$join:test.serv"},
)
def test_cant_hide_direct_ancestors(self) -> None:
@@ -127,9 +133,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.http_client.post_json = post_json
# Figure out what the most recent event is
- most_recent = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )[0]
+ most_recent = next(
+ iter(
+ self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+ )
+ )
# Now lie about an event
lying_event = make_event_from_dict(
@@ -165,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Make sure the invalid event isn't there
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
- self.assertEqual(extrem[0], "$join:test.serv")
+ self.assertEqual(extrem, {"$join:test.serv"})
def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index a46c29dd..434902c3 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -51,12 +51,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# before we do that, we persist some other events to act as state.
self._inject_visibility("@admin:hs", "joined")
- for i in range(0, 10):
+ for i in range(10):
self._inject_room_member("@resident%i:hs" % i)
events_to_filter = []
- for i in range(0, 10):
+ for i in range(10):
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
evt = self._inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
@@ -74,7 +74,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
)
# the result should be 5 redacted events, and 5 unredacted events.
- for i in range(0, 5):
+ for i in range(5):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertNotIn("a", filtered[i].content)
@@ -177,7 +177,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
)
)
- for i in range(0, len(events_to_filter)):
+ for i in range(len(events_to_filter)):
self.assertEqual(
events_to_filter[i].event_id,
filtered[i].event_id,
diff --git a/tests/unittest.py b/tests/unittest.py
index 5d3640d8..dbaff361 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -70,6 +70,7 @@ from synapse.logging.context import (
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
@@ -858,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
self.get_success(
- hs.get_datastores().main.store_server_keys_json(
+ hs.get_datastores().main.store_server_keys_response(
self.OTHER_SERVER_NAME,
- verify_key_id,
from_server=self.OTHER_SERVER_NAME,
- ts_now_ms=clock.time_msec(),
- ts_expires_ms=clock.time_msec() + 10000,
- key_json_bytes=canonicaljson.encode_canonical_json(
- {
- "verify_keys": {
- verify_key_id: {
- "key": signedjson.key.encode_verify_key_base64(
- verify_key
- )
- }
+ ts_added_ms=clock.time_msec(),
+ verify_keys={
+ verify_key_id: FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000
+ ),
+ },
+ response_json={
+ "verify_keys": {
+ verify_key_id: {
+ "key": signedjson.key.encode_verify_key_base64(verify_key)
}
}
- ),
+ },
)
)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 064f4987..168419f4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -623,14 +623,14 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
a = A()
- for k in range(0, 12):
+ for k in range(12):
yield a.func(k)
self.assertEqual(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
- for k in range(0, 12):
+ for k in range(12):
yield a.func(k)
self.assertTrue(