summaryrefslogtreecommitdiff
path: root/synapse/storage/data_stores/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/data_stores/main/devices.py')
-rw-r--r--synapse/storage/data_stores/main/devices.py300
1 files changed, 216 insertions, 84 deletions
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 8af5f7de..03f5141e 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
from six import iteritems
@@ -31,7 +32,11 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import (
+ Database,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -40,6 +45,7 @@ from synapse.util.caches.descriptors import (
cachedList,
)
from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
logger = logging.getLogger(__name__)
@@ -47,6 +53,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
"drop_device_list_streams_non_unique_indexes"
)
+BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
@@ -112,23 +120,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- # We retrieve n+1 devices from the list of outbound pokes where n is
- # our outbound device update limit. We then check if the very last
- # device has the same stream_id as the second-to-last device. If so,
- # then we ignore all devices with that stream_id and only send the
- # devices with a lower stream_id.
- #
- # If when culling the list we end up with no devices afterwards, we
- # consider the device update to be too large, and simply skip the
- # stream_id; the rationale being that such a large device list update
- # is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
- limit + 1,
+ limit,
)
# Return an empty list if there are no updates
@@ -166,14 +164,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- # if we have exceeded the limit, we need to exclude any results with the
- # same stream_id as the last row.
- if len(updates) > limit:
- stream_id_cutoff = updates[-1][2]
- now_stream_id = stream_id_cutoff - 1
- else:
- stream_id_cutoff = None
-
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
@@ -181,7 +171,6 @@ class DeviceWorkerStore(SQLBaseStore):
# the max stream_id across each set of duplicate entries
#
# maps (user_id, device_id) -> (stream_id, opentracing_context)
- # as long as their stream_id does not match that of the last row
#
# opentracing_context contains the opentracing metadata for the request
# that created the poke
@@ -192,10 +181,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
- if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
- # Stop processing updates
- break
-
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +203,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- # If we didn't find any updates with a stream_id lower than the cutoff, it
- # means that there are more than limit updates all of which have the same
- # steam_id.
-
- # That should only happen if a client is spamming the server with new
- # devices, in which case E2E isn't going to work well anyway. We'll just
- # skip that stream_id and return an empty list, and continue with the next
- # stream_id next time.
- if not query_map and not cross_signing_keys_by_user:
- return stream_id_cutoff, []
-
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -259,11 +233,11 @@ class DeviceWorkerStore(SQLBaseStore):
# get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
- WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+ WHERE destination = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id
LIMIT ?
"""
- txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
+ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
return list(txn)
@@ -301,7 +275,14 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = yield self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id
)
- for device_id, device in iteritems(user_devices):
+
+ # make sure we go through the devices in stream order
+ device_ids = sorted(
+ user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+ )
+
+ for device_id in device_ids:
+ device = user_devices[device_id]
stream_id, opentracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -560,8 +541,8 @@ class DeviceWorkerStore(SQLBaseStore):
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = list(
- self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+ to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
)
if not to_check:
@@ -611,22 +592,33 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
+ async def get_all_device_list_changes_for_remotes(
+ self, from_key: int, to_key: int, limit: int,
+ ) -> List[Tuple[int, str]]:
+ """Return a list of `(stream_id, entity)` which is the combined list of
+ changes to devices and which destinations need to be poked. Entity is
+ either a user ID (starting with '@') or a remote destination.
"""
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
+
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
+ LIMIT ?
"""
- return self.db.execute(
- "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+
+ return await self.db.execute(
+ "get_all_device_list_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
@cached(max_entries=10000)
@@ -728,6 +720,11 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._drop_device_list_streams_non_unique_indexes,
)
+ # clear out duplicate device list outbound pokes
+ self.db.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+ )
+
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -742,6 +739,66 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
return 1
+ async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+ # for some reason, we have accumulated duplicate entries in
+ # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+ # efficient.
+ #
+ # For each duplicate, we delete all the existing rows and put one back.
+
+ KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
+ last_row = progress.get(
+ "last_row",
+ {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
+ )
+
+ def _txn(txn):
+ clause, args = make_tuple_comparison_clause(
+ self.db.engine, [(x, last_row[x]) for x in KEY_COLS]
+ )
+ sql = """
+ SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
+ FROM device_lists_outbound_pokes
+ WHERE %s
+ GROUP BY %s
+ HAVING count(*) > 1
+ ORDER BY %s
+ LIMIT ?
+ """ % (
+ clause, # WHERE
+ ",".join(KEY_COLS), # GROUP BY
+ ",".join(KEY_COLS), # ORDER BY
+ )
+ txn.execute(sql, args + [batch_size])
+ rows = self.db.cursor_to_dict(txn)
+
+ row = None
+ for row in rows:
+ self.db.simple_delete_txn(
+ txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+ )
+
+ row["sent"] = False
+ self.db.simple_insert_txn(
+ txn, "device_lists_outbound_pokes", row,
+ )
+
+ if row:
+ self.db.updates._background_update_progress_txn(
+ txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+ )
+
+ return len(rows)
+
+ rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn)
+
+ if not rows:
+ await self.db.updates._end_background_update(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
+ )
+
+ return rows
+
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
@@ -1021,29 +1078,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ if not device_ids:
+ return
+
+ with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ yield self.db.runInteraction(
+ "add_device_change_to_stream",
+ self._add_device_change_to_stream_txn,
+ user_id,
+ device_ids,
+ stream_ids,
+ )
+
+ if not hosts:
+ return stream_ids[-1]
+
+ context = get_active_span_text_map()
+ with self._device_list_id_gen.get_next_mult(
+ len(hosts) * len(device_ids)
+ ) as stream_ids:
yield self.db.runInteraction(
- "add_device_change_to_streams",
- self._add_device_change_txn,
+ "add_device_outbound_poke_to_stream",
+ self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
- stream_id,
+ stream_ids,
+ context,
)
- return stream_id
- def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
- now = self._clock.time_msec()
+ return stream_ids[-1]
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_id,
- )
+
+ min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
@@ -1052,7 +1129,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_id) for device_id in device_ids],
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
@@ -1060,11 +1137,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
- for device_id in device_ids
+ for stream_id, device_id in zip(stream_ids, device_ids)
],
)
- context = get_active_span_text_map()
+ def _add_device_outbound_poke_to_stream_txn(
+ self, txn, user_id, device_ids, hosts, stream_ids, context,
+ ):
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
+
+ now = self._clock.time_msec()
+ next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn(
txn,
@@ -1072,7 +1160,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
{
"destination": destination,
- "stream_id": stream_id,
+ "stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
@@ -1086,18 +1174,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
- def _prune_old_outbound_device_pokes(self):
+ def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
- that we don't fill up due to dead servers. We keep one entry per
- (destination, user_id) tuple to ensure that the prev_ids remain correct
- if the server does come back.
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
"""
- yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ yesterday = self._clock.time_msec() - prune_age
def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
select_sql = """
- SELECT destination, user_id, max(stream_id) as stream_id
- FROM device_lists_outbound_pokes
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
GROUP BY destination, user_id
HAVING min(ts) < ? AND count(*) > 1
"""
@@ -1108,14 +1225,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not rows:
return
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
delete_sql = """
DELETE FROM device_lists_outbound_pokes
- WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
"""
-
- txn.executemany(
- delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
- )
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
# Since we've deleted unsent deltas, we need to remove the entry
# of last successful sent so that the prev_ids are correctly set.
@@ -1125,7 +1257,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""
txn.executemany(sql, ((row[0], row[1]) for row in rows))
- logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+ logger.info("Pruned %d device list outbound pokes", count)
return run_as_background_process(
"prune_old_outbound_device_pokes",