diff options
Diffstat (limited to 'synapse/storage/util/id_generators.py')
-rw-r--r-- | synapse/storage/util/id_generators.py | 118 |
1 files changed, 65 insertions, 53 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5c522f4a..a02dfc7d 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,51 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from collections import deque import contextlib import threading class IdGenerator(object): - def __init__(self, table, column, store): - self.table = table - self.column = column - self.store = store + def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = None + self._next_id = _load_max_id(db_conn, table, column) - @defer.inlineCallbacks def get_next(self): - if self._next_id is None: - yield self.store.runInteraction( - "IdGenerator_%s" % (self.table,), - self.get_next_txn, - ) - with self._lock: - i = self._next_id self._next_id += 1 - defer.returnValue(i) - - def get_next_txn(self, txn): - with self._lock: - if self._next_id: - i = self._next_id - self._next_id += 1 - return i - else: - txn.execute( - "SELECT MAX(%s) FROM %s" % (self.column, self.table,) - ) + return self._next_id - val, = txn.fetchone() - cur = val or 0 - cur += 1 - self._next_id = cur + 1 - return cur +def _load_max_id(db_conn, table, column): + cur = db_conn.cursor() + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + val, = cur.fetchone() + cur.close() + return int(val) if val else 1 class StreamIdGenerator(object): @@ -69,25 +46,23 @@ class StreamIdGenerator(object): persistence of events can complete out of order. Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column): - self.table = table - self.column = column - + def __init__(self, db_conn, table, column, extra_tables=[]): self._lock = threading.Lock() - - cur = db_conn.cursor() - self._current_max = self._get_or_compute_current_max(cur) - cur.close() - + self._current_max = _load_max_id(db_conn, table, column) + for table, column in extra_tables: + self._current_max = max( + self._current_max, + _load_max_id(db_conn, table, column) + ) self._unfinished_ids = deque() - def get_next(self, store): + def get_next(self): """ Usage: - with yield stream_id_gen.get_next as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -106,10 +81,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, store, n): + def get_next_mult(self, n): """ Usage: - with yield stream_id_gen.get_next(store, n) as stream_ids: + with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -130,7 +105,7 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, store): + def get_max_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -140,12 +115,49 @@ class StreamIdGenerator(object): return self._current_max - def _get_or_compute_current_max(self, txn): + +class ChainedIdGenerator(object): + """Used to generate new stream ids where the stream must be kept in sync + with another stream. It generates pairs of IDs, the first element is an + integer ID for this stream, the second element is the ID for the stream + that this stream needs to be kept in sync with.""" + + def __init__(self, chained_generator, db_conn, table, column): + self.chained_generator = chained_generator + self._lock = threading.Lock() + self._current_max = _load_max_id(db_conn, table, column) + self._unfinished_ids = deque() + + def get_next(self): + """ + Usage: + with stream_id_gen.get_next() as (stream_id, chained_id): + # ... persist event ... + """ with self._lock: - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] + self._current_max += 1 + next_id = self._current_max + chained_id = self.chained_generator.get_max_token() - self._current_max = int(val) if val else 1 + self._unfinished_ids.append((next_id, chained_id)) - return self._current_max + @contextlib.contextmanager + def manager(): + try: + yield (next_id, chained_id) + finally: + with self._lock: + self._unfinished_ids.remove((next_id, chained_id)) + + return manager() + + def get_max_token(self): + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + with self._lock: + if self._unfinished_ids: + stream_id, chained_id = self._unfinished_ids[0] + return (stream_id - 1, chained_id) + + return (self._current_max, self.chained_generator.get_max_token()) |