summaryrefslogtreecommitdiff
path: root/synapse/storage/util/id_generators.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util/id_generators.py')
-rw-r--r--synapse/storage/util/id_generators.py118
1 files changed, 65 insertions, 53 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 5c522f4a..a02dfc7d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,51 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from collections import deque
import contextlib
import threading
class IdGenerator(object):
- def __init__(self, table, column, store):
- self.table = table
- self.column = column
- self.store = store
+ def __init__(self, db_conn, table, column):
self._lock = threading.Lock()
- self._next_id = None
+ self._next_id = _load_max_id(db_conn, table, column)
- @defer.inlineCallbacks
def get_next(self):
- if self._next_id is None:
- yield self.store.runInteraction(
- "IdGenerator_%s" % (self.table,),
- self.get_next_txn,
- )
-
with self._lock:
- i = self._next_id
self._next_id += 1
- defer.returnValue(i)
-
- def get_next_txn(self, txn):
- with self._lock:
- if self._next_id:
- i = self._next_id
- self._next_id += 1
- return i
- else:
- txn.execute(
- "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
- )
+ return self._next_id
- val, = txn.fetchone()
- cur = val or 0
- cur += 1
- self._next_id = cur + 1
- return cur
+def _load_max_id(db_conn, table, column):
+ cur = db_conn.cursor()
+ cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ val, = cur.fetchone()
+ cur.close()
+ return int(val) if val else 1
class StreamIdGenerator(object):
@@ -69,25 +46,23 @@ class StreamIdGenerator(object):
persistence of events can complete out of order.
Usage:
- with stream_id_gen.get_next_txn(txn) as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- def __init__(self, db_conn, table, column):
- self.table = table
- self.column = column
-
+ def __init__(self, db_conn, table, column, extra_tables=[]):
self._lock = threading.Lock()
-
- cur = db_conn.cursor()
- self._current_max = self._get_or_compute_current_max(cur)
- cur.close()
-
+ self._current_max = _load_max_id(db_conn, table, column)
+ for table, column in extra_tables:
+ self._current_max = max(
+ self._current_max,
+ _load_max_id(db_conn, table, column)
+ )
self._unfinished_ids = deque()
- def get_next(self, store):
+ def get_next(self):
"""
Usage:
- with yield stream_id_gen.get_next as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -106,10 +81,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, store, n):
+ def get_next_mult(self, n):
"""
Usage:
- with yield stream_id_gen.get_next(store, n) as stream_ids:
+ with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -130,7 +105,7 @@ class StreamIdGenerator(object):
return manager()
- def get_max_token(self, store):
+ def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
@@ -140,12 +115,49 @@ class StreamIdGenerator(object):
return self._current_max
- def _get_or_compute_current_max(self, txn):
+
+class ChainedIdGenerator(object):
+ """Used to generate new stream ids where the stream must be kept in sync
+ with another stream. It generates pairs of IDs, the first element is an
+ integer ID for this stream, the second element is the ID for the stream
+ that this stream needs to be kept in sync with."""
+
+ def __init__(self, chained_generator, db_conn, table, column):
+ self.chained_generator = chained_generator
+ self._lock = threading.Lock()
+ self._current_max = _load_max_id(db_conn, table, column)
+ self._unfinished_ids = deque()
+
+ def get_next(self):
+ """
+ Usage:
+ with stream_id_gen.get_next() as (stream_id, chained_id):
+ # ... persist event ...
+ """
with self._lock:
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
- rows = txn.fetchall()
- val, = rows[0]
+ self._current_max += 1
+ next_id = self._current_max
+ chained_id = self.chained_generator.get_max_token()
- self._current_max = int(val) if val else 1
+ self._unfinished_ids.append((next_id, chained_id))
- return self._current_max
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield (next_id, chained_id)
+ finally:
+ with self._lock:
+ self._unfinished_ids.remove((next_id, chained_id))
+
+ return manager()
+
+ def get_max_token(self):
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+ """
+ with self._lock:
+ if self._unfinished_ids:
+ stream_id, chained_id = self._unfinished_ids[0]
+ return (stream_id - 1, chained_id)
+
+ return (self._current_max, self.chained_generator.get_max_token())