summaryrefslogtreecommitdiff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py5
-rw-r--r--synapse/storage/data_stores/main/__init__.py2
-rw-r--r--synapse/storage/data_stores/main/account_data.py16
-rw-r--r--synapse/storage/data_stores/main/appservice.py4
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py9
-rw-r--r--synapse/storage/data_stores/main/devices.py2
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py10
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py2
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py4
-rw-r--r--synapse/storage/data_stores/main/events.py76
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py14
-rw-r--r--synapse/storage/data_stores/main/events_worker.py9
-rw-r--r--synapse/storage/data_stores/main/group_server.py22
-rw-r--r--synapse/storage/data_stores/main/push_rule.py8
-rw-r--r--synapse/storage/data_stores/main/pusher.py6
-rw-r--r--synapse/storage/data_stores/main/receipts.py8
-rw-r--r--synapse/storage/data_stores/main/registration.py65
-rw-r--r--synapse/storage/data_stores/main/room.py17
-rw-r--r--synapse/storage/data_stores/main/roommember.py7
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py34
-rw-r--r--synapse/storage/data_stores/main/search.py6
-rw-r--r--synapse/storage/data_stores/main/state.py65
-rw-r--r--synapse/storage/data_stores/main/stream.py97
-rw-r--r--synapse/storage/data_stores/main/tags.py5
-rw-r--r--synapse/storage/data_stores/main/ui_auth.py12
-rw-r--r--synapse/storage/data_stores/main/user_directory.py4
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py26
-rw-r--r--synapse/storage/data_stores/state/store.py12
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite.py13
-rw-r--r--synapse/storage/persist_events.py5
-rw-r--r--synapse/storage/util/id_generators.py8
-rw-r--r--synapse/storage/util/sequence.py98
36 files changed, 471 insertions, 238 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index bfce541c..985a0428 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -100,8 +100,8 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # Decode it to a Unicode string before feeding it to json.loads, so we
- # consistenty get a Unicode-containing object out.
+ # Decode it to a Unicode string before feeding it to json.loads, since
+ # Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 59f3394b..018826ef 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -249,7 +249,10 @@ class BackgroundUpdater(object):
retcol="progress_json",
)
- progress = json.loads(progress_json)
+ # Avoid a circular import.
+ from synapse.storage._base import db_to_json
+
+ progress = db_to_json(progress_json)
time_start = self._clock.time_msec()
items_updated = await update_handler(progress, batch_size)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 4b4763c7..932458f6 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -128,7 +128,7 @@ class DataStore(
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index b58f04d0..33cc372d 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -77,7 +77,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
global_account_data = {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
rows = self.db.simple_select_list_txn(
@@ -90,7 +90,7 @@ class AccountDataWorkerStore(SQLBaseStore):
by_room = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
- room_data[row["account_data_type"]] = json.loads(row["content"])
+ room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room
@@ -113,7 +113,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- return json.loads(result)
+ return db_to_json(result)
else:
return None
@@ -137,7 +137,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
return {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
return self.db.runInteraction(
@@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore):
allow_none=True,
)
- return json.loads(content_json) if content_json else None
+ return db_to_json(content_json) if content_json else None
return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
@@ -255,7 +255,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: json.loads(row[1]) for row in txn}
+ global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
@@ -267,7 +267,7 @@ class AccountDataWorkerStore(SQLBaseStore):
account_data_by_room = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
- room_account_data[row[1]] = json.loads(row[2])
+ room_account_data[row[1]] = db_to_json(row[2])
return global_account_data, account_data_by_room
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 7a1fe8cd..56659fed 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
@@ -303,7 +303,7 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
- event_ids = json.loads(entry["event_ids"])
+ event_ids = db_to_json(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index d313b970..da297b31 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
@@ -65,7 +65,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return messages, stream_pos
@@ -173,7 +173,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
@@ -424,9 +424,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
- txn.execute(sql, (stream_id, stream_id))
-
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 343cf9a2..45581a65 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -577,7 +577,7 @@ class DeviceWorkerStore(SQLBaseStore):
rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
- return {user for row in rows for user in json.loads(row[0])}
+ return {user for row in rows for user in db_to_json(row[0])}
else:
return set()
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 23f4570c..615364f0 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
class EndToEndRoomKeyStore(SQLBaseStore):
@@ -148,7 +148,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"forwarded_count": row["forwarded_count"],
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
- "session_data": json.loads(row["session_data"]),
+ "session_data": db_to_json(row["session_data"]),
}
return sessions
@@ -222,7 +222,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": row[2],
"forwarded_count": row[3],
"is_verified": row[4],
- "session_data": json.loads(row[5]),
+ "session_data": db_to_json(row[5]),
}
return ret
@@ -319,7 +319,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
)
- result["auth_data"] = json.loads(result["auth_data"])
+ result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 6c3cff82..317c07a8 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -366,7 +366,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
- key = json.loads(row["keydata"])
+ key = db_to_json(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index bc9f4f08..504babaa 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -58,7 +58,7 @@ def _deserialize_action(actions, is_highlight):
"""Custom deserializer for actions. This allows us to "compress" common actions
"""
if actions:
- return json.loads(actions)
+ return db_to_json(actions)
if is_highlight:
return DEFAULT_HIGHLIGHT_ACTION
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 230fb5cd..6f2e0d15 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,11 +17,9 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from functools import wraps
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
import attr
-from canonicaljson import json
from prometheus_client import Counter
from twisted.internet import defer
@@ -33,7 +31,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
-from synapse.storage._base import make_in_list_sql_clause
+from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.search import SearchEntry
from synapse.storage.database import Database, LoggingTransaction
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -69,27 +67,6 @@ def encode_json(json_object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-def _retry_on_integrity_error(func):
- """Wraps a database function so that it gets retried on IntegrityError,
- with `delete_existing=True` passed in.
-
- Args:
- func: function that returns a Deferred and accepts a `delete_existing` arg
- """
-
- @wraps(func)
- @defer.inlineCallbacks
- def f(self, *args, **kwargs):
- try:
- res = yield func(self, *args, delete_existing=False, **kwargs)
- except self.database_engine.module.IntegrityError:
- logger.exception("IntegrityError, retrying.")
- res = yield func(self, *args, delete_existing=True, **kwargs)
- return res
-
- return f
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -134,7 +111,6 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @_retry_on_integrity_error
@defer.inlineCallbacks
def _persist_events_and_state_updates(
self,
@@ -143,7 +119,6 @@ class PersistEventsStore:
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- delete_existing: bool = False,
):
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -157,7 +132,6 @@ class PersistEventsStore:
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
backfilled
- delete_existing
Returns:
Deferred: resolves when the events have been persisted
@@ -197,7 +171,6 @@ class PersistEventsStore:
self._persist_events_txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
- delete_existing=delete_existing,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
@@ -262,7 +235,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
+ results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
yield self.db.runInteraction(
@@ -323,7 +296,7 @@ class PersistEventsStore:
if prev_event_id in existing_prevs:
continue
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
to_recursively_check.append(prev_event_id)
existing_prevs.add(prev_event_id)
@@ -341,7 +314,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
- delete_existing: bool = False,
state_delta_for_room: Dict[str, DeltaState] = {},
new_forward_extremeties: Dict[str, List[str]] = {},
):
@@ -393,13 +365,6 @@ class PersistEventsStore:
# From this point onwards the events are only events that we haven't
# seen before.
- if delete_existing:
- # For paranoia reasons, we go and delete all the existing entries
- # for these events so we can reinsert them.
- # This gets around any problems with some tables already having
- # entries.
- self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
-
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
# Insert into event_to_state_groups.
@@ -617,7 +582,7 @@ class PersistEventsStore:
txn.execute(sql, (room_id, EventTypes.Create, ""))
row = txn.fetchone()
if row:
- event_json = json.loads(row[0])
+ event_json = db_to_json(row[0])
content = event_json.get("content", {})
creator = content.get("creator")
room_version_id = content.get("room_version", RoomVersions.V1.identifier)
@@ -797,39 +762,6 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
- @classmethod
- def _delete_existing_rows_txn(cls, txn, events_and_contexts):
- if not events_and_contexts:
- # nothing to do here
- return
-
- logger.info("Deleting existing")
-
- for table in (
- "events",
- "event_auth",
- "event_json",
- "event_edges",
- "event_forward_extremities",
- "event_reference_hashes",
- "event_search",
- "event_to_state_groups",
- "state_events",
- "rejections",
- "redactions",
- "room_memberships",
- ):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts],
- )
-
- for table in ("event_push_actions",):
- txn.executemany(
- "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
- [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
- )
-
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 62d28f44..663c94b2 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -15,12 +15,10 @@
import logging
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.constants import EventContentFields
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -125,7 +123,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in rows:
try:
event_id = row[1]
- event_json = json.loads(row[2])
+ event_json = db_to_json(row[2])
sender = event_json["sender"]
content = event_json["content"]
@@ -208,7 +206,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in ev_rows:
event_id = row["event_id"]
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -317,7 +315,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
soft_failed = False
if metadata:
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
@@ -358,7 +356,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
graph[event_id] = {prev_event_id}
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
else:
@@ -543,7 +541,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
last_row_event_id = ""
for (event_id, event_json_raw) in results:
try:
- event_json = json.loads(event_json_raw)
+ event_json = db_to_json(event_json_raw)
self.db.simple_insert_many_txn(
txn=txn,
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 01cad7d4..e812c670 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -21,7 +21,6 @@ import threading
from collections import namedtuple
from typing import List, Optional, Tuple
-from canonicaljson import json
from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -40,7 +39,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
@@ -611,8 +610,8 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = json.loads(row["json"])
- internal_metadata = json.loads(row["internal_metadata"])
+ d = db_to_json(row["json"])
+ internal_metadata = db_to_json(row["internal_metadata"])
format_version = row["format_version"]
if format_version is None:
@@ -640,7 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
- logger.error(
+ logger.warning(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 4fb9f985..01ff561e 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -197,7 +197,7 @@ class GroupServerWorkerStore(SQLBaseStore):
categories = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -221,7 +221,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["category_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -235,7 +235,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_category",
)
- category["profile"] = json.loads(category["profile"])
+ category["profile"] = db_to_json(category["profile"])
return category
@@ -251,7 +251,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["role_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -265,7 +265,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_role",
)
- role["profile"] = json.loads(role["profile"])
+ role["profile"] = db_to_json(role["profile"])
return role
@@ -333,7 +333,7 @@ class GroupServerWorkerStore(SQLBaseStore):
roles = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -462,7 +462,7 @@ class GroupServerWorkerStore(SQLBaseStore):
now = int(self._clock.time_msec())
if row and now < row["valid_until_ms"]:
- return json.loads(row["attestation_json"])
+ return db_to_json(row["attestation_json"])
return None
@@ -489,7 +489,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": row[0],
"type": row[1],
"membership": row[2],
- "content": json.loads(row[3]),
+ "content": db_to_json(row[3]),
}
for row in txn
]
@@ -519,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": group_id,
"membership": membership,
"type": gtype,
- "content": json.loads(content_json),
+ "content": db_to_json(content_json),
}
for group_id, membership, gtype, content_json in txn
]
@@ -567,7 +567,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [
- (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index f6e78ca5..c2292481 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -24,7 +24,7 @@ from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
@@ -43,8 +43,8 @@ def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
- rule["conditions"] = json.loads(rawrule["conditions"])
- rule["actions"] = json.loads(rawrule["actions"])
+ rule["conditions"] = db_to_json(rawrule["conditions"])
+ rule["actions"] = db_to_json(rawrule["actions"])
rule["default"] = False
ruleslist.append(rule)
@@ -259,7 +259,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 54610162..e18f1ca8 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -17,11 +17,11 @@
import logging
from typing import Iterable, Iterator, List, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class PusherWorkerStore(SQLBaseStore):
for r in rows:
dataJson = r["data"]
try:
- r["data"] = json.loads(dataJson)
+ r["data"] = db_to_json(dataJson)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 8f5505bd..1d723f2d 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.async_helpers import ObservableDeferred
@@ -203,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
- ] = json.loads(row["data"])
+ ] = db_to_json(row["data"])
return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@@ -260,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
- receipt_type[row["user_id"]] = json.loads(row["data"])
+ receipt_type[row["user_id"]] = db_to_json(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -329,7 +329,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn]
+ updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
limited = False
upper_bound = current_id
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 587d4b91..27d2c502 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore):
self.config = hs.config
self.clock = hs.get_clock()
+ self._user_id_seq = build_sequence_generator(
+ database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+ )
+
@cached()
def get_user_by_id(self, user_id):
return self.db.simple_select_one(
@@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
- @defer.inlineCallbacks
- def find_next_generated_user_id_localpart(self):
- """
- Gets the localpart of the next generated user ID.
+ async def generate_user_id(self) -> str:
+ """Generate a suitable localpart for a guest user
- Generated user IDs are integers, so we find the largest integer user ID
- already taken and return that plus one.
+ Returns: a (hopefully) free localpart
"""
-
- def _find_next_generated_user_id(txn):
- # We bound between '@0' and '@a' to avoid pulling the entire table
- # out.
- txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
-
- regex = re.compile(r"^@(\d+):")
-
- max_found = 0
-
- for (user_id,) in txn:
- match = regex.search(user_id)
- if match:
- max_found = max(int(match.group(1)), max_found)
-
- return max_found + 1
-
- return (
- (
- yield self.db.runInteraction(
- "find_next_generated_user_id", _find_next_generated_user_id
- )
- )
+ next_id = await self.db.runInteraction(
+ "generate_user_id", self._user_id_seq.get_next_id_txn
)
+ return str(next_id)
+
async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
@@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
+
+
+def find_max_generated_user_id_localpart(cur: Cursor) -> int:
+ """
+ Gets the localpart of the max current generated user ID.
+
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that.
+ """
+
+ # We bound between '@0' and '@a' to avoid pulling the entire table
+ # out.
+ cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
+
+ regex = re.compile(r"^@(\d+):")
+
+ max_found = 0
+
+ for (user_id,) in cur:
+ match = regex.search(user_id)
+ if match:
+ max_found = max(int(match.group(1)), max_found)
+ return max_found
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index c473cf15..d2e1e36e 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -28,7 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import ThirdPartyInstanceID
@@ -118,7 +118,12 @@ class RoomWorkerStore(SQLBaseStore):
WHERE room_id = ?
"""
txn.execute(sql, [room_id])
- res = self.db.cursor_to_dict(txn)[0]
+ # Catch error if sql returns empty result to return "None" instead of an error
+ try:
+ res = self.db.cursor_to_dict(txn)[0]
+ except IndexError:
+ return None
+
res["federatable"] = bool(res["federatable"])
res["public"] = bool(res["public"])
return res
@@ -665,7 +670,7 @@ class RoomWorkerStore(SQLBaseStore):
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
- event_json = json.loads(content_json)
+ event_json = db_to_json(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
@@ -910,8 +915,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
if not row["json"]:
retention_policy = {}
else:
- ev = json.loads(row["json"])
- retention_policy = json.dumps(ev["content"])
+ ev = db_to_json(row["json"])
+ retention_policy = ev["content"]
self.db.simple_insert_txn(
txn=txn,
@@ -966,7 +971,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
updates = []
for room_id, event_json in txn:
- event_dict = json.loads(event_json)
+ event_dict = db_to_json(event_json)
room_version_id = event_dict.get("content", {}).get(
"room_version", RoomVersions.V1.identifier
)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 44bab65e..a92e401e 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import Iterable, List, Set
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -27,6 +25,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import (
LoggingTransaction,
SQLBaseStore,
+ db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
@@ -498,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
@@ -938,7 +937,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
event_id = row["event_id"]
room_id = row["room_id"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
new file mode 100644
index 00000000..1cc2633a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We need to store the stream positions by instance in a sharded config world.
+--
+-- We default to master as we want the column to be NOT NULL and we correctly
+-- reset the instance name to match the config each time we start up.
+ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master';
+
+CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name);
diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
new file mode 100644
index 00000000..2011f6bc
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
@@ -0,0 +1,34 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adds a postgres SEQUENCE for generating guest user IDs.
+"""
+
+from synapse.storage.data_stores.main.registration import (
+ find_max_generated_user_id_localpart,
+)
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ return
+
+ next_id = find_max_generated_user_id_localpart(cur) + 1
+ cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index a8381dc5..d5222829 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -17,12 +17,10 @@ import logging
import re
from collections import namedtuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -157,7 +155,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 347cc507..bb38a04e 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
+ # get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ? ORDER BY room_id LIMIT ?
@@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
if not room_ids:
return True, set()
+ ###########################################################################
+ #
+ # exclude rooms where we have active members
+
sql = """
SELECT room_id
- FROM current_state_events
+ FROM local_current_membership
WHERE
room_id > ? AND room_id <= ?
- AND type = 'm.room.member'
AND membership = 'join'
- AND state_key LIKE ?
GROUP BY room_id
"""
- txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
-
+ txn.execute(sql, (last_room_id, room_ids[-1]))
joined_room_ids = {row[0] for row in txn}
+ to_delete = set(room_ids) - joined_room_ids
+
+ ###########################################################################
+ #
+ # exclude rooms which we are in the process of constructing; these otherwise
+ # qualify as "rooms with no local users", and would have their
+ # forward extremities cleaned up.
+
+ # the following query will return a list of rooms which have forward
+ # extremities that are *not* also the create event in the room - ie
+ # those that are not being created currently.
+
+ sql = """
+ SELECT DISTINCT efe.room_id
+ FROM event_forward_extremities efe
+ LEFT JOIN current_state_events cse ON
+ cse.event_id = efe.event_id
+ AND cse.type = 'm.room.create'
+ AND cse.state_key = ''
+ WHERE
+ cse.event_id IS NULL
+ AND efe.room_id > ? AND efe.room_id <= ?
+ """
+
+ txn.execute(sql, (last_room_id, room_ids[-1]))
+
+ # build a set of those rooms within `to_delete` that do not appear in
+ # the above, leaving us with the rooms in `to_delete` that *are* being
+ # created.
+ creating_rooms = to_delete.difference(row[0] for row in txn)
+ logger.info("skipping rooms which are being created: %s", creating_rooms)
+
+ # now remove the rooms being created from the list of those to delete.
+ #
+ # (we could have just taken the intersection of `to_delete` with the result
+ # of the sql query, but it's useful to be able to log `creating_rooms`; and
+ # having done so, it's quicker to remove the (few) creating rooms from
+ # `to_delete` than it is to form the intersection with the (larger) list of
+ # not-creating-rooms)
+
+ to_delete -= creating_rooms
- left_rooms = set(room_ids) - joined_room_ids
+ ###########################################################################
+ #
+ # now clear the state for the rooms
- logger.info("Deleting current state left rooms: %r", left_rooms)
+ logger.info("Deleting current state left rooms: %r", to_delete)
# First we get all users that we still think were joined to the
# room. This is so that we can mark those device lists as
@@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
retcols=("state_key",),
)
@@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
@@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="event_forward_extremities",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 379d758b..10d39b36 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -45,7 +45,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -253,6 +253,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+ self._send_federation = hs.should_send_federation()
+ self._federation_shard_config = hs.config.worker.federation_shard_config
+
+ # If we're a process that sends federation we may need to reset the
+ # `federation_stream_position` table to match the current sharding
+ # config. We don't do this now as otherwise two processes could conflict
+ # during startup which would cause one to die.
+ self._need_to_reset_federation_stream_positions = self._send_federation
+
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self.db.get_cache_dict(
db_conn,
@@ -793,22 +803,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, events
- def get_federation_out_pos(self, typ):
- return self.db.simple_select_one_onecol(
+ async def get_federation_out_pos(self, typ: str) -> int:
+ if self._need_to_reset_federation_stream_positions:
+ await self.db.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ return await self.db.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
desc="get_federation_out_pos",
)
- def update_federation_out_pos(self, typ, stream_id):
- return self.db.simple_update_one(
+ async def update_federation_out_pos(self, typ, stream_id):
+ if self._need_to_reset_federation_stream_positions:
+ await self.db.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ return await self.db.simple_update_one(
table="federation_stream_position",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
+ def _reset_federation_positions_txn(self, txn):
+ """Fiddles with the `federation_stream_position` table to make it match
+ the configured federation sender instances during start up.
+ """
+
+ # The federation sender instances may have changed, so we need to
+ # massage the `federation_stream_position` table to have a row per type
+ # per instance sending federation. If there is a mismatch we update the
+ # table with the correct rows using the *minimum* stream ID seen. This
+ # may result in resending of events/EDUs to remote servers, but that is
+ # preferable to dropping them.
+
+ if not self._send_federation:
+ return
+
+ # Pull out the configured instances. If we don't have a shard config then
+ # we assume that we're the only instance sending.
+ configured_instances = self._federation_shard_config.instances
+ if not configured_instances:
+ configured_instances = [self._instance_name]
+ elif self._instance_name not in configured_instances:
+ return
+
+ instances_in_table = self.db.simple_select_onecol_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={},
+ retcol="instance_name",
+ )
+
+ if set(instances_in_table) == set(configured_instances):
+ # Nothing to do
+ return
+
+ sql = """
+ SELECT type, MIN(stream_id) FROM federation_stream_position
+ GROUP BY type
+ """
+ txn.execute(sql)
+ min_positions = dict(txn) # Map from type -> min position
+
+ # Ensure we do actually have some values here
+ assert set(min_positions) == {"federation", "events"}
+
+ sql = """
+ DELETE FROM federation_stream_position
+ WHERE NOT (%s)
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "instance_name", configured_instances
+ )
+ txn.execute(sql % (clause,), args)
+
+ for typ, stream_id in min_positions.items():
+ self.db.simple_upsert_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={"type": typ, "instance_name": self._instance_name},
+ values={"stream_id": stream_id},
+ )
+
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 290317fd..bd722777 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import db_to_json
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
@@ -49,7 +50,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = json.loads(row["content"])
+ room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room
return deferred
@@ -180,7 +181,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(
- lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
+ lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
)
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index 4c044b1a..5f1b9197 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from typing import Any, Dict, Optional, Union
import attr
+from canonicaljson import json
from synapse.api.errors import StoreError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util import stringutils as stringutils
@@ -118,7 +118,7 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session",
)
- result["clientdict"] = json.loads(result["clientdict"])
+ result["clientdict"] = db_to_json(result["clientdict"])
return UIAuthSessionData(session_id, **result)
@@ -168,7 +168,7 @@ class UIAuthWorkerStore(SQLBaseStore):
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
):
- results[row["stage_type"]] = json.loads(row["result"])
+ results[row["stage_type"]] = db_to_json(row["result"])
return results
@@ -224,7 +224,7 @@ class UIAuthWorkerStore(SQLBaseStore):
)
# Update it and add it back to the database.
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
serverdict[key] = value
self.db.simple_update_one_txn(
@@ -254,7 +254,7 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session_data",
)
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
return serverdict.get(key, default)
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 6b8130bf..942e51fd 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
- users_with_profile = yield state.get_current_users_in_room(room_id)
+ users_with_profile = yield defer.ensureDeferred(
+ state.get_current_users_in_room(room_id)
+ )
user_ids = set(users_with_profile)
# Update each user in the user directory.
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index ec6b8a4f..d3038ff0 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -70,11 +70,11 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
- def mark_user_erased(self, user_id):
+ def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
- user_id (str): full user_id to be erased
+ user_id: full user_id to be erased
"""
def f(txn):
@@ -89,3 +89,25 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db.runInteraction("mark_user_erased", f)
+
+ def mark_user_not_erased(self, user_id: str) -> None:
+ """Indicate that user_id is no longer erased.
+
+ Args:
+ user_id: full user_id to be un-erased
+ """
+
+ def f(txn):
+ # first check if they are already in the list
+ txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
+ if not txn.fetchone():
+ return
+
+ # They are there, delete them.
+ self.simple_delete_one_txn(
+ txn, "erased_users", keyvalues={"user_id": user_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
+
+ return self.db.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 5db9f201..128c09a2 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"*stateGroupMembersCache*", 500000,
)
+ def get_max_state_group_txn(txn: Cursor):
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ return txn.fetchone()[0]
+
+ self._state_group_seq_gen = build_sequence_generator(
+ self.database_engine, get_max_state_group_txn, "state_group_id_seq"
+ )
+
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
- state_group = self.database_engine.get_next_state_group_id(txn)
+ state_group = self._state_group_seq_gen.get_next_id_txn(txn)
self.db.simple_insert_txn(
txn,
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ab0bbe4b..908cbc79 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def lock_table(self, txn, table: str) -> None:
...
- @abc.abstractmethod
- def get_next_state_group_id(self, txn) -> int:
- """Returns an int that can be used as a new state_group ID
- """
- ...
-
@property
@abc.abstractmethod
def server_version(self) -> str:
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a3158808..ff39281f 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- txn.execute("SELECT nextval('state_group_id_seq')")
- return txn.fetchone()[0]
-
@property
def server_version(self):
"""Returns a string giving the server version. For example: '8.1.5'
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 215a9494..8a0f8c89 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def lock_table(self, txn, table):
return
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- # We do application locking here since if we're using sqlite then
- # we are a single process synapse.
- with self._current_state_group_id_lock:
- if self._current_state_group_id is None:
- txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- self._current_state_group_id = txn.fetchone()[0]
-
- self._current_state_group_id += 1
- return self._current_state_group_id
-
@property
def server_version(self):
"""Gets a string giving the server version. For example: '3.22.0'
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa460416..78fbdcde 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -29,7 +29,6 @@ from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap
@@ -648,6 +647,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
+
+ # Avoid a circular import.
+ from synapse.state import StateResolutionStore
+
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0be..787cebfb 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple
from typing_extensions import Deque
from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
class IdGenerator(object):
@@ -247,7 +248,6 @@ class MultiWriterIdGenerator:
):
self._db = db
self._instance_name = instance_name
- self._sequence_name = sequence_name
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
@@ -260,6 +260,8 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
@@ -283,9 +285,7 @@ class MultiWriterIdGenerator:
return current_positions
def _load_next_id_txn(self, txn):
- txn.execute("SELECT nextval(?)", (self._sequence_name,))
- (next_id,) = txn.fetchone()
- return next_id
+ return self._sequence_gen.get_next_id_txn(txn)
async def get_next(self):
"""
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 00000000..63dfea42
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+ """A class which generates a unique sequence of integers"""
+
+ @abc.abstractmethod
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ """Gets the next ID in the sequence"""
+ ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+ def __init__(self, sequence_name: str):
+ self._sequence_name = sequence_name
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ return txn.fetchone()[0]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses local locking
+
+ This only works reliably if there are no other worker processes generating IDs at
+ the same time.
+ """
+
+ def __init__(self, get_first_callback: GetFirstCallbackType):
+ """
+ Args:
+ get_first_callback: a callback which is called on the first call to
+ get_next_id_txn; should return the curreent maximum id
+ """
+ # the callback. this is cleared after it is called, so that it can be GCed.
+ self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
+
+ # The current max value, or None if we haven't looked in the DB yet.
+ self._current_max_id = None # type: Optional[int]
+ self._lock = threading.Lock()
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ self._current_max_id += 1
+ return self._current_max_id
+
+
+def build_sequence_generator(
+ database_engine: BaseDatabaseEngine,
+ get_first_callback: GetFirstCallbackType,
+ sequence_name: str,
+) -> SequenceGenerator:
+ """Get the best impl of SequenceGenerator available
+
+ This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+ sqlite.
+
+ Args:
+ database_engine: the database engine we are connected to
+ get_first_callback: a callback which gets the next sequence ID. Used if
+ we're on sqlite.
+ sequence_name: the name of a postgres sequence to use.
+ """
+ if isinstance(database_engine, PostgresEngine):
+ return PostgresSequenceGenerator(sequence_name)
+ else:
+ return LocalSequenceGenerator(get_first_callback)