summaryrefslogtreecommitdiff
path: root/synapse/storage/data_stores/main/censor_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/data_stores/main/censor_events.py')
-rw-r--r--synapse/storage/data_stores/main/censor_events.py208
1 files changed, 208 insertions, 0 deletions
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py
new file mode 100644
index 00000000..2d482617
--- /dev/null
+++ b/synapse/storage/data_stores/main/censor_events.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.internet import defer
+
+from synapse.events.utils import prune_event_dict
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.data_stores.main.events import encode_json
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ def _censor_redactions():
+ return run_as_background_process(
+ "_censor_redactions", self._censor_redactions
+ )
+
+ if self.hs.config.redaction_retention_period is not None:
+ hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+
+ async def _censor_redactions(self):
+ """Censors all redactions older than the configured period that haven't
+ been censored yet.
+
+ By censor we mean update the event_json table with the redacted event.
+ """
+
+ if self.hs.config.redaction_retention_period is None:
+ return
+
+ if not (
+ await self.db.updates.has_completed_background_update(
+ "redactions_have_censored_ts_idx"
+ )
+ ):
+ # We don't want to run this until the appropriate index has been
+ # created.
+ return
+
+ before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+ # We fetch all redactions that:
+ # 1. point to an event we have,
+ # 2. has a received_ts from before the cut off, and
+ # 3. we haven't yet censored.
+ #
+ # This is limited to 100 events to ensure that we don't try and do too
+ # much at once. We'll get called again so this should eventually catch
+ # up.
+ sql = """
+ SELECT redactions.event_id, redacts FROM redactions
+ LEFT JOIN events AS original_event ON (
+ redacts = original_event.event_id
+ )
+ WHERE NOT have_censored
+ AND redactions.received_ts <= ?
+ ORDER BY redactions.received_ts ASC
+ LIMIT ?
+ """
+
+ rows = await self.db.execute(
+ "_censor_redactions_fetch", None, sql, before_ts, 100
+ )
+
+ updates = []
+
+ for redaction_id, event_id in rows:
+ redaction_event = await self.get_event(redaction_id, allow_none=True)
+ original_event = await self.get_event(
+ event_id, allow_rejected=True, allow_none=True
+ )
+
+ # The SQL above ensures that we have both the redaction and
+ # original event, so if the `get_event` calls return None it
+ # means that the redaction wasn't allowed. Either way we know that
+ # the result won't change so we mark the fact that we've checked.
+ if (
+ redaction_event
+ and original_event
+ and original_event.internal_metadata.is_redacted()
+ ):
+ # Redaction was allowed
+ pruned_json = encode_json(
+ prune_event_dict(
+ original_event.room_version, original_event.get_dict()
+ )
+ )
+ else:
+ # Redaction wasn't allowed
+ pruned_json = None
+
+ updates.append((redaction_id, event_id, pruned_json))
+
+ def _update_censor_txn(txn):
+ for redaction_id, event_id, pruned_json in updates:
+ if pruned_json:
+ self._censor_event_txn(txn, event_id, pruned_json)
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="redactions",
+ keyvalues={"event_id": redaction_id},
+ updatevalues={"have_censored": True},
+ )
+
+ await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+ def _censor_event_txn(self, txn, event_id, pruned_json):
+ """Censor an event by replacing its JSON in the event_json table with the
+ provided pruned JSON.
+
+ Args:
+ txn (LoggingTransaction): The database transaction.
+ event_id (str): The ID of the event to censor.
+ pruned_json (str): The pruned JSON
+ """
+ self.db.simple_update_one_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": pruned_json},
+ )
+
+ @defer.inlineCallbacks
+ def expire_event(self, event_id):
+ """Retrieve and expire an event that has expired, and delete its associated
+ expiry timestamp. If the event can't be retrieved, delete its associated
+ timestamp so we don't try to expire it again in the future.
+
+ Args:
+ event_id (str): The ID of the event to delete.
+ """
+ # Try to retrieve the event's content from the database or the event cache.
+ event = yield self.get_event(event_id)
+
+ def delete_expired_event_txn(txn):
+ # Delete the expiry timestamp associated with this event from the database.
+ self._delete_event_expiry_txn(txn, event_id)
+
+ if not event:
+ # If we can't find the event, log a warning and delete the expiry date
+ # from the database so that we don't try to expire it again in the
+ # future.
+ logger.warning(
+ "Can't expire event %s because we don't have it.", event_id
+ )
+ return
+
+ # Prune the event's dict then convert it to JSON.
+ pruned_json = encode_json(
+ prune_event_dict(event.room_version, event.get_dict())
+ )
+
+ # Update the event_json table to replace the event's JSON with the pruned
+ # JSON.
+ self._censor_event_txn(txn, event.event_id, pruned_json)
+
+ # We need to invalidate the event cache entry for this event because we
+ # changed its content in the database. We can't call
+ # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+ # right type.
+ txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ # Send that invalidation to replication so that other workers also invalidate
+ # the event cache.
+ self._send_invalidation_to_replication(
+ txn, "_get_event_cache", (event.event_id,)
+ )
+
+ yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+ def _delete_event_expiry_txn(self, txn, event_id):
+ """Delete the expiry timestamp associated with an event ID without deleting the
+ actual event.
+
+ Args:
+ txn (LoggingTransaction): The transaction to use to perform the deletion.
+ event_id (str): The event ID to delete the associated expiry timestamp of.
+ """
+ return self.db.simple_delete_txn(
+ txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ )