diff options
Diffstat (limited to 'synapse/storage')
36 files changed, 471 insertions, 238 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bfce541c..985a0428 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -100,8 +100,8 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # Decode it to a Unicode string before feeding it to json.loads, so we - # consistenty get a Unicode-containing object out. + # Decode it to a Unicode string before feeding it to json.loads, since + # Python 3.5 does not support deserializing bytes. if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 59f3394b..018826ef 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -249,7 +249,10 @@ class BackgroundUpdater(object): retcol="progress_json", ) - progress = json.loads(progress_json) + # Avoid a circular import. + from synapse.storage._base import db_to_json + + progress = db_to_json(progress_json) time_start = self._clock.time_msec() items_updated = await update_handler(progress, batch_size) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 4b4763c7..932458f6 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -128,7 +128,7 @@ class DataStore( db_conn, "presence_stream", "stream_id" ) self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index b58f04d0..33cc372d 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -77,7 +77,7 @@ class AccountDataWorkerStore(SQLBaseStore): ) global_account_data = { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } rows = self.db.simple_select_list_txn( @@ -90,7 +90,7 @@ class AccountDataWorkerStore(SQLBaseStore): by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) - room_data[row["account_data_type"]] = json.loads(row["content"]) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room @@ -113,7 +113,7 @@ class AccountDataWorkerStore(SQLBaseStore): ) if result: - return json.loads(result) + return db_to_json(result) else: return None @@ -137,7 +137,7 @@ class AccountDataWorkerStore(SQLBaseStore): ) return { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } return self.db.runInteraction( @@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore): allow_none=True, ) - return json.loads(content_json) if content_json else None + return db_to_json(content_json) if content_json else None return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn @@ -255,7 +255,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) - global_account_data = {row[0]: json.loads(row[1]) for row in txn} + global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" @@ -267,7 +267,7 @@ class AccountDataWorkerStore(SQLBaseStore): account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) - room_account_data[row[1]] = json.loads(row[2]) + room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 7a1fe8cd..56659fed 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -22,7 +22,7 @@ from twisted.internet import defer from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database @@ -303,7 +303,7 @@ class ApplicationServiceTransactionWorkerStore( if not entry: return None - event_ids = json.loads(entry["event_ids"]) + event_ids = db_to_json(entry["event_ids"]) events = yield self.get_events_as_list(event_ids) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index d313b970..da297b31 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,7 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache @@ -65,7 +65,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos @@ -173,7 +173,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id @@ -424,9 +424,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" - txn.execute(sql, (stream_id, stream_id)) - local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 343cf9a2..45581a65 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -577,7 +577,7 @@ class DeviceWorkerStore(SQLBaseStore): rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return {user for row in rows for user in json.loads(row[0])} + return {user for row in rows for user in db_to_json(row[0])} else: return set() diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 23f4570c..615364f0 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +from canonicaljson import json from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json class EndToEndRoomKeyStore(SQLBaseStore): @@ -148,7 +148,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "forwarded_count": row["forwarded_count"], # is_verified must be returned to the client as a boolean "is_verified": bool(row["is_verified"]), - "session_data": json.loads(row["session_data"]), + "session_data": db_to_json(row["session_data"]), } return sessions @@ -222,7 +222,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": row[2], "forwarded_count": row[3], "is_verified": row[4], - "session_data": json.loads(row[5]), + "session_data": db_to_json(row[5]), } return ret @@ -319,7 +319,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, retcols=("version", "algorithm", "auth_data", "etag"), ) - result["auth_data"] = json.loads(result["auth_data"]) + result["auth_data"] = db_to_json(result["auth_data"]) result["version"] = str(result["version"]) if result["etag"] is None: result["etag"] = 0 diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 6c3cff82..317c07a8 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -366,7 +366,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for row in rows: user_id = row["user_id"] key_type = row["keytype"] - key = json.loads(row["keydata"]) + key = db_to_json(row["keydata"]) user_info = result.setdefault(user_id, {}) user_info[key_type] = key diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index bc9f4f08..504babaa 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -21,7 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks @@ -58,7 +58,7 @@ def _deserialize_action(actions, is_highlight): """Custom deserializer for actions. This allows us to "compress" common actions """ if actions: - return json.loads(actions) + return db_to_json(actions) if is_highlight: return DEFAULT_HIGHLIGHT_ACTION diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 230fb5cd..6f2e0d15 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -17,11 +17,9 @@ import itertools import logging from collections import OrderedDict, namedtuple -from functools import wraps from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple import attr -from canonicaljson import json from prometheus_client import Counter from twisted.internet import defer @@ -33,7 +31,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function -from synapse.storage._base import make_in_list_sql_clause +from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.data_stores.main.search import SearchEntry from synapse.storage.database import Database, LoggingTransaction from synapse.storage.util.id_generators import StreamIdGenerator @@ -69,27 +67,6 @@ def encode_json(json_object): _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) -def _retry_on_integrity_error(func): - """Wraps a database function so that it gets retried on IntegrityError, - with `delete_existing=True` passed in. - - Args: - func: function that returns a Deferred and accepts a `delete_existing` arg - """ - - @wraps(func) - @defer.inlineCallbacks - def f(self, *args, **kwargs): - try: - res = yield func(self, *args, delete_existing=False, **kwargs) - except self.database_engine.module.IntegrityError: - logger.exception("IntegrityError, retrying.") - res = yield func(self, *args, delete_existing=True, **kwargs) - return res - - return f - - @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -134,7 +111,6 @@ class PersistEventsStore: hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" - @_retry_on_integrity_error @defer.inlineCallbacks def _persist_events_and_state_updates( self, @@ -143,7 +119,6 @@ class PersistEventsStore: state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], backfilled: bool = False, - delete_existing: bool = False, ): """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -157,7 +132,6 @@ class PersistEventsStore: new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. backfilled - delete_existing Returns: Deferred: resolves when the events have been persisted @@ -197,7 +171,6 @@ class PersistEventsStore: self._persist_events_txn, events_and_contexts=events_and_contexts, backfilled=backfilled, - delete_existing=delete_existing, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) @@ -262,7 +235,7 @@ class PersistEventsStore: ) txn.execute(sql + clause, args) - results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) + results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): yield self.db.runInteraction( @@ -323,7 +296,7 @@ class PersistEventsStore: if prev_event_id in existing_prevs: continue - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: to_recursively_check.append(prev_event_id) existing_prevs.add(prev_event_id) @@ -341,7 +314,6 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - delete_existing: bool = False, state_delta_for_room: Dict[str, DeltaState] = {}, new_forward_extremeties: Dict[str, List[str]] = {}, ): @@ -393,13 +365,6 @@ class PersistEventsStore: # From this point onwards the events are only events that we haven't # seen before. - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - self._store_event_txn(txn, events_and_contexts=events_and_contexts) # Insert into event_to_state_groups. @@ -617,7 +582,7 @@ class PersistEventsStore: txn.execute(sql, (room_id, EventTypes.Create, "")) row = txn.fetchone() if row: - event_json = json.loads(row[0]) + event_json = db_to_json(row[0]) content = event_json.get("content", {}) creator = content.get("creator") room_version_id = content.get("room_version", RoomVersions.V1.identifier) @@ -797,39 +762,6 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] - @classmethod - def _delete_existing_rows_txn(cls, txn, events_and_contexts): - if not events_and_contexts: - # nothing to do here - return - - logger.info("Deleting existing") - - for table in ( - "events", - "event_auth", - "event_json", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_to_state_groups", - "state_events", - "rejections", - "redactions", - "room_memberships", - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts], - ) - - for table in ("event_push_actions",): - txn.executemany( - "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], - ) - def _store_event_txn(self, txn, events_and_contexts): """Insert new events into the event and event_json tables diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 62d28f44..663c94b2 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -15,12 +15,10 @@ import logging -from canonicaljson import json - from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): for row in rows: try: event_id = row[1] - event_json = json.loads(row[2]) + event_json = db_to_json(row[2]) sender = event_json["sender"] content = event_json["content"] @@ -208,7 +206,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): for row in ev_rows: event_id = row["event_id"] - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) try: origin_server_ts = event_json["origin_server_ts"] except (KeyError, AttributeError): @@ -317,7 +315,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): soft_failed = False if metadata: - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) @@ -358,7 +356,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): graph[event_id] = {prev_event_id} - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) else: @@ -543,7 +541,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): last_row_event_id = "" for (event_id, event_json_raw) in results: try: - event_json = json.loads(event_json_raw) + event_json = db_to_json(event_json_raw) self.db.simple_insert_many_txn( txn=txn, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 01cad7d4..e812c670 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -21,7 +21,6 @@ import threading from collections import namedtuple from typing import List, Optional, Tuple -from canonicaljson import json from constantly import NamedConstant, Names from twisted.internet import defer @@ -40,7 +39,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id @@ -611,8 +610,8 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - d = json.loads(row["json"]) - internal_metadata = json.loads(row["internal_metadata"]) + d = db_to_json(row["json"]) + internal_metadata = db_to_json(row["internal_metadata"]) format_version = row["format_version"] if format_version is None: @@ -640,7 +639,7 @@ class EventsWorkerStore(SQLBaseStore): else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) if not room_version: - logger.error( + logger.warning( "Event %s in room %s has unknown room version %s", event_id, d["room_id"], diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 4fb9f985..01ff561e 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -21,7 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -197,7 +197,7 @@ class GroupServerWorkerStore(SQLBaseStore): categories = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -221,7 +221,7 @@ class GroupServerWorkerStore(SQLBaseStore): return { row["category_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } @@ -235,7 +235,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group_category", ) - category["profile"] = json.loads(category["profile"]) + category["profile"] = db_to_json(category["profile"]) return category @@ -251,7 +251,7 @@ class GroupServerWorkerStore(SQLBaseStore): return { row["role_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } @@ -265,7 +265,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group_role", ) - role["profile"] = json.loads(role["profile"]) + role["profile"] = db_to_json(role["profile"]) return role @@ -333,7 +333,7 @@ class GroupServerWorkerStore(SQLBaseStore): roles = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -462,7 +462,7 @@ class GroupServerWorkerStore(SQLBaseStore): now = int(self._clock.time_msec()) if row and now < row["valid_until_ms"]: - return json.loads(row["attestation_json"]) + return db_to_json(row["attestation_json"]) return None @@ -489,7 +489,7 @@ class GroupServerWorkerStore(SQLBaseStore): "group_id": row[0], "type": row[1], "membership": row[2], - "content": json.loads(row[3]), + "content": db_to_json(row[3]), } for row in txn ] @@ -519,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore): "group_id": group_id, "membership": membership, "type": gtype, - "content": json.loads(content_json), + "content": db_to_json(content_json), } for group_id, membership, gtype, content_json in txn ] @@ -567,7 +567,7 @@ class GroupServerWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_id, current_id, limit)) updates = [ - (stream_id, (group_id, user_id, gtype, json.loads(content_json))) + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) for stream_id, group_id, user_id, gtype, content_json in txn ] diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index f6e78ca5..c2292481 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -24,7 +24,7 @@ from twisted.internet import defer from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.pusher import PusherWorkerStore @@ -43,8 +43,8 @@ def _load_rules(rawrules, enabled_map): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) + rule["conditions"] = db_to_json(rawrule["conditions"]) + rule["actions"] = db_to_json(rawrule["actions"]) rule["default"] = False ruleslist.append(rule) @@ -259,7 +259,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index 54610162..e18f1ca8 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -17,11 +17,11 @@ import logging from typing import Iterable, Iterator, List, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class PusherWorkerStore(SQLBaseStore): for r in rows: dataJson = r["data"] try: - r["data"] = json.loads(dataJson) + r["data"] = db_to_json(dataJson) except Exception as e: logger.warning( "Invalid JSON in data for pusher %d: %s, %s", diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 8f5505bd..1d723f2d 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.async_helpers import ObservableDeferred @@ -203,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore): for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] - ] = json.loads(row["data"]) + ] = db_to_json(row["data"]) return [{"type": "m.receipt", "room_id": room_id, "content": content}] @@ -260,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) - receipt_type[row["user_id"]] = json.loads(row["data"]) + receipt_type[row["user_id"]] = db_to_json(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] @@ -329,7 +329,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn] + updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] limited = False upper_bound = current_id diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 587d4b91..27d2c502 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore): self.config = hs.config self.clock = hs.get_clock() + self._user_id_seq = build_sequence_generator( + database.engine, find_max_generated_user_id_localpart, "user_id_seq", + ) + @cached() def get_user_by_id(self, user_id): return self.db.simple_select_one( @@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self.db.runInteraction("count_real_users", _count_users) return ret - @defer.inlineCallbacks - def find_next_generated_user_id_localpart(self): - """ - Gets the localpart of the next generated user ID. + async def generate_user_id(self) -> str: + """Generate a suitable localpart for a guest user - Generated user IDs are integers, so we find the largest integer user ID - already taken and return that plus one. + Returns: a (hopefully) free localpart """ - - def _find_next_generated_user_id(txn): - # We bound between '@0' and '@a' to avoid pulling the entire table - # out. - txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") - - regex = re.compile(r"^@(\d+):") - - max_found = 0 - - for (user_id,) in txn: - match = regex.search(user_id) - if match: - max_found = max(int(match.group(1)), max_found) - - return max_found + 1 - - return ( - ( - yield self.db.runInteraction( - "find_next_generated_user_id", _find_next_generated_user_id - ) - ) + next_id = await self.db.runInteraction( + "generate_user_id", self._user_id_seq.get_next_id_txn ) + return str(next_id) + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: """Returns user id from threepid @@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): keyvalues={"user_id": user_id}, values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) + + +def find_max_generated_user_id_localpart(cur: Cursor) -> int: + """ + Gets the localpart of the max current generated user ID. + + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that. + """ + + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + max_found = 0 + + for (user_id,) in cur: + match = regex.search(user_id) + if match: + max_found = max(int(match.group(1)), max_found) + return max_found diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index c473cf15..d2e1e36e 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -28,7 +28,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID @@ -118,7 +118,12 @@ class RoomWorkerStore(SQLBaseStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - res = self.db.cursor_to_dict(txn)[0] + # Catch error if sql returns empty result to return "None" instead of an error + try: + res = self.db.cursor_to_dict(txn)[0] + except IndexError: + return None + res["federatable"] = bool(res["federatable"]) res["public"] = bool(res["public"]) return res @@ -665,7 +670,7 @@ class RoomWorkerStore(SQLBaseStore): next_token = None for stream_ordering, content_json in txn: next_token = stream_ordering - event_json = json.loads(content_json) + event_json = db_to_json(content_json) content = event_json["content"] content_url = content.get("url") thumbnail_url = content.get("info", {}).get("thumbnail_url") @@ -910,8 +915,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): if not row["json"]: retention_policy = {} else: - ev = json.loads(row["json"]) - retention_policy = json.dumps(ev["content"]) + ev = db_to_json(row["json"]) + retention_policy = ev["content"] self.db.simple_insert_txn( txn=txn, @@ -966,7 +971,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): updates = [] for room_id, event_json in txn: - event_dict = json.loads(event_json) + event_dict = db_to_json(event_json) room_version_id = event_dict.get("content", {}).get( "room_version", RoomVersions.V1.identifier ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 44bab65e..a92e401e 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -17,8 +17,6 @@ import logging from typing import Iterable, List, Set -from canonicaljson import json - from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -27,6 +25,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import ( LoggingTransaction, SQLBaseStore, + db_to_json, make_in_list_sql_clause, ) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore @@ -498,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) @@ -938,7 +937,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): event_id = row["event_id"] room_id = row["room_id"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql new file mode 100644 index 00000000..1cc2633a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We need to store the stream positions by instance in a sharded config world. +-- +-- We default to master as we want the column to be NOT NULL and we correctly +-- reset the instance name to match the config each time we start up. +ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master'; + +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py new file mode 100644 index 00000000..2011f6bc --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py @@ -0,0 +1,34 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adds a postgres SEQUENCE for generating guest user IDs. +""" + +from synapse.storage.data_stores.main.registration import ( + find_max_generated_user_id_localpart, +) +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + next_id = find_max_generated_user_id_localpart(cur) + 1 + cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index a8381dc5..d5222829 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -17,12 +17,10 @@ import logging import re from collections import namedtuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -157,7 +155,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): stream_ordering = row["stream_ordering"] origin_server_ts = row["origin_server_ts"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 347cc507..bb38a04e 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): last_room_id = progress.get("last_room_id", "") def _background_remove_left_rooms_txn(txn): + # get a batch of room ids to consider sql = """ SELECT DISTINCT room_id FROM current_state_events WHERE room_id > ? ORDER BY room_id LIMIT ? @@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): if not room_ids: return True, set() + ########################################################################### + # + # exclude rooms where we have active members + sql = """ SELECT room_id - FROM current_state_events + FROM local_current_membership WHERE room_id > ? AND room_id <= ? - AND type = 'm.room.member' AND membership = 'join' - AND state_key LIKE ? GROUP BY room_id """ - txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - + txn.execute(sql, (last_room_id, room_ids[-1])) joined_room_ids = {row[0] for row in txn} + to_delete = set(room_ids) - joined_room_ids + + ########################################################################### + # + # exclude rooms which we are in the process of constructing; these otherwise + # qualify as "rooms with no local users", and would have their + # forward extremities cleaned up. + + # the following query will return a list of rooms which have forward + # extremities that are *not* also the create event in the room - ie + # those that are not being created currently. + + sql = """ + SELECT DISTINCT efe.room_id + FROM event_forward_extremities efe + LEFT JOIN current_state_events cse ON + cse.event_id = efe.event_id + AND cse.type = 'm.room.create' + AND cse.state_key = '' + WHERE + cse.event_id IS NULL + AND efe.room_id > ? AND efe.room_id <= ? + """ + + txn.execute(sql, (last_room_id, room_ids[-1])) + + # build a set of those rooms within `to_delete` that do not appear in + # the above, leaving us with the rooms in `to_delete` that *are* being + # created. + creating_rooms = to_delete.difference(row[0] for row in txn) + logger.info("skipping rooms which are being created: %s", creating_rooms) + + # now remove the rooms being created from the list of those to delete. + # + # (we could have just taken the intersection of `to_delete` with the result + # of the sql query, but it's useful to be able to log `creating_rooms`; and + # having done so, it's quicker to remove the (few) creating rooms from + # `to_delete` than it is to form the intersection with the (larger) list of + # not-creating-rooms) + + to_delete -= creating_rooms - left_rooms = set(room_ids) - joined_room_ids + ########################################################################### + # + # now clear the state for the rooms - logger.info("Deleting current state left rooms: %r", left_rooms) + logger.info("Deleting current state left rooms: %r", to_delete) # First we get all users that we still think were joined to the # room. This is so that we can mark those device lists as @@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, retcols=("state_key",), ) @@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) @@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="event_forward_extremities", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 379d758b..10d39b36 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -45,7 +45,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_in_list_sql_clause from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -253,6 +253,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(StreamWorkerStore, self).__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._send_federation = hs.should_send_federation() + self._federation_shard_config = hs.config.worker.federation_shard_config + + # If we're a process that sends federation we may need to reset the + # `federation_stream_position` table to match the current sharding + # config. We don't do this now as otherwise two processes could conflict + # during startup which would cause one to die. + self._need_to_reset_federation_stream_positions = self._send_federation + events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, @@ -793,22 +803,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events - def get_federation_out_pos(self, typ): - return self.db.simple_select_one_onecol( + async def get_federation_out_pos(self, typ: str) -> int: + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, desc="get_federation_out_pos", ) - def update_federation_out_pos(self, typ, stream_id): - return self.db.simple_update_one( + async def update_federation_out_pos(self, typ, stream_id): + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_update_one( table="federation_stream_position", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) + def _reset_federation_positions_txn(self, txn): + """Fiddles with the `federation_stream_position` table to make it match + the configured federation sender instances during start up. + """ + + # The federation sender instances may have changed, so we need to + # massage the `federation_stream_position` table to have a row per type + # per instance sending federation. If there is a mismatch we update the + # table with the correct rows using the *minimum* stream ID seen. This + # may result in resending of events/EDUs to remote servers, but that is + # preferable to dropping them. + + if not self._send_federation: + return + + # Pull out the configured instances. If we don't have a shard config then + # we assume that we're the only instance sending. + configured_instances = self._federation_shard_config.instances + if not configured_instances: + configured_instances = [self._instance_name] + elif self._instance_name not in configured_instances: + return + + instances_in_table = self.db.simple_select_onecol_txn( + txn, + table="federation_stream_position", + keyvalues={}, + retcol="instance_name", + ) + + if set(instances_in_table) == set(configured_instances): + # Nothing to do + return + + sql = """ + SELECT type, MIN(stream_id) FROM federation_stream_position + GROUP BY type + """ + txn.execute(sql) + min_positions = dict(txn) # Map from type -> min position + + # Ensure we do actually have some values here + assert set(min_positions) == {"federation", "events"} + + sql = """ + DELETE FROM federation_stream_position + WHERE NOT (%s) + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "instance_name", configured_instances + ) + txn.execute(sql % (clause,), args) + + for typ, stream_id in min_positions.items(): + self.db.simple_upsert_txn( + txn, + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + values={"stream_id": stream_id}, + ) + def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 290317fd..bd722777 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -21,6 +21,7 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import db_to_json from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.util.caches.descriptors import cached @@ -49,7 +50,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tags_by_room = {} for row in rows: room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = json.loads(row["content"]) + room_tags[row["tag"]] = db_to_json(row["content"]) return tags_by_room return deferred @@ -180,7 +181,7 @@ class TagsWorkerStore(AccountDataWorkerStore): retcols=("tag", "content"), desc="get_tags_for_room", ).addCallback( - lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows} + lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} ) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py index 4c044b1a..5f1b9197 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from typing import Any, Dict, Optional, Union import attr +from canonicaljson import json from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.types import JsonDict from synapse.util import stringutils as stringutils @@ -118,7 +118,7 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session", ) - result["clientdict"] = json.loads(result["clientdict"]) + result["clientdict"] = db_to_json(result["clientdict"]) return UIAuthSessionData(session_id, **result) @@ -168,7 +168,7 @@ class UIAuthWorkerStore(SQLBaseStore): retcols=("stage_type", "result"), desc="get_completed_ui_auth_stages", ): - results[row["stage_type"]] = json.loads(row["result"]) + results[row["stage_type"]] = db_to_json(row["result"]) return results @@ -224,7 +224,7 @@ class UIAuthWorkerStore(SQLBaseStore): ) # Update it and add it back to the database. - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) serverdict[key] = value self.db.simple_update_one_txn( @@ -254,7 +254,7 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session_data", ) - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) return serverdict.get(key, default) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 6b8130bf..942e51fd 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): room_id ) - users_with_profile = yield state.get_current_users_in_room(room_id) + users_with_profile = yield defer.ensureDeferred( + state.get_current_users_in_room(room_id) + ) user_ids = set(users_with_profile) # Update each user in the user directory. diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index ec6b8a4f..d3038ff0 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -70,11 +70,11 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id): + def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: - user_id (str): full user_id to be erased + user_id: full user_id to be erased """ def f(txn): @@ -89,3 +89,25 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) return self.db.runInteraction("mark_user_erased", f) + + def mark_user_not_erased(self, user_id: str) -> None: + """Indicate that user_id is no longer erased. + + Args: + user_id: full user_id to be un-erased + """ + + def f(txn): + # first check if they are already in the list + txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) + if not txn.fetchone(): + return + + # They are there, delete them. + self.simple_delete_one_txn( + txn, "erased_users", keyvalues={"user_id": user_id} + ) + + self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) + + return self.db.runInteraction("mark_user_not_erased", f) diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 5db9f201..128c09a2 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database from synapse.storage.state import StateFilter +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "*stateGroupMembersCache*", 500000, ) + def get_max_state_group_txn(txn: Cursor): + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + return txn.fetchone()[0] + + self._state_group_seq_gen = build_sequence_generator( + self.database_engine, get_max_state_group_txn, "state_group_id_seq" + ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between @@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # AFAIK, this can never happen raise Exception("current_state_ids cannot be None") - state_group = self.database_engine.get_next_state_group_id(txn) + state_group = self._state_group_seq_gen.get_next_id_txn(txn) self.db.simple_insert_txn( txn, diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ab0bbe4b..908cbc79 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): def lock_table(self, txn, table: str) -> None: ... - @abc.abstractmethod - def get_next_state_group_id(self, txn) -> int: - """Returns an int that can be used as a new state_group ID - """ - ... - @property @abc.abstractmethod def server_version(self) -> str: diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a3158808..ff39281f 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - txn.execute("SELECT nextval('state_group_id_seq')") - return txn.fetchone()[0] - @property def server_version(self): """Returns a string giving the server version. For example: '8.1.5' diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 215a9494..8a0f8c89 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def lock_table(self, txn, table): return - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - # We do application locking here since if we're using sqlite then - # we are a single process synapse. - with self._current_state_group_id_lock: - if self._current_state_group_id is None: - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - self._current_state_group_id = txn.fetchone()[0] - - self._current_state_group_id += 1 - return self._current_state_group_id - @property def server_version(self): """Gets a string giving the server version. For example: '3.22.0' diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa460416..78fbdcde 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,7 +29,6 @@ from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main.events import DeltaState from synapse.types import StateMap @@ -648,6 +647,10 @@ class EventsPersistenceStorage(object): room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f89ce0be..787cebfb 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple from typing_extensions import Deque from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.util.sequence import PostgresSequenceGenerator class IdGenerator(object): @@ -247,7 +248,6 @@ class MultiWriterIdGenerator: ): self._db = db self._instance_name = instance_name - self._sequence_name = sequence_name # We lock as some functions may be called from DB threads. self._lock = threading.Lock() @@ -260,6 +260,8 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + self._sequence_gen = PostgresSequenceGenerator(sequence_name) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: @@ -283,9 +285,7 @@ class MultiWriterIdGenerator: return current_positions def _load_next_id_txn(self, txn): - txn.execute("SELECT nextval(?)", (self._sequence_name,)) - (next_id,) = txn.fetchone() - return next_id + return self._sequence_gen.get_next_id_txn(txn) async def get_next(self): """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py new file mode 100644 index 00000000..63dfea42 --- /dev/null +++ b/synapse/storage/util/sequence.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import threading +from typing import Callable, Optional + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.types import Cursor + + +class SequenceGenerator(metaclass=abc.ABCMeta): + """A class which generates a unique sequence of integers""" + + @abc.abstractmethod + def get_next_id_txn(self, txn: Cursor) -> int: + """Gets the next ID in the sequence""" + ... + + +class PostgresSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses a postgres sequence""" + + def __init__(self, sequence_name: str): + self._sequence_name = sequence_name + + def get_next_id_txn(self, txn: Cursor) -> int: + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + return txn.fetchone()[0] + + +GetFirstCallbackType = Callable[[Cursor], int] + + +class LocalSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses local locking + + This only works reliably if there are no other worker processes generating IDs at + the same time. + """ + + def __init__(self, get_first_callback: GetFirstCallbackType): + """ + Args: + get_first_callback: a callback which is called on the first call to + get_next_id_txn; should return the curreent maximum id + """ + # the callback. this is cleared after it is called, so that it can be GCed. + self._callback = get_first_callback # type: Optional[GetFirstCallbackType] + + # The current max value, or None if we haven't looked in the DB yet. + self._current_max_id = None # type: Optional[int] + self._lock = threading.Lock() + + def get_next_id_txn(self, txn: Cursor) -> int: + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + self._current_max_id += 1 + return self._current_max_id + + +def build_sequence_generator( + database_engine: BaseDatabaseEngine, + get_first_callback: GetFirstCallbackType, + sequence_name: str, +) -> SequenceGenerator: + """Get the best impl of SequenceGenerator available + + This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on + sqlite. + + Args: + database_engine: the database engine we are connected to + get_first_callback: a callback which gets the next sequence ID. Used if + we're on sqlite. + sequence_name: the name of a postgres sequence to use. + """ + if isinstance(database_engine, PostgresEngine): + return PostgresSequenceGenerator(sequence_name) + else: + return LocalSequenceGenerator(get_first_callback) |