summaryrefslogtreecommitdiff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-07-30 10:04:43 +0200
committerAndrej Shadura <andrewsh@debian.org>2021-07-30 10:04:43 +0200
commit679ff900f5e9b83af346904d7c8604cc5917608d (patch)
tree6e38ee74d09dcfb5a348090be1f0feac4fae47f9 /synapse/handlers
parent81e25363896b892f797b8e8ca906f2b4b49a386a (diff)
New upstream version 1.39.0
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py10
-rw-r--r--synapse/handlers/account_validity.py128
-rw-r--r--synapse/handlers/admin.py4
-rw-r--r--synapse/handlers/appservice.py6
-rw-r--r--synapse/handlers/auth.py16
-rw-r--r--synapse/handlers/cas.py4
-rw-r--r--synapse/handlers/device.py14
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/directory.py11
-rw-r--r--synapse/handlers/e2e_keys.py40
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/federation.py53
-rw-r--r--synapse/handlers/groups_local.py4
-rw-r--r--synapse/handlers/identity.py4
-rw-r--r--synapse/handlers/initial_sync.py14
-rw-r--r--synapse/handlers/message.py44
-rw-r--r--synapse/handlers/oidc.py56
-rw-r--r--synapse/handlers/pagination.py4
-rw-r--r--synapse/handlers/presence.py28
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/receipts.py19
-rw-r--r--synapse/handlers/register.py20
-rw-r--r--synapse/handlers/room.py26
-rw-r--r--synapse/handlers/room_list.py50
-rw-r--r--synapse/handlers/saml.py8
-rw-r--r--synapse/handlers/search.py8
-rw-r--r--synapse/handlers/space_summary.py84
-rw-r--r--synapse/handlers/sso.py12
-rw-r--r--synapse/handlers/stats.py37
-rw-r--r--synapse/handlers/sync.py38
-rw-r--r--synapse/handlers/typing.py28
-rw-r--r--synapse/handlers/user_directory.py2
32 files changed, 487 insertions, 297 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d800e169..6a05a653 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,8 +15,6 @@
import logging
from typing import TYPE_CHECKING, Optional
-import synapse.state
-import synapse.storage
import synapse.types
from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
@@ -38,10 +36,10 @@ class BaseHandler:
"""
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastore() # type: synapse.storage.DataStore
+ self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
+ self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
@@ -55,12 +53,12 @@ class BaseHandler:
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
- self.admin_redaction_ratelimiter = Ratelimiter(
+ self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
- ) # type: Optional[Ratelimiter]
+ )
else:
self.admin_redaction_ratelimiter = None
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d752cf34..078accd6 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -15,9 +15,11 @@
import email.mime.multipart
import email.utils
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
-from synapse.api.errors import StoreError, SynapseError
+from twisted.web.http import Request
+
+from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
@@ -27,6 +29,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Types for callbacks to be registered via the module api
+IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
+ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
+# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
+# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
+ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
+ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
+ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
+
class AccountValidityHandler:
def __init__(self, hs: "HomeServer"):
@@ -70,6 +81,99 @@ class AccountValidityHandler:
if hs.config.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
+ self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
+ self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
+ self._on_legacy_send_mail_callback: Optional[
+ ON_LEGACY_SEND_MAIL_CALLBACK
+ ] = None
+ self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
+
+ # The legacy admin requests callback isn't a protected attribute because we need
+ # to access it from the admin servlet, which is outside of this handler.
+ self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
+
+ def register_account_validity_callbacks(
+ self,
+ is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
+ on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+ on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
+ on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
+ on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
+ ):
+ """Register callbacks from module for each hook."""
+ if is_user_expired is not None:
+ self._is_user_expired_callbacks.append(is_user_expired)
+
+ if on_user_registration is not None:
+ self._on_user_registration_callbacks.append(on_user_registration)
+
+ # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
+ # an admin one). As part of moving the feature into a module, we need to change
+ # the path from /_matrix/client/unstable/account_validity/... to
+ # /_synapse/client/account_validity, because:
+ #
+ # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
+ # * the way we register servlets means that modules can't register resources
+ # under /_matrix/client
+ #
+ # We need to allow for a transition period between the old and new endpoints
+ # in order to allow for clients to update (and for emails to be processed).
+ #
+ # Once the email-account-validity module is loaded, it will take control of account
+ # validity by moving the rows from our `account_validity` table into its own table.
+ #
+ # Therefore, we need to allow modules (in practice just the one implementing the
+ # email-based account validity) to temporarily hook into the legacy endpoints so we
+ # can route the traffic coming into the old endpoints into the module, which is
+ # why we have the following three temporary hooks.
+ if on_legacy_send_mail is not None:
+ if self._on_legacy_send_mail_callback is not None:
+ raise RuntimeError("Tried to register on_legacy_send_mail twice")
+
+ self._on_legacy_send_mail_callback = on_legacy_send_mail
+
+ if on_legacy_renew is not None:
+ if self._on_legacy_renew_callback is not None:
+ raise RuntimeError("Tried to register on_legacy_renew twice")
+
+ self._on_legacy_renew_callback = on_legacy_renew
+
+ if on_legacy_admin_request is not None:
+ if self.on_legacy_admin_request_callback is not None:
+ raise RuntimeError("Tried to register on_legacy_admin_request twice")
+
+ self.on_legacy_admin_request_callback = on_legacy_admin_request
+
+ async def is_user_expired(self, user_id: str) -> bool:
+ """Checks if a user has expired against third-party modules.
+
+ Args:
+ user_id: The user to check the expiry of.
+
+ Returns:
+ Whether the user has expired.
+ """
+ for callback in self._is_user_expired_callbacks:
+ expired = await callback(user_id)
+ if expired is not None:
+ return expired
+
+ if self._account_validity_enabled:
+ # If no module could determine whether the user has expired and the legacy
+ # configuration is enabled, fall back to it.
+ return await self.store.is_account_expired(user_id, self.clock.time_msec())
+
+ return False
+
+ async def on_user_registration(self, user_id: str):
+ """Tell third-party modules about a user's registration.
+
+ Args:
+ user_id: The ID of the newly registered user.
+ """
+ for callback in self._on_user_registration_callbacks:
+ await callback(user_id)
+
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
@@ -95,6 +199,17 @@ class AccountValidityHandler:
Raises:
SynapseError if the user is not set to renew.
"""
+ # If a module supports sending a renewal email from here, do that, otherwise do
+ # the legacy dance.
+ if self._on_legacy_send_mail_callback is not None:
+ await self._on_legacy_send_mail_callback(user_id)
+ return
+
+ if not self._account_validity_renew_by_email_enabled:
+ raise AuthError(
+ 403, "Account renewal via email is disabled on this server."
+ )
+
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
# If this user isn't set to be expired, raise an error.
@@ -209,6 +324,10 @@ class AccountValidityHandler:
token is considered stale. A token is stale if the 'token_used_ts_ms' db column
is non-null.
+ This method exists to support handling the legacy account validity /renew
+ endpoint. If a module implements the on_legacy_renew callback, then this process
+ is delegated to the module instead.
+
Args:
renewal_token: Token sent with the renewal request.
Returns:
@@ -218,6 +337,11 @@ class AccountValidityHandler:
* An int representing the user's expiry timestamp as milliseconds since the
epoch, or 0 if the token was invalid.
"""
+ # If a module supports triggering a renew from here, do that, otherwise do the
+ # legacy dance.
+ if self._on_legacy_renew_callback is not None:
+ return await self._on_legacy_renew_callback(renewal_token)
+
try:
(
user_id,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index d75a8b15..bfa7f2c5 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
to_key = RoomStreamToken(None, stream_ordering)
# Events that we've processed in this room
- written_events = set() # type: Set[str]
+ written_events: Set[str] = set()
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
# events "children".
- unseen_to_child_events = {} # type: Dict[str, Set[str]]
+ unseen_to_child_events: Dict[str, Set[str]] = {}
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 862638cc..21a17cd2 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -96,7 +96,7 @@ class ApplicationServicesHandler:
self.current_max, limit
)
- events_by_room = {} # type: Dict[str, List[EventBase]]
+ events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@@ -275,7 +275,7 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
- events = [] # type: List[JsonDict]
+ events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
@@ -375,7 +375,7 @@ class ApplicationServicesHandler:
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
- protocols = {} # type: Dict[str, List[JsonDict]]
+ protocols: Dict[str, List[JsonDict]] = {}
# Collect up all the individual protocol responses out of the ASes
for s in services:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e2ac595a..22a85522 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
+ self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
@@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
# A mapping of user ID to extra attributes to include in the login
# response.
- self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
+ self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
async def validate_user_via_ui_auth(
self,
@@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
- sid = None # type: Optional[str]
+ sid: Optional[str] = None
authdict = clientdict.pop("auth", {})
if "session" in authdict:
sid = authdict["session"]
@@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
)
# check auth type currently being presented
- errordict = {} # type: Dict[str, Any]
+ errordict: Dict[str, Any] = {}
if "type" in authdict:
- login_type = authdict["type"] # type: str
+ login_type: str = authdict["type"]
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
@@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
- params = {} # type: Dict[str, Any]
+ params: Dict[str, Any] = {}
for f in public_flows:
for stage in f:
@@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- user_id_to_verify = await self.get_session_data(
+ user_id_to_verify: str = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
- ) # type: str
+ )
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 7346ccfe..0325f86e 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -40,7 +40,7 @@ class CasError(Exception):
def __str__(self):
if self.error_description:
- return "{}: {}".format(self.error, self.error_description)
+ return f"{self.error}: {self.error_description}"
return self.error
@@ -171,7 +171,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
- attributes = {} # type: Dict[str, List[Optional[str]]]
+ attributes: Dict[str, List[Optional[str]]] = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 95bdc590..46ee8344 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id
)
- hosts = set() # type: Set[str]
+ hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
@@ -613,20 +613,20 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = (
- {}
- ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
+ self._pending_updates: Dict[
+ str, List[Tuple[str, str, Iterable[str], JsonDict]]
+ ] = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
- self._seen_updates = ExpiringCache(
+ self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
- ) # type: ExpiringCache[str, Set[str]]
+ )
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
@@ -755,7 +755,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
- seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
+ seen_updates: Set[str] = self._seen_updates.get(user_id, set())
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 580b9415..679b47f0 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -203,7 +203,7 @@ class DeviceMessageHandler:
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
- remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+ remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 4064a2b8..d487fee6 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -22,6 +22,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
NotFoundError,
+ RequestSendFailed,
ShadowBanError,
StoreError,
SynapseError,
@@ -236,9 +237,9 @@ class DirectoryHandler(BaseHandler):
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
- result = await self.get_association_from_room_alias(
- room_alias
- ) # type: Optional[RoomAliasMapping]
+ result: Optional[
+ RoomAliasMapping
+ ] = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
@@ -252,12 +253,14 @@ class DirectoryHandler(BaseHandler):
retry_on_dns_fail=False,
ignore_backoff=True,
)
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to fetch alias")
except CodeMessageException as e:
logging.warning("Error retrieving alias")
if e.code == 404:
fed_result = None
else:
- raise
+ raise SynapseError(502, "Failed to fetch alias")
if fed_result and "room_id" in fed_result and "servers" in fed_result:
room_id = fed_result["room_id"]
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 3972849d..d9237085 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -115,9 +115,9 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
- device_keys_query = query_body.get(
+ device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
- ) # type: Dict[str, Iterable[str]]
+ )
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -136,7 +136,7 @@ class E2eKeysHandler:
# First get local devices.
# A map of destination -> failure response.
- failures = {} # type: Dict[str, JsonDict]
+ failures: Dict[str, JsonDict] = {}
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@@ -151,11 +151,9 @@ class E2eKeysHandler:
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
- remote_queries_not_in_cache = (
- {}
- ) # type: Dict[str, Dict[str, Iterable[str]]]
+ remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
- query_list = [] # type: List[Tuple[str, Optional[str]]]
+ query_list: List[Tuple[str, Optional[str]]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
@@ -362,9 +360,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
- local_query = [] # type: List[Tuple[str, Optional[str]]]
+ local_query: List[Tuple[str, Optional[str]]] = []
- result_dict = {} # type: Dict[str, Dict[str, dict]]
+ result_dict: Dict[str, Dict[str, dict]] = {}
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@@ -402,9 +400,9 @@ class E2eKeysHandler:
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
"""Handle a device key query from a federated server"""
- device_keys_query = query_body.get(
+ device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
- ) # type: Dict[str, Optional[List[str]]]
+ )
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@@ -421,8 +419,8 @@ class E2eKeysHandler:
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
- local_query = [] # type: List[Tuple[str, str, str]]
- remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
+ local_query: List[Tuple[str, str, str]] = []
+ remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
@@ -439,8 +437,8 @@ class E2eKeysHandler:
results = await self.store.claim_e2e_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key.
- json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
- failures = {} # type: Dict[str, JsonDict]
+ json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
@@ -768,8 +766,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
- signature_list = [] # type: List[SignatureListItem]
- failures = {} # type: Dict[str, Dict[str, JsonDict]]
+ signature_list: List["SignatureListItem"] = []
+ failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
@@ -930,8 +928,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
- signature_list = [] # type: List[SignatureListItem]
- failures = {} # type: Dict[str, Dict[str, JsonDict]]
+ signature_list: List["SignatureListItem"] = []
+ failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
@@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
+ self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
@@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
- device_ids = [] # type: List[str]
+ device_ids: List[str] = []
logger.info("pending updates: %r", pending_updates)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f134f1e2..4b3f0370 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
- to_add = [] # type: List[JsonDict]
+ to_add: List[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
continue
@@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
- users = await self.store.get_users_in_room(
+ users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
- ) # type: Iterable[str]
+ )
else:
users = [event.state_key]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 991ec991..57287199 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
- self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
+ self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._room_backfill = Linearizer("room_backfill")
@@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps = list(ours.values()) # type: List[StateMap[str]]
+ state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it.
del ours
@@ -735,7 +735,7 @@ class FederationHandler(BaseHandler):
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
- (await self.store.get_events(missing_desired_events, allow_rejected=True))
+ await self.store.get_events(missing_desired_events, allow_rejected=True)
)
# check for events which were in the wrong room.
@@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
# exact key to expect. Otherwise check it matches any key we
# have for that device.
- current_keys = [] # type: Container[str]
+ current_keys: Container[str] = []
if device:
keys = device.get("keys", {}).get("keys", {})
@@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
- joined_domains = {} # type: Dict[str, int]
+ joined_domains: Dict[str, int] = {}
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
@@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version(room_id)
- event_map = {} # type: Dict[str, EventBase]
+ event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str):
with nested_logging_context(event_id):
@@ -1414,12 +1414,15 @@ class FederationHandler(BaseHandler):
Invites must be signed by the invitee's server before distribution.
"""
- pdu = await self.federation_client.send_invite(
- destination=target_host,
- room_id=event.room_id,
- event_id=event.event_id,
- pdu=event,
- )
+ try:
+ pdu = await self.federation_client.send_invite(
+ destination=target_host,
+ room_id=event.room_id,
+ event_id=event.event_id,
+ pdu=event,
+ )
+ except RequestSendFailed:
+ raise SynapseError(502, f"Can't connect to server {target_host}")
return pdu
@@ -1593,7 +1596,7 @@ class FederationHandler(BaseHandler):
# Ask the remote server to create a valid knock event for us. Once received,
# we sign the event
- params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
+ params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
origin, event, event_format_version = await self._make_and_verify_event(
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
@@ -1931,7 +1934,7 @@ class FederationHandler(BaseHandler):
builder=builder
)
- event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -2023,7 +2026,7 @@ class FederationHandler(BaseHandler):
# for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK:
- event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -2450,14 +2453,14 @@ class FederationHandler(BaseHandler):
state_sets_d = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
- state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
+ state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
- current_state_ids = {
+ current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
- } # type: StateMap[str]
+ }
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2814,7 +2817,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
- event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
+ event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else:
event_key = None
state_updates = {
@@ -3031,9 +3034,13 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context)
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
- await self.federation_client.forward_third_party_invite(
- destinations, room_id, event_dict
- )
+
+ try:
+ await self.federation_client.forward_third_party_invite(
+ destinations, room_id, event_dict
+ )
+ except (RequestSendFailed, HttpResponseException):
+ raise SynapseError(502, "Failed to forward third party invite")
async def on_exchange_third_party_invite_request(
self, event_dict: JsonDict
@@ -3149,7 +3156,7 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
- last_exception = None # type: Optional[Exception]
+ last_exception: Optional[Exception] = None
# for each public key in the 3pid invite event
for public_key_object in event_auth.get_public_keys(invite_event):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 157f2ff2..1a6c5c64 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
- destinations = {} # type: Dict[str, Set[str]]
+ destinations: Dict[str, Set[str]] = {}
local_users = set()
for user_id in user_ids:
@@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
- failed_results = [] # type: List[str]
+ failed_results: List[str] = []
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 33d16fbf..0961dec5 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler):
)
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
+ url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
content = {
"mxid": mxid,
@@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler):
return data["mxid"]
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
- except IOError as e:
+ except OSError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
return None
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 76242865..5d496407 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = ResponseCache(
- hs.get_clock(), "initial_sync_cache"
- ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
+ self.snapshot_cache: ResponseCache[
+ Tuple[
+ str,
+ Optional[StreamToken],
+ Optional[StreamToken],
+ str,
+ Optional[int],
+ bool,
+ bool,
+ ]
+ ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 66e40a91..8a0024ce 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -81,7 +81,7 @@ class MessageHandler:
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
- self._scheduled_expiry = None # type: Optional[IDelayedCall]
+ self._scheduled_expiry: Optional[IDelayedCall] = None
if not hs.config.worker_app:
run_as_background_process(
@@ -196,9 +196,7 @@ class MessageHandler:
room_state_events = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
- room_state = room_state_events[
- event.event_id
- ] # type: Mapping[Any, EventBase]
+ room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
else:
raise AuthError(
403,
@@ -421,9 +419,9 @@ class EventCreationHandler:
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
- self.third_party_event_rules = (
+ self.third_party_event_rules: "ThirdPartyEventRules" = (
self.hs.get_third_party_event_rules()
- ) # type: ThirdPartyEventRules
+ )
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@@ -440,7 +438,7 @@ class EventCreationHandler:
#
# map from room id to time-of-last-attempt.
#
- self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
+ self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
# The number of forward extremeities before a dummy event is sent.
self._dummy_events_threshold = hs.config.dummy_events_threshold
@@ -465,9 +463,7 @@ class EventCreationHandler:
# Stores the state groups we've recently added to the joined hosts
# external cache. Note that the timeout must be significantly less than
# the TTL on the external cache.
- self._external_cache_joined_hosts_updates = (
- None
- ) # type: Optional[ExpiringCache]
+ self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
if self._external_cache.is_enabled():
self._external_cache_joined_hosts_updates = ExpiringCache(
"_external_cache_joined_hosts_updates",
@@ -518,6 +514,9 @@ class EventCreationHandler:
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
+ historical: Indicates whether the message is being inserted
+ back in time around some existing events. This is used to skip
+ a few checks and mark the event as backfilled.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
@@ -772,6 +771,7 @@ class EventCreationHandler:
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
outlier: bool = False,
+ historical: bool = False,
depth: Optional[int] = None,
) -> Tuple[EventBase, int]:
"""
@@ -799,6 +799,9 @@ class EventCreationHandler:
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
+ historical: Indicates whether the message is being inserted
+ back in time around some existing events. This is used to skip
+ a few checks and mark the event as backfilled.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
@@ -847,6 +850,7 @@ class EventCreationHandler:
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
outlier=outlier,
+ historical=historical,
depth=depth,
)
@@ -945,10 +949,10 @@ class EventCreationHandler:
if requester:
context.app_service = requester.app_service
- third_party_result = await self.third_party_event_rules.check_event_allowed(
+ res, new_content = await self.third_party_event_rules.check_event_allowed(
event, context
)
- if not third_party_result:
+ if res is False:
logger.info(
"Event %s forbidden by third-party rules",
event,
@@ -956,11 +960,11 @@ class EventCreationHandler:
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
- elif isinstance(third_party_result, dict):
+ elif new_content is not None:
# the third-party rules want to replace the event. We'll need to build a new
# event.
event, context = await self._rebuild_event_after_third_party_rules(
- third_party_result, event
+ new_content, event
)
self.validator.validate_new(event, self.config)
@@ -1291,7 +1295,7 @@ class EventCreationHandler:
# Validate a newly added alias or newly added alt_aliases.
original_alias = None
- original_alt_aliases = [] # type: List[str]
+ original_alt_aliases: List[str] = []
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
@@ -1594,11 +1598,13 @@ class EventCreationHandler:
for k, v in original_event.internal_metadata.get_dict().items():
setattr(builder.internal_metadata, k, v)
- # the event type hasn't changed, so there's no point in re-calculating the
- # auth events.
+ # modules can send new state events, so we re-calculate the auth events just in
+ # case.
+ prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
+
event = await builder.build(
- prev_event_ids=original_event.prev_event_ids(),
- auth_event_ids=original_event.auth_event_ids(),
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=None,
)
# we rebuild the event context, to be on the safe side. If nothing else,
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index ee6e41c0..eca8f160 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -72,26 +72,26 @@ _SESSION_COOKIES = [
(b"oidc_session_no_samesite", b"HttpOnly"),
]
+
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
#: OpenID.Core sec 3.1.3.3.
-Token = TypedDict(
- "Token",
- {
- "access_token": str,
- "token_type": str,
- "id_token": Optional[str],
- "refresh_token": Optional[str],
- "expires_in": int,
- "scope": Optional[str],
- },
-)
+class Token(TypedDict):
+ access_token: str
+ token_type: str
+ id_token: Optional[str]
+ refresh_token: Optional[str]
+ expires_in: int
+ scope: Optional[str]
+
#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
#: there is no real point of doing this in our case.
JWK = Dict[str, str]
+
#: A JWK Set, as per RFC7517 sec 5.
-JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class JWKS(TypedDict):
+ keys: List[JWK]
class OidcHandler:
@@ -105,9 +105,9 @@ class OidcHandler:
assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs)
- self._providers = {
+ self._providers: Dict[str, "OidcProvider"] = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
- } # type: Dict[str, OidcProvider]
+ }
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
@@ -178,7 +178,7 @@ class OidcHandler:
# are two.
for cookie_name, _ in _SESSION_COOKIES:
- session = request.getCookie(cookie_name) # type: Optional[bytes]
+ session: Optional[bytes] = request.getCookie(cookie_name)
if session is not None:
break
else:
@@ -255,7 +255,7 @@ class OidcError(Exception):
def __str__(self):
if self.error_description:
- return "{}: {}".format(self.error, self.error_description)
+ return f"{self.error}: {self.error_description}"
return self.error
@@ -277,7 +277,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
- self._callback_url = hs.config.oidc_callback_url # type: str
+ self._callback_url: str = hs.config.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
@@ -290,7 +290,7 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
- client_secret = None # type: Union[None, str, JwtClientSecret]
+ client_secret: Optional[Union[str, JwtClientSecret]] = None
if provider.client_secret:
client_secret = provider.client_secret
elif provider.client_secret_jwt_key:
@@ -305,7 +305,7 @@ class OidcProvider:
provider.client_id,
client_secret,
provider.client_auth_method,
- ) # type: ClientAuth
+ )
self._client_auth_method = provider.client_auth_method
# cache of metadata for the identity provider (endpoint uris, mostly). This is
@@ -324,7 +324,7 @@ class OidcProvider:
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
- self._server_name = hs.config.server_name # type: str
+ self._server_name: str = hs.config.server_name
# identifier for the external_ids table
self.idp_id = provider.idp_id
@@ -639,7 +639,7 @@ class OidcProvider:
)
logger.warning(description)
# Body was still valid JSON. Might be useful to log it for debugging.
- logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+ logger.warning("Code exchange response: %r", resp)
raise OidcError("server_error", description)
return resp
@@ -1217,10 +1217,12 @@ class OidcSessionData:
ui_auth_session_id = attr.ib(type=str)
-UserAttributeDict = TypedDict(
- "UserAttributeDict",
- {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
-)
+class UserAttributeDict(TypedDict):
+ localpart: Optional[str]
+ display_name: Optional[str]
+ emails: List[str]
+
+
C = TypeVar("C")
@@ -1381,7 +1383,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- emails = [] # type: List[str]
+ emails: List[str] = []
email = render_template_field(self._config.email_template)
if email:
emails.append(email)
@@ -1391,7 +1393,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
- extras = {} # type: Dict[str, str]
+ extras: Dict[str, str] = {}
for key, template in self._config.extra_attributes.items():
try:
extras[key] = template.render(user=userinfo).strip()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1e1186c2..1dbafd25 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -81,9 +81,9 @@ class PaginationHandler:
self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock()
- self._purges_in_progress_by_room = set() # type: Set[str]
+ self._purges_in_progress_by_room: Set[str] = set()
# map from purge id to PurgeStatus
- self._purges_by_id = {} # type: Dict[str, PurgeStatus]
+ self._purges_by_id: Dict[str, PurgeStatus] = {}
self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 44ed7a07..016c5df2 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false.
- self._user_to_num_current_syncs = {} # type: Dict[str, int]
+ self._user_to_num_current_syncs: Dict[str, int] = {}
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
# user_id -> last_sync_ms. Lists the users that have stopped syncing but
# we haven't notified the presence writer of that yet
- self.users_going_offline = {} # type: Dict[str, int]
+ self.users_going_offline: Dict[str, int] = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
# Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted
- self.unpersisted_users_changes = set() # type: Set[str]
+ self.unpersisted_users_changes: Set[str] = set()
hs.get_reactor().addSystemEventTrigger(
"before",
@@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
- self.user_to_num_current_syncs = {} # type: Dict[str, int]
+ self.user_to_num_current_syncs: Dict[str, int] = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
@@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
- self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]]
- self.external_process_last_updated_ms = {} # type: Dict[str, int]
+ self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
+ self.external_process_last_updated_ms: Dict[str, int] = {}
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
@@ -1581,9 +1581,7 @@ class PresenceEventSource:
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
- interested_and_updated_users = (
- set()
- ) # type: Union[Set[str], FrozenSet[str]]
+ interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
if from_key:
# First get all users that have had a presence update
@@ -1950,8 +1948,8 @@ async def get_interested_parties(
A 2-tuple of `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
- room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
- users_to_states = {} # type: Dict[str, List[UserPresenceState]]
+ room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
+ users_to_states: Dict[str, List[UserPresenceState]] = {}
for state in states:
room_ids = await store.get_rooms_for_user(state.user_id)
for room_id in room_ids:
@@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
# stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states
# cached.
- self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
+ self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
self._next_id = 1
# Map from instance name to current token
- self._current_tokens = {} # type: Dict[str, int]
+ self._current_tokens: Dict[str, int] = {}
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
@@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
# handle the case where `from_token` stream ID has already been dropped.
start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
- to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
+ to_send: List[Tuple[int, Tuple[str, str]]] = []
limited = False
new_id = upto_token
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
@@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
if not self._federation:
return
- hosts_to_users = {} # type: Dict[str, Set[str]]
+ hosts_to_users: Dict[str, Set[str]] = {}
for row in rows:
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 05b4a97b..20a033d0 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
- displayname_to_set = new_displayname # type: Optional[str]
+ displayname_to_set: Optional[str] = new_displayname
if new_displayname == "":
displayname_to_set = None
@@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
- avatar_url_to_set = new_avatar_url # type: Optional[str]
+ avatar_url_to_set: Optional[str] = new_avatar_url
if new_avatar_url == "":
avatar_url_to_set = None
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f782d9db..283483fc 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -30,6 +30,8 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
+ self.event_auth_handler = hs.get_event_auth_handler()
+
self.hs = hs
# We only need to poke the federation sender explicitly if its on the
@@ -59,6 +61,19 @@ class ReceiptsHandler(BaseHandler):
"""Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
for room_id, room_values in content.items():
+ # If we're not in the room just ditch the event entirely. This is
+ # probably an old server that has come back and thinks we're still in
+ # the room (or we've been rejoined to the room by a state reset).
+ is_in_room = await self.event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
+ if not is_in_room:
+ logger.info(
+ "Ignoring receipt from %s as we're not in the room",
+ origin,
+ )
+ continue
+
for receipt_type, users in room_values.items():
for user_id, user_values in users.items():
if get_domain_from_id(user_id) != origin:
@@ -83,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier."""
- min_batch_id = None # type: Optional[int]
- max_batch_id = None # type: Optional[int]
+ min_batch_id: Optional[int] = None
+ max_batch_id: Optional[int] = None
for receipt in receipts:
res = await self.store.insert_receipt(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 26ef0161..8cf61413 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -55,15 +55,12 @@ login_counter = Counter(
["guest", "auth_provider"],
)
-LoginDict = TypedDict(
- "LoginDict",
- {
- "device_id": str,
- "access_token": str,
- "valid_until_ms": Optional[int],
- "refresh_token": Optional[str],
- },
-)
+
+class LoginDict(TypedDict):
+ device_id: str
+ access_token: str
+ valid_until_ms: Optional[int]
+ refresh_token: Optional[str]
class RegistrationHandler(BaseHandler):
@@ -77,6 +74,7 @@ class RegistrationHandler(BaseHandler):
self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
+ self._account_validity_handler = hs.get_account_validity_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self._server_name = hs.hostname
@@ -700,6 +698,10 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ # Only call the account validity module(s) on the main process, to avoid
+ # repeating e.g. database writes on all of the workers.
+ await self._account_validity_handler.on_user_registration(user_id)
+
async def register_device(
self,
user_id: str,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 579b1b93..370561e5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler):
self.config = hs.config
# Room state based off defined presets
- self._presets_dict = {
+ self._presets_dict: Dict[str, Dict[str, Any]] = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": HistoryVisibility.SHARED,
@@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
- } # type: Dict[str, Dict[str, Any]]
+ }
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
# If a user tries to update the same room multiple times in quick
# succession, only process the first attempt and return its result to
# subsequent requests
- self._upgrade_response_cache = ResponseCache(
+ self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
- ) # type: ResponseCache[Tuple[str, str]]
+ )
self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
- creation_content = {
+ creation_content: JsonDict = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
- } # type: JsonDict
+ }
# Check if old room was non-federatable
@@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler):
else:
is_requester_admin = await self.auth.is_server_admin(requester.user)
- # Check whether the third party rules allows/changes the room create
- # request.
- event_allowed = await self.third_party_event_rules.on_create_room(
+ # Let the third party rules modify the room creation config if needed, or abort
+ # the room creation entirely with an exception.
+ await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
- if not event_allowed:
- raise SynapseError(
- 403, "You are not permitted to create rooms", Codes.FORBIDDEN
- )
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
@@ -936,7 +932,7 @@ class RoomCreationHandler(BaseHandler):
etype=EventTypes.PowerLevels, content=pl_content
)
else:
- power_level_content = {
+ power_level_content: JsonDict = {
"users": {creator_id: 100},
"users_default": 0,
"events": {
@@ -955,7 +951,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
- } # type: JsonDict
+ }
if config["original_invitees_have_ops"]:
for invitee in invite_list:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5e3ef7ce..fae2c098 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,7 +20,12 @@ import msgpack
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.api.errors import Codes, HttpResponseException
+from synapse.api.errors import (
+ Codes,
+ HttpResponseException,
+ RequestSendFailed,
+ SynapseError,
+)
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
@@ -42,12 +47,12 @@ class RoomListHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
- self.response_cache = ResponseCache(
- hs.get_clock(), "room_list"
- ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
- self.remote_response_cache = ResponseCache(
- hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
- ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
+ self.response_cache: ResponseCache[
+ Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
+ ] = ResponseCache(hs.get_clock(), "room_list")
+ self.remote_response_cache: ResponseCache[
+ Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
+ ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
async def get_local_public_room_list(
self,
@@ -134,10 +139,10 @@ class RoomListHandler(BaseHandler):
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
- bounds = (
+ bounds: Optional[Tuple[int, str]] = (
batch_token.last_joined_members,
batch_token.last_room_id,
- ) # type: Optional[Tuple[int, str]]
+ )
forwards = batch_token.direction_is_forward
has_batch_token = True
else:
@@ -177,7 +182,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results]
- response = {} # type: JsonDict
+ response: JsonDict = {}
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
@@ -378,7 +383,11 @@ class RoomListHandler(BaseHandler):
):
logger.debug("Falling back to locally-filtered /publicRooms")
else:
- raise # Not an error that should trigger a fallback.
+ # Not an error that should trigger a fallback.
+ raise SynapseError(502, "Failed to fetch room list")
+ except RequestSendFailed:
+ # Not an error that should trigger a fallback.
+ raise SynapseError(502, "Failed to fetch room list")
# if we reach this point, then we fall back to the situation where
# we currently don't support searching across federation, so we have
@@ -417,14 +426,17 @@ class RoomListHandler(BaseHandler):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
- return await repl_layer.get_public_rooms(
- server_name,
- limit=limit,
- since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
+ try:
+ return await repl_layer.get_public_rooms(
+ server_name,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
+ except (RequestSendFailed, HttpResponseException):
+ raise SynapseError(502, "Failed to fetch room list")
key = (
server_name,
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 80ba65b9..e6e71e97 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
- self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
+ self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@@ -372,7 +372,7 @@ class SamlHandler(BaseHandler):
DOT_REPLACE_PATTERN = re.compile(
- ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+ "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)
)
@@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
return username
-MXID_MAPPER_MAP = {
+MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
-} # type: Dict[str, Callable[[str], str]]
+}
@attr.s
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 4e718d3f..8226d6f5 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
- historical_room_ids = [] # type: List[str]
+ historical_room_ids: List[str] = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
# Holds result of grouping by room, if applicable
- room_groups = {} # type: Dict[str, JsonDict]
+ room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable
- sender_group = {} # type: Dict[str, JsonDict]
+ sender_group: Dict[str, JsonDict] = {}
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
- room_events = [] # type: List[EventBase]
+ room_events: List[EventBase] = []
i = 0
pagination_token = batch_token
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index b585057e..5f7d4602 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
HistoryVisibility,
+ JoinRules,
Membership,
RoomTypes,
)
@@ -89,14 +90,14 @@ class SpaceSummaryHandler:
room_queue = deque((_RoomQueueEntry(room_id, ()),))
# rooms we have already processed
- processed_rooms = set() # type: Set[str]
+ processed_rooms: Set[str] = set()
# events we have already processed. We don't necessarily have their event ids,
# so instead we key on (room id, state key)
- processed_events = set() # type: Set[Tuple[str, str]]
+ processed_events: Set[Tuple[str, str]] = set()
- rooms_result = [] # type: List[JsonDict]
- events_result = [] # type: List[JsonDict]
+ rooms_result: List[JsonDict] = []
+ events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS:
queue_entry = room_queue.popleft()
@@ -150,14 +151,21 @@ class SpaceSummaryHandler:
# The room should only be included in the summary if:
# a. the user is in the room;
# b. the room is world readable; or
- # c. the user is in a space that has been granted access to
- # the room.
+ # c. the user could join the room, e.g. the join rules
+ # are set to public or the user is in a space that
+ # has been granted access to the room.
#
# Note that we know the user is not in the root room (which is
# why the remote call was made in the first place), but the user
# could be in one of the children rooms and we just didn't know
# about the link.
- include_room = room.get("world_readable") is True
+
+ # The API doesn't return the room version so assume that a
+ # join rule of knock is valid.
+ include_room = (
+ room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+ or room.get("world_readable") is True
+ )
# Check if the user is a member of any of the allowed spaces
# from the response.
@@ -264,10 +272,10 @@ class SpaceSummaryHandler:
# the set of rooms that we should not walk further. Initialise it with the
# excluded-rooms list; we will add other rooms as we process them so that
# we do not loop.
- processed_rooms = set(exclude_rooms) # type: Set[str]
+ processed_rooms: Set[str] = set(exclude_rooms)
- rooms_result = [] # type: List[JsonDict]
- events_result = [] # type: List[JsonDict]
+ rooms_result: List[JsonDict] = []
+ events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS:
room_id = room_queue.popleft()
@@ -345,7 +353,7 @@ class SpaceSummaryHandler:
max_children = MAX_ROOMS_PER_SPACE
now = self._clock.time_msec()
- events_result = [] # type: List[JsonDict]
+ events_result: List[JsonDict] = []
for edge_event in itertools.islice(child_events, max_children):
events_result.append(
await self._event_serializer.serialize_event(
@@ -420,9 +428,8 @@ class SpaceSummaryHandler:
It should be included if:
- * The requester is joined or invited to the room.
- * The requester can join without an invite (per MSC3083).
- * The origin server has any user that is joined or invited to the room.
+ * The requester is joined or can join the room (per MSC3173).
+ * The origin server has any user that is joined or can join the room.
* The history visibility is set to world readable.
Args:
@@ -441,13 +448,39 @@ class SpaceSummaryHandler:
# If there's no state for the room, it isn't known.
if not state_ids:
+ # The user might have a pending invite for the room.
+ if requester and await self._store.get_invite_for_local_user_in_room(
+ requester, room_id
+ ):
+ return True
+
logger.info("room %s is unknown, omitting from summary", room_id)
return False
room_version = await self._store.get_room_version(room_id)
- # if we have an authenticated requesting user, first check if they are able to view
- # stripped state in the room.
+ # Include the room if it has join rules of public or knock.
+ join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""))
+ if join_rules_event_id:
+ join_rules_event = await self._store.get_event(join_rules_event_id)
+ join_rule = join_rules_event.content.get("join_rule")
+ if join_rule == JoinRules.PUBLIC or (
+ room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+ ):
+ return True
+
+ # Include the room if it is peekable.
+ hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist_vis_event_id:
+ hist_vis_ev = await self._store.get_event(hist_vis_event_id)
+ hist_vis = hist_vis_ev.content.get("history_visibility")
+ if hist_vis == HistoryVisibility.WORLD_READABLE:
+ return True
+
+ # Otherwise we need to check information specific to the user or server.
+
+ # If we have an authenticated requesting user, check if they are a member
+ # of the room (or can join the room).
if requester:
member_event_id = state_ids.get((EventTypes.Member, requester), None)
@@ -470,9 +503,11 @@ class SpaceSummaryHandler:
return True
# If this is a request over federation, check if the host is in the room or
- # is in one of the spaces specified via the join rules.
+ # has a user who could join the room.
elif origin:
- if await self._event_auth_handler.check_host_in_room(room_id, origin):
+ if await self._event_auth_handler.check_host_in_room(
+ room_id, origin
+ ) or await self._store.is_host_invited(room_id, origin):
return True
# Alternately, if the host has a user in any of the spaces specified
@@ -490,18 +525,10 @@ class SpaceSummaryHandler:
):
return True
- # otherwise, check if the room is peekable
- hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None)
- if hist_vis_event_id:
- hist_vis_ev = await self._store.get_event(hist_vis_event_id)
- hist_vis = hist_vis_ev.content.get("history_visibility")
- if hist_vis == HistoryVisibility.WORLD_READABLE:
- return True
-
logger.info(
- "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary",
+ "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary",
room_id,
- requester,
+ requester or origin,
)
return False
@@ -535,6 +562,7 @@ class SpaceSummaryHandler:
"canonical_alias": stats["canonical_alias"],
"num_joined_members": stats["joined_members"],
"avatar_url": stats["avatar"],
+ "join_rules": stats["join_rules"],
"world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
),
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 0b297e54..1b855a68 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -202,10 +202,10 @@ class SsoHandler:
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
# a map from session id to session data
- self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
+ self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
# map from idp_id to SsoIdentityProvider
- self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
+ self._identity_providers: Dict[str, SsoIdentityProvider] = {}
self._consent_at_registration = hs.config.consent.user_consent_at_registration
@@ -296,7 +296,7 @@ class SsoHandler:
)
# if the client chose an IdP, use that
- idp = None # type: Optional[SsoIdentityProvider]
+ idp: Optional[SsoIdentityProvider] = None
if idp_id:
idp = self._identity_providers.get(idp_id)
if not idp:
@@ -669,9 +669,9 @@ class SsoHandler:
remote_user_id,
)
- user_id_to_verify = await self._auth_handler.get_session_data(
+ user_id_to_verify: str = await self._auth_handler.get_session_data(
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
- ) # type: str
+ )
if not user_id:
logger.warning(
@@ -793,7 +793,7 @@ class SsoHandler:
session.use_display_name = use_display_name
emails_from_idp = set(session.emails)
- filtered_emails = set() # type: Set[str]
+ filtered_emails: Set[str] = set()
# we iterate through the list rather than just building a set conjunction, so
# that we can log attempts to use unknown addresses
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 4e45d1da..3fd89af2 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -45,12 +45,11 @@ class StatsHandler:
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.stats_bucket_size = hs.config.stats_bucket_size
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
- self.pos = None # type: Optional[int]
+ self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -106,20 +105,6 @@ class StatsHandler:
room_deltas = {}
user_deltas = {}
- # Then count deltas for total_events and total_event_bytes.
- (
- room_count,
- user_count,
- ) = await self.store.get_changes_room_total_events_and_bytes(
- self.pos, max_pos
- )
-
- for room_id, fields in room_count.items():
- room_deltas.setdefault(room_id, Counter()).update(fields)
-
- for user_id, fields in user_count.items():
- user_deltas.setdefault(user_id, Counter()).update(fields)
-
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@@ -146,10 +131,10 @@ class StatsHandler:
mapping from room/user ID to changes in the various fields.
"""
- room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
- user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
+ room_to_stats_deltas: Dict[str, CounterType[str]] = {}
+ user_to_stats_deltas: Dict[str, CounterType[str]] = {}
- room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
+ room_to_state_updates: Dict[str, Dict[str, Any]] = {}
for delta in deltas:
typ = delta["type"]
@@ -179,14 +164,12 @@ class StatsHandler:
)
continue
- event_content = {} # type: JsonDict
+ event_content: JsonDict = {}
- sender = None
if event_id is not None:
event = await self.store.get_event(event_id, allow_none=True)
if event:
event_content = event.content or {}
- sender = event.sender
# All the values in this dict are deltas (RELATIVE changes)
room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter())
@@ -244,12 +227,6 @@ class StatsHandler:
room_stats_delta["joined_members"] += 1
elif membership == Membership.INVITE:
room_stats_delta["invited_members"] += 1
-
- if sender and self.is_mine_id(sender):
- user_to_stats_deltas.setdefault(sender, Counter())[
- "invites_sent"
- ] += 1
-
elif membership == Membership.LEAVE:
room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
@@ -279,10 +256,6 @@ class StatsHandler:
room_state["is_federatable"] = (
event_content.get("m.federate", True) is True
)
- if sender and self.is_mine_id(sender):
- user_to_stats_deltas.setdefault(sender, Counter())[
- "rooms_created"
- ] += 1
elif typ == EventTypes.JoinRules:
room_state["join_rules"] = event_content.get("join_rule")
elif typ == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b9a03610..f30bfcc9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -278,12 +278,14 @@ class SyncHandler:
self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
- self.lazy_loaded_members_cache = ExpiringCache(
+ self.lazy_loaded_members_cache: ExpiringCache[
+ Tuple[str, Optional[str]], LruCache[str, str]
+ ] = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
- ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
+ )
async def wait_for_sync_for_user(
self,
@@ -440,7 +442,7 @@ class SyncHandler:
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
- ephemeral_by_room = {} # type: JsonDict
+ ephemeral_by_room: JsonDict = {}
for event in typing:
# we want to exclude the room_id from the event, but modifying the
@@ -502,7 +504,7 @@ class SyncHandler:
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
- current_state_ids = frozenset() # type: FrozenSet[str]
+ current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
current_state_ids_map = await self.store.get_current_state_ids(
room_id
@@ -783,9 +785,9 @@ class SyncHandler:
def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
) -> LruCache[str, str]:
- cache = self.lazy_loaded_members_cache.get(
+ cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
cache_key
- ) # type: Optional[LruCache[str, str]]
+ )
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@@ -984,7 +986,7 @@ class SyncHandler:
if t[0] == EventTypes.Member:
cache.set(t[1], event_id)
- state = {} # type: Dict[str, EventBase]
+ state: Dict[str, EventBase] = {}
if state_ids:
state = await self.store.get_events(list(state_ids.values()))
@@ -1088,9 +1090,13 @@ class SyncHandler:
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
- one_time_key_counts = {} # type: JsonDict
- unused_fallback_key_types = [] # type: List[str]
+ one_time_key_counts: JsonDict = {}
+ unused_fallback_key_types: List[str] = []
if device_id:
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
@@ -1437,7 +1443,7 @@ class SyncHandler:
)
if block_all_room_ephemeral:
- ephemeral_by_room = {} # type: Dict[str, List[JsonDict]]
+ ephemeral_by_room: Dict[str, List[JsonDict]] = {}
else:
now_token, ephemeral_by_room = await self.ephemeral_by_room(
sync_result_builder,
@@ -1468,7 +1474,7 @@ class SyncHandler:
# If there is ignored users account data and it matches the proper type,
# then use it.
- ignored_users = frozenset() # type: FrozenSet[str]
+ ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
@@ -1586,7 +1592,7 @@ class SyncHandler:
user_id, since_token.room_key, now_token.room_key
)
- mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]]
+ mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
@@ -1599,7 +1605,7 @@ class SyncHandler:
logger.debug(
"Membership changes in %s: [%s]",
room_id,
- ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
+ ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events),
)
non_joins = [e for e in events if e.membership != Membership.JOIN]
@@ -1722,7 +1728,7 @@ class SyncHandler:
# This is all screaming out for a refactor, as the logic here is
# subtle and the moving parts numerous.
if leave_event.internal_metadata.is_out_of_band_membership():
- batch_events = [leave_event] # type: Optional[List[EventBase]]
+ batch_events: Optional[List[EventBase]] = [leave_event]
else:
batch_events = None
@@ -1971,7 +1977,7 @@ class SyncHandler:
room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
- summary = {} # type: Optional[JsonDict]
+ summary: Optional[JsonDict] = {}
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
@@ -1995,7 +2001,7 @@ class SyncHandler:
)
if room_builder.rtype == "joined":
- unread_notifications = {} # type: Dict[str, int]
+ unread_notifications: Dict[str, int] = {}
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e22393ad..0cb651a4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -68,11 +68,11 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
- self._room_serials = {} # type: Dict[str, int]
+ self._room_serials: Dict[str, int] = {}
# map room IDs to sets of users currently typing
- self._room_typing = {} # type: Dict[str, Set[str]]
+ self._room_typing: Dict[str, Set[str]] = {}
- self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
+ self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
@@ -208,6 +208,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
+ self.event_auth_handler = hs.get_event_auth_handler()
self.hs = hs
@@ -216,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
# clock time we expect to stop
- self._member_typing_until = {} # type: Dict[RoomMember, int]
+ self._member_typing_until: Dict[RoomMember, int] = {}
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
@@ -326,6 +327,19 @@ class TypingWriterHandler(FollowerTypingHandler):
room_id = content["room_id"]
user_id = content["user_id"]
+ # If we're not in the room just ditch the event entirely. This is
+ # probably an old server that has come back and thinks we're still in
+ # the room (or we've been rejoined to the room by a state reset).
+ is_in_room = await self.event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
+ if not is_in_room:
+ logger.info(
+ "Ignoring typing update from %s as we're not in the room",
+ origin,
+ )
+ return
+
member = RoomMember(user_id=user_id, room_id=room_id)
# Check that the string is a valid user id
@@ -391,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id:
return [], current_id, False
- changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
- last_id
- ) # type: Optional[Iterable[str]]
+ changed_rooms: Optional[
+ Iterable[str]
+ ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
if changed_rooms is None:
changed_rooms = self._room_serials
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index dacc4f30..6edb1da5 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
- self.pos = None # type: Optional[int]
+ self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False