summaryrefslogtreecommitdiff
path: root/synapse/storage
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-03-01 15:20:32 +0100
committerAndrej Shadura <andrewsh@debian.org>2021-03-01 15:20:32 +0100
commit7b07dc8dd1aa7eb4c55edb19822a30cfdc4adc0b (patch)
tree093d5ba40a632cf760b57d2fe0b60c41b73bb8ef /synapse/storage
parentc5414640d8cd028fad2320bad78ffc179251e4ef (diff)
New upstream version 1.28.0
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py3
-rw-r--r--synapse/storage/background_updates.py8
-rw-r--r--synapse/storage/database.py41
-rw-r--r--synapse/storage/databases/__init__.py5
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/appservice.py3
-rw-r--r--synapse/storage/databases/main/client_ips.py12
-rw-r--r--synapse/storage/databases/main/deviceinbox.py2
-rw-r--r--synapse/storage/databases/main/devices.py38
-rw-r--r--synapse/storage/databases/main/directory.py7
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py11
-rw-r--r--synapse/storage/databases/main/event_federation.py19
-rw-r--r--synapse/storage/databases/main/event_push_actions.py18
-rw-r--r--synapse/storage/databases/main/events.py45
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py29
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py16
-rw-r--r--synapse/storage/databases/main/group_server.py40
-rw-r--r--synapse/storage/databases/main/keys.py7
-rw-r--r--synapse/storage/databases/main/media_repository.py15
-rw-r--r--synapse/storage/databases/main/metrics.py4
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/profile.py6
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py7
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/registration.py14
-rw-r--r--synapse/storage/databases/main/room.py25
-rw-r--r--synapse/storage/databases/main/roommember.py17
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py3
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite10
-rw-r--r--synapse/storage/databases/main/state.py11
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py4
-rw-r--r--synapse/storage/databases/main/stream.py42
-rw-r--r--synapse/storage/databases/main/transactions.py21
-rw-r--r--synapse/storage/databases/main/ui_auth.py22
-rw-r--r--synapse/storage/databases/main/user_directory.py14
-rw-r--r--synapse/storage/databases/state/bg_updates.py2
-rw-r--r--synapse/storage/databases/state/store.py6
-rw-r--r--synapse/storage/engines/__init__.py8
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py3
-rw-r--r--synapse/storage/engines/sqlite.py14
-rw-r--r--synapse/storage/persist_events.py12
-rw-r--r--synapse/storage/prepare_database.py14
-rw-r--r--synapse/storage/purge_events.py6
-rw-r--r--synapse/storage/state.py5
-rw-r--r--synapse/storage/types.py37
-rw-r--r--synapse/storage/util/id_generators.py24
-rw-r--r--synapse/storage/util/sequence.py11
51 files changed, 431 insertions, 274 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c0d9d124..a3c52695 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -43,8 +43,7 @@ __all__ = ["Databases", "DataStore"]
class Storage:
- """The high level interfaces for talking to various storage layers.
- """
+ """The high level interfaces for talking to various storage layers."""
def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 29b8ca67..329660cf 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -77,7 +77,7 @@ class BackgroundUpdatePerformance:
class BackgroundUpdater:
- """ Background updates are updates to the database that run in the
+ """Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
process and autotuning the batch size.
@@ -158,8 +158,7 @@ class BackgroundUpdater:
return False
async def has_completed_background_update(self, update_name: str) -> bool:
- """Check if the given background update has finished running.
- """
+ """Check if the given background update has finished running."""
if self._all_done:
return True
@@ -198,7 +197,8 @@ class BackgroundUpdater:
if not self._current_background_update:
all_pending_updates = await self.db_pool.runInteraction(
- "background_updates", get_background_updates_txn,
+ "background_updates",
+ get_background_updates_txn,
)
if not all_pending_updates:
# no work left to do
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d2ba4bd2..46469264 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -85,8 +85,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
- """Get the connection pool for the database.
- """
+ """Get the connection pool for the database."""
# By default enable `cp_reconnect`. We need to fiddle with db_args in case
# someone has explicitly set `cp_reconnect`.
@@ -158,8 +157,8 @@ class LoggingDatabaseConnection:
def commit(self) -> None:
self.conn.commit()
- def rollback(self, *args, **kwargs) -> None:
- self.conn.rollback(*args, **kwargs)
+ def rollback(self) -> None:
+ self.conn.rollback()
def __enter__(self) -> "Connection":
self.conn.__enter__()
@@ -244,12 +243,15 @@ class LoggingTransaction:
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
+ def fetchone(self) -> Optional[Tuple]:
+ return self.txn.fetchone()
+
+ def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+ return self.txn.fetchmany(size=size)
+
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
- def fetchone(self) -> Tuple:
- return self.txn.fetchone()
-
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@@ -429,8 +431,7 @@ class DatabasePool:
)
def is_running(self) -> bool:
- """Is the database pool currently running
- """
+ """Is the database pool currently running"""
return self._db_pool.running
async def _check_safe_to_upsert(self) -> None:
@@ -543,7 +544,11 @@ class DatabasePool:
# This can happen if the database disappears mid
# transaction.
transaction_logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name,
+ e,
+ i,
+ N,
)
if i < N:
i += 1
@@ -564,7 +569,9 @@ class DatabasePool:
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning(
- "[TXN EROLL] {%s} %s", name, e1,
+ "[TXN EROLL] {%s} %s",
+ name,
+ e1,
)
continue
raise
@@ -754,6 +761,7 @@ class DatabasePool:
Returns:
A list of dicts where the key is the column header.
"""
+ assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@@ -1402,7 +1410,10 @@ class DatabasePool:
@staticmethod
def simple_select_onecol_txn(
- txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
@@ -1712,7 +1723,11 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
- desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
+ desc,
+ self.simple_delete_one_txn,
+ table,
+ keyvalues,
+ db_autocommit=True,
)
@staticmethod
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0c243250..e84f8b42 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -56,7 +56,10 @@ class Databases:
database_config.databases,
)
prepare_database(
- db_conn, engine, hs.config, databases=database_config.databases,
+ db_conn,
+ engine,
+ hs.config,
+ databases=database_config.databases,
)
database = DatabasePool(hs, database_config, engine)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5d084558..70b49854 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -340,7 +340,7 @@ class DataStore(
count = txn.fetchone()[0]
sql = (
- "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
+ sql_base
+ " ORDER BY u.name LIMIT ? OFFSET ?"
)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e550cbc8..03a38422 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -73,8 +73,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
- """Check if the user is one associated with an app service (exclusively)
- """
+ """Check if the user is one associated with an app service (exclusively)"""
if self.exclusive_user_regex:
return bool(self.exclusive_user_regex.match(user_id))
else:
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index ea1e8fb5..6d18e692 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -280,8 +280,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _devices_last_seen_update(self, progress, batch_size):
- """Background update to insert last seen info into devices table
- """
+ """Background update to insert last seen info into devices table"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
@@ -363,8 +362,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
- """Removes entries in user IPs older than the configured period.
- """
+ """Removes entries in user IPs older than the configured period."""
if self.user_ips_max_age is None:
# Nothing to do
@@ -565,7 +563,11 @@ class ClientIpStore(ClientIpWorkerStore):
results = {}
for key in self._batch_row_update:
- uid, access_token, ip, = key
+ (
+ uid,
+ access_token,
+ ip,
+ ) = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 31f70ac5..45ca6620 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -450,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
},
)
- # Add the messages to the approriate local device inboxes so that
+ # Add the messages to the appropriate local device inboxes so that
# they'll be sent to the devices when they next sync.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 659d8f24..d327e9aa 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -315,7 +315,8 @@ class DeviceWorkerStore(SQLBaseStore):
# make sure we go through the devices in stream order
device_ids = sorted(
- user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+ user_devices.keys(),
+ key=lambda i: query_map[(user_id, i)][0],
)
for device_id in device_ids:
@@ -366,8 +367,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def mark_as_sent_devices_by_remote(
self, destination: str, stream_id: int
) -> None:
- """Mark that updates have successfully been sent to the destination.
- """
+ """Mark that updates have successfully been sent to the destination."""
await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
@@ -681,7 +681,8 @@ class DeviceWorkerStore(SQLBaseStore):
return results
async def get_user_ids_requiring_device_list_resync(
- self, user_ids: Optional[Collection[str]] = None,
+ self,
+ user_ids: Optional[Collection[str]] = None,
) -> Set[str]:
"""Given a list of remote users return the list of users that we
should resync the device lists for. If None is given instead of a list,
@@ -721,8 +722,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
- """Mark that we no longer track device lists for remote user.
- """
+ """Mark that we no longer track device lists for remote user."""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
self.db_pool.simple_delete_txn(
@@ -902,7 +902,8 @@ class DeviceWorkerStore(SQLBaseStore):
logger.info("Pruned %d device list outbound pokes", count)
await self.db_pool.runInteraction(
- "_prune_old_outbound_device_pokes", _prune_txn,
+ "_prune_old_outbound_device_pokes",
+ _prune_txn,
)
@@ -943,7 +944,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
# clear out duplicate device list outbound pokes
self.db_pool.updates.register_background_update_handler(
- BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+ self._remove_duplicate_outbound_pokes,
)
# a pair of background updates that were added during the 1.14 release cycle,
@@ -1004,17 +1006,23 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
row = None
for row in rows:
self.db_pool.simple_delete_txn(
- txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+ txn,
+ "device_lists_outbound_pokes",
+ {x: row[x] for x in KEY_COLS},
)
row["sent"] = False
self.db_pool.simple_insert_txn(
- txn, "device_lists_outbound_pokes", row,
+ txn,
+ "device_lists_outbound_pokes",
+ row,
)
if row:
self.db_pool.updates._background_update_progress_txn(
- txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+ txn,
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+ {"last_row": row},
)
return len(rows)
@@ -1286,7 +1294,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# we've done a full resync, so we remove the entry that says we need
# to resync
self.db_pool.simple_delete_txn(
- txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+ txn,
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
)
async def add_device_change_to_streams(
@@ -1336,7 +1346,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_ids: List[str],
):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
+ self._device_list_stream_cache.entity_has_changed,
+ user_id,
+ stream_ids[-1],
)
min_stream_id = stream_ids[0]
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index e5060d4c..267b9483 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -85,7 +85,7 @@ class DirectoryStore(DirectoryWorkerStore):
servers: Iterable[str],
creator: Optional[str] = None,
) -> None:
- """ Creates an association between a room alias and room_id/servers
+ """Creates an association between a room alias and room_id/servers
Args:
room_alias: The alias to create.
@@ -160,7 +160,10 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
async def update_aliases_for_room(
- self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+ self,
+ old_room_id: str,
+ new_room_id: str,
+ creator: Optional[str] = None,
) -> None:
"""Repoint all of the aliases for a given room, to a different room.
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 309f1e86..f1e7859d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
- """ Count the number of one time keys the server has for a device
+ """Count the number of one time keys the server has for a device
Returns:
A mapping from algorithm to number of keys for that algorithm.
"""
@@ -494,7 +494,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
- self, txn: Connection, user_ids: List[str],
+ self,
+ txn: Connection,
+ user_ids: List[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -556,7 +558,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_cross_signing_signatures_txn(
- self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+ self,
+ txn: Connection,
+ keys: Dict[str, Dict[str, dict]],
+ from_user_id: str,
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 8326640d..18ddb92f 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -71,7 +71,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
- self, event_ids: Collection[str], include_given: bool = False,
+ self,
+ event_ids: Collection[str],
+ include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
@@ -273,7 +275,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# origin chain.
if origin_sequence_number <= chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
- target_sequence_number, chains.get(target_chain_id, 0),
+ target_sequence_number,
+ chains.get(target_chain_id, 0),
)
seen_chains.add(target_chain_id)
@@ -371,7 +374,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# and state sets {A} and {B} then walking the auth chains of A and B
# would immediately show that C is reachable by both. However, if we
# stopped at C then we'd only reach E via the auth chain of B and so E
- # would errornously get included in the returned difference.
+ # would erroneously get included in the returned difference.
#
# The other thing that we do is limit the number of auth chains we walk
# at once, due to practical limits (i.e. we can only query the database
@@ -497,7 +500,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
a_ids = new_aids
- # Mark that the auth event is reachable by the approriate sets.
+ # Mark that the auth event is reachable by the appropriate sets.
sets.intersection_update(event_to_missing_sets[event_id])
search.sort()
@@ -632,8 +635,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
async def get_min_depth(self, room_id: str) -> int:
- """For the given room, get the minimum depth we have seen for it.
- """
+ """For the given room, get the minimum depth we have seen for it."""
return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
@@ -858,12 +860,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
await self.db_pool.runInteraction(
- "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+ "_delete_old_forward_extrem_cache",
+ _delete_old_forward_extrem_cache_txn,
)
class EventFederationStore(EventFederationWorkerStore):
- """ Responsible for storing and serving up the various graphs associated
+ """Responsible for storing and serving up the various graphs associated
with an event. Including the main event graph and the auth chains for an
event.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 438383ab..78245ad5 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight):
def _deserialize_action(actions, is_highlight):
- """Custom deserializer for actions. This allows us to "compress" common actions
- """
+ """Custom deserializer for actions. This allows us to "compress" common actions"""
if actions:
return db_to_json(actions)
@@ -91,7 +90,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
- self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+ self,
+ room_id: str,
+ user_id: str,
+ last_read_event_id: Optional[str],
) -> Dict[str, int]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@@ -120,13 +122,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
- self, txn, room_id, user_id, last_read_event_id,
+ self,
+ txn,
+ room_id,
+ user_id,
+ last_read_event_id,
):
stream_ordering = None
if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn(
- txn, last_read_event_id, allow_none=True,
+ txn,
+ last_read_event_id,
+ allow_none=True,
)
if stream_ordering is None:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ccda9f1c..287606cb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -399,7 +399,9 @@ class PersistEventsStore:
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
def _persist_event_auth_chain_txn(
- self, txn: LoggingTransaction, events: List[EventBase],
+ self,
+ txn: LoggingTransaction,
+ events: List[EventBase],
) -> None:
# We only care about state events, so this if there are no state events.
@@ -470,7 +472,11 @@ class PersistEventsStore:
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
- txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ txn,
+ self.db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
)
@classmethod
@@ -517,7 +523,10 @@ class PersistEventsStore:
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
auth_events = db_pool.simple_select_onecol_txn(
- txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+ txn,
+ "event_auth",
+ keyvalues={"event_id": event_id},
+ retcol="auth_id",
)
events_to_calc_chain_id_for.add(event_id)
@@ -550,7 +559,9 @@ class PersistEventsStore:
WHERE
"""
clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", missing_auth_chains,
+ txn.database_engine,
+ "event_id",
+ missing_auth_chains,
)
txn.execute(sql + clause, args)
@@ -704,7 +715,8 @@ class PersistEventsStore:
if chain_map[a_id][0] != chain_id
}
for start_auth_id, end_auth_id in itertools.permutations(
- event_to_auth_chain.get(event_id, []), r=2,
+ event_to_auth_chain.get(event_id, []),
+ r=2,
):
if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id]
@@ -888,8 +900,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
):
- """Persist the mapping from transaction IDs to event IDs (if defined).
- """
+ """Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = []
for event, _ in events_and_contexts:
@@ -909,7 +920,9 @@ class PersistEventsStore:
if to_insert:
self.db_pool.simple_insert_many_txn(
- txn, table="event_txn_id", values=to_insert,
+ txn,
+ table="event_txn_id",
+ values=to_insert,
)
def _update_current_state_txn(
@@ -941,7 +954,9 @@ class PersistEventsStore:
txn.execute(sql, (stream_id, self._instance_name, room_id))
self.db_pool.simple_delete_txn(
- txn, table="current_state_events", keyvalues={"room_id": room_id},
+ txn,
+ table="current_state_events",
+ keyvalues={"room_id": room_id},
)
else:
# We're still in the room, so we update the current state as normal.
@@ -1050,7 +1065,7 @@ class PersistEventsStore:
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
- # and which we have added, then we invlidate the caches for all
+ # and which we have added, then we invalidate the caches for all
# those users.
members_changed = {
state_key
@@ -1608,8 +1623,7 @@ class PersistEventsStore:
)
def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
+ """Store a room member in the database."""
def str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) else None
@@ -2001,8 +2015,7 @@ class PersistEventsStore:
@attr.s(slots=True)
class _LinkMap:
- """A helper type for tracking links between chains.
- """
+ """A helper type for tracking links between chains."""
# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
@@ -2108,7 +2121,9 @@ class _LinkMap:
yield (src_chain, src_seq, target_chain, target_seq)
def exists_path_from(
- self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+ self,
+ src_tuple: Tuple[int, int],
+ target_tuple: Tuple[int, int],
) -> bool:
"""Checks if there is a path between the source chain ID/sequence and
target chain ID/sequence.
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 5ca4fa68..89274e75 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -32,8 +32,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class _CalculateChainCover:
- """Return value for _calculate_chain_cover_txn.
- """
+ """Return value for _calculate_chain_cover_txn."""
# The last room_id/depth/stream processed.
room_id = attr.ib(type=str)
@@ -127,11 +126,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
self.db_pool.updates.register_background_update_handler(
- "rejected_events_metadata", self._rejected_events_metadata,
+ "rejected_events_metadata",
+ self._rejected_events_metadata,
)
self.db_pool.updates.register_background_update_handler(
- "chain_cover", self._chain_cover_index,
+ "chain_cover",
+ self._chain_cover_index,
)
async def _background_reindex_fields_sender(self, progress, batch_size):
@@ -462,8 +463,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return num_handled
async def _redactions_received_ts(self, progress, batch_size):
- """Handles filling out the `received_ts` column in redactions.
- """
+ """Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")
def _redactions_received_ts_txn(txn):
@@ -518,8 +518,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return count
async def _event_fix_redactions_bytes(self, progress, batch_size):
- """Undoes hex encoded censored redacted event JSON.
- """
+ """Undoes hex encoded censored redacted event JSON."""
def _event_fix_redactions_bytes_txn(txn):
# This update is quite fast due to new index.
@@ -642,7 +641,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
LIMIT ?
"""
- txn.execute(sql, (last_event_id, batch_size,))
+ txn.execute(
+ sql,
+ (
+ last_event_id,
+ batch_size,
+ ),
+ )
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
@@ -910,7 +915,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms.
PersistEventsStore._add_chain_cover_index(
- txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ txn,
+ self.db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
)
return _CalculateChainCover(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index 0ac1da9c..b3703ae1 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -71,7 +71,9 @@ class EventForwardExtremitiesStore(SQLBaseStore):
if txn.rowcount > 0:
# Invalidate the cache
self._invalidate_cache_and_stream(
- txn, self.get_latest_event_ids_in_room, (room_id,),
+ txn,
+ self.get_latest_event_ids_in_room,
+ (room_id,),
)
return txn.rowcount
@@ -97,5 +99,6 @@ class EventForwardExtremitiesStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
- "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+ "get_forward_extremities_for_room",
+ get_forward_extremities_for_room_txn,
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 71d823be..c8850a47 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -120,7 +120,9 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
+ db_conn,
+ "events",
+ "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
@@ -140,7 +142,8 @@ class EventsWorkerStore(SQLBaseStore):
if hs.config.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
- self._cleanup_old_transaction_ids, 5 * 60 * 1000,
+ self._cleanup_old_transaction_ids,
+ 5 * 60 * 1000,
)
self._get_event_cache = LruCache(
@@ -1325,8 +1328,7 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
async def is_event_after(self, event_id1, event_id2):
- """Returns True if event_id1 is after event_id2 in the stream
- """
+ """Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@@ -1428,8 +1430,7 @@ class EventsWorkerStore(SQLBaseStore):
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
- """Cleans out transaction id mappings older than 24hrs.
- """
+ """Cleans out transaction id mappings older than 24hrs."""
def _cleanup_old_transaction_ids_txn(txn):
sql = """
@@ -1440,5 +1441,6 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (one_day_ago,))
return await self.db_pool.runInteraction(
- "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+ "_cleanup_old_transaction_ids",
+ _cleanup_old_transaction_ids_txn,
)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 72181919..ac07e019 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple
+
+from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -26,6 +28,9 @@ from synapse.util import json_encoder
_DEFAULT_CATEGORY_ID = ""
_DEFAULT_ROLE_ID = ""
+# A room in a group.
+_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
+
class GroupServerWorkerStore(SQLBaseStore):
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
@@ -72,7 +77,7 @@ class GroupServerWorkerStore(SQLBaseStore):
async def get_rooms_in_group(
self, group_id: str, include_private: bool = False
- ) -> List[Dict[str, Union[str, bool]]]:
+ ) -> List[_RoomInGroup]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
@@ -123,7 +128,9 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_rooms_for_summary_by_category(
- self, group_id: str, include_private: bool = False,
+ self,
+ group_id: str,
+ include_private: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request
@@ -368,8 +375,7 @@ class GroupServerWorkerStore(SQLBaseStore):
async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
- """Has the group server invited a user?
- """
+ """Has the group server invited a user?"""
return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -427,8 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
- """Get all groups a user is publicising
- """
+ """Get all groups a user is publicising"""
return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
@@ -437,8 +442,7 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_attestations_need_renewals(self, valid_until_ms):
- """Get all attestations that need to be renewed until givent time
- """
+ """Get all attestations that need to be renewed until givent time"""
def _get_attestations_need_renewals_txn(txn):
sql = """
@@ -781,8 +785,7 @@ class GroupServerStore(GroupServerWorkerStore):
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
- """Add/update room category for group
- """
+ """Add/update room category for group"""
insertion_values = {}
update_values = {"category_id": category_id} # This cannot be empty
@@ -818,8 +821,7 @@ class GroupServerStore(GroupServerWorkerStore):
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
- """Add/remove user role
- """
+ """Add/remove user role"""
insertion_values = {}
update_values = {"role_id": role_id} # This cannot be empty
@@ -1012,8 +1014,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def add_group_invite(self, group_id: str, user_id: str) -> None:
- """Record that the group server has invited a user
- """
+ """Record that the group server has invited a user"""
await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
@@ -1156,8 +1157,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_group_publicity(
self, group_id: str, user_id: str, publicise: bool
) -> None:
- """Update whether the user is publicising their membership of the group
- """
+ """Update whether the user is publicising their membership of the group"""
await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1300,8 +1300,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_attestation_renewal(
self, group_id: str, user_id: str, attestation: dict
) -> None:
- """Update an attestation that we have renewed
- """
+ """Update an attestation that we have renewed"""
await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1312,8 +1311,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_remote_attestion(
self, group_id: str, user_id: str, attestation: dict
) -> None:
- """Update an attestation that a remote has renewed
- """
+ """Update an attestation that a remote has renewed"""
await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 04ac2d0c..d504323b 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -33,8 +33,7 @@ db_binary_type = memoryview
class KeyStore(SQLBaseStore):
- """Persistence for signature verification keys
- """
+ """Persistence for signature verification keys"""
@cached()
def _get_server_verify_key(self, server_name_and_key_id):
@@ -155,7 +154,7 @@ class KeyStore(SQLBaseStore):
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name: The name of the server.
- key_id: The identifer of the key this JSON is for.
+ key_id: The identifier of the key this JSON is for.
from_server: The server this JSON was fetched from.
ts_now_ms: The time now in milliseconds.
ts_valid_until_ms: The time when this json stops being valid.
@@ -182,7 +181,7 @@ class KeyStore(SQLBaseStore):
async def get_server_keys_json(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
- """Retrive the key json for a list of server_keys and key ids.
+ """Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index e0171776..a0313c3c 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -169,7 +169,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_before(
- self, before_ts: int, size_gt: int, keep_profiles: bool,
+ self,
+ before_ts: int,
+ size_gt: int,
+ keep_profiles: bool,
) -> List[str]:
# to find files that have never been accessed (last_access_ts IS NULL)
@@ -454,10 +457,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_remote_media_thumbnail(
- self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+ self,
+ origin: str,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_type: str,
) -> Optional[Dict[str, Any]]:
- """Fetch the thumbnail info of given width, height and type.
- """
+ """Fetch the thumbnail info of given width, height and type."""
return await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92e65aa6..614a418a 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -111,7 +111,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
async def count_daily_sent_e2ee_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
- # hostname then thats your own fault.
+ # hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
sql = """
@@ -167,7 +167,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
async def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
- # hostname then thats your own fault.
+ # hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
sql = """
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index dbbb99cb..29edab34 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -130,7 +130,9 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
+ cached_method_name="_get_presence_for_user",
+ list_name="user_ids",
+ num_args=1,
)
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 54ef0f1f..ba01d310 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -118,8 +118,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def is_subscribed_remote_profile_for_user(self, user_id):
- """Check whether we are interested in a remote user's profile.
- """
+ """Check whether we are interested in a remote user's profile."""
res = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
@@ -145,8 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> List[Dict[str, str]]:
- """Get all users who haven't been checked since `last_checked`
- """
+ """Get all users who haven't been checked since `last_checked`"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 711d5aa2..9e58dc0e 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -168,7 +168,9 @@ class PushRulesWorkerStore(
)
@cachedList(
- cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
+ cached_method_name="get_push_rules_for_user",
+ list_name="user_ids",
+ num_args=1,
)
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
@@ -195,7 +197,9 @@ class PushRulesWorkerStore(
use_new_defaults = user_id in self._users_new_default_push_rules
results[user_id] = _load_rules(
- rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
+ rules,
+ enabled_map_by_user.get(user_id, {}),
+ use_new_defaults,
)
return results
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 2687ef3e..7cb69dd6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -179,7 +179,9 @@ class PusherWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
+ cached_method_name="get_if_user_has_pusher",
+ list_name="user_ids",
+ num_args=1,
)
async def get_if_users_have_pushers(
self, user_ids: Iterable[str]
@@ -263,7 +265,8 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {}
for row in res:
params_by_room[row["room_id"]] = ThrottleParams(
- row["last_sent_ts"], row["throttle_ms"],
+ row["last_sent_ts"],
+ row["throttle_ms"],
)
return params_by_room
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e4843a20..43c852c9 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -160,7 +160,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
Args:
room_id: List of room_ids.
- to_key: Max stream id to fetch receipts upto.
+ to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
@@ -189,7 +189,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
Args:
room_ids: The room id.
- to_key: Max stream id to fetch receipts upto.
+ to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
@@ -208,8 +208,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
- """See get_linearized_receipts_for_room
- """
+ """See get_linearized_receipts_for_room"""
def f(txn):
if from_key:
@@ -304,7 +303,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
- @cached(num_args=2,)
+ @cached(
+ num_args=2,
+ )
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
@@ -312,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
to a limit of the latest 100 read receipts.
Args:
- to_key: Max stream id to fetch receipts upto.
+ to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8405dd46..d5b55078 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -79,13 +79,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
- database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+ database.engine,
+ find_max_generated_user_id_localpart,
+ "user_id_seq",
)
self._account_validity = hs.config.account_validity
if hs.config.run_background_tasks and self._account_validity.enabled:
self._clock.call_later(
- 0.0, self._set_expiration_date_when_missing,
+ 0.0,
+ self._set_expiration_date_when_missing,
)
# Create a background job for culling expired 3PID validity tokens
@@ -110,6 +113,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"creation_ts",
"user_type",
"deactivated",
+ "shadow_banned",
],
allow_none=True,
desc="get_user_by_id",
@@ -369,23 +373,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""
def set_shadow_banned_txn(txn):
+ user_id = user.to_string()
self.db_pool.simple_update_one_txn(
txn,
table="users",
- keyvalues={"name": user.to_string()},
+ keyvalues={"name": user_id},
updatevalues={"shadow_banned": shadow_banned},
)
# In order for this to apply immediately, clear the cache for this user.
tokens = self.db_pool.simple_select_onecol_txn(
txn,
table="access_tokens",
- keyvalues={"user_id": user.to_string()},
+ keyvalues={"user_id": user_id},
retcol="token",
)
for token in tokens:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,)
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a9fcb5f5..9cbcd530 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -193,8 +193,7 @@ class RoomWorkerStore(SQLBaseStore):
)
async def get_room_count(self) -> int:
- """Retrieve the total number of rooms.
- """
+ """Retrieve the total number of rooms."""
def f(txn):
sql = "SELECT count(*) FROM rooms"
@@ -517,7 +516,8 @@ class RoomWorkerStore(SQLBaseStore):
return rooms, room_count[0]
return await self.db_pool.runInteraction(
- "get_rooms_paginate", _get_rooms_paginate_txn,
+ "get_rooms_paginate",
+ _get_rooms_paginate_txn,
)
@cached(max_entries=10000)
@@ -578,7 +578,8 @@ class RoomWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
ret = await self.db_pool.runInteraction(
- "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+ "get_retention_policy_for_room",
+ get_retention_policy_for_room_txn,
)
# If we don't know this room ID, ret will be None, in this case return the default
@@ -707,7 +708,10 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
async def quarantine_media_by_id(
- self, server_name: str, media_id: str, quarantined_by: str,
+ self,
+ server_name: str,
+ media_id: str,
+ quarantined_by: str,
) -> int:
"""quarantines a single local or remote media id
@@ -961,7 +965,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self.config = hs.config
self.db_pool.updates.register_background_update_handler(
- "insert_room_retention", self._background_insert_retention,
+ "insert_room_retention",
+ self._background_insert_retention,
)
self.db_pool.updates.register_background_update_handler(
@@ -1033,7 +1038,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return False
end = await self.db_pool.runInteraction(
- "insert_room_retention", _background_insert_retention_txn,
+ "insert_room_retention",
+ _background_insert_retention_txn,
)
if end:
@@ -1044,7 +1050,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int
):
- """Background update to go and add room version inforamtion to `rooms`
+ """Background update to go and add room version information to `rooms`
table from `current_state_events` table.
"""
@@ -1588,7 +1594,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
LIMIT ?
OFFSET ?
""".format(
- where_clause=where_clause, order=order,
+ where_clause=where_clause,
+ order=order,
)
args += [limit, start]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 92382bed..a9216ca9 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -70,10 +70,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
- self._count_known_servers, 60 * 1000,
+ self._count_known_servers,
+ 60 * 1000,
)
self.hs.get_clock().call_later(
- 1000, self._count_known_servers,
+ 1000,
+ self._count_known_servers,
)
LaterGauge(
"synapse_federation_known_servers",
@@ -174,7 +176,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000)
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
- """ Get the details of a room roughly suitable for use by the room
+ """Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
@@ -488,8 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_users_who_share_room_with_user(
self, user_id: str, cache_context: _CacheContext
) -> Set[str]:
- """Returns the set of users who share a room with `user_id`
- """
+ """Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
@@ -618,7 +619,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
+ cached_method_name="_get_joined_profile_from_event_id",
+ list_name="event_ids",
)
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
@@ -802,8 +804,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
- """Get user_id and membership of a set of event IDs.
- """
+ """Get user_id and membership of a set of event IDs."""
return await self.db_pool.simple_select_many_batch(
table="room_memberships",
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index ad875c73..3907189e 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute(
- "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
+ "UPDATE remote_media_cache SET last_access_ts = ?",
+ (int(time.time() * 1000),),
)
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede..308124e5 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -67,11 +67,6 @@ CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
/* event_search(event_id,room_id,sender,"key",value) */;
-CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
-CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) );
CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) );
CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) );
@@ -149,11 +144,6 @@ CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_las
CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') );
CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value )
/* user_directory_search(user_id,value) */;
-CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value');
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL );
CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT );
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 3c1e3381..a7f37173 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -52,8 +52,7 @@ class _GetStateGroupDelta(
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
- """The parts of StateGroupStore that can be called from workers.
- """
+ """The parts of StateGroupStore that can be called from workers."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -276,8 +275,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
- """Returns mapping event_id -> state_group
- """
+ """Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@@ -338,7 +336,8 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
columns=["state_group"],
)
self.db_pool.updates.register_background_update_handler(
- self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
+ self.DELETE_CURRENT_STATE_UPDATE_NAME,
+ self._background_remove_left_rooms,
)
async def _background_remove_left_rooms(self, progress, batch_size):
@@ -487,7 +486,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
- """ Keeps track of the state at a given event.
+ """Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 356623fc..0dbb501f 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -64,7 +64,7 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
# N results.
- # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # We arbitrarily limit to 100 stream_id entries to ensure we don't
# select toooo many.
sql = """
SELECT stream_id, count(*)
@@ -81,7 +81,7 @@ class StateDeltasStore(SQLBaseStore):
for stream_id, count in txn:
total += count
if total > 100:
- # We arbitarily limit to 100 entries to ensure we don't
+ # We arbitrarily limit to 100 entries to ensure we don't
# select toooo many.
logger.debug(
"Clipping current_state_delta_stream rows to stream_id %i",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d421d18f..1c99393c 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -1001,7 +1001,9 @@ class StatsStore(StateDeltasStore):
ORDER BY {order_by_column} {order}
LIMIT ? OFFSET ?
""".format(
- sql_base=sql_base, order_by_column=order_by_column, order=order,
+ sql_base=sql_base,
+ order_by_column=order_by_column,
+ order=order,
)
args += [limit, start]
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e3b9ff5c..91f8abb6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -565,7 +565,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
AND e.stream_ordering > ? AND e.stream_ordering <= ?
ORDER BY e.stream_ordering ASC
"""
- txn.execute(sql, (user_id, min_from_id, max_to_id,))
+ txn.execute(
+ sql,
+ (
+ user_id,
+ min_from_id,
+ max_to_id,
+ ),
+ )
rows = [
_EventDictReturn(event_id, None, stream_ordering)
@@ -695,7 +702,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return "t%d-%d" % (topo, token)
def get_stream_id_for_event_txn(
- self, txn: LoggingTransaction, event_id: str, allow_none=False,
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn=txn,
@@ -706,8 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
)
async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
- """Get the persisted position for an event
- """
+ """Get the persisted position for an event"""
row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
@@ -897,19 +906,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
) -> Tuple[int, List[EventBase]]:
"""Get all new events
- Returns all events with from_id < stream_ordering <= current_id.
+ Returns all events with from_id < stream_ordering <= current_id.
- Args:
- from_id: the stream_ordering of the last event we processed
- current_id: the stream_ordering of the most recently processed event
- limit: the maximum number of events to return
+ Args:
+ from_id: the stream_ordering of the last event we processed
+ current_id: the stream_ordering of the most recently processed event
+ limit: the maximum number of events to return
- Returns:
- A tuple of (next_id, events), where `next_id` is the next value to
- pass as `from_id` (it will either be the stream_ordering of the
- last returned event, or, if fewer than `limit` events were found,
- the `current_id`).
- """
+ Returns:
+ A tuple of (next_id, events), where `next_id` is the next value to
+ pass as `from_id` (it will either be the stream_ordering of the
+ last returned event, or, if fewer than `limit` events were found,
+ the `current_id`).
+ """
def get_all_new_events_stream_txn(txn):
sql = (
@@ -1238,8 +1247,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
@cached()
async def get_id_for_instance(self, instance_name: str) -> int:
- """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
- """
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
def _get_id_for_instance_txn(txn):
instance_id = self.db_pool.simple_select_one_onecol_txn(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index cea595ff..b921d63d 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -64,8 +64,7 @@ class TransactionWorkerStore(SQLBaseStore):
class TransactionStore(TransactionWorkerStore):
- """A collection of queries for handling PDUs.
- """
+ """A collection of queries for handling PDUs."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -198,7 +197,7 @@ class TransactionStore(TransactionWorkerStore):
retry_interval: int,
) -> None:
"""Sets the current retry timings for a given destination.
- Both timings should be zero if retrying is no longer occuring.
+ Both timings should be zero if retrying is no longer occurring.
Args:
destination
@@ -299,7 +298,10 @@ class TransactionStore(TransactionWorkerStore):
)
async def store_destination_rooms_entries(
- self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+ self,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
) -> None:
"""
Updates or creates `destination_rooms` entries in batch for a single event.
@@ -394,7 +396,9 @@ class TransactionStore(TransactionWorkerStore):
)
async def get_catch_up_room_event_ids(
- self, destination: str, last_successful_stream_ordering: int,
+ self,
+ destination: str,
+ last_successful_stream_ordering: int,
) -> List[str]:
"""
Returns at most 50 event IDs and their corresponding stream_orderings
@@ -418,7 +422,9 @@ class TransactionStore(TransactionWorkerStore):
@staticmethod
def _get_catch_up_room_event_ids_txn(
- txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+ txn: LoggingTransaction,
+ destination: str,
+ last_successful_stream_ordering: int,
) -> List[str]:
q = """
SELECT event_id FROM destination_rooms
@@ -429,7 +435,8 @@ class TransactionStore(TransactionWorkerStore):
LIMIT 50
"""
txn.execute(
- q, (destination, last_successful_stream_ordering),
+ q,
+ (destination, last_successful_stream_ordering),
)
event_ids = [row[0] for row in txn]
return event_ids
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 79b7ece3..5473ec14 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore):
"""
async def create_ui_auth_session(
- self, clientdict: JsonDict, uri: str, method: str, description: str,
+ self,
+ clientdict: JsonDict,
+ uri: str,
+ method: str,
+ description: str,
) -> UIAuthSessionData:
"""
Creates a new user interactive authentication session.
@@ -123,7 +127,10 @@ class UIAuthWorkerStore(SQLBaseStore):
return UIAuthSessionData(session_id, **result)
async def mark_ui_auth_stage_complete(
- self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+ self,
+ session_id: str,
+ stage_type: str,
+ result: Union[str, bool, JsonDict],
):
"""
Mark a session stage as completed.
@@ -261,10 +268,12 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
async def add_user_agent_ip_to_ui_auth_session(
- self, session_id: str, user_agent: str, ip: str,
+ self,
+ session_id: str,
+ user_agent: str,
+ ip: str,
):
- """Add the given user agent / IP to the tracking table
- """
+ """Add the given user agent / IP to the tracking table"""
await self.db_pool.simple_upsert(
table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
@@ -273,7 +282,8 @@ class UIAuthWorkerStore(SQLBaseStore):
)
async def get_user_agents_ips_to_ui_auth_session(
- self, session_id: str,
+ self,
+ session_id: str,
) -> List[Tuple[str, str]]:
"""Get the given user agents / IPs used during the ui auth process
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 7b9729da..63f88eac 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -336,8 +336,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
async def is_room_world_readable_or_publicly_joinable(self, room_id):
- """Check if the room is either world_readable or publically joinable
- """
+ """Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event
types_to_filter = (
@@ -516,8 +515,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
async def delete_all_from_user_dir(self) -> None:
- """Delete the entire user directory
- """
+ """Delete the entire user directory"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
@@ -709,7 +707,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {row["room_id"] for row in rows}
- async def get_user_directory_stream_pos(self) -> int:
+ async def get_user_directory_stream_pos(self) -> Optional[int]:
+ """
+ Get the stream ID of the user directory stream.
+
+ Returns:
+ The stream token or None if the initial background update hasn't happened yet.
+ """
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index acb24e33..1fd333b7 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -27,7 +27,7 @@ MAX_STATE_DELTA_HOPS = 100
class StateGroupBackgroundUpdateStore(SQLBaseStore):
- """Defines functions related to state groups needed to run the state backgroud
+ """Defines functions related to state groups needed to run the state background
updates.
"""
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 89cdc84a..b16b9905 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -48,8 +48,7 @@ class _GetStateGroupDelta(
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
- """A data store for fetching/storing state groups.
- """
+ """A data store for fetching/storing state groups."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -89,7 +88,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
50000,
)
self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*", 500000,
+ "*stateGroupMembersCache*",
+ 500000,
)
def get_max_state_group_txn(txn: Cursor):
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 035f9ea6..d15ccfac 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import platform
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
@@ -28,11 +27,8 @@ def create_engine(database_config) -> BaseDatabaseEngine:
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
- # pypy requires psycopg2cffi rather than psycopg2
- if platform.python_implementation() == "PyPy":
- import psycopg2cffi as psycopg2 # type: ignore
- else:
- import psycopg2 # type: ignore
+ # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
+ import psycopg2 # type: ignore
return PostgresEngine(psycopg2, database_config)
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index d6d632dc..cca839c7 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -94,14 +94,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def server_version(self) -> str:
- """Gets a string giving the server version. For example: '3.22.0'
- """
+ """Gets a string giving the server version. For example: '3.22.0'"""
...
@abc.abstractmethod
def in_transaction(self, conn: Connection) -> bool:
- """Whether the connection is currently in a transaction.
- """
+ """Whether the connection is currently in a transaction."""
...
@abc.abstractmethod
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 7719ac32..80a3558a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -138,8 +138,7 @@ class PostgresEngine(BaseDatabaseEngine):
@property
def supports_using_any_list(self):
- """Do we support using `a = ANY(?)` and passing a list
- """
+ """Do we support using `a = ANY(?)` and passing a list"""
return True
def is_deadlock(self, error):
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 5db0f0b5..b87e7798 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import platform
import struct
import threading
import typing
@@ -28,7 +29,15 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database")
- self._is_in_memory = database in (None, ":memory:",)
+ self._is_in_memory = database in (
+ None,
+ ":memory:",
+ )
+
+ if platform.python_implementation() == "PyPy":
+ # pypy's sqlite3 module doesn't handle bytearrays, convert them
+ # back to bytes.
+ database_module.register_adapter(bytearray, lambda array: bytes(array))
# The current max state_group, or None if we haven't looked
# in the DB yet.
@@ -57,8 +66,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
@property
def supports_using_any_list(self):
- """Do we support using `a = ANY(?)` and passing a list
- """
+ """Do we support using `a = ANY(?)` and passing a list"""
return False
def check_database(self, db_conn, allow_outdated_version: bool = False):
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 61fc49c6..3a0d6fb3 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -411,8 +411,8 @@ class EventsPersistenceStorage:
)
for room_id, ev_ctx_rm in events_by_room.items():
- latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
- room_id
+ latest_event_ids = (
+ await self.main_store.get_latest_event_ids_in_room(room_id)
)
new_latest_event_ids = await self._calculate_new_extremities(
room_id, ev_ctx_rm, latest_event_ids
@@ -889,7 +889,8 @@ class EventsPersistenceStorage:
continue
logger.debug(
- "Not dropping as too new and not in new_senders: %s", new_senders,
+ "Not dropping as too new and not in new_senders: %s",
+ new_senders,
)
return new_latest_event_ids
@@ -1004,7 +1005,10 @@ class EventsPersistenceStorage:
remote_event_ids = [
event_id
- for (typ, state_key,), event_id in current_state.items()
+ for (
+ typ,
+ state_key,
+ ), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 566ea19b..6c3c2da5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -113,7 +113,7 @@ def prepare_database(
# which should be empty.
if config is None:
raise ValueError(
- "config==None in prepare_database, but databse is not empty"
+ "config==None in prepare_database, but database is not empty"
)
# if it's a worker app, refuse to upgrade the database, to avoid multiple
@@ -425,7 +425,10 @@ def _upgrade_existing_database(
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
"Found multiple delta files with the same name in v%d: %s"
- % (v, duplicates,)
+ % (
+ v,
+ duplicates,
+ )
)
# We sort to ensure that we apply the delta files in a consistent
@@ -532,7 +535,8 @@ def _apply_module_schema_files(
names_and_streams: the names and streams of schemas to be applied
"""
cur.execute(
- "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
+ "SELECT file FROM applied_module_schemas WHERE module_name = ?",
+ (modname,),
)
applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
@@ -619,9 +623,9 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
- current_version = int(row[0]) if row else None
- if current_version:
+ if row is not None:
+ current_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 6c359c1a..3c490886 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -26,15 +26,13 @@ logger = logging.getLogger(__name__)
class PurgeEventsStorage:
- """High level interface for purging rooms and event history.
- """
+ """High level interface for purging rooms and event history."""
def __init__(self, hs: "HomeServer", stores: Databases):
self.stores = stores
async def purge_room(self, room_id: str) -> None:
- """Deletes all record of a room
- """
+ """Deletes all record of a room"""
state_groups_to_delete = await self.stores.main.purge_room(room_id)
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 31ccbf23..d179a418 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -340,8 +340,7 @@ class StateFilter:
class StateGroupStorage:
- """High level interface to fetching state for event.
- """
+ """High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
@@ -400,7 +399,7 @@ class StateGroupStorage:
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
- """ Get the state groups for the given list of event_ids
+ """Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba1..17291c9d 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing_extensions import Protocol
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
+_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+
class Cursor(Protocol):
- def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
...
- def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...
- def fetchall(self) -> List[Tuple]:
+ def fetchone(self) -> Optional[Tuple]:
+ ...
+
+ def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
...
- def fetchone(self) -> Tuple:
+ def fetchall(self) -> List[Tuple]:
...
@property
- def description(self) -> Any:
- return None
+ def description(
+ self,
+ ) -> Optional[
+ Sequence[
+ # Note that this is an approximate typing based on sqlite3 and other
+ # drivers, and may not be entirely accurate.
+ Tuple[
+ str,
+ Optional[Any],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ ]
+ ]
+ ]:
+ ...
@property
def rowcount(self) -> int:
@@ -59,7 +80,7 @@ class Connection(Protocol):
def commit(self) -> None:
...
- def rollback(self, *args, **kwargs) -> None:
+ def rollback(self) -> None:
...
def __enter__(self) -> "Connection":
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 71ef5a72..d4643c4f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -245,7 +245,7 @@ class MultiWriterIdGenerator:
# and b) noting that if we have seen a run of persisted positions
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
#
- # Note: There is no guarentee that the IDs generated by the sequence
+ # Note: There is no guarantee that the IDs generated by the sequence
# will be gapless; gaps can form when e.g. a transaction was rolled
# back. This means that sometimes we won't be able to skip forward the
# position even though everything has been persisted. However, since
@@ -277,7 +277,9 @@ class MultiWriterIdGenerator:
self._load_current_ids(db_conn, tables)
def _load_current_ids(
- self, db_conn, tables: List[Tuple[str, str, str]],
+ self,
+ db_conn,
+ tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -364,7 +366,10 @@ class MultiWriterIdGenerator:
rows.sort()
with self._lock:
- for (instance, stream_id,) in rows:
+ for (
+ instance,
+ stream_id,
+ ) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)
@@ -418,7 +423,7 @@ class MultiWriterIdGenerator:
# bother, as nothing will read it).
#
# We only do this on the success path so that the persisted current
- # position points to a persited row with the correct instance name.
+ # position points to a persisted row with the correct instance name.
if self._writers:
txn.call_after(
run_as_background_process,
@@ -481,8 +486,7 @@ class MultiWriterIdGenerator:
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
- """
+ """Returns the position of the given writer."""
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
@@ -509,7 +513,7 @@ class MultiWriterIdGenerator:
}
def advance(self, instance_name: str, new_id: int):
- """Advance the postion of the named writer to the given ID, if greater
+ """Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
@@ -581,8 +585,7 @@ class MultiWriterIdGenerator:
break
def _update_stream_positions_table_txn(self, txn: Cursor):
- """Update the `stream_positions` table with newly persisted position.
- """
+ """Update the `stream_positions` table with newly persisted position."""
if not self._writers:
return
@@ -622,8 +625,7 @@ class _AsyncCtxManagerWrapper:
@attr.s(slots=True)
class _MultiWriterCtxManager:
- """Async context manager returned by MultiWriterIdGenerator
- """
+ """Async context manager returned by MultiWriterIdGenerator"""
id_gen = attr.ib(type=MultiWriterIdGenerator)
multiple_ids = attr.ib(type=Optional[int], default=None)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 0ec4dc29..3ea637b2 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
- return txn.fetchone()[0]
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ return fetch_res[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
@@ -122,8 +124,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
stream_name: Optional[str] = None,
positive: bool = True,
):
- """See SequenceGenerator.check_consistency for docstring.
- """
+ """See SequenceGenerator.check_consistency for docstring."""
txn = db_conn.cursor(txn_name="sequence.check_consistency")
@@ -147,7 +148,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
- last_value, is_called = txn.fetchone()
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ last_value, is_called = fetch_res
# If we have an associated stream check the stream_positions table.
max_in_stream_positions = None