summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2022-04-22 20:34:38 +0200
committerAndrej Shadura <andrewsh@debian.org>2022-04-22 20:34:38 +0200
commit1b9b92888056ce7fd1f3a010ca7afd5c3963d44e (patch)
tree716cb361eb4332eca7b147f35c87ceb29b3ac958 /synapse
parent02eb467b57bf21597094de52232c93d5f4a38b7d (diff)
New upstream version 1.57.1
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rwxr-xr-xsynapse/_scripts/export_signing_key.py13
-rwxr-xr-xsynapse/_scripts/generate_config.py2
-rwxr-xr-xsynapse/_scripts/generate_log_config.py2
-rwxr-xr-xsynapse/_scripts/generate_signing_key.py2
-rwxr-xr-xsynapse/_scripts/hash_password.py4
-rwxr-xr-xsynapse/_scripts/move_remote_media_to_new_store.py28
-rw-r--r--synapse/_scripts/register_new_matrix_user.py3
-rw-r--r--synapse/_scripts/review_recent_signups.py4
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py271
-rwxr-xr-xsynapse/_scripts/synctl.py10
-rwxr-xr-xsynapse/_scripts/update_synapse_database.py39
-rw-r--r--synapse/app/_base.py4
-rw-r--r--synapse/app/admin_cmd.py2
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/appservice/__init__.py12
-rw-r--r--synapse/appservice/api.py10
-rw-r--r--synapse/appservice/scheduler.py53
-rw-r--r--synapse/config/_base.py5
-rw-r--r--synapse/config/_base.pyi5
-rw-r--r--synapse/config/account_validity.py4
-rw-r--r--synapse/config/api.py6
-rw-r--r--synapse/config/appservice.py7
-rw-r--r--synapse/config/auth.py7
-rw-r--r--synapse/config/background_updates.py7
-rw-r--r--synapse/config/cache.py7
-rw-r--r--synapse/config/captcha.py26
-rw-r--r--synapse/config/cas.py5
-rw-r--r--synapse/config/consent.py13
-rw-r--r--synapse/config/database.py12
-rw-r--r--synapse/config/emailconfig.py7
-rw-r--r--synapse/config/experimental.py14
-rw-r--r--synapse/config/federation.py7
-rw-r--r--synapse/config/groups.py8
-rw-r--r--synapse/config/jwt.py8
-rw-r--r--synapse/config/key.py26
-rw-r--r--synapse/config/logger.py7
-rw-r--r--synapse/config/metrics.py9
-rw-r--r--synapse/config/modules.py5
-rw-r--r--synapse/config/oembed.py4
-rw-r--r--synapse/config/oidc.py4
-rw-r--r--synapse/config/password_auth_providers.py3
-rw-r--r--synapse/config/push.py8
-rw-r--r--synapse/config/ratelimiting.py8
-rw-r--r--synapse/config/redis.py7
-rw-r--r--synapse/config/registration.py10
-rw-r--r--synapse/config/repository.py7
-rw-r--r--synapse/config/retention.py7
-rw-r--r--synapse/config/room.py6
-rw-r--r--synapse/config/room_directory.py6
-rw-r--r--synapse/config/saml2.py10
-rw-r--r--synapse/config/server.py41
-rw-r--r--synapse/config/server_notices.py19
-rw-r--r--synapse/config/spam_checker.py3
-rw-r--r--synapse/config/sso.py6
-rw-r--r--synapse/config/stats.py7
-rw-r--r--synapse/config/third_party_event_rules.py5
-rw-r--r--synapse/config/tls.py19
-rw-r--r--synapse/config/tracer.py7
-rw-r--r--synapse/config/user_directory.py8
-rw-r--r--synapse/config/voip.py8
-rw-r--r--synapse/config/workers.py16
-rw-r--r--synapse/crypto/keyring.py2
-rw-r--r--synapse/events/builder.py2
-rw-r--r--synapse/events/third_party_rules.py26
-rw-r--r--synapse/federation/federation_client.py3
-rw-r--r--synapse/federation/federation_server.py10
-rw-r--r--synapse/federation/transport/client.py7
-rw-r--r--synapse/handlers/account_data.py52
-rw-r--r--synapse/handlers/account_validity.py4
-rw-r--r--synapse/handlers/appservice.py154
-rw-r--r--synapse/handlers/auth.py3
-rw-r--r--synapse/handlers/device.py134
-rw-r--r--synapse/handlers/e2e_keys.py4
-rw-r--r--synapse/handlers/e2e_room_keys.py14
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/federation_event.py51
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/pagination.py9
-rw-r--r--synapse/handlers/presence.py6
-rw-r--r--synapse/handlers/read_marker.py2
-rw-r--r--synapse/handlers/relations.py275
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/handlers/room_batch.py38
-rw-r--r--synapse/handlers/room_member.py4
-rw-r--r--synapse/handlers/search.py30
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/sync.py121
-rw-r--r--synapse/handlers/ui_auth/checkers.py2
-rw-r--r--synapse/http/client.py16
-rw-r--r--synapse/http/matrixfederationclient.py35
-rw-r--r--synapse/http/types.py21
-rw-r--r--synapse/logging/opentracing.py3
-rw-r--r--synapse/metrics/__init__.py16
-rw-r--r--synapse/metrics/_gc.py6
-rw-r--r--synapse/metrics/_reactor_metrics.py4
-rw-r--r--synapse/metrics/_types.py31
-rw-r--r--synapse/metrics/background_process_metrics.py3
-rw-r--r--synapse/metrics/jemalloc.py20
-rw-r--r--synapse/module_api/__init__.py105
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/python_dependencies.py8
-rw-r--r--synapse/replication/slave/storage/client_ips.py59
-rw-r--r--synapse/replication/slave/storage/devices.py16
-rw-r--r--synapse/replication/tcp/client.py2
-rw-r--r--synapse/replication/tcp/commands.py8
-rw-r--r--synapse/replication/tcp/external_cache.py31
-rw-r--r--synapse/replication/tcp/handler.py48
-rw-r--r--synapse/rest/client/relations.py172
-rw-r--r--synapse/rest/client/room_batch.py21
-rw-r--r--synapse/rest/client/sync.py19
-rw-r--r--synapse/rest/key/v2/local_key_resource.py12
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py8
-rw-r--r--synapse/rest/media/v1/media_repository.py6
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py9
-rw-r--r--synapse/server_notices/server_notices_manager.py59
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/storage/database.py167
-rw-r--r--synapse/storage/databases/main/__init__.py20
-rw-r--r--synapse/storage/databases/main/appservice.py107
-rw-r--r--synapse/storage/databases/main/client_ips.py167
-rw-r--r--synapse/storage/databases/main/devices.py315
-rw-r--r--synapse/storage/databases/main/events.py10
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py60
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/registration.py157
-rw-r--r--synapse/storage/databases/main/relations.py227
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/signatures.py2
-rw-r--r--synapse/storage/databases/main/state.py32
-rw-r--r--synapse/storage/databases/main/stream.py70
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/relations.py84
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql23
-rw-r--r--synapse/storage/schema/main/delta/69/01as_txn_seq.py44
-rw-r--r--synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql38
-rw-r--r--synapse/storage/state.py20
-rw-r--r--synapse/storage/types.py1
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py126
-rw-r--r--synapse/util/async_helpers.py148
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/visibility.py234
146 files changed, 3132 insertions, 1574 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 2e651053..b62eed66 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -68,7 +68,7 @@ try:
except ImportError:
pass
-__version__ = "1.56.0"
+__version__ = "1.57.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/export_signing_key.py b/synapse/_scripts/export_signing_key.py
index 3d254348..12c890bd 100755
--- a/synapse/_scripts/export_signing_key.py
+++ b/synapse/_scripts/export_signing_key.py
@@ -15,19 +15,19 @@
import argparse
import sys
import time
-from typing import Optional
+from typing import NoReturn, Optional
-import nacl.signing
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
+from signedjson.types import VerifyKey
-def exit(status: int = 0, message: Optional[str] = None):
+def exit(status: int = 0, message: Optional[str] = None) -> NoReturn:
if message:
print(message, file=sys.stderr)
sys.exit(status)
-def format_plain(public_key: nacl.signing.VerifyKey):
+def format_plain(public_key: VerifyKey) -> None:
print(
"%s:%s %s"
% (
@@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
)
-def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
+def format_for_config(public_key: VerifyKey, expiry_ts: int) -> None:
print(
' "%s:%s": { key: "%s", expired_ts: %i }'
% (
@@ -50,7 +50,7 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
)
-def main():
+def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -94,7 +94,6 @@ def main():
message="Error reading key from file %s: %s %s"
% (file.name, type(e), e),
)
- res = []
for key in res:
formatter(get_verify_key(key))
diff --git a/synapse/_scripts/generate_config.py b/synapse/_scripts/generate_config.py
index 75fce20b..08eb8ef1 100755
--- a/synapse/_scripts/generate_config.py
+++ b/synapse/_scripts/generate_config.py
@@ -7,7 +7,7 @@ import sys
from synapse.config.homeserver import HomeServerConfig
-def main():
+def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-dir",
diff --git a/synapse/_scripts/generate_log_config.py b/synapse/_scripts/generate_log_config.py
index 82fc7631..7ae08ec0 100755
--- a/synapse/_scripts/generate_log_config.py
+++ b/synapse/_scripts/generate_log_config.py
@@ -20,7 +20,7 @@ import sys
from synapse.config.logger import DEFAULT_LOG_CONFIG
-def main():
+def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
diff --git a/synapse/_scripts/generate_signing_key.py b/synapse/_scripts/generate_signing_key.py
index bc26d25b..3f8f5da7 100755
--- a/synapse/_scripts/generate_signing_key.py
+++ b/synapse/_scripts/generate_signing_key.py
@@ -20,7 +20,7 @@ from signedjson.key import generate_signing_key, write_signing_keys
from synapse.util.stringutils import random_string
-def main():
+def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
diff --git a/synapse/_scripts/hash_password.py b/synapse/_scripts/hash_password.py
index 708640c7..3aa29de5 100755
--- a/synapse/_scripts/hash_password.py
+++ b/synapse/_scripts/hash_password.py
@@ -9,7 +9,7 @@ import bcrypt
import yaml
-def prompt_for_pass():
+def prompt_for_pass() -> str:
password = getpass.getpass("Password: ")
if not password:
@@ -23,7 +23,7 @@ def prompt_for_pass():
return password
-def main():
+def main() -> None:
bcrypt_rounds = 12
password_pepper = ""
diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py
index 9667d95d..819afaac 100755
--- a/synapse/_scripts/move_remote_media_to_new_store.py
+++ b/synapse/_scripts/move_remote_media_to_new_store.py
@@ -42,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
logger = logging.getLogger()
-def main(src_repo, dest_repo):
+def main(src_repo: str, dest_repo: str) -> None:
src_paths = MediaFilePaths(src_repo)
dest_paths = MediaFilePaths(dest_repo)
for line in sys.stdin:
@@ -55,14 +55,19 @@ def main(src_repo, dest_repo):
move_media(parts[0], parts[1], src_paths, dest_paths)
-def move_media(origin_server, file_id, src_paths, dest_paths):
+def move_media(
+ origin_server: str,
+ file_id: str,
+ src_paths: MediaFilePaths,
+ dest_paths: MediaFilePaths,
+) -> None:
"""Move the given file, and any thumbnails, to the dest repo
Args:
- origin_server (str):
- file_id (str):
- src_paths (MediaFilePaths):
- dest_paths (MediaFilePaths):
+ origin_server:
+ file_id:
+ src_paths:
+ dest_paths:
"""
logger.info("%s/%s", origin_server, file_id)
@@ -91,7 +96,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
)
-def mkdir_and_move(original_file, dest_file):
+def mkdir_and_move(original_file: str, dest_file: str) -> None:
dirname = os.path.dirname(dest_file)
if not os.path.exists(dirname):
logger.debug("mkdir %s", dirname)
@@ -109,10 +114,9 @@ if __name__ == "__main__":
parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args()
- logging_config = {
- "level": logging.DEBUG if args.v else logging.INFO,
- "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
- }
- logging.basicConfig(**logging_config)
+ logging.basicConfig(
+ level=logging.DEBUG if args.v else logging.INFO,
+ format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+ )
main(args.src_repo, args.dest_repo)
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 4ffe6a1e..092601f5 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -22,7 +22,7 @@ import logging
import sys
from typing import Callable, Optional
-import requests as _requests
+import requests
import yaml
@@ -33,7 +33,6 @@ def request_registration(
shared_secret: str,
admin: bool = False,
user_type: Optional[str] = None,
- requests=_requests,
_print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit,
) -> None:
diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
index e207f154..a935c50f 100644
--- a/synapse/_scripts/review_recent_signups.py
+++ b/synapse/_scripts/review_recent_signups.py
@@ -138,9 +138,7 @@ def main() -> None:
config_args = parser.parse_args(sys.argv[1:])
config_files = find_config_files(search_paths=config_args.config_path)
config_dict = read_config_files(config_files)
- config.parse_config_dict(
- config_dict,
- )
+ config.parse_config_dict(config_dict, "", "")
since_ms = time.time() * 1000 - Config.parse_duration(config_args.since)
exclude_users_with_email = config_args.exclude_emails
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index c38666da..12ff79f6 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -21,12 +21,29 @@ import logging
import sys
import time
import traceback
-from typing import Dict, Iterable, Optional, Set
+from types import TracebackType
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ NoReturn,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ cast,
+)
import yaml
from matrix_common.versionstring import get_distribution_version_string
+from typing_extensions import TypedDict
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as reactor_
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
@@ -35,7 +52,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
-from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
@@ -66,8 +83,12 @@ from synapse.storage.databases.main.user_directory import (
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+from synapse.types import ISynapseReactor
from synapse.util import Clock
+# Cast safety: Twisted does some naughty magic which replaces the
+# twisted.internet.reactor module with a Reactor instance at runtime.
+reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db")
@@ -97,6 +118,7 @@ BOOLEAN_COLUMNS = {
"users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"],
+ "device_lists_changes_in_room": ["converted_to_destinations"],
}
@@ -158,12 +180,16 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
-end_error = None # type: Optional[str]
+end_error: Optional[str] = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
-end_error_exec_info = None
+end_error_exec_info: Optional[
+ Tuple[Type[BaseException], BaseException, TracebackType]
+] = None
+
+R = TypeVar("R")
class Store(
@@ -187,17 +213,19 @@ class Store(
PresenceBackgroundUpdateStore,
GroupServerWorkerStore,
):
- def execute(self, f, *args, **kwargs):
+ def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
- def execute_sql(self, sql, *args):
- def r(txn):
+ def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]:
+ def r(txn: LoggingTransaction) -> List[Tuple]:
txn.execute(sql, args)
return txn.fetchall()
return self.db_pool.runInteraction("execute_sql", r)
- def insert_many_txn(self, txn, table, headers, rows):
+ def insert_many_txn(
+ self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple]
+ ) -> None:
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
@@ -210,14 +238,15 @@ class Store(
logger.exception("Failed to insert: %s", table)
raise
- def set_room_is_public(self, room_id, is_public):
+ # Note: the parent method is an `async def`.
+ def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn:
raise Exception(
"Attempt to set room_is_public during port_db: database not empty?"
)
class MockHomeserver:
- def __init__(self, config):
+ def __init__(self, config: HomeServerConfig):
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server.server_name
@@ -225,21 +254,30 @@ class MockHomeserver:
"matrix-synapse"
)
- def get_clock(self):
+ def get_clock(self) -> Clock:
return self.clock
- def get_reactor(self):
+ def get_reactor(self) -> ISynapseReactor:
return reactor
- def get_instance_name(self):
+ def get_instance_name(self) -> str:
return "master"
-class Porter(object):
- def __init__(self, **kwargs):
- self.__dict__.update(kwargs)
+class Porter:
+ def __init__(
+ self,
+ sqlite_config: Dict[str, Any],
+ progress: "Progress",
+ batch_size: int,
+ hs_config: HomeServerConfig,
+ ):
+ self.sqlite_config = sqlite_config
+ self.progress = progress
+ self.batch_size = batch_size
+ self.hs_config = hs_config
- async def setup_table(self, table):
+ async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
row = await self.postgres_store.db_pool.simple_select_one(
@@ -281,7 +319,7 @@ class Porter(object):
)
else:
- def delete_all(txn):
+ def delete_all(txn: LoggingTransaction) -> None:
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
)
@@ -306,7 +344,7 @@ class Porter(object):
async def get_table_constraints(self) -> Dict[str, Set[str]]:
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
- def _get_constraints(txn):
+ def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]:
# We can pull the information about foreign key constraints out from
# the postgres schema tables.
sql = """
@@ -322,7 +360,7 @@ class Porter(object):
"""
txn.execute(sql)
- results = {}
+ results: Dict[str, Set[str]] = {}
for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table)
return results
@@ -332,8 +370,13 @@ class Porter(object):
)
async def handle_table(
- self, table, postgres_size, table_size, forward_chunk, backward_chunk
- ):
+ self,
+ table: str,
+ postgres_size: int,
+ table_size: int,
+ forward_chunk: int,
+ backward_chunk: int,
+ ) -> None:
logger.info(
"Table %s: %i/%i (rows %i-%i) already ported",
table,
@@ -380,7 +423,9 @@ class Porter(object):
while True:
- def r(txn):
+ def r(
+ txn: LoggingTransaction,
+ ) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
forward_rows = []
backward_rows = []
if do_forward[0]:
@@ -407,6 +452,7 @@ class Porter(object):
)
if frows or brows:
+ assert headers is not None
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
@@ -415,7 +461,8 @@ class Porter(object):
rows = frows + brows
rows = self._convert_rows(table, headers, rows)
- def insert(txn):
+ def insert(txn: LoggingTransaction) -> None:
+ assert headers is not None
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store.db_pool.simple_update_one_txn(
@@ -437,8 +484,12 @@ class Porter(object):
return
async def handle_search_table(
- self, postgres_size, table_size, forward_chunk, backward_chunk
- ):
+ self,
+ postgres_size: int,
+ table_size: int,
+ forward_chunk: int,
+ backward_chunk: int,
+ ) -> None:
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
@@ -449,7 +500,7 @@ class Porter(object):
while True:
- def r(txn):
+ def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
@@ -463,7 +514,7 @@ class Porter(object):
# We have to treat event_search differently since it has a
# different structure in the two different databases.
- def insert(txn):
+ def insert(txn: LoggingTransaction) -> None:
sql = (
"INSERT INTO event_search (event_id, room_id, key,"
" sender, vector, origin_server_ts, stream_ordering)"
@@ -517,7 +568,7 @@ class Porter(object):
self,
db_config: DatabaseConnectionConfig,
allow_outdated_version: bool = False,
- ):
+ ) -> Store:
"""Builds and returns a database store using the provided configuration.
Args:
@@ -539,12 +590,13 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
- store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
+ # Type safety: ignore that we're using Mock homeservers here.
+ store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
db_conn.commit()
return store
- async def run_background_updates_on_postgres(self):
+ async def run_background_updates_on_postgres(self) -> None:
# Manually apply all background updates on the PostgreSQL database.
postgres_ready = (
await self.postgres_store.db_pool.updates.has_completed_background_updates()
@@ -556,12 +608,12 @@ class Porter(object):
self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready:
- await self.postgres_store.db_pool.updates.do_next_background_update(100)
+ await self.postgres_store.db_pool.updates.do_next_background_update(True)
postgres_ready = await (
self.postgres_store.db_pool.updates.has_completed_background_updates()
)
- async def run(self):
+ async def run(self) -> None:
"""Ports the SQLite database to a PostgreSQL database.
When a fatal error is met, its message is assigned to the global "end_error"
@@ -597,7 +649,7 @@ class Porter(object):
self.progress.set_state("Creating port tables")
- def create_port_table(txn):
+ def create_port_table(txn: LoggingTransaction) -> None:
txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
@@ -610,7 +662,7 @@ class Porter(object):
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
- def alter_table(txn):
+ def alter_table(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
@@ -723,12 +775,16 @@ class Porter(object):
except Exception as e:
global end_error_exec_info
end_error = str(e)
- end_error_exec_info = sys.exc_info()
+ # Type safety: we're in an exception handler, so the exc_info() tuple
+ # will not be (None, None, None).
+ end_error_exec_info = sys.exc_info() # type: ignore[assignment]
logger.exception("")
finally:
reactor.stop()
- def _convert_rows(self, table, headers, rows):
+ def _convert_rows(
+ self, table: str, headers: List[str], rows: List[Tuple]
+ ) -> List[Tuple]:
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
@@ -736,7 +792,7 @@ class Porter(object):
class BadValueException(Exception):
pass
- def conv(j, col):
+ def conv(j: int, col: object) -> object:
if j in bool_cols:
return bool(col)
if isinstance(col, bytes):
@@ -762,7 +818,7 @@ class Porter(object):
return outrows
- async def _setup_sent_transactions(self):
+ async def _setup_sent_transactions(self) -> Tuple[int, int, int]:
# Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000
@@ -774,10 +830,10 @@ class Porter(object):
")"
)
- def r(txn):
+ def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
txn.execute(select)
rows = txn.fetchall()
- headers = [column[0] for column in txn.description]
+ headers: List[str] = [column[0] for column in txn.description]
ts_ind = headers.index("ts")
@@ -791,7 +847,7 @@ class Porter(object):
if inserted_rows:
max_inserted_rowid = max(r[0] for r in rows)
- def insert(txn):
+ def insert(txn: LoggingTransaction) -> None:
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
@@ -800,7 +856,7 @@ class Porter(object):
else:
max_inserted_rowid = 0
- def get_start_id(txn):
+ def get_start_id(txn: LoggingTransaction) -> int:
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
@@ -825,12 +881,13 @@ class Porter(object):
},
)
- def get_sent_table_size(txn):
+ def get_sent_table_size(txn: LoggingTransaction) -> int:
txn.execute(
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
)
- (size,) = txn.fetchone()
- return int(size)
+ result = txn.fetchone()
+ assert result is not None
+ return int(result[0])
remaining_count = await self.sqlite_store.execute(get_sent_table_size)
@@ -838,25 +895,35 @@ class Porter(object):
return next_chunk, inserted_rows, total_count
- async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
- frows = await self.sqlite_store.execute_sql(
- "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
+ async def _get_remaining_count_to_port(
+ self, table: str, forward_chunk: int, backward_chunk: int
+ ) -> int:
+ frows = cast(
+ List[Tuple[int]],
+ await self.sqlite_store.execute_sql(
+ "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
+ ),
)
- brows = await self.sqlite_store.execute_sql(
- "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
+ brows = cast(
+ List[Tuple[int]],
+ await self.sqlite_store.execute_sql(
+ "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
+ ),
)
return frows[0][0] + brows[0][0]
- async def _get_already_ported_count(self, table):
+ async def _get_already_ported_count(self, table: str) -> int:
rows = await self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,)
)
return rows[0][0]
- async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
+ async def _get_total_count_to_port(
+ self, table: str, forward_chunk: int, backward_chunk: int
+ ) -> Tuple[int, int]:
remaining, done = await make_deferred_yieldable(
defer.gatherResults(
[
@@ -877,14 +944,17 @@ class Porter(object):
return done, remaining + done
async def _setup_state_group_id_seq(self) -> None:
- curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+ curr_id: Optional[
+ int
+ ] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
)
if not curr_id:
return
- def r(txn):
+ def r(txn: LoggingTransaction) -> None:
+ assert curr_id is not None
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
@@ -895,7 +965,7 @@ class Porter(object):
"setup_user_id_seq", find_max_generated_user_id_localpart
)
- def r(txn):
+ def r(txn: LoggingTransaction) -> None:
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
@@ -917,7 +987,7 @@ class Porter(object):
allow_none=True,
)
- def _setup_events_stream_seqs_set_pos(txn):
+ def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None:
if curr_forward_id:
txn.execute(
"ALTER SEQUENCE events_stream_seq RESTART WITH %s",
@@ -941,17 +1011,20 @@ class Porter(object):
"""Set a sequence to the correct value."""
current_stream_ids = []
for stream_id_table in stream_id_tables:
- max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
- table=stream_id_table,
- keyvalues={},
- retcol="COALESCE(MAX(stream_id), 1)",
- allow_none=True,
+ max_stream_id = cast(
+ int,
+ await self.sqlite_store.db_pool.simple_select_one_onecol(
+ table=stream_id_table,
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), 1)",
+ allow_none=True,
+ ),
)
current_stream_ids.append(max_stream_id)
next_id = max(current_stream_ids) + 1
- def r(txn):
+ def r(txn: LoggingTransaction) -> None:
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
txn.execute(sql + " %s", (next_id,))
@@ -960,14 +1033,18 @@ class Porter(object):
)
async def _setup_auth_chain_sequence(self) -> None:
- curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+ curr_chain_id: Optional[
+ int
+ ] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains",
keyvalues={},
retcol="MAX(chain_id)",
allow_none=True,
)
- def r(txn):
+ def r(txn: LoggingTransaction) -> None:
+ # Presumably there is at least one row in event_auth_chains.
+ assert curr_chain_id is not None
txn.execute(
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
(curr_chain_id + 1,),
@@ -985,15 +1062,22 @@ class Porter(object):
##############################################
-class Progress(object):
+class TableProgress(TypedDict):
+ start: int
+ num_done: int
+ total: int
+ perc: int
+
+
+class Progress:
"""Used to report progress of the port"""
- def __init__(self):
- self.tables = {}
+ def __init__(self) -> None:
+ self.tables: Dict[str, TableProgress] = {}
self.start_time = int(time.time())
- def add_table(self, table, cur, size):
+ def add_table(self, table: str, cur: int, size: int) -> None:
self.tables[table] = {
"start": cur,
"num_done": cur,
@@ -1001,19 +1085,22 @@ class Progress(object):
"perc": int(cur * 100 / size),
}
- def update(self, table, num_done):
+ def update(self, table: str, num_done: int) -> None:
data = self.tables[table]
data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"])
- def done(self):
+ def done(self) -> None:
+ pass
+
+ def set_state(self, state: str) -> None:
pass
class CursesProgress(Progress):
"""Reports progress to a curses window"""
- def __init__(self, stdscr):
+ def __init__(self, stdscr: "curses.window"):
self.stdscr = stdscr
curses.use_default_colors()
@@ -1022,7 +1109,7 @@ class CursesProgress(Progress):
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
- self.last_update = 0
+ self.last_update = 0.0
self.finished = False
@@ -1031,7 +1118,7 @@ class CursesProgress(Progress):
super(CursesProgress, self).__init__()
- def update(self, table, num_done):
+ def update(self, table: str, num_done: int) -> None:
super(CursesProgress, self).update(table, num_done)
self.total_processed = 0
@@ -1042,7 +1129,7 @@ class CursesProgress(Progress):
self.render()
- def render(self, force=False):
+ def render(self, force: bool = False) -> None:
now = time.time()
if not force and now - self.last_update < 0.2:
@@ -1081,8 +1168,7 @@ class CursesProgress(Progress):
left_margin = 5
middle_space = 1
- items = self.tables.items()
- items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
+ items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
@@ -1115,12 +1201,12 @@ class CursesProgress(Progress):
self.stdscr.refresh()
self.last_update = time.time()
- def done(self):
+ def done(self) -> None:
self.finished = True
self.render(True)
self.stdscr.getch()
- def set_state(self, state):
+ def set_state(self, state: str) -> None:
self.stdscr.clear()
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
self.stdscr.refresh()
@@ -1129,7 +1215,7 @@ class CursesProgress(Progress):
class TerminalProgress(Progress):
"""Just prints progress to the terminal"""
- def update(self, table, num_done):
+ def update(self, table: str, num_done: int) -> None:
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
@@ -1138,7 +1224,7 @@ class TerminalProgress(Progress):
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
)
- def set_state(self, state):
+ def set_state(self, state: str) -> None:
print(state + "...")
@@ -1146,7 +1232,7 @@ class TerminalProgress(Progress):
##############################################
-def main():
+def main() -> None:
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
@@ -1178,15 +1264,11 @@ def main():
args = parser.parse_args()
- logging_config = {
- "level": logging.DEBUG if args.v else logging.INFO,
- "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
- }
-
- if args.curses:
- logging_config["filename"] = "port-synapse.log"
-
- logging.basicConfig(**logging_config)
+ logging.basicConfig(
+ level=logging.DEBUG if args.v else logging.INFO,
+ format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+ filename="port-synapse.log" if args.curses else None,
+ )
sqlite_config = {
"name": "sqlite3",
@@ -1216,7 +1298,8 @@ def main():
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
- def start(stdscr=None):
+ def start(stdscr: Optional["curses.window"] = None) -> None:
+ progress: Progress
if stdscr:
progress = CursesProgress(stdscr)
else:
@@ -1230,7 +1313,7 @@ def main():
)
@defer.inlineCallbacks
- def run():
+ def run() -> Generator["defer.Deferred[Any]", Any, None]:
with LoggingContext("synapse_port_db_run"):
yield defer.ensureDeferred(porter.run())
diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py
index 1ab36949..b4c96ad7 100755
--- a/synapse/_scripts/synctl.py
+++ b/synapse/_scripts/synctl.py
@@ -24,7 +24,7 @@ import signal
import subprocess
import sys
import time
-from typing import Iterable, Optional
+from typing import Iterable, NoReturn, Optional, TextIO
import yaml
@@ -45,7 +45,7 @@ one of the following:
--------------------------------------------------------------------------------"""
-def pid_running(pid):
+def pid_running(pid: int) -> bool:
try:
os.kill(pid, 0)
except OSError as err:
@@ -68,7 +68,7 @@ def pid_running(pid):
return True
-def write(message, colour=NORMAL, stream=sys.stdout):
+def write(message: str, colour: str = NORMAL, stream: TextIO = sys.stdout) -> None:
# Lets check if we're writing to a TTY before colouring
should_colour = False
try:
@@ -84,7 +84,7 @@ def write(message, colour=NORMAL, stream=sys.stdout):
stream.write(colour + message + NORMAL + "\n")
-def abort(message, colour=RED, stream=sys.stderr):
+def abort(message: str, colour: str = RED, stream: TextIO = sys.stderr) -> NoReturn:
write(message, colour, stream)
sys.exit(1)
@@ -166,7 +166,7 @@ Worker = collections.namedtuple(
)
-def main():
+def main() -> None:
parser = argparse.ArgumentParser()
diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py
index f43676af..c443522c 100755
--- a/synapse/_scripts/update_synapse_database.py
+++ b/synapse/_scripts/update_synapse_database.py
@@ -16,42 +16,47 @@
import argparse
import logging
import sys
+from typing import cast
import yaml
from matrix_common.versionstring import get_distribution_version_string
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as reactor_
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.types import ISynapseReactor
+# Cast safety: Twisted does some naughty magic which replaces the
+# twisted.internet.reactor module with a Reactor instance at runtime.
+reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer):
- DATASTORE_CLASS = DataStore
+ DATASTORE_CLASS = DataStore # type: ignore [assignment]
- def __init__(self, config, **kwargs):
+ def __init__(self, config: HomeServerConfig):
super(MockHomeserver, self).__init__(
- config.server.server_name, reactor=reactor, config=config, **kwargs
- )
-
- self.version_string = "Synapse/" + get_distribution_version_string(
- "matrix-synapse"
+ hostname=config.server.server_name,
+ config=config,
+ reactor=reactor,
+ version_string="Synapse/"
+ + get_distribution_version_string("matrix-synapse"),
)
-def run_background_updates(hs):
+def run_background_updates(hs: HomeServer) -> None:
store = hs.get_datastores().main
- async def run_background_updates():
+ async def run_background_updates() -> None:
await store.db_pool.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run.
reactor.stop()
- def run():
+ def run() -> None:
# Apply all background updates on the database.
defer.ensureDeferred(
run_as_background_process("background_updates", run_background_updates)
@@ -62,7 +67,7 @@ def run_background_updates(hs):
reactor.run()
-def main():
+def main() -> None:
parser = argparse.ArgumentParser(
description=(
"Updates a synapse database to the latest schema and optionally runs background updates"
@@ -85,12 +90,10 @@ def main():
args = parser.parse_args()
- logging_config = {
- "level": logging.DEBUG if args.v else logging.INFO,
- "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
- }
-
- logging.basicConfig(**logging_config)
+ logging.basicConfig(
+ level=logging.DEBUG if args.v else logging.INFO,
+ format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+ )
# Load, process and sanity-check the config.
hs_config = yaml.safe_load(args.database_config)
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3e59805b..37321f91 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -130,7 +130,7 @@ def start_reactor(
appname: str,
soft_file_limit: int,
gc_thresholds: Optional[Tuple[int, int, int]],
- pid_file: str,
+ pid_file: Optional[str],
daemonize: bool,
print_pidfile: bool,
logger: logging.Logger,
@@ -171,6 +171,8 @@ def start_reactor(
# appearing to go backwards.
with PreserveLoggingContext():
if daemonize:
+ assert pid_file is not None
+
if print_pidfile:
print(pid_file)
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 6f8e33a1..2b0d92cb 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -33,7 +33,6 @@ from synapse.handlers.admin import ExfiltrationWriter
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -61,7 +60,6 @@ class AdminCmdSlavedStore(
SlavedDeviceStore,
SlavedPushRuleStore,
SlavedEventStore,
- SlavedClientIpStore,
BaseSlavedStore,
RoomWorkerStore,
):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index b6f510ed..1865c671 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -53,7 +53,6 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
@@ -247,7 +246,6 @@ class GenericWorkerSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedProfileStore,
- SlavedClientIpStore,
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 07ec95f1..d23d9221 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +23,13 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
+from synapse.types import (
+ DeviceListUpdates,
+ GroupID,
+ JsonDict,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
@@ -400,6 +407,7 @@ class AppServiceTransaction:
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
+ device_list_summary: DeviceListUpdates,
):
self.service = service
self.id = id
@@ -408,6 +416,7 @@ class AppServiceTransaction:
self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys
+ self.device_list_summary = device_list_summary
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@@ -424,6 +433,7 @@ class AppServiceTransaction:
to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys,
+ device_list_summary=self.device_list_summary,
txn_id=self.id,
)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 98fe3540..0cdbb04b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ from synapse.appservice import (
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient
-from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient):
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
+ device_list_summary: DeviceListUpdates,
txn_id: Optional[int] = None,
) -> bool:
"""
@@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient):
}
)
+ # TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions:
if one_time_key_counts:
body[
@@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient):
body[
"org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys
+ if device_list_summary:
+ body["org.matrix.msc3202.device_lists"] = {
+ "changed": list(device_list_summary.changed),
+ "left": list(device_list_summary.left),
+ }
try:
await self.put_json(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index a6084b9c..3b49e607 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -72,7 +72,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
-from synapse.types import JsonDict
+from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import Clock
if TYPE_CHECKING:
@@ -122,6 +122,7 @@ class ApplicationServiceScheduler:
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None,
+ device_list_summary: Optional[DeviceListUpdates] = None,
) -> None:
"""
Enqueue some data to be sent off to an application service.
@@ -133,10 +134,18 @@ class ApplicationServiceScheduler:
to_device_messages: The to-device messages to send. These differ from normal
to-device messages sent to clients, as they have 'to_device_id' and
'to_user_id' fields.
+ device_list_summary: A summary of users that the application service either needs
+ to refresh the device lists of, or those that the application service need no
+ longer track the device lists of.
"""
# We purposefully allow this method to run with empty events/ephemeral
# collections, so that callers do not need to check iterable size themselves.
- if not events and not ephemeral and not to_device_messages:
+ if (
+ not events
+ and not ephemeral
+ and not to_device_messages
+ and not device_list_summary
+ ):
return
if events:
@@ -147,6 +156,10 @@ class ApplicationServiceScheduler:
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
to_device_messages
)
+ if device_list_summary:
+ self.queuer.queued_device_list_summaries.setdefault(
+ appservice.id, []
+ ).append(device_list_summary)
# Kick off a new application service transaction
self.queuer.start_background_request(appservice)
@@ -169,6 +182,8 @@ class _ServiceQueuer:
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
+ # dict of {service_id: [device_list_summary]}
+ self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
@@ -212,7 +227,35 @@ class _ServiceQueuer:
]
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
- if not events and not ephemeral and not to_device_messages_to_send:
+ # Consolidate any pending device list summaries into a single, up-to-date
+ # summary.
+ # Note: this code assumes that in a single DeviceListUpdates, a user will
+ # never be in both "changed" and "left" sets.
+ device_list_summary = DeviceListUpdates()
+ for summary in self.queued_device_list_summaries.get(service.id, []):
+ # For every user in the incoming "changed" set:
+ # * Remove them from the existing "left" set if necessary
+ # (as we need to start tracking them again)
+ # * Add them to the existing "changed" set if necessary.
+ device_list_summary.left.difference_update(summary.changed)
+ device_list_summary.changed.update(summary.changed)
+
+ # For every user in the incoming "left" set:
+ # * Remove them from the existing "changed" set if necessary
+ # (we no longer need to track them)
+ # * Add them to the existing "left" set if necessary.
+ device_list_summary.changed.difference_update(summary.left)
+ device_list_summary.left.update(summary.left)
+ self.queued_device_list_summaries.clear()
+
+ if (
+ not events
+ and not ephemeral
+ and not to_device_messages_to_send
+ # DeviceListUpdates is True if either the 'changed' or 'left' sets have
+ # at least one entry, otherwise False
+ and not device_list_summary
+ ):
return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
@@ -240,6 +283,7 @@ class _ServiceQueuer:
to_device_messages_to_send,
one_time_key_counts,
unused_fallback_keys,
+ device_list_summary,
)
except Exception:
logger.exception("AS request failed")
@@ -322,6 +366,7 @@ class _TransactionController:
to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
+ device_list_summary: Optional[DeviceListUpdates] = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
@@ -336,6 +381,7 @@ class _TransactionController:
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
+ device_list_summary: The device list summary to include in the transaction.
"""
try:
txn = await self.store.create_appservice_txn(
@@ -345,6 +391,7 @@ class _TransactionController:
to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {},
+ device_list_summary=device_list_summary or DeviceListUpdates(),
)
service_is_up = await self._is_service_up(service)
if service_is_up:
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 8e19e2fc..179aa7ff 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -702,10 +702,7 @@ class RootConfig:
return obj
def parse_config_dict(
- self,
- config_dict: Dict[str, Any],
- config_dir_path: Optional[str] = None,
- data_dir_path: Optional[str] = None,
+ self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str
) -> None:
"""Read the information from the config dict into this Config object.
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 363d8b45..bd092f95 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -124,10 +124,7 @@ class RootConfig:
@classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
def parse_config_dict(
- self,
- config_dict: Dict[str, Any],
- config_dir_path: Optional[str] = ...,
- data_dir_path: Optional[str] = ...,
+ self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str
) -> None: ...
def generate_config(
self,
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index c533452c..d1335e77 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Any
from synapse.config._base import Config, ConfigError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -29,7 +31,7 @@ https://matrix-org.github.io/synapse/latest/templates.html
class AccountValidityConfig(Config):
section = "account_validity"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
"""Parses the old account validity config. The config format looks like this:
account_validity:
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 8133b6b6..2cc63053 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Iterable
+from typing import Any, Iterable
from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
@@ -26,12 +26,12 @@ logger = logging.getLogger(__name__)
class ApiConfig(Config):
section = "api"
- def read_config(self, config: JsonDict, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
- def generate_config_section(cls, **kwargs) -> str:
+ def generate_config_section(cls, **kwargs: Any) -> str:
formatted_default_state_types = "\n".join(
" # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
)
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 439bfe15..720b90a2 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, List
+from typing import Any, Dict, List
from urllib import parse as urlparse
import yaml
@@ -31,12 +31,12 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
section = "appservice"
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
- def generate_config_section(cls, **kwargs) -> str:
+ def generate_config_section(cls, **kwargs: Any) -> str:
return """\
# A list of application service config files to use
#
@@ -170,6 +170,7 @@ def _load_appservice(
# When enabled, appservice transactions contain the following information:
# - device One-Time Key counts
# - device unused fallback key usage states
+ # - device list changes
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError(
diff --git a/synapse/config/auth.py b/synapse/config/auth.py
index ba8bf9cb..bb417a23 100644
--- a/synapse/config/auth.py
+++ b/synapse/config/auth.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
from ._base import Config
@@ -21,7 +24,7 @@ class AuthConfig(Config):
section = "auth"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
password_config = config.get("password_config", {})
if password_config is None:
password_config = {}
@@ -40,7 +43,7 @@ class AuthConfig(Config):
ui_auth.get("session_timeout", 0)
)
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
password_config:
# Uncomment to disable password login
diff --git a/synapse/config/background_updates.py b/synapse/config/background_updates.py
index f6cdeacc..07fadbe0 100644
--- a/synapse/config/background_updates.py
+++ b/synapse/config/background_updates.py
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
from ._base import Config
@@ -18,7 +21,7 @@ from ._base import Config
class BackgroundUpdateConfig(Config):
section = "background_updates"
- def generate_config_section(self, **kwargs) -> str:
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Background Updates ##
@@ -52,7 +55,7 @@ class BackgroundUpdateConfig(Config):
#default_batch_size: 50
"""
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
bg_update_config = config.get("background_updates") or {}
self.update_duration_ms = bg_update_config.get(
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 9a68da9c..94d852f4 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -16,10 +16,11 @@ import logging
import os
import re
import threading
-from typing import Callable, Dict, Optional
+from typing import Any, Callable, Dict, Optional
import attr
+from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -105,7 +106,7 @@ class CacheConfig(Config):
with _CACHES_LOCK:
_CACHES.clear()
- def generate_config_section(self, **kwargs) -> str:
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Caching ##
@@ -172,7 +173,7 @@ class CacheConfig(Config):
#sync_response_cache_duration: 2m
"""
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
)
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 9e48f865..92c603f2 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -12,15 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from typing import Any
+
+from synapse.types import JsonDict
+
+from ._base import Config, ConfigError
class CaptchaConfig(Config):
section = "captcha"
- def read_config(self, config, **kwargs):
- self.recaptcha_private_key = config.get("recaptcha_private_key")
- self.recaptcha_public_key = config.get("recaptcha_public_key")
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+ recaptcha_private_key = config.get("recaptcha_private_key")
+ if recaptcha_private_key is not None and not isinstance(
+ recaptcha_private_key, str
+ ):
+ raise ConfigError("recaptcha_private_key must be a string.")
+ self.recaptcha_private_key = recaptcha_private_key
+
+ recaptcha_public_key = config.get("recaptcha_public_key")
+ if recaptcha_public_key is not None and not isinstance(
+ recaptcha_public_key, str
+ ):
+ raise ConfigError("recaptcha_public_key must be a string.")
+ self.recaptcha_public_key = recaptcha_public_key
+
self.enable_registration_captcha = config.get(
"enable_registration_captcha", False
)
@@ -30,7 +46,7 @@ class CaptchaConfig(Config):
)
self.recaptcha_template = self.read_template("recaptcha.html")
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Captcha ##
# See docs/CAPTCHA_SETUP.md for full details of configuring this.
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 6f275409..8af0794b 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -16,6 +16,7 @@
from typing import Any, List
from synapse.config.sso import SsoAttributeRequirement
+from synapse.types import JsonDict
from ._base import Config
from ._util import validate_config
@@ -29,7 +30,7 @@ class CasConfig(Config):
section = "cas"
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
cas_config = config.get("cas_config", None)
self.cas_enabled = cas_config and cas_config.get("enabled", True)
@@ -52,7 +53,7 @@ class CasConfig(Config):
self.cas_displayname_attribute = None
self.cas_required_attributes = []
- def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Enable Central Authentication Service (CAS) for registration and login.
#
diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index ecc43b08..8ee3d345 100644
--- a/synapse/config/consent.py
+++ b/synapse/config/consent.py
@@ -13,9 +13,10 @@
# limitations under the License.
from os import path
-from typing import Optional
+from typing import Any, Optional
from synapse.config import ConfigError
+from synapse.types import JsonDict
from ._base import Config
@@ -76,18 +77,18 @@ class ConsentConfig(Config):
section = "consent"
- def __init__(self, *args):
+ def __init__(self, *args: Any):
super().__init__(*args)
self.user_consent_version: Optional[str] = None
self.user_consent_template_dir: Optional[str] = None
- self.user_consent_server_notice_content = None
+ self.user_consent_server_notice_content: Optional[JsonDict] = None
self.user_consent_server_notice_to_guests = False
- self.block_events_without_consent_error = None
+ self.block_events_without_consent_error: Optional[str] = None
self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
consent_config = config.get("user_consent")
self.terms_template = self.read_template("terms.html")
@@ -118,5 +119,5 @@ class ConsentConfig(Config):
"policy_name", "Privacy Policy"
)
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return DEFAULT_CONFIG
diff --git a/synapse/config/database.py b/synapse/config/database.py
index d7f2219f..de0d3ca0 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -15,8 +15,10 @@
import argparse
import logging
import os
+from typing import Any, List
from synapse.config._base import Config, ConfigError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -121,12 +123,12 @@ class DatabaseConnectionConfig:
class DatabaseConfig(Config):
section = "database"
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+ def __init__(self, *args: Any):
+ super().__init__(*args)
- self.databases = []
+ self.databases: List[DatabaseConnectionConfig] = []
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# We *experimentally* support specifying multiple databases via the
# `databases` key. This is a map from a label to database config in the
# same format as the `database` config option, plus an extra
@@ -170,7 +172,7 @@ class DatabaseConfig(Config):
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(database_path)
- def generate_config_section(self, data_dir_path, **kwargs) -> str:
+ def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db")
}
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 949d7dd5..5b5c2f4f 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -19,9 +19,12 @@ import email.utils
import logging
import os
from enum import Enum
+from typing import Any
import attr
+from synapse.types import JsonDict
+
from ._base import Config, ConfigError
logger = logging.getLogger(__name__)
@@ -73,7 +76,7 @@ class EmailSubjectConfig:
class EmailConfig(Config):
section = "email"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# TODO: We should separate better the email configuration from the notification
# and account validity config.
@@ -354,7 +357,7 @@ class EmailConfig(Config):
path=("email", "invite_client_location"),
)
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return (
"""\
# Configuration for sending emails from Synapse.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 064db448..447476fb 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
from synapse.config._base import Config
from synapse.types import JsonDict
@@ -21,13 +23,11 @@ class ExperimentalConfig(Config):
section = "experimental"
- def read_config(self, config: JsonDict, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
experimental = config.get("experimental_features") or {}
# MSC3440 (thread relation)
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
- # MSC3666: including bundled relations in /search.
- self.msc3666_enabled: bool = experimental.get("msc3666_enabled", False)
# MSC3026 (busy presence state)
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
@@ -59,8 +59,9 @@ class ExperimentalConfig(Config):
"msc3202_device_masquerading", False
)
- # Portion of MSC3202 related to transaction extensions:
- # sending one-time key counts and fallback key usage to application services.
+ # The portion of MSC3202 related to transaction extensions:
+ # sending device list changes, one-time key counts and fallback key
+ # usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False
)
@@ -77,3 +78,6 @@ class ExperimentalConfig(Config):
# The deprecated groups feature.
self.groups_enabled: bool = experimental.get("groups_enabled", True)
+
+ # MSC2654: Unread counts
+ self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 7d64993e..0e74f707 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Any, Optional
from synapse.config._base import Config
from synapse.config._util import validate_config
+from synapse.types import JsonDict
class FederationConfig(Config):
section = "federation"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist: Optional[dict] = None
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
@@ -48,7 +49,7 @@ class FederationConfig(Config):
"allow_device_name_lookup_over_federation", True
)
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Federation ##
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index 15c2e64b..c9b9c6da 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
+
from ._base import Config
class GroupsConfig(Config):
section = "groups"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Uncomment to allow non-server-admin users to create groups on this server
#
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 24c3ef01..2a756d1a 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
+
from ._base import Config, ConfigError
MISSING_JWT = """Missing jwt library. This is required for jwt login.
@@ -24,7 +28,7 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
class JWTConfig(Config):
section = "jwt"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
@@ -52,7 +56,7 @@ class JWTConfig(Config):
self.jwt_issuer = None
self.jwt_audiences = None
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# JSON web token integration. The following settings can be used to make
# Synapse JSON web tokens for authentication, instead of its internal
diff --git a/synapse/config/key.py b/synapse/config/key.py
index ee83c6c0..ada65f6d 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -16,7 +16,7 @@
import hashlib
import logging
import os
-from typing import Any, Dict, Iterator, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
import attr
import jsonschema
@@ -38,6 +38,9 @@ from synapse.util.stringutils import random_string, random_string_with_symbols
from ._base import Config, ConfigError
+if TYPE_CHECKING:
+ from signedjson.key import VerifyKeyWithExpiry
+
INSECURE_NOTARY_ERROR = """\
Your server is configured to accept key server responses without signature
validation or TLS certificate validation. This is likely to be very insecure. If
@@ -96,11 +99,14 @@ class TrustedKeyServer:
class KeyConfig(Config):
section = "key"
- def read_config(self, config, config_dir_path, **kwargs):
+ def read_config(
+ self, config: JsonDict, config_dir_path: str, **kwargs: Any
+ ) -> None:
# the signing key can be specified inline or in a separate file
if "signing_key" in config:
self.signing_key = read_signing_keys([config["signing_key"]])
else:
+ assert config_dir_path is not None
signing_key_path = config.get("signing_key_path")
if signing_key_path is None:
signing_key_path = os.path.join(
@@ -169,8 +175,12 @@ class KeyConfig(Config):
self.form_secret = config.get("form_secret", None)
def generate_config_section(
- self, config_dir_path, server_name, generate_secrets=False, **kwargs
- ):
+ self,
+ config_dir_path: str,
+ server_name: str,
+ generate_secrets: bool = False,
+ **kwargs: Any,
+ ) -> str:
base_key_name = os.path.join(config_dir_path, server_name)
if generate_secrets:
@@ -300,7 +310,7 @@ class KeyConfig(Config):
def read_old_signing_keys(
self, old_signing_keys: Optional[JsonDict]
- ) -> Dict[str, VerifyKey]:
+ ) -> Dict[str, "VerifyKeyWithExpiry"]:
if old_signing_keys is None:
return {}
keys = {}
@@ -308,8 +318,8 @@ class KeyConfig(Config):
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.expired_ts = key_data["expired_ts"]
+ verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment]
+ verify_key.expired = key_data["expired_ts"]
keys[key_id] = verify_key
else:
raise ConfigError(
@@ -422,7 +432,7 @@ def _parse_key_servers(
server_name = server["server_name"]
result = TrustedKeyServer(server_name=server_name)
- verify_keys = server.get("verify_keys")
+ verify_keys: Optional[Dict[str, str]] = server.get("verify_keys")
if verify_keys is not None:
result.verify_keys = {}
for key_id, key_base64 in verify_keys.items():
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index cbbe2219..99db9e1e 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -35,6 +35,7 @@ from twisted.logger import (
from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter
+from synapse.types import JsonDict
from ._base import Config, ConfigError
@@ -147,13 +148,15 @@ https://matrix-org.github.io/synapse/v1.54/structured_logging.html
class LoggingConfig(Config):
section = "logging"
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
if config.get("log_file"):
raise ConfigError(LOG_FILE_ERROR)
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
- def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
+ def generate_config_section(
+ self, config_dir_path: str, server_name: str, **kwargs: Any
+ ) -> str:
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return (
"""\
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index f62292ec..aa360a41 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Optional
+
import attr
+from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -37,7 +40,7 @@ class MetricsFlags:
class MetricsConfig(Config):
section = "metrics"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)
self.report_stats_endpoint = config.get(
@@ -67,7 +70,9 @@ class MetricsConfig(Config):
"sentry.dsn field is required when sentry integration is enabled"
)
- def generate_config_section(self, report_stats=None, **kwargs):
+ def generate_config_section(
+ self, report_stats: Optional[bool] = None, **kwargs: Any
+ ) -> str:
res = """\
## Metrics ###
diff --git a/synapse/config/modules.py b/synapse/config/modules.py
index 2ef02b8f..0915014f 100644
--- a/synapse/config/modules.py
+++ b/synapse/config/modules.py
@@ -14,13 +14,14 @@
from typing import Any, Dict, List, Tuple
from synapse.config._base import Config, ConfigError
+from synapse.types import JsonDict
from synapse.util.module_loader import load_module
class ModulesConfig(Config):
section = "modules"
- def read_config(self, config: dict, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.loaded_modules: List[Tuple[Any, Dict]] = []
configured_modules = config.get("modules") or []
@@ -31,7 +32,7 @@ class ModulesConfig(Config):
self.loaded_modules.append(load_module(module, config_path))
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """
## Modules ##
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
index ea6ace47..690ffb52 100644
--- a/synapse/config/oembed.py
+++ b/synapse/config/oembed.py
@@ -40,7 +40,7 @@ class OembedConfig(Config):
section = "oembed"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
oembed_config: Dict[str, Any] = config.get("oembed") or {}
# A list of patterns which will be used.
@@ -143,7 +143,7 @@ class OembedConfig(Config):
)
return re.compile(pattern)
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# oEmbed allows for easier embedding content from a website. It can be
# used for generating URLs previews of services which support it.
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 5d571651..b9c40522 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -36,7 +36,7 @@ LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingPr
class OIDCConfig(Config):
section = "oidc"
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers:
return
@@ -66,7 +66,7 @@ class OIDCConfig(Config):
# OIDC is enabled if we have a provider
return bool(self.oidc_providers)
- def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index f980102b..35df4254 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -14,6 +14,7 @@
from typing import Any, List, Tuple, Type
+from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from ._base import Config
@@ -24,7 +25,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
section = "authproviders"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
"""Parses the old password auth providers config. The config format looks like this:
password_providers:
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 6ef8491c..2e796d1c 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
+
from ._base import Config
class PushConfig(Config):
section = "push"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
self.push_group_unread_count_by_room = push_config.get(
@@ -46,7 +50,7 @@ class PushConfig(Config):
)
self.push_include_content = not redact_content
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """
## Push ##
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index e9ccf1bd..0587f5c1 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional
+from typing import Any, Dict, Optional
import attr
+from synapse.types import JsonDict
+
from ._base import Config
@@ -43,7 +45,7 @@ class FederationRateLimitConfig:
class RatelimitConfig(Config):
section = "ratelimiting"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
@@ -142,7 +144,7 @@ class RatelimitConfig(Config):
},
)
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Ratelimiting ##
diff --git a/synapse/config/redis.py b/synapse/config/redis.py
index bdb1aac3..ec7a7354 100644
--- a/synapse/config/redis.py
+++ b/synapse/config/redis.py
@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
from synapse.config._base import Config
+from synapse.types import JsonDict
from synapse.util.check_dependencies import check_requirements
class RedisConfig(Config):
section = "redis"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
redis_config = config.get("redis") or {}
self.redis_enabled = redis_config.get("enabled", False)
@@ -32,7 +35,7 @@ class RedisConfig(Config):
self.redis_port = redis_config.get("port", 6379)
self.redis_password = redis_config.get("password")
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 40fb329a..39e9acb6 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
-from typing import Optional
+from typing import Any, Optional
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
-from synapse.types import RoomAlias, UserID
+from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols, strtobool
class RegistrationConfig(Config):
section = "registration"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_registration = strtobool(
str(config.get("enable_registration", False))
)
@@ -196,7 +196,9 @@ class RegistrationConfig(Config):
self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False)
- def generate_config_section(self, generate_secrets=False, **kwargs):
+ def generate_config_section(
+ self, generate_secrets: bool = False, **kwargs: Any
+ ) -> str:
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50),
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 0a0d901b..98d8a166 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -14,7 +14,7 @@
import logging
import os
-from typing import Dict, List, Tuple
+from typing import Any, Dict, List, Tuple
from urllib.request import getproxies_environment # type: ignore
import attr
@@ -94,7 +94,7 @@ def parse_thumbnail_requirements(
class ContentRepositoryConfig(Config):
section = "media"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# Only enable the media repo if either the media repo is enabled or the
# current worker app is the media repo.
@@ -223,7 +223,8 @@ class ContentRepositoryConfig(Config):
"url_preview_accept_language"
) or ["en"]
- def generate_config_section(self, data_dir_path, **kwargs):
+ def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
+ assert data_dir_path is not None
media_store = os.path.join(data_dir_path, "media_store")
formatted_thumbnail_sizes = "".join(
diff --git a/synapse/config/retention.py b/synapse/config/retention.py
index aed9bf45..03b723b8 100644
--- a/synapse/config/retention.py
+++ b/synapse/config/retention.py
@@ -13,11 +13,12 @@
# limitations under the License.
import logging
-from typing import List, Optional
+from typing import Any, List, Optional
import attr
from synapse.config._base import Config, ConfigError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -34,7 +35,7 @@ class RetentionPurgeJob:
class RetentionConfig(Config):
section = "retention"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@@ -153,7 +154,7 @@ class RetentionConfig(Config):
RetentionPurgeJob(self.parse_duration("1d"), None, None)
]
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Message retention policy at the server level.
#
diff --git a/synapse/config/room.py b/synapse/config/room.py
index d889d90d..e18a87ea 100644
--- a/synapse/config/room.py
+++ b/synapse/config/room.py
@@ -13,8 +13,10 @@
# limitations under the License.
import logging
+from typing import Any
from synapse.api.constants import RoomCreationPreset
+from synapse.types import JsonDict
from ._base import Config, ConfigError
@@ -32,7 +34,7 @@ class RoomDefaultEncryptionTypes:
class RoomConfig(Config):
section = "room"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# Whether new, locally-created rooms should have encryption enabled
encryption_for_room_type = config.get(
"encryption_enabled_by_default_for_room_type",
@@ -61,7 +63,7 @@ class RoomConfig(Config):
"Invalid value for encryption_enabled_by_default_for_room_type"
)
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Rooms ##
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 3c5e0f7c..717ba70e 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import Any, List
from matrix_common.regex import glob_to_regex
@@ -25,7 +25,7 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
section = "roomdirectory"
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_room_list_search = config.get("enable_room_list_search", True)
alias_creation_rules = config.get("alias_creation_rules")
@@ -52,7 +52,7 @@ class RoomDirectoryConfig(Config):
_RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
]
- def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
+ def generate_config_section(self, **kwargs: Any) -> str:
return """
# Uncomment to disable searching the public room list. When disabled
# blocks searching local and remote room lists for local and remote
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 43c456d5..19b2f1b2 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -65,7 +65,7 @@ def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
class SAML2Config(Config):
section = "saml2"
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.saml2_enabled = False
saml2_config = config.get("saml2_config")
@@ -165,13 +165,13 @@ class SAML2Config(Config):
config_path = saml2_config.get("config_path", None)
if config_path is not None:
mod = load_python_module(config_path)
- config = getattr(mod, "CONFIG", None)
- if config is None:
+ config_dict_from_file = getattr(mod, "CONFIG", None)
+ if config_dict_from_file is None:
raise ConfigError(
"Config path specified by saml2_config.config_path does not "
"have a CONFIG property."
)
- _dict_merge(merge_dict=config, into_dict=saml2_config_dict)
+ _dict_merge(merge_dict=config_dict_from_file, into_dict=saml2_config_dict)
import saml2.config
@@ -223,7 +223,7 @@ class SAML2Config(Config):
},
}
- def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
+ def generate_config_section(self, config_dir_path: str, **kwargs: Any) -> str:
return """\
## Single sign-on integration ##
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 38de4b80..415279d2 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -248,7 +248,7 @@ class LimitRemoteRoomsConfig:
class ServerConfig(Config):
section = "server"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
@@ -259,8 +259,8 @@ class ServerConfig(Config):
self.pid_file = self.abspath(config.get("pid_file"))
self.soft_file_limit = config.get("soft_file_limit", 0)
- self.daemonize = config.get("daemonize")
- self.print_pidfile = config.get("print_pidfile")
+ self.daemonize = bool(config.get("daemonize"))
+ self.print_pidfile = bool(config.get("print_pidfile"))
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.serve_server_wellknown = config.get("serve_server_wellknown", False)
@@ -680,18 +680,30 @@ class ServerConfig(Config):
config.get("use_account_validity_in_account_status") or False
)
+ # This is a temporary option that enables fully using the new
+ # `device_lists_changes_in_room` without the backwards compat code. This
+ # is primarily for testing. If enabled the server should *not* be
+ # downgraded, as it may lead to missing device list updates.
+ self.use_new_device_lists_changes_in_room = (
+ config.get("use_new_device_lists_changes_in_room") or False
+ )
+
+ self.rooms_to_exclude_from_sync: List[str] = (
+ config.get("exclude_rooms_from_sync") or []
+ )
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
def generate_config_section(
self,
- server_name,
- data_dir_path,
- open_private_ports,
- listeners,
- config_dir_path,
- **kwargs,
- ):
+ config_dir_path: str,
+ data_dir_path: str,
+ server_name: str,
+ open_private_ports: bool,
+ listeners: Optional[List[dict]],
+ **kwargs: Any,
+ ) -> str:
ip_range_blacklist = "\n".join(
" # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
)
@@ -1234,6 +1246,15 @@ class ServerConfig(Config):
# information about using custom templates.
#
#custom_template_directory: /path/to/custom/templates/
+
+ # List of rooms to exclude from sync responses. This is useful for server
+ # administrators wishing to group users into a room without these users being able
+ # to see it from their client.
+ #
+ # By default, no room is excluded.
+ #
+ #exclude_rooms_from_sync:
+ # - !foo:example.com
"""
% locals()
)
diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py
index bde4e879..505b4f6c 100644
--- a/synapse/config/server_notices.py
+++ b/synapse/config/server_notices.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.types import UserID
+
+from typing import Any, Optional
+
+from synapse.types import JsonDict, UserID
from ._base import Config
@@ -60,14 +63,14 @@ class ServerNoticesConfig(Config):
section = "servernotices"
- def __init__(self, *args):
+ def __init__(self, *args: Any):
super().__init__(*args)
- self.server_notices_mxid = None
- self.server_notices_mxid_display_name = None
- self.server_notices_mxid_avatar_url = None
- self.server_notices_room_name = None
+ self.server_notices_mxid: Optional[str] = None
+ self.server_notices_mxid_display_name: Optional[str] = None
+ self.server_notices_mxid_avatar_url: Optional[str] = None
+ self.server_notices_room_name: Optional[str] = None
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
c = config.get("server_notices")
if c is None:
return
@@ -81,5 +84,5 @@ class ServerNoticesConfig(Config):
# todo: i18n
self.server_notices_room_name = c.get("room_name", "Server Notices")
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return DEFAULT_CONFIG
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 4c52103b..f22784f9 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -16,6 +16,7 @@ import logging
from typing import Any, Dict, List, Tuple
from synapse.config import ConfigError
+from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from ._base import Config
@@ -33,7 +34,7 @@ see https://matrix-org.github.io/synapse/latest/modules/index.html
class SpamCheckerConfig(Config):
section = "spamchecker"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.spam_checkers: List[Tuple[Any, Dict]] = []
spam_checkers = config.get("spam_checker") or []
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index e4a42432..f88eba77 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -16,6 +16,8 @@ from typing import Any, Dict, Optional
import attr
+from synapse.types import JsonDict
+
from ._base import Config
logger = logging.getLogger(__name__)
@@ -49,7 +51,7 @@ class SSOConfig(Config):
section = "sso"
- def read_config(self, config, **kwargs) -> None:
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
sso_config: Dict[str, Any] = config.get("sso") or {}
# The sso-specific template_dir
@@ -106,7 +108,7 @@ class SSOConfig(Config):
)
self.sso_client_whitelist.append(login_fallback_url)
- def generate_config_section(self, **kwargs) -> str:
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Additional settings to use with single-sign on systems such as OpenID Connect,
# SAML2 and CAS.
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 6f253e00..ed1f416e 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -13,6 +13,9 @@
# limitations under the License.
import logging
+from typing import Any
+
+from synapse.types import JsonDict
from ._base import Config
@@ -36,7 +39,7 @@ class StatsConfig(Config):
section = "stats"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.stats_enabled = True
stats_config = config.get("stats", None)
if stats_config:
@@ -44,7 +47,7 @@ class StatsConfig(Config):
if not self.stats_enabled:
logger.warning(ROOM_STATS_DISABLED_WARN)
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """
# Settings for local room and user statistics collection. See
# https://matrix-org.github.io/synapse/latest/room_and_user_statistics.html.
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index a3fae024..eca209d5 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from ._base import Config
@@ -20,7 +23,7 @@ from ._base import Config
class ThirdPartyRulesConfig(Config):
section = "thirdpartyrules"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.third_party_event_rules = None
provider = config.get("third_party_event_rules", None)
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 6e673d65..cb17950d 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -14,7 +14,7 @@
import logging
import os
-from typing import List, Optional, Pattern
+from typing import Any, List, Optional, Pattern
from matrix_common.regex import glob_to_regex
@@ -22,6 +22,7 @@ from OpenSSL import SSL, crypto
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
from synapse.config._base import Config, ConfigError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -29,7 +30,7 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
section = "tls"
- def read_config(self, config: dict, config_dir_path: str, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
@@ -142,13 +143,13 @@ class TlsConfig(Config):
def generate_config_section(
self,
- config_dir_path,
- server_name,
- data_dir_path,
- tls_certificate_path,
- tls_private_key_path,
- **kwargs,
- ):
+ config_dir_path: str,
+ data_dir_path: str,
+ server_name: str,
+ tls_certificate_path: Optional[str],
+ tls_private_key_path: Optional[str],
+ **kwargs: Any,
+ ) -> str:
"""If the TLS paths are not specified the default will be certs in the
config directory"""
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 7aff618e..3472a9a0 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Set
+from typing import Any, Set
+from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@@ -22,7 +23,7 @@ from ._base import Config, ConfigError
class TracerConfig(Config):
section = "tracing"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
opentracing_config = config.get("opentracing")
if opentracing_config is None:
opentracing_config = {}
@@ -65,7 +66,7 @@ class TracerConfig(Config):
)
self.force_tracing_for_users.add(u)
- def generate_config_section(cls, **kwargs):
+ def generate_config_section(cls, **kwargs: Any) -> str:
return """\
## Opentracing ##
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 6d6678c7..010e7919 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
+
from ._base import Config
@@ -22,7 +26,7 @@ class UserDirectoryConfig(Config):
section = "userdirectory"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
user_directory_config = config.get("user_directory") or {}
self.user_directory_search_enabled = user_directory_config.get("enabled", True)
self.user_directory_search_all_users = user_directory_config.get(
@@ -32,7 +36,7 @@ class UserDirectoryConfig(Config):
"prefer_local_users", False
)
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """
# User Directory configuration
#
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index b313bff1..87c09abe 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
+from synapse.types import JsonDict
+
from ._base import Config
class VoipConfig(Config):
section = "voip"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
@@ -28,7 +32,7 @@ class VoipConfig(Config):
)
self.turn_allow_guests = config.get("turn_allow_guests", True)
- def generate_config_section(self, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## TURN ##
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index bdaba6db..a5479dfc 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -14,10 +14,12 @@
# limitations under the License.
import argparse
-from typing import List, Union
+from typing import Any, List, Union
import attr
+from synapse.types import JsonDict
+
from ._base import (
Config,
ConfigError,
@@ -110,7 +112,7 @@ class WorkerConfig(Config):
section = "worker"
- def read_config(self, config, **kwargs):
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
@@ -120,9 +122,13 @@ class WorkerConfig(Config):
self.worker_listeners = [
parse_listener_def(x) for x in config.get("worker_listeners", [])
]
- self.worker_daemonize = config.get("worker_daemonize")
+ self.worker_daemonize = bool(config.get("worker_daemonize"))
self.worker_pid_file = config.get("worker_pid_file")
- self.worker_log_config = config.get("worker_log_config")
+
+ worker_log_config = config.get("worker_log_config")
+ if worker_log_config is not None and not isinstance(worker_log_config, str):
+ raise ConfigError("worker_log_config must be a string")
+ self.worker_log_config = worker_log_config
# The host used to connect to the main synapse
self.worker_replication_host = config.get("worker_replication_host", None)
@@ -290,7 +296,7 @@ class WorkerConfig(Config):
self.worker_name is None and background_tasks_instance == "master"
) or self.worker_name == background_tasks_instance
- def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Workers ##
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 6cf384f6..c88afb29 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -176,7 +176,7 @@ class Keyring:
self._local_verify_keys: Dict[str, FetchKeyResult] = {}
for key_id, key in hs.config.key.old_signing_keys.items():
self._local_verify_keys[key_id] = FetchKeyResult(
- verify_key=key, valid_until_ts=key.expired_ts
+ verify_key=key, valid_until_ts=key.expired
)
vk = get_verify_key(hs.signing_key)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 1ea1bb7d..98c203ad 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -15,7 +15,7 @@ import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
-from nacl.signing import SigningKey
+from signedjson.types import SigningKey
from synapse.api.constants import MAX_DEPTH
from synapse.api.room_versions import (
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index bfca454f..ef68e202 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -42,6 +42,7 @@ CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
+ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@@ -169,6 +170,7 @@ class ThirdPartyEventRules:
self._on_user_deactivation_status_changed_callbacks: List[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = []
+ self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
def register_third_party_rules_callbacks(
self,
@@ -187,6 +189,7 @@ class ThirdPartyEventRules:
on_user_deactivation_status_changed: Optional[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None,
+ on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
@@ -221,6 +224,9 @@ class ThirdPartyEventRules:
on_user_deactivation_status_changed,
)
+ if on_threepid_bind is not None:
+ self._on_threepid_bind_callbacks.append(on_threepid_bind)
+
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> Tuple[bool, Optional[dict]]:
@@ -479,3 +485,23 @@ class ThirdPartyEventRules:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
)
+
+ async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None:
+ """Called after a threepid association has been verified and stored.
+
+ Note that this callback is called when an association is created on the
+ local homeserver, not when it's created on an identity server (and then kept track
+ of so that it can be unbound on the same IS later on).
+
+ Args:
+ user_id: the user being associated with the threepid.
+ medium: the threepid's medium.
+ address: the threepid's address.
+ """
+ for callback in self._on_threepid_bind_callbacks:
+ try:
+ await callback(user_id, medium, address)
+ except Exception as e:
+ logger.exception(
+ "Failed to run module API callback %s: %s", callback, e
+ )
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 467275b9..6a59cb4b 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,6 +56,7 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
+from synapse.http.types import QueryParams
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -154,7 +155,7 @@ class FederationClient(FederationBase):
self,
destination: str,
query_type: str,
- args: dict,
+ args: QueryParams,
retry_on_dns_fail: bool = False,
ignore_backoff: bool = False,
) -> JsonDict:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c7400c73..69d83358 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -188,7 +188,7 @@ class FederationServer(FederationBase):
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
- with (await self._server_linearizer.queue((origin, room_id))):
+ async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -218,7 +218,7 @@ class FederationServer(FederationBase):
Tuple indicating the response status code and dictionary response
body including `event_id`.
"""
- with (await self._server_linearizer.queue((origin, room_id))):
+ async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -529,7 +529,7 @@ class FederationServer(FederationBase):
# in the cache so we could return it without waiting for the linearizer
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
- with (await self._server_linearizer.queue((origin, room_id))):
+ async with self._server_linearizer.queue((origin, room_id)):
resp: JsonDict = dict(
await self._state_resp_cache.wrap(
(room_id, event_id),
@@ -883,7 +883,7 @@ class FederationServer(FederationBase):
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
- with (await self._server_linearizer.queue((origin, room_id))):
+ async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -945,7 +945,7 @@ class FederationServer(FederationBase):
latest_events: List[str],
limit: int,
) -> Dict[str, list]:
- with (await self._server_linearizer.queue((origin, room_id))):
+ async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index de6e5f44..01dc5ca9 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -44,6 +44,7 @@ from synapse.api.urls import (
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
+from synapse.http.types import QueryParams
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -255,7 +256,7 @@ class TransportLayerClient:
self,
destination: str,
query_type: str,
- args: dict,
+ args: QueryParams,
retry_on_dns_fail: bool,
ignore_backoff: bool = False,
prefix: str = FEDERATION_V1_PREFIX,
@@ -481,7 +482,7 @@ class TransportLayerClient:
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
- data["limit"] = str(limit)
+ data["limit"] = limit
if since_token:
data["since"] = since_token
@@ -503,7 +504,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/publicRooms")
- args: Dict[str, Any] = {
+ args: Dict[str, Union[str, Iterable[str]]] = {
"include_all_networks": "true" if include_all_networks else "false"
}
if third_party_instance_id:
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 177b4f89..4af9fbc5 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import random
-from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
@@ -27,6 +28,12 @@ from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
+logger = logging.getLogger(__name__)
+
+ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
+ [str, Optional[str], str, JsonDict], Awaitable
+]
+
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
@@ -40,6 +47,44 @@ class AccountDataHandler:
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
+ self._on_account_data_updated_callbacks: List[
+ ON_ACCOUNT_DATA_UPDATED_CALLBACK
+ ] = []
+
+ def register_module_callbacks(
+ self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
+ ) -> None:
+ """Register callbacks from modules."""
+ if on_account_data_updated is not None:
+ self._on_account_data_updated_callbacks.append(on_account_data_updated)
+
+ async def _notify_modules(
+ self,
+ user_id: str,
+ room_id: Optional[str],
+ account_data_type: str,
+ content: JsonDict,
+ ) -> None:
+ """Notifies modules about new account data changes.
+
+ A change can be either a new account data type being added, or the content
+ associated with a type being changed. Account data for a given type is removed by
+ changing the associated content to an empty dictionary.
+
+ Note that this is not called when the tags associated with a room change.
+
+ Args:
+ user_id: The user whose account data is changing.
+ room_id: The ID of the room the account data change concerns, if any.
+ account_data_type: The type of the account data.
+ content: The content that is now associated with this type.
+ """
+ for callback in self._on_account_data_updated_callbacks:
+ try:
+ await callback(user_id, room_id, account_data_type, content)
+ except Exception as e:
+ logger.exception("Failed to run module callback %s: %s", callback, e)
+
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
@@ -63,6 +108,8 @@ class AccountDataHandler:
"account_data_key", max_stream_id, users=[user_id]
)
+ await self._notify_modules(user_id, room_id, account_data_type, content)
+
return max_stream_id
else:
response = await self._room_data_client(
@@ -96,6 +143,9 @@ class AccountDataHandler:
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
+
+ await self._notify_modules(user_id, None, account_data_type, content)
+
return max_stream_id
else:
response = await self._user_data_client(
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 9d0975f6..05a13841 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -180,9 +180,9 @@ class AccountValidityHandler:
expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
- for user in expiring_users:
+ for user_id, expiration_ts_ms in expiring_users:
await self._send_renewal_email(
- user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
+ user_id=user_id, expiration_ts=expiration_ts_ms
)
async def send_renewal_email_to_user(self, user_id: str) -> None:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index bd913e52..1b578405 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
+from synapse.types import (
+ DeviceListUpdates,
+ JsonDict,
+ RoomAlias,
+ RoomStreamToken,
+ UserID,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
@@ -58,6 +64,9 @@ class ApplicationServicesHandler:
self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
+ self._msc3202_transaction_extensions_enabled = (
+ hs.config.experimental.msc3202_transaction_extensions
+ )
self.current_max = 0
self.is_processing = False
@@ -204,9 +213,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
- `stream_key` can be "typing_key", "receipt_key", "presence_key" or
- "to_device_key". Any other value for `stream_key` will cause this function
- to return early.
+ `stream_key` can be "typing_key", "receipt_key", "presence_key",
+ "to_device_key" or "device_list_key". Any other value for `stream_key`
+ will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@@ -230,6 +239,7 @@ class ApplicationServicesHandler:
"receipt_key",
"presence_key",
"to_device_key",
+ "device_list_key",
):
return
@@ -253,15 +263,37 @@ class ApplicationServicesHandler:
):
return
+ # Ignore device lists if the feature flag is not enabled
+ if (
+ stream_key == "device_list_key"
+ and not self._msc3202_transaction_extensions_enabled
+ ):
+ return
+
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
# Note that whether these events are actually relevant to these appservices
# is decided later on.
+ services = self.store.get_app_services()
services = [
service
- for service in self.store.get_app_services()
- if service.supports_ephemeral
+ for service in services
+ # Different stream keys require different support booleans
+ if (
+ stream_key
+ in (
+ "typing_key",
+ "receipt_key",
+ "presence_key",
+ "to_device_key",
+ )
+ and service.supports_ephemeral
+ )
+ or (
+ stream_key == "device_list_key"
+ and service.msc3202_transaction_extensions
+ )
]
if not services:
# Bail out early if none of the target appservices have explicitly registered
@@ -298,10 +330,8 @@ class ApplicationServicesHandler:
continue
# Since we read/update the stream position for this AS/stream
- with (
- await self._ephemeral_events_linearizer.queue(
- (service.id, stream_key)
- )
+ async with self._ephemeral_events_linearizer.queue(
+ (service.id, stream_key)
):
if stream_key == "receipt_key":
events = await self._handle_receipts(service, new_token)
@@ -336,6 +366,20 @@ class ApplicationServicesHandler:
service, "to_device", new_token
)
+ elif stream_key == "device_list_key":
+ device_list_summary = await self._get_device_list_summary(
+ service, new_token
+ )
+ if device_list_summary:
+ self.scheduler.enqueue_for_appservice(
+ service, device_list_summary=device_list_summary
+ )
+
+ # Persist the latest handled stream token for this appservice
+ await self.store.set_appservice_stream_type_pos(
+ service, "device_list", new_token
+ )
+
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@@ -542,6 +586,96 @@ class ApplicationServicesHandler:
return message_payload
+ async def _get_device_list_summary(
+ self,
+ appservice: ApplicationService,
+ new_key: int,
+ ) -> DeviceListUpdates:
+ """
+ Retrieve a list of users who have changed their device lists.
+
+ Args:
+ appservice: The application service to retrieve device list changes for.
+ new_key: The stream key of the device list change that triggered this method call.
+
+ Returns:
+ A set of device list updates, comprised of users that the appservices needs to:
+ * resync the device list of, and
+ * stop tracking the device list of.
+ """
+ # Fetch the last successfully processed device list update stream ID
+ # for this appservice.
+ from_key = await self.store.get_type_stream_id_for_appservice(
+ appservice, "device_list"
+ )
+
+ # Fetch the users who have modified their device list since then.
+ users_with_changed_device_lists = (
+ await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
+ )
+
+ # Filter out any users the application service is not interested in
+ #
+ # For each user who changed their device list, we want to check whether this
+ # appservice would be interested in the change.
+ filtered_users_with_changed_device_lists = {
+ user_id
+ for user_id in users_with_changed_device_lists
+ if await self._is_appservice_interested_in_device_lists_of_user(
+ appservice, user_id
+ )
+ }
+
+ # Create a summary of "changed" and "left" users.
+ # TODO: Calculate "left" users.
+ device_list_summary = DeviceListUpdates(
+ changed=filtered_users_with_changed_device_lists
+ )
+
+ return device_list_summary
+
+ async def _is_appservice_interested_in_device_lists_of_user(
+ self,
+ appservice: ApplicationService,
+ user_id: str,
+ ) -> bool:
+ """
+ Returns whether a given application service is interested in the device list
+ updates of a given user.
+
+ The application service is interested in the user's device list updates if any
+ of the following are true:
+ * The user is the appservice's sender localpart user.
+ * The user is in the appservice's user namespace.
+ * At least one member of one room that the user is a part of is in the
+ appservice's user namespace.
+ * The appservice is explicitly (via room ID or alias) interested in at
+ least one room that the user is in.
+
+ Args:
+ appservice: The application service to gauge interest of.
+ user_id: The ID of the user whose device list interest is in question.
+
+ Returns:
+ True if the application service is interested in the user's device lists, False
+ otherwise.
+ """
+ # This method checks against both the sender localpart user as well as if the
+ # user is in the appservice's user namespace.
+ if appservice.is_interested_in_user(user_id):
+ return True
+
+ # Determine whether any of the rooms the user is in justifies sending this
+ # device list update to the application service.
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ for room_id in room_ids:
+ # This method covers checking room members for appservice interest as well as
+ # room ID and alias checks.
+ if await appservice.is_interested_in_room(room_id, self.store):
+ return True
+
+ return False
+
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3e29c96a..86991d26 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -211,6 +211,7 @@ class AuthHandler:
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
+ self._third_party_rules = hs.get_third_party_event_rules()
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@@ -1505,6 +1506,8 @@ class AuthHandler:
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
+ await self._third_party_rules.on_threepid_bind(user_id, medium, address)
+
async def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
) -> bool:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d5ccaa0c..ffa28b2a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -37,7 +37,10 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import (
JsonDict,
StreamToken,
@@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
+ # Whether `_handle_new_device_update_async` is currently processing.
+ self._handle_new_device_update_is_processing = False
+
+ # If a new device update may have happened while the loop was
+ # processing.
+ self._handle_new_device_update_new_data = False
+
+ # On start up check if there are any updates pending.
+ hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+
+ # Used to decide if we calculate outbound pokes up front or not. By
+ # default we do to allow safely downgrading Synapse.
+ self.use_new_device_lists_changes_in_room = (
+ hs.config.server.use_new_device_lists_changes_in_room
+ )
+
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
# No changes to notify about, so this is a no-op.
return
- users_who_share_room = await self.store.get_users_who_share_room_with_user(
- user_id
- )
+ room_ids = await self.store.get_rooms_for_user(user_id)
+
+ hosts: Optional[Set[str]] = None
+ if not self.use_new_device_lists_changes_in_room:
+ hosts = set()
- hosts: Set[str] = set()
- if self.hs.is_mine_id(user_id):
- hosts.update(get_domain_from_id(u) for u in users_who_share_room)
- hosts.discard(self.server_name)
+ if self.hs.is_mine_id(user_id):
+ for room_id in room_ids:
+ joined_users = await self.store.get_users_in_room(room_id)
+ hosts.update(get_domain_from_id(u) for u in joined_users)
- set_tag("target_hosts", hosts)
+ set_tag("target_hosts", hosts)
+
+ hosts.discard(self.server_name)
position = await self.store.add_device_change_to_streams(
- user_id, device_ids, list(hosts)
+ user_id,
+ device_ids,
+ hosts=hosts,
+ room_ids=room_ids,
)
if not position:
@@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
- users_to_notify = users_who_share_room.union({user_id})
+ self.notifier.on_new_event(
+ "device_list_key", position, users={user_id}, rooms=room_ids
+ )
- self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
+ # We may need to do some processing asynchronously.
+ self._handle_new_device_update_async()
if hosts:
logger.info(
@@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
+ @wrap_as_background_process("_handle_new_device_update_async")
+ async def _handle_new_device_update_async(self) -> None:
+ """Called when we have a new local device list update that we need to
+ send out over federation.
+
+ This happens in the background so as not to block the original request
+ that generated the device update.
+ """
+ if self._handle_new_device_update_is_processing:
+ self._handle_new_device_update_new_data = True
+ return
+
+ self._handle_new_device_update_is_processing = True
+
+ # The stream ID we processed previous iteration (if any), and the set of
+ # hosts we've already poked about for this update. This is so that we
+ # don't poke the same remote server about the same update repeatedly.
+ current_stream_id = None
+ hosts_already_sent_to: Set[str] = set()
+
+ try:
+ while True:
+ self._handle_new_device_update_new_data = False
+ rows = await self.store.get_uncoverted_outbound_room_pokes()
+ if not rows:
+ # If the DB returned nothing then there is nothing left to
+ # do, *unless* a new device list update happened during the
+ # DB query.
+ if self._handle_new_device_update_new_data:
+ continue
+ else:
+ return
+
+ for user_id, device_id, room_id, stream_id, opentracing_context in rows:
+ joined_user_ids = await self.store.get_users_in_room(room_id)
+ hosts = {get_domain_from_id(u) for u in joined_user_ids}
+ hosts.discard(self.server_name)
+
+ # Check if we've already sent this update to some hosts
+ if current_stream_id == stream_id:
+ hosts -= hosts_already_sent_to
+
+ await self.store.add_device_list_outbound_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ stream_id=stream_id,
+ hosts=hosts,
+ context=opentracing_context,
+ )
+
+ # Notify replication that we've updated the device list stream.
+ self.notifier.notify_replication()
+
+ if hosts:
+ logger.info(
+ "Sending device list update notif for %r to: %r",
+ user_id,
+ hosts,
+ )
+ for host in hosts:
+ self.federation_sender.send_device_messages(
+ host, immediate=False
+ )
+ log_kv(
+ {"message": "sent device update to host", "host": host}
+ )
+
+ if current_stream_id != stream_id:
+ # Clear the set of hosts we've already sent to as we're
+ # processing a new update.
+ hosts_already_sent_to.clear()
+
+ hosts_already_sent_to.update(hosts)
+ current_stream_id = stream_id
+
+ finally:
+ self._handle_new_device_update_is_processing = False
+
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
@@ -725,7 +833,7 @@ class DeviceListUpdater:
async def _handle_device_updates(self, user_id: str) -> None:
"Actually handle pending updates."
- with (await self._remote_edu_linearizer.queue(user_id)):
+ async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d96456cd..d6714228 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -118,7 +118,7 @@ class E2eKeysHandler:
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
"""
- with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
+ async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
@@ -1386,7 +1386,7 @@ class SigningKeyEduUpdater:
device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater
- with (await self._remote_edu_linearizer.queue(user_id)):
+ async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 52e44a2d..446f509b 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -83,7 +83,7 @@ class E2eRoomKeysHandler:
# we deliberately take the lock to get keys so that changing the version
# works atomically
- with (await self._upload_linearizer.queue(user_id)):
+ async with self._upload_linearizer.queue(user_id):
# make sure the backup version exists
try:
await self.store.get_e2e_room_keys_version_info(user_id, version)
@@ -126,7 +126,7 @@ class E2eRoomKeysHandler:
"""
# lock for consistency with uploading
- with (await self._upload_linearizer.queue(user_id)):
+ async with self._upload_linearizer.queue(user_id):
# make sure the backup version exists
try:
version_info = await self.store.get_e2e_room_keys_version_info(
@@ -187,7 +187,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here?
- with (await self._upload_linearizer.queue(user_id)):
+ async with self._upload_linearizer.queue(user_id):
# Check that the version we're trying to upload is the current version
try:
@@ -332,7 +332,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version
- with (await self._upload_linearizer.queue(user_id)):
+ async with self._upload_linearizer.queue(user_id):
new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info
)
@@ -359,7 +359,7 @@ class E2eRoomKeysHandler:
}
"""
- with (await self._upload_linearizer.queue(user_id)):
+ async with self._upload_linearizer.queue(user_id):
try:
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
@@ -383,7 +383,7 @@ class E2eRoomKeysHandler:
NotFoundError: if this backup version doesn't exist
"""
- with (await self._upload_linearizer.queue(user_id)):
+ async with self._upload_linearizer.queue(user_id):
try:
await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
@@ -413,7 +413,7 @@ class E2eRoomKeysHandler:
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
- with (await self._upload_linearizer.queue(user_id)):
+ async with self._upload_linearizer.queue(user_id):
try:
old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 350ec9c0..78d14990 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -151,7 +151,7 @@ class FederationHandler:
return. This is used as part of the heuristic to decide if we
should back paginate.
"""
- with (await self._room_backfill.queue(room_id)):
+ async with self._room_backfill.queue(room_id):
return await self._maybe_backfill_inner(room_id, current_depth, limit)
async def _maybe_backfill_inner(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 4bd87709..03c1197c 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -224,7 +224,7 @@ class FederationEventHandler:
len(missing_prevs),
shortstr(missing_prevs),
)
- with (await self._room_pdu_linearizer.queue(pdu.room_id)):
+ async with self._room_pdu_linearizer.queue(pdu.room_id):
logger.info(
"Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
@@ -469,6 +469,12 @@ class FederationEventHandler:
if context.rejected:
raise SynapseError(400, "Join event was rejected")
+ # the remote server is responsible for sending our join event to the rest
+ # of the federation. Indeed, attempting to do so will result in problems
+ # when we try to look up the state before the join (to get the server list)
+ # and discover that we do not have it.
+ event.internal_metadata.proactively_send = False
+
return await self.persist_events_and_notify(room_id, [(event, context)])
async def backfill(
@@ -891,10 +897,24 @@ class FederationEventHandler:
logger.debug("We are also missing %i auth events", len(missing_auth_events))
missing_events = missing_desired_events | missing_auth_events
- logger.debug("Fetching %i events from remote", len(missing_events))
- await self._get_events_and_persist(
- destination=destination, room_id=room_id, event_ids=missing_events
- )
+
+ # Making an individual request for each of 1000s of events has a lot of
+ # overhead. On the other hand, we don't really want to fetch all of the events
+ # if we already have most of them.
+ #
+ # As an arbitrary heuristic, if we are missing more than 10% of the events, then
+ # we fetch the whole state.
+ #
+ # TODO: might it be better to have an API which lets us do an aggregate event
+ # request
+ if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+ logger.debug("Requesting complete state from remote")
+ await self._get_state_and_persist(destination, room_id, event_id)
+ else:
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, event_ids=missing_events
+ )
# we need to make sure we re-load from the database to get the rejected
# state correct.
@@ -953,6 +973,27 @@ class FederationEventHandler:
return remote_state
+ async def _get_state_and_persist(
+ self, destination: str, room_id: str, event_id: str
+ ) -> None:
+ """Get the complete room state at a given event, and persist any new events
+ as outliers"""
+ room_version = await self._store.get_room_version(room_id)
+ auth_events, state_events = await self._federation_client.get_room_state(
+ destination, room_id, event_id=event_id, room_version=room_version
+ )
+ logger.info("/state returned %i events", len(auth_events) + len(state_events))
+
+ await self._auth_and_persist_outliers(
+ room_id, itertools.chain(auth_events, state_events)
+ )
+
+ # we also need the event itself.
+ if not await self._store.have_seen_event(room_id, event_id):
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, event_ids=(event_id,)
+ )
+
async def _process_received_pdu(
self,
origin: str,
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 57c9fdfe..c183e9c4 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -858,8 +858,6 @@ class IdentityHandler:
if room_type is not None:
invite_config["room_type"] = room_type
- # TODO The unstable field is deprecated and should be removed in the future.
- invite_config["org.matrix.msc3288.room_type"] = room_type
# If a custom web client location is available, include it in the request.
if self._web_client_location:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 766f597a..7db6905c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -851,7 +851,7 @@ class EventCreationHandler:
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
- with (await self.limiter.queue(event_dict["room_id"])):
+ async with self.limiter.queue(event_dict["room_id"]):
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
event_dict["room_id"],
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 876b8794..7ee33403 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -441,7 +441,14 @@ class PaginationHandler:
if pagin_config.from_token:
from_token = pagin_config.from_token
else:
- from_token = self.hs.get_event_sources().get_current_token_for_pagination()
+ from_token = (
+ await self.hs.get_event_sources().get_current_token_for_pagination(
+ room_id
+ )
+ )
+ # We expect `/messages` to use historic pagination tokens by default but
+ # `/messages` should still works with live tokens when manually provided.
+ assert from_token.room_key.topological
if pagin_config.limit is None:
# This shouldn't happen as we've set a default limit before this
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 34d9411b..209a4b0e 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1030,7 +1030,7 @@ class PresenceHandler(BasePresenceHandler):
is_syncing: Whether or not the user is now syncing
sync_time_msec: Time in ms when the user was last syncing
"""
- with (await self.external_sync_linearizer.queue(process_id)):
+ async with self.external_sync_linearizer.queue(process_id):
prev_state = await self.current_state_for_user(user_id)
process_presence = self.external_process_to_current_syncs.setdefault(
@@ -1071,7 +1071,7 @@ class PresenceHandler(BasePresenceHandler):
Used when the process has stopped/disappeared.
"""
- with (await self.external_sync_linearizer.queue(process_id)):
+ async with self.external_sync_linearizer.queue(process_id):
process_presence = self.external_process_to_current_syncs.pop(
process_id, set()
)
@@ -1625,7 +1625,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# We'll actually pull the presence updates for these users at the end.
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
- if from_key:
+ if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index bad1acc6..05122fd5 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -40,7 +40,7 @@ class ReadMarkerHandler:
the read marker has changed.
"""
- with await self.read_marker_linearizer.queue((room_id, user_id)):
+ async with self.read_marker_linearizer.queue((room_id, user_id)):
existing_read_marker = await self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read"
)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 73217d13..0be23195 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+)
import attr
from frozendict import frozendict
@@ -20,12 +29,12 @@ from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
-from synapse.types import JsonDict, Requester, StreamToken
+from synapse.storage.databases.main.relations import _RelatedEvent
+from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -116,7 +125,10 @@ class RelationsHandler:
if event is None:
raise SynapseError(404, "Unknown parent event.")
- pagination_chunk = await self._main_store.get_relations_for_event(
+ # Note that ignored users are not passed into get_relations_for_event
+ # below. Ignored users are handled in filter_events_for_client (and by
+ # not passing them in here we should get a better cache hit rate).
+ related_events, next_token = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
@@ -130,7 +142,7 @@ class RelationsHandler:
)
events = await self._main_store.get_events_as_list(
- [c["event_id"] for c in pagination_chunk.chunk]
+ [e.event_id for e in related_events]
)
events = await filter_events_for_client(
@@ -152,14 +164,100 @@ class RelationsHandler:
events, now, bundle_aggregations=aggregations
)
- return_value = await pagination_chunk.to_dict(self._main_store)
- return_value["chunk"] = serialized_events
- return_value["original_event"] = original_event
+ return_value = {
+ "chunk": serialized_events,
+ "original_event": original_event,
+ }
+
+ if next_token:
+ return_value["next_batch"] = await next_token.to_string(self._main_store)
+
+ if from_token:
+ return_value["prev_batch"] = await from_token.to_string(self._main_store)
return return_value
+ async def get_relations_for_event(
+ self,
+ event_id: str,
+ event: EventBase,
+ room_id: str,
+ relation_type: str,
+ ignored_users: FrozenSet[str] = frozenset(),
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+ """Get a list of events which relate to an event, ordered by topological ordering.
+
+ Args:
+ event_id: Fetch events that relate to this event ID.
+ event: The matching EventBase to event_id.
+ room_id: The room the event belongs to.
+ relation_type: The type of relation.
+ ignored_users: The users ignored by the requesting user.
+
+ Returns:
+ List of event IDs that match relations requested. The rows are of
+ the form `{"event_id": "..."}`.
+ """
+
+ # Call the underlying storage method, which is cached.
+ related_events, next_token = await self._main_store.get_relations_for_event(
+ event_id, event, room_id, relation_type, direction="f"
+ )
+
+ # Filter out ignored users and convert to the expected format.
+ related_events = [
+ event for event in related_events if event.sender not in ignored_users
+ ]
+
+ return related_events, next_token
+
+ async def get_annotations_for_event(
+ self,
+ event_id: str,
+ room_id: str,
+ limit: int = 5,
+ ignored_users: FrozenSet[str] = frozenset(),
+ ) -> List[JsonDict]:
+ """Get a list of annotations on the event, grouped by event type and
+ aggregation key, sorted by count.
+
+ This is used e.g. to get the what and how many reactions have happend
+ on an event.
+
+ Args:
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ limit: Only fetch the `limit` groups.
+ ignored_users: The users ignored by the requesting user.
+
+ Returns:
+ List of groups of annotations that match. Each row is a dict with
+ `type`, `key` and `count` fields.
+ """
+ # Get the base results for all users.
+ full_results = await self._main_store.get_aggregation_groups_for_event(
+ event_id, room_id, limit
+ )
+
+ # Then subtract off the results for any ignored users.
+ ignored_results = await self._main_store.get_aggregation_groups_for_users(
+ event_id, room_id, limit, ignored_users
+ )
+
+ filtered_results = []
+ for result in full_results:
+ key = (result["type"], result["key"])
+ if key in ignored_results:
+ result = result.copy()
+ result["count"] -= ignored_results[key]
+ if result["count"] <= 0:
+ continue
+ filtered_results.append(result)
+
+ return filtered_results
+
async def _get_bundled_aggregation_for_event(
- self, event: EventBase, user_id: str
+ self, event: EventBase, ignored_users: FrozenSet[str]
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
@@ -167,7 +265,7 @@ class RelationsHandler:
Args:
event: The event to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
+ ignored_users: The users ignored by the requesting user.
Returns:
The bundled aggregations for an event, if bundled aggregations are
@@ -190,23 +288,125 @@ class RelationsHandler:
# while others need more processing during serialization.
aggregations = BundledAggregations()
- annotations = await self._main_store.get_aggregation_groups_for_event(
- event_id, room_id
+ annotations = await self.get_annotations_for_event(
+ event_id, room_id, ignored_users=ignored_users
)
- if annotations.chunk:
- aggregations.annotations = await annotations.to_dict(
- cast("DataStore", self)
- )
-
- references = await self._main_store.get_relations_for_event(
- event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
+ if annotations:
+ aggregations.annotations = {"chunk": annotations}
+
+ references, next_token = await self.get_relations_for_event(
+ event_id,
+ event,
+ room_id,
+ RelationTypes.REFERENCE,
+ ignored_users=ignored_users,
)
- if references.chunk:
- aggregations.references = await references.to_dict(cast("DataStore", self))
+ if references:
+ aggregations.references = {
+ "chunk": [{"event_id": event.event_id} for event in references]
+ }
+
+ if next_token:
+ aggregations.references["next_batch"] = await next_token.to_string(
+ self._main_store
+ )
# Store the bundled aggregations in the event metadata for later use.
return aggregations
+ async def get_threads_for_events(
+ self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
+ ) -> Dict[str, _ThreadAggregation]:
+ """Get the bundled aggregations for threads for the requested events.
+
+ Args:
+ event_ids: Events to get aggregations for threads.
+ user_id: The user requesting the bundled aggregations.
+ ignored_users: The users ignored by the requesting user.
+
+ Returns:
+ A dictionary mapping event ID to the thread information.
+
+ May not contain a value for all requested event IDs.
+ """
+ user = UserID.from_string(user_id)
+
+ # Fetch thread summaries.
+ summaries = await self._main_store.get_thread_summaries(event_ids)
+
+ # Only fetch participated for a limited selection based on what had
+ # summaries.
+ thread_event_ids = [
+ event_id for event_id, summary in summaries.items() if summary
+ ]
+ participated = await self._main_store.get_threads_participated(
+ thread_event_ids, user_id
+ )
+
+ # Then subtract off the results for any ignored users.
+ ignored_results = await self._main_store.get_threaded_messages_per_user(
+ thread_event_ids, ignored_users
+ )
+
+ # A map of event ID to the thread aggregation.
+ results = {}
+
+ for event_id, summary in summaries.items():
+ if summary:
+ thread_count, latest_thread_event, edit = summary
+
+ # Subtract off the count of any ignored users.
+ for ignored_user in ignored_users:
+ thread_count -= ignored_results.get((event_id, ignored_user), 0)
+
+ # This is gnarly, but if the latest event is from an ignored user,
+ # attempt to find one that isn't from an ignored user.
+ if latest_thread_event.sender in ignored_users:
+ room_id = latest_thread_event.room_id
+
+ # If the root event is not found, something went wrong, do
+ # not include a summary of the thread.
+ event = await self._event_handler.get_event(user, room_id, event_id)
+ if event is None:
+ continue
+
+ potential_events, _ = await self.get_relations_for_event(
+ event_id,
+ event,
+ room_id,
+ RelationTypes.THREAD,
+ ignored_users,
+ )
+
+ # If all found events are from ignored users, do not include
+ # a summary of the thread.
+ if not potential_events:
+ continue
+
+ # The *last* event returned is the one that is cared about.
+ event = await self._event_handler.get_event(
+ user, room_id, potential_events[-1].event_id
+ )
+ # It is unexpected that the event will not exist.
+ if event is None:
+ logger.warning(
+ "Unable to fetch latest event in a thread with event ID: %s",
+ potential_events[-1].event_id,
+ )
+ continue
+ latest_thread_event = event
+
+ results[event_id] = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ latest_edit=edit,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=participated[event_id],
+ )
+
+ return results
+
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
@@ -230,13 +430,21 @@ class RelationsHandler:
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
+ # Fetch any ignored users of the requesting user.
+ ignored_users = await self._main_store.ignored_users(user_id)
+
# Fetch other relations per event.
for event in events_by_id.values():
- event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ event_result = await self._get_bundled_aggregation_for_event(
+ event, ignored_users
+ )
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
+ #
+ # Note that there is no use in limiting edits by ignored users since the
+ # parent event should be ignored in the first place if the user is ignored.
edits = await self._main_store.get_applicable_edits(
[
event_id
@@ -247,25 +455,10 @@ class RelationsHandler:
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
- # Fetch thread summaries.
- summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
- # Only fetch participated for a limited selection based on what had
- # summaries.
- participated = await self._main_store.get_threads_participated(
- [event_id for event_id, summary in summaries.items() if summary], user_id
+ threads = await self.get_threads_for_events(
+ events_by_id.keys(), user_id, ignored_users
)
- for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event, edit = summary
- results.setdefault(
- event_id, BundledAggregations()
- ).thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- latest_edit=edit,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=participated[event_id],
- )
+ for event_id, thread in threads.items():
+ results.setdefault(event_id, BundledAggregations()).thread = thread
return results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 092e185c..b31f00b5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -771,7 +771,9 @@ class RoomCreationHandler:
% (user_id,),
)
- visibility = config.get("visibility", None)
+ # The spec says rooms should default to private visibility if
+ # `visibility` is not specified.
+ visibility = config.get("visibility", "private")
is_public = visibility == "public"
room_id = await self._generate_room_id(
@@ -881,7 +883,7 @@ class RoomCreationHandler:
#
# we also don't need to check the requester's shadow-ban here, as we
# have already done so above (and potentially emptied invite_list).
- with (await self.room_member_handler.member_linearizer.queue((room_id,))):
+ async with self.room_member_handler.member_linearizer.queue((room_id,)):
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
@@ -1442,8 +1444,8 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def get_current_key(self) -> RoomStreamToken:
return self.store.get_room_max_token()
- def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
- return self.store.get_room_events_max_id(room_id)
+ def get_current_key_for_room(self, room_id: str) -> Awaitable[RoomStreamToken]:
+ return self.store.get_current_room_stream_token_for_room_id(room_id)
class ShutdownRoomResponse(TypedDict):
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index a0255bd1..78e299d3 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -156,8 +156,8 @@ class RoomBatchHandler:
) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists
them in a floating state event chain which don't resolve into the current room
- state. They are floating because they reference no prev_events and are marked
- as outliers which disconnects them from the normal DAG.
+ state. They are floating because they reference no prev_events which disconnects
+ them from the normal DAG.
Args:
state_events_at_start:
@@ -213,31 +213,23 @@ class RoomBatchHandler:
room_id=room_id,
action=membership,
content=event_dict["content"],
- # Mark as an outlier to disconnect it from the normal DAG
- # and not show up between batches of history.
- outlier=True,
historical=True,
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
- # Since each state event is marked as an outlier, the
- # `EventContext.for_outlier()` won't have any `state_ids`
- # set and therefore can't derive any state even though the
- # prev_events are set. Also since the first event in the
- # state chain is floating with no `prev_events`, it can't
- # derive state from anywhere automatically. So we need to
- # set some state explicitly.
+ # The first event in the state chain is floating with no
+ # `prev_events` which means it can't derive state from
+ # anywhere automatically. So we need to set some state
+ # explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
- # reference and also update in the event when we append later.
+ # reference and also update in the event when we append
+ # later.
state_event_ids=state_event_ids.copy(),
)
else:
- # TODO: Add some complement tests that adds state that is not member joins
- # and will use this code path. Maybe we only want to support join state events
- # and can get rid of this `else`?
(
event,
_,
@@ -246,21 +238,15 @@ class RoomBatchHandler:
state_event["sender"], app_service_requester.app_service
),
event_dict,
- # Mark as an outlier to disconnect it from the normal DAG
- # and not show up between batches of history.
- outlier=True,
historical=True,
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
- # Since each state event is marked as an outlier, the
- # `EventContext.for_outlier()` won't have any `state_ids`
- # set and therefore can't derive any state even though the
- # prev_events are set. Also since the first event in the
- # state chain is floating with no `prev_events`, it can't
- # derive state from anywhere automatically. So we need to
- # set some state explicitly.
+ # The first event in the state chain is floating with no
+ # `prev_events` which means it can't derive state from
+ # anywhere automatically. So we need to set some state
+ # explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0785e311..802e57c4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -515,8 +515,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We first linearise by the application service (to try to limit concurrent joins
# by application services), and then by room ID.
- with (await self.member_as_limiter.queue(as_id)):
- with (await self.member_linearizer.queue(key)):
+ async with self.member_as_limiter.queue(as_id):
+ async with self.member_linearizer.queue(key):
result = await self.update_membership_locked(
requester,
target,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 30eddda6..102dd4b5 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -59,8 +59,6 @@ class SearchHandler:
self.state_store = self.storage.state
self.auth = hs.get_auth()
- self._msc3666_enabled = hs.config.experimental.msc3666_enabled
-
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room.
@@ -353,22 +351,20 @@ class SearchHandler:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
- aggregations = None
- if self._msc3666_enabled:
- aggregations = await self._relations_handler.get_bundled_aggregations(
- # Generate an iterable of EventBase for all the events that will be
- # returned, including contextual events.
- itertools.chain(
- # The events_before and events_after for each context.
- itertools.chain.from_iterable(
- itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
- for context in contexts.values()
- ),
- # The returned events.
- search_result.allowed_events,
+ aggregations = await self._relations_handler.get_bundled_aggregations(
+ # Generate an iterable of EventBase for all the events that will be
+ # returned, including contextual events.
+ itertools.chain(
+ # The events_before and events_after for each context.
+ itertools.chain.from_iterable(
+ itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
+ for context in contexts.values()
),
- user.to_string(),
- )
+ # The returned events.
+ search_result.allowed_events,
+ ),
+ user.to_string(),
+ )
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 4f02a060..e4fe94e5 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -430,7 +430,7 @@ class SsoHandler:
# grab a lock while we try to find a mapping for this user. This seems...
# optimistic, especially for implementations that end up redirecting to
# interstitial pages.
- with await self._mapping_lock.queue(auth_provider_id):
+ async with self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6c569cfb..6c8b17c4 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,17 +13,7 @@
# limitations under the License.
import itertools
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Collection,
- Dict,
- FrozenSet,
- List,
- Optional,
- Set,
- Tuple,
-)
+from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
+ DeviceListUpdates,
JsonDict,
MutableStateMap,
Requester,
@@ -184,21 +175,6 @@ class GroupsSyncResult:
return bool(self.join or self.invite or self.leave)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class DeviceLists:
- """
- Attributes:
- changed: List of user_ids whose devices may have changed
- left: List of user_ids whose devices we no longer track
- """
-
- changed: Collection[str]
- left: Collection[str]
-
- def __bool__(self) -> bool:
- return bool(self.changed or self.left)
-
-
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
@@ -240,7 +216,7 @@ class SyncResult:
knocked: List[KnockedSyncResult]
archived: List[ArchivedSyncResult]
to_device: List[JsonDict]
- device_lists: DeviceLists
+ device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict
device_unused_fallback_key_types: List[str]
groups: Optional[GroupsSyncResult]
@@ -298,6 +274,8 @@ class SyncHandler:
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
+ self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
+
async def wait_for_sync_for_user(
self,
requester: Requester,
@@ -1177,8 +1155,9 @@ class SyncHandler:
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
- logger.debug("Fetching group data")
- await self._generate_sync_entry_for_groups(sync_result_builder)
+ if self.hs_config.experimental.groups_enabled:
+ logger.debug("Fetching group data")
+ await self._generate_sync_entry_for_groups(sync_result_builder)
num_events = 0
@@ -1262,8 +1241,8 @@ class SyncHandler:
newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
- ) -> DeviceLists:
- """Generate the DeviceLists section of sync
+ ) -> DeviceListUpdates:
+ """Generate the DeviceListUpdates section of sync
Args:
sync_result_builder
@@ -1381,9 +1360,11 @@ class SyncHandler:
if any(e.room_id in joined_rooms for e in entries):
newly_left_users.discard(user_id)
- return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
+ return DeviceListUpdates(
+ changed=users_that_have_changed, left=newly_left_users
+ )
else:
- return DeviceLists(changed=[], left=[])
+ return DeviceListUpdates()
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
@@ -1607,13 +1588,15 @@ class SyncHandler:
ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
- sync_result_builder, ignored_users
+ sync_result_builder, ignored_users, self.rooms_to_exclude
)
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
)
else:
- room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
+ room_changes = await self._get_all_rooms(
+ sync_result_builder, ignored_users, self.rooms_to_exclude
+ )
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1689,7 +1672,10 @@ class SyncHandler:
return False
async def _get_rooms_changed(
- self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
+ self,
+ sync_result_builder: "SyncResultBuilder",
+ ignored_users: FrozenSet[str],
+ excluded_rooms: List[str],
) -> _RoomChanges:
"""Determine the changes in rooms to report to the user.
@@ -1721,7 +1707,7 @@ class SyncHandler:
# _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code.
membership_change_events = await self.store.get_membership_changes_for_user(
- user_id, since_token.room_key, now_token.room_key
+ user_id, since_token.room_key, now_token.room_key, excluded_rooms
)
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@@ -1865,6 +1851,7 @@ class SyncHandler:
full_state=False,
since_token=since_token,
upto_token=leave_token,
+ out_of_band=leave_event.internal_metadata.is_out_of_band_membership(),
)
)
@@ -1922,7 +1909,10 @@ class SyncHandler:
)
async def _get_all_rooms(
- self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
+ self,
+ sync_result_builder: "SyncResultBuilder",
+ ignored_users: FrozenSet[str],
+ ignored_rooms: List[str],
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
@@ -1933,7 +1923,7 @@ class SyncHandler:
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
-
+ ignored_rooms: List of rooms to ignore.
"""
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1944,6 +1934,7 @@ class SyncHandler:
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id,
membership_list=Membership.LIST,
+ excluded_rooms=ignored_rooms,
)
room_entries = []
@@ -2127,33 +2118,41 @@ class SyncHandler:
):
return
- state = await self.compute_state_delta(
- room_id,
- batch,
- sync_config,
- since_token,
- now_token,
- full_state=full_state,
- )
+ if not room_builder.out_of_band:
+ state = await self.compute_state_delta(
+ room_id,
+ batch,
+ sync_config,
+ since_token,
+ now_token,
+ full_state=full_state,
+ )
+ else:
+ # An out of band room won't have any state changes.
+ state = {}
summary: Optional[JsonDict] = {}
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
# the name itself).
- if sync_config.filter_collection.lazy_load_members() and (
- # we recalculate the summary:
- # if there are membership changes in the timeline, or
- # if membership has changed during a gappy sync, or
- # if this is an initial sync.
- any(ev.type == EventTypes.Member for ev in batch.events)
- or (
- # XXX: this may include false positives in the form of LL
- # members which have snuck into state
- batch.limited
- and any(t == EventTypes.Member for (t, k) in state)
+ if (
+ not room_builder.out_of_band
+ and sync_config.filter_collection.lazy_load_members()
+ and (
+ # we recalculate the summary:
+ # if there are membership changes in the timeline, or
+ # if membership has changed during a gappy sync, or
+ # if this is an initial sync.
+ any(ev.type == EventTypes.Member for ev in batch.events)
+ or (
+ # XXX: this may include false positives in the form of LL
+ # members which have snuck into state
+ batch.limited
+ and any(t == EventTypes.Member for (t, k) in state)
+ )
+ or since_token is None
)
- or since_token is None
):
summary = await self.compute_summary(
room_id, sync_config, batch, state, now_token
@@ -2397,6 +2396,8 @@ class RoomSyncResultBuilder:
full_state: Whether the full state should be sent in result
since_token: Earliest point to return events from, or None
upto_token: Latest point to return events from.
+ out_of_band: whether the events in the room are "out of band" events
+ and the server isn't in the room.
"""
room_id: str
@@ -2406,3 +2407,5 @@ class RoomSyncResultBuilder:
full_state: bool
since_token: Optional[StreamToken]
upto_token: StreamToken
+
+ out_of_band: bool = False
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 014754a6..472b029a 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -107,6 +107,8 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
+ assert self._secret is not None
+
resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={
diff --git a/synapse/http/client.py b/synapse/http/client.py
index c01d2326..8310fb46 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -22,7 +22,6 @@ from typing import (
BinaryIO,
Callable,
Dict,
- Iterable,
List,
Mapping,
Optional,
@@ -72,6 +71,7 @@ from twisted.web.iweb import (
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
from synapse.http.proxyagent import ProxyAgent
+from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
@@ -97,10 +97,6 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the entries can either be Lists or bytes.
RawHeaderValue = Sequence[Union[str, bytes]]
-# the type of the query params, to be passed into `urlencode`
-QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
-QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
-
def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
@@ -911,7 +907,7 @@ def read_body_with_max_size(
return d
-def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
+def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
@@ -924,13 +920,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
if args is None:
return b""
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, str):
- vs = [vs]
- encoded_args[k] = [v.encode("utf8") for v in vs]
-
- query_str = urllib.parse.urlencode(encoded_args, True)
+ query_str = urllib.parse.urlencode(args, True)
return query_str.encode("utf8")
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 6b98d865..5097b3ca 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -67,6 +67,7 @@ from synapse.http.client import (
read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.types import QueryParams
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
@@ -98,10 +99,6 @@ MAXINT = sys.maxsize
_next_id = 1
-
-QueryArgs = Dict[str, Union[str, List[str]]]
-
-
T = TypeVar("T")
@@ -144,7 +141,7 @@ class MatrixFederationRequest:
"""A callback to generate the JSON.
"""
- query: Optional[dict] = None
+ query: Optional[QueryParams] = None
"""Query arguments.
"""
@@ -165,10 +162,7 @@ class MatrixFederationRequest:
destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii")
- if self.query:
- query_bytes = encode_query_args(self.query)
- else:
- query_bytes = b""
+ query_bytes = encode_query_args(self.query)
# The object is frozen so we can pre-compute this.
uri = urllib.parse.urlunparse(
@@ -485,10 +479,7 @@ class MatrixFederationHttpClient:
method_bytes = request.method.encode("ascii")
destination_bytes = request.destination.encode("ascii")
path_bytes = request.path.encode("ascii")
- if request.query:
- query_bytes = encode_query_args(request.query)
- else:
- query_bytes = b""
+ query_bytes = encode_query_args(request.query)
scope = start_active_span(
"outgoing-federation-request",
@@ -746,7 +737,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
@@ -764,7 +755,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
@@ -781,7 +772,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
@@ -891,7 +882,7 @@ class MatrixFederationHttpClient:
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
"""Sends the specified json data using POST
@@ -961,7 +952,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
@@ -976,7 +967,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- args: Optional[QueryArgs] = ...,
+ args: Optional[QueryParams] = ...,
retry_on_dns_fail: bool = ...,
timeout: Optional[int] = ...,
ignore_backoff: bool = ...,
@@ -990,7 +981,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
@@ -1085,7 +1076,7 @@ class MatrixFederationHttpClient:
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
@@ -1150,7 +1141,7 @@ class MatrixFederationHttpClient:
destination: str,
path: str,
output_stream,
- args: Optional[QueryArgs] = None,
+ args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
diff --git a/synapse/http/types.py b/synapse/http/types.py
new file mode 100644
index 00000000..11fe232d
--- /dev/null
+++ b/synapse/http/types.py
@@ -0,0 +1,21 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Iterable, Mapping, Union
+
+# the type of the query params, to be passed into `urlencode` with `doseq=True`.
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
+__all__ = ["QueryParams"]
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 3ebed5c1..f86ee9aa 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -289,6 +289,9 @@ class SynapseTags:
# Uniqueish ID of a database transaction
DB_TXN_ID = "db.txn_id"
+ # The name of the external cache
+ CACHE_NAME = "cache.name"
+
class SynapseBaggage:
FORCE_TRACING = "synapse-force-tracing"
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index d321946a..fffd8354 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -52,12 +53,13 @@ from synapse.metrics._exposition import (
start_http_server,
)
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
+from synapse.metrics._types import Collector
logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
-all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
+all_gauges: Dict[str, Collector] = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@@ -78,11 +80,10 @@ RegistryProxy = cast(CollectorRegistry, _RegistryProxy)
@attr.s(slots=True, hash=True, auto_attribs=True)
-class LaterGauge:
-
+class LaterGauge(Collector):
name: str
desc: str
- labels: Optional[Iterable[str]] = attr.ib(hash=False)
+ labels: Optional[Sequence[str]] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
@@ -125,7 +126,7 @@ class LaterGauge:
MetricsEntry = TypeVar("MetricsEntry")
-class InFlightGauge(Generic[MetricsEntry]):
+class InFlightGauge(Generic[MetricsEntry], Collector):
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight
at any given time.
@@ -246,7 +247,7 @@ class InFlightGauge(Generic[MetricsEntry]):
all_gauges[self.name] = self
-class GaugeBucketCollector:
+class GaugeBucketCollector(Collector):
"""Like a Histogram, but the buckets are Gauges which are updated atomically.
The data is updated by calling `update_data` with an iterable of measurements.
@@ -340,7 +341,7 @@ class GaugeBucketCollector:
#
-class CPUMetrics:
+class CPUMetrics(Collector):
def __init__(self) -> None:
ticks_per_sec = 100
try:
@@ -470,6 +471,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
__all__ = [
+ "Collector",
"MetricsResource",
"generate_latest",
"start_http_server",
diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py
index 2bc909ef..b7d47ce3 100644
--- a/synapse/metrics/_gc.py
+++ b/synapse/metrics/_gc.py
@@ -30,6 +30,8 @@ from prometheus_client.core import (
from twisted.internet import task
+from synapse.metrics._types import Collector
+
"""Prometheus metrics for garbage collection"""
@@ -71,7 +73,7 @@ gc_time = Histogram(
)
-class GCCounts:
+class GCCounts(Collector):
def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
@@ -135,7 +137,7 @@ def install_gc_manager() -> None:
#
-class PyPyGCStats:
+class PyPyGCStats(Collector):
def collect(self) -> Iterable[Metric]:
# @stats is a pretty-printer object with __str__() returning a nice table,
diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py
index f38f7983..a2c6e684 100644
--- a/synapse/metrics/_reactor_metrics.py
+++ b/synapse/metrics/_reactor_metrics.py
@@ -21,6 +21,8 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily
from twisted.internet import reactor
+from synapse.metrics._types import Collector
+
#
# Twisted reactor metrics
#
@@ -54,7 +56,7 @@ class EpollWrapper:
return getattr(self._poller, item)
-class ReactorLastSeenMetric:
+class ReactorLastSeenMetric(Collector):
def __init__(self, epoll_wrapper: EpollWrapper):
self._epoll_wrapper = epoll_wrapper
diff --git a/synapse/metrics/_types.py b/synapse/metrics/_types.py
new file mode 100644
index 00000000..dc5aa493
--- /dev/null
+++ b/synapse/metrics/_types.py
@@ -0,0 +1,31 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+from typing import Iterable
+
+from prometheus_client import Metric
+
+try:
+ from prometheus_client.registry import Collector
+except ImportError:
+ # prometheus_client.Collector is new as of prometheus 0.14. We redefine it here
+ # for compatibility with earlier versions.
+ class _Collector(ABC):
+ @abstractmethod
+ def collect(self) -> Iterable[Metric]:
+ pass
+
+ Collector = _Collector # type: ignore
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 53c508af..f61396bb 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -46,6 +46,7 @@ from synapse.logging.opentracing import (
noop_context_manager,
start_active_span,
)
+from synapse.metrics._types import Collector
if TYPE_CHECKING:
import resource
@@ -127,7 +128,7 @@ _background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set(
_bg_metrics_lock = threading.Lock()
-class _Collector:
+class _Collector(Collector):
"""A custom metrics collector for the background process metrics.
Ensures that all of the metrics are up-to-date with any in-flight processes
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 98ed9c08..6bc329f0 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -16,11 +16,13 @@ import ctypes
import logging
import os
import re
-from typing import Iterable, Optional
+from typing import Iterable, Optional, overload
-from prometheus_client import Metric
+from prometheus_client import REGISTRY, Metric
+from typing_extensions import Literal
-from synapse.metrics import REGISTRY, GaugeMetricFamily
+from synapse.metrics import GaugeMetricFamily
+from synapse.metrics._types import Collector
logger = logging.getLogger(__name__)
@@ -59,6 +61,16 @@ def _setup_jemalloc_stats() -> None:
jemalloc = ctypes.CDLL(jemalloc_path)
+ @overload
+ def _mallctl(
+ name: str, read: Literal[True] = True, write: Optional[int] = None
+ ) -> int:
+ ...
+
+ @overload
+ def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
+ ...
+
def _mallctl(
name: str, read: bool = True, write: Optional[int] = None
) -> Optional[int]:
@@ -134,7 +146,7 @@ def _setup_jemalloc_stats() -> None:
except Exception as e:
logger.warning("Failed to reload jemalloc stats: %s", e)
- class JemallocCollector:
+ class JemallocCollector(Collector):
"""Metrics for internal jemalloc stats."""
def collect(self) -> Iterable[Metric]:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index ba9755f0..8f9e6292 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -62,8 +62,10 @@ from synapse.events.third_party_rules import (
ON_CREATE_ROOM_CALLBACK,
ON_NEW_EVENT_CALLBACK,
ON_PROFILE_UPDATE_CALLBACK,
+ ON_THREEPID_BIND_CALLBACK,
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
)
+from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
from synapse.handlers.account_validity import (
IS_USER_EXPIRED_CALLBACK,
ON_LEGACY_ADMIN_REQUEST,
@@ -117,6 +119,7 @@ from synapse.types import (
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.caches.descriptors import cached
+from synapse.util.frozenutils import freeze
if TYPE_CHECKING:
from synapse.app.generic_worker import GenericWorkerSlavedStore
@@ -209,12 +212,14 @@ class ModuleApi:
# We expose these as properties below in order to attach a helpful docstring.
self._http_client: SimpleHttpClient = hs.get_simple_http_client()
self._public_room_list_manager = PublicRoomListManager(hs)
+ self._account_data_manager = AccountDataManager(hs)
self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._password_auth_provider = hs.get_password_auth_provider()
self._presence_router = hs.get_presence_router()
+ self._account_data_handler = hs.get_account_data_handler()
#################################################################################
# The following methods should only be called during the module's initialisation.
@@ -293,6 +298,7 @@ class ModuleApi:
on_user_deactivation_status_changed: Optional[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None,
+ on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
) -> None:
"""Registers callbacks for third party event rules capabilities.
@@ -308,6 +314,7 @@ class ModuleApi:
check_can_deactivate_user=check_can_deactivate_user,
on_profile_update=on_profile_update,
on_user_deactivation_status_changed=on_user_deactivation_status_changed,
+ on_threepid_bind=on_threepid_bind,
)
def register_presence_router_callbacks(
@@ -373,6 +380,19 @@ class ModuleApi:
min_batch_size=min_batch_size,
)
+ def register_account_data_callbacks(
+ self,
+ *,
+ on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None,
+ ) -> None:
+ """Registers account data callbacks.
+
+ Added in Synapse 1.57.0.
+ """
+ return self._account_data_handler.register_module_callbacks(
+ on_account_data_updated=on_account_data_updated,
+ )
+
def register_web_resource(self, path: str, resource: Resource) -> None:
"""Registers a web resource to be served at the given path.
@@ -414,6 +434,14 @@ class ModuleApi:
return self._public_room_list_manager
@property
+ def account_data_manager(self) -> "AccountDataManager":
+ """Allows reading and modifying users' account data.
+
+ Added in Synapse v1.57.0.
+ """
+ return self._account_data_manager
+
+ @property
def public_baseurl(self) -> str:
"""The configured public base URL for this homeserver.
@@ -512,6 +540,17 @@ class ModuleApi:
"""
return await self._store.is_server_admin(UserID.from_string(user_id))
+ async def set_user_admin(self, user_id: str, admin: bool) -> None:
+ """Sets if a user is a server admin.
+
+ Added in Synapse v1.56.0.
+
+ Args:
+ user_id: The Matrix ID of the user to set admin status for.
+ admin: True iff the user is to be a server admin, false otherwise.
+ """
+ await self._store.set_server_admin(UserID.from_string(user_id), admin)
+
def get_qualified_user_id(self, username: str) -> str:
"""Qualify a user id, if necessary
@@ -1357,3 +1396,69 @@ class PublicRoomListManager:
room_id: The ID of the room.
"""
await self._store.set_room_is_public(room_id, False)
+
+
+class AccountDataManager:
+ """
+ Allows modules to manage account data.
+ """
+
+ def __init__(self, hs: "HomeServer") -> None:
+ self._hs = hs
+ self._store = hs.get_datastores().main
+ self._handler = hs.get_account_data_handler()
+
+ def _validate_user_id(self, user_id: str) -> None:
+ """
+ Validates a user ID is valid and local.
+ Private method to be used in other account data methods.
+ """
+ user = UserID.from_string(user_id)
+ if not self._hs.is_mine(user):
+ raise ValueError(
+ f"{user_id} is not local to this homeserver; can't access account data for remote users."
+ )
+
+ async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]:
+ """
+ Gets some global account data, of a specified type, for the specified user.
+
+ The provided user ID must be a valid user ID of a local user.
+
+ Added in Synapse v1.57.0.
+ """
+ self._validate_user_id(user_id)
+
+ data = await self._store.get_global_account_data_by_type_for_user(
+ user_id, data_type
+ )
+ # We clone and freeze to prevent the module accidentally mutating the
+ # dict that lives in the cache, as that could introduce nasty bugs.
+ return freeze(data)
+
+ async def put_global(
+ self, user_id: str, data_type: str, new_data: JsonDict
+ ) -> None:
+ """
+ Puts some global account data, of a specified type, for the specified user.
+
+ The provided user ID must be a valid user ID of a local user.
+
+ Please note that this will overwrite existing the account data of that type
+ for that user!
+
+ Added in Synapse v1.57.0.
+ """
+ self._validate_user_id(user_id)
+
+ if not isinstance(data_type, str):
+ raise TypeError(f"data_type must be a str; got {type(data_type).__name__}")
+
+ if not isinstance(new_data, dict):
+ raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
+
+ # Ensure the user exists, so we don't just write to users that aren't there.
+ if await self._store.get_userinfo_by_id(user_id) is None:
+ raise ValueError(f"User {user_id} does not exist on this server.")
+
+ await self._handler.add_account_data_for_user(user_id, data_type, new_data)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index a402a3e4..b07cf2ee 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -397,7 +397,7 @@ class RulesForRoom:
self.room_push_rule_cache_metrics.inc_hits()
return self.data.rules_by_user
- with (await self.linearizer.queue(self.room_id)):
+ async with self.linearizer.queue(self.room_id):
if state_group and self.data.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index d02cca0b..ec199a16 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -48,14 +48,13 @@ REQUIREMENTS = [
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.4.0",
# we use the type definitions added in signedjson 1.1.
- "signedjson>=1.1.0,<=1.1.1",
+ "signedjson>=1.1.0",
"pynacl>=1.2.1",
- "idna>=2.5",
# validating SSL certs for IP addresses requires service_identity 18.1.
"service_identity>=18.1.0",
# Twisted 18.9 introduces some logger improvements that the structured
# logger utilises
- "Twisted>=18.9.0",
+ "Twisted[tls]>=18.9.0",
"treq>=15.1",
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=16.0.0",
@@ -89,6 +88,9 @@ REQUIREMENTS = [
"matrix-common~=1.1.0",
# We need packaging.requirements.Requirement, added in 16.1.
"packaging>=16.1",
+ # At the time of writing, we only use functions from the version `importlib.metadata`
+ # which shipped in Python 3.8. This corresponds to version 1.4 of the backport.
+ "importlib_metadata>=1.4 ; python_version < '3.8'",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
deleted file mode 100644
index 14706a08..00000000
--- a/synapse/replication/slave/storage/client_ips.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING
-
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.util.caches.lrucache import LruCache
-
-from ._base import BaseSlavedStore
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class SlavedClientIpStore(BaseSlavedStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
- cache_name="client_ip_last_seen", max_size=50000
- )
-
- async def insert_client_ip(
- self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
- ) -> None:
- now = int(self._clock.time_msec())
- key = (user_id, access_token, ip)
-
- try:
- last_seen = self.client_ip_last_seen.get(key)
- except KeyError:
- last_seen = None
-
- # Rate-limited inserts
- if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
- return
-
- self.client_ip_last_seen.set(key, now)
-
- self.hs.get_replication_command_handler().send_user_ip(
- user_id, access_token, ip, user_agent, device_id, now
- )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0ffd34f1..30717c2b 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -20,7 +20,6 @@ from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatu
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -33,8 +32,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
- super().__init__(database, db_conn, hs)
-
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
@@ -44,18 +41,11 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
],
)
- device_list_max = self._device_list_id_gen.get_current_token()
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max
- )
- self._user_signature_stream_cache = StreamChangeCache(
- "UserSignatureStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
+
+ super().__init__(database, db_conn, hs)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index deeaaec4..122892c7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -451,7 +451,7 @@ class FederationSenderHandler:
# service for robustness? Or could we replace it with an assertion that
# we're not being re-entered?
- with (await self._fed_position_linearizer.queue(None)):
+ async with self._fed_position_linearizer.queue(None):
# We persist and ack the same position, so we take a copy of it
# here as otherwise it can get modified from underneath us.
current_position = self.federation_position
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 3654f6c0..fe349481 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -356,7 +356,7 @@ class UserIpCommand(Command):
access_token: str,
ip: str,
user_agent: str,
- device_id: str,
+ device_id: Optional[str],
last_seen: int,
):
self.user_id = user_id
@@ -389,6 +389,12 @@ class UserIpCommand(Command):
)
)
+ def __repr__(self) -> str:
+ return (
+ f"UserIpCommand({self.user_id!r}, .., {self.ip!r}, "
+ f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
+ )
+
class RemoteServerUpCommand(_SimpleCommand):
"""Sent when a worker has detected that a remote server is no longer
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index bf7d0179..a448dd7e 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Optional
from prometheus_client import Counter, Histogram
+from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder
@@ -93,14 +94,18 @@ class ExternalCache:
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
- with response_timer.labels("set").time():
- return await make_deferred_yieldable(
- self._redis_connection.set(
- self._get_redis_key(cache_name, key),
- encoded_value,
- pexpire=expiry_ms,
+ with opentracing.start_active_span(
+ "ExternalCache.set",
+ tags={opentracing.SynapseTags.CACHE_NAME: cache_name},
+ ):
+ with response_timer.labels("set").time():
+ return await make_deferred_yieldable(
+ self._redis_connection.set(
+ self._get_redis_key(cache_name, key),
+ encoded_value,
+ pexpire=expiry_ms,
+ )
)
- )
async def get(self, cache_name: str, key: str) -> Optional[Any]:
"""Look up a key/value in the named cache."""
@@ -108,10 +113,14 @@ class ExternalCache:
if self._redis_connection is None:
return None
- with response_timer.labels("get").time():
- result = await make_deferred_yieldable(
- self._redis_connection.get(self._get_redis_key(cache_name, key))
- )
+ with opentracing.start_active_span(
+ "ExternalCache.get",
+ tags={opentracing.SynapseTags.CACHE_NAME: cache_name},
+ ):
+ with response_timer.labels("get").time():
+ result = await make_deferred_yieldable(
+ self._redis_connection.get(self._get_redis_key(cache_name, key))
+ )
logger.debug("Got cache result %s %s: %r", cache_name, key, result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b217c35f..615f1828 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -235,6 +235,14 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, it's the background worker that should
+ # receive USER_IP commands and store the relevant client IPs.
+ self._should_insert_client_ips = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._should_insert_client_ips = hs.get_instance_name() == "master"
+
def _add_command_to_stream_queue(
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
@@ -401,23 +409,37 @@ class ReplicationCommandHandler:
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
- if self._is_master:
+ if self._is_master or self._should_insert_client_ips:
+ # We make a point of only returning an awaitable if there's actually
+ # something to do; on_USER_IP is not an async function, but
+ # _handle_user_ip is.
+ # If on_USER_IP returns an awaitable, it gets scheduled as a
+ # background process (see `BaseReplicationStreamProtocol.handle_command`).
return self._handle_user_ip(cmd)
else:
+ # Returning None when this process definitely has nothing to do
+ # reduces the overhead of handling the USER_IP command, which is
+ # currently broadcast to all workers regardless of utility.
return None
async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
- await self._store.insert_client_ip(
- cmd.user_id,
- cmd.access_token,
- cmd.ip,
- cmd.user_agent,
- cmd.device_id,
- cmd.last_seen,
- )
-
- assert self._server_notices_sender is not None
- await self._server_notices_sender.on_user_ip(cmd.user_id)
+ """
+ Handles a User IP, branching depending on whether we are the main process
+ and/or the background worker.
+ """
+ if self._is_master:
+ assert self._server_notices_sender is not None
+ await self._server_notices_sender.on_user_ip(cmd.user_id)
+
+ if self._should_insert_client_ips:
+ await self._store.insert_client_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
+ cmd.last_seen,
+ )
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
if cmd.instance_name == self._instance_name:
@@ -698,7 +720,7 @@ class ReplicationCommandHandler:
access_token: str,
ip: str,
user_agent: str,
- device_id: str,
+ device_id: Optional[str],
last_seen: int,
) -> None:
"""Tell the master that the user made a request."""
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index c16078b1..3cae6d2b 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -12,22 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""This class implements the proposed relation APIs from MSC 1849.
-
-Since the MSC has not been approved all APIs here are unstable and may change at
-any time to reflect changes in the MSC.
-"""
-
import logging
from typing import TYPE_CHECKING, Optional, Tuple
-from synapse.api.constants import RelationTypes
-from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.storage.relations import AggregationPaginationToken
from synapse.types import JsonDict, StreamToken
if TYPE_CHECKING:
@@ -44,7 +35,7 @@ class RelationPaginationServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
- releases=(),
+ releases=("v1",),
)
def __init__(self, hs: "HomeServer"):
@@ -93,166 +84,5 @@ class RelationPaginationServlet(RestServlet):
return 200, result
-class RelationAggregationPaginationServlet(RestServlet):
- """API to paginate aggregation groups of relations, e.g. paginate the
- types and counts of the reactions on the events.
-
- Example request and response:
-
- GET /rooms/{room_id}/aggregations/{parent_id}
-
- {
- chunk: [
- {
- "type": "m.reaction",
- "key": "👍",
- "count": 3
- }
- ]
- }
- """
-
- PATTERNS = client_patterns(
- "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
- "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
- releases=(),
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.store = hs.get_datastores().main
- self.event_handler = hs.get_event_handler()
-
- async def on_GET(
- self,
- request: SynapseRequest,
- room_id: str,
- parent_id: str,
- relation_type: Optional[str] = None,
- event_type: Optional[str] = None,
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
- await self.auth.check_user_in_room_or_world_readable(
- room_id,
- requester.user.to_string(),
- allow_departed_users=True,
- )
-
- # This checks that a) the event exists and b) the user is allowed to
- # view it.
- event = await self.event_handler.get_event(requester.user, room_id, parent_id)
- if event is None:
- raise SynapseError(404, "Unknown parent event.")
-
- if relation_type not in (RelationTypes.ANNOTATION, None):
- raise SynapseError(
- 400, f"Relation type must be '{RelationTypes.ANNOTATION}'"
- )
-
- limit = parse_integer(request, "limit", default=5)
- from_token_str = parse_string(request, "from")
- to_token_str = parse_string(request, "to")
-
- # Return the relations
- from_token = None
- if from_token_str:
- from_token = AggregationPaginationToken.from_string(from_token_str)
-
- to_token = None
- if to_token_str:
- to_token = AggregationPaginationToken.from_string(to_token_str)
-
- pagination_chunk = await self.store.get_aggregation_groups_for_event(
- event_id=parent_id,
- room_id=room_id,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
-
- return 200, await pagination_chunk.to_dict(self.store)
-
-
-class RelationAggregationGroupPaginationServlet(RestServlet):
- """API to paginate within an aggregation group of relations, e.g. paginate
- all the 👍 reactions on an event.
-
- Example request and response:
-
- GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍
-
- {
- chunk: [
- {
- "type": "m.reaction",
- "content": {
- "m.relates_to": {
- "rel_type": "m.annotation",
- "key": "👍"
- }
- }
- },
- ...
- ]
- }
- """
-
- PATTERNS = client_patterns(
- "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
- "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
- releases=(),
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.store = hs.get_datastores().main
- self._relations_handler = hs.get_relations_handler()
-
- async def on_GET(
- self,
- request: SynapseRequest,
- room_id: str,
- parent_id: str,
- relation_type: str,
- event_type: str,
- key: str,
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
- if relation_type != RelationTypes.ANNOTATION:
- raise SynapseError(400, "Relation type must be 'annotation'")
-
- limit = parse_integer(request, "limit", default=5)
- from_token_str = parse_string(request, "from")
- to_token_str = parse_string(request, "to")
-
- from_token = None
- if from_token_str:
- from_token = await StreamToken.from_string(self.store, from_token_str)
- to_token = None
- if to_token_str:
- to_token = await StreamToken.from_string(self.store, to_token_str)
-
- result = await self._relations_handler.get_relations(
- requester=requester,
- event_id=parent_id,
- room_id=room_id,
- relation_type=relation_type,
- event_type=event_type,
- aggregation_key=key,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
-
- return 200, result
-
-
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
- RelationAggregationPaginationServlet(hs).register(http_server)
- RelationAggregationGroupPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 07804853..dd91dabe 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -123,6 +123,19 @@ class RoomBatchSendEventRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
+ # Make sure that the prev_event_ids exist and aren't outliers - ie, they are
+ # regular parts of the room DAG where we know the state.
+ non_outlier_prev_events = await self.store.have_events_in_timeline(
+ prev_event_ids_from_query
+ )
+ for prev_event_id in prev_event_ids_from_query:
+ if prev_event_id not in non_outlier_prev_events:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "prev_event %s does not exist, or is an outlier" % (prev_event_id,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
# For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent state events that allowed that message to be
# sent. We will use that as a base to auth our historical messages
@@ -131,14 +144,6 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_event_ids_from_query
)
- if not state_event_ids:
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
- % prev_event_ids_from_query,
- errcode=Codes.INVALID_PARAM,
- )
-
state_event_ids_at_start = []
# Create and persist all of the state events that float off on their own
# before the batch. These will most likely be all of the invite/member
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 53c385a8..2e25e863 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -99,6 +99,7 @@ class SyncRestServlet(RestServlet):
self.presence_handler = hs.get_presence_handler()
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
+ self._msc2654_enabled = hs.config.experimental.msc2654_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
@@ -300,14 +301,13 @@ class SyncRestServlet(RestServlet):
if archived:
response["rooms"][Membership.LEAVE] = archived
- # By the time we get here groups is no longer optional.
- assert sync_result.groups is not None
- if sync_result.groups.join:
- response["groups"][Membership.JOIN] = sync_result.groups.join
- if sync_result.groups.invite:
- response["groups"][Membership.INVITE] = sync_result.groups.invite
- if sync_result.groups.leave:
- response["groups"][Membership.LEAVE] = sync_result.groups.leave
+ if sync_result.groups is not None:
+ if sync_result.groups.join:
+ response["groups"][Membership.JOIN] = sync_result.groups.join
+ if sync_result.groups.invite:
+ response["groups"][Membership.INVITE] = sync_result.groups.invite
+ if sync_result.groups.leave:
+ response["groups"][Membership.LEAVE] = sync_result.groups.leave
return response
@@ -521,7 +521,8 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
- result["org.matrix.msc2654.unread_count"] = room.unread_count
+ if self._msc2654_enabled:
+ result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index b9bfbea2..0c9f042c 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -76,17 +76,17 @@ class LocalKey(Resource):
def response_json_object(self) -> JsonDict:
verify_keys = {}
- for key in self.config.key.signing_key:
- verify_key_bytes = key.verify_key.encode()
- key_id = "%s:%s" % (key.alg, key.version)
+ for signing_key in self.config.key.signing_key:
+ verify_key_bytes = signing_key.verify_key.encode()
+ key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)}
old_verify_keys = {}
- for key_id, key in self.config.key.old_signing_keys.items():
- verify_key_bytes = key.encode()
+ for key_id, old_signing_key in self.config.key.old_signing_keys.items():
+ verify_key_bytes = old_signing_key.encode()
old_verify_keys[key_id] = {
"key": encode_base64(verify_key_bytes),
- "expired_ts": key.expired_ts,
+ "expired_ts": old_signing_key.expired,
}
json_object = {
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 3525d6ae..f5971575 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, Set
from signedjson.sign import sign_json
@@ -149,7 +149,7 @@ class RemoteKey(DirectServeJsonResource):
cached = await self.store.get_server_keys_json(store_queries)
- json_results = set()
+ json_results: Set[bytes] = set()
time_now_ms = self.clock.time_msec()
@@ -234,8 +234,8 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
- for key_json in json_results:
- key_json = json_decoder.decode(key_json.decode("utf-8"))
+ for key_json_raw in json_results:
+ key_json = json_decoder.decode(key_json_raw.decode("utf-8"))
for signing_key in self.config.key.key_server_signing_keys:
key_json = sign_json(
key_json, self.config.server.server_name, signing_key
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 6c414402..3e5d6c62 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -258,7 +258,7 @@ class MediaRepository:
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (await self.remote_media_linearizer.queue(key)):
+ async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -294,7 +294,7 @@ class MediaRepository:
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (await self.remote_media_linearizer.queue(key)):
+ async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -850,7 +850,7 @@ class MediaRepository:
# TODO: Should we delete from the backup store
- with (await self.remote_media_linearizer.queue(key)):
+ async with self.remote_media_linearizer.queue(key):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index d47af8ea..50383bdb 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -200,12 +200,17 @@ class PreviewUrlResource(DirectServeJsonResource):
match = False
continue
+ # Some attributes might not be parsed as strings by urlsplit (such as the
+ # port, which is parsed as an int). Because we use match functions that
+ # expect strings, we want to make sure that's what we give them.
+ value_str = str(value)
+
if pattern.startswith("^"):
- if not re.match(pattern, getattr(url_tuple, attrib)):
+ if not re.match(pattern, value_str):
match = False
continue
else:
- if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
+ if not fnmatch.fnmatch(value_str, pattern):
match = False
continue
if match:
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 7b4814e0..48eae5fa 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.events import EventBase
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -35,6 +35,7 @@ class ServerNoticesManager:
self._room_creation_handler = hs.get_room_creation_handler()
self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler()
+ self._message_handler = hs.get_message_handler()
self._is_mine_id = hs.is_mine_id
self._server_name = hs.hostname
@@ -107,6 +108,10 @@ class ServerNoticesManager:
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
+ requester = create_requester(
+ self.server_notices_mxid, authenticated_entity=self._server_name
+ )
+
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
@@ -125,6 +130,12 @@ class ServerNoticesManager:
room.room_id,
user_id,
)
+ await self._update_notice_user_profile_if_changed(
+ requester,
+ room.room_id,
+ self._config.servernotices.server_notices_mxid_display_name,
+ self._config.servernotices.server_notices_mxid_avatar_url,
+ )
return room.room_id
# apparently no existing notice room: create a new one
@@ -143,9 +154,6 @@ class ServerNoticesManager:
"avatar_url": self._config.servernotices.server_notices_mxid_avatar_url,
}
- requester = create_requester(
- self.server_notices_mxid, authenticated_entity=self._server_name
- )
info, _ = await self._room_creation_handler.create_room(
requester,
config={
@@ -194,3 +202,46 @@ class ServerNoticesManager:
room_id=room_id,
action="invite",
)
+
+ async def _update_notice_user_profile_if_changed(
+ self,
+ requester: Requester,
+ room_id: str,
+ display_name: Optional[str],
+ avatar_url: Optional[str],
+ ) -> None:
+ """
+ Updates the notice user's profile if it's different from what is in the room.
+
+ Args:
+ requester: The user who is performing the update.
+ room_id: The ID of the server notice room
+ display_name: The displayname of the server notice user
+ avatar_url: The avatar url of the server notice user
+ """
+ logger.debug("Checking whether notice user profile has changed for %s", room_id)
+
+ assert self.server_notices_mxid is not None
+
+ notice_user_data_in_room = await self._message_handler.get_room_data(
+ self.server_notices_mxid,
+ room_id,
+ EventTypes.Member,
+ self.server_notices_mxid,
+ )
+
+ assert notice_user_data_in_room is not None
+
+ notice_user_profile_changed = (
+ display_name != notice_user_data_in_room.content.get("displayname")
+ or avatar_url != notice_user_data_in_room.content.get("avatar_url")
+ )
+ if notice_user_profile_changed:
+ logger.info("Updating notice user profile in room %s", room_id)
+ await self._room_member_handler.update_membership(
+ requester=requester,
+ target=UserID.from_string(self.server_notices_mxid),
+ room_id=room_id,
+ action="join",
+ content={"displayname": display_name, "avatar_url": avatar_url},
+ )
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 21888cc8..fbf7ba46 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -573,7 +573,7 @@ class StateResolutionHandler:
"""
group_names = frozenset(state_groups_ids.keys())
- with (await self.resolve_linearizer.queue(group_names)):
+ async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
return cache
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3ef2bdd7..5eb545c8 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -241,9 +241,17 @@ class LoggingTransaction:
self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
+ """Call the given callback on the main twisted thread after the transaction has
+ finished.
+
+ Mostly used to invalidate the caches on the correct thread.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_after`
+ will accumulate across transaction attempts and will _all_ be called once a
+ transaction attempt succeeds, regardless of whether previous transaction
+ attempts failed. Otherwise, if all transaction attempts fail, all
+ `call_on_exception` callbacks will be run instead.
"""
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
@@ -254,6 +262,15 @@ class LoggingTransaction:
def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
+ """Call the given callback on the main twisted thread after the transaction has
+ failed.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_on_exception`
+ will accumulate across transaction attempts and will _all_ be called once the
+ final transaction attempt fails. No `call_on_exception` callbacks will be run
+ if any transaction attempt succeeds.
+ """
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -1251,6 +1268,7 @@ class DatabasePool:
value_names: Collection[str],
value_values: Collection[Collection[Any]],
desc: str,
+ lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1262,6 +1280,8 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert. Unused if the database engine
+ supports native upserts.
"""
# We can autocommit if we are going to use native upserts
@@ -1269,7 +1289,7 @@ class DatabasePool:
self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
)
- return await self.runInteraction(
+ await self.runInteraction(
desc,
self.simple_upsert_many_txn,
table,
@@ -1277,6 +1297,7 @@ class DatabasePool:
key_values,
value_names,
value_values,
+ lock=lock,
db_autocommit=autocommit,
)
@@ -1288,6 +1309,7 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
+ lock: bool = True,
) -> None:
"""
Upsert, many times.
@@ -1299,6 +1321,8 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert. Unused if the database engine
+ supports native upserts.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -1306,7 +1330,7 @@ class DatabasePool:
)
else:
return self.simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values
+ txn, table, key_names, key_values, value_names, value_values, lock=lock
)
def simple_upsert_many_txn_emulated(
@@ -1317,6 +1341,7 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
+ lock: bool = True,
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
@@ -1328,17 +1353,24 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
+ lock: True to lock the table when doing the upsert.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
if not value_names:
value_values = [() for x in range(len(key_values))]
+ if lock:
+ # Lock the table just once, to prevent it being done once per row.
+ # Note that, according to Postgres' documentation, once obtained,
+ # the lock is held for the remainder of the current transaction.
+ self.engine.lock_table(txn, "user_ips")
+
for keyv, valv in zip(key_values, value_values):
_keys = {x: y for x, y in zip(key_names, keyv)}
_vals = {x: y for x, y in zip(value_names, valv)}
- self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False)
def simple_upsert_many_txn_native_upsert(
self,
@@ -1775,6 +1807,86 @@ class DatabasePool:
return txn.rowcount
+ async def simple_update_many(
+ self,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ desc: str,
+ ) -> None:
+ """
+ Update, many times, using batching where possible.
+ If the keys don't match anything, nothing will be updated.
+
+ Args:
+ table: The table to update
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The names of value columns to update.
+ value_values: A list of each row's value column values.
+ """
+
+ await self.runInteraction(
+ desc,
+ self.simple_update_many_txn,
+ table,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
+ )
+
+ @staticmethod
+ def simple_update_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Collection[Iterable[Any]],
+ ) -> None:
+ """
+ Update, many times, using batching where possible.
+ If the keys don't match anything, nothing will be updated.
+
+ Args:
+ table: The table to update
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The names of value columns to update.
+ value_values: A list of each row's value column values.
+ """
+
+ if len(value_values) != len(key_values):
+ raise ValueError(
+ f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
+ )
+
+ # List of tuples of (value values, then key values)
+ # (This matches the order needed for the query)
+ args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
+
+ for ks, vs in zip(key_values, value_values):
+ args.append(tuple(vs) + tuple(ks))
+
+ # 'col1 = ?, col2 = ?, ...'
+ set_clause = ", ".join(f"{n} = ?" for n in value_names)
+
+ if key_names:
+ # 'WHERE col3 = ? AND col4 = ? AND col5 = ?'
+ where_clause = "WHERE " + (" AND ".join(f"{n} = ?" for n in key_names))
+ else:
+ where_clause = ""
+
+ # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
+ sql = f"""
+ UPDATE {table} SET {set_clause} {where_clause}
+ """
+
+ txn.execute_batch(sql, args)
+
async def simple_update_one(
self,
table: str,
@@ -2013,29 +2125,40 @@ class DatabasePool:
max_value: int,
limit: int = 100000,
) -> Tuple[Dict[Any, int], int]:
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
+ """Gets roughly the last N changes in the given stream table as a
+ map from entity to the stream ID of the most recent change.
+
+ Also returns the minimum stream ID.
+ """
+
+ # This may return many rows for the same entity, but the `limit` is only
+ # a suggestion so we don't care that much.
+ #
+ # Note: Some stream tables can have multiple rows with the same stream
+ # ID. Instead of handling this with complicated SQL, we instead simply
+ # add one to the returned minimum stream ID to ensure correctness.
+ sql = f"""
+ SELECT {entity_column}, {stream_column}
+ FROM {table}
+ ORDER BY {stream_column} DESC
+ LIMIT ?
+ """
txn = db_conn.cursor(txn_name="get_cache_dict")
- txn.execute(sql, (int(max_value),))
+ txn.execute(sql, (limit,))
- cache = {row[0]: int(row[1]) for row in txn}
+ # The rows come out in reverse stream ID order, so we want to keep the
+ # stream ID of the first row for each entity.
+ cache: Dict[Any, int] = {}
+ for row in txn:
+ cache.setdefault(row[0], int(row[1]))
txn.close()
if cache:
- min_val = min(cache.values())
+ # We add one here as we don't know if we have all rows for the
+ # minimum stream ID.
+ min_val = min(cache.values()) + 1
else:
min_val = max_value
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index f024761b..951031af 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -33,7 +33,7 @@ from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore
-from .client_ips import ClientIpStore
+from .client_ips import ClientIpWorkerStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
from .directory import DirectoryStore
@@ -49,7 +49,7 @@ from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
-from .monthly_active_users import MonthlyActiveUsersStore
+from .monthly_active_users import MonthlyActiveUsersWorkerStore
from .openid import OpenIdStore
from .presence import PresenceStore
from .profile import ProfileStore
@@ -112,13 +112,13 @@ class DataStore(
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
- ClientIpStore,
+ ClientIpWorkerStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
- MonthlyActiveUsersStore,
+ MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
@@ -146,6 +146,7 @@ class DataStore(
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
],
)
@@ -182,17 +183,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
- device_list_max = self._device_list_id_gen.get_current_token()
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max
- )
- self._user_signature_stream_cache = StreamChangeCache(
- "UserSignatureStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
-
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 06944465..fa732edc 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
from synapse.appservice import (
ApplicationService,
@@ -26,10 +26,16 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
+from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -72,9 +78,25 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+ def get_max_as_txn_id(txn: Cursor) -> int:
+ logger.warning("Falling back to slow query, you should port to postgres")
+ txn.execute(
+ "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
+ )
+ return txn.fetchone()[0] # type: ignore
+
+ self._as_txn_seq_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_max_as_txn_id,
+ "application_services_txn_id_seq",
+ table="application_services_txns",
+ id_column="txn_id",
+ )
+
super().__init__(database, db_conn, hs)
- def get_app_services(self):
+ def get_app_services(self) -> List[ApplicationService]:
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
@@ -217,6 +239,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
+ device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,27 +254,14 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
+ device_list_summary: The device list summary to include in the transaction.
Returns:
A new transaction.
"""
- def _create_appservice_txn(txn):
- # work out new txn id (highest txn id for this service += 1)
- # The highest id may be the last one sent (in which case it is last_txn)
- # or it may be the highest in the txns list (which are waiting to be/are
- # being sent)
- last_txn_id = self._get_last_txn(txn, service.id)
-
- txn.execute(
- "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
- (service.id,),
- )
- highest_txn_id = txn.fetchone()[0]
- if highest_txn_id is None:
- highest_txn_id = 0
-
- new_txn_id = max(highest_txn_id, last_txn_id) + 1
+ def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
+ new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
# Insert new txn into txn table
event_ids = json_encoder.encode([e.event_id for e in events])
@@ -268,6 +278,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
+ device_list_summary=device_list_summary,
)
return await self.db_pool.runInteraction(
@@ -283,25 +294,8 @@ class ApplicationServiceTransactionWorkerStore(
txn_id: The transaction ID being completed.
service: The application service which was sent this transaction.
"""
- txn_id = int(txn_id)
-
- def _complete_appservice_txn(txn):
- # Debugging query: Make sure the txn being completed is EXACTLY +1 from
- # what was there before. If it isn't, we've got problems (e.g. the AS
- # has probably missed some events), so whine loudly but still continue,
- # since it shouldn't fail completion of the transaction.
- last_txn_id = self._get_last_txn(txn, service.id)
- if (last_txn_id + 1) != txn_id:
- logger.error(
- "appservice: Completing a transaction which has an ID > 1 from "
- "the last ID sent to this AS. We've either dropped events or "
- "sent it to the AS out of order. FIX ME. last_txn=%s "
- "completing_txn=%s service_id=%s",
- last_txn_id,
- txn_id,
- service.id,
- )
+ def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
@@ -332,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
An AppServiceTransaction or None.
"""
- def _get_oldest_unsent_txn(txn):
+ def _get_oldest_unsent_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
@@ -359,8 +355,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
- # TODO: to-device messages, one-time key counts and unused fallback keys
- # are not yet populated for catch-up transactions.
+ # TODO: to-device messages, one-time key counts, device list summaries and unused
+ # fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
@@ -370,21 +366,11 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
+ device_list_summary=DeviceListUpdates(),
)
- def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
- txn.execute(
- "SELECT last_txn FROM application_services_state WHERE as_id=?",
- (service_id,),
- )
- last_txn_id = txn.fetchone()
- if last_txn_id is None or last_txn_id[0] is None: # no row exists
- return 0
- else:
- return int(last_txn_id[0]) # select 'last_txn' col
-
async def set_appservice_last_pos(self, pos: int) -> None:
- def set_appservice_last_pos_txn(txn):
+ def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
@@ -398,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
- def get_new_events_for_appservice_txn(txn):
+ def get_new_events_for_appservice_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@@ -430,13 +418,13 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
- if type not in ("read_receipt", "presence", "to_device"):
+ if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
)
- def get_type_stream_id_for_appservice_txn(txn):
+ def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
stream_id_type = "%s_stream_id" % type
txn.execute(
# We do NOT want to escape `stream_id_type`.
@@ -446,7 +434,8 @@ class ApplicationServiceTransactionWorkerStore(
)
last_stream_id = txn.fetchone()
if last_stream_id is None or last_stream_id[0] is None: # no row exists
- return 0
+ # Stream tokens always start from 1, to avoid foot guns around `0` being falsey.
+ return 1
else:
return int(last_stream_id[0])
@@ -457,13 +446,13 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if stream_type not in ("read_receipt", "presence", "to_device"):
+ if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
- def set_appservice_stream_type_pos_txn(txn):
+ def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8b0c614e..0df160d2 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -25,7 +25,9 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
-from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.databases.main.monthly_active_users import (
+ MonthlyActiveUsersWorkerStore,
+)
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self.user_ips_max_age = hs.config.server.user_ips_max_age
+ # (user_id, access_token, ip,) -> last_seen
+ self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
+ cache_name="client_ip_last_seen", max_size=50000
+ )
+
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+ if self._update_on_this_worker:
+ # This is the designated worker that can write to the client IP
+ # tables.
+
+ # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+ self._batch_row_update: Dict[
+ Tuple[str, str, str], Tuple[str, Optional[str], int]
+ ] = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period."""
@@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
- async def get_last_client_ip_by_device(
+ async def _get_last_client_ip_by_device_from_database(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on.
@@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
- async def get_user_ip_and_agents(
+ async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
@@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
for access_token, ip, user_agent, last_seen in rows
]
-
-class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
-
- # (user_id, access_token, ip,) -> last_seen
- self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
- cache_name="client_ip_last_seen", max_size=50000
- )
-
- super().__init__(database, db_conn, hs)
-
- # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update: Dict[
- Tuple[str, str, str], Tuple[str, Optional[str], int]
- ] = {}
-
- self._client_ip_looper = self._clock.looping_call(
- self._update_client_ips_batch, 5 * 1000
- )
- self.hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", self._update_client_ips_batch
- )
-
async def insert_client_ip(
self,
user_id: str,
@@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- await self.populate_monthly_active_users(user_id)
+
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.set(key, now)
- self._batch_row_update[key] = (user_agent, device_id, now)
+ if self._update_on_this_worker:
+ await self.populate_monthly_active_users(user_id)
+ self._batch_row_update[key] = (user_agent, device_id, now)
+ else:
+ # We are not the designated writer-worker, so stream over replication
+ self.hs.get_replication_command_handler().send_user_ip(
+ user_id, access_token, ip, user_agent, device_id, now
+ )
@wrap_as_background_process("update_client_ips")
async def _update_client_ips_batch(self) -> None:
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -603,51 +616,57 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- await self.db_pool.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ if to_update:
+ await self.db_pool.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(
self,
txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None:
- if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
- not self.database_engine.can_native_upsert
- ):
- self.database_engine.lock_table(txn, "user_ips")
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
+
+ # Keys and values for the `user_ips` upsert.
+ user_ips_keys = []
+ user_ips_values = []
+
+ # Keys and values for the `devices` update.
+ devices_keys = []
+ devices_values = []
for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
-
- self.db_pool.simple_upsert_txn(
- txn,
- table="user_ips",
- keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
- values={
- "user_agent": user_agent,
- "device_id": device_id,
- "last_seen": last_seen,
- },
- lock=False,
- )
+ user_ips_keys.append((user_id, access_token, ip))
+ user_ips_values.append((user_agent, device_id, last_seen))
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- # this is always an update rather than an upsert: the row should
- # already exist, and if it doesn't, that may be because it has been
- # deleted, and we don't want to re-create it.
- self.db_pool.simple_update_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- updatevalues={
- "user_agent": user_agent,
- "last_seen": last_seen,
- "ip": ip,
- },
- )
+ devices_keys.append((user_id, device_id))
+ devices_values.append((user_agent, last_seen, ip))
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="user_ips",
+ key_names=("user_id", "access_token", "ip"),
+ key_values=user_ips_keys,
+ value_names=("user_agent", "device_id", "last_seen"),
+ value_values=user_ips_values,
+ )
+
+ if devices_values:
+ self.db_pool.simple_update_many_txn(
+ txn,
+ table="devices",
+ key_names=("user_id", "device_id"),
+ key_values=devices_keys,
+ value_names=("user_agent", "last_seen", "ip"),
+ value_values=devices_values,
+ )
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
@@ -662,7 +681,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
- ret = await super().get_last_client_ip_by_device(user_id, device_id)
+ ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return ret
# Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database.
@@ -707,9 +731,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
Only the latest user agent for each access token and IP address combination
is available.
"""
+ rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return rows_from_db
+
results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection
- for connection in await super().get_user_ip_and_agents(user, since_ts)
+ for connection in rows_from_db
}
# Overlay data that is pending insertion on top of the results from the
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3b3a089b..dc8009b2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ device_list_max = self._device_list_id_gen.get_current_token()
+ device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_stream_cache = StreamChangeCache(
+ "DeviceListStreamChangeCache",
+ min_device_list_id,
+ prefilled_cache=device_list_prefill,
+ )
+
+ (
+ user_signature_stream_prefill,
+ user_signature_stream_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "user_signature_stream",
+ entity_column="from_user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=1000,
+ )
+ self._user_signature_stream_cache = StreamChangeCache(
+ "UserSignatureStreamChangeCache",
+ user_signature_stream_list_id,
+ prefilled_cache=user_signature_stream_prefill,
+ )
+
+ (
+ device_list_federation_prefill,
+ device_list_federation_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_outbound_pokes",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache",
+ device_list_federation_list_id,
+ prefilled_cache=device_list_federation_prefill,
+ )
+
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
@@ -681,42 +731,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
- self, from_key: int, user_ids: Iterable[str]
+ self,
+ from_key: int,
+ user_ids: Optional[Iterable[str]] = None,
+ to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
- from_key: The device lists stream token
- user_ids: The user IDs to query for devices.
+ from_key: The minimum device lists stream token to query device list changes for,
+ exclusive.
+ user_ids: If provided, only check if these users have changed their device lists.
+ Otherwise changes from all users are returned.
+ to_key: The maximum device lists stream token to query device list changes for,
+ inclusive.
Returns:
- The set of user_ids whose devices have changed since `from_key`
+ The set of user_ids whose devices have changed since `from_key` (exclusive)
+ until `to_key` (inclusive).
"""
-
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ if user_ids is None:
+ # Get set of all users that have had device list changes since 'from_key'
+ user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
+ from_key
+ )
+ else:
+ # The same as above, but filter results to only those users in 'user_ids'
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
- if not to_check:
+ if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
- sql = """
+ stream_id_where_clause = "stream_id > ?"
+ sql_args = [from_key]
+
+ if to_key:
+ stream_id_where_clause += " AND stream_id <= ?"
+ sql_args.append(to_key)
+
+ sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE stream_id > ?
+ WHERE {stream_id_where_clause}
AND
"""
- for chunk in batch_iter(to_check, 100):
+ # Query device changes with a batch of users at a time
+ for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, (from_key,) + tuple(args))
+ txn.execute(sql + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
return changes
@@ -788,6 +860,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
LIMIT ?
"""
@@ -1506,7 +1579,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
- self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+ self,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: Optional[Collection[str]],
+ room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -1515,7 +1592,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
- hosts: The remote destinations that should be notified of the change.
+ hosts: The remote destinations that should be notified of the change. If
+ None then the set of hosts have *not* been calculated, and will be
+ calculated later by a background task.
+ room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@@ -1524,34 +1604,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
- async with self._device_list_id_gen.get_next_mult(
- len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_change_to_stream",
- self._add_device_change_to_stream_txn,
+ context = get_active_span_text_map()
+
+ def add_device_changes_txn(
+ txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+ ):
+ self._add_device_change_to_stream_txn(
+ txn,
user_id,
device_ids,
- stream_ids,
+ stream_ids_for_device_change,
)
- if not hosts:
- return stream_ids[-1]
+ self._add_device_outbound_room_poke_txn(
+ txn,
+ user_id,
+ device_ids,
+ room_ids,
+ stream_ids_for_device_change,
+ context,
+ hosts_have_been_calculated=hosts is not None,
+ )
- context = get_active_span_text_map()
- async with self._device_list_id_gen.get_next_mult(
- len(hosts) * len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_outbound_poke_to_stream",
- self._add_device_outbound_poke_to_stream_txn,
+ # If the set of hosts to send to has not been calculated yet (and so
+ # `hosts` is None) or there are no `hosts` to send to, then skip
+ # trying to persist them to the DB.
+ if not hosts:
+ return
+
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
user_id,
device_ids,
hosts,
- stream_ids,
+ stream_ids_for_outbound_pokes,
context,
)
+ # `device_lists_stream` wants a stream ID per device update.
+ num_stream_ids = len(device_ids)
+
+ if hosts:
+ # `device_lists_outbound_pokes` wants a different stream ID for
+ # each row, which is a row per host per device update.
+ num_stream_ids += len(hosts) * len(device_ids)
+
+ async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+ stream_ids_for_device_change = stream_ids[: len(device_ids)]
+ stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+ await self.db_pool.runInteraction(
+ "add_device_change_to_stream",
+ add_device_changes_txn,
+ stream_ids_for_device_change,
+ stream_ids_for_outbound_pokes,
+ )
+
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@@ -1595,7 +1703,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
- stream_ids: List[str],
+ stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@@ -1606,8 +1714,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
- next_stream_id = iter(stream_ids)
+ stream_id_iterator = iter(stream_ids)
+ encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1623,16 +1732,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
- next(next_stream_id),
+ next(stream_id_iterator),
user_id,
device_id,
False,
now,
- json_encoder.encode(context)
- if whitelisted_homeserver(destination)
- else "{}",
+ encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
+
+ def _add_device_outbound_room_poke_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Iterable[str],
+ room_ids: Collection[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
+ hosts_have_been_calculated: bool,
+ ) -> None:
+ """Record the user in the room has updated their device.
+
+ Args:
+ hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+ has been updated already with the updates.
+ """
+
+ # We only need to convert to outbound pokes if they are our user.
+ converted_to_destinations = (
+ hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+ )
+
+ encoded_context = json_encoder.encode(context)
+
+ # The `device_lists_changes_in_room.stream_id` column matches the
+ # corresponding `stream_id` of the update in the `device_lists_stream`
+ # table, i.e. all rows persisted for the same device update will have
+ # the same `stream_id` (but different room IDs).
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keys=(
+ "user_id",
+ "device_id",
+ "room_id",
+ "stream_id",
+ "converted_to_destinations",
+ "opentracing_context",
+ ),
+ values=[
+ (
+ user_id,
+ device_id,
+ room_id,
+ stream_id,
+ converted_to_destinations,
+ encoded_context,
+ )
+ for room_id in room_ids
+ for device_id, stream_id in zip(device_ids, stream_ids)
+ ],
+ )
+
+ async def get_uncoverted_outbound_room_pokes(
+ self, limit: int = 10
+ ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+ """Get device list changes by room that have not yet been handled and
+ written to `device_lists_outbound_pokes`.
+
+ Returns:
+ A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ """
+
+ sql = """
+ SELECT user_id, device_id, room_id, stream_id, opentracing_context
+ FROM device_lists_changes_in_room
+ WHERE NOT converted_to_destinations
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ def get_uncoverted_outbound_room_pokes_txn(txn):
+ txn.execute(sql, (limit,))
+ return txn.fetchall()
+
+ return await self.db_pool.runInteraction(
+ "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+ )
+
+ async def add_device_list_outbound_pokes(
+ self,
+ user_id: str,
+ device_id: str,
+ room_id: str,
+ stream_id: int,
+ hosts: Collection[str],
+ context: Optional[Dict[str, str]],
+ ) -> None:
+ """Queue the device update to be sent to the given set of hosts,
+ calculated from the room ID.
+
+ Marks the associated row in `device_lists_changes_in_room` as handled.
+ """
+
+ def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+ if hosts:
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
+ user_id=user_id,
+ device_ids=[device_id],
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
+ )
+
+ self.db_pool.simple_update_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "room_id": room_id,
+ },
+ updatevalues={"converted_to_destinations": True},
+ )
+
+ if not hosts:
+ # If there are no hosts then we don't try and generate stream IDs.
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ [],
+ )
+
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ stream_ids,
+ )
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index d2532431..3fcd5f5b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -197,12 +197,10 @@ class PersistEventsStore:
)
persist_event_counter.inc(len(events_and_contexts))
- if stream < 0:
- # backfilled events have negative stream orderings, so we don't
- # want to set the event_persisted_position to that.
- synapse.metrics.event_persisted_position.set(
- events_and_contexts[-1][0].internal_metadata.stream_ordering
- )
+ if not use_negative_stream_ordering:
+ # we don't want to set the event_persisted_position to a negative
+ # stream_ordering.
+ synapse.metrics.event_persisted_position.set(stream)
for event, context in events_and_contexts:
if context.app_service:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 59454a47..a60e3f4f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,7 +22,6 @@ from typing import (
Dict,
Iterable,
List,
- NoReturn,
Optional,
Set,
Tuple,
@@ -1330,10 +1329,9 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
- # this only exists for the benefit of the @cachedList descriptor on
- # _have_seen_events_dict
- raise NotImplementedError()
+ async def have_seen_event(self, room_id: str, event_id: str) -> bool:
+ res = await self._have_seen_events_dict(((room_id, event_id),))
+ return res[(room_id, event_id)]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 21662296..4f1c22c7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,7 +15,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
-class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._max_mau_value = hs.config.server.max_mau_value
+ self._mau_stats_only = hs.config.server.mau_stats_only
+
+ if self._update_on_this_worker:
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self._mau_stats_only = hs.config.server.mau_stats_only
-
- # Do not add more reserved users than the total allowable number
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
@@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn:
threepids: List of threepid dicts to reserve
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why?
@@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id: user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
# is_support_user which is complicated because I want to cache the result.
@@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn (cursor):
user_id (str): user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
@@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id(str): the user_id to query
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e6f97aee..332e901d 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
+ max_receipts_stream_id = self.get_max_receipt_stream_id()
+ receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "receipts_linearized",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=max_receipts_stream_id,
+ limit=10000,
+ )
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
+ "ReceiptsRoomChangeCache",
+ min_receipts_stream_id,
+ prefilled_cache=receipts_stream_prefill,
)
def get_max_receipt_stream_id(self) -> int:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7f3d190e..d43163c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID, UserInfo
+from synapse.types import JsonDict, UserID, UserInfo
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -79,7 +79,7 @@ class TokenLookupResult:
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
- def _default_token_owner(self):
+ def _default_token_owner(self) -> str:
return self.user_id
@@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
the account.
"""
- def set_account_validity_for_user_txn(txn):
+ def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_txn(
txn=txn,
table="account_validity",
@@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_renewal_token_for_user",
)
- async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
+ async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- A list of dictionaries, each with a user ID and expiration time (in milliseconds).
+ A list of tuples, each with a user ID and expiration time (in milliseconds).
"""
- def select_users_txn(txn, now_ms, renew_at):
+ def select_users_txn(
+ txn: LoggingTransaction, now_ms: int, renew_at: int
+ ) -> List[Tuple[str, int]]:
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
@@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
admin: true iff the user is to be a server admin, false otherwise.
"""
- def set_server_admin_txn(txn):
+ def set_server_admin_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
)
@@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_type: type of the user or None for a user without a type.
"""
- def set_user_type_txn(txn):
+ def set_user_type_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
)
@@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
- def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+ def _query_for_auth(
+ self, txn: LoggingTransaction, token: str
+ ) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
users.is_guest,
@@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"is_support_user", self.is_support_user_txn, user_id
)
- def is_real_user_txn(self, txn, user_id):
+ def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return res is None
- def is_support_user_txn(self, txn, user_id):
+ def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A mapping of user_id -> password_hash.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> Dict[str, str]:
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
- return dict(txn)
+ result = cast(List[Tuple[str, str]], txn.fetchall())
+ return dict(result)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
@@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def _replace_user_external_id_txn(
txn: LoggingTransaction,
- ):
+ ) -> None:
_remove_user_external_ids_txn(txn, user_id)
for auth_provider, external_id in record_external_ids:
@@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return [(r["auth_provider"], r["external_id"]) for r in res]
- async def count_all_users(self):
+ async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
who registered on the homeserver in the past 24 hours
"""
- def _count_daily_user_type(txn):
+ def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
@@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"count_daily_user_type", _count_daily_user_type
)
- async def count_nonbridged_users(self):
- def _count_users(txn):
+ async def count_nonbridged_users(self) -> int:
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT COUNT(*) FROM users
WHERE appservice_id IS NULL
"""
)
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
- async def count_real_users(self):
+ async def count_real_users(self) -> int:
"""Counts all users without a special user_type registered on the homeserver."""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return user_id
def get_user_id_by_threepid_txn(
- self, txn, medium: str, address: str
+ self, txn: LoggingTransaction, medium: str, address: str
) -> Optional[str]:
"""Returns user id from threepid
@@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
+ async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
- ):
+ ) -> None:
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
@@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
assert address or sid
- def get_threepid_validation_session_txn(txn):
+ def get_threepid_validation_session_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
session_id: The ID of the session to delete
"""
- def delete_threepid_session_txn(txn):
+ def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_token",
@@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ def cull_expired_threepid_validation_tokens_txn(
+ txn: LoggingTransaction, ts: int
+ ) -> None:
sql = """
DELETE FROM threepid_validation_token WHERE
expires < ?
@@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@wrap_as_background_process("account_validity_set_expiration_dates")
- async def _set_expiration_date_when_missing(self):
+ async def _set_expiration_date_when_missing(self) -> None:
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
"""
- def select_users_with_no_expiration_date_txn(txn):
+ def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
"""Retrieves the list of registered users with no expiration date from the
database, filtering out deactivated users.
"""
@@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
select_users_with_no_expiration_date_txn,
)
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ def set_expiration_date_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
+ ) -> None:
"""Sets an expiration date to the account with the given user ID.
Args:
@@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token pending use
"""
- def _set_registration_token_pending_txn(txn):
+ def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
pending = self.db_pool.simple_select_one_onecol_txn(
txn,
"registration_tokens",
@@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
updatevalues={"pending": pending + 1},
)
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn
)
@@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token to be 'used'
"""
- def _use_registration_token_txn(txn):
+ def _use_registration_token_txn(txn: LoggingTransaction) -> None:
# Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
@@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
},
)
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn
)
@@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A list of dicts, each containing details of a token.
"""
- def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ def select_registration_tokens_txn(
+ txn: LoggingTransaction, now: int, valid: Optional[bool]
+ ) -> List[Dict[str, Any]]:
if valid is None:
# Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens")
@@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Whether the row was inserted or not.
"""
- def _create_registration_token_txn(txn):
+ def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
row = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A dict with all info about the token, or None if token doesn't exist.
"""
- def _update_registration_token_txn(txn):
+ def _update_registration_token_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Dict[str, Any]]:
try:
self.db_pool.simple_update_one_txn(
txn,
@@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> Optional[RefreshTokenLookupResult]:
"""Lookup a refresh token with hints about its validity."""
- def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+ def _lookup_refresh_token_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[RefreshTokenLookupResult]:
txn.execute(
"""
SELECT
@@ -1745,6 +1762,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ return res if res else False
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -1795,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
unique=False,
)
- async def _background_update_set_deactivated_flag(self, progress, batch_size):
+ async def _background_update_set_deactivated_flag(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
last_user = progress.get("user_id", "")
- def _background_update_set_deactivated_flag_txn(txn):
+ def _background_update_set_deactivated_flag_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[bool, int]:
txn.execute(
"""
SELECT
@@ -1874,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
deactivated,
)
- def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+ def set_user_deactivated_status_txn(
+ self, txn: LoggingTransaction, user_id: str, deactivated: bool
+ ) -> None:
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
@@ -1887,18 +1922,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
- @cached()
- async def is_guest(self, user_id: str) -> bool:
- res = await self.db_pool.simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="is_guest",
- allow_none=True,
- desc="is_guest",
- )
-
- return res if res else False
-
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id
- def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ def _set_device_for_access_token_txn(
+ self, txn: LoggingTransaction, token: str, device_id: str
+ ) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
)
@@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def _register_user(
self,
- txn,
+ txn: LoggingTransaction,
user_id: str,
password_hash: Optional[str],
was_guest: bool,
@@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
- ):
+ ) -> None:
user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time())
@@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
pointless. Use flush_user separately.
"""
- def user_set_password_hash_txn(txn):
+ def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
@@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
A tuple of (token, token id, device id) for each of the deleted tokens
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def delete_access_token(self, access_token: str) -> None:
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
async def delete_refresh_token(self, refresh_token: str) -> None:
- def f(txn):
+ def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
)
@@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"""
# Insert everything into a transaction in order to run atomically
- def validate_threepid_session_txn(txn):
+ def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_session",
@@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
longer be valid
"""
- def start_or_continue_validation_session_txn(txn):
+ def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
# Create or update a validation session
self.db_pool.simple_upsert_txn(
txn,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b2295fd5..407158ce 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Optional,
@@ -39,8 +40,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -49,6 +49,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _RelatedEvent:
+ """
+ Contains enough information about a related event in order to properly filter
+ events from ignored users.
+ """
+
+ # The event ID of the related event.
+ event_id: str
+ # The sender of the related event.
+ sender: str
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -73,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> PaginationChunk:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -90,8 +103,10 @@ class RelationsWorkerStore(SQLBaseStore):
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- List of event IDs that match relations requested. The rows are of
- the form `{"event_id": "..."}`.
+ A tuple of:
+ A list of related event IDs & their senders.
+
+ The next stream token, if one exists.
"""
# We don't use `event_id`, it's there so that we can cache based on
# it. The `event_id` must match the `event.event_id`.
@@ -132,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, relation_type, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -146,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -156,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
- events.append({"event_id": row[0]})
- last_topo_id = row[2]
- last_stream_id = row[3]
+ events.append(_RelatedEvent(row[0], row[2]))
+ last_topo_id = row[3]
+ last_stream_id = row[4]
# If there are more events, generate the next pagination key.
next_token = None
@@ -179,9 +194,7 @@ class RelationsWorkerStore(SQLBaseStore):
groups_key=0,
)
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
- )
+ return events[:limit], next_token
return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
@@ -252,15 +265,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def get_aggregation_groups_for_event(
- self,
- event_id: str,
- room_id: str,
- event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[AggregationPaginationToken] = None,
- to_token: Optional[AggregationPaginationToken] = None,
- ) -> PaginationChunk:
+ self, event_id: str, room_id: str, limit: int = 5
+ ) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -270,82 +276,96 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
- event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
- direction: Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [
+ args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
+ limit,
]
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
+ sql = """
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
+ GROUP BY relation_type, type, aggregation_key
+ ORDER BY COUNT(*) DESC
+ LIMIT ?
+ """
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
- to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
- engine=self.database_engine,
+ def _get_aggregation_groups_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> List[JsonDict]:
+ txn.execute(sql, args)
+
+ return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
+ async def get_aggregation_groups_for_users(
+ self,
+ event_id: str,
+ room_id: str,
+ limit: int,
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Fetch the partial aggregations for an event for specific users.
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
+ This is used, in conjunction with get_aggregation_groups_for_event, to
+ remove information from the results for ignored users.
- sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ Args:
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ limit: Only fetch the `limit` groups.
+ users: The users to fetch information for.
+
+ Returns:
+ A map of (event type, aggregation key) to a count of users.
+ """
+
+ if not users:
+ return {}
+
+ args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
+
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "sender", users
+ )
+ args.extend(users_args)
+
+ sql = f"""
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
- WHERE {where_clause}
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ ORDER BY COUNT(*) DESC
LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
+ """
- def _get_aggregation_groups_for_event_txn(
+ def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
- txn.execute(sql, where_args + [limit + 1])
+ ) -> Dict[Tuple[str, str], int]:
+ txn.execute(sql, args + [limit])
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
-
- if len(events) <= limit:
- next_batch = None
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
+ return {(row[0], row[1]): row[2] for row in txn}
return await self.db_pool.runInteraction(
- "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
)
@cached()
@@ -574,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
return summaries
+ async def get_threaded_messages_per_user(
+ self,
+ event_ids: Collection[str],
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Get the number of threaded replies for a set of users.
+
+ This is used, in conjunction with get_thread_summaries, to calculate an
+ accurate count of the replies to a thread by subtracting ignored users.
+
+ Args:
+ event_ids: The events to check for threaded replies.
+ users: The user to calculate the count of their replies.
+
+ Returns:
+ A map of the (event_id, sender) to the count of their replies.
+ """
+ if not users:
+ return {}
+
+ # Fetch the number of threaded replies.
+ sql = """
+ SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND %s
+ AND %s
+ GROUP BY parent.event_id, child.sender
+ """
+
+ def _get_threaded_messages_per_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[Tuple[str, str], int]:
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "child.sender", users
+ )
+ events_clause, events_args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
+ else:
+ relations_clause = "relation_type = ?"
+ relations_args = [RelationTypes.THREAD]
+
+ txn.execute(
+ sql % (users_sql, events_clause, relations_clause),
+ users_args + events_args + relations_args,
+ )
+ return {(row[0], row[1]): row[2] for row in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
+ )
+
@cached()
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()
@@ -661,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_events_have_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3248da53..48e83592 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: Collection[str]
+ self,
+ user_id: str,
+ membership_list: Collection[str],
+ excluded_rooms: Optional[List[str]] = None,
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
values which the user must be in.
+ excluded_rooms: A list of rooms to ignore.
Returns:
The RoomsForUser that the user matches the membership types.
@@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership_list,
)
- # Now we filter out forgotten rooms
- forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
- return [room for room in rooms if room.room_id not in forgotten_rooms]
+ # Now we filter out forgotten and excluded rooms
+ rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+
+ if excluded_rooms is not None:
+ rooms_to_exclude.update(set(excluded_rooms))
+
+ return [room for room in rooms if room.room_id not in rooms_to_exclude]
def _get_rooms_for_local_user_where_membership_is_txn(
- self, txn, user_id: str, membership_list: List[str]
+ self,
+ txn,
+ user_id: str,
+ membership_list: List[str],
) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
@@ -877,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return frozenset(cache.hosts_to_joined_users)
# Since we'll mutate the cache we need to lock.
- with (await self._joined_host_linearizer.queue(room_id)):
+ async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 0518b8b9..95148fd2 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
- def get_event_reference_hash(self, event_id):
+ def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 28460fd3..ecdc1fdc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+
+from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap
+from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
- async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
+ async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
- if not isinstance(predecessor, collections.abc.Mapping):
+ if not isinstance(predecessor, (dict, frozendict)):
return None
+ # The keys must be strings since the data is JSON.
return predecessor
async def get_create_event_for_room(self, room_id: str) -> EventBase:
@@ -202,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The current state of the room.
"""
- def _get_current_state_ids_txn(txn):
+ def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
- async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
- """Returns mapping event_id -> state_group"""
+ async def _get_state_group_for_events(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, int]:
+ """Returns mapping event_id -> state_group.
+
+ Raises:
+ RuntimeError if the state is unknown at any of the given events
+ """
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
- return {row["event_id"]: row["state_group"] for row in rows}
+ res = {row["event_id"]: row["state_group"] for row in rows}
+ for e in event_ids:
+ if e not in res:
+ raise RuntimeError("No state group for unknown or outlier event %s" % e)
+ return res
async def get_referenced_state_groups(
self, state_groups: Iterable[int]
@@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
)
for user_id in potentially_left_users - joined_users:
- await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+ await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined]
return batch_size
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 39e1efe3..6d45a8a9 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,7 +36,17 @@ what sort order was used:
"""
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
from frozendict import frozendict
@@ -585,7 +595,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
+ self,
+ user_id: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
+ excluded_rooms: Optional[List[str]] = None,
) -> List[EventBase]:
"""Fetch membership events for a given user.
@@ -610,23 +624,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
min_from_id = from_key.stream
max_to_id = to_key.get_max_stream_pos()
+ args: List[Any] = [user_id, min_from_id, max_to_id]
+
+ ignore_room_clause = ""
+ if excluded_rooms is not None and len(excluded_rooms) > 0:
+ ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
+ "?" for _ in excluded_rooms
+ )
+ args = args + excluded_rooms
+
sql = """
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ %s
ORDER BY e.stream_ordering ASC
- """
- txn.execute(
- sql,
- (
- user_id,
- min_from_id,
- max_to_id,
- ),
+ """ % (
+ ignore_room_clause,
)
+ txn.execute(sql, args)
+
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
@@ -722,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A tuple of (stream ordering, topological ordering, event_id)
"""
- def _f(txn):
+ def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
@@ -732,27 +752,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering))
- return txn.fetchone()
+ return cast(Optional[Tuple[int, int, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)
- async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
- """Returns the current token for rooms stream.
+ async def get_current_room_stream_token_for_room_id(
+ self, room_id: Optional[str] = None
+ ) -> RoomStreamToken:
+ """Returns the current position of the rooms stream.
- By default, it returns the current global stream token. Specifying a
- `room_id` causes it to return the current room specific topological
- token.
+ By default, it returns a live token with the current global stream
+ token. Specifying a `room_id` causes it to return a historic token with
+ the room specific topological token.
"""
- token = self.get_room_max_stream_ordering()
+ stream_ordering = self.get_room_max_stream_ordering()
if room_id is None:
- return "s%d" % (token,)
+ return RoomStreamToken(None, stream_ordering)
else:
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- return "t%d-%d" % (topo, token)
+ return RoomStreamToken(topo, stream_ordering)
def get_stream_id_for_event_txn(
self,
@@ -827,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod
def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
- ):
+ ) -> None:
"""Inserts ordering information to events' internal metadata from
the DB rows.
@@ -973,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the `current_id`).
"""
- def get_all_new_events_stream_txn(txn):
+ def get_all_new_events_stream_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@@ -1319,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_id_for_instance(self, instance_name: str) -> int:
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
- def _get_id_for_instance_txn(txn):
+ def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
instance_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="instance_map",
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index c8e508a9..b0f5de67 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
def get_tag_content(
- txn: LoggingTransaction, tag_ids
+ txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
@@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
def _update_revision_txn(
- self, txn, user_id: str, room_id: str, next_id: int
+ self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
) -> None:
"""Update the latest revision of the tags for the given user and room.
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
deleted file mode 100644
index fba27015..00000000
--- a/synapse/storage/relations.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-
-import attr
-
-from synapse.api.errors import SynapseError
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
- from synapse.storage.databases.main import DataStore
-
-logger = logging.getLogger(__name__)
-
-
-@attr.s(slots=True, auto_attribs=True)
-class PaginationChunk:
- """Returned by relation pagination APIs.
-
- Attributes:
- chunk: The rows returned by pagination
- next_batch: Token to fetch next set of results with, if
- None then there are no more results.
- prev_batch: Token to fetch previous set of results with, if
- None then there are no previous results.
- """
-
- chunk: List[JsonDict]
- next_batch: Optional[Any] = None
- prev_batch: Optional[Any] = None
-
- async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
- d = {"chunk": self.chunk}
-
- if self.next_batch:
- d["next_batch"] = await self.next_batch.to_string(store)
-
- if self.prev_batch:
- d["prev_batch"] = await self.prev_batch.to_string(store)
-
- return d
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class AggregationPaginationToken:
- """Pagination token for relation aggregation pagination API.
-
- As the results are order by count and then MAX(stream_ordering) of the
- aggregation groups, we can just use them as our pagination token.
-
- Attributes:
- count: The count of relations in the boundary group.
- stream: The MAX stream ordering in the boundary group.
- """
-
- count: int
- stream: int
-
- @staticmethod
- def from_string(string: str) -> "AggregationPaginationToken":
- try:
- c, s = string.split("-")
- return AggregationPaginationToken(int(c), int(s))
- except ValueError:
- raise SynapseError(400, "Invalid aggregation pagination token")
-
- async def to_string(self, store: "DataStore") -> str:
- return "%d-%d" % (self.count, self.stream)
-
- def as_tuple(self) -> Tuple[Any, ...]:
- return attr.astuple(self)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 7b21c1b9..151f2aa9 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 68 # remember to update the list below when updating
+SCHEMA_VERSION = 69 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -58,6 +58,10 @@ Changes in SCHEMA_VERSION = 68:
- event_reference_hashes is no longer read.
- `events` has `state_key` and `rejection_reason` columns, which are populated for
new events.
+
+Changes in SCHEMA_VERSION = 69:
+ - We now write to `device_lists_changes_in_room` table.
+ - Use sequence to generate future `application_services_txns.txn_id`s
"""
diff --git a/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql
new file mode 100644
index 00000000..7590e34b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql
@@ -0,0 +1,23 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to track what device list changes stream id that this application
+-- service has been caught up to.
+
+-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible
+-- state is useful for determining if we've ever sent traffic for a stream type
+-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for
+-- one way this can be used.
+ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT; \ No newline at end of file
diff --git a/synapse/storage/schema/main/delta/69/01as_txn_seq.py b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
new file mode 100644
index 00000000..24bd4b39
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
@@ -0,0 +1,44 @@
+# Copyright 2022 Beeper
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Adds a postgres SEQUENCE for generating application service transaction IDs.
+"""
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ # If we already have some AS TXNs we want to start from the current
+ # maximum value. There are two potential places this is stored - the
+ # actual TXNs themselves *and* the AS state table. At time of migration
+ # it is possible the TXNs table is empty so we must include the AS state
+ # last_txn as a potential option, and pick the maximum.
+
+ cur.execute("SELECT COALESCE(max(txn_id), 0) FROM application_services_txns")
+ row = cur.fetchone()
+ txn_max = row[0]
+
+ cur.execute("SELECT COALESCE(max(last_txn), 0) FROM application_services_state")
+ row = cur.fetchone()
+ last_txn_max = row[0]
+
+ start_val = max(last_txn_max, txn_max) + 1
+
+ cur.execute(
+ "CREATE SEQUENCE application_services_txn_id_seq START WITH %s",
+ (start_val,),
+ )
diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
new file mode 100644
index 00000000..b5b1782b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
@@ -0,0 +1,38 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_lists_changes_in_room (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+
+ -- This initially matches `device_lists_stream.stream_id`. Note that we
+ -- delete older values from `device_lists_stream`, so we can't use a foreign
+ -- constraint here.
+ --
+ -- The table will contain rows with the same `stream_id` but different
+ -- `room_id`, as for each device update we store a row per room the user is
+ -- joined to. Therefore `(stream_id, room_id)` gives a unique index.
+ stream_id BIGINT NOT NULL,
+
+ -- We have a background process which goes through this table and converts
+ -- entries into rows in `device_lists_outbound_pokes`. Once we have processed
+ -- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
+ converted_to_destinations BOOLEAN NOT NULL,
+ opentracing_context TEXT
+);
+
+CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
+CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 86f1a537..cda194e8 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -571,6 +571,10 @@ class StateGroupStorage:
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
if not event_ids:
return {}
@@ -659,6 +663,10 @@ class StateGroupStorage:
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@@ -696,6 +704,10 @@ class StateGroupStorage:
Returns:
A dict from event_id -> (type, state_key) -> event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@@ -723,6 +735,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
@@ -741,6 +757,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 57f4883b..d7d6f1d9 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -45,6 +45,7 @@ class Cursor(Protocol):
Sequence[
# Note that this is an approximate typing based on sqlite3 and other
# drivers, and may not be entirely accurate.
+ # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
Tuple[
str,
Optional[Any],
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index fb8fe172..acf17ba6 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -69,7 +69,7 @@ class EventSources:
)
return token
- def get_current_token_for_pagination(self) -> StreamToken:
+ async def get_current_token_for_pagination(self, room_id: str) -> StreamToken:
"""Get the current token for a given room to be used to paginate
events.
@@ -80,7 +80,7 @@ class EventSources:
The current token for pagination.
"""
token = StreamToken(
- room_key=self.sources.room.get_current_key(),
+ room_key=await self.sources.room.get_current_key_for_room(room_id),
presence_key=0,
typing_key=0,
receipt_key=0,
diff --git a/synapse/types.py b/synapse/types.py
index 5ce2a5b0..9ac688b2 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -25,6 +25,7 @@ from typing import (
Match,
MutableMapping,
Optional,
+ Set,
Tuple,
Type,
TypeVar,
@@ -38,6 +39,7 @@ from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
+from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import (
IReactorCore,
IReactorPluggableNameResolver,
@@ -421,22 +423,44 @@ class RoomStreamToken:
s0 s1
| |
- [0] V [1] V [2]
+ [0] ▼ [1] ▼ [2]
Tokens can either be a point in the live event stream or a cursor going
through historic events.
- When traversing the live event stream events are ordered by when they
- arrived at the homeserver.
+ When traversing the live event stream, events are ordered by
+ `stream_ordering` (when they arrived at the homeserver).
- When traversing historic events the events are ordered by their depth in
- the event graph "topological_ordering" and then by when they arrived at the
- homeserver "stream_ordering".
+ When traversing historic events, events are first ordered by their `depth`
+ (`topological_ordering` in the event graph) and tie-broken by
+ `stream_ordering` (when the event arrived at the homeserver).
- Live tokens start with an "s" followed by the "stream_ordering" id of the
- event it comes after. Historic tokens start with a "t" followed by the
- "topological_ordering" id of the event it comes after, followed by "-",
- followed by the "stream_ordering" id of the event it comes after.
+ If you're looking for more info about what a token with all of the
+ underscores means, ex.
+ `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring
+ for `StreamToken` below.
+
+ ---
+
+ Live tokens start with an "s" followed by the `stream_ordering` of the event
+ that comes before the position of the token. Said another way:
+ `stream_ordering` uniquely identifies a persisted event. The live token
+ means "the position just after the event identified by `stream_ordering`".
+ An example token is:
+
+ s2633508
+
+ ---
+
+ Historic tokens start with a "t" followed by the `depth`
+ (`topological_ordering` in the event graph) of the event that comes before
+ the position of the token, followed by "-", followed by the
+ `stream_ordering` of the event that comes before the position of the token.
+ An example token is:
+
+ t426-2633508
+
+ ---
There is also a third mode for live tokens where the token starts with "m",
which is sometimes used when using sharded event persisters. In this case
@@ -463,6 +487,8 @@ class RoomStreamToken:
Note: The `RoomStreamToken` cannot have both a topological part and an
instance map.
+ ---
+
For caching purposes, `RoomStreamToken`s and by extension, all their
attributes, must be hashable.
"""
@@ -515,6 +541,8 @@ class RoomStreamToken:
stream=stream,
instance_map=frozendict(instance_map),
)
+ except CancelledError:
+ raise
except Exception:
pass
raise SynapseError(400, "Invalid room stream token %r" % (string,))
@@ -599,7 +627,57 @@ class RoomStreamToken:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class StreamToken:
- """A collection of positions within multiple streams.
+ """A collection of keys joined together by underscores in the following
+ order and which represent the position in their respective streams.
+
+ ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1`
+ 1. `room_key`: `s2633508` which is a `RoomStreamToken`
+ - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
+ - See the docstring for `RoomStreamToken` for more details.
+ 2. `presence_key`: `17`
+ 3. `typing_key`: `338`
+ 4. `receipt_key`: `6732159`
+ 5. `account_data_key`: `1082514`
+ 6. `push_rules_key`: `541479`
+ 7. `to_device_key`: `274711`
+ 8. `device_list_key`: `265584`
+ 9. `groups_key`: `1`
+
+ You can see how many of these keys correspond to the various
+ fields in a "/sync" response:
+ ```json
+ {
+ "next_batch": "s12_4_0_1_1_1_1_4_1",
+ "presence": {
+ "events": []
+ },
+ "device_lists": {
+ "changed": []
+ },
+ "rooms": {
+ "join": {
+ "!QrZlfIDQLNLdZHqTnt:hs1": {
+ "timeline": {
+ "events": [],
+ "prev_batch": "s10_4_0_1_1_1_1_4_1",
+ "limited": false
+ },
+ "state": {
+ "events": []
+ },
+ "account_data": {
+ "events": []
+ },
+ "ephemeral": {
+ "events": []
+ }
+ }
+ }
+ }
+ }
+ ```
+
+ ---
For caching purposes, `StreamToken`s and by extension, all their attributes,
must be hashable.
@@ -630,6 +708,8 @@ class StreamToken:
return cls(
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
)
+ except CancelledError:
+ raise
except Exception:
raise SynapseError(400, "Invalid stream token")
@@ -748,6 +828,30 @@ class ReadReceipt:
data: JsonDict
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DeviceListUpdates:
+ """
+ An object containing a diff of information regarding other users' device lists, intended for
+ a recipient to carry out device list tracking.
+
+ Attributes:
+ changed: A set of users whose device lists have changed recently.
+ left: A set of users who the recipient no longer needs to track the device lists of.
+ Typically when those users no longer share any end-to-end encryption enabled rooms.
+ """
+
+ # We need to use a factory here, otherwise `set` is not evaluated at
+ # object instantiation, but instead at class definition instantiation.
+ # The latter happening only once, thus always giving you the same sets
+ # across multiple DeviceListUpdates instances.
+ # Also see: don't define mutable default arguments.
+ changed: Set[str] = attr.ib(factory=set)
+ left: Set[str] = attr.ib(factory=set)
+
+ def __bool__(self) -> bool:
+ return bool(self.changed or self.left)
+
+
def get_verify_key_from_cross_signing_key(key_info):
"""Get the key ID and signedjson verify key from a cross-signing key dict
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 6a8e844d..650e44de 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -18,7 +18,7 @@ import collections
import inspect
import itertools
import logging
-from contextlib import asynccontextmanager, contextmanager
+from contextlib import asynccontextmanager
from typing import (
Any,
AsyncIterator,
@@ -29,7 +29,6 @@ from typing import (
Generic,
Hashable,
Iterable,
- Iterator,
List,
Optional,
Set,
@@ -342,7 +341,7 @@ class Linearizer:
Example:
- with await limiter.queue("test_key"):
+ async with limiter.queue("test_key"):
# do some work.
"""
@@ -383,95 +382,53 @@ class Linearizer:
# non-empty.
return bool(entry.deferreds)
- def queue(self, key: Hashable) -> defer.Deferred:
- # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
- # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
- # propagated inside inlineCallbacks until Twisted 18.7)
+ def queue(self, key: Hashable) -> AsyncContextManager[None]:
+ @asynccontextmanager
+ async def _ctx_manager() -> AsyncIterator[None]:
+ entry = await self._acquire_lock(key)
+ try:
+ yield
+ finally:
+ self._release_lock(key, entry)
+
+ return _ctx_manager()
+
+ async def _acquire_lock(self, key: Hashable) -> _LinearizerEntry:
+ """Acquires a linearizer lock, waiting if necessary.
+
+ Returns once we have secured the lock.
+ """
entry = self.key_to_defer.setdefault(
key, _LinearizerEntry(0, collections.OrderedDict())
)
- # If the number of things executing is greater than the maximum
- # then add a deferred to the list of blocked items
- # When one of the things currently executing finishes it will callback
- # this item so that it can continue executing.
- if entry.count >= self.max_count:
- res = self._await_lock(key)
- else:
+ if entry.count < self.max_count:
+ # The number of things executing is less than the maximum.
logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry.count += 1
- res = defer.succeed(None)
-
- # once we successfully get the lock, we need to return a context manager which
- # will release the lock.
-
- @contextmanager
- def _ctx_manager(_: None) -> Iterator[None]:
- try:
- yield
- finally:
- logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
-
- # We've finished executing so check if there are any things
- # blocked waiting to execute and start one of them
- entry.count -= 1
-
- if entry.deferreds:
- (next_def, _) = entry.deferreds.popitem(last=False)
-
- # we need to run the next thing in the sentinel context.
- with PreserveLoggingContext():
- next_def.callback(None)
- elif entry.count == 0:
- # We were the last thing for this key: remove it from the
- # map.
- del self.key_to_defer[key]
-
- res.addCallback(_ctx_manager)
- return res
-
- def _await_lock(self, key: Hashable) -> defer.Deferred:
- """Helper for queue: adds a deferred to the queue
-
- Assumes that we've already checked that we've reached the limit of the number
- of lock-holders we allow. Creates a new deferred which is added to the list, and
- adds some management around cancellations.
-
- Returns the deferred, which will callback once we have secured the lock.
-
- """
- entry = self.key_to_defer[key]
+ return entry
+ # Otherwise, the number of things executing is at the maximum and we have to
+ # add a deferred to the list of blocked items.
+ # When one of the things currently executing finishes it will callback
+ # this item so that it can continue executing.
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1
- def cb(_r: None) -> "defer.Deferred[None]":
- logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
- entry.count += 1
-
- # if the code holding the lock completes synchronously, then it
- # will recursively run the next claimant on the list. That can
- # relatively rapidly lead to stack exhaustion. This is essentially
- # the same problem as http://twistedmatrix.com/trac/ticket/9304.
- #
- # In order to break the cycle, we add a cheeky sleep(0) here to
- # ensure that we fall back to the reactor between each iteration.
- #
- # (This needs to happen while we hold the lock, and the context manager's exit
- # code must be synchronous, so this is the only sensible place.)
- return self._clock.sleep(0)
-
- def eb(e: Failure) -> Failure:
+ try:
+ await new_defer
+ except Exception as e:
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
- "Cancelling wait for linearizer lock %r for key %r", self.name, key
+ "Cancelling wait for linearizer lock %r for key %r",
+ self.name,
+ key,
)
-
else:
logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
@@ -481,10 +438,47 @@ class Linearizer:
# we just have to take ourselves back out of the queue.
del entry.deferreds[new_defer]
- return e
+ raise
- new_defer.addCallbacks(cb, eb)
- return new_defer
+ logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
+ entry.count += 1
+
+ # if the code holding the lock completes synchronously, then it
+ # will recursively run the next claimant on the list. That can
+ # relatively rapidly lead to stack exhaustion. This is essentially
+ # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+ #
+ # In order to break the cycle, we add a cheeky sleep(0) here to
+ # ensure that we fall back to the reactor between each iteration.
+ #
+ # This needs to happen while we hold the lock. We could put it on the
+ # exit path, but that would slow down the uncontended case.
+ try:
+ await self._clock.sleep(0)
+ except CancelledError:
+ self._release_lock(key, entry)
+ raise
+
+ return entry
+
+ def _release_lock(self, key: Hashable, entry: _LinearizerEntry) -> None:
+ """Releases a held linearizer lock."""
+ logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
+
+ # We've finished executing so check if there are any things
+ # blocked waiting to execute and start one of them
+ entry.count -= 1
+
+ if entry.deferreds:
+ (next_def, _) = entry.deferreds.popitem(last=False)
+
+ # we need to run the next thing in the sentinel context.
+ with PreserveLoggingContext():
+ next_def.callback(None)
+ elif entry.count == 0:
+ # We were the last thing for this key: remove it from the
+ # map.
+ del self.key_to_defer[key]
class ReadWriteLock:
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 1cbc180e..42f6abb5 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -17,7 +17,7 @@ import logging
import typing
from enum import Enum, auto
from sys import intern
-from typing import Any, Callable, Dict, List, Optional, Sized
+from typing import Any, Callable, Dict, List, Optional, Sized, TypeVar
import attr
from prometheus_client.core import Gauge
@@ -195,8 +195,10 @@ KNOWN_KEYS = {
)
}
+T = TypeVar("T", Optional[str], str)
-def intern_string(string: Optional[str]) -> Optional[str]:
+
+def intern_string(string: T) -> T:
"""Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None:
return None
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 49519eb8..250f0735 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -1,4 +1,5 @@
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright (C) The Matrix.org Foundation C.I.C. 2022
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, FrozenSet, List, Optional
+from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
+
+from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
@@ -40,6 +43,8 @@ MEMBERSHIP_PRIORITY = (
Membership.BAN,
)
+_HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "")
+
async def filter_events_for_client(
storage: Storage,
@@ -74,7 +79,7 @@ async def filter_events_for_client(
# to clients.
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
- types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
+ types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
# we exclude outliers at this point, and then handle them separately later
event_id_to_state = await storage.state.get_state_for_events(
@@ -157,7 +162,7 @@ async def filter_events_for_client(
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
- visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+ visibility_event = state.get(_HISTORY_VIS_KEY, None)
if visibility_event:
visibility = visibility_event.content.get(
"history_visibility", HistoryVisibility.SHARED
@@ -293,67 +298,28 @@ async def filter_events_for_server(
return True
return False
- def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool:
- history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
- if history:
- visibility = history.content.get(
- "history_visibility", HistoryVisibility.SHARED
- )
- if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]:
- # We now loop through all state events looking for
- # membership states for the requesting server to determine
- # if the server is either in the room or has been invited
- # into the room.
- for ev in state.values():
- if ev.type != EventTypes.Member:
- continue
- try:
- domain = get_domain_from_id(ev.state_key)
- except Exception:
- continue
-
- if domain != server_name:
- continue
-
- memtype = ev.membership
- if memtype == Membership.JOIN:
- return True
- elif memtype == Membership.INVITE:
- if visibility == HistoryVisibility.INVITED:
- return True
- else:
- # server has no users in the room: redact
- return False
-
- return True
-
- # Lets check to see if all the events have a history visibility
- # of "shared" or "world_readable". If that's the case then we don't
- # need to check membership (as we know the server is in the room).
- event_to_state_ids = await storage.state.get_state_ids_for_events(
- frozenset(e.event_id for e in events),
- state_filter=StateFilter.from_types(
- types=((EventTypes.RoomHistoryVisibility, ""),)
- ),
- )
-
- visibility_ids = set()
- for sids in event_to_state_ids.values():
- hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
- if hist:
- visibility_ids.add(hist)
+ def check_event_is_visible(
+ visibility: str, memberships: StateMap[EventBase]
+ ) -> bool:
+ if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED):
+ return True
- # If we failed to find any history visibility events then the default
- # is "shared" visibility.
- if not visibility_ids:
- all_open = True
- else:
- event_map = await storage.main.get_events(visibility_ids)
- all_open = all(
- e.content.get("history_visibility")
- in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
- for e in event_map.values()
- )
+ # We now loop through all membership events looking for
+ # membership states for the requesting server to determine
+ # if the server is either in the room or has been invited
+ # into the room.
+ for ev in memberships.values():
+ assert get_domain_from_id(ev.state_key) == server_name
+
+ memtype = ev.membership
+ if memtype == Membership.JOIN:
+ return True
+ elif memtype == Membership.INVITE:
+ if visibility == HistoryVisibility.INVITED:
+ return True
+
+ # server has no users in the room: redact
+ return False
if not check_history_visibility_only:
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
@@ -362,34 +328,100 @@ async def filter_events_for_server(
# to no users having been erased.
erased_senders = {}
- if all_open:
- # all the history_visibility state affecting these events is open, so
- # we don't need to filter by membership state. We *do* need to check
- # for user erasure, though.
- if erased_senders:
- to_return = []
- for e in events:
- if not is_sender_erased(e, erased_senders):
- to_return.append(e)
- elif redact:
- to_return.append(prune_event(e))
-
- return to_return
-
- # If there are no erased users then we can just return the given list
- # of events without having to copy it.
- return events
-
- # Ok, so we're dealing with events that have non-trivial visibility
- # rules, so we need to also get the memberships of the room.
-
- # first, for each event we're wanting to return, get the event_ids
- # of the history vis and membership state at those events.
+ # Let's check to see if all the events have a history visibility
+ # of "shared" or "world_readable". If that's the case then we don't
+ # need to check membership (as we know the server is in the room).
+ event_to_history_vis = await _event_to_history_vis(storage, events)
+
+ # for any with restricted vis, we also need the memberships
+ event_to_memberships = await _event_to_memberships(
+ storage,
+ [
+ e
+ for e in events
+ if event_to_history_vis[e.event_id]
+ not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
+ ],
+ server_name,
+ )
+
+ to_return = []
+ for e in events:
+ erased = is_sender_erased(e, erased_senders)
+ visible = check_event_is_visible(
+ event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
+ )
+ if visible and not erased:
+ to_return.append(e)
+ elif redact:
+ to_return.append(prune_event(e))
+
+ return to_return
+
+
+async def _event_to_history_vis(
+ storage: Storage, events: Collection[EventBase]
+) -> Dict[str, str]:
+ """Get the history visibility at each of the given events
+
+ Returns a map from event id to history_visibility setting
+ """
+
+ # outliers get special treatment here. We don't have the state at that point in the
+ # room (and attempting to look it up will raise an exception), so all we can really
+ # do is assume that the requesting server is allowed to see the event. That's
+ # equivalent to there not being a history_visibility event, so we just exclude
+ # any outliers from the query.
+ event_to_state_ids = await storage.state.get_state_ids_for_events(
+ frozenset(e.event_id for e in events if not e.internal_metadata.is_outlier()),
+ state_filter=StateFilter.from_types(types=(_HISTORY_VIS_KEY,)),
+ )
+
+ visibility_ids = {
+ vis_event_id
+ for vis_event_id in (
+ state_ids.get(_HISTORY_VIS_KEY) for state_ids in event_to_state_ids.values()
+ )
+ if vis_event_id
+ }
+ vis_events = await storage.main.get_events(visibility_ids)
+
+ result: Dict[str, str] = {}
+ for event in events:
+ vis = HistoryVisibility.SHARED
+ state_ids = event_to_state_ids.get(event.event_id)
+
+ # if we didn't find any state for this event, it's an outlier, and we assume
+ # it's open
+ visibility_id = None
+ if state_ids:
+ visibility_id = state_ids.get(_HISTORY_VIS_KEY)
+
+ if visibility_id:
+ vis_event = vis_events[visibility_id]
+ vis = vis_event.content.get("history_visibility", HistoryVisibility.SHARED)
+ assert isinstance(vis, str)
+
+ result[event.event_id] = vis
+ return result
+
+
+async def _event_to_memberships(
+ storage: Storage, events: Collection[EventBase], server_name: str
+) -> Dict[str, StateMap[EventBase]]:
+ """Get the remote membership list at each of the given events
+
+ Returns a map from event id to state map, which will contain only membership events
+ for the given server.
+ """
+
+ if not events:
+ return {}
+
+ # for each event, get the event_ids of the membership state at those events.
event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
- state_filter=StateFilter.from_types(
- types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
- ),
+ state_filter=StateFilter.from_types(types=((EventTypes.Member, None),)),
)
# We only want to pull out member events that correspond to the
@@ -405,10 +437,7 @@ async def filter_events_for_server(
for key, event_id in key_to_eid.items()
}
- def include(typ, state_key):
- if typ != EventTypes.Member:
- return True
-
+ def include(state_key: str) -> bool:
# we avoid using get_domain_from_id here for efficiency.
idx = state_key.find(":")
if idx == -1:
@@ -416,10 +445,14 @@ async def filter_events_for_server(
return state_key[idx + 1 :] == server_name
event_map = await storage.main.get_events(
- [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
+ [
+ e_id
+ for e_id, (_, state_key) in event_id_to_state_key.items()
+ if include(state_key)
+ ]
)
- event_to_state = {
+ return {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items()
@@ -427,14 +460,3 @@ async def filter_events_for_server(
}
for e_id, key_to_eid in event_to_state_ids.items()
}
-
- to_return = []
- for e in events:
- erased = is_sender_erased(e, erased_senders)
- visible = check_event_is_visible(e, event_to_state[e.event_id])
- if visible and not erased:
- to_return.append(e)
- elif redact:
- to_return.append(prune_event(e))
-
- return to_return