diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/client.py | 8 | ||||
-rw-r--r-- | synapse/replication/tcp/commands.py | 10 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 427 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 58 | ||||
-rw-r--r-- | synapse/replication/tcp/redis.py | 59 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 7 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 2 |
7 files changed, 374 insertions, 197 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 4985e40b..fcf8ebf1 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -24,6 +24,7 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -104,6 +105,7 @@ class ReplicationDataHandler: self._clock = hs.get_clock() self._streams = hs.get_replication_streams() self._instance_name = hs.get_instance_name() + self._typing_handler = hs.get_typing_handler() # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. @@ -127,6 +129,12 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + if stream_name == TypingStream.NAME: + self._typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows] + ) + if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index ccc7f1f0..f33801f8 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -293,20 +293,22 @@ class FederationAckCommand(Command): Format:: - FEDERATION_ACK <token> + FEDERATION_ACK <instance_name> <token> """ NAME = "FEDERATION_ACK" - def __init__(self, token): + def __init__(self, instance_name, token): + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - return cls(int(line)) + instance_name, token = line.split(" ") + return cls(instance_name, int(token)) def to_line(self): - return str(self.token) + return "%s %s" % (self.instance_name, self.token) class RemovePusherCommand(Command): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 55b3b790..1c303f3a 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -14,13 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from typing import ( + Any, + Awaitable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter +from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, @@ -42,8 +56,8 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + TypingStream, ) -from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -55,12 +69,16 @@ inbound_rdata_count = Counter( user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) + user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +# the type of the entries in _command_queues_by_stream +_StreamCommandQueue = Deque[ + Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] +] + + class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. @@ -96,6 +114,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, TypingStream): + # Only add TypingStream as a source on the instance in charge of + # typing. + if hs.config.worker.writers.typing == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -107,10 +133,6 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) - self._position_linearizer = Linearizer( - "replication_position", clock=self._clock - ) - # Map of stream name to batched updates. See RdataCommand for info on # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] @@ -122,10 +144,6 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming stream names that are coming from - # that connection. - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] - LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -133,6 +151,32 @@ class ReplicationCommandHandler: lambda: len(self._connections), ) + # When POSITION or RDATA commands arrive, we stick them in a queue and process + # them in order in a separate background process. + + # the streams which are currently being processed by _unsafe_process_queue + self._processing_streams = set() # type: Set[str] + + # for each stream, a queue of commands that are awaiting processing, and the + # connection that they arrived on. + self._command_queues_by_stream = { + stream_name: _StreamCommandQueue() for stream_name in self._streams + } + + # For each connection, the incoming stream names that have received a POSITION + # from that connection. + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_command_queue", + "Number of inbound RDATA/POSITION commands queued for processing", + ["stream_name"], + lambda: { + (stream_name,): len(queue) + for stream_name, queue in self._command_queues_by_stream.items() + }, + ) + self._is_master = hs.config.worker_app is None self._federation_sender = None @@ -143,6 +187,65 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + def _add_command_to_stream_queue( + self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + ) -> None: + """Queue the given received command for processing + + Adds the given command to the per-stream queue, and processes the queue if + necessary + """ + stream_name = cmd.stream_name + queue = self._command_queues_by_stream.get(stream_name) + if queue is None: + logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) + return + + queue.append((cmd, conn)) + + # if we're already processing this stream, there's nothing more to do: + # the new entry on the queue will get picked up in due course + if stream_name in self._processing_streams: + return + + # fire off a background process to start processing the queue. + run_as_background_process( + "process-replication-data", self._unsafe_process_queue, stream_name + ) + + async def _unsafe_process_queue(self, stream_name: str): + """Processes the command queue for the given stream, until it is empty + + Does not check if there is already a thread processing the queue, hence "unsafe" + """ + assert stream_name not in self._processing_streams + + self._processing_streams.add(stream_name) + try: + queue = self._command_queues_by_stream.get(stream_name) + while queue: + cmd, conn = queue.popleft() + try: + await self._process_command(cmd, conn, stream_name) + except Exception: + logger.exception("Failed to handle command %s", cmd) + finally: + self._processing_streams.discard(stream_name) + + async def _process_command( + self, + cmd: Union[PositionCommand, RdataCommand], + conn: AbstractConnection, + stream_name: str, + ) -> None: + if isinstance(cmd, PositionCommand): + await self._process_position(stream_name, conn, cmd) + elif isinstance(cmd, RdataCommand): + await self._process_rdata(stream_name, conn, cmd) + else: + # This shouldn't be possible + raise Exception("Unrecognised command %s in stream queue", cmd.NAME) + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. @@ -199,7 +302,7 @@ class ReplicationCommandHandler: """ return self._streams_to_replicate - async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) def send_positions_to_connection(self, conn: AbstractConnection): @@ -218,57 +321,73 @@ class ReplicationCommandHandler: ) ) - async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): + def on_USER_SYNC( + self, conn: AbstractConnection, cmd: UserSyncCommand + ) -> Optional[Awaitable[None]]: user_sync_counter.inc() if self._is_master: - await self._presence_handler.update_external_syncs_row( + return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + else: + return None - async def on_CLEAR_USER_SYNC( + def on_CLEAR_USER_SYNC( self, conn: AbstractConnection, cmd: ClearUserSyncsCommand - ): + ) -> Optional[Awaitable[None]]: if self._is_master: - await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + return self._presence_handler.update_external_syncs_clear(cmd.instance_id) + else: + return None - async def on_FEDERATION_ACK( - self, conn: AbstractConnection, cmd: FederationAckCommand - ): + def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): federation_ack_counter.inc() if self._federation_sender: - self._federation_sender.federation_ack(cmd.token) + self._federation_sender.federation_ack(cmd.instance_name, cmd.token) - async def on_REMOVE_PUSHER( + def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand - ): + ) -> Optional[Awaitable[None]]: remove_pusher_counter.inc() if self._is_master: - await self._store.delete_pusher_by_app_id_pushkey_user_id( - app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id - ) + return self._handle_remove_pusher(cmd) + else: + return None - self._notifier.on_new_replication_data() + async def _handle_remove_pusher(self, cmd: RemovePusherCommand): + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self._notifier.on_new_replication_data() - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): + def on_USER_IP( + self, conn: AbstractConnection, cmd: UserIpCommand + ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() if self._is_master: - await self._store.insert_client_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) + return self._handle_user_ip(cmd) + else: + return None + + async def _handle_user_ip(self, cmd: UserIpCommand): + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) - if self._server_notices_sender: - await self._server_notices_sender.on_user_ip(cmd.user_id) + assert self._server_notices_sender is not None + await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -276,63 +395,71 @@ class ReplicationCommandHandler: stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) - raise - - # We linearize here for two reasons: + # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - with await self._position_linearizer.queue(cmd.stream_name): - # make sure that we've processed a POSITION for this stream *on this - # connection*. (A POSITION on another connection is no good, as there - # is no guarantee that we have seen all the intermediate updates.) - sbc = self._streams_by_connection.get(conn) - if not sbc or stream_name not in sbc: - # Let's drop the row for now, on the assumption we'll receive a - # `POSITION` soon and we'll catch up correctly then. - logger.debug( - "Discarding RDATA for unconnected stream %s -> %s", - stream_name, - cmd.token, - ) - return - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token). - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got RDATA for unknown stream: %s", stream_name) - return - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # Discard this data if this token is earlier than the current - # position. Note that streams can be reset (in which case you - # expect an earlier token), but that must be preceded by a - # POSITION command. - if cmd.token <= current_token: - logger.debug( - "Discarding RDATA from stream %s at position %s before previous position %s", - stream_name, - cmd.token, - current_token, - ) - else: - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + self._add_command_to_stream_queue(conn, cmd) + + async def _process_rdata( + self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + ) -> None: + """Process an RDATA command + + Called after the command has been popped off the queue of inbound commands + """ + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception as e: + raise Exception( + "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) + ) from e + + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + return + + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + + stream = self._streams[stream_name] + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -351,78 +478,74 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - stream_name = cmd.stream_name - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", stream_name) - return + self._add_command_to_stream_queue(conn, cmd) - # We protect catching up with a linearizer in case the replication - # connection reconnects under us. - with await self._position_linearizer.queue(stream_name): - # We're about to go and catch up with the stream, so remove from set - # of connected streams. - for streams in self._streams_by_connection.values(): - streams.discard(stream_name) - - # We clear the pending batches for the stream as the fetching of the - # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(stream_name, []) - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # If the position token matches our current token then we're up to - # date and there's nothing to do. Otherwise, fetch all updates - # between then and now. - missing_updates = cmd.token != current_token - while missing_updates: - logger.info( - "Fetching replication rows for '%s' between %i and %i", - stream_name, - current_token, - cmd.token, - ) - ( - updates, - current_token, - missing_updates, - ) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token - ) + async def _process_position( + self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + ) -> None: + """Process a POSITION command - # TODO: add some tests for this + Called after the command has been popped off the queue of inbound commands + """ + stream = self._streams[stream_name] - # Some streams return multiple rows with the same stream IDs, - # which need to be processed in batches. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) - for token, rows in _batch_updates(updates): - await self.on_rdata( - stream_name, - cmd.instance_name, - token, - [stream.parse_row(row) for row in rows], - ) + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) - # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) + (updates, current_token, missing_updates) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token ) - self._streams_by_connection.setdefault(conn, set()).add(stream_name) + # TODO: add some tests for this - async def on_REMOTE_SERVER_UP( - self, conn: AbstractConnection, cmd: RemoteServerUpCommand - ): + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) + + self._streams_by_connection.setdefault(conn, set()).add(stream_name) + + def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -527,7 +650,7 @@ class ReplicationCommandHandler: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ - self.send_command(FederationAckCommand(token)) + self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index ca47f5cc..03509238 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -50,6 +50,7 @@ import abc import fcntl import logging import struct +from inspect import isawaitable from typing import TYPE_CHECKING, List from prometheus_client import Counter @@ -57,8 +58,12 @@ from prometheus_client import Counter from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import PreserveLoggingContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, @@ -124,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`. + `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine; + if so, that will get run as a background process. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -160,6 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + ctx_name = "replication-conn-%s" % self.conn_id + self._logging_context = BackgroundProcessLoggingContext(ctx_name) + self._logging_context.request = ctx_name + def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -210,6 +223,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def lineReceived(self, line: bytes): """Called when we've received a line """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_line(line) + + def _parse_and_dispatch_line(self, line: bytes): if line.strip() == "": # Ignore blank lines return @@ -232,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. First calls `self.on_<COMMAND>` if it exists, then calls - `self.command_handler.on_<COMMAND>` if it exists. This allows for - protocol level handling of commands (e.g. PINGs), before delegating to - the handler. + `self.command_handler.on_<COMMAND>` if it exists (which can optionally + return an Awaitable). + + This allows for protocol level handling of commands (e.g. PINGs), before + delegating to the handler. Args: cmd: received command @@ -254,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # specific handling. cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + cmd_func(cmd) handled = True # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(self, cmd) + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) + handled = True if not handled: @@ -336,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - async def on_PING(self, line): + def on_PING(self, line): self.received_ping = True - async def on_ERROR(self, cmd): + def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -397,6 +422,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: self.transport.unregisterProducer() + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def __str__(self): addr = None if self.transport: @@ -431,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ServerCommand(self.server_name)) super().connectionMade() - async def on_NAME(self, cmd): + def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -460,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Once we've connected subscribe to the necessary streams self.replicate() - async def on_SERVER(self, cmd): + def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 0a7e7f67..f225e533 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -14,12 +14,16 @@ # limitations under the License. import logging +from inspect import isawaitable from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( Command, ReplicateCommand, @@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): stream_name = None # type: str outbound_redis_connection = None # type: txredisapi.RedisProtocol + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler" + ) + def connectionMade(self): logger.info("Connected to redis") super().connectionMade() @@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_message(message) + def _parse_and_dispatch_message(self, message: str): if message.strip() == "": # Ignore blank lines return @@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND>, which should return an awaitable. + Delegates to `self.handler.on_<COMMAND>` (which can optionally return an + Awaitable). Args: cmd: received command """ - handled = False - - # First call any command handlers on this instance. These are for redis - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(self, cmd) - handled = True - - if not handled: + if not cmd_func: logger.warning("Unhandled command: %r", cmd) + return + + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) self.handler.lost_connection(self) + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def send_command(self, cmd: Command): """Send a command if connection has been established. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9076bbe9..7a42de3f 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -294,11 +294,12 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - if hs.config.worker_app is None: - # on the master, query the typing handler + writer_instance = hs.config.worker.writers.typing + if writer_instance == hs.get_instance_name(): + # On the writer, query the typing handler update_function = typing_handler.get_all_typing_updates else: - # Query master process + # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 1c2a4cce..16c63ff4 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq -from collections import Iterable +from collections.abc import Iterable from typing import List, Tuple, Type import attr |