summaryrefslogtreecommitdiff
path: root/synapse
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-10-19 19:02:19 +0200
committerAndrej Shadura <andrewsh@debian.org>2021-10-19 19:02:19 +0200
commit94d2082531bf10c3cdf17b4e8fde9ca1a6c9de40 (patch)
tree8a96d1eb4c266243e10504a968fd49cb780df9d4 /synapse
parent6b06932344e635f554420698ecd1954e31d0c6ea (diff)
New upstream version 1.45.0
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/filtering.py117
-rw-r--r--synapse/api/ratelimiting.py86
-rw-r--r--synapse/app/_base.py10
-rw-r--r--synapse/app/admin_cmd.py8
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/app/homeserver.py16
-rw-r--r--synapse/app/phone_stats_home.py8
-rw-r--r--synapse/config/_base.py64
-rw-r--r--synapse/config/account_validity.py2
-rw-r--r--synapse/config/cas.py2
-rw-r--r--synapse/config/emailconfig.py9
-rw-r--r--synapse/config/key.py6
-rw-r--r--synapse/config/oidc.py2
-rw-r--r--synapse/config/registration.py7
-rw-r--r--synapse/config/repository.py2
-rw-r--r--synapse/config/saml2.py2
-rw-r--r--synapse/config/server.py104
-rw-r--r--synapse/config/server_notices.py4
-rw-r--r--synapse/config/sso.py6
-rw-r--r--synapse/config/tls.py9
-rw-r--r--synapse/event_auth.py156
-rw-r--r--synapse/events/builder.py20
-rw-r--r--synapse/events/presence_router.py6
-rw-r--r--synapse/events/spamcheck.py59
-rw-r--r--synapse/events/third_party_rules.py9
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/federation/federation_server.py5
-rw-r--r--synapse/federation/transport/server/__init__.py2
-rw-r--r--synapse/handlers/_base.py120
-rw-r--r--synapse/handlers/account_validity.py8
-rw-r--r--synapse/handlers/admin.py7
-rw-r--r--synapse/handlers/auth.py10
-rw-r--r--synapse/handlers/deactivate_account.py10
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/handlers/directory.py11
-rw-r--r--synapse/handlers/event_auth.py15
-rw-r--r--synapse/handlers/events.py12
-rw-r--r--synapse/handlers/federation.py76
-rw-r--r--synapse/handlers/federation_event.py169
-rw-r--r--synapse/handlers/identity.py22
-rw-r--r--synapse/handlers/initial_sync.py8
-rw-r--r--synapse/handlers/message.py86
-rw-r--r--synapse/handlers/pagination.py22
-rw-r--r--synapse/handlers/profile.py17
-rw-r--r--synapse/handlers/read_marker.py5
-rw-r--r--synapse/handlers/receipts.py6
-rw-r--r--synapse/handlers/register.py24
-rw-r--r--synapse/handlers/room.py25
-rw-r--r--synapse/handlers/room_batch.py423
-rw-r--r--synapse/handlers/room_list.py7
-rw-r--r--synapse/handlers/room_member.py74
-rw-r--r--synapse/handlers/saml.py7
-rw-r--r--synapse/handlers/search.py11
-rw-r--r--synapse/handlers/send_email.py9
-rw-r--r--synapse/handlers/set_password.py6
-rw-r--r--synapse/handlers/ui_auth/checkers.py14
-rw-r--r--synapse/handlers/user_directory.py108
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/http/matrixfederationclient.py10
-rw-r--r--synapse/http/server.py5
-rw-r--r--synapse/logging/_terse_json.py6
-rw-r--r--synapse/logging/context.py16
-rw-r--r--synapse/logging/opentracing.py9
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py20
-rw-r--r--synapse/push/clientformat.py4
-rw-r--r--synapse/push/httppusher.py4
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/replication/http/_base.py154
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py4
-rw-r--r--synapse/replication/slave/storage/pushers.py10
-rw-r--r--synapse/replication/tcp/client.py2
-rw-r--r--synapse/replication/tcp/handler.py15
-rw-r--r--synapse/replication/tcp/redis.py8
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/rest/admin/users.py4
-rw-r--r--synapse/rest/client/account.py40
-rw-r--r--synapse/rest/client/auth.py8
-rw-r--r--synapse/rest/client/capabilities.py10
-rw-r--r--synapse/rest/client/filter.py2
-rw-r--r--synapse/rest/client/login.py6
-rw-r--r--synapse/rest/client/profile.py6
-rw-r--r--synapse/rest/client/push_rule.py4
-rw-r--r--synapse/rest/client/register.py32
-rw-r--r--synapse/rest/client/room.py2
-rw-r--r--synapse/rest/client/room_batch.py337
-rw-r--r--synapse/rest/client/shared_rooms.py2
-rw-r--r--synapse/rest/client/sync.py2
-rw-r--r--synapse/rest/client/voip.py2
-rw-r--r--synapse/rest/media/v1/__init__.py38
-rw-r--r--synapse/rest/media/v1/oembed.py28
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py186
-rw-r--r--synapse/rest/media/v1/thumbnailer.py21
-rw-r--r--synapse/rest/well_known.py4
-rw-r--r--synapse/server.py16
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py8
-rw-r--r--synapse/server_notices/server_notices_manager.py8
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/state/v1.py12
-rw-r--r--synapse/state/v2.py6
-rw-r--r--synapse/storage/databases/main/censor_events.py8
-rw-r--r--synapse/storage/databases/main/client_ips.py15
-rw-r--r--synapse/storage/databases/main/events.py34
-rw-r--r--synapse/storage/databases/main/filtering.py8
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py36
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py10
-rw-r--r--synapse/storage/databases/main/registration.py17
-rw-r--r--synapse/storage/databases/main/room.py8
-rw-r--r--synapse/storage/databases/main/room_batch.py6
-rw-r--r--synapse/storage/databases/main/search.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py101
-rw-r--r--synapse/storage/prepare_database.py6
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/state.py172
-rw-r--r--synapse/storage/util/id_generators.py225
-rw-r--r--synapse/storage/util/sequence.py6
-rw-r--r--synapse/util/__init__.py11
-rw-r--r--synapse/util/async_helpers.py6
-rw-r--r--synapse/util/caches/cached_call.py2
-rw-r--r--synapse/util/caches/deferred_cache.py11
-rw-r--r--synapse/util/caches/lrucache.py57
-rw-r--r--synapse/util/caches/response_cache.py6
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/caches/ttlcache.py12
-rw-r--r--synapse/util/daemonize.py8
-rw-r--r--synapse/util/metrics.py27
-rw-r--r--synapse/util/patch_inline_callbacks.py28
-rw-r--r--synapse/util/threepids.py4
-rw-r--r--synapse/util/versionstring.py25
132 files changed, 2443 insertions, 1542 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index b8979c36..97452f34 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.44.0"
+__version__ = "1.45.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index ad1ff6a9..20e91a11 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,7 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from typing import List
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Container,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ TypeVar,
+ Union,
+)
import jsonschema
from jsonschema import FormatChecker
@@ -23,7 +33,11 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomID, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
FILTER_SCHEMA = {
"additionalProperties": False,
@@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = {
@FormatChecker.cls_checks("matrix_room_id")
-def matrix_room_id_validator(room_id_str):
+def matrix_room_id_validator(room_id_str: str) -> RoomID:
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks("matrix_user_id")
-def matrix_user_id_validator(user_id_str):
+def matrix_user_id_validator(user_id_str: str) -> UserID:
return UserID.from_string(user_id_str)
class Filtering:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
- async def get_user_filter(self, user_localpart, filter_id):
+ async def get_user_filter(
+ self, user_localpart: str, filter_id: Union[int, str]
+ ) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
- def add_user_filter(self, user_localpart, user_filter):
+ def add_user_filter(
+ self, user_localpart: str, user_filter: JsonDict
+ ) -> Awaitable[int]:
self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
@@ -146,13 +164,13 @@ class Filtering:
# replace_user_filter at some point? There's no REST API specified for
# them however
- def check_valid_filter(self, user_filter_json):
+ def check_valid_filter(self, user_filter_json: JsonDict) -> None:
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
- user_filter_json(dict): The filter
+ user_filter_json: The filter
Raises:
SynapseError: If the filter is not valid.
"""
@@ -167,8 +185,12 @@ class Filtering:
raise SynapseError(400, str(e))
+# Filters work across events, presence EDUs, and account data.
+FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
+
+
class FilterCollection:
- def __init__(self, filter_json):
+ def __init__(self, filter_json: JsonDict):
self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {})
@@ -188,25 +210,25 @@ class FilterCollection:
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
- def __repr__(self):
+ def __repr__(self) -> str:
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
- def get_filter_json(self):
+ def get_filter_json(self) -> JsonDict:
return self._filter_json
- def timeline_limit(self):
+ def timeline_limit(self) -> int:
return self._room_timeline_filter.limit()
- def presence_limit(self):
+ def presence_limit(self) -> int:
return self._presence_filter.limit()
- def ephemeral_limit(self):
+ def ephemeral_limit(self) -> int:
return self._room_ephemeral_filter.limit()
- def lazy_load_members(self):
+ def lazy_load_members(self) -> bool:
return self._room_state_filter.lazy_load_members()
- def include_redundant_members(self):
+ def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members()
def filter_presence(self, events):
@@ -218,29 +240,31 @@ class FilterCollection:
def filter_room_state(self, events):
return self._room_state_filter.filter(self._room_filter.filter(events))
- def filter_room_timeline(self, events):
+ def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))
- def filter_room_ephemeral(self, events):
+ def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
- def filter_room_account_data(self, events):
+ def filter_room_account_data(
+ self, events: Iterable[FilterEvent]
+ ) -> List[FilterEvent]:
return self._room_account_data.filter(self._room_filter.filter(events))
- def blocks_all_presence(self):
+ def blocks_all_presence(self) -> bool:
return (
self._presence_filter.filters_all_types()
or self._presence_filter.filters_all_senders()
)
- def blocks_all_room_ephemeral(self):
+ def blocks_all_room_ephemeral(self) -> bool:
return (
self._room_ephemeral_filter.filters_all_types()
or self._room_ephemeral_filter.filters_all_senders()
or self._room_ephemeral_filter.filters_all_rooms()
)
- def blocks_all_room_timeline(self):
+ def blocks_all_room_timeline(self) -> bool:
return (
self._room_timeline_filter.filters_all_types()
or self._room_timeline_filter.filters_all_senders()
@@ -249,7 +273,7 @@ class FilterCollection:
class Filter:
- def __init__(self, filter_json):
+ def __init__(self, filter_json: JsonDict):
self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
@@ -266,20 +290,20 @@ class Filter:
self.labels = self.filter_json.get("org.matrix.labels", None)
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
- def filters_all_types(self):
+ def filters_all_types(self) -> bool:
return "*" in self.not_types
- def filters_all_senders(self):
+ def filters_all_senders(self) -> bool:
return "*" in self.not_senders
- def filters_all_rooms(self):
+ def filters_all_rooms(self) -> bool:
return "*" in self.not_rooms
- def check(self, event):
+ def check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
Returns:
- bool: True if the event matches
+ True if the event matches
"""
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
@@ -305,18 +329,25 @@ class Filter:
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
- content = event.get("content", {})
+ content = event.get("content") or {}
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, [])
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
- def check_fields(self, room_id, sender, event_type, labels, contains_url):
+ def check_fields(
+ self,
+ room_id: Optional[str],
+ sender: Optional[str],
+ event_type: Optional[str],
+ labels: Container[str],
+ contains_url: bool,
+ ) -> bool:
"""Checks whether the filter matches the given event fields.
Returns:
- bool: True if the event fields match
+ True if the event fields match
"""
literal_keys = {
"rooms": lambda v: room_id == v,
@@ -343,14 +374,14 @@ class Filter:
return True
- def filter_rooms(self, room_ids):
+ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
"""Apply the 'rooms' filter to a given list of rooms.
Args:
- room_ids (list): A list of room_ids.
+ room_ids: A list of room_ids.
Returns:
- list: A list of room_ids that match the filter
+ A list of room_ids that match the filter
"""
room_ids = set(room_ids)
@@ -363,23 +394,23 @@ class Filter:
return room_ids
- def filter(self, events):
+ def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events))
- def limit(self):
+ def limit(self) -> int:
return self.filter_json.get("limit", 10)
- def lazy_load_members(self):
+ def lazy_load_members(self) -> bool:
return self.filter_json.get("lazy_load_members", False)
- def include_redundant_members(self):
+ def include_redundant_members(self) -> bool:
return self.filter_json.get("include_redundant_members", False)
- def with_room_ids(self, room_ids):
+ def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
Args:
- room_ids (iterable[unicode]): The room_ids to add
+ room_ids: The room_ids to add
Returns:
filter: A new filter including the given rooms and the old
@@ -390,8 +421,8 @@ class Filter:
return newFilter
-def _matches_wildcard(actual_value, filter_value):
- if filter_value.endswith("*"):
+def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
+ if filter_value.endswith("*") and isinstance(actual_value, str):
type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix)
else:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index cbdd7402..e8964097 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
+from synapse.config.ratelimiting import RateLimitConfig
from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
@@ -233,3 +234,88 @@ class Ratelimiter:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
+
+
+class RequestRatelimiter:
+ def __init__(
+ self,
+ store: DataStore,
+ clock: Clock,
+ rc_message: RateLimitConfig,
+ rc_admin_redaction: Optional[RateLimitConfig],
+ ):
+ self.store = store
+ self.clock = clock
+
+ # The rate_hz and burst_count are overridden on a per-user basis
+ self.request_ratelimiter = Ratelimiter(
+ store=self.store, clock=self.clock, rate_hz=0, burst_count=0
+ )
+ self._rc_message = rc_message
+
+ # Check whether ratelimiting room admin message redaction is enabled
+ # by the presence of rate limits in the config
+ if rc_admin_redaction:
+ self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
+ store=self.store,
+ clock=self.clock,
+ rate_hz=rc_admin_redaction.per_second,
+ burst_count=rc_admin_redaction.burst_count,
+ )
+ else:
+ self.admin_redaction_ratelimiter = None
+
+ async def ratelimit(
+ self,
+ requester: Requester,
+ update: bool = True,
+ is_admin_redaction: bool = False,
+ ) -> None:
+ """Ratelimits requests.
+
+ Args:
+ requester
+ update: Whether to record that a request is being processed.
+ Set to False when doing multiple checks for one request (e.g.
+ to check up front if we would reject the request), and set to
+ True for the last call for a given request.
+ is_admin_redaction: Whether this is a room admin/moderator
+ redacting an event. If so then we may apply different
+ ratelimits depending on config.
+
+ Raises:
+ LimitExceededError if the request should be ratelimited
+ """
+ user_id = requester.user.to_string()
+
+ # The AS user itself is never rate limited.
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service is not None:
+ return # do not ratelimit app service senders
+
+ messages_per_second = self._rc_message.per_second
+ burst_count = self._rc_message.burst_count
+
+ # Check if there is a per user override in the DB.
+ override = await self.store.get_ratelimit_for_user(user_id)
+ if override:
+ # If overridden with a null Hz then ratelimiting has been entirely
+ # disabled for the user
+ if not override.messages_per_second:
+ return
+
+ messages_per_second = override.messages_per_second
+ burst_count = override.burst_count
+
+ if is_admin_redaction and self.admin_redaction_ratelimiter:
+ # If we have separate config for admin redactions, use a separate
+ # ratelimiter as to not have user_ids clash
+ await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
+ else:
+ # Override rate and burst count per-user
+ await self.request_ratelimiter.ratelimit(
+ requester,
+ rate_hz=messages_per_second,
+ burst_count=burst_count,
+ update=update,
+ )
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 548f6dcd..4a204a58 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -86,11 +86,11 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
start_reactor(
appname,
- soft_file_limit=config.soft_file_limit,
- gc_thresholds=config.gc_thresholds,
+ soft_file_limit=config.server.soft_file_limit,
+ gc_thresholds=config.server.gc_thresholds,
pid_file=config.worker.worker_pid_file,
daemonize=config.worker.worker_daemonize,
- print_pidfile=config.print_pidfile,
+ print_pidfile=config.server.print_pidfile,
logger=logger,
run_command=run_command,
)
@@ -298,10 +298,10 @@ def refresh_certificate(hs):
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
"""
- if not hs.config.has_tls_listener():
+ if not hs.config.server.has_tls_listener():
return
- hs.config.read_certificate_from_disk()
+ hs.config.tls.read_certificate_from_disk()
hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config)
if hs._listening_services:
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index f2c5b752..13d20af4 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -195,14 +195,14 @@ def start(config_options):
config.logging.no_redirect_stdio = True
# Explicitly disable background processes
- config.update_user_directory = False
+ config.server.update_user_directory = False
config.worker.run_background_tasks = False
- config.start_pushers = False
+ config.worker.start_pushers = False
config.pusher_shard_config.instances = []
- config.send_federation = False
+ config.worker.send_federation = False
config.federation_shard_config.instances = []
- synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
ss = AdminCmdServer(
config.server.server_name,
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 3036e1b4..7489f31d 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -462,7 +462,7 @@ def start(config_options):
# For other worker types we force this to off.
config.server.update_user_directory = False
- synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
if config.server.gc_seconds:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 205831dc..422f03cc 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer):
)
if name in ["media", "federation", "client"]:
- if self.config.media.enable_media_repo:
+ if self.config.server.enable_media_repo:
media_repo = self.get_media_repository_resource()
resources.update(
{MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
@@ -248,7 +248,7 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "webclient":
- webclient_loc = self.config.web_client_location
+ webclient_loc = self.config.server.web_client_location
if webclient_loc is None:
logger.warning(
@@ -343,7 +343,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
+ events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
if config.server.gc_seconds:
@@ -439,11 +439,11 @@ def run(hs):
_base.start_reactor(
"synapse-homeserver",
- soft_file_limit=hs.config.soft_file_limit,
- gc_thresholds=hs.config.gc_thresholds,
- pid_file=hs.config.pid_file,
- daemonize=hs.config.daemonize,
- print_pidfile=hs.config.print_pidfile,
+ soft_file_limit=hs.config.server.soft_file_limit,
+ gc_thresholds=hs.config.server.gc_thresholds,
+ pid_file=hs.config.server.pid_file,
+ daemonize=hs.config.server.daemonize,
+ print_pidfile=hs.config.server.print_pidfile,
logger=logger,
)
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 49e7a45e..fcd01e83 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -74,7 +74,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
store = hs.get_datastore()
stats["homeserver"] = hs.config.server.server_name
- stats["server_context"] = hs.config.server_context
+ stats["server_context"] = hs.config.server.server_context
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
version = sys.version_info
@@ -171,7 +171,7 @@ def start_phone_stats_home(hs):
current_mau_count_by_service = {}
reserved_users = ()
store = hs.get_datastore()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
current_mau_count = await store.get_monthly_active_count()
current_mau_count_by_service = (
await store.get_monthly_active_count_by_service()
@@ -183,9 +183,9 @@ def start_phone_stats_home(hs):
current_mau_by_service_gauge.labels(app_service).set(float(count))
registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
- max_mau_gauge.set(float(hs.config.max_mau_value))
+ max_mau_gauge.set(float(hs.config.server.max_mau_value))
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
generate_monthly_active_users()
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index d974a1a2..7c4428a1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -118,21 +118,6 @@ class Config:
"synapse", "res/templates"
)
- def __getattr__(self, item: str) -> Any:
- """
- Try and fetch a configuration option that does not exist on this class.
-
- This is so that existing configs that rely on `self.value`, where value
- is actually from a different config section, continue to work.
- """
- if item in ["generate_config_section", "read_config"]:
- raise AttributeError(item)
-
- if self.root is None:
- raise AttributeError(item)
- else:
- return self.root._get_unclassed_config(self.section, item)
-
@staticmethod
def parse_size(value):
if isinstance(value, int):
@@ -289,7 +274,9 @@ class Config:
env.filters.update(
{
"format_ts": _format_ts_filter,
- "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+ "mxc_to_http": _create_mxc_to_http_filter(
+ self.root.server.public_baseurl
+ ),
}
)
@@ -311,8 +298,6 @@ class RootConfig:
config_classes = []
def __init__(self):
- self._configs = OrderedDict()
-
for config_class in self.config_classes:
if config_class.section is None:
raise ValueError("%r requires a section name" % (config_class,))
@@ -321,42 +306,7 @@ class RootConfig:
conf = config_class(self)
except Exception as e:
raise Exception("Failed making %s: %r" % (config_class.section, e))
- self._configs[config_class.section] = conf
-
- def __getattr__(self, item: str) -> Any:
- """
- Redirect lookups on this object either to config objects, or values on
- config objects, so that `config.tls.blah` works, as well as legacy uses
- of things like `config.server_name`. It will first look up the config
- section name, and then values on those config classes.
- """
- if item in self._configs.keys():
- return self._configs[item]
-
- return self._get_unclassed_config(None, item)
-
- def _get_unclassed_config(self, asking_section: Optional[str], item: str):
- """
- Fetch a config value from one of the instantiated config classes that
- has not been fetched directly.
-
- Args:
- asking_section: If this check is coming from a Config child, which
- one? This section will not be asked if it has the value.
- item: The configuration value key.
-
- Raises:
- AttributeError if no config classes have the config key. The body
- will contain what sections were checked.
- """
- for key, val in self._configs.items():
- if key == asking_section:
- continue
-
- if item in dir(val):
- return getattr(val, item)
-
- raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+ setattr(self, config_class.section, conf)
def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
"""
@@ -373,9 +323,11 @@ class RootConfig:
"""
res = OrderedDict()
- for name, config in self._configs.items():
+ for config_class in self.config_classes:
+ config = getattr(self, config_class.section)
+
if hasattr(config, func_name):
- res[name] = getattr(config, func_name)(*args, **kwargs)
+ res[config_class.section] = getattr(config, func_name)(*args, **kwargs)
return res
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index ffaffc49..b56c2a24 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -76,7 +76,7 @@ class AccountValidityConfig(Config):
)
if self.account_validity_renew_by_email_enabled:
- if not self.public_baseurl:
+ if not self.root.server.public_baseurl:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
# Load account validity templates.
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 901f4123..9b58ecf3 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -37,7 +37,7 @@ class CasConfig(Config):
# The public baseurl is required because it is used by the redirect
# template.
- public_baseurl = self.public_baseurl
+ public_baseurl = self.root.server.public_baseurl
if not public_baseurl:
raise ConfigError("cas_config requires a public_baseurl to be set")
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 936abe61..8ff59aa2 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -19,7 +19,6 @@ import email.utils
import logging
import os
from enum import Enum
-from typing import Optional
import attr
@@ -135,7 +134,7 @@ class EmailConfig(Config):
# msisdn is currently always remote while Synapse does not support any method of
# sending SMS messages
ThreepidBehaviour.REMOTE
- if self.account_threepid_delegate_email
+ if self.root.registration.account_threepid_delegate_email
else ThreepidBehaviour.LOCAL
)
# Prior to Synapse v1.4.0, there was another option that defined whether Synapse would
@@ -144,7 +143,7 @@ class EmailConfig(Config):
# identity server in the process.
self.using_identity_server_from_trusted_list = False
if (
- not self.account_threepid_delegate_email
+ not self.root.registration.account_threepid_delegate_email
and config.get("trust_identity_server_for_password_resets", False) is True
):
# Use the first entry in self.trusted_third_party_id_servers instead
@@ -156,7 +155,7 @@ class EmailConfig(Config):
# trusted_third_party_id_servers does not contain a scheme whereas
# account_threepid_delegate_email is expected to. Presume https
- self.account_threepid_delegate_email: Optional[str] = (
+ self.root.registration.account_threepid_delegate_email = (
"https://" + first_trusted_identity_server
)
self.using_identity_server_from_trusted_list = True
@@ -335,7 +334,7 @@ class EmailConfig(Config):
"client_base_url", email_config.get("riot_base_url", None)
)
- if self.account_validity_renew_by_email_enabled:
+ if self.root.account_validity.account_validity_renew_by_email_enabled:
expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html"
)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 94a90630..015dbb8a 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -145,11 +145,13 @@ class KeyConfig(Config):
# list of TrustedKeyServer objects
self.key_servers = list(
- _parse_key_servers(key_servers, self.federation_verify_certificates)
+ _parse_key_servers(
+ key_servers, self.root.tls.federation_verify_certificates
+ )
)
self.macaroon_secret_key = config.get(
- "macaroon_secret_key", self.registration_shared_secret
+ "macaroon_secret_key", self.root.registration.registration_shared_secret
)
if not self.macaroon_secret_key:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 7e67fbad..10f57963 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -58,7 +58,7 @@ class OIDCConfig(Config):
"Multiple OIDC providers have the idp_id %r." % idp_id
)
- public_baseurl = self.public_baseurl
+ public_baseurl = self.root.server.public_baseurl
if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback"
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 7cffdacf..a3d2a38c 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -45,7 +45,10 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
- if self.account_threepid_delegate_msisdn and not self.public_baseurl:
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.root.server.public_baseurl
+ ):
raise ConfigError(
"The configuration option `public_baseurl` is required if "
"`account_threepid_delegate.msisdn` is set, such that "
@@ -85,7 +88,7 @@ class RegistrationConfig(Config):
if mxid_localpart:
# Convert the localpart to a full mxid.
self.auto_join_user_id = UserID(
- mxid_localpart, self.server_name
+ mxid_localpart, self.root.server.server_name
).to_string()
if self.autocreate_auto_join_rooms:
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 7481f3bf..69906a98 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -94,7 +94,7 @@ class ContentRepositoryConfig(Config):
# Only enable the media repo if either the media repo is enabled or the
# current worker app is the media repo.
if (
- self.enable_media_repo is False
+ self.root.server.enable_media_repo is False
and config.get("worker_app") != "synapse.app.media_repository"
):
self.can_load_media_repo = False
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 05e98362..9c51b6a2 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -199,7 +199,7 @@ class SAML2Config(Config):
"""
import saml2
- public_baseurl = self.public_baseurl
+ public_baseurl = self.root.server.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ad8715da..818b8063 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,6 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +17,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
import yaml
@@ -184,49 +182,74 @@ KNOWN_RESOURCES = {
@attr.s(frozen=True)
class HttpResourceConfig:
- names = attr.ib(
- type=List[str],
+ names: List[str] = attr.ib(
factory=list,
validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore
)
- compress = attr.ib(
- type=bool,
+ compress: bool = attr.ib(
default=False,
validator=attr.validators.optional(attr.validators.instance_of(bool)), # type: ignore[arg-type]
)
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class HttpListenerConfig:
"""Object describing the http-specific parts of the config of a listener"""
- x_forwarded = attr.ib(type=bool, default=False)
- resources = attr.ib(type=List[HttpResourceConfig], factory=list)
- additional_resources = attr.ib(type=Dict[str, dict], factory=dict)
- tag = attr.ib(type=str, default=None)
+ x_forwarded: bool = False
+ resources: List[HttpResourceConfig] = attr.ib(factory=list)
+ additional_resources: Dict[str, dict] = attr.ib(factory=dict)
+ tag: Optional[str] = None
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ListenerConfig:
"""Object describing the configuration of a single listener."""
- port = attr.ib(type=int, validator=attr.validators.instance_of(int))
- bind_addresses = attr.ib(type=List[str])
- type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
- tls = attr.ib(type=bool, default=False)
+ port: int = attr.ib(validator=attr.validators.instance_of(int))
+ bind_addresses: List[str]
+ type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
+ tls: bool = False
# http_options is only populated if type=http
- http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
+ http_options: Optional[HttpListenerConfig] = None
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ManholeConfig:
"""Object describing the configuration of the manhole"""
- username = attr.ib(type=str, validator=attr.validators.instance_of(str))
- password = attr.ib(type=str, validator=attr.validators.instance_of(str))
- priv_key = attr.ib(type=Optional[Key])
- pub_key = attr.ib(type=Optional[Key])
+ username: str = attr.ib(validator=attr.validators.instance_of(str))
+ password: str = attr.ib(validator=attr.validators.instance_of(str))
+ priv_key: Optional[Key]
+ pub_key: Optional[Key]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RetentionConfig:
+ """Object describing the configuration of the manhole"""
+
+ interval: int
+ shortest_max_lifetime: Optional[int]
+ longest_max_lifetime: Optional[int]
+
+
+@attr.s(frozen=True)
+class LimitRemoteRoomsConfig:
+ enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
+ complexity: Union[float, int] = attr.ib(
+ validator=attr.validators.instance_of(
+ (float, int) # type: ignore[arg-type] # noqa
+ ),
+ default=1.0,
+ )
+ complexity_error: str = attr.ib(
+ validator=attr.validators.instance_of(str),
+ default=ROOM_COMPLEXITY_TOO_GREAT,
+ )
+ admins_can_join: bool = attr.ib(
+ validator=attr.validators.instance_of(bool), default=False
+ )
class ServerConfig(Config):
@@ -519,7 +542,7 @@ class ServerConfig(Config):
" greater than 'allowed_lifetime_max'"
)
- self.retention_purge_jobs: List[Dict[str, Optional[int]]] = []
+ self.retention_purge_jobs: List[RetentionConfig] = []
for purge_job_config in retention_config.get("purge_jobs", []):
interval_config = purge_job_config.get("interval")
@@ -553,20 +576,12 @@ class ServerConfig(Config):
)
self.retention_purge_jobs.append(
- {
- "interval": interval,
- "shortest_max_lifetime": shortest_max_lifetime,
- "longest_max_lifetime": longest_max_lifetime,
- }
+ RetentionConfig(interval, shortest_max_lifetime, longest_max_lifetime)
)
if not self.retention_purge_jobs:
self.retention_purge_jobs = [
- {
- "interval": self.parse_duration("1d"),
- "shortest_max_lifetime": None,
- "longest_max_lifetime": None,
- }
+ RetentionConfig(self.parse_duration("1d"), None, None)
]
self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
@@ -591,25 +606,6 @@ class ServerConfig(Config):
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
- @attr.s
- class LimitRemoteRoomsConfig:
- enabled = attr.ib(
- validator=attr.validators.instance_of(bool), default=False
- )
- complexity = attr.ib(
- validator=attr.validators.instance_of(
- (float, int) # type: ignore[arg-type] # noqa
- ),
- default=1.0,
- )
- complexity_error = attr.ib(
- validator=attr.validators.instance_of(str),
- default=ROOM_COMPLEXITY_TOO_GREAT,
- )
- admins_can_join = attr.ib(
- validator=attr.validators.instance_of(bool), default=False
- )
-
self.limit_remote_rooms = LimitRemoteRoomsConfig(
**(config.get("limit_remote_rooms") or {})
)
diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py
index 48bf3241..bde4e879 100644
--- a/synapse/config/server_notices.py
+++ b/synapse/config/server_notices.py
@@ -73,7 +73,9 @@ class ServerNoticesConfig(Config):
return
mxid_localpart = c["system_mxid_localpart"]
- self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string()
+ self.server_notices_mxid = UserID(
+ mxid_localpart, self.root.server.server_name
+ ).to_string()
self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None)
self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None)
# todo: i18n
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 524a7ff3..11a9b76a 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -103,8 +103,10 @@ class SSOConfig(Config):
# the client's.
# public_baseurl is an optional setting, so we only add the fallback's URL to the
# list if it's provided (because we can't figure out what that URL is otherwise).
- if self.public_baseurl:
- login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
+ if self.root.server.public_baseurl:
+ login_fallback_url = (
+ self.root.server.public_baseurl + "_matrix/static/client/login"
+ )
self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs):
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 5679f05e..6227434b 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -172,9 +172,12 @@ class TlsConfig(Config):
)
# YYYYMMDDhhmmssZ -- in UTC
- expires_on = datetime.strptime(
- tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ"
- )
+ expiry_data = tls_certificate.get_notAfter()
+ if expiry_data is None:
+ raise ValueError(
+ "TLS Certificate has no expiry date, and this is not permitted"
+ )
+ expires_on = datetime.strptime(expiry_data.decode("ascii"), "%Y%m%d%H%M%SZ")
now = datetime.utcnow()
days_remaining = (expires_on - now).days
return days_remaining
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 65040283..ca0293a3 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -41,42 +41,112 @@ from synapse.types import StateMap, UserID, get_domain_from_id
logger = logging.getLogger(__name__)
-def check(
- room_version_obj: RoomVersion,
- event: EventBase,
- auth_events: StateMap[EventBase],
- do_sig_check: bool = True,
- do_size_check: bool = True,
+def validate_event_for_room_version(
+ room_version_obj: RoomVersion, event: EventBase
) -> None:
- """Checks if this event is correctly authed.
+ """Ensure that the event complies with the limits, and has the right signatures
+
+ NB: does not *validate* the signatures - it assumes that any signatures present
+ have already been checked.
+
+ NB: it does not check that the event satisfies the auth rules (that is done in
+ check_auth_rules_for_event) - these tests are independent of the rest of the state
+ in the room.
+
+ NB: This is used to check events that have been received over federation. As such,
+ it can only enforce the checks specified in the relevant room version, to avoid
+ a split-brain situation where some servers accept such events, and others reject
+ them.
+
+ TODO: consider moving this into EventValidator
Args:
- room_version_obj: the version of the room
- event: the event being checked.
- auth_events: the existing room state.
- do_sig_check: True if it should be verified that the sending server
- signed the event.
- do_size_check: True if the size of the event fields should be verified.
+ room_version_obj: the version of the room which contains this event
+ event: the event to be checked
Raises:
- AuthError if the checks fail
-
- Returns:
- if the auth checks pass.
+ SynapseError if there is a problem with the event
"""
- assert isinstance(auth_events, dict)
-
- if do_size_check:
- _check_size_limits(event)
+ _check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
- room_id = event.room_id
+ # check that the event has the correct signatures
+ sender_domain = get_domain_from_id(event.sender)
+
+ is_invite_via_3pid = (
+ event.type == EventTypes.Member
+ and event.membership == Membership.INVITE
+ and "third_party_invite" in event.content
+ )
+
+ # Check the sender's domain has signed the event
+ if not event.signatures.get(sender_domain):
+ # We allow invites via 3pid to have a sender from a different
+ # HS, as the sender must match the sender of the original
+ # 3pid invite. This is checked further down with the
+ # other dedicated membership checks.
+ if not is_invite_via_3pid:
+ raise AuthError(403, "Event not signed by sender's server")
+
+ if event.format_version in (EventFormatVersions.V1,):
+ # Only older room versions have event IDs to check.
+ event_id_domain = get_domain_from_id(event.event_id)
+
+ # Check the origin domain has signed the event
+ if not event.signatures.get(event_id_domain):
+ raise AuthError(403, "Event not signed by sending server")
+
+ is_invite_via_allow_rule = (
+ room_version_obj.msc3083_join_rules
+ and event.type == EventTypes.Member
+ and event.membership == Membership.JOIN
+ and EventContentFields.AUTHORISING_USER in event.content
+ )
+ if is_invite_via_allow_rule:
+ authoriser_domain = get_domain_from_id(
+ event.content[EventContentFields.AUTHORISING_USER]
+ )
+ if not event.signatures.get(authoriser_domain):
+ raise AuthError(403, "Event not signed by authorising server")
+
+
+def check_auth_rules_for_event(
+ room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
+ """Check that an event complies with the auth rules
+
+ Checks whether an event passes the auth rules with a given set of state events
+
+ Assumes that we have already checked that the event is the right shape (it has
+ enough signatures, has a room ID, etc). In other words:
+
+ - it's fine for use in state resolution, when we have already decided whether to
+ accept the event or not, and are now trying to decide whether it should make it
+ into the room state
+
+ - when we're doing the initial event auth, it is only suitable in combination with
+ a bunch of other tests.
+
+ Args:
+ room_version_obj: the version of the room
+ event: the event being checked.
+ auth_events: the room state to check the events against.
+
+ Raises:
+ AuthError if the checks fail
+ """
+ assert isinstance(auth_events, dict)
# We need to ensure that the auth events are actually for the same room, to
# stop people from using powers they've been granted in other rooms for
# example.
+ #
+ # Arguably we don't need to do this when we're just doing state res, as presumably
+ # the state res algorithm isn't silly enough to give us events from different rooms.
+ # Still, it's easier to do it anyway.
+ room_id = event.room_id
for auth_event in auth_events.values():
if auth_event.room_id != room_id:
raise AuthError(
@@ -85,44 +155,12 @@ def check(
"which is in room %s"
% (event.event_id, room_id, auth_event.event_id, auth_event.room_id),
)
-
- if do_sig_check:
- sender_domain = get_domain_from_id(event.sender)
-
- is_invite_via_3pid = (
- event.type == EventTypes.Member
- and event.membership == Membership.INVITE
- and "third_party_invite" in event.content
- )
-
- # Check the sender's domain has signed the event
- if not event.signatures.get(sender_domain):
- # We allow invites via 3pid to have a sender from a different
- # HS, as the sender must match the sender of the original
- # 3pid invite. This is checked further down with the
- # other dedicated membership checks.
- if not is_invite_via_3pid:
- raise AuthError(403, "Event not signed by sender's server")
-
- if event.format_version in (EventFormatVersions.V1,):
- # Only older room versions have event IDs to check.
- event_id_domain = get_domain_from_id(event.event_id)
-
- # Check the origin domain has signed the event
- if not event.signatures.get(event_id_domain):
- raise AuthError(403, "Event not signed by sending server")
-
- is_invite_via_allow_rule = (
- event.type == EventTypes.Member
- and event.membership == Membership.JOIN
- and EventContentFields.AUTHORISING_USER in event.content
- )
- if is_invite_via_allow_rule:
- authoriser_domain = get_domain_from_id(
- event.content[EventContentFields.AUTHORISING_USER]
+ if auth_event.rejected_reason:
+ raise AuthError(
+ 403,
+ "During auth for event %s: found rejected event %s in the state"
+ % (event.event_id, auth_event.event_id),
)
- if not event.signatures.get(authoriser_domain):
- raise AuthError(403, "Event not signed by authorising server")
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
#
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 87e2bb12..50f2a4c1 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -18,10 +18,8 @@ import attr
from nacl.signing import SigningKey
from synapse.api.constants import MAX_DEPTH
-from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
KNOWN_EVENT_FORMAT_VERSIONS,
- KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
@@ -197,24 +195,6 @@ class EventBuilderFactory:
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
- def new(self, room_version: str, key_values: dict) -> EventBuilder:
- """Generate an event builder appropriate for the given room version
-
- Deprecated: use for_room_version with a RoomVersion object instead
-
- Args:
- room_version: Version of the room that we're creating an event builder for
- key_values: Fields used as the basis of the new event
-
- Returns:
- EventBuilder
- """
- v = KNOWN_ROOM_VERSIONS.get(room_version)
- if not v:
- # this can happen if support is withdrawn for a room version
- raise UnsupportedRoomVersionError()
- return self.for_room_version(v, key_values)
-
def for_room_version(
self, room_version: RoomVersion, key_values: dict
) -> EventBuilder:
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index eb4556cd..68b8b190 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -45,11 +45,11 @@ def load_legacy_presence_router(hs: "HomeServer"):
configuration, and registers the hooks they implement.
"""
- if hs.config.presence_router_module_class is None:
+ if hs.config.server.presence_router_module_class is None:
return
- module = hs.config.presence_router_module_class
- config = hs.config.presence_router_config
+ module = hs.config.server.presence_router_module_class
+ config = hs.config.server.presence_router_config
api = hs.get_module_api()
presence_router = module(config=config, module_api=api)
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index c389f70b..ae4c8ab2 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -44,7 +44,9 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
["synapse.events.EventBase"],
Awaitable[Union[bool, str]],
]
+USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
+USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[
[str, List[str], List[Dict[str, str]]], Awaitable[bool]
@@ -165,7 +167,11 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
class SpamChecker:
def __init__(self):
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
+ self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
+ self._user_may_send_3pid_invite_callbacks: List[
+ USER_MAY_SEND_3PID_INVITE_CALLBACK
+ ] = []
self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
self._user_may_create_room_with_invites_callbacks: List[
USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
@@ -187,7 +193,9 @@ class SpamChecker:
def register_callbacks(
self,
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+ user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
+ user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
user_may_create_room_with_invites: Optional[
USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
@@ -206,9 +214,17 @@ class SpamChecker:
if check_event_for_spam is not None:
self._check_event_for_spam_callbacks.append(check_event_for_spam)
+ if user_may_join_room is not None:
+ self._user_may_join_room_callbacks.append(user_may_join_room)
+
if user_may_invite is not None:
self._user_may_invite_callbacks.append(user_may_invite)
+ if user_may_send_3pid_invite is not None:
+ self._user_may_send_3pid_invite_callbacks.append(
+ user_may_send_3pid_invite,
+ )
+
if user_may_create_room is not None:
self._user_may_create_room_callbacks.append(user_may_create_room)
@@ -259,6 +275,24 @@ class SpamChecker:
return False
+ async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
+ """Checks if a given users is allowed to join a room.
+ Not called when a user creates a room.
+
+ Args:
+ userid: The ID of the user wanting to join the room
+ room_id: The ID of the room the user wants to join
+ is_invited: Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ for callback in self._user_may_join_room_callbacks:
+ if await callback(user_id, room_id, is_invited) is False:
+ return False
+
+ return True
+
async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool:
@@ -280,6 +314,31 @@ class SpamChecker:
return True
+ async def user_may_send_3pid_invite(
+ self, inviter_userid: str, medium: str, address: str, room_id: str
+ ) -> bool:
+ """Checks if a given user may invite a given threepid into the room
+
+ If this method returns false, the threepid invite will be rejected.
+
+ Note that if the threepid is already associated with a Matrix user ID, Synapse
+ will call user_may_invite with said user ID instead.
+
+ Args:
+ inviter_userid: The user ID of the sender of the invitation
+ medium: The 3PID's medium (e.g. "email")
+ address: The 3PID's address (e.g. "alice@example.com")
+ room_id: The room ID
+
+ Returns:
+ True if the user may send the invite, otherwise False
+ """
+ for callback in self._user_may_send_3pid_invite_callbacks:
+ if await callback(inviter_userid, medium, address, room_id) is False:
+ return False
+
+ return True
+
async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index d94b1bb4..976d9fa4 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -217,6 +217,15 @@ class ThirdPartyEventRules:
for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await callback(event, state_events)
+ except SynapseError as e:
+ # FIXME: Being able to throw SynapseErrors is relied upon by
+ # some modules. PR #10386 accidentally broke this ability.
+ # That said, we aren't keen on exposing this implementation detail
+ # to modules and we should one day have a proper way to do what
+ # is wanted.
+ # This module callback needs a rework so that hacks such as
+ # this one are not necessary.
+ raise e
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 38fccd1e..520edbbf 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -372,7 +372,7 @@ class EventClientSerializer:
def __init__(self, hs):
self.store = hs.get_datastore()
self.experimental_msc1849_support_enabled = (
- hs.config.experimental_msc1849_support_enabled
+ hs.config.server.experimental_msc1849_support_enabled
)
async def serialize_event(
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5f4383ee..d8c0b86f 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1008,7 +1008,10 @@ class FederationServer(FederationBase):
async with lock:
logger.info("handling received PDU: %s", event)
try:
- await self._federation_event_handler.on_receive_pdu(origin, event)
+ with nested_logging_context(event.event_id):
+ await self._federation_event_handler.on_receive_pdu(
+ origin, event
+ )
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 95176ba6..c32539bf 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -117,7 +117,7 @@ class PublicRoomList(BaseFederationServlet):
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_room_list_handler()
- self.allow_access = hs.config.allow_public_rooms_over_federation
+ self.allow_access = hs.config.server.allow_public_rooms_over_federation
async def on_GET(
self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]]
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
deleted file mode 100644
index 0ccef884..00000000
--- a/synapse/handlers/_base.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Copyright 2014 - 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Optional
-
-from synapse.api.ratelimiting import Ratelimiter
-from synapse.types import Requester
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class BaseHandler:
- """
- Common base class for the event handlers.
-
- Deprecated: new code should not use this. Instead, Handler classes should define the
- fields they actually need. The utility methods should either be factored out to
- standalone helper functions, or to different Handler classes.
- """
-
- def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastore()
- self.auth = hs.get_auth()
- self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler()
- self.distributor = hs.get_distributor()
- self.clock = hs.get_clock()
- self.hs = hs
-
- # The rate_hz and burst_count are overridden on a per-user basis
- self.request_ratelimiter = Ratelimiter(
- store=self.store, clock=self.clock, rate_hz=0, burst_count=0
- )
- self._rc_message = self.hs.config.ratelimiting.rc_message
-
- # Check whether ratelimiting room admin message redaction is enabled
- # by the presence of rate limits in the config
- if self.hs.config.ratelimiting.rc_admin_redaction:
- self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
- store=self.store,
- clock=self.clock,
- rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second,
- burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count,
- )
- else:
- self.admin_redaction_ratelimiter = None
-
- self.server_name = hs.hostname
-
- self.event_builder_factory = hs.get_event_builder_factory()
-
- async def ratelimit(
- self,
- requester: Requester,
- update: bool = True,
- is_admin_redaction: bool = False,
- ) -> None:
- """Ratelimits requests.
-
- Args:
- requester
- update: Whether to record that a request is being processed.
- Set to False when doing multiple checks for one request (e.g.
- to check up front if we would reject the request), and set to
- True for the last call for a given request.
- is_admin_redaction: Whether this is a room admin/moderator
- redacting an event. If so then we may apply different
- ratelimits depending on config.
-
- Raises:
- LimitExceededError if the request should be ratelimited
- """
- user_id = requester.user.to_string()
-
- # The AS user itself is never rate limited.
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service is not None:
- return # do not ratelimit app service senders
-
- messages_per_second = self._rc_message.per_second
- burst_count = self._rc_message.burst_count
-
- # Check if there is a per user override in the DB.
- override = await self.store.get_ratelimit_for_user(user_id)
- if override:
- # If overridden with a null Hz then ratelimiting has been entirely
- # disabled for the user
- if not override.messages_per_second:
- return
-
- messages_per_second = override.messages_per_second
- burst_count = override.burst_count
-
- if is_admin_redaction and self.admin_redaction_ratelimiter:
- # If we have separate config for admin redactions, use a separate
- # ratelimiter as to not have user_ids clash
- await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
- else:
- # Override rate and burst count per-user
- await self.request_ratelimiter.ratelimit(
- requester,
- rate_hz=messages_per_second,
- burst_count=burst_count,
- update=update,
- )
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 5a5f124d..87e415df 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -67,12 +67,8 @@ class AccountValidityHandler:
and self._account_validity_renew_by_email_enabled
):
# Don't do email-specific configuration if renewal by email is disabled.
- self._template_html = (
- hs.config.account_validity.account_validity_template_html
- )
- self._template_text = (
- hs.config.account_validity.account_validity_template_text
- )
+ self._template_html = hs.config.email.account_validity_template_html
+ self._template_text = hs.config.email.account_validity_template_text
self._renew_email_subject = (
hs.config.account_validity.account_validity_renew_email_subject
)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index bfa7f2c5..a53cd62d 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -21,18 +21,15 @@ from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class AdminHandler(BaseHandler):
+class AdminHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a8c717ef..f4612a5b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -52,7 +52,6 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers._base import BaseHandler
from synapse.handlers.ui_auth import (
INTERACTIVE_AUTH_CHECKERS,
UIAuthSessionDataConstants,
@@ -186,19 +185,20 @@ class LoginTokenAttributes:
auth_provider_id = attr.ib(type=str)
-class AuthHandler(BaseHandler):
+class AuthHandler:
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
self.checkers[inst.AUTH_TYPE] = inst # type: ignore
- self.bcrypt_rounds = hs.config.bcrypt_rounds
+ self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
# we can't use hs.get_module_api() here, because to do so will create an
# import loop.
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 9ae5b775..e88c3c27 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -19,19 +19,17 @@ from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Requester, UserID, create_requester
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class DeactivateAccountHandler(BaseHandler):
+class DeactivateAccountHandler:
"""Handler which deals with deactivating user accounts."""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
self.hs = hs
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@@ -133,6 +131,10 @@ class DeactivateAccountHandler(BaseHandler):
# delete from user directory
await self.user_directory_handler.handle_local_user_deactivated(user_id)
+ # If the user is present in the monthly active users table
+ # remove them
+ await self.store.remove_deactivated_user_from_mau_table(user_id)
+
# Mark the user as erased, if they asked for that
if erase_data:
user = UserID.from_string(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 35334725..75e60197 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -40,8 +40,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -50,14 +48,16 @@ logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100
-class DeviceWorkerHandler(BaseHandler):
+class DeviceWorkerHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.clock = hs.get_clock()
self.hs = hs
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler()
+ self.server_name = hs.hostname
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 5cfba3c8..14ed7d98 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -31,26 +31,25 @@ from synapse.appservice import ApplicationService
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class DirectoryHandler(BaseHandler):
+class DirectoryHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.auth = hs.get_auth()
+ self.hs = hs
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.config = hs.config
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
- self.require_membership = hs.config.require_membership_for_aliases
+ self.require_membership = hs.config.server.require_membership_for_aliases
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self.server_name = hs.hostname
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index cb81fa09..d089c562 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -22,7 +22,8 @@ from synapse.api.constants import (
RestrictedJoinRuleTypes,
)
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.api.room_versions import RoomVersion
+from synapse.event_auth import check_auth_rules_for_event
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
@@ -45,21 +46,17 @@ class EventAuthHandler:
self._store = hs.get_datastore()
self._server_name = hs.hostname
- async def check_from_context(
+ async def check_auth_rules_from_context(
self,
- room_version: str,
+ room_version_obj: RoomVersion,
event: EventBase,
context: EventContext,
- do_sig_check: bool = True,
) -> None:
+ """Check an event passes the auth rules at its own auth events"""
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
-
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- event_auth.check(
- room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
- )
+ check_auth_rules_for_event(room_version_obj, event, auth_events)
def compute_auth_events(
self,
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 4b3f0370..1f64534a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -25,8 +25,6 @@ from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -34,11 +32,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class EventStreamHandler(BaseHandler):
+class EventStreamHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ self.hs = hs
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
@@ -138,9 +136,9 @@ class EventStreamHandler(BaseHandler):
return chunk
-class EventHandler(BaseHandler):
+class EventHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
self.storage = hs.get_storage()
async def get_event(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index adbd150e..3e341bd2 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,11 +45,14 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
+from synapse.event_auth import (
+ check_auth_rules_for_event,
+ validate_event_for_room_version,
+)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@@ -74,15 +77,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class FederationHandler(BaseHandler):
+class FederationHandler:
"""Handles general incoming federation requests
Incoming events are *not* handled here, for which see FederationEventHandler.
"""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
self.hs = hs
self.store = hs.get_datastore()
@@ -95,6 +96,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.event_builder_factory = hs.get_event_builder_factory()
self._event_auth_handler = hs.get_event_auth_handler()
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.config = hs.config
@@ -723,8 +725,8 @@ class FederationHandler(BaseHandler):
state_ids,
)
- builder = self.event_builder_factory.new(
- room_version.identifier,
+ builder = self.event_builder_factory.for_room_version(
+ room_version,
{
"type": EventTypes.Member,
"content": event_content,
@@ -747,10 +749,9 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- await self._event_auth_handler.check_from_context(
- room_version.identifier, event, context, do_sig_check=False
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version, event, context
)
-
return event
async def on_invite_request(
@@ -767,7 +768,7 @@ class FederationHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- if self.hs.config.block_non_admin_invites:
+ if self.hs.config.server.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
if not await self.spam_checker.user_may_invite(
@@ -902,9 +903,9 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- room_version = await self.store.get_room_version_id(room_id)
- builder = self.event_builder_factory.new(
- room_version,
+ room_version_obj = await self.store.get_room_version(room_id)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj,
{
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
@@ -921,8 +922,8 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- await self._event_auth_handler.check_from_context(
- room_version, event, context, do_sig_check=False
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Failed to create new leave %r because %s", event, e)
@@ -954,10 +955,10 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
- room_version = await self.store.get_room_version_id(room_id)
+ room_version_obj = await self.store.get_room_version(room_id)
- builder = self.event_builder_factory.new(
- room_version,
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj,
{
"type": EventTypes.Member,
"content": {"membership": Membership.KNOCK},
@@ -983,8 +984,8 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_knock_request`
- await self._event_auth_handler.check_from_context(
- room_version, event, context, do_sig_check=False
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Failed to create new knock %r because %s", event, e)
@@ -1173,7 +1174,8 @@ class FederationHandler(BaseHandler):
auth_for_e[(EventTypes.Create, "")] = create_event
try:
- event_auth.check(room_version, e, auth_events=auth_for_e)
+ validate_event_for_room_version(room_version, e)
+ check_auth_rules_for_event(room_version, e, auth_for_e)
except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
@@ -1250,8 +1252,10 @@ class FederationHandler(BaseHandler):
}
if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname):
- room_version = await self.store.get_room_version_id(room_id)
- builder = self.event_builder_factory.new(room_version, event_dict)
+ room_version_obj = await self.store.get_room_version(room_id)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
EventValidator().validate_builder(builder)
event, context = await self.event_creation_handler.create_new_client_event(
@@ -1259,7 +1263,7 @@ class FederationHandler(BaseHandler):
)
event, context = await self.add_display_name_to_third_party_invite(
- room_version, event_dict, event, context
+ room_version_obj, event_dict, event, context
)
EventValidator().validate_new(event, self.config)
@@ -1269,8 +1273,9 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
- await self._event_auth_handler.check_from_context(
- room_version, event, context
+ validate_event_for_room_version(room_version_obj, event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
@@ -1304,22 +1309,25 @@ class FederationHandler(BaseHandler):
"""
assert_params_in_dict(event_dict, ["room_id"])
- room_version = await self.store.get_room_version_id(event_dict["room_id"])
+ room_version_obj = await self.store.get_room_version(event_dict["room_id"])
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
- builder = self.event_builder_factory.new(room_version, event_dict)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event, context = await self.add_display_name_to_third_party_invite(
- room_version, event_dict, event, context
+ room_version_obj, event_dict, event, context
)
try:
- await self._event_auth_handler.check_from_context(
- room_version, event, context
+ validate_event_for_room_version(room_version_obj, event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as e:
logger.warning("Denying third party invite %r because %s", event, e)
@@ -1336,7 +1344,7 @@ class FederationHandler(BaseHandler):
async def add_display_name_to_third_party_invite(
self,
- room_version: str,
+ room_version_obj: RoomVersion,
event_dict: JsonDict,
event: EventBase,
context: EventContext,
@@ -1368,7 +1376,9 @@ class FederationHandler(BaseHandler):
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
- builder = self.event_builder_factory.new(room_version, event_dict)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
EventValidator().validate_builder(builder)
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 01fd8411..f640b417 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -29,7 +29,6 @@ from typing import (
from prometheus_client import Counter
-from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -47,7 +46,11 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.event_auth import auth_types_for_event
+from synapse.event_auth import (
+ auth_types_for_event,
+ check_auth_rules_for_event,
+ validate_event_for_room_version,
+)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
@@ -68,11 +71,7 @@ from synapse.types import (
UserID,
get_domain_from_id,
)
-from synapse.util.async_helpers import (
- Linearizer,
- concurrently_execute,
- yieldable_gather_results,
-)
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
@@ -357,6 +356,11 @@ class FederationEventHandler:
)
# all looks good, we can persist the event.
+
+ # First, precalculate the joined hosts so that the federation sender doesn't
+ # need to.
+ await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
+
await self._run_push_actions_and_persist_event(event, context)
return event, context
@@ -890,6 +894,9 @@ class FederationEventHandler:
backfilled=backfilled,
)
except AuthError as e:
+ # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it
+ # for now
+ logger.exception("Unexpected AuthError from _check_event_auth")
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@@ -1011,9 +1018,8 @@ class FederationEventHandler:
room_version = await self._store.get_room_version(marker_event.room_id)
create_event = await self._store.get_create_event_for_room(marker_event.room_id)
room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
- if (
- not room_version.msc2716_historical
- or not self._config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self._config.experimental.msc2716_enabled
or marker_event.sender != room_creator
):
return
@@ -1155,7 +1161,10 @@ class FederationEventHandler:
return
logger.info(
- "Persisting %i of %i remaining events", len(roots), len(event_map)
+ "Persisting %i of %i remaining outliers: %s",
+ len(roots),
+ len(event_map),
+ shortstr(e.event_id for e in roots),
)
await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
@@ -1189,7 +1198,10 @@ class FederationEventHandler:
allow_rejected=True,
)
- async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
+ room_version = await self._store.get_room_version_id(room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
with nested_logging_context(suffix=event.event_id):
auth = {}
for auth_event_id in event.auth_event_ids():
@@ -1207,17 +1219,16 @@ class FederationEventHandler:
auth[(ae.type, ae.state_key)] = ae
context = EventContext.for_outlier()
- context = await self._check_event_auth(
- origin,
- event,
- context,
- claimed_auth_event_map=auth,
- )
+ try:
+ validate_event_for_room_version(room_version_obj, event)
+ check_auth_rules_for_event(room_version_obj, event, auth)
+ except AuthError as e:
+ logger.warning("Rejecting %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
return event, context
- events_to_persist = (
- x for x in await yieldable_gather_results(prep, fetched_events) if x
- )
+ events_to_persist = (x for x in (prep(event) for event in fetched_events) if x)
await self.persist_events_and_notify(room_id, tuple(events_to_persist))
async def _check_event_auth(
@@ -1226,7 +1237,6 @@ class FederationEventHandler:
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
"""
@@ -1242,42 +1252,45 @@ class FederationEventHandler:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly including events that were rejected, or are in the wrong room.
-
- Only populated when populating outliers.
-
backfilled: True if the event was backfilled.
Returns:
The updated context object.
"""
- # claimed_auth_event_map should be given iff the event is an outlier
- assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
+ # This method should only be used for non-outliers
+ assert not event.internal_metadata.outlier
+ # first of all, check that the event itself is valid.
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- if claimed_auth_event_map:
- # if we have a copy of the auth events from the event, use that as the
- # basis for auth.
- auth_events = claimed_auth_event_map
- else:
- # otherwise, we calculate what the auth events *should* be, and use that
- prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self._event_auth_handler.compute_auth_events(
- event, prev_state_ids, for_verification=True
- )
- auth_events_x = await self._store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+ try:
+ validate_event_for_room_version(room_version_obj, event)
+ except AuthError as e:
+ logger.warning("While validating received event %r: %s", event, e)
+ # TODO: use a different rejected reason here?
+ context.rejected = RejectedReason.AUTH_ERROR
+ return context
+
+ # calculate what the auth events *should* be, to use as a basis for auth.
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
+ event, prev_state_ids, for_verification=True
+ )
+ auth_events_x = await self._store.get_events(auth_events_ids)
+ calculated_auth_event_map = {
+ (e.type, e.state_key): e for e in auth_events_x.values()
+ }
try:
(
context,
auth_events_for_auth,
) = await self._update_auth_events_and_context_for_auth(
- origin, event, context, auth_events
+ origin,
+ event,
+ context,
+ calculated_auth_event_map=calculated_auth_event_map,
)
except Exception:
# We don't really mind if the above fails, so lets not fail
@@ -1289,24 +1302,17 @@ class FederationEventHandler:
"Ignoring failure and continuing processing of event.",
event.event_id,
)
- auth_events_for_auth = auth_events
+ auth_events_for_auth = calculated_auth_event_map
try:
- event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
+ check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth)
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
+ return context
- if not context.rejected:
- await self._check_for_soft_fail(event, state, backfilled, origin=origin)
- await self._maybe_kick_guest_users(event)
-
- # If we are going to send this event over federation we precaclculate
- # the joined hosts.
- if event.internal_metadata.get_send_on_behalf_of():
- await self._event_creation_handler.cache_joined_hosts_for_event(
- event, context
- )
+ await self._check_for_soft_fail(event, state, backfilled, origin=origin)
+ await self._maybe_kick_guest_users(event)
return context
@@ -1404,7 +1410,7 @@ class FederationEventHandler:
}
try:
- event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ check_auth_rules_for_event(room_version_obj, event, current_auth_events)
except AuthError as e:
logger.warning(
"Soft-failing %r (from %s) because %s",
@@ -1425,7 +1431,7 @@ class FederationEventHandler:
origin: str,
event: EventBase,
context: EventContext,
- input_auth_events: StateMap[EventBase],
+ calculated_auth_event_map: StateMap[EventBase],
) -> Tuple[EventContext, StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs.
@@ -1443,19 +1449,17 @@ class FederationEventHandler:
event:
context:
- input_auth_events:
- Map from (event_type, state_key) to event
-
- Normally, our calculated auth_events based on the state of the room
- at the event's position in the DAG, though occasionally (eg if the
- event is an outlier), may be the auth events claimed by the remote
- server.
+ calculated_auth_event_map:
+ Our calculated auth_events based on the state of the room
+ at the event's position in the DAG.
Returns:
updated context, updated auth event map
"""
- # take a copy of input_auth_events before we modify it.
- auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
+ assert not event.internal_metadata.outlier
+
+ # take a copy of calculated_auth_event_map before we modify it.
+ auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map)
event_auth_events = set(event.auth_event_ids())
@@ -1475,6 +1479,11 @@ class FederationEventHandler:
logger.debug("Events %s are in the store", have_events)
missing_auth.difference_update(have_events)
+ # missing_auth is now the set of event_ids which:
+ # a. are listed in event.auth_events, *and*
+ # b. are *not* part of our calculated auth events based on room state, *and*
+ # c. are *not* yet in our database.
+
if missing_auth:
# If we don't have all the auth events, we need to get them.
logger.info("auth_events contains unknown events: %s", missing_auth)
@@ -1496,19 +1505,31 @@ class FederationEventHandler:
}
)
- if event.internal_metadata.is_outlier():
- # XXX: given that, for an outlier, we'll be working with the
- # event's *claimed* auth events rather than those we calculated:
- # (a) is there any point in this test, since different_auth below will
- # obviously be empty
- # (b) alternatively, why don't we do it earlier?
- logger.info("Skipping auth_event fetch for outlier")
- return context, auth_events
+ # auth_events now contains
+ # 1. our *calculated* auth events based on the room state, plus:
+ # 2. any events which:
+ # a. are listed in `event.auth_events`, *and*
+ # b. are not part of our calculated auth events, *and*
+ # c. were not in our database before the call to /event_auth
+ # d. have since been added to our database (most likely by /event_auth).
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
+ # different_auth is the set of events which *are* in `event.auth_events`, but
+ # which are *not* in `auth_events`. Comparing with (2.) above, this means
+ # exclusively the set of `event.auth_events` which we already had in our
+ # database before any call to /event_auth.
+ #
+ # I'm reasonably sure that the fact that events returned by /event_auth are
+ # blindly added to auth_events (and hence excluded from different_auth) is a bug
+ # - though it's a very long-standing one (see
+ # https://github.com/matrix-org/synapse/commit/78015948a7febb18e000651f72f8f58830a55b93#diff-0bc92da3d703202f5b9be2d3f845e375f5b1a6bc6ba61705a8af9be1121f5e42R786
+ # from Jan 2015 which seems to add it, though it actually just moves it from
+ # elsewhere (before that, it gets lost in a mess of huge "various bug fixes"
+ # PRs).
+
if not different_auth:
return context, auth_events
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index fe8a9958..9c319b53 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -39,8 +39,6 @@ from synapse.util.stringutils import (
valid_id_server_location,
)
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -49,15 +47,14 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://"
-class IdentityHandler(BaseHandler):
+class IdentityHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
# An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
- hs, ip_blacklist=hs.config.federation_ip_range_blacklist
+ hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist
)
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
@@ -573,9 +570,15 @@ class IdentityHandler(BaseHandler):
# Try to validate as email
if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ # Remote emails will only be used if a valid identity server is provided.
+ assert (
+ self.hs.config.registration.account_threepid_delegate_email is not None
+ )
+
# Ask our delegated email identity server
validation_session = await self.threepid_from_creds(
- self.hs.config.account_threepid_delegate_email, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_email,
+ threepid_creds,
)
elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
@@ -587,10 +590,11 @@ class IdentityHandler(BaseHandler):
return validation_session
# Try to validate as msisdn
- if self.hs.config.account_threepid_delegate_msisdn:
+ if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
- self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_msisdn,
+ threepid_creds,
)
return validation_session
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9ad39a65..d4e45561 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -31,8 +31,6 @@ from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -40,9 +38,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class InitialSyncHandler(BaseHandler):
+class InitialSyncHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.state_handler = hs.get_state_handler()
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index fd861e94..4de9f4b8 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import random
+from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -39,9 +40,11 @@ from synapse.api.errors import (
NotFoundError,
ShadowBanError,
SynapseError,
+ UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
+from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
@@ -59,8 +62,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.server import HomeServer
@@ -79,7 +80,7 @@ class MessageHandler:
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
- self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
@@ -413,7 +414,9 @@ class EventCreationHandler:
self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.config = hs.config
- self.require_membership_for_aliases = hs.config.require_membership_for_aliases
+ self.require_membership_for_aliases = (
+ hs.config.server.require_membership_for_aliases
+ )
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
@@ -423,13 +426,12 @@ class EventCreationHandler:
Membership.JOIN,
Membership.KNOCK,
}
- if self.hs.config.include_profile_data_on_invite:
+ if self.hs.config.server.include_profile_data_on_invite:
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
- # This is only used to get at ratelimit function
- self.base_handler = BaseHandler(hs)
+ self.request_ratelimiter = hs.get_request_ratelimiter()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
@@ -459,11 +461,11 @@ class EventCreationHandler:
#
self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
# The number of forward extremeities before a dummy event is sent.
- self._dummy_events_threshold = hs.config.dummy_events_threshold
+ self._dummy_events_threshold = hs.config.server.dummy_events_threshold
if (
self.config.worker.run_background_tasks
- and self.config.cleanup_extremities_with_dummy_events
+ and self.config.server.cleanup_extremities_with_dummy_events
):
self.clock.looping_call(
lambda: run_as_background_process(
@@ -475,7 +477,7 @@ class EventCreationHandler:
self._message_handler = hs.get_message_handler()
- self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
self._external_cache = hs.get_external_cache()
@@ -549,16 +551,22 @@ class EventCreationHandler:
await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
- room_version = event_dict["content"]["room_version"]
+ room_version_id = event_dict["content"]["room_version"]
+ room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version_obj:
+ # this can happen if support is withdrawn for a room version
+ raise UnsupportedRoomVersionError(room_version_id)
else:
try:
- room_version = await self.store.get_room_version_id(
+ room_version_obj = await self.store.get_room_version(
event_dict["room_id"]
)
except NotFoundError:
raise AuthError(403, "Unknown room")
- builder = self.event_builder_factory.new(room_version, event_dict)
+ builder = self.event_builder_factory.for_room_version(
+ room_version_obj, event_dict
+ )
self.validator.validate_builder(builder)
@@ -1064,9 +1072,17 @@ class EventCreationHandler:
EventTypes.Create,
"",
):
- room_version = event.content.get("room_version", RoomVersions.V1.identifier)
+ room_version_id = event.content.get(
+ "room_version", RoomVersions.V1.identifier
+ )
+ room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version_obj:
+ raise UnsupportedRoomVersionError(
+ "Attempt to create a room with unsupported room version %s"
+ % (room_version_id,)
+ )
else:
- room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = await self.store.get_room_version(event.room_id)
if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here are
@@ -1075,8 +1091,9 @@ class EventCreationHandler:
assert event.content["membership"] == Membership.LEAVE
else:
try:
- await self._event_auth_handler.check_from_context(
- room_version, event, context
+ validate_event_for_room_version(room_version_obj, event)
+ await self._event_auth_handler.check_auth_rules_from_context(
+ room_version_obj, event, context
)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
@@ -1302,7 +1319,7 @@ class EventCreationHandler:
original_event and event.sender != original_event.sender
)
- await self.base_handler.ratelimit(
+ await self.request_ratelimiter.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
@@ -1456,6 +1473,39 @@ class EventCreationHandler:
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
+ if event.type == EventTypes.MSC2716_INSERTION:
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ create_event = await self.store.get_create_event_for_room(event.room_id)
+ room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+
+ # Only check an insertion event if the room version
+ # supports it or the event is from the room creator.
+ if room_version_obj.msc2716_historical or (
+ self.config.experimental.msc2716_enabled
+ and event.sender == room_creator
+ ):
+ next_batch_id = event.content.get(
+ EventContentFields.MSC2716_NEXT_BATCH_ID
+ )
+ conflicting_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(
+ event.room_id, next_batch_id
+ )
+ )
+ if conflicting_insertion_event_id is not None:
+ # The current insertion event that we're processing is invalid
+ # because an insertion event already exists in the room with the
+ # same next_batch_id. We can't allow multiple because the batch
+ # pointing will get weird, e.g. we can't determine which insertion
+ # event the batch event is pointing to.
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Another insertion event already exists with the same next_batch_id",
+ errcode=Codes.INVALID_PARAM,
+ )
+
# Mark any `m.historical` messages as backfilled so they don't appear
# in `/sync` and have the proper decrementing `stream_ordering` as we import
backfilled = False
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 08b93b3e..176e4dfd 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -85,23 +85,29 @@ class PaginationHandler:
self._purges_by_id: Dict[str, PurgeStatus] = {}
self._event_serializer = hs.get_event_client_serializer()
- self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+ self._retention_default_max_lifetime = (
+ hs.config.server.retention_default_max_lifetime
+ )
- self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
- self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
+ self._retention_allowed_lifetime_min = (
+ hs.config.server.retention_allowed_lifetime_min
+ )
+ self._retention_allowed_lifetime_max = (
+ hs.config.server.retention_allowed_lifetime_max
+ )
- if hs.config.worker.run_background_tasks and hs.config.retention_enabled:
+ if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled:
# Run the purge jobs described in the configuration file.
- for job in hs.config.retention_purge_jobs:
+ for job in hs.config.server.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job)
self.clock.looping_call(
run_as_background_process,
- job["interval"],
+ job.interval,
"purge_history_for_rooms_in_range",
self.purge_history_for_rooms_in_range,
- job["shortest_max_lifetime"],
- job["longest_max_lifetime"],
+ job.shortest_max_lifetime,
+ job.longest_max_lifetime,
)
async def purge_history_for_rooms_in_range(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index b23a1541..e6c3cf58 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -32,8 +32,6 @@ from synapse.types import (
get_domain_from_id,
)
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -43,7 +41,7 @@ MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
-class ProfileHandler(BaseHandler):
+class ProfileHandler:
"""Handles fetching and updating user profile information.
ProfileHandler can be instantiated directly on workers and will
@@ -54,7 +52,9 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.hs = hs
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -62,6 +62,7 @@ class ProfileHandler(BaseHandler):
)
self.user_directory_handler = hs.get_user_directory_handler()
+ self.request_ratelimiter = hs.get_request_ratelimiter()
if hs.config.worker.run_background_tasks:
self.clock.looping_call(
@@ -178,7 +179,7 @@ class ProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
- if not by_admin and not self.hs.config.enable_set_displayname:
+ if not by_admin and not self.hs.config.registration.enable_set_displayname:
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
@@ -268,7 +269,7 @@ class ProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
- if not by_admin and not self.hs.config.enable_set_avatar_url:
+ if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
@@ -346,7 +347,7 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
return
- await self.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
# Do not actually update the room state for shadow-banned users.
if requester.shadow_banned:
@@ -397,7 +398,7 @@ class ProfileHandler(BaseHandler):
# when building a membership event. In this case, we must allow the
# lookup.
if (
- not self.hs.config.limit_profile_requests_to_users_who_share_rooms
+ not self.hs.config.server.limit_profile_requests_to_users_who_share_rooms
or not requester
):
return
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index bd8160e7..58593e57 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -17,17 +17,14 @@ from typing import TYPE_CHECKING
from synapse.util.async_helpers import Linearizer
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class ReadMarkerHandler(BaseHandler):
+class ReadMarkerHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
self.server_name = hs.config.server.server_name
self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler()
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f21f33ad..374e961e 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
-from synapse.handlers._base import BaseHandler
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -26,10 +25,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class ReceiptsHandler(BaseHandler):
+class ReceiptsHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.notifier = hs.get_notifier()
self.server_name = hs.config.server.server_name
self.store = hs.get_datastore()
self.event_auth_handler = hs.get_event_auth_handler()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 4f99f137..a0e6a017 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -41,8 +41,6 @@ from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -85,9 +83,10 @@ class LoginDict(TypedDict):
refresh_token: Optional[str]
-class RegistrationHandler(BaseHandler):
+class RegistrationHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
self.hs = hs
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -116,8 +115,8 @@ class RegistrationHandler(BaseHandler):
self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
- self.session_lifetime = hs.config.session_lifetime
- self.access_token_lifetime = hs.config.access_token_lifetime
+ self.session_lifetime = hs.config.registration.session_lifetime
+ self.access_token_lifetime = hs.config.registration.access_token_lifetime
init_counters_for_auth_provider("")
@@ -340,8 +339,13 @@ class RegistrationHandler(BaseHandler):
auth_provider=(auth_provider_id or ""),
).inc()
+ # If the user does not need to consent at registration, auto-join any
+ # configured rooms.
if not self.hs.config.consent.user_consent_at_registration:
- if not self.hs.config.auto_join_rooms_for_guests and make_guest:
+ if (
+ not self.hs.config.registration.auto_join_rooms_for_guests
+ and make_guest
+ ):
logger.info(
"Skipping auto-join for %s because auto-join for guests is disabled",
user_id,
@@ -387,7 +391,7 @@ class RegistrationHandler(BaseHandler):
"preset": self.hs.config.registration.autocreate_auto_join_room_preset,
}
- # If the configuration providers a user ID to create rooms with, use
+ # If the configuration provides a user ID to create rooms with, use
# that instead of the first user registered.
requires_join = False
if self.hs.config.registration.auto_join_user_id:
@@ -510,7 +514,7 @@ class RegistrationHandler(BaseHandler):
# we don't have a local user in the room to craft up an invite with.
requires_invite = await self.store.is_host_joined(
room_id,
- self.server_name,
+ self._server_name,
)
if requires_invite:
@@ -854,7 +858,7 @@ class RegistrationHandler(BaseHandler):
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
- self.hs.config.mau_limits_reserved_threepids, threepid
+ self.hs.config.server.mau_limits_reserved_threepids, threepid
):
await self.store.upsert_monthly_active_user(user_id)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 8fede5e9..7072bca1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -52,6 +52,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin
@@ -75,8 +76,6 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -87,15 +86,18 @@ id_server_scheme = "https://"
FIVE_MINUTES_IN_MS = 5 * 60 * 1000
-class RoomCreationHandler(BaseHandler):
+class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.hs = hs
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self.config = hs.config
+ self.request_ratelimiter = hs.get_request_ratelimiter()
# Room state based off defined presets
self._presets_dict: Dict[str, Dict[str, Any]] = {
@@ -161,7 +163,7 @@ class RoomCreationHandler(BaseHandler):
Raises:
ShadowBanError if the requester is shadow-banned.
"""
- await self.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
user_id = requester.user.to_string()
@@ -237,8 +239,9 @@ class RoomCreationHandler(BaseHandler):
},
},
)
- old_room_version = await self.store.get_room_version_id(old_room_id)
- await self._event_auth_handler.check_from_context(
+ old_room_version = await self.store.get_room_version(old_room_id)
+ validate_event_for_room_version(old_room_version, tombstone_event)
+ await self._event_auth_handler.check_auth_rules_from_context(
old_room_version, tombstone_event, tombstone_context
)
@@ -663,10 +666,10 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
- await self.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
room_version_id = config.get(
- "room_version", self.config.default_room_version.identifier
+ "room_version", self.config.server.default_room_version.identifier
)
if not isinstance(room_version_id, str):
@@ -858,6 +861,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -960,6 +964,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=ratelimit,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
new file mode 100644
index 00000000..51dd4e75
--- /dev/null
+++ b/synapse/handlers/room_batch.py
@@ -0,0 +1,423 @@
+import logging
+from typing import TYPE_CHECKING, List, Tuple
+
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.appservice import ApplicationService
+from synapse.http.servlet import assert_params_in_dict
+from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RoomBatchHandler:
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state_store = hs.get_storage().state
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
+ """Finds the depth which would sort it after the most-recent
+ prev_event_id but before the successors of those events. If no
+ successors are found, we assume it's an historical extremity part of the
+ current batch and use the same depth of the prev_event_ids.
+
+ Args:
+ prev_event_ids: List of prev event IDs
+
+ Returns:
+ Inherited depth
+ """
+ (
+ most_recent_prev_event_id,
+ most_recent_prev_event_depth,
+ ) = await self.store.get_max_depth_of(prev_event_ids)
+
+ # We want to insert the historical event after the `prev_event` but before the successor event
+ #
+ # We inherit depth from the successor event instead of the `prev_event`
+ # because events returned from `/messages` are first sorted by `topological_ordering`
+ # which is just the `depth` and then tie-break with `stream_ordering`.
+ #
+ # We mark these inserted historical events as "backfilled" which gives them a
+ # negative `stream_ordering`. If we use the same depth as the `prev_event`,
+ # then our historical event will tie-break and be sorted before the `prev_event`
+ # when it should come after.
+ #
+ # We want to use the successor event depth so they appear after `prev_event` because
+ # it has a larger `depth` but before the successor event because the `stream_ordering`
+ # is negative before the successor event.
+ successor_event_ids = await self.store.get_successor_events(
+ [most_recent_prev_event_id]
+ )
+
+ # If we can't find any successor events, then it's a forward extremity of
+ # historical messages and we can just inherit from the previous historical
+ # event which we can already assume has the correct depth where we want
+ # to insert into.
+ if not successor_event_ids:
+ depth = most_recent_prev_event_depth
+ else:
+ (
+ _,
+ oldest_successor_depth,
+ ) = await self.store.get_min_depth_of(successor_event_ids)
+
+ depth = oldest_successor_depth
+
+ return depth
+
+ def create_insertion_event_dict(
+ self, sender: str, room_id: str, origin_server_ts: int
+ ) -> JsonDict:
+ """Creates an event dict for an "insertion" event with the proper fields
+ and a random batch ID.
+
+ Args:
+ sender: The event author MXID
+ room_id: The room ID that the event belongs to
+ origin_server_ts: Timestamp when the event was sent
+
+ Returns:
+ The new event dictionary to insert.
+ """
+
+ next_batch_id = random_string(8)
+ insertion_event = {
+ "type": EventTypes.MSC2716_INSERTION,
+ "sender": sender,
+ "room_id": room_id,
+ "content": {
+ EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
+ EventContentFields.MSC2716_HISTORICAL: True,
+ },
+ "origin_server_ts": origin_server_ts,
+ }
+
+ return insertion_event
+
+ async def create_requester_for_user_id_from_app_service(
+ self, user_id: str, app_service: ApplicationService
+ ) -> Requester:
+ """Creates a new requester for the given user_id
+ and validates that the app service is allowed to control
+ the given user.
+
+ Args:
+ user_id: The author MXID that the app service is controlling
+ app_service: The app service that controls the user
+
+ Returns:
+ Requester object
+ """
+
+ await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
+
+ return create_requester(user_id, app_service=app_service)
+
+ async def get_most_recent_auth_event_ids_from_event_id_list(
+ self, event_ids: List[str]
+ ) -> List[str]:
+ """Find the most recent auth event ids (derived from state events) that
+ allowed that message to be sent. We will use this as a base
+ to auth our historical messages against.
+
+ Args:
+ event_ids: List of event ID's to look at
+
+ Returns:
+ List of event ID's
+ """
+
+ (
+ most_recent_prev_event_id,
+ _,
+ ) = await self.store.get_max_depth_of(event_ids)
+ # mapping from (type, state_key) -> state_event_id
+ prev_state_map = await self.state_store.get_state_ids_for_event(
+ most_recent_prev_event_id
+ )
+ # List of state event ID's
+ prev_state_ids = list(prev_state_map.values())
+ auth_event_ids = prev_state_ids
+
+ return auth_event_ids
+
+ async def persist_state_events_at_start(
+ self,
+ state_events_at_start: List[JsonDict],
+ room_id: str,
+ initial_auth_event_ids: List[str],
+ app_service_requester: Requester,
+ ) -> List[str]:
+ """Takes all `state_events_at_start` event dictionaries and creates/persists
+ them as floating state events which don't resolve into the current room state.
+ They are floating because they reference a fake prev_event which doesn't connect
+ to the normal DAG at all.
+
+ Args:
+ state_events_at_start:
+ room_id: Room where you want the events persisted in.
+ initial_auth_event_ids: These will be the auth_events for the first
+ state event created. Each event created afterwards will be
+ added to the list of auth events for the next state event
+ created.
+ app_service_requester: The requester of an application service.
+
+ Returns:
+ List of state event ID's we just persisted
+ """
+ assert app_service_requester.app_service
+
+ state_event_ids_at_start = []
+ auth_event_ids = initial_auth_event_ids.copy()
+ for state_event in state_events_at_start:
+ assert_params_in_dict(
+ state_event, ["type", "origin_server_ts", "content", "sender"]
+ )
+
+ logger.debug(
+ "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
+ state_event,
+ auth_event_ids,
+ )
+
+ event_dict = {
+ "type": state_event["type"],
+ "origin_server_ts": state_event["origin_server_ts"],
+ "content": state_event["content"],
+ "room_id": room_id,
+ "sender": state_event["sender"],
+ "state_key": state_event["state_key"],
+ }
+
+ # Mark all events as historical
+ event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
+
+ # Make the state events float off on their own so we don't have a
+ # bunch of `@mxid joined the room` noise between each batch
+ fake_prev_event_id = "$" + random_string(43)
+
+ # TODO: This is pretty much the same as some other code to handle inserting state in this file
+ if event_dict["type"] == EventTypes.Member:
+ membership = event_dict["content"].get("membership", None)
+ event_id, _ = await self.room_member_handler.update_membership(
+ await self.create_requester_for_user_id_from_app_service(
+ state_event["sender"], app_service_requester.app_service
+ ),
+ target=UserID.from_string(event_dict["state_key"]),
+ room_id=room_id,
+ action=membership,
+ content=event_dict["content"],
+ outlier=True,
+ prev_event_ids=[fake_prev_event_id],
+ # Make sure to use a copy of this list because we modify it
+ # later in the loop here. Otherwise it will be the same
+ # reference and also update in the event when we append later.
+ auth_event_ids=auth_event_ids.copy(),
+ )
+ else:
+ # TODO: Add some complement tests that adds state that is not member joins
+ # and will use this code path. Maybe we only want to support join state events
+ # and can get rid of this `else`?
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ await self.create_requester_for_user_id_from_app_service(
+ state_event["sender"], app_service_requester.app_service
+ ),
+ event_dict,
+ outlier=True,
+ prev_event_ids=[fake_prev_event_id],
+ # Make sure to use a copy of this list because we modify it
+ # later in the loop here. Otherwise it will be the same
+ # reference and also update in the event when we append later.
+ auth_event_ids=auth_event_ids.copy(),
+ )
+ event_id = event.event_id
+
+ state_event_ids_at_start.append(event_id)
+ auth_event_ids.append(event_id)
+
+ return state_event_ids_at_start
+
+ async def persist_historical_events(
+ self,
+ events_to_create: List[JsonDict],
+ room_id: str,
+ initial_prev_event_ids: List[str],
+ inherited_depth: int,
+ auth_event_ids: List[str],
+ app_service_requester: Requester,
+ ) -> List[str]:
+ """Create and persists all events provided sequentially. Handles the
+ complexity of creating events in chronological order so they can
+ reference each other by prev_event but still persists in
+ reverse-chronoloical order so they have the correct
+ (topological_ordering, stream_ordering) and sort correctly from
+ /messages.
+
+ Args:
+ events_to_create: List of historical events to create in JSON
+ dictionary format.
+ room_id: Room where you want the events persisted in.
+ initial_prev_event_ids: These will be the prev_events for the first
+ event created. Each event created afterwards will point to the
+ previous event created.
+ inherited_depth: The depth to create the events at (you will
+ probably by calling inherit_depth_from_prev_ids(...)).
+ auth_event_ids: Define which events allow you to create the given
+ event in the room.
+ app_service_requester: The requester of an application service.
+
+ Returns:
+ List of persisted event IDs
+ """
+ assert app_service_requester.app_service
+
+ prev_event_ids = initial_prev_event_ids.copy()
+
+ event_ids = []
+ events_to_persist = []
+ for ev in events_to_create:
+ assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
+
+ event_dict = {
+ "type": ev["type"],
+ "origin_server_ts": ev["origin_server_ts"],
+ "content": ev["content"],
+ "room_id": room_id,
+ "sender": ev["sender"], # requester.user.to_string(),
+ "prev_events": prev_event_ids.copy(),
+ }
+
+ # Mark all events as historical
+ event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
+
+ event, context = await self.event_creation_handler.create_event(
+ await self.create_requester_for_user_id_from_app_service(
+ ev["sender"], app_service_requester.app_service
+ ),
+ event_dict,
+ prev_event_ids=event_dict.get("prev_events"),
+ auth_event_ids=auth_event_ids,
+ historical=True,
+ depth=inherited_depth,
+ )
+ logger.debug(
+ "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
+ event,
+ prev_event_ids,
+ auth_event_ids,
+ )
+
+ assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
+ event.sender,
+ )
+
+ events_to_persist.append((event, context))
+ event_id = event.event_id
+
+ event_ids.append(event_id)
+ prev_event_ids = [event_id]
+
+ # Persist events in reverse-chronological order so they have the
+ # correct stream_ordering as they are backfilled (which decrements).
+ # Events are sorted by (topological_ordering, stream_ordering)
+ # where topological_ordering is just depth.
+ for (event, context) in reversed(events_to_persist):
+ await self.event_creation_handler.handle_new_client_event(
+ await self.create_requester_for_user_id_from_app_service(
+ event["sender"], app_service_requester.app_service
+ ),
+ event=event,
+ context=context,
+ )
+
+ return event_ids
+
+ async def handle_batch_of_events(
+ self,
+ events_to_create: List[JsonDict],
+ room_id: str,
+ batch_id_to_connect_to: str,
+ initial_prev_event_ids: List[str],
+ inherited_depth: int,
+ auth_event_ids: List[str],
+ app_service_requester: Requester,
+ ) -> Tuple[List[str], str]:
+ """
+ Handles creating and persisting all of the historical events as well
+ as insertion and batch meta events to make the batch navigable in the DAG.
+
+ Args:
+ events_to_create: List of historical events to create in JSON
+ dictionary format.
+ room_id: Room where you want the events created in.
+ batch_id_to_connect_to: The batch_id from the insertion event you
+ want this batch to connect to.
+ initial_prev_event_ids: These will be the prev_events for the first
+ event created. Each event created afterwards will point to the
+ previous event created.
+ inherited_depth: The depth to create the events at (you will
+ probably by calling inherit_depth_from_prev_ids(...)).
+ auth_event_ids: Define which events allow you to create the given
+ event in the room.
+ app_service_requester: The requester of an application service.
+
+ Returns:
+ Tuple containing a list of created events and the next_batch_id
+ """
+
+ # Connect this current batch to the insertion event from the previous batch
+ last_event_in_batch = events_to_create[-1]
+ batch_event = {
+ "type": EventTypes.MSC2716_BATCH,
+ "sender": app_service_requester.user.to_string(),
+ "room_id": room_id,
+ "content": {
+ EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
+ EventContentFields.MSC2716_HISTORICAL: True,
+ },
+ # Since the batch event is put at the end of the batch,
+ # where the newest-in-time event is, copy the origin_server_ts from
+ # the last event we're inserting
+ "origin_server_ts": last_event_in_batch["origin_server_ts"],
+ }
+ # Add the batch event to the end of the batch (newest-in-time)
+ events_to_create.append(batch_event)
+
+ # Add an "insertion" event to the start of each batch (next to the oldest-in-time
+ # event in the batch) so the next batch can be connected to this one.
+ insertion_event = self.create_insertion_event_dict(
+ sender=app_service_requester.user.to_string(),
+ room_id=room_id,
+ # Since the insertion event is put at the start of the batch,
+ # where the oldest-in-time event is, copy the origin_server_ts from
+ # the first event we're inserting
+ origin_server_ts=events_to_create[0]["origin_server_ts"],
+ )
+ next_batch_id = insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
+ ]
+ # Prepend the insertion event to the start of the batch (oldest-in-time)
+ events_to_create = [insertion_event] + events_to_create
+
+ # Create and persist all of the historical events
+ event_ids = await self.persist_historical_events(
+ events_to_create=events_to_create,
+ room_id=room_id,
+ initial_prev_event_ids=initial_prev_event_ids,
+ inherited_depth=inherited_depth,
+ auth_event_ids=auth_event_ids,
+ app_service_requester=app_service_requester,
+ )
+
+ return event_ids, next_batch_id
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index c3d4199e..ba7a14d6 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -36,8 +36,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -49,9 +47,10 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
-class RoomListHandler(BaseHandler):
+class RoomListHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.hs = hs
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index afa7e472..74e6c7ec 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -51,8 +51,6 @@ from synapse.types import (
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -89,8 +87,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.servernotices.server_notices_mxid
- self._enable_lookup = hs.config.enable_3pid_lookup
- self.allow_per_room_profiles = self.config.allow_per_room_profiles
+ self._enable_lookup = hs.config.registration.enable_3pid_lookup
+ self.allow_per_room_profiles = self.config.server.allow_per_room_profiles
self._join_rate_limiter_local = Ratelimiter(
store=self.store,
@@ -118,9 +116,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
)
- # This is only used to get at the ratelimit function. It's fine there are
- # multiple of these as it doesn't store state.
- self.base_handler = BaseHandler(hs)
+ self.request_ratelimiter = hs.get_request_ratelimiter()
@abc.abstractmethod
async def _remote_join(
@@ -434,6 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
outlier: bool = False,
prev_event_ids: Optional[List[str]] = None,
@@ -451,6 +448,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Information from a 3PID invite.
ratelimit: Whether to rate limit the request.
content: The content of the created event.
+ new_room: Whether the membership update is happening in the context of a room
+ creation.
require_consent: Whether consent is required.
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
@@ -485,6 +484,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
outlier=outlier,
prev_event_ids=prev_event_ids,
@@ -504,6 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
+ new_room: bool = False,
require_consent: bool = True,
outlier: bool = False,
prev_event_ids: Optional[List[str]] = None,
@@ -523,6 +524,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
third_party_signed:
ratelimit:
content:
+ new_room: Whether the membership update is happening in the context of a room
+ creation.
require_consent:
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
@@ -625,7 +628,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
- if self.config.block_non_admin_invites:
+ if self.config.server.block_non_admin_invites:
logger.info(
"Blocking invite: user is not admin and non-admin "
"invites disabled"
@@ -726,6 +729,30 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ # Figure out whether the user is a server admin to determine whether they
+ # should be able to bypass the spam checker.
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ bypass_spam_checker = True
+
+ else:
+ bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+
+ inviter = await self._get_inviter(target.to_string(), room_id)
+ if (
+ not bypass_spam_checker
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ and not new_room
+ and not await self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ )
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
# Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join(
target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
@@ -1230,7 +1257,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
- if self.config.block_non_admin_invites:
+ if self.config.server.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
@@ -1244,7 +1271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- await self.base_handler.ratelimit(requester)
+ await self.request_ratelimiter.ratelimit(requester)
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
@@ -1268,10 +1295,22 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if invitee:
# Note that update_membership with an action of "invite" can raise
# a ShadowBanError, but this was done above already.
+ # We don't check the invite against the spamchecker(s) here (through
+ # user_may_invite) because we'll do it further down the line anyway (in
+ # update_membership_locked).
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
+ # Check if the spamchecker(s) allow this invite to go through.
+ if not await self.spam_checker.user_may_send_3pid_invite(
+ inviter_userid=requester.user.to_string(),
+ medium=medium,
+ address=address,
+ room_id=room_id,
+ ):
+ raise SynapseError(403, "Cannot send threepid invite")
+
stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
@@ -1428,7 +1467,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
Returns: bool of whether the complexity is too great, or None
if unable to be fetched
"""
- max_complexity = self.hs.config.limit_remote_rooms.complexity
+ max_complexity = self.hs.config.server.limit_remote_rooms.complexity
complexity = await self.federation_handler.get_room_complexity(
remote_room_hosts, room_id
)
@@ -1444,7 +1483,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
Args:
room_id: The room ID to check for complexity.
"""
- max_complexity = self.hs.config.limit_remote_rooms.complexity
+ max_complexity = self.hs.config.server.limit_remote_rooms.complexity
complexity = await self.store.get_room_complexity(room_id)
return complexity["v1"] > max_complexity
@@ -1468,8 +1507,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
- check_complexity = self.hs.config.limit_remote_rooms.enabled
- if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join:
+ check_complexity = self.hs.config.server.limit_remote_rooms.enabled
+ if (
+ check_complexity
+ and self.hs.config.server.limit_remote_rooms.admins_can_join
+ ):
check_complexity = not await self.auth.is_server_admin(user)
if check_complexity:
@@ -1480,7 +1522,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if too_complex is True:
raise SynapseError(
code=400,
- msg=self.hs.config.limit_remote_rooms.complexity_error,
+ msg=self.hs.config.server.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
@@ -1515,7 +1557,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
raise SynapseError(
code=400,
- msg=self.hs.config.limit_remote_rooms.complexity_error,
+ msg=self.hs.config.server.limit_remote_rooms.complexity_error,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 2fed9f37..727d75a5 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -22,7 +22,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
@@ -51,9 +50,11 @@ class Saml2SessionData:
ui_auth_session_id: Optional[str] = None
-class SamlHandler(BaseHandler):
+class SamlHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.server_name = hs.hostname
self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 8226d6f5..a3ffa26b 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -26,17 +26,18 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class SearchHandler(BaseHandler):
+class SearchHandler:
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
+ self.state_handler = hs.get_state_handler()
+ self.clock = hs.get_clock()
+ self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -105,7 +106,7 @@ class SearchHandler(BaseHandler):
dict to be returned to the client with results of search
"""
- if not self.hs.config.enable_search:
+ if not self.hs.config.server.enable_search:
raise SynapseError(400, "Search is disabled on this homeserver")
batch_group = None
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index 25e6b012..1a062a78 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -105,8 +105,13 @@ async def _sendmail(
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
- # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong
- reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type]
+ reactor.connectTCP(
+ smtphost, # type: ignore[arg-type]
+ smtpport,
+ factory,
+ timeout=30,
+ bindAddress=None,
+ )
await make_deferred_yieldable(d)
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a63fac82..706ad727 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -17,19 +17,17 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class SetPasswordHandler(BaseHandler):
+class SetPasswordHandler:
"""Handler which deals with changing user account passwords"""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self.store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8f5d465f..184730eb 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker:
# msisdns are currently always ThreepidBehaviour.REMOTE
if medium == "msisdn":
- if not self.hs.config.account_threepid_delegate_msisdn:
+ if not self.hs.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
threepid = await identity_handler.threepid_from_creds(
- self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_msisdn,
+ threepid_creds,
)
elif medium == "email":
if (
self.hs.config.email.threepid_behaviour_email
== ThreepidBehaviour.REMOTE
):
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
threepid = await identity_handler.threepid_from_creds(
- self.hs.config.account_threepid_delegate_email, threepid_creds
+ self.hs.config.registration.account_threepid_delegate_email,
+ threepid_creds,
)
elif (
self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
@@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
- return bool(self.hs.config.account_threepid_delegate_msisdn)
+ return bool(self.hs.config.registration.account_threepid_delegate_msisdn)
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("msisdn", authdict)
@@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
- self._enabled = bool(hs.config.registration_requires_token)
+ self._enabled = bool(hs.config.registration.registration_requires_token)
self.store = hs.get_datastore()
def is_enabled(self) -> bool:
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index b91e7cb5..8810f048 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -60,7 +60,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.update_user_directory = hs.config.update_user_directory
+ self.update_user_directory = hs.config.server.update_user_directory
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
@@ -132,12 +132,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
- # Support users are for diagnostics and should not appear in the user directory.
- is_support = await self.store.is_support_user(user_id)
- # When change profile information of deactivated user it should not appear in the user directory.
- is_deactivated = await self.store.get_user_deactivated_status(user_id)
-
- if not (is_support or is_deactivated):
+ if await self.store.should_include_local_user_in_dir(user_id):
await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@@ -208,6 +203,7 @@ class UserDirectoryHandler(StateDeltasHandler):
public_value=Membership.JOIN,
)
+ is_remote = not self.is_mine_id(state_key)
if change is MatchChange.now_false:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
@@ -225,32 +221,36 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id in user_ids:
await self._handle_remove_user(room_id, user_id)
- return
+ continue
else:
logger.debug("Server is still in room: %r", room_id)
- is_support = await self.store.is_support_user(state_key)
- if not is_support:
+ include_in_dir = (
+ is_remote
+ or await self.store.should_include_local_user_in_dir(state_key)
+ )
+ if include_in_dir:
if change is MatchChange.no_change:
- # Handle any profile changes
- await self._handle_profile_change(
- state_key, room_id, prev_event_id, event_id
- )
+ # Handle any profile changes for remote users.
+ # (For local users we are not forced to scan membership
+ # events; instead the rest of the application calls
+ # `handle_local_profile_change`.)
+ if is_remote:
+ await self._handle_profile_change(
+ state_key, room_id, prev_event_id, event_id
+ )
continue
if change is MatchChange.now_true: # The user joined
- event = await self.store.get_event(event_id, allow_none=True)
- # It isn't expected for this event to not exist, but we
- # don't want the entire background process to break.
- if event is None:
- continue
-
- profile = ProfileInfo(
- avatar_url=event.content.get("avatar_url"),
- display_name=event.content.get("displayname"),
- )
-
- await self._handle_new_user(room_id, state_key, profile)
+ # This may be the first time we've seen a remote user. If
+ # so, ensure we have a directory entry for them. (We don't
+ # need to do this for local users: their directory entry
+ # is created at the point of registration.
+ if is_remote:
+ await self._upsert_directory_entry_for_remote_user(
+ state_key, event_id
+ )
+ await self._track_user_joined_room(room_id, state_key)
else: # The user left
await self._handle_remove_user(room_id, state_key)
else:
@@ -300,7 +300,7 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id
)
- logger.debug("Change: %r, publicness: %r", publicness, is_public)
+ logger.debug("Publicness change: %r, is_public: %r", publicness, is_public)
if publicness is MatchChange.now_true and not is_public:
# If we became world readable but room isn't currently public then
@@ -311,42 +311,50 @@ class UserDirectoryHandler(StateDeltasHandler):
# ignore the change
return
- other_users_in_room_with_profiles = (
- await self.store.get_users_in_room_with_profiles(room_id)
- )
+ users_in_room = await self.store.get_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
- for user_id in other_users_in_room_with_profiles.keys():
+ for user_id in users_in_room:
await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables.
- # NOTE: this is not the most efficient method, as handle_new_user sets
+ # NOTE: this is not the most efficient method, as _track_user_joined_room sets
# up local_user -> other_user and other_user_whos_local -> local_user,
# which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this
# too bad, though.
- for user_id, profile in other_users_in_room_with_profiles.items():
- await self._handle_new_user(room_id, user_id, profile)
+ for user_id in users_in_room:
+ await self._track_user_joined_room(room_id, user_id)
- async def _handle_new_user(
- self, room_id: str, user_id: str, profile: ProfileInfo
+ async def _upsert_directory_entry_for_remote_user(
+ self, user_id: str, event_id: str
) -> None:
- """Called when we might need to add user to directory
-
- Args:
- room_id: The room ID that user joined or started being public
- user_id
+ """A remote user has just joined a room. Ensure they have an entry in
+ the user directory. The caller is responsible for making sure they're
+ remote.
"""
+ event = await self.store.get_event(event_id, allow_none=True)
+ # It isn't expected for this event to not exist, but we
+ # don't want the entire background process to break.
+ if event is None:
+ return
+
logger.debug("Adding new user to dir, %r", user_id)
await self.store.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
+ user_id, event.content.get("displayname"), event.content.get("avatar_url")
)
+ async def _track_user_joined_room(self, room_id: str, user_id: str) -> None:
+ """Someone's just joined a room. Update `users_in_public_rooms` or
+ `users_who_share_private_rooms` as appropriate.
+
+ The caller is responsible for ensuring that the given user is not excluded
+ from the user directory.
+ """
is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
- # Now we update users who share rooms with users.
other_users_in_room = await self.store.get_users_in_room(room_id)
if is_public:
@@ -356,13 +364,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# First, if they're our user then we need to update for every user
if self.is_mine_id(user_id):
-
- is_appservice = self.store.get_if_app_services_interested_in_user(
- user_id
- )
-
- # We don't care about appservice users.
- if not is_appservice:
+ if await self.store.should_include_local_user_in_dir(user_id):
for other_user_id in other_users_in_room:
if user_id == other_user_id:
continue
@@ -374,10 +376,10 @@ class UserDirectoryHandler(StateDeltasHandler):
if user_id == other_user_id:
continue
- is_appservice = self.store.get_if_app_services_interested_in_user(
+ include_other_user = self.is_mine_id(
other_user_id
- )
- if self.is_mine_id(other_user_id) and not is_appservice:
+ ) and await self.store.should_include_local_user_in_dir(other_user_id)
+ if include_other_user:
to_insert.add((other_user_id, user_id))
if to_insert:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5204c3d0..b5a2d333 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -912,7 +912,7 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def __init__(self):
self._context = SSL.Context(SSL.SSLv23_METHOD)
- self._context.set_verify(VERIFY_NONE, lambda *_: None)
+ self._context.set_verify(VERIFY_NONE, lambda *_: False)
def getContext(self, hostname=None, port=None):
return self._context
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cdc36b8d..4f592246 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -327,23 +327,23 @@ class MatrixFederationHttpClient:
self.reactor = hs.get_reactor()
user_agent = hs.version_string
- if hs.config.user_agent_suffix:
- user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
+ if hs.config.server.user_agent_suffix:
+ user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
user_agent = user_agent.encode("ascii")
federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
- hs.config.federation_ip_range_whitelist,
- hs.config.federation_ip_range_blacklist,
+ hs.config.server.federation_ip_range_whitelist,
+ hs.config.server.federation_ip_range_blacklist,
)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
federation_agent,
- ip_blacklist=hs.config.federation_ip_range_blacklist,
+ ip_blacklist=hs.config.server.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 0df1bfbe..897ba5e4 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -563,7 +563,10 @@ class _ByteProducer:
try:
self._request.registerProducer(self, True)
- except RuntimeError as e:
+ except AttributeError as e:
+ # Calling self._request.registerProducer might raise an AttributeError since
+ # the underlying Twisted code calls self._request.channel.registerProducer,
+ # however self._request.channel will be None if the connection was lost.
logger.info("Connection disconnected before response was written: %r", e)
# We drop our references to data we'll not use.
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 6e82f7c7..b78d6e17 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -65,6 +65,12 @@ class JsonFormatter(logging.Formatter):
if key not in _IGNORED_LOG_RECORD_ATTRIBUTES:
event[key] = value
+ if record.exc_info:
+ exc_type, exc_value, _ = record.exc_info
+ if exc_type:
+ event["exc_type"] = f"{exc_type.__name__}"
+ event["exc_value"] = f"{exc_value}"
+
return _encoder.encode(event)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 02e5ddd2..bdc01877 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -52,7 +52,7 @@ try:
is_thread_resource_usage_supported = True
- def get_thread_resource_usage() -> "Optional[resource._RUsage]":
+ def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
return resource.getrusage(RUSAGE_THREAD)
@@ -61,7 +61,7 @@ except Exception:
# won't track resource usage.
is_thread_resource_usage_supported = False
- def get_thread_resource_usage() -> "Optional[resource._RUsage]":
+ def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
return None
@@ -226,10 +226,10 @@ class _Sentinel:
def copy_to(self, record):
pass
- def start(self, rusage: "Optional[resource._RUsage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]"):
pass
- def stop(self, rusage: "Optional[resource._RUsage]"):
+ def stop(self, rusage: "Optional[resource.struct_rusage]"):
pass
def add_database_transaction(self, duration_sec):
@@ -289,7 +289,7 @@ class LoggingContext:
# The thread resource usage when the logcontext became active. None
# if the context is not currently active.
- self.usage_start: Optional[resource._RUsage] = None
+ self.usage_start: Optional[resource.struct_rusage] = None
self.main_thread = get_thread_id()
self.request = None
@@ -410,7 +410,7 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def start(self, rusage: "Optional[resource._RUsage]") -> None:
+ def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is currently running.
@@ -435,7 +435,7 @@ class LoggingContext:
else:
self.usage_start = rusage
- def stop(self, rusage: "Optional[resource._RUsage]") -> None:
+ def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is no longer running.
@@ -490,7 +490,7 @@ class LoggingContext:
return res
- def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]:
+ def _get_cputime(self, current: "resource.struct_rusage") -> Tuple[float, float]:
"""Get the cpu usage time between start() and the given rusage
Args:
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 03d2dd94..20d23a42 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -339,6 +339,7 @@ def ensure_active_span(message, ret=None):
"There was no active span when trying to %s."
" Did you forget to start one or did a context slip?",
message,
+ stack_info=True,
)
return ret
@@ -806,6 +807,14 @@ def trace(func=None, opname=None):
result.addCallbacks(call_back, err_back)
else:
+ if inspect.isawaitable(result):
+ logger.error(
+ "@trace may not have wrapped %s correctly! "
+ "The function is not async but returned a %s.",
+ func.__qualname__,
+ type(result).__name__,
+ )
+
scope.__exit__(None, None, None)
return result
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 3a142607..2ab599a3 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -265,7 +265,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self)
- def start(self, rusage: "Optional[resource._RUsage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]"):
"""Log context has started running (again)."""
super().start(rusage)
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 2c23afe8..820f6f3f 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -94,7 +94,7 @@ class Pusher(metaclass=abc.ABCMeta):
self._start_processing()
@abc.abstractmethod
- def _start_processing(self):
+ def _start_processing(self) -> None:
"""Start processing push notifications."""
raise NotImplementedError()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c337e530..0622a37a 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -290,6 +290,12 @@ def _condition_checker(
return True
+MemberMap = Dict[str, Tuple[str, str]]
+Rule = Dict[str, dict]
+RulesByUser = Dict[str, List[Rule]]
+StateGroup = Union[object, int]
+
+
@attr.s(slots=True)
class RulesForRoomData:
"""The data stored in the cache by `RulesForRoom`.
@@ -299,16 +305,16 @@ class RulesForRoomData:
"""
# event_id -> (user_id, state)
- member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict)
+ member_map = attr.ib(type=MemberMap, factory=dict)
# user_id -> rules
- rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict)
+ rules_by_user = attr.ib(type=RulesByUser, factory=dict)
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
- state_group = attr.ib(type=Union[object, int], factory=object)
+ state_group = attr.ib(type=StateGroup, factory=object)
# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
@@ -532,7 +538,13 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
- def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
+ def update_cache(
+ self,
+ sequence: int,
+ members: MemberMap,
+ rules_by_user: RulesByUser,
+ state_group: StateGroup,
+ ) -> None:
if sequence == self.data.sequence:
self.data.member_map.update(members)
self.data.rules_by_user = rules_by_user
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 1fc9716a..c5708cd8 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -19,7 +19,9 @@ from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MA
from synapse.types import UserID
-def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
+def format_push_rules_for_user(
+ user: UserID, ruleslist: List
+) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eac65572..dbf4ad7f 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -403,10 +403,10 @@ class HttpPusher(Pusher):
rejected = resp["rejected"]
return rejected
- async def _send_badge(self, badge):
+ async def _send_badge(self, badge: int) -> None:
"""
Args:
- badge (int): number of unread messages
+ badge: number of unread messages
"""
logger.debug("Sending updated badge count %d to %s", badge, self.name)
d = {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index e38e3c5d..ce299ba3 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -892,7 +892,7 @@ def safe_text(raw_text: str) -> jinja2.Markup:
A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup(
- bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
+ bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False))
)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f1b78d09..e047ec74 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -182,85 +182,87 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
)
@trace(opname="outgoing_replication_request")
- @outgoing_gauge.track_inprogress()
async def send_request(*, instance_name="master", **kwargs):
- if instance_name == local_instance_name:
- raise Exception("Trying to send HTTP request to self")
- if instance_name == "master":
- host = master_host
- port = master_port
- elif instance_name in instance_map:
- host = instance_map[instance_name].host
- port = instance_map[instance_name].port
- else:
- raise Exception(
- "Instance %r not in 'instance_map' config" % (instance_name,)
+ with outgoing_gauge.track_inprogress():
+ if instance_name == local_instance_name:
+ raise Exception("Trying to send HTTP request to self")
+ if instance_name == "master":
+ host = master_host
+ port = master_port
+ elif instance_name in instance_map:
+ host = instance_map[instance_name].host
+ port = instance_map[instance_name].port
+ else:
+ raise Exception(
+ "Instance %r not in 'instance_map' config" % (instance_name,)
+ )
+
+ data = await cls._serialize_payload(**kwargs)
+
+ url_args = [
+ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
+ ]
+
+ if cls.CACHE:
+ txn_id = random_string(10)
+ url_args.append(txn_id)
+
+ if cls.METHOD == "POST":
+ request_func = client.post_json_get_json
+ elif cls.METHOD == "PUT":
+ request_func = client.put_json
+ elif cls.METHOD == "GET":
+ request_func = client.get_json
+ else:
+ # We have already asserted in the constructor that a
+ # compatible was picked, but lets be paranoid.
+ raise Exception(
+ "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
+ )
+
+ uri = "http://%s:%s/_synapse/replication/%s/%s" % (
+ host,
+ port,
+ cls.NAME,
+ "/".join(url_args),
)
- data = await cls._serialize_payload(**kwargs)
-
- url_args = [
- urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
- ]
-
- if cls.CACHE:
- txn_id = random_string(10)
- url_args.append(txn_id)
-
- if cls.METHOD == "POST":
- request_func = client.post_json_get_json
- elif cls.METHOD == "PUT":
- request_func = client.put_json
- elif cls.METHOD == "GET":
- request_func = client.get_json
- else:
- # We have already asserted in the constructor that a
- # compatible was picked, but lets be paranoid.
- raise Exception(
- "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
- )
-
- uri = "http://%s:%s/_synapse/replication/%s/%s" % (
- host,
- port,
- cls.NAME,
- "/".join(url_args),
- )
-
- try:
- # We keep retrying the same request for timeouts. This is so that we
- # have a good idea that the request has either succeeded or failed on
- # the master, and so whether we should clean up or not.
- while True:
- headers: Dict[bytes, List[bytes]] = {}
- # Add an authorization header, if configured.
- if replication_secret:
- headers[b"Authorization"] = [b"Bearer " + replication_secret]
- opentracing.inject_header_dict(headers, check_destination=False)
- try:
- result = await request_func(uri, data, headers=headers)
- break
- except RequestTimedOutError:
- if not cls.RETRY_ON_TIMEOUT:
- raise
-
- logger.warning("%s request timed out; retrying", cls.NAME)
-
- # If we timed out we probably don't need to worry about backing
- # off too much, but lets just wait a little anyway.
- await clock.sleep(1)
- except HttpResponseException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the main process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- _outgoing_request_counter.labels(cls.NAME, e.code).inc()
- raise e.to_synapse_error()
- except Exception as e:
- _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
- raise SynapseError(502, "Failed to talk to main process") from e
-
- _outgoing_request_counter.labels(cls.NAME, 200).inc()
- return result
+ try:
+ # We keep retrying the same request for timeouts. This is so that we
+ # have a good idea that the request has either succeeded or failed
+ # on the master, and so whether we should clean up or not.
+ while True:
+ headers: Dict[bytes, List[bytes]] = {}
+ # Add an authorization header, if configured.
+ if replication_secret:
+ headers[b"Authorization"] = [
+ b"Bearer " + replication_secret
+ ]
+ opentracing.inject_header_dict(headers, check_destination=False)
+ try:
+ result = await request_func(uri, data, headers=headers)
+ break
+ except RequestTimedOutError:
+ if not cls.RETRY_ON_TIMEOUT:
+ raise
+
+ logger.warning("%s request timed out; retrying", cls.NAME)
+
+ # If we timed out we probably don't need to worry about backing
+ # off too much, but lets just wait a little anyway.
+ await clock.sleep(1)
+ except HttpResponseException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the main process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ _outgoing_request_counter.labels(cls.NAME, e.code).inc()
+ raise e.to_synapse_error()
+ except Exception as e:
+ _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
+ raise SynapseError(502, "Failed to talk to main process") from e
+
+ _outgoing_request_counter.labels(cls.NAME, 200).inc()
+ return result
return send_request
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 2cb74890..8c1bf922 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -13,14 +13,14 @@
# limitations under the License.
from typing import List, Optional, Tuple
-from synapse.storage.types import Connection
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker:
def __init__(
self,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
table: str,
column: str,
extra_tables: Optional[List[Tuple[str, str]]] = None,
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 2672a2c9..cea90c0f 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,9 +15,8 @@
from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.pusher import PusherWorkerStore
-from synapse.storage.types import Connection
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -27,7 +26,12 @@ if TYPE_CHECKING:
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 37769ace..961c1776 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -117,7 +117,7 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
- self._notify_pushers = hs.config.start_pushers
+ self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool()
self._presence_handler = hs.get_presence_handler()
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1438a82b..6aa93180 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -171,7 +171,10 @@ class ReplicationCommandHandler:
if hs.config.worker.worker_app is not None:
continue
- if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ if (
+ stream.NAME == FederationStream.NAME
+ and hs.config.worker.send_federation
+ ):
# We only support federation stream if federation sending
# has been disabled on the master.
continue
@@ -225,7 +228,7 @@ class ReplicationCommandHandler:
self._is_master = hs.config.worker.worker_app is None
self._federation_sender = None
- if self._is_master and not hs.config.send_federation:
+ if self._is_master and not hs.config.worker.send_federation:
self._federation_sender = hs.get_federation_sender()
self._server_notices_sender = None
@@ -315,7 +318,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
- hs.config.redis.redis_host.encode(),
+ hs.config.redis.redis_host, # type: ignore[arg-type]
hs.config.redis.redis_port,
self._factory,
)
@@ -324,7 +327,11 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker.worker_replication_host
port = hs.config.worker.worker_replication_port
- hs.get_reactor().connectTCP(host.encode(), port, self._factory)
+ hs.get_reactor().connectTCP(
+ host, # type: ignore[arg-type]
+ port,
+ self._factory,
+ )
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 8c0df627..062fe2f3 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -364,6 +364,12 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
+ reactor.connectTCP(
+ host, # type: ignore[arg-type]
+ port,
+ factory,
+ timeout=30,
+ bindAddress=None,
+ )
return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 030852cb..80f9b23b 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -71,7 +71,7 @@ class ReplicationStreamer:
self.notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
- self._replication_torture_level = hs.config.replication_torture_level
+ self._replication_torture_level = hs.config.server.replication_torture_level
self.notifier.add_replication_callback(self.on_notifier_poke)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 46bfec46..f20aa653 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()
- if not self.hs.config.registration_shared_secret:
+ if not self.hs.config.registration.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
body = parse_json_object_from_request(request)
@@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet):
got_mac = body["mac"]
want_mac_builder = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
+ key=self.hs.config.registration.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac_builder.update(nonce.encode("utf8"))
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 6a7608d6..6b272658 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -119,7 +119,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
if existing_user_id is None:
- if self.config.request_token_inhibit_3pid_errors:
+ if self.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
# Have the configured identity server handle the request
ret = await self.identity_handler.requestEmailToken(
- self.hs.config.account_threepid_delegate_email,
+ self.hs.config.registration.account_threepid_delegate_email,
email,
client_secret,
send_attempt,
@@ -403,7 +403,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
- if self.config.request_token_inhibit_3pid_errors:
+ if self.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
# Have the configured identity server handle the request
ret = await self.identity_handler.requestEmailToken(
- self.hs.config.account_threepid_delegate_email,
+ self.hs.config.registration.account_threepid_delegate_email,
email,
client_secret,
send_attempt,
@@ -486,7 +486,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
- if self.hs.config.request_token_inhibit_3pid_errors:
+ if self.hs.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
- if not self.hs.config.account_threepid_delegate_msisdn:
+ if not self.hs.config.registration.account_threepid_delegate_msisdn:
logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
@@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
)
ret = await self.identity_handler.requestMsisdnToken(
- self.hs.config.account_threepid_delegate_msisdn,
+ self.hs.config.registration.account_threepid_delegate_msisdn,
country,
phone_number,
client_secret,
@@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
- if not self.config.account_threepid_delegate_msisdn:
+ if not self.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
400,
"This homeserver is not validating phone numbers. Use an identity server "
@@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
# Proxy submit_token request to msisdn threepid delegate
response = await self.identity_handler.proxy_msisdn_submit_token(
- self.config.account_threepid_delegate_msisdn,
+ self.config.registration.account_threepid_delegate_msisdn,
body["client_secret"],
body["sid"],
body["token"],
@@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_3pid_changes:
+ if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
@@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet):
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_3pid_changes:
+ if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
@@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_3pid_changes:
+ if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
@@ -857,8 +857,8 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
# If the domain whitelist is set, the domain must be in it
if (
valid
- and hs.config.next_link_domain_whitelist is not None
- and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+ and hs.config.server.next_link_domain_whitelist is not None
+ and next_link_parsed.hostname not in hs.config.server.next_link_domain_whitelist
):
valid = False
@@ -878,9 +878,13 @@ class WhoamiRestServlet(RestServlet):
self.auth = hs.get_auth()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- response = {"user_id": requester.user.to_string()}
+ response = {
+ "user_id": requester.user.to_string(),
+ # MSC: https://github.com/matrix-org/matrix-doc/pull/3069
+ "org.matrix.msc3069.is_guest": bool(requester.is_guest),
+ }
# Appservices and similar accounts do not have device IDs
# that we can report on, so exclude them for compliance.
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 282861fa..9c15a043 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -48,9 +48,11 @@ class AuthRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.captcha.recaptcha_template
- self.terms_template = hs.config.terms_template
- self.registration_token_template = hs.config.registration_token_template
- self.success_template = hs.config.fallback_success_template
+ self.terms_template = hs.config.consent.terms_template
+ self.registration_token_template = (
+ hs.config.registration.registration_token_template
+ )
+ self.success_template = hs.config.registration.fallback_success_template
async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 65b3b5ce..2a3e24ae 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -44,10 +44,10 @@ class CapabilitiesRestServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
change_password = self.auth_handler.can_change_password()
- response = {
+ response: JsonDict = {
"capabilities": {
"m.room_versions": {
- "default": self.config.default_room_version.identifier,
+ "default": self.config.server.default_room_version.identifier,
"available": {
v.identifier: v.disposition
for v in KNOWN_ROOM_VERSIONS.values()
@@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet):
if self.config.experimental.msc3283_enabled:
response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
- "enabled": self.config.enable_set_displayname
+ "enabled": self.config.registration.enable_set_displayname
}
response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
- "enabled": self.config.enable_set_avatar_url
+ "enabled": self.config.registration.enable_set_avatar_url
}
response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
- "enabled": self.config.enable_3pid_changes
+ "enabled": self.config.registration.enable_3pid_changes
}
return 200, response
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 6ed60c74..cc1c2f97 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -90,7 +90,7 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request)
- set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
+ set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index fa5c173f..d49a647b 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2.saml2_enabled
self.cas_enabled = hs.config.cas.cas_enabled
self.oidc_enabled = hs.config.oidc.oidc_enabled
- self._msc2918_enabled = hs.config.access_token_lifetime is not None
+ self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
self.auth = hs.get_auth()
@@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
self._clock = hs.get_clock()
- self.access_token_lifetime = hs.config.access_token_lifetime
+ self.access_token_lifetime = hs.config.registration.access_token_lifetime
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
@@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
- if hs.config.access_token_lifetime is not None:
+ if hs.config.registration.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas.cas_enabled:
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index d0f20de5..c684636c 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -41,7 +41,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester_user = None
- if self.hs.config.require_auth_for_profile_requests:
+ if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
@@ -94,7 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester_user = None
- if self.hs.config.require_auth_for_profile_requests:
+ if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
@@ -146,7 +146,7 @@ class ProfileRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester_user = None
- if self.hs.config.require_auth_for_profile_requests:
+ if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index ecebc46e..6f796d5e 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -61,7 +61,9 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
- self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+ self._users_new_default_push_rules = (
+ hs.config.server.users_new_default_push_rules
+ )
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 48b0062c..bf3cb341 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -129,7 +129,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
- if self.hs.config.request_token_inhibit_3pid_errors:
+ if self.hs.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.account_threepid_delegate_email
+ assert self.hs.config.registration.account_threepid_delegate_email
# Have the configured identity server handle the request
ret = await self.identity_handler.requestEmailToken(
- self.hs.config.account_threepid_delegate_email,
+ self.hs.config.registration.account_threepid_delegate_email,
email,
client_secret,
send_attempt,
@@ -209,7 +209,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
if existing_user_id is not None:
- if self.hs.config.request_token_inhibit_3pid_errors:
+ if self.hs.config.server.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
@@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
- if not self.hs.config.account_threepid_delegate_msisdn:
+ if not self.hs.config.registration.account_threepid_delegate_msisdn:
logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
@@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
ret = await self.identity_handler.requestMsisdnToken(
- self.hs.config.account_threepid_delegate_msisdn,
+ self.hs.config.registration.account_threepid_delegate_msisdn,
country,
phone_number,
client_secret,
@@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
)
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
- if not self.hs.config.enable_registration:
+ if not self.hs.config.registration.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
@@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
- if not self.hs.config.enable_registration:
+ if not self.hs.config.registration.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
@@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
- self._registration_enabled = self.hs.config.enable_registration
- self._msc2918_enabled = hs.config.access_token_lifetime is not None
+ self._registration_enabled = self.hs.config.registration.enable_registration
+ self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -682,7 +682,7 @@ class RegisterRestServlet(RestServlet):
# written to the db
if threepid:
if is_threepid_reserved(
- self.hs.config.mau_limits_reserved_threepids, threepid
+ self.hs.config.server.mau_limits_reserved_threepids, threepid
):
await self.store.upsert_monthly_active_user(registered_user_id)
@@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet):
async def _do_guest_registration(
self, params: JsonDict, address: Optional[str] = None
) -> Tuple[int, JsonDict]:
- if not self.hs.config.allow_guest_access:
+ if not self.hs.config.registration.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user(
make_guest=True, address=address
@@ -849,13 +849,13 @@ def _calculate_registration_flows(
"""
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
- require_email = "email" in config.registrations_require_3pid
- require_msisdn = "msisdn" in config.registrations_require_3pid
+ require_email = "email" in config.registration.registrations_require_3pid
+ require_msisdn = "msisdn" in config.registration.registrations_require_3pid
show_msisdn = True
show_email = True
- if config.disable_msisdn_registration:
+ if config.registration.disable_msisdn_registration:
show_msisdn = False
require_msisdn = False
@@ -909,7 +909,7 @@ def _calculate_registration_flows(
flow.insert(0, LoginType.RECAPTCHA)
# Prepend registration token to all flows if we're requiring a token
- if config.registration_requires_token:
+ if config.registration.registration_requires_token:
for flow in flows:
flow.insert(0, LoginType.REGISTRATION_TOKEN)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index bf46dc60..ed95189b 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -369,7 +369,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
# federations.
- if not self.hs.config.allow_public_rooms_without_auth:
+ if not self.hs.config.server.allow_public_rooms_without_auth:
raise
# We allow people to not be authed if they're just looking at our
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index bf14ec38..38ad4c24 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -15,13 +15,12 @@
import logging
import re
from http import HTTPStatus
-from typing import TYPE_CHECKING, Awaitable, List, Tuple
+from typing import TYPE_CHECKING, Awaitable, Tuple
from twisted.web.server import Request
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.appservice import ApplicationService
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -32,7 +31,7 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
@@ -77,102 +76,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
- self.hs = hs
self.store = hs.get_datastore()
- self.state_store = hs.get_storage().state
self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
+ self.room_batch_handler = hs.get_room_batch_handler()
self.txns = HttpTransactionCache(hs)
- async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
- (
- most_recent_prev_event_id,
- most_recent_prev_event_depth,
- ) = await self.store.get_max_depth_of(prev_event_ids)
-
- # We want to insert the historical event after the `prev_event` but before the successor event
- #
- # We inherit depth from the successor event instead of the `prev_event`
- # because events returned from `/messages` are first sorted by `topological_ordering`
- # which is just the `depth` and then tie-break with `stream_ordering`.
- #
- # We mark these inserted historical events as "backfilled" which gives them a
- # negative `stream_ordering`. If we use the same depth as the `prev_event`,
- # then our historical event will tie-break and be sorted before the `prev_event`
- # when it should come after.
- #
- # We want to use the successor event depth so they appear after `prev_event` because
- # it has a larger `depth` but before the successor event because the `stream_ordering`
- # is negative before the successor event.
- successor_event_ids = await self.store.get_successor_events(
- [most_recent_prev_event_id]
- )
-
- # If we can't find any successor events, then it's a forward extremity of
- # historical messages and we can just inherit from the previous historical
- # event which we can already assume has the correct depth where we want
- # to insert into.
- if not successor_event_ids:
- depth = most_recent_prev_event_depth
- else:
- (
- _,
- oldest_successor_depth,
- ) = await self.store.get_min_depth_of(successor_event_ids)
-
- depth = oldest_successor_depth
-
- return depth
-
- def _create_insertion_event_dict(
- self, sender: str, room_id: str, origin_server_ts: int
- ) -> JsonDict:
- """Creates an event dict for an "insertion" event with the proper fields
- and a random batch ID.
-
- Args:
- sender: The event author MXID
- room_id: The room ID that the event belongs to
- origin_server_ts: Timestamp when the event was sent
-
- Returns:
- The new event dictionary to insert.
- """
-
- next_batch_id = random_string(8)
- insertion_event = {
- "type": EventTypes.MSC2716_INSERTION,
- "sender": sender,
- "room_id": room_id,
- "content": {
- EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
- EventContentFields.MSC2716_HISTORICAL: True,
- },
- "origin_server_ts": origin_server_ts,
- }
-
- return insertion_event
-
- async def _create_requester_for_user_id_from_app_service(
- self, user_id: str, app_service: ApplicationService
- ) -> Requester:
- """Creates a new requester for the given user_id
- and validates that the app service is allowed to control
- the given user.
-
- Args:
- user_id: The author MXID that the app service is controlling
- app_service: The app service that controls the user
-
- Returns:
- Requester object
- """
-
- await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
-
- return create_requester(user_id, app_service=app_service)
-
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
@@ -200,121 +109,62 @@ class RoomBatchSendEventRestServlet(RestServlet):
errcode=Codes.MISSING_PARAM,
)
+ # Verify the batch_id_from_query corresponds to an actual insertion event
+ # and have the batch connected.
+ if batch_id_from_query:
+ corresponding_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(
+ room_id, batch_id_from_query
+ )
+ )
+ if corresponding_insertion_event_id is None:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "No insertion event corresponds to the given ?batch_id",
+ errcode=Codes.INVALID_PARAM,
+ )
+
# For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base
# to auth our historical messages against.
- (
- most_recent_prev_event_id,
- _,
- ) = await self.store.get_max_depth_of(prev_event_ids_from_query)
- # mapping from (type, state_key) -> state_event_id
- prev_state_map = await self.state_store.get_state_ids_for_event(
- most_recent_prev_event_id
+ auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list(
+ prev_event_ids_from_query
)
- # List of state event ID's
- prev_state_ids = list(prev_state_map.values())
- auth_event_ids = prev_state_ids
-
- state_event_ids_at_start = []
- for state_event in body["state_events_at_start"]:
- assert_params_in_dict(
- state_event, ["type", "origin_server_ts", "content", "sender"]
- )
- logger.debug(
- "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
- state_event,
- auth_event_ids,
+ # Create and persist all of the state events that float off on their own
+ # before the batch. These will most likely be all of the invite/member
+ # state events used to auth the upcoming historical messages.
+ state_event_ids_at_start = (
+ await self.room_batch_handler.persist_state_events_at_start(
+ state_events_at_start=body["state_events_at_start"],
+ room_id=room_id,
+ initial_auth_event_ids=auth_event_ids,
+ app_service_requester=requester,
)
+ )
+ # Update our ongoing auth event ID list with all of the new state we
+ # just created
+ auth_event_ids.extend(state_event_ids_at_start)
- event_dict = {
- "type": state_event["type"],
- "origin_server_ts": state_event["origin_server_ts"],
- "content": state_event["content"],
- "room_id": room_id,
- "sender": state_event["sender"],
- "state_key": state_event["state_key"],
- }
-
- # Mark all events as historical
- event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
-
- # Make the state events float off on their own
- fake_prev_event_id = "$" + random_string(43)
-
- # TODO: This is pretty much the same as some other code to handle inserting state in this file
- if event_dict["type"] == EventTypes.Member:
- membership = event_dict["content"].get("membership", None)
- event_id, _ = await self.room_member_handler.update_membership(
- await self._create_requester_for_user_id_from_app_service(
- state_event["sender"], requester.app_service
- ),
- target=UserID.from_string(event_dict["state_key"]),
- room_id=room_id,
- action=membership,
- content=event_dict["content"],
- outlier=True,
- prev_event_ids=[fake_prev_event_id],
- # Make sure to use a copy of this list because we modify it
- # later in the loop here. Otherwise it will be the same
- # reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
- )
- else:
- # TODO: Add some complement tests that adds state that is not member joins
- # and will use this code path. Maybe we only want to support join state events
- # and can get rid of this `else`?
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- await self._create_requester_for_user_id_from_app_service(
- state_event["sender"], requester.app_service
- ),
- event_dict,
- outlier=True,
- prev_event_ids=[fake_prev_event_id],
- # Make sure to use a copy of this list because we modify it
- # later in the loop here. Otherwise it will be the same
- # reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
- )
- event_id = event.event_id
-
- state_event_ids_at_start.append(event_id)
- auth_event_ids.append(event_id)
-
- events_to_create = body["events"]
-
- inherited_depth = await self._inherit_depth_from_prev_ids(
+ inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query
)
+ events_to_create = body["events"]
+
# Figure out which batch to connect to. If they passed in
# batch_id_from_query let's use it. The batch ID passed in comes
# from the batch_id in the "insertion" event from the previous batch.
last_event_in_batch = events_to_create[-1]
- batch_id_to_connect_to = batch_id_from_query
base_insertion_event = None
if batch_id_from_query:
+ batch_id_to_connect_to = batch_id_from_query
# All but the first base insertion event should point at a fake
# event, which causes the HS to ask for the state at the start of
# the batch later.
+ fake_prev_event_id = "$" + random_string(43)
prev_event_ids = [fake_prev_event_id]
-
- # Verify the batch_id_from_query corresponds to an actual insertion event
- # and have the batch connected.
- corresponding_insertion_event_id = (
- await self.store.get_insertion_event_by_batch_id(batch_id_from_query)
- )
- if corresponding_insertion_event_id is None:
- raise SynapseError(
- 400,
- "No insertion event corresponds to the given ?batch_id",
- errcode=Codes.INVALID_PARAM,
- )
- pass
# Otherwise, create an insertion event to act as a starting point.
#
# We don't always have an insertion event to start hanging more history
@@ -325,10 +175,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
else:
prev_event_ids = prev_event_ids_from_query
- base_insertion_event_dict = self._create_insertion_event_dict(
- sender=requester.user.to_string(),
- room_id=room_id,
- origin_server_ts=last_event_in_batch["origin_server_ts"],
+ base_insertion_event_dict = (
+ self.room_batch_handler.create_insertion_event_dict(
+ sender=requester.user.to_string(),
+ room_id=room_id,
+ origin_server_ts=last_event_in_batch["origin_server_ts"],
+ )
)
base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
@@ -336,7 +188,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
base_insertion_event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- await self._create_requester_for_user_id_from_app_service(
+ await self.room_batch_handler.create_requester_for_user_id_from_app_service(
base_insertion_event_dict["sender"],
requester.app_service,
),
@@ -351,92 +203,17 @@ class RoomBatchSendEventRestServlet(RestServlet):
EventContentFields.MSC2716_NEXT_BATCH_ID
]
- # Connect this current batch to the insertion event from the previous batch
- batch_event = {
- "type": EventTypes.MSC2716_BATCH,
- "sender": requester.user.to_string(),
- "room_id": room_id,
- "content": {
- EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
- EventContentFields.MSC2716_HISTORICAL: True,
- },
- # Since the batch event is put at the end of the batch,
- # where the newest-in-time event is, copy the origin_server_ts from
- # the last event we're inserting
- "origin_server_ts": last_event_in_batch["origin_server_ts"],
- }
- # Add the batch event to the end of the batch (newest-in-time)
- events_to_create.append(batch_event)
-
- # Add an "insertion" event to the start of each batch (next to the oldest-in-time
- # event in the batch) so the next batch can be connected to this one.
- insertion_event = self._create_insertion_event_dict(
- sender=requester.user.to_string(),
+ # Create and persist all of the historical events as well as insertion
+ # and batch meta events to make the batch navigable in the DAG.
+ event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events(
+ events_to_create=events_to_create,
room_id=room_id,
- # Since the insertion event is put at the start of the batch,
- # where the oldest-in-time event is, copy the origin_server_ts from
- # the first event we're inserting
- origin_server_ts=events_to_create[0]["origin_server_ts"],
+ batch_id_to_connect_to=batch_id_to_connect_to,
+ initial_prev_event_ids=prev_event_ids,
+ inherited_depth=inherited_depth,
+ auth_event_ids=auth_event_ids,
+ app_service_requester=requester,
)
- # Prepend the insertion event to the start of the batch (oldest-in-time)
- events_to_create = [insertion_event] + events_to_create
-
- event_ids = []
- events_to_persist = []
- for ev in events_to_create:
- assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
-
- event_dict = {
- "type": ev["type"],
- "origin_server_ts": ev["origin_server_ts"],
- "content": ev["content"],
- "room_id": room_id,
- "sender": ev["sender"], # requester.user.to_string(),
- "prev_events": prev_event_ids.copy(),
- }
-
- # Mark all events as historical
- event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
-
- event, context = await self.event_creation_handler.create_event(
- await self._create_requester_for_user_id_from_app_service(
- ev["sender"], requester.app_service
- ),
- event_dict,
- prev_event_ids=event_dict.get("prev_events"),
- auth_event_ids=auth_event_ids,
- historical=True,
- depth=inherited_depth,
- )
- logger.debug(
- "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
- event,
- prev_event_ids,
- auth_event_ids,
- )
-
- assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
- event.sender,
- )
-
- events_to_persist.append((event, context))
- event_id = event.event_id
-
- event_ids.append(event_id)
- prev_event_ids = [event_id]
-
- # Persist events in reverse-chronological order so they have the
- # correct stream_ordering as they are backfilled (which decrements).
- # Events are sorted by (topological_ordering, stream_ordering)
- # where topological_ordering is just depth.
- for (event, context) in reversed(events_to_persist):
- ev = await self.event_creation_handler.handle_new_client_event(
- await self._create_requester_for_user_id_from_app_service(
- event["sender"], requester.app_service
- ),
- event=event,
- context=context,
- )
insertion_event_id = event_ids[0]
batch_event_id = event_ids[-1]
@@ -445,9 +222,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
response_dict = {
"state_event_ids": state_event_ids_at_start,
"event_ids": historical_event_ids,
- "next_batch_id": insertion_event["content"][
- EventContentFields.MSC2716_NEXT_BATCH_ID
- ],
+ "next_batch_id": next_batch_id,
"insertion_event_id": insertion_event_id,
"batch_event_id": batch_event_id,
}
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index 1d90493e..09a46737 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -42,7 +42,7 @@ class UserSharedRoomsServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.user_directory_active = hs.config.update_user_directory
+ self.user_directory_active = hs.config.server.update_user_directory
async def on_GET(
self, request: SynapseRequest, user_id: str
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 1259058b..913216a7 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -155,7 +155,7 @@ class SyncRestServlet(RestServlet):
try:
filter_object = json_decoder.decode(filter_id)
set_timeline_upper_limit(
- filter_object, self.hs.config.filter_timeline_limit
+ filter_object, self.hs.config.server.filter_timeline_limit
)
except Exception:
raise SynapseError(400, "Invalid filter JSON")
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index ea2b8aa4..ea7e0251 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -70,7 +70,7 @@ class VoipRestServlet(RestServlet):
{
"username": username,
"password": password,
- "ttl": userLifetime / 1000,
+ "ttl": userLifetime // 1000,
"uris": turnUris,
},
)
diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py
index 3dd16d4b..d5b74cdd 100644
--- a/synapse/rest/media/v1/__init__.py
+++ b/synapse/rest/media/v1/__init__.py
@@ -12,33 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import PIL.Image
+from PIL.features import check_codec
# check for JPEG support.
-try:
- PIL.Image._getdecoder("rgb", "jpeg", None)
-except OSError as e:
- if str(e).startswith("decoder jpeg not available"):
- raise Exception(
- "FATAL: jpeg codec not supported. Install pillow correctly! "
- " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
- " pip install pillow --user'"
- )
-except Exception:
- # any other exception is fine
- pass
+if not check_codec("jpg"):
+ raise Exception(
+ "FATAL: jpeg codec not supported. Install pillow correctly! "
+ " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
+ " pip install pillow --user'"
+ )
# check for PNG support.
-try:
- PIL.Image._getdecoder("rgb", "zip", None)
-except OSError as e:
- if str(e).startswith("decoder zip not available"):
- raise Exception(
- "FATAL: zip codec not supported. Install pillow correctly! "
- " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
- " pip install pillow --user'"
- )
-except Exception:
- # any other exception is fine
- pass
+if not check_codec("zlib"):
+ raise Exception(
+ "FATAL: zip codec not supported. Install pillow correctly! "
+ " 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
+ " pip install pillow --user'"
+ )
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index e04671fb..78b1603f 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -96,6 +96,32 @@ class OEmbedProvider:
# No match.
return None
+ def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]:
+ """
+ Search an HTML document for oEmbed autodiscovery information.
+
+ Args:
+ tree: The parsed HTML body.
+
+ Returns:
+ The URL to use for oEmbed information, or None if no URL was found.
+ """
+ # Search for link elements with the proper rel and type attributes.
+ for tag in tree.xpath(
+ "//link[@rel='alternate'][@type='application/json+oembed']"
+ ):
+ if "href" in tag.attrib:
+ return tag.attrib["href"]
+
+ # Some providers (e.g. Flickr) use alternative instead of alternate.
+ for tag in tree.xpath(
+ "//link[@rel='alternative'][@type='application/json+oembed']"
+ ):
+ if "href" in tag.attrib:
+ return tag.attrib["href"]
+
+ return None
+
def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
"""
Parse the oEmbed response into an Open Graph response.
@@ -165,7 +191,7 @@ class OEmbedProvider:
except Exception as e:
# Trap any exception and let the code follow as usual.
- logger.warning(f"Error parsing oEmbed metadata from {url}: {e:r}")
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
open_graph_response = {}
cache_age = None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 79a42b24..1fe0fc8a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import re
import shutil
import sys
import traceback
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Tuple, Union
from urllib import parse as urlparse
import attr
@@ -73,6 +73,7 @@ OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
ONE_DAY = 24 * ONE_HOUR
+IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -295,22 +296,32 @@ class PreviewUrlResource(DirectServeJsonResource):
body = file.read()
encoding = get_html_media_encoding(body, media_info.media_type)
- og = decode_and_calc_og(body, media_info.uri, encoding)
-
- await self._precache_image_url(user, media_info, og)
-
- elif oembed_url and _is_json(media_info.media_type):
- # Handle an oEmbed response.
- with open(media_info.filename, "rb") as file:
- body = file.read()
-
- oembed_response = self._oembed.parse_oembed_response(url, body)
- og = oembed_response.open_graph_result
-
- # Use the cache age from the oEmbed result, instead of the HTTP response.
- if oembed_response.cache_age is not None:
- expiration_ms = oembed_response.cache_age
+ tree = decode_body(body, encoding)
+ if tree is not None:
+ # Check if this HTML document points to oEmbed information and
+ # defer to that.
+ oembed_url = self._oembed.autodiscover_from_html(tree)
+ og = {}
+ if oembed_url:
+ oembed_info = await self._download_url(oembed_url, user)
+ og, expiration_ms = await self._handle_oembed_response(
+ url, oembed_info, expiration_ms
+ )
+
+ # If there was no oEmbed URL (or oEmbed parsing failed), attempt
+ # to generate the Open Graph information from the HTML.
+ if not oembed_url or not og:
+ og = _calc_og(tree, media_info.uri)
+
+ await self._precache_image_url(user, media_info, og)
+ else:
+ og = {}
+ elif oembed_url:
+ # Handle the oEmbed information.
+ og, expiration_ms = await self._handle_oembed_response(
+ url, media_info, expiration_ms
+ )
await self._precache_image_url(user, media_info, og)
else:
@@ -478,6 +489,39 @@ class PreviewUrlResource(DirectServeJsonResource):
else:
del og["og:image"]
+ async def _handle_oembed_response(
+ self, url: str, media_info: MediaInfo, expiration_ms: int
+ ) -> Tuple[JsonDict, int]:
+ """
+ Parse the downloaded oEmbed info.
+
+ Args:
+ url: The URL which is being previewed (not the one which was
+ requested).
+ media_info: The media being previewed.
+ expiration_ms: The length of time, in milliseconds, the media is valid for.
+
+ Returns:
+ A tuple of:
+ The Open Graph dictionary, if the oEmbed info can be parsed.
+ The (possibly updated) length of time, in milliseconds, the media is valid for.
+ """
+ # If JSON was not returned, there's nothing to do.
+ if not _is_json(media_info.media_type):
+ return {}, expiration_ms
+
+ with open(media_info.filename, "rb") as file:
+ body = file.read()
+
+ oembed_response = self._oembed.parse_oembed_response(url, body)
+ open_graph_result = oembed_response.open_graph_result
+
+ # Use the cache age from the oEmbed result, if one was given.
+ if open_graph_result and oembed_response.cache_age is not None:
+ expiration_ms = oembed_response.cache_age
+
+ return open_graph_result, expiration_ms
+
def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
@@ -496,6 +540,27 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.info("Still running DB updates; skipping expiry")
return
+ def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
+ """Attempt to remove the given chain of parent directories
+
+ Args:
+ dirs: The list of directory paths to delete, with children appearing
+ before their parents.
+ """
+ for dir in dirs:
+ try:
+ os.rmdir(dir)
+ except FileNotFoundError:
+ # Already deleted, continue with deleting the rest
+ pass
+ except OSError as e:
+ # Failed, skip deleting the rest of the parent dirs
+ if e.errno != errno.ENOTEMPTY:
+ logger.warning(
+ "Failed to remove media directory: %r: %s", dir, e
+ )
+ break
+
# First we delete expired url cache entries
media_ids = await self.store.get_expired_url_cache(now)
@@ -504,20 +569,16 @@ class PreviewUrlResource(DirectServeJsonResource):
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
+ except FileNotFoundError:
+ pass # If the path doesn't exist, meh
except OSError as e:
- # If the path doesn't exist, meh
- if e.errno != errno.ENOENT:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
- continue
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
+ continue
removed_media.append(media_id)
- try:
- dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
- for dir in dirs:
- os.rmdir(dir)
- except Exception:
- pass
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ try_remove_parent_dirs(dirs)
await self.store.delete_url_cache(removed_media)
@@ -530,7 +591,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * ONE_DAY
+ expire_before = now - IMAGE_CACHE_EXPIRY_MS
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
@@ -538,36 +599,30 @@ class PreviewUrlResource(DirectServeJsonResource):
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
+ except FileNotFoundError:
+ pass # If the path doesn't exist, meh
except OSError as e:
- # If the path doesn't exist, meh
- if e.errno != errno.ENOENT:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
- continue
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
+ continue
- try:
- dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
- for dir in dirs:
- os.rmdir(dir)
- except Exception:
- pass
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ try_remove_parent_dirs(dirs)
thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
try:
shutil.rmtree(thumbnail_dir)
+ except FileNotFoundError:
+ pass # If the path doesn't exist, meh
except OSError as e:
- # If the path doesn't exist, meh
- if e.errno != errno.ENOENT:
- logger.warning("Failed to remove media: %r: %s", media_id, e)
- continue
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
+ continue
removed_media.append(media_id)
- try:
- dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
- for dir in dirs:
- os.rmdir(dir)
- except Exception:
- pass
+ dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
+ # Note that one of the directories to be deleted has already been
+ # removed by the `rmtree` above.
+ try_remove_parent_dirs(dirs)
await self.store.delete_url_cache_media(removed_media)
@@ -619,26 +674,22 @@ def get_html_media_encoding(body: bytes, content_type: str) -> str:
return "utf-8"
-def decode_and_calc_og(
- body: bytes, media_uri: str, request_encoding: Optional[str] = None
-) -> JsonDict:
+def decode_body(
+ body: bytes, request_encoding: Optional[str] = None
+) -> Optional["etree.Element"]:
"""
- Calculate metadata for an HTML document.
-
- This uses lxml to parse the HTML document into the OG response. If errors
- occur during processing of the document, an empty response is returned.
+ This uses lxml to parse the HTML document.
Args:
body: The HTML document, as bytes.
- media_url: The URI used to download the body.
request_encoding: The character encoding of the body, as a string.
Returns:
- The OG response as a dictionary.
+ The parsed HTML body, or None if an error occurred during processed.
"""
# If there's no body, nothing useful is going to be found.
if not body:
- return {}
+ return None
from lxml import etree
@@ -650,25 +701,22 @@ def decode_and_calc_og(
parser = etree.HTMLParser(recover=True, encoding="utf-8")
except Exception as e:
logger.warning("Unable to create HTML parser: %s" % (e,))
- return {}
-
- def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
- # Attempt to parse the body. If this fails, log and return no metadata.
- tree = etree.fromstring(body_attempt, parser)
-
- # The data was successfully parsed, but no tree was found.
- if tree is None:
- return {}
+ return None
- return _calc_og(tree, media_uri)
+ def _attempt_decode_body(
+ body_attempt: Union[bytes, str]
+ ) -> Optional["etree.Element"]:
+ # Attempt to parse the body. Returns None if the body was successfully
+ # parsed, but no tree was found.
+ return etree.fromstring(body_attempt, parser)
# Attempt to parse the body. If this fails, log and return no metadata.
try:
- return _attempt_calc_og(body)
+ return _attempt_decode_body(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- return _attempt_calc_og(body.decode("utf-8", "ignore"))
+ return _attempt_decode_body(body.decode("utf-8", "ignore"))
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index df54a406..46701a8b 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -61,9 +61,19 @@ class Thumbnailer:
self.transpose_method = None
try:
# We don't use ImageOps.exif_transpose since it crashes with big EXIF
- image_exif = self.image._getexif()
+ #
+ # Ignore safety: Pillow seems to acknowledge that this method is
+ # "private, experimental, but generally widely used". Pillow 6
+ # includes a public getexif() method (no underscore) that we might
+ # consider using instead when we can bump that dependency.
+ #
+ # At the time of writing, Debian buster (currently oldstable)
+ # provides version 5.4.1. It's expected to EOL in mid-2022, see
+ # https://wiki.debian.org/DebianReleases#Production_Releases
+ image_exif = self.image._getexif() # type: ignore
if image_exif is not None:
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
+ assert isinstance(image_orientation, int)
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
except Exception as e:
# A lot of parsing errors can happen when parsing EXIF
@@ -76,7 +86,10 @@ class Thumbnailer:
A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
- self.image = self.image.transpose(self.transpose_method)
+ # Safety: `transpose` takes an int rather than e.g. an IntEnum.
+ # self.transpose_method is set above to be a value in
+ # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values.
+ self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
self.width, self.height = self.image.size
self.transpose_method = None
# We don't need EXIF any more
@@ -101,7 +114,7 @@ class Thumbnailer:
else:
return (max_height * self.width) // self.height, max_height
- def _resize(self, width: int, height: int) -> Image:
+ def _resize(self, width: int, height: int) -> Image.Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful.
@@ -151,7 +164,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
- def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
+ def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index c80a3a99..7ac01faa 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -39,9 +39,9 @@ class WellKnownBuilder:
result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}}
- if self._config.default_identity_server:
+ if self._config.registration.default_identity_server:
result["m.identity_server"] = {
- "base_url": self._config.default_identity_server
+ "base_url": self._config.registration.default_identity_server
}
return result
diff --git a/synapse/server.py b/synapse/server.py
index 637eb15b..5bc045d6 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -39,7 +39,7 @@ from twisted.web.resource import IResource
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
-from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.config.homeserver import HomeServerConfig
@@ -97,6 +97,7 @@ from synapse.handlers.room import (
RoomCreationHandler,
RoomShutdownHandler,
)
+from synapse.handlers.room_batch import RoomBatchHandler
from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
@@ -438,6 +439,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return RoomCreationHandler(self)
@cache_in_self
+ def get_room_batch_handler(self) -> RoomBatchHandler:
+ return RoomBatchHandler(self)
+
+ @cache_in_self
def get_room_shutdown_handler(self) -> RoomShutdownHandler:
return RoomShutdownHandler(self)
@@ -816,3 +821,12 @@ class HomeServer(metaclass=abc.ABCMeta):
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
return self.config.worker.send_federation
+
+ @cache_in_self
+ def get_request_ratelimiter(self) -> RequestRatelimiter:
+ return RequestRatelimiter(
+ self.get_datastore(),
+ self.get_clock(),
+ self.config.ratelimiting.rc_message,
+ self.config.ratelimiting.rc_admin_redaction,
+ )
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 073b0d75..8522930b 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -47,9 +47,9 @@ class ResourceLimitsServerNotices:
self._notifier = hs.get_notifier()
self._enabled = (
- hs.config.limit_usage_by_mau
+ hs.config.server.limit_usage_by_mau
and self._server_notices_manager.is_enabled()
- and not hs.config.hs_disabled
+ and not hs.config.server.hs_disabled
)
async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
@@ -98,7 +98,7 @@ class ResourceLimitsServerNotices:
try:
if (
limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER
- and not self._config.mau_limit_alerting
+ and not self._config.server.mau_limit_alerting
):
# We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return
@@ -149,7 +149,7 @@ class ResourceLimitsServerNotices:
"body": event_body,
"msgtype": ServerNoticeMsgType,
"server_notice_type": ServerNoticeLimitReached,
- "admin_contact": self._config.admin_contact,
+ "admin_contact": self._config.server.admin_contact,
"limit_type": event_limit_type,
}
event = await self._server_notices_manager.send_notice(
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index cd1c5ff6..0cf60236 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -41,12 +41,8 @@ class ServerNoticesManager:
self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.servernotices.server_notices_mxid
- def is_enabled(self):
- """Checks if server notices are enabled on this server.
-
- Returns:
- bool
- """
+ def is_enabled(self) -> bool:
+ """Checks if server notices are enabled on this server."""
return self.server_notices_mxid is not None
async def send_notice(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c981df3f..5cf2e125 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -118,7 +118,7 @@ class _StateCacheEntry:
else:
self.state_id = _gen_state_id()
- def __len__(self):
+ def __len__(self) -> int:
return len(self.state)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 92336d7c..ffe6207a 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -225,7 +225,7 @@ def _resolve_with_state(
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase],
-):
+) -> MutableStateMap[str]:
conflicted_state = {}
for key, event_ids in conflicted_state_ids.items():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
@@ -329,12 +329,10 @@ def _resolve_auth_events(
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
event,
auth_events,
- do_sig_check=False,
- do_size_check=False,
)
prev_event = event
except AuthError:
@@ -349,12 +347,10 @@ def _resolve_normal_events(
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
RoomVersions.V1,
event,
auth_events,
- do_sig_check=False,
- do_size_check=False,
)
return event
except AuthError:
@@ -366,7 +362,7 @@ def _resolve_normal_events(
def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
- def key_func(e):
+ def key_func(e: EventBase) -> Tuple[int, str]:
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest()
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 7b1e8361..bd18eefd 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -481,7 +481,7 @@ async def _reverse_topological_power_sort(
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
- def _get_power_order(event_id):
+ def _get_power_order(event_id: str) -> Tuple[int, int, str]:
ev = event_map[event_id]
pl = event_to_pl[event_id]
@@ -546,12 +546,10 @@ async def _iterative_auth_checks(
auth_events[key] = event_map[ev_id]
try:
- event_auth.check(
+ event_auth.check_auth_rules_for_event(
room_version,
event,
auth_events,
- do_sig_check=False,
- do_size_check=False,
)
resolved_state[(event.type, event.state_key)] = event_id
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 6305414e..eee07227 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -36,7 +36,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
if (
hs.config.worker.run_background_tasks
- and self.hs.config.redaction_retention_period is not None
+ and self.hs.config.server.redaction_retention_period is not None
):
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
@@ -48,7 +48,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
By censor we mean update the event_json table with the redacted event.
"""
- if self.hs.config.redaction_retention_period is None:
+ if self.hs.config.server.redaction_retention_period is None:
return
if not (
@@ -60,7 +60,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# created.
return
- before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+ before_ts = (
+ self._clock.time_msec() - self.hs.config.server.redaction_retention_period
+ )
# We fetch all redactions that:
# 1. point to an event we have,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index cc192f5c..6c1ef090 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -353,7 +353,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self.user_ips_max_age = hs.config.user_ips_max_age
+ self.user_ips_max_age = hs.config.server.user_ips_max_age
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore):
"""
ret = await super().get_last_client_ip_by_device(user_id, device_id)
- # Update what is retrieved from the database with data which is pending insertion.
+ # Update what is retrieved from the database with data which is pending
+ # insertion, as if it has already been stored in the database.
for key in self._batch_row_update:
- uid, access_token, ip = key
+ uid, _access_token, ip = key
if uid == user_id:
user_agent, did, last_seen = self._batch_row_update[key]
+
+ if did is None:
+ # These updates don't make it to the `devices` table
+ continue
+
if not device_id or did == device_id:
- ret[(user_id, device_id)] = {
+ ret[(user_id, did)] = {
"user_id": user_id,
- "access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": did,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 584f818f..19f55c19 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -104,7 +104,7 @@ class PersistEventsStore:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
- self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+ self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
@@ -1276,13 +1276,6 @@ class PersistEventsStore:
logger.exception("")
raise
- # update the stored internal_metadata to update the "outlier" flag.
- # TODO: This is unused as of Synapse 1.31. Remove it once we are happy
- # to drop backwards-compatibility with 1.30.
- metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
- sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
- txn.execute(sql, (metadata_json, event.event_id))
-
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
@@ -1327,19 +1320,6 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- def get_internal_metadata(event):
- im = event.internal_metadata.get_dict()
-
- # temporary hack for database compatibility with Synapse 1.30 and earlier:
- # store the `outlier` flag inside the internal_metadata json as well as in
- # the `events` table, so that if anyone rolls back to an older Synapse,
- # things keep working. This can be removed once we are happy to drop support
- # for that
- if event.internal_metadata.is_outlier():
- im["outlier"] = True
-
- return im
-
self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
@@ -1348,7 +1328,7 @@ class PersistEventsStore:
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": json_encoder.encode(
- get_internal_metadata(event)
+ event.internal_metadata.get_dict()
),
"json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
@@ -1783,9 +1763,8 @@ class PersistEventsStore:
retcol="creator",
allow_none=True,
)
- if (
- not room_version.msc2716_historical
- or not self.hs.config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self.hs.config.experimental.msc2716_enabled
or event.sender != room_creator
):
return
@@ -1845,9 +1824,8 @@ class PersistEventsStore:
retcol="creator",
allow_none=True,
)
- if (
- not room_version.msc2716_historical
- or not self.hs.config.experimental.msc2716_enabled
+ if not room_version.msc2716_historical and (
+ not self.hs.config.experimental.msc2716_enabled
or event.sender != room_creator
):
return
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index bb244a03..434986fa 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Union
+
from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
@@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
@cached(num_args=2)
- async def get_user_filter(self, user_localpart, filter_id):
+ async def get_user_filter(
+ self, user_localpart: str, filter_id: Union[int, str]
+ ) -> JsonDict:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b76ee51a..ec4d47a5 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -32,8 +32,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._max_mau_value = hs.config.max_mau_value
+ self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
+ self._max_mau_value = hs.config.server.max_mau_value
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
@@ -96,8 +96,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
users = []
- for tp in self.hs.config.mau_limits_reserved_threepids[
- : self.hs.config.max_mau_value
+ for tp in self.hs.config.server.mau_limits_reserved_threepids[
+ : self.hs.config.server.max_mau_value
]:
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
@@ -212,7 +212,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._mau_stats_only = hs.config.mau_stats_only
+ self._mau_stats_only = hs.config.server.mau_stats_only
# Do not add more reserved users than the total allowable number
self.db_pool.new_transaction(
@@ -221,7 +221,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
[],
[],
self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+ hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
@@ -354,3 +354,27 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
await self.upsert_monthly_active_user(user_id)
+
+ async def remove_deactivated_user_from_mau_table(self, user_id: str) -> None:
+ """
+ Removes a deactivated user from the monthly active user
+ table and resets affected caches.
+
+ Args:
+ user_id(str): the user_id to remove
+ """
+
+ rows_deleted = await self.db_pool.simple_delete(
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ desc="simple_delete",
+ )
+
+ if rows_deleted != 0:
+ await self.invalidate_cache_and_stream(
+ "user_last_seen_monthly_active", (user_id,)
+ )
+ await self.invalidate_cache_and_stream("get_monthly_active_count", ())
+ await self.invalidate_cache_and_stream(
+ "get_monthly_active_count_by_service", ()
+ )
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a7fb8cd8..fc720f59 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import logging
-from typing import List, Tuple, Union
+from typing import Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
@@ -101,7 +101,9 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+ self._users_new_default_push_rules = (
+ hs.config.server.users_new_default_push_rules
+ )
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
@@ -137,7 +139,7 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, use_new_defaults)
@cached(max_entries=5000)
- async def get_push_rules_enabled_for_user(self, user_id):
+ async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index a93caae8..b73ce53c 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional,
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -32,7 +31,12 @@ logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c83089ee..181841ee 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID, UserInfo
@@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return False
now = self._clock.time_msec()
- trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
+ trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
@@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
- id_servers = set(self.config.trusted_third_party_id_servers)
+ id_servers = set(self.config.registration.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
@@ -1775,10 +1775,17 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
- self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+ self._ignore_unknown_session_error = (
+ hs.config.server.request_token_inhibit_3pid_errors
+ )
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 118b390e..d69eaf80 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore):
# policy.
if not ret:
return {
- "min_lifetime": self.config.retention_default_min_lifetime,
- "max_lifetime": self.config.retention_default_max_lifetime,
+ "min_lifetime": self.config.server.retention_default_min_lifetime,
+ "max_lifetime": self.config.server.retention_default_max_lifetime,
}
row = ret[0]
@@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore):
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention_default_min_lifetime
+ row["min_lifetime"] = self.config.server.retention_default_min_lifetime
if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention_default_max_lifetime
+ row["max_lifetime"] = self.config.server.retention_default_max_lifetime
return row
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index a3833887..300a563c 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -18,7 +18,9 @@ from synapse.storage._base import SQLBaseStore
class RoomBatchStore(SQLBaseStore):
- async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]:
+ async def get_insertion_event_by_batch_id(
+ self, room_id: str, batch_id: str
+ ) -> Optional[str]:
"""Retrieve a insertion event ID.
Args:
@@ -30,7 +32,7 @@ class RoomBatchStore(SQLBaseStore):
"""
return await self.db_pool.simple_select_one_onecol(
table="insertion_events",
- keyvalues={"next_batch_id": batch_id},
+ keyvalues={"room_id": room_id, "next_batch_id": batch_id},
retcol="event_id",
allow_none=True,
)
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2a1e99e1..c85383c9 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -51,7 +51,7 @@ class SearchWorkerStore(SQLBaseStore):
txn:
entries: entries to be added to the table
"""
- if not self.hs.config.enable_search:
+ if not self.hs.config.server.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
@@ -105,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- if not hs.config.enable_search:
+ if not hs.config.server.enable_search:
return
self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 90d65edc..e98a45b6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
cast,
)
+from synapse.api.errors import StoreError
+
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -40,12 +42,10 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-
TEMP_TABLE = "_temp_populate_user_directory"
class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
-
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
@@ -230,38 +230,49 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
is_in_room = await self.is_host_joined(room_id, self.server_name)
if is_in_room:
- is_public = await self.is_room_world_readable_or_publicly_joinable(
- room_id
- )
-
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+ # Throw away users excluded from the directory.
+ users_with_profile = {
+ user_id: profile
+ for user_id, profile in users_with_profile.items()
+ if not self.hs.is_mine_id(user_id)
+ or await self.should_include_local_user_in_dir(user_id)
+ }
- # Update each user in the user directory.
+ # Upsert a user_directory record for each remote user we see.
for user_id, profile in users_with_profile.items():
+ # Local users are processed separately in
+ # `_populate_user_directory_users`; there we can read from
+ # the `profiles` table to ensure we don't leak their per-room
+ # profiles. It also means we write local users to this table
+ # exactly once, rather than once for every room they're in.
+ if self.hs.is_mine_id(user_id):
+ continue
+ # TODO `users_with_profile` above reads from the `user_directory`
+ # table, meaning that `profile` is bespoke to this room.
+ # and this leaks remote users' per-room profiles to the user directory.
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
- to_insert = set()
-
+ # Now update the room sharing tables to include this room.
+ is_public = await self.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
if is_public:
- for user_id in users_with_profile:
- if self.get_if_app_services_interested_in_user(user_id):
- continue
-
- to_insert.add(user_id)
-
- if to_insert:
- await self.add_users_in_public_rooms(room_id, to_insert)
- to_insert.clear()
+ if users_with_profile:
+ await self.add_users_in_public_rooms(
+ room_id, users_with_profile.keys()
+ )
else:
+ to_insert = set()
for user_id in users_with_profile:
+ # We want the set of pairs (L, M) where L and M are
+ # in `users_with_profile` and L is local.
+ # Do so by looking for the local user L first.
if not self.hs.is_mine_id(user_id):
continue
- if self.get_if_app_services_interested_in_user(user_id):
- continue
-
for other_user_id in users_with_profile:
if user_id == other_user_id:
continue
@@ -349,10 +360,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
for user_id in users_to_work_on:
- profile = await self.get_profileinfo(get_localpart_from_id(user_id))
- await self.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
- )
+ if await self.should_include_local_user_in_dir(user_id):
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ await self.update_profile_in_user_dir(
+ user_id, profile.display_name, profile.avatar_url
+ )
# We've finished processing a user. Delete it from the table.
await self.db_pool.simple_delete_one(
@@ -369,6 +381,42 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
+ async def should_include_local_user_in_dir(self, user: str) -> bool:
+ """Certain classes of local user are omitted from the user directory.
+ Is this user one of them?
+ """
+ # We're opting to exclude the appservice sender (user defined by the
+ # `sender_localpart` in the appservice registration) even though
+ # technically it could be DM-able. In the future, this could potentially
+ # be configurable per-appservice whether the appservice sender can be
+ # contacted.
+ if self.get_app_service_by_user_id(user) is not None:
+ return False
+
+ # We're opting to exclude appservice users (anyone matching the user
+ # namespace regex in the appservice registration) even though technically
+ # they could be DM-able. In the future, this could potentially
+ # be configurable per-appservice whether the appservice users can be
+ # contacted.
+ if self.get_if_app_services_interested_in_user(user):
+ # TODO we might want to make this configurable for each app service
+ return False
+
+ # Support users are for diagnostics and should not appear in the user directory.
+ if await self.is_support_user(user):
+ return False
+
+ # Deactivated users aren't contactable, so should not appear in the user directory.
+ try:
+ if await self.get_user_deactivated_status(user):
+ return False
+ except StoreError:
+ # No such user in the users table. No need to do this when calling
+ # is_support_user---that returns False if the user is missing.
+ return False
+
+ return True
+
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
@@ -527,7 +575,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- async def update_user_directory_stream_pos(self, stream_id: int) -> None:
+ async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
@@ -537,7 +585,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
-
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f31880b8..11ca47ea 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -366,7 +366,7 @@ def _upgrade_existing_database(
+ "new for the server to understand"
)
- # some of the deltas assume that config.server_name is set correctly, so now
+ # some of the deltas assume that server_name is set correctly, so now
# is a good time to run the sanity check.
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
@@ -487,6 +487,10 @@ def _upgrade_existing_database(
spec = importlib.util.spec_from_file_location(
module_name, absolute_path
)
+ if spec is None:
+ raise RuntimeError(
+ f"Could not build a module spec for {module_name} at {absolute_path}"
+ )
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 573e05a4..1aee741a 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# When updating these values, please leave a short summary of the changes below.
-
-SCHEMA_VERSION = 64
+SCHEMA_VERSION = 64 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -46,7 +44,7 @@ Changes in SCHEMA_VERSION = 64:
"""
-SCHEMA_COMPAT_VERSION = 59
+SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata.
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
This value is stored in the database, and checked on startup. If the value in the
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5e86befd..b5ba1560 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,9 +15,11 @@ import logging
from typing import (
TYPE_CHECKING,
Awaitable,
+ Collection,
Dict,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -29,7 +31,7 @@ from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
@@ -134,6 +136,23 @@ class StateFilter:
include_others=True,
)
+ @staticmethod
+ def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+ """
+ Returns a (frozen) StateFilter with the same contents as the parameters
+ specified here, which can be made of mutable types.
+ """
+ types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
+ for state_types, state_keys in types.items():
+ if state_keys is not None:
+ types_with_frozen_values[state_types] = frozenset(state_keys)
+ else:
+ types_with_frozen_values[state_types] = None
+
+ return StateFilter(
+ frozendict(types_with_frozen_values), include_others=include_others
+ )
+
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the
@@ -356,6 +375,157 @@ class StateFilter:
return member_filter, non_member_filter
+ def _decompose_into_four_parts(
+ self,
+ ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
+ """
+ Decomposes this state filter into 4 constituent parts, which can be
+ thought of as this:
+ all? - minus_wildcards + plus_wildcards + plus_state_keys
+
+ where
+ * all represents ALL state
+ * minus_wildcards represents entire state types to remove
+ * plus_wildcards represents entire state types to add
+ * plus_state_keys represents individual state keys to add
+
+ See `recompose_from_four_parts` for the other direction of this
+ correspondence.
+ """
+ is_all = self.include_others
+ excluded_types: Set[str] = {t for t in self.types if is_all}
+ wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
+ concrete_keys: Set[StateKey] = set(self.concrete_types())
+
+ return (is_all, excluded_types), (wildcard_types, concrete_keys)
+
+ @staticmethod
+ def _recompose_from_four_parts(
+ all_part: bool,
+ minus_wildcards: Set[str],
+ plus_wildcards: Set[str],
+ plus_state_keys: Set[StateKey],
+ ) -> "StateFilter":
+ """
+ Recomposes a state filter from 4 parts.
+
+ See `decompose_into_four_parts` (the other direction of this
+ correspondence) for descriptions on each of the parts.
+ """
+
+ # {state type -> set of state keys OR None for wildcard}
+ # (The same structure as that of a StateFilter.)
+ new_types: Dict[str, Optional[Set[str]]] = {}
+
+ # if we start with all, insert the excluded statetypes as empty sets
+ # to prevent them from being included
+ if all_part:
+ new_types.update({state_type: set() for state_type in minus_wildcards})
+
+ # insert the plus wildcards
+ new_types.update({state_type: None for state_type in plus_wildcards})
+
+ # insert the specific state keys
+ for state_type, state_key in plus_state_keys:
+ if state_type in new_types:
+ entry = new_types[state_type]
+ if entry is not None:
+ entry.add(state_key)
+ elif not all_part:
+ # don't insert if the entire type is already included by
+ # include_others as this would actually shrink the state allowed
+ # by this filter.
+ new_types[state_type] = {state_key}
+
+ return StateFilter.freeze(new_types, include_others=all_part)
+
+ def approx_difference(self, other: "StateFilter") -> "StateFilter":
+ """
+ Returns a state filter which represents `self - other`.
+
+ This is useful for determining what state remains to be pulled out of the
+ database if we want the state included by `self` but already have the state
+ included by `other`.
+
+ The returned state filter
+ - MUST include all state events that are included by this filter (`self`)
+ unless they are included by `other`;
+ - MUST NOT include state events not included by this filter (`self`); and
+ - MAY be an over-approximation: the returned state filter
+ MAY additionally include some state events from `other`.
+
+ This implementation attempts to return the narrowest such state filter.
+ In the case that `self` contains wildcards for state types where
+ `other` contains specific state keys, an approximation must be made:
+ the returned state filter keeps the wildcard, as state filters are not
+ able to express 'all state keys except some given examples'.
+ e.g.
+ StateFilter(m.room.member -> None (wildcard))
+ minus
+ StateFilter(m.room.member -> {'@wombat:example.org'})
+ is approximated as
+ StateFilter(m.room.member -> None (wildcard))
+ """
+
+ # We first transform self and other into an alternative representation:
+ # - whether or not they include all events to begin with ('all')
+ # - if so, which event types are excluded? ('excludes')
+ # - which entire event types to include ('wildcards')
+ # - which concrete state keys to include ('concrete state keys')
+ (self_all, self_excludes), (
+ self_wildcards,
+ self_concrete_keys,
+ ) = self._decompose_into_four_parts()
+ (other_all, other_excludes), (
+ other_wildcards,
+ other_concrete_keys,
+ ) = other._decompose_into_four_parts()
+
+ # Start with an estimate of the difference based on self
+ new_all = self_all
+ # Wildcards from the other can be added to the exclusion filter
+ new_excludes = self_excludes | other_wildcards
+ # We remove wildcards that appeared as wildcards in the other
+ new_wildcards = self_wildcards - other_wildcards
+ # We filter out the concrete state keys that appear in the other
+ # as wildcards or concrete state keys.
+ new_concrete_keys = {
+ (state_type, state_key)
+ for (state_type, state_key) in self_concrete_keys
+ if state_type not in other_wildcards
+ } - other_concrete_keys
+
+ if other_all:
+ if self_all:
+ # If self starts with all, then we add as wildcards any
+ # types which appear in the other's exclusion filter (but
+ # aren't in the self exclusion filter). This is as the other
+ # filter will return everything BUT the types in its exclusion, so
+ # we need to add those excluded types that also match the self
+ # filter as wildcard types in the new filter.
+ new_wildcards |= other_excludes.difference(self_excludes)
+
+ # If other is an `include_others` then the difference isn't.
+ new_all = False
+ # (We have no need for excludes when we don't start with all, as there
+ # is nothing to exclude.)
+ new_excludes = set()
+
+ # We also filter out all state types that aren't in the exclusion
+ # list of the other.
+ new_wildcards &= other_excludes
+ new_concrete_keys = {
+ (state_type, state_key)
+ for (state_type, state_key) in new_concrete_keys
+ if state_type in other_excludes
+ }
+
+ # Transform our newly-constructed state filter from the alternative
+ # representation back into the normal StateFilter representation.
+ return StateFilter._recompose_from_four_parts(
+ new_all, new_excludes, new_wildcards, new_concrete_keys
+ )
+
class StateGroupStorage:
"""High level interface to fetching state for event."""
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 6f7cbe40..67081161 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,42 +16,62 @@ import logging
import threading
from collections import OrderedDict
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from types import TracebackType
+from typing import (
+ AsyncContextManager,
+ ContextManager,
+ Dict,
+ Generator,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
-from sortedcontainers import SortedSet
+from sortedcontainers import SortedList, SortedSet
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
logger = logging.getLogger(__name__)
+T = TypeVar("T")
+
+
class IdGenerator:
- def __init__(self, db_conn, table, column):
+ def __init__(
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ column: str,
+ ):
self._lock = threading.Lock()
self._next_id = _load_current_id(db_conn, table, column)
- def get_next(self):
+ def get_next(self) -> int:
with self._lock:
self._next_id += 1
return self._next_id
-def _load_current_id(db_conn, table, column, step=1):
- """
-
- Args:
- db_conn (object):
- table (str):
- column (str):
- step (int):
-
- Returns:
- int
- """
+def _load_current_id(
+ db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
+) -> int:
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor(txn_name="_load_current_id")
@@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- (val,) = cur.fetchone()
+ result = cur.fetchone()
+ assert result is not None
+ (val,) = result
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
@@ -93,16 +115,16 @@ class StreamIdGenerator:
def __init__(
self,
- db_conn,
- table,
- column,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
- step=1,
- ):
+ step: int = 1,
+ ) -> None:
assert step != 0
self._lock = threading.Lock()
- self._step = step
- self._current = _load_current_id(db_conn, table, column, step)
+ self._step: int = step
+ self._current: int = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
@@ -115,7 +137,7 @@ class StreamIdGenerator:
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
- def get_next(self):
+ def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -128,7 +150,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id
@contextmanager
- def manager():
+ def manager() -> Generator[int, None, None]:
try:
yield next_id
finally:
@@ -137,7 +159,7 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager())
- def get_next_mult(self, n):
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
@@ -155,7 +177,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id
@contextmanager
- def manager():
+ def manager() -> Generator[Sequence[int], None, None]:
try:
yield next_ids
finally:
@@ -215,7 +237,7 @@ class MultiWriterIdGenerator:
def __init__(
self,
- db_conn,
+ db_conn: LoggingDatabaseConnection,
db: DatabasePool,
stream_name: str,
instance_name: str,
@@ -223,7 +245,7 @@ class MultiWriterIdGenerator:
sequence_name: str,
writers: List[str],
positive: bool = True,
- ):
+ ) -> None:
self._db = db
self._stream_name = stream_name
self._instance_name = instance_name
@@ -243,6 +265,15 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids: SortedSet[int] = SortedSet()
+ # We also need to track when we've requested some new stream IDs but
+ # they haven't yet been added to the `_unfinished_ids` set. Every time
+ # we request a new stream ID we add the current max stream ID to the
+ # list, and remove it once we've added the newly allocated IDs to the
+ # `_unfinished_ids` set. This means that we *may* be allocated stream
+ # IDs above those in the list, and so we can't advance the local current
+ # position beyond the minimum stream ID in this list.
+ self._in_flight_fetches: SortedList[int] = SortedList()
+
# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
self._finished_ids: Set[int] = set()
@@ -268,6 +299,9 @@ class MultiWriterIdGenerator:
)
self._known_persisted_positions: List[int] = []
+ # The maximum stream ID that we have seen been allocated across any writer.
+ self._max_seen_allocated_stream_id = 1
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
@@ -283,11 +317,15 @@ class MultiWriterIdGenerator:
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
+ self._max_seen_allocated_stream_id = max(
+ self._current_positions.values(), default=1
+ )
+
def _load_current_ids(
self,
- db_conn,
+ db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]],
- ):
+ ) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream.
@@ -335,7 +373,9 @@ class MultiWriterIdGenerator:
"agg": "MAX" if self._positive else "-MIN",
}
cur.execute(sql)
- (stream_id,) = cur.fetchone()
+ result = cur.fetchone()
+ assert result is not None
+ (stream_id,) = result
max_stream_id = max(max_stream_id, stream_id)
@@ -354,7 +394,7 @@ class MultiWriterIdGenerator:
self._persisted_upto_position = min_stream_id
- rows = []
+ rows: List[Tuple[str, int]] = []
for table, instance_column, id_column in tables:
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
@@ -367,7 +407,8 @@ class MultiWriterIdGenerator:
}
cur.execute(sql, (min_stream_id * self._return_factor,))
- rows.extend(cur)
+ # Cast safety: this corresponds to the types returned by the query above.
+ rows.extend(cast(Iterable[Tuple[str, int]], cur))
# Sort so that we handle rows in order for each instance.
rows.sort()
@@ -385,13 +426,35 @@ class MultiWriterIdGenerator:
cur.close()
- def _load_next_id_txn(self, txn) -> int:
- return self._sequence_gen.get_next_id_txn(txn)
+ def _load_next_id_txn(self, txn: Cursor) -> int:
+ stream_ids = self._load_next_mult_id_txn(txn, 1)
+ return stream_ids[0]
- def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
- return self._sequence_gen.get_next_mult_txn(txn, n)
+ def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
+ # We need to track that we've requested some more stream IDs, and what
+ # the current max allocated stream ID is. This is to prevent a race
+ # where we've been allocated stream IDs but they have not yet been added
+ # to the `_unfinished_ids` set, allowing the current position to advance
+ # past them.
+ with self._lock:
+ current_max = self._max_seen_allocated_stream_id
+ self._in_flight_fetches.add(current_max)
+
+ try:
+ stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)
+
+ with self._lock:
+ self._unfinished_ids.update(stream_ids)
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
+ )
+ finally:
+ with self._lock:
+ self._in_flight_fetches.remove(current_max)
+
+ return stream_ids
- def get_next(self):
+ def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -403,9 +466,12 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
- return _MultiWriterCtxManager(self)
+ # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
+ # controls the return type. If `None` or omitted, the context manager yields
+ # a single integer stream_id; otherwise it yields a list of stream_ids.
+ return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
- def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
"""
Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids:
@@ -417,9 +483,10 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
- return _MultiWriterCtxManager(self, n)
+ # Cast safety: see get_next.
+ return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
- def get_next_txn(self, txn: LoggingTransaction):
+ def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Usage:
@@ -434,9 +501,6 @@ class MultiWriterIdGenerator:
next_id = self._load_next_id_txn(txn)
- with self._lock:
- self._unfinished_ids.add(next_id)
-
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
@@ -457,7 +521,7 @@ class MultiWriterIdGenerator:
return self._return_factor * next_id
- def _mark_id_as_finished(self, next_id: int):
+ def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the
current position if possible.
"""
@@ -468,15 +532,27 @@ class MultiWriterIdGenerator:
new_cur: Optional[int] = None
- if self._unfinished_ids:
+ if self._unfinished_ids or self._in_flight_fetches:
# If there are unfinished IDs then the new position will be the
- # largest finished ID less than the minimum unfinished ID.
+ # largest finished ID strictly less than the minimum unfinished
+ # ID.
+
+ # The minimum unfinished ID needs to take account of both
+ # `_unfinished_ids` and `_in_flight_fetches`.
+ if self._unfinished_ids and self._in_flight_fetches:
+ # `_in_flight_fetches` stores the maximum safe stream ID, so
+ # we add one to make it equivalent to the minimum unsafe ID.
+ min_unfinished = min(
+ self._unfinished_ids[0], self._in_flight_fetches[0] + 1
+ )
+ elif self._in_flight_fetches:
+ min_unfinished = self._in_flight_fetches[0] + 1
+ else:
+ min_unfinished = self._unfinished_ids[0]
finished = set()
-
- min_unfinshed = self._unfinished_ids[0]
for s in self._finished_ids:
- if s < min_unfinshed:
+ if s < min_unfinished:
if new_cur is None or new_cur < s:
new_cur = s
else:
@@ -534,7 +610,7 @@ class MultiWriterIdGenerator:
for name, i in self._current_positions.items()
}
- def advance(self, instance_name: str, new_id: int):
+ def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
@@ -546,6 +622,10 @@ class MultiWriterIdGenerator:
new_id, self._current_positions.get(instance_name, 0)
)
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, new_id
+ )
+
self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int:
@@ -560,7 +640,7 @@ class MultiWriterIdGenerator:
with self._lock:
return self._return_factor * self._persisted_upto_position
- def _add_persisted_position(self, new_id: int):
+ def _add_persisted_position(self, new_id: int) -> None:
"""Record that we have persisted a position.
This is used to keep the `_current_positions` up to date.
@@ -576,7 +656,11 @@ class MultiWriterIdGenerator:
# to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name)
- if our_current_position and not self._unfinished_ids:
+ if (
+ our_current_position
+ and not self._unfinished_ids
+ and not self._in_flight_fetches
+ ):
self._current_positions[self._instance_name] = max(
our_current_position, new_id
)
@@ -606,7 +690,7 @@ class MultiWriterIdGenerator:
# do.
break
- def _update_stream_positions_table_txn(self, txn: Cursor):
+ def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
"""Update the `stream_positions` table with newly persisted position."""
if not self._writers:
@@ -628,20 +712,25 @@ class MultiWriterIdGenerator:
txn.execute(sql, (self._stream_name, self._instance_name, pos))
-@attr.s(slots=True)
-class _AsyncCtxManagerWrapper:
+@attr.s(frozen=True, auto_attribs=True)
+class _AsyncCtxManagerWrapper(Generic[T]):
"""Helper class to convert a plain context manager to an async one.
This is mainly useful if you have a plain context manager but the interface
requires an async one.
"""
- inner = attr.ib()
+ inner: ContextManager[T]
- async def __aenter__(self):
+ async def __aenter__(self) -> T:
return self.inner.__enter__()
- async def __aexit__(self, exc_type, exc, tb):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> Optional[bool]:
return self.inner.__exit__(exc_type, exc, tb)
@@ -663,15 +752,17 @@ class _MultiWriterCtxManager:
db_autocommit=True,
)
- with self.id_gen._lock:
- self.id_gen._unfinished_ids.update(self.stream_ids)
-
if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
return [i * self.id_gen._return_factor for i in self.stream_ids]
- async def __aexit__(self, exc_type, exc, tb):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> bool:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index bb33e04f..75268cbe 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
@@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
"""See SequenceGenerator.check_consistency for docstring."""
txn = db_conn.cursor(txn_name="sequence.check_consistency")
@@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
- ):
+ ) -> None:
# There is nothing to do for in memory sequences
pass
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index bd234549..abf53d14 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -50,7 +50,16 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
if type(obj) is frozendict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
- return obj._dict
+ try:
+ # Safety: we catch the AttributeError immediately below.
+ # See https://github.com/matrix-org/python-canonicaljson/issues/36#issuecomment-927816293
+ # for discussion on how frozendict's internals have changed over time.
+ return obj._dict # type: ignore[attr-defined]
+ except AttributeError:
+ # When the C implementation of frozendict is used,
+ # there isn't a `_dict` attribute with a dict
+ # so we resort to making a copy of the frozendict
+ return dict(obj)
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 82d918a0..5df80ea8 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -438,7 +438,8 @@ class ReadWriteLock:
try:
yield
finally:
- new_defer.callback(None)
+ with PreserveLoggingContext():
+ new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
return _ctx_manager()
@@ -466,7 +467,8 @@ class ReadWriteLock:
try:
yield
finally:
- new_defer.callback(None)
+ with PreserveLoggingContext():
+ new_defer.callback(None)
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index e58dd91e..470f4f91 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -85,7 +85,7 @@ class CachedCall(Generic[TV]):
# result in the deferred, since `awaiting` a deferred destroys its result.
# (Also, if it's a Failure, GCing the deferred would log a critical error
# about unhandled Failures)
- def got_result(r):
+ def got_result(r: Union[TV, Failure]) -> None:
self._result = r
self._deferred.addBoth(got_result)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 6262efe0..da502aec 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -31,6 +31,7 @@ from prometheus_client import Gauge
from twisted.internet import defer
from twisted.python import failure
+from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache
@@ -112,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
self.thread: Optional[threading.Thread] = None
@property
- def max_entries(self):
+ def max_entries(self) -> int:
return self.cache.max_size
def check_thread(self) -> None:
@@ -258,7 +259,7 @@ class DeferredCache(Generic[KT, VT]):
return False
- def cb(result) -> None:
+ def cb(result: VT) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
@@ -270,7 +271,7 @@ class DeferredCache(Generic[KT, VT]):
# not have been. Either way, let's double-check now.
entry.invalidate()
- def eb(_fail) -> None:
+ def eb(_fail: Failure) -> None:
compare_and_pop()
entry.invalidate()
@@ -284,11 +285,11 @@ class DeferredCache(Generic[KT, VT]):
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
- ):
+ ) -> None:
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
- def invalidate(self, key):
+ def invalidate(self, key) -> None:
"""Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4ff62b40..a0a7a9de 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
try:
from pympler.asizeof import Asizer
- def _get_size_of(val: Any, *, recurse=True) -> int:
+ def _get_size_of(val: Any, *, recurse: bool = True) -> int:
"""Get an estimate of the size in bytes of the object.
Args:
@@ -71,7 +71,7 @@ try:
except ImportError:
- def _get_size_of(val: Any, *, recurse=True) -> int:
+ def _get_size_of(val: Any, *, recurse: bool = True) -> int:
return 0
@@ -85,15 +85,6 @@ VT = TypeVar("VT")
# a general type var, distinct from either KT or VT
T = TypeVar("T")
-
-def enumerate_leaves(node, depth):
- if depth == 0:
- yield node
- else:
- for n in node.values():
- yield from enumerate_leaves(n, depth - 1)
-
-
P = TypeVar("P")
@@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]):
__slots__ = ["last_access_ts_secs"]
- def update_last_access(self, clock: Clock):
+ def update_last_access(self, clock: Clock) -> None:
self.last_access_ts_secs = int(clock.time())
@@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
@wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
"""Walks the global cache list to find cache entries that haven't been
accessed in the given number of seconds.
"""
@@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int):
logger.info("Dropped %d items from caches", i)
-def setup_expire_lru_cache_entries(hs: "HomeServer"):
+def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
"""Start a background job that expires all cache entries if they have not
been accessed for the given number of seconds.
"""
@@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"):
)
-class _Node:
+class _Node(Generic[KT, VT]):
__slots__ = [
"_list_node",
"_global_list_node",
@@ -197,8 +188,8 @@ class _Node:
def __init__(
self,
root: "ListNode[_Node]",
- key,
- value,
+ key: KT,
+ value: VT,
cache: "weakref.ReferenceType[LruCache]",
clock: Clock,
callbacks: Collection[Callable[[], None]] = (),
@@ -409,7 +400,7 @@ class LruCache(Generic[KT, VT]):
def synchronized(f: FT) -> FT:
@wraps(f)
- def inner(*args, **kwargs):
+ def inner(*args: Any, **kwargs: Any) -> Any:
with lock:
return f(*args, **kwargs)
@@ -418,17 +409,19 @@ class LruCache(Generic[KT, VT]):
cached_cache_len = [0]
if size_callback is not None:
- def cache_len():
+ def cache_len() -> int:
return cached_cache_len[0]
else:
- def cache_len():
+ def cache_len() -> int:
return len(cache)
self.len = synchronized(cache_len)
- def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
+ def add_node(
+ key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
+ ) -> None:
node = _Node(
list_root,
key,
@@ -446,7 +439,7 @@ class LruCache(Generic[KT, VT]):
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
- def move_node_to_front(node: _Node):
+ def move_node_to_front(node: _Node) -> None:
node.move_to_front(real_clock, list_root)
def delete_node(node: _Node) -> int:
@@ -488,7 +481,7 @@ class LruCache(Generic[KT, VT]):
default: Optional[T] = None,
callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
- ):
+ ) -> Union[None, T, VT]:
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
@@ -502,7 +495,9 @@ class LruCache(Generic[KT, VT]):
return default
@synchronized
- def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
+ def cache_set(
+ key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+ ) -> None:
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@@ -547,7 +542,7 @@ class LruCache(Generic[KT, VT]):
...
@synchronized
- def cache_pop(key: KT, default: Optional[T] = None):
+ def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
node = cache.get(key, None)
if node:
delete_node(node)
@@ -612,25 +607,25 @@ class LruCache(Generic[KT, VT]):
self.contains = cache_contains
self.clear = cache_clear
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
result = self.get(key, self.sentinel)
if result is self.sentinel:
raise KeyError()
else:
- return result
+ return cast(VT, result)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: KT, value: VT) -> None:
self.set(key, value)
- def __delitem__(self, key, value):
+ def __delitem__(self, key: KT, value: VT) -> None:
result = self.pop(key, self.sentinel)
if result is self.sentinel:
raise KeyError()
- def __len__(self):
+ def __len__(self) -> int:
return self.len()
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return self.contains(key)
def set_cache_factor(self, factor: float) -> bool:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index ed720433..88ccf443 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -104,8 +104,8 @@ class ResponseCache(Generic[KV]):
return None
def _set(
- self, context: ResponseCacheContext[KV], deferred: defer.Deferred
- ) -> defer.Deferred:
+ self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
+ ) -> "defer.Deferred[RV]":
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
@@ -126,7 +126,7 @@ class ResponseCache(Generic[KV]):
key = context.cache_key
self.pending_result_cache[key] = result
- def on_complete(r):
+ def on_complete(r: RV) -> RV:
# if this cache has a non-zero timeout, and the callback has not cleared
# the should_cache bit, we leave it in the cache for now and schedule
# its removal later.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 27b1da23..330709b8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -40,10 +40,10 @@ class StreamChangeCache:
self,
name: str,
current_stream_pos: int,
- max_size=10000,
+ max_size: int = 10000,
prefilled_cache: Optional[Mapping[EntityType, int]] = None,
- ):
- self._original_max_size = max_size
+ ) -> None:
+ self._original_max_size: int = max_size
self._max_size = math.floor(max_size)
self._entity_to_key: Dict[EntityType, int] = {}
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 46afe3f9..0b9ac26b 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -159,12 +159,12 @@ class TTLCache(Generic[KT, VT]):
del self._expiry_list[0]
-@attr.s(frozen=True, slots=True)
-class _CacheEntry:
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
- expiry_time = attr.ib(type=float)
- ttl = attr.ib(type=float)
- key = attr.ib()
- value = attr.ib()
+ expiry_time: float
+ ttl: float
+ key: Any # should be KT
+ value: Any # should be VT
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index f1a351cf..de04f34e 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -19,6 +19,8 @@ import logging
import os
import signal
import sys
+from types import FrameType, TracebackType
+from typing import NoReturn, Type
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
@@ -97,7 +99,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# (we don't normally expect reactor.run to raise any exceptions, but this will
# also catch any other uncaught exceptions before we get that far.)
- def excepthook(type_, value, traceback):
+ def excepthook(
+ type_: Type[BaseException], value: BaseException, traceback: TracebackType
+ ) -> None:
logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
sys.excepthook = excepthook
@@ -119,7 +123,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
sys.exit(1)
# write a log line on SIGTERM.
- def sigterm(signum, frame):
+ def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn:
logger.warning("Caught signal %s. Stopping daemon." % signum)
sys.exit(0)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1b82dca8..1e784b3f 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -14,9 +14,11 @@
import logging
from functools import wraps
-from typing import Any, Callable, Optional, TypeVar, cast
+from types import TracebackType
+from typing import Any, Callable, Optional, Type, TypeVar, cast
from prometheus_client import Counter
+from typing_extensions import Protocol
from synapse.logging.context import (
ContextResourceUsage,
@@ -24,6 +26,7 @@ from synapse.logging.context import (
current_context,
)
from synapse.metrics import InFlightGauge
+from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -64,6 +67,10 @@ in_flight = InFlightGauge(
T = TypeVar("T", bound=Callable[..., Any])
+class HasClock(Protocol):
+ clock: Clock
+
+
def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
"""
Used to decorate an async function with a `Measure` context manager.
@@ -86,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
block_name = func.__name__ if name is None else name
@wraps(func)
- async def measured_func(self, *args, **kwargs):
+ async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
@@ -104,10 +111,10 @@ class Measure:
"start",
]
- def __init__(self, clock, name: str):
+ def __init__(self, clock: Clock, name: str) -> None:
"""
Args:
- clock: A n object with a "time()" method, which returns the current
+ clock: An object with a "time()" method, which returns the current
time in seconds.
name: The name of the metric to report.
"""
@@ -124,7 +131,7 @@ class Measure:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
self._logging_context = LoggingContext(str(curr_context), parent_context)
- self.start: Optional[int] = None
+ self.start: Optional[float] = None
def __enter__(self) -> "Measure":
if self.start is not None:
@@ -138,7 +145,12 @@ class Measure:
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
@@ -168,8 +180,9 @@ class Measure:
"""
return self._logging_context.get_resource_usage()
- def _update_in_flight(self, metrics):
+ def _update_in_flight(self, metrics) -> None:
"""Gets called when processing in flight metrics"""
+ assert self.start is not None
duration = self.clock.time() - self.start
metrics.real_time_max = max(metrics.real_time_max, duration)
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 9dd010af..1f18654d 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
import functools
import sys
-from typing import Any, Callable, List
+from typing import Any, Callable, Generator, List, TypeVar
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -24,6 +24,9 @@ from twisted.python.failure import Failure
_already_patched = False
+T = TypeVar("T")
+
+
def do_patch() -> None:
"""
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@@ -37,15 +40,19 @@ def do_patch() -> None:
if _already_patched:
return
- def new_inline_callbacks(f):
+ def new_inline_callbacks(
+ f: Callable[..., Generator["Deferred[object]", object, T]]
+ ) -> Callable[..., "Deferred[T]"]:
@functools.wraps(f)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
start_context = current_context()
changes: List[str] = []
- orig = orig_inline_callbacks(_check_yield_points(f, changes))
+ orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+ _check_yield_points(f, changes)
+ )
try:
- res = orig(*args, **kwargs)
+ res: "Deferred[T]" = orig(*args, **kwargs)
except Exception:
if current_context() != start_context:
for err in changes:
@@ -84,7 +91,7 @@ def do_patch() -> None:
print(err, file=sys.stderr)
raise Exception(err)
- def check_ctx(r):
+ def check_ctx(r: T) -> T:
if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
@@ -107,7 +114,10 @@ def do_patch() -> None:
_already_patched = True
-def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
+def _check_yield_points(
+ f: Callable[..., Generator["Deferred[object]", object, T]],
+ changes: List[str],
+) -> Callable:
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct.
@@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
from synapse.logging.context import current_context
@functools.wraps(f)
- def check_yield_points_inner(*args, **kwargs):
+ def check_yield_points_inner(
+ *args: Any, **kwargs: Any
+ ) -> Generator["Deferred[object]", object, T]:
gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index baa9190a..389adf00 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
bool: whether the 3PID medium/address is allowed to be added to this HS
"""
- if hs.config.allowed_local_3pids:
- for constraint in hs.config.allowed_local_3pids:
+ if hs.config.registration.allowed_local_3pids:
+ for constraint in hs.config.registration.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
address,
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 1c20b24b..899ee0ad 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -15,14 +15,18 @@
import logging
import os
import subprocess
+from types import ModuleType
+from typing import Dict
logger = logging.getLogger(__name__)
+version_cache: Dict[ModuleType, str] = {}
-def get_version_string(module) -> str:
+
+def get_version_string(module: ModuleType) -> str:
"""Given a module calculate a git-aware version string for it.
- If called on a module not in a git checkout will return `__verison__`.
+ If called on a module not in a git checkout will return `__version__`.
Args:
module (module)
@@ -31,11 +35,13 @@ def get_version_string(module) -> str:
str
"""
- cached_version = getattr(module, "_synapse_version_string_cache", None)
- if cached_version:
+ cached_version = version_cache.get(module)
+ if cached_version is not None:
return cached_version
- version_string = module.__version__
+ # We want this to fail loudly with an AttributeError. Type-ignore this so
+ # mypy only considers the happy path.
+ version_string = module.__version__ # type: ignore[attr-defined]
try:
null = open(os.devnull, "w")
@@ -97,10 +103,15 @@ def get_version_string(module) -> str:
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- version_string = "%s (%s)" % (module.__version__, git_version)
+ version_string = "%s (%s)" % (
+ # If the __version__ attribute doesn't exist, we'll have failed
+ # loudly above.
+ module.__version__, # type: ignore[attr-defined]
+ git_version,
+ )
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- module._synapse_version_string_cache = version_string
+ version_cache[module] = version_string
return version_string