summaryrefslogtreecommitdiff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/databases/main/test_events_worker.py25
-rw-r--r--tests/storage/databases/main/test_lock.py54
-rw-r--r--tests/storage/test_appservice.py27
-rw-r--r--tests/storage/test_base.py2
-rw-r--r--tests/storage/test_devices.py7
-rw-r--r--tests/storage/test_event_chain.py3
-rw-r--r--tests/storage/test_event_federation.py9
-rw-r--r--tests/storage/test_events.py58
-rw-r--r--tests/storage/test_monthly_active_users.py83
-rw-r--r--tests/storage/test_purge.py19
-rw-r--r--tests/storage/test_redaction.py14
-rw-r--r--tests/storage/test_room.py12
-rw-r--r--tests/storage/test_room_search.py4
-rw-r--r--tests/storage/test_roommember.py2
-rw-r--r--tests/storage/test_state.py2
-rw-r--r--tests/storage/test_user_directory.py1
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py59
17 files changed, 304 insertions, 77 deletions
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index c237a8c7..38963ce4 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -154,6 +154,31 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+ def test_event_ref(self):
+ """Test that we reuse events that are still in memory but have fallen
+ out of the cache, rather than requesting them from the DB.
+ """
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ # We keep hold of the event event though we never use it.
+ event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ self.get_success(self.store.get_event(self.event_id))
+
+ # Since the event is still in memory we shouldn't have fetched it
+ # from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
+
def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 74c6224e..3cc2a58d 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer, reactor
+from twisted.internet.base import ReactorBase
+from twisted.internet.defer import Deferred
+
from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
@@ -22,6 +26,56 @@ class LockTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs: HomeServer):
self.store = hs.get_datastores().main
+ def test_acquire_contention(self):
+ # Track the number of tasks holding the lock.
+ # Should be at most 1.
+ in_lock = 0
+ max_in_lock = 0
+
+ release_lock: "Deferred[None]" = Deferred()
+
+ async def task():
+ nonlocal in_lock
+ nonlocal max_in_lock
+
+ lock = await self.store.try_acquire_lock("name", "key")
+ if not lock:
+ return
+
+ async with lock:
+ in_lock += 1
+ max_in_lock = max(max_in_lock, in_lock)
+
+ # Block to allow other tasks to attempt to take the lock.
+ await release_lock
+
+ in_lock -= 1
+
+ # Start 3 tasks.
+ task1 = defer.ensureDeferred(task())
+ task2 = defer.ensureDeferred(task())
+ task3 = defer.ensureDeferred(task())
+
+ # Give the reactor a kick so that the database transaction returns.
+ self.pump()
+
+ release_lock.callback(None)
+
+ # Run the tasks to completion.
+ # To work around `Linearizer`s using a different reactor to sleep when
+ # contended (#12841), we call `runUntilCurrent` on
+ # `twisted.internet.reactor`, which is a different reactor to that used
+ # by the homeserver.
+ assert isinstance(reactor, ReactorBase)
+ self.get_success(task1)
+ reactor.runUntilCurrent()
+ self.get_success(task2)
+ reactor.runUntilCurrent()
+ self.get_success(task3)
+
+ # At most one task should have held the lock at a time.
+ self.assertEqual(max_in_lock, 1)
+
def test_simple_lock(self):
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1bf93e79..1047ed09 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -14,7 +14,7 @@
import json
import os
import tempfile
-from typing import List, Optional, cast
+from typing import List, cast
from unittest.mock import Mock
import yaml
@@ -149,15 +149,12 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def _set_state(
- self, id: str, state: ApplicationServiceState, txn: Optional[int] = None
- ):
+ def _set_state(self, id: str, state: ApplicationServiceState):
return self.db_pool.runOperation(
self.engine.convert_param_style(
- "INSERT INTO application_services_state(as_id, state, last_txn) "
- "VALUES(?,?,?)"
+ "INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
),
- (id, state.value, txn),
+ (id, state.value),
)
def _insert_txn(self, as_id, txn_id, events):
@@ -283,17 +280,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
- "SELECT last_txn FROM application_services_state WHERE as_id=?"
- ),
- (service.id,),
- )
- )
- self.assertEqual(1, len(res))
- self.assertEqual(txn_id, res[0][0])
-
- res = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
@@ -316,14 +302,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
- "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+ "SELECT state FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
)
self.assertEqual(1, len(res))
- self.assertEqual(txn_id, res[0][0])
- self.assertEqual(ApplicationServiceState.UP.value, res[0][1])
+ self.assertEqual(ApplicationServiceState.UP.value, res[0][0])
res = self.get_success(
self.db_pool.runQuery(
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index a8ffb52c..cce8e75c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -60,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
- self.datastore = SQLBaseStore(db, None, hs)
+ self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index bbf079b2..f37505b6 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -13,6 +13,7 @@
# limitations under the License.
import synapse.api.errors
+from synapse.api.constants import EduTypes
from tests.unittest import HomeserverTestCase
@@ -266,10 +267,12 @@ class DeviceStoreTestCase(HomeserverTestCase):
# (This is a temporary arrangement for backwards compatibility!)
self.assertEqual(len(device_updates), 2, device_updates)
self.assertEqual(
- device_updates[0][0], "m.signing_key_update", device_updates[0]
+ device_updates[0][0], EduTypes.SIGNING_KEY_UPDATE, device_updates[0]
)
self.assertEqual(
- device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
+ device_updates[1][0],
+ EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
+ device_updates[1],
)
# Check there are no more device updates left.
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 401020fd..a0ce077a 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
- txn, [(e, EventContext()) for e in events]
+ txn,
+ [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 645d564d..d92a9ac5 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- txn.execute(
- (
- "INSERT INTO event_reference_hashes "
- "(event_id, algorithm, hash) "
- "VALUES (?, 'sha256', ?)"
- ),
- (event_id, bytearray(b"ffff")),
- )
-
for i in range(0, 20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index ef5e2587..2ff88e64 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -31,7 +31,8 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
+ self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
@@ -69,9 +70,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
- self.state.compute_event_context(event, old_state=state)
+ self.state.compute_event_context(event, state_ids_before_event=state)
)
- self.get_success(self.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities):
"""Assert the current extremities for the room"""
@@ -103,9 +104,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -135,17 +138,20 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
- self.get_success(self.state.get_current_state(self.room_id))
+ self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
self.state.compute_event_context(
- remote_event_2, old_state=state_before_gap.values()
+ remote_event_2,
+ state_ids_before_event=state_before_gap,
)
)
- self.get_success(self.persistence.persist_event(remote_event_2, context))
+ self.get_success(self._persistence.persist_event(remote_event_2, context))
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -177,9 +183,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +215,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +257,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +301,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +337,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
- state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+ state_before_gap = self.get_success(
+ self._state_storage_controller.get_current_state_ids(self.room_id)
+ )
- self.persist_event(remote_event_2, state=state_before_gap.values())
+ self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
@@ -340,7 +356,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self):
@@ -377,7 +393,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_rooms_for_user` to add the remote user to the cache
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@@ -424,7 +440,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_users_in_room` to add the remote user to the cache
users = self.get_success(self.store.get_users_in_room(room_id))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 4c29ad79..e8b4a564 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -407,3 +407,86 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service1], 2)
self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1)
+
+ def test_get_monthly_active_users_by_service(self):
+ # (No users, no filtering) -> empty result
+ result = self.get_success(self.store.get_monthly_active_users_by_service())
+
+ self.assertEqual(len(result), 0)
+
+ # (Some users, no filtering) -> non-empty result
+ appservice1_user1 = "@appservice1_user1:example.com"
+ appservice2_user1 = "@appservice2_user1:example.com"
+ service1 = "service1"
+ service2 = "service2"
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
+
+ result = self.get_success(self.store.get_monthly_active_users_by_service())
+
+ self.assertEqual(len(result), 2)
+ self.assertIn((service1, appservice1_user1), result)
+ self.assertIn((service2, appservice2_user1), result)
+
+ # (Some users, end-timestamp filtering) -> non-empty result
+ appservice1_user2 = "@appservice1_user2:example.com"
+ timestamp1 = self.reactor.seconds()
+ self.reactor.advance(5)
+ timestamp2 = self.reactor.seconds()
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
+
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ end_timestamp=round(timestamp1 * 1000)
+ )
+ )
+
+ self.assertEqual(len(result), 2)
+ self.assertNotIn((service1, appservice1_user2), result)
+
+ # (Some users, start-timestamp filtering) -> non-empty result
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ start_timestamp=round(timestamp2 * 1000)
+ )
+ )
+
+ self.assertEqual(len(result), 1)
+ self.assertIn((service1, appservice1_user2), result)
+
+ # (Some users, full-timestamp filtering) -> non-empty result
+ native_user1 = "@native_user1:example.com"
+ native = "native"
+ timestamp3 = self.reactor.seconds()
+ self.reactor.advance(100)
+ self.get_success(
+ self.store.register_user(
+ user_id=native_user1, password_hash=None, appservice_id=native
+ )
+ )
+ self.get_success(self.store.upsert_monthly_active_user(native_user1))
+
+ result = self.get_success(
+ self.store.get_monthly_active_users_by_service(
+ start_timestamp=round(timestamp2 * 1000),
+ end_timestamp=round(timestamp3 * 1000),
+ )
+ )
+
+ self.assertEqual(len(result), 1)
+ self.assertIn((service1, appservice1_user2), result)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 08cc6023..8dfaa055 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self):
"""
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
self.get_success(
- self.storage.purge_events.purge_history(self.room_id, token_str, True)
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
f = self.get_failure(
- self.storage.purge_events.purge_history(self.room_id, event, True),
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, event, True
+ ),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
@@ -98,14 +102,17 @@ class PurgeTests(HomeserverTestCase):
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
- state_handler = self.hs.get_state_handler()
create_event = self.get_success(
- state_handler.get_current_state(self.room_id, "m.room.create", "")
+ self._storage_controllers.state.get_current_state_event(
+ self.room_id, "m.room.create", ""
+ )
)
self.assertIsNotNone(create_event)
# Purge everything before this topological token
- self.get_success(self.storage.purge_events.purge_room(self.room_id))
+ self.get_success(
+ self._storage_controllers.purge_events.purge_room(self.room_id)
+ )
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d8d17ef3..6c4e63b7 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage = hs.get_storage_controllers()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.storage.persistence.persist_event(redaction_event, context)
+ self._storage.persistence.persist_event(redaction_event, context)
)
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 5b011e18..3c79dabc 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs):
self.get_success(
- self.storage.persistence.persist_event(
+ self._storage_controllers.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
@@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
- self.store.get_current_state(room_id=self.room.to_string())
+ self._storage_controllers.state.get_current_state(
+ room_id=self.room.to_string()
+ )
)
self.assertEqual(1, len(state))
@@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
- self.store.get_current_state(room_id=self.room.to_string())
+ self._storage_controllers.state.get_current_state(
+ room_id=self.room.to_string()
+ )
)
self.assertEqual(1, len(state))
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 8dfc1e1d..e747c6b5 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
prev_state_map = self.get_success(
- self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
+ self.hs.get_storage_controllers().state.get_state_ids_for_event(
+ prev_event_ids[0]
+ )
)
event_dict = {
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index a2a9c05f..1218786d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -34,7 +34,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override]
# We can't test the RoomMemberStore on its own without the other event
# storage logic
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index f88f1c55..8043bdbd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7f1964eb..5b60cf52 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -134,7 +134,6 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice = ApplicationService(
token="i_am_an_app_service",
- hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 303e190b..cae14151 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -17,8 +17,12 @@ from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+ PartialCurrentStateTracker,
+ PartialStateEventsTracker,
+)
+from tests.test_utils import make_awaitable
from tests.unittest import TestCase
@@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d2)
+
+
+class PartialCurrentStateTrackerTestCase(TestCase):
+ def setUp(self) -> None:
+ self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+
+ self.tracker = PartialCurrentStateTracker(self.mock_store)
+
+ def test_does_not_block_for_full_state_rooms(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_blocks_for_partial_room_state(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d = ensureDeferred(self.tracker.await_full_state("room_id"))
+
+ # there should be no result yet
+ self.assertNoResult(d)
+
+ # notifying that the room has been de-partial-stated should unblock
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d)
+
+ def test_un_partial_state_race(self):
+ # We should correctly handle race between awaiting the state and us
+ # un-partialling the state
+ async def is_partial_state_room(events):
+ self.tracker.notify_un_partial_stated("room_id")
+ return True
+
+ self.mock_store.is_partial_state_room.side_effect = is_partial_state_room
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_cancellation(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d1)
+
+ d2 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d2)
+
+ d1.cancel()
+ self.assertFailure(d1, CancelledError)
+
+ # d2 should still be waiting!
+ self.assertNoResult(d2)
+
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d2)