summaryrefslogtreecommitdiff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2020-06-14 18:05:46 +0200
committerAndrej Shadura <andrewsh@debian.org>2020-06-14 18:05:46 +0200
commit08d5e062dc43f8af47c39699483652c1235ec5d2 (patch)
treefe437819faf5d7e1173b41388ac2973cc0c342d1 /synapse/replication/tcp
parentda7f96aa2a3b1485dafa016f38aac1d4376b64e7 (diff)
New upstream version 1.15.0
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py133
-rw-r--r--synapse/replication/tcp/commands.py33
-rw-r--r--synapse/replication/tcp/handler.py111
-rw-r--r--synapse/replication/tcp/redis.py7
-rw-r--r--synapse/replication/tcp/resource.py33
-rw-r--r--synapse/replication/tcp/streams/_base.py160
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/replication/tcp/streams/federation.py36
8 files changed, 358 insertions, 159 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c35..df29732f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,14 +14,23 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
-
+import heapq
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Dict, List, Tuple
+from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
-from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams.events import (
+ EventsStream,
+ EventsStreamEventRow,
+ EventsStreamRow,
+)
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,6 +39,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
@@ -83,8 +96,20 @@ class ReplicationDataHandler:
to handle updates in additional ways.
"""
- def __init__(self, store: BaseSlavedStore):
- self.store = store
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.pusher_pool = hs.get_pusherpool()
+ self.notifier = hs.get_notifier()
+ self._reactor = hs.get_reactor()
+ self._clock = hs.get_clock()
+ self._streams = hs.get_replication_streams()
+ self._instance_name = hs.get_instance_name()
+
+ # Map from stream to list of deferreds waiting for the stream to
+ # arrive at a particular position. The lists are sorted by stream position.
+ self._streams_to_waiters = (
+ {}
+ ) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -100,10 +125,100 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
- self.store.process_replication_rows(stream_name, token, rows)
-
- async def on_position(self, stream_name: str, token: int):
- self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, instance_name, token, rows)
+
+ if stream_name == EventsStream.NAME:
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ assert isinstance(row, EventsStreamRow)
+
+ event = await self.store.get_event(
+ row.data.event_id, allow_rejected=True
+ )
+ if event.rejected_reason:
+ continue
+
+ extra_users = () # type: Tuple[str, ...]
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ await self.pusher_pool.on_new_notifications(token, token)
+
+ # Notify any waiting deferreds. The list is ordered by position so we
+ # just iterate through the list until we reach a position that is
+ # greater than the received row position.
+ waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+ # Index of first item with a position after the current token, i.e we
+ # have called all deferreds before this index. If not overwritten by
+ # loop below means either a) no items in list so no-op or b) all items
+ # in list were called and so the list should be cleared. Setting it to
+ # `len(list)` works for both cases.
+ index_of_first_deferred_not_called = len(waiting_list)
+
+ for idx, (position, deferred) in enumerate(waiting_list):
+ if position <= token:
+ try:
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except Exception:
+ # The deferred has been cancelled or timed out.
+ pass
+ else:
+ # The list is sorted by position so we don't need to continue
+ # checking any further entries in the list.
+ index_of_first_deferred_not_called = idx
+ break
+
+ # Drop all entries in the waiting list that were called in the above
+ # loop. (This maintains the order so no need to resort)
+ waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
+
+ async def wait_for_stream_position(
+ self, instance_name: str, stream_name: str, position: int
+ ):
+ """Wait until this instance has received updates up to and including
+ the given stream position.
+ """
+
+ if instance_name == self._instance_name:
+ # We don't get told about updates written by this process, and
+ # anyway in that case we don't need to wait.
+ return
+
+ current_position = self._streams[stream_name].current_token(self._instance_name)
+ if position <= current_position:
+ # We're already past the position
+ return
+
+ # Create a new deferred that times out after N seconds, as we don't want
+ # to wedge here forever.
+ deferred = Deferred()
+ deferred = timeout_deferred(
+ deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+ )
+
+ waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
+
+ # We insert into the list using heapq as it is more efficient than
+ # pushing then resorting each time.
+ heapq.heappush(waiting_list, (position, deferred))
+
+ # We measure here to get in flight counts and average waiting time.
+ with Measure(self._clock, "repl.wait_for_stream_position"):
+ logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+ await make_deferred_yieldable(deferred)
+ logger.info(
+ "Finished waiting for repl stream %r to reach %s", stream_name, position
+ )
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d..c04f6228 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
-class InvalidateCacheCommand(Command):
- """Sent by the client to invalidate an upstream cache.
-
- THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
- NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
- Mainly used to invalidate destination retry timing caches.
-
- Format::
-
- INVALIDATE_CACHE <cache_func> <keys_json>
-
- Where <keys_json> is a json list.
- """
-
- NAME = "INVALIDATE_CACHE"
-
- def __init__(self, cache_func, keys):
- self.cache_func = cache_func
- self.keys = keys
-
- @classmethod
- def from_line(cls, line):
- cache_func, keys_json = line.split(" ", 1)
-
- return cls(cache_func, json.loads(keys_json))
-
- def to_line(self):
- return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand,
FederationAckCommand,
RemovePusherCommand,
- InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
- InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 4328b38e..cbcf46f3 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
# limitations under the License.
import logging
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Tuple,
- TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
- InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -48,7 +36,14 @@ from synapse.replication.tcp.commands import (
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.replication.tcp.streams import (
+ STREAMS_MAP,
+ BackfillStream,
+ CachesStream,
+ EventsStream,
+ FederationStream,
+ Stream,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -85,6 +80,34 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
+ # List of streams that this instance is the source of
+ self._streams_to_replicate = [] # type: List[Stream]
+
+ for stream in self._streams.values():
+ if stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream.
+ self._streams_to_replicate.append(stream)
+ continue
+
+ if isinstance(stream, (EventsStream, BackfillStream)):
+ # Only add EventStream and BackfillStream as a source on the
+ # instance in charge of event persistence.
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ self._streams_to_replicate.append(stream)
+
+ continue
+
+ # Only add any other streams if we're on master.
+ if hs.config.worker_app is not None:
+ continue
+
+ if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ # We only support federation stream if federation sending
+ # has been disabled on the master.
+ continue
+
+ self._streams_to_replicate.append(stream)
+
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
@@ -136,6 +159,9 @@ class ReplicationCommandHandler:
hs.config.redis_port,
)
+ # First let's ensure that we have a ReplicationStreamer started.
+ hs.get_replication_streamer()
+
# We need two connections to redis, one for the subscription stream and
# one to send commands to (as you can't send further redis commands to a
# connection after SUBSCRIBE is called).
@@ -162,16 +188,33 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a map from stream name to all streams.
+ """
+ return self._streams
+
+ def get_streams_to_replicate(self) -> List[Stream]:
+ """Get a list of streams that this instances replicates.
+ """
+ return self._streams_to_replicate
+
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
- # We only want to announce positions by the writer of the streams.
- # Currently this is just the master process.
- if not self._is_master:
- return
+ self.send_positions_to_connection(conn)
+
+ def send_positions_to_connection(self, conn: AbstractConnection):
+ """Send current position of all streams this process is source of to
+ the connection.
+ """
- for stream_name, stream in self._streams.items():
- current_token = stream.current_token()
+ # We respond with current position of all streams this instance
+ # replicates.
+ for stream in self.get_streams_to_replicate():
self.send_command(
- PositionCommand(stream_name, self._instance_name, current_token)
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ stream.current_token(self._instance_name),
+ )
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
@@ -208,18 +251,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(
- self, conn: AbstractConnection, cmd: InvalidateCacheCommand
- ):
- invalidate_cache_counter.inc()
-
- if self._is_master:
- # We invalidate the cache locally, but then also stream that to other
- # workers.
- await self._store.invalidate_cache_and_stream(
- cmd.cache_func, tuple(cmd.keys)
- )
-
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@@ -293,7 +324,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
- logger.debug("Received rdata %s -> %s", stream_name, token)
+ logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@@ -324,7 +355,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
- current_token = stream.current_token()
+ current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@@ -361,7 +392,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(stream_name, cmd.token)
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@@ -489,12 +522,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
- def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
def send_user_ip(
self,
user_id: str,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 55bfa71d..e776b631 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -70,7 +70,6 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe)
- self.handler.new_connection(self)
async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
@@ -81,9 +80,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
+ self.handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
+ # We send out our positions when there is a new connection in case the
+ # other side missed updates. We do this for Redis connections as the
+ # otherside won't know we've connected and so won't issue a REPLICATE.
+ self.handler.send_positions_to_connection(self)
+
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589..41569305 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,6 @@
import logging
import random
-from typing import Dict, List
from prometheus_client import Counter
@@ -25,7 +24,6 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -71,26 +69,11 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
- # Work out list of streams that this instance is the source of.
- self.streams = [] # type: List[Stream]
- if hs.config.worker_app is None:
- for stream in STREAMS_MAP.values():
- if stream == FederationStream and hs.config.send_federation:
- # We only support federation stream if federation sending
- # hase been disabled on the master.
- continue
-
- self.streams.append(stream(hs))
-
- self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
- # Only bother registering the notifier callback if we have streams to
- # publish.
- if self.streams:
- self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -98,10 +81,8 @@ class ReplicationStreamer(object):
self.command_handler = hs.get_tcp_replication()
- def get_streams(self) -> Dict[str, Stream]:
- """Get a mapp from stream name to stream instance.
- """
- return self.streams_by_name
+ # Set of streams to replicate.
+ self.streams = self.command_handler.get_streams_to_replicate()
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -145,7 +126,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.current_token():
+ if stream.last_token == stream.current_token(
+ self._instance_name
+ ):
continue
if self._replication_torture_level:
@@ -157,7 +140,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
- stream.current_token(),
+ stream.current_token(self._instance_name),
)
try:
updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c36..4acefc8a 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@@ -95,19 +108,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
- current_token_function: Callable[[], Token],
+ current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
- current_token_function and update_function are callbacks which should be
- implemented by subclasses.
+ `current_token_function` and `update_function` are callbacks which
+ should be implemented by subclasses.
- current_token_function is called to get the current token of the underlying
- stream.
+ `current_token_function` takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process). On the writer process this is where
+ the writer has successfully written up to, whereas on other processes
+ this is the position which we have received updates up to over
+ replication. (Note that most streams have a single writer and so their
+ implementations ignore the instance name passed in).
- update_function is called to get updates for this stream between a pair of
- stream tokens. See the UpdateFunction type definition for more info.
+ `update_function` is called to get updates for this stream between a
+ pair of stream tokens. See the `UpdateFunction` type definition for more
+ info.
Args:
local_instance_name: The instance name of the current process
@@ -119,13 +138,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.last_token = self.current_token()
+ self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@@ -137,7 +156,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
- current_token = self.current_token()
+ current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@@ -169,6 +188,16 @@ class Stream(object):
return updates, upto_token, limited
+def current_token_without_instance(
+ current_token: Callable[[], int]
+) -> Callable[[str], int]:
+ """Takes a current token callback function for a single writer stream
+ that doesn't take an instance name parameter and wraps it in a function that
+ does accept an instance name parameter but ignores it.
+ """
+ return lambda instance_name: current_token()
+
+
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@@ -234,7 +263,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_backfill_token,
+ current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@@ -270,7 +299,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), store.get_current_presence_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_current_presence_token),
+ update_function,
)
@@ -295,7 +326,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(), typing_handler.get_current_token, update_function
+ hs.get_instance_name(),
+ current_token_without_instance(typing_handler.get_current_token),
+ update_function,
)
@@ -318,7 +351,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_receipt_stream_id,
+ current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@@ -338,7 +371,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
- def _current_token(self) -> int:
+ def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@@ -372,7 +405,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
- store.get_pushers_stream_token,
+ current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@@ -401,13 +434,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token,
- db_query_to_update_function(store.get_all_updated_caches),
+ self.store.get_cache_stream_token,
+ self._update_function,
)
+ async def _update_function(
+ self, instance_name: str, from_token: int, upto_token: int, limit: int
+ ):
+ rows = await self.store.get_all_updated_caches(
+ instance_name, from_token, upto_token, limit
+ )
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
class PublicRoomsStream(Stream):
"""The public rooms list changed
@@ -430,7 +477,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_current_public_room_stream_id,
+ current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@@ -451,7 +498,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@@ -469,7 +516,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_to_device_stream_token,
+ current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@@ -489,7 +536,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_max_account_data_stream_id,
+ current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@@ -499,32 +546,69 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self.store.get_max_account_data_stream_id,
- db_query_to_update_function(self._update_function),
+ current_token_without_instance(self.store.get_max_account_data_stream_id),
+ self._update_function,
)
- async def _update_function(self, from_token, to_token, limit):
- global_results, room_results = await self.store.get_all_updated_account_data(
- from_token, from_token, to_token, limit
+ async def _update_function(
+ self, instance_name: str, from_token: int, to_token: int, limit: int
+ ) -> StreamUpdateResult:
+ limited = False
+ global_results = await self.store.get_updated_global_account_data(
+ from_token, to_token, limit
)
- results = list(room_results)
- results.extend(
- (stream_id, user_id, None, account_data_type)
+ # if the global results hit the limit, we'll need to limit the room results to
+ # the same stream token.
+ if len(global_results) >= limit:
+ to_token = global_results[-1][0]
+ limited = True
+
+ room_results = await self.store.get_updated_room_account_data(
+ from_token, to_token, limit
+ )
+
+ # likewise, if the room results hit the limit, limit the global results to
+ # the same stream token.
+ if len(room_results) >= limit:
+ to_token = room_results[-1][0]
+ limited = True
+
+ # convert the global results to the right format, and limit them to the to_token
+ # at the same time
+ global_rows = (
+ (stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
+ if stream_id <= to_token
+ )
+
+ # we know that the room_results are already limited to `to_token` so no need
+ # for a check on `stream_id` here.
+ room_rows = (
+ (stream_id, (user_id, room_id, account_data_type))
+ for stream_id, user_id, room_id, account_data_type in room_results
)
- return results
+ # We need to return a sorted list, so merge them together.
+ #
+ # Note: We order only by the stream ID to work around a bug where the
+ # same stream ID could appear in both `global_rows` and `room_rows`,
+ # leading to a comparison between the data tuples. The comparison could
+ # fail due to attempting to compare the `room_id` which results in a
+ # `TypeError` from comparing a `str` vs `None`.
+ updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+ return updates, to_token, limited
class GroupServerStream(Stream):
@@ -540,7 +624,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_group_stream_token,
+ current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@@ -558,7 +642,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_device_stream_token,
+ current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d8..f3703903 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self._store.get_current_events_token,
+ current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e3..9bcd13b0 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import (
+ Stream,
+ current_token_without_instance,
+ make_http_update_function,
+)
class FederationStream(Stream):
@@ -35,21 +39,35 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- # Not all synapse instances will have a federation sender instance,
- # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
- # so we stub the stream out when that is the case.
- if hs.config.worker_app is None or hs.should_send_federation():
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
federation_sender = hs.get_federation_sender()
- current_token = federation_sender.get_current_token
- update_function = db_query_to_update_function(
- federation_sender.get_replication_rows
+ current_token = current_token_without_instance(
+ federation_sender.get_current_token
)
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
else:
- current_token = lambda: 0
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
+ current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
+ def _stub_current_token(instance_name: str) -> int:
+ # dummy current-token method for use on workers
+ return 0
+
+ @staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False