summaryrefslogtreecommitdiff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2020-09-23 10:00:31 +0200
committerAndrej Shadura <andrewsh@debian.org>2020-09-23 10:00:31 +0200
commitc4b994b356a5af29bdf1f8648dd7d929a237acbd (patch)
tree6163eda9a9a5329caea1cf5b4ad42b974cc1ae3b /synapse/handlers
parenta8007890d174f089f2ce28aae9d919df346f74f9 (diff)
New upstream version 1.20.0
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/__init__.py2
-rw-r--r--synapse/handlers/_base.py2
-rw-r--r--synapse/handlers/account_data.py2
-rw-r--r--synapse/handlers/account_validity.py22
-rw-r--r--synapse/handlers/acme.py2
-rw-r--r--synapse/handlers/acme_issuing_service.py2
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py110
-rw-r--r--synapse/handlers/cas_handler.py11
-rw-r--r--synapse/handlers/device.py6
-rw-r--r--synapse/handlers/devicemessage.py7
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/handlers/e2e_room_keys.py2
-rw-r--r--synapse/handlers/events.py49
-rw-r--r--synapse/handlers/federation.py121
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py5
-rw-r--r--synapse/handlers/initial_sync.py80
-rw-r--r--synapse/handlers/message.py126
-rw-r--r--synapse/handlers/oidc_handler.py35
-rw-r--r--synapse/handlers/pagination.py145
-rw-r--r--synapse/handlers/password_policy.py2
-rw-r--r--synapse/handlers/presence.py10
-rw-r--r--synapse/handlers/profile.py23
-rw-r--r--synapse/handlers/receipts.py2
-rw-r--r--synapse/handlers/register.py34
-rw-r--r--synapse/handlers/room.py154
-rw-r--r--synapse/handlers/room_member.py82
-rw-r--r--synapse/handlers/saml_handler.py22
-rw-r--r--synapse/handlers/state_deltas.py2
-rw-r--r--synapse/handlers/sync.py54
-rw-r--r--synapse/handlers/typing.py23
-rw-r--r--synapse/handlers/ui_auth/checkers.py5
-rw-r--r--synapse/handlers/user_directory.py8
36 files changed, 814 insertions, 364 deletions
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 2dd18301..286f0054 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -20,7 +20,7 @@ from .identity import IdentityHandler
from .search import SearchHandler
-class Handlers(object):
+class Handlers:
""" Deprecated. A collection of handlers.
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index ba2bf998..0206320e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -25,7 +25,7 @@ from synapse.types import UserID
logger = logging.getLogger(__name__)
-class BaseHandler(object):
+class BaseHandler:
"""
Common base class for the event handlers.
"""
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index a8d3fbc6..9112a0ab 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -14,7 +14,7 @@
# limitations under the License.
-class AccountDataEventSource(object):
+class AccountDataEventSource:
def __init__(self, hs):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 590135d1..4caf6d59 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -26,15 +26,10 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
-try:
- from synapse.push.mailer import load_jinja2_templates
-except ImportError:
- load_jinja2_templates = None
-
logger = logging.getLogger(__name__)
-class AccountValidityHandler(object):
+class AccountValidityHandler:
def __init__(self, hs):
self.hs = hs
self.config = hs.config
@@ -47,9 +42,11 @@ class AccountValidityHandler(object):
if (
self._account_validity.enabled
and self._account_validity.renew_by_email_enabled
- and load_jinja2_templates
):
# Don't do email-specific configuration if renewal by email is disabled.
+ self._template_html = self.config.account_validity_template_html
+ self._template_text = self.config.account_validity_template_text
+
try:
app_name = self.hs.config.email_app_name
@@ -65,17 +62,6 @@ class AccountValidityHandler(object):
self._raw_from = email.utils.parseaddr(self._from_string)[1]
- self._template_html, self._template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_expiry_template_html,
- self.config.email_expiry_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
-
# Check the renewal emails to send and send them every 30min.
def send_emails():
# run as a background process to make sure that the database transactions
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 7666d3ab..8476256a 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -34,7 +34,7 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
--------------------------------------------------------------------------------"""
-class AcmeHandler(object):
+class AcmeHandler:
def __init__(self, hs):
self.hs = hs
self.reactor = hs.get_reactor()
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index e1d4224e..69650ff2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -78,7 +78,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou
@attr.s
@implementer(ICertificateStore)
-class ErsatzStore(object):
+class ErsatzStore:
"""
A store that only stores in memory.
"""
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 506bb2b2..918d0e03 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -197,7 +197,7 @@ class AdminHandler(BaseHandler):
return writer.finished()
-class ExfiltrationWriter(object):
+class ExfiltrationWriter:
"""Interface used to specify how to write exported data.
"""
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index c9044a50..9d4e87da 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
-class ApplicationServicesHandler(object):
+class ApplicationServicesHandler:
def __init__(self, hs):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c24e7baf..90189869 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -42,9 +42,9 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
-from synapse.push.mailer import load_jinja2_templates
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -52,6 +52,91 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+def convert_client_dict_legacy_fields_to_identifier(
+ submission: JsonDict,
+) -> Dict[str, str]:
+ """
+ Convert a legacy-formatted login submission to an identifier dict.
+
+ Legacy login submissions (used in both login and user-interactive authentication)
+ provide user-identifying information at the top-level instead.
+
+ These are now deprecated and replaced with identifiers:
+ https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
+
+ Args:
+ submission: The client dict to convert
+
+ Returns:
+ The matching identifier dict
+
+ Raises:
+ SynapseError: If the format of the client dict is invalid
+ """
+ identifier = submission.get("identifier", {})
+
+ # Generate an m.id.user identifier if "user" parameter is present
+ user = submission.get("user")
+ if user:
+ identifier = {"type": "m.id.user", "user": user}
+
+ # Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present
+ medium = submission.get("medium")
+ address = submission.get("address")
+ if medium and address:
+ identifier = {
+ "type": "m.id.thirdparty",
+ "medium": medium,
+ "address": address,
+ }
+
+ # We've converted valid, legacy login submissions to an identifier. If the
+ # submission still doesn't have an identifier, it's invalid
+ if not identifier:
+ raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM)
+
+ # Ensure the identifier has a type
+ if "type" not in identifier:
+ raise SynapseError(
+ 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
+ )
+
+ return identifier
+
+
+def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
+ """
+ Convert a phone login identifier type to a generic threepid identifier.
+
+ Args:
+ identifier: Login identifier dict of type 'm.id.phone'
+
+ Returns:
+ An equivalent m.id.thirdparty identifier dict
+ """
+ if "country" not in identifier or (
+ # The specification requires a "phone" field, while Synapse used to require a "number"
+ # field. Accept both for backwards compatibility.
+ "phone" not in identifier
+ and "number" not in identifier
+ ):
+ raise SynapseError(
+ 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
+ )
+
+ # Accept both "phone" and "number" as valid keys in m.id.phone
+ phone_number = identifier.get("phone", identifier["number"])
+
+ # Convert user-provided phone number to a consistent representation
+ msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
+
+ return {
+ "type": "m.id.thirdparty",
+ "medium": "msisdn",
+ "address": msisdn,
+ }
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -132,18 +217,17 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
- self._sso_redirect_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
- )[0]
+ self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
+
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
- self._sso_auth_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_auth_confirm.html"],
- )[0]
+ self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
+
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
+
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
@@ -366,6 +450,14 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+
+ await self.store.add_user_agent_ip_to_ui_auth_session(
+ session.session_id, user_agent, clientip
+ )
+
if not authdict:
raise InteractiveAuthIncompleteError(
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
@@ -1144,7 +1236,7 @@ class AuthHandler(BaseHandler):
@attr.s
-class MacaroonGenerator(object):
+class MacaroonGenerator:
hs = attr.ib()
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 786e608f..a4cc4b9a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -35,6 +35,7 @@ class CasHandler:
"""
def __init__(self, hs):
+ self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -210,8 +211,16 @@ class CasHandler:
else:
if not registered_user_id:
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(
+ b"User-Agent", default=[b""]
+ )[0].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=user_display_name
+ localpart=localpart,
+ default_display_name=user_display_name,
+ user_agent_ips=(user_agent, ip_address),
)
await self._auth_handler.complete_sso_login(
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index db417d60..643d71a7 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
return result
async def on_federation_query_user_devices(self, user_id):
- stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+ stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
+ user_id
+ )
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
@@ -495,7 +497,7 @@ def _update_device_from_client_ips(device, client_ips):
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
-class DeviceListUpdater(object):
+class DeviceListUpdater:
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs, device_handler):
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 610b08d0..64ef7f63 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -16,8 +16,6 @@
import logging
from typing import Any, Dict
-from canonicaljson import json
-
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@@ -27,12 +25,13 @@ from synapse.logging.opentracing import (
start_active_span,
)
from synapse.types import UserID, get_domain_from_id
+from synapse.util import json_encoder
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
-class DeviceMessageHandler(object):
+class DeviceMessageHandler:
def __init__(self, hs):
"""
Args:
@@ -174,7 +173,7 @@ class DeviceMessageHandler(object):
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
- "org.matrix.opentracing_context": json.dumps(context),
+ "org.matrix.opentracing_context": json_encoder.encode(context),
}
log_kv({"local_messages": local_messages})
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 79a2df62..46826eb7 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -23,6 +23,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
NotFoundError,
+ ShadowBanError,
StoreError,
SynapseError,
)
@@ -199,6 +200,8 @@ class DirectoryHandler(BaseHandler):
try:
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
+ except ShadowBanError as e:
+ logger.info("Failed to update alias events due to shadow-ban: %s", e)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -292,6 +295,9 @@ class DirectoryHandler(BaseHandler):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 84169c10..d629c7c1 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,7 +19,7 @@ import logging
from typing import Dict, List, Optional, Tuple
import attr
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
@@ -35,7 +35,7 @@ from synapse.types import (
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
-from synapse.util import unwrapFirstError
+from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -43,7 +43,7 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
-class E2eKeysHandler(object):
+class E2eKeysHandler:
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
@@ -353,7 +353,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = await self.store.get_e2e_device_keys(local_query)
+ results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json.loads(json_bytes)
+ key_id: json_decoder.decode(json_bytes)
}
@trace
@@ -734,7 +734,7 @@ class E2eKeysHandler(object):
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
- devices = await self.store.get_e2e_device_keys([(user_id, None)])
+ devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
def _one_time_keys_match(old_key_json, new_key):
- old_key = json.loads(old_key_json)
+ old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
@@ -1212,7 +1212,7 @@ class SignatureListItem:
signature = attr.ib()
-class SigningKeyEduUpdater(object):
+class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs, e2e_keys_handler):
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 0bb983dc..f01b0907 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -29,7 +29,7 @@ from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
-class E2eRoomKeysHandler(object):
+class E2eRoomKeysHandler:
"""
Implements an optional realtime backup mechanism for encrypted E2E megolm room keys.
This gives a way for users to store and recover their megolm keys if they lose all
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 1924636c..b05e32f4 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -15,29 +15,30 @@
import logging
import random
+from typing import TYPE_CHECKING, Iterable, List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.utils import log_function
-from synapse.types import UserID
+from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(EventStreamHandler, self).__init__(hs)
- # Count of active streams per user
- self._streams_per_user = {}
- # Grace timers per user to delay the "stopped" signal
- self._stop_timer_per_user = {}
-
self.distributor = hs.get_distributor()
self.distributor.declare("started_user_eventstream")
self.distributor.declare("stopped_user_eventstream")
@@ -52,14 +53,14 @@ class EventStreamHandler(BaseHandler):
@log_function
async def get_stream(
self,
- auth_user_id,
- pagin_config,
- timeout=0,
- as_client_event=True,
- affect_presence=True,
- room_id=None,
- is_guest=False,
- ):
+ auth_user_id: str,
+ pagin_config: PaginationConfig,
+ timeout: int = 0,
+ as_client_event: bool = True,
+ affect_presence: bool = True,
+ room_id: Optional[str] = None,
+ is_guest: bool = False,
+ ) -> JsonDict:
"""Fetches the events stream for a given user.
"""
@@ -98,7 +99,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
- to_add = []
+ to_add = [] # type: List[JsonDict]
for event in events:
if not isinstance(event, EventBase):
continue
@@ -110,7 +111,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence for everyone in the room.
users = await self.state.get_current_users_in_room(
event.room_id
- )
+ ) # type: Iterable[str]
else:
users = [event.state_key]
@@ -144,20 +145,22 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage()
- async def get_event(self, user, room_id, event_id):
+ async def get_event(
+ self, user: UserID, room_id: Optional[str], event_id: str
+ ) -> Optional[EventBase]:
"""Retrieve a single specified event.
Args:
- user (synapse.types.UserID): The user requesting the event
- room_id (str|None): The expected room id. We'll return None if the
+ user: The user requesting the event
+ room_id: The expected room id. We'll return None if the
event's room does not match.
- event_id (str): The event ID to obtain.
+ event_id: The event ID to obtain.
Returns:
- dict: An event, or None if there is no event matching this ID.
+ An event, or None if there is no event matching this ID.
Raises:
SynapseError if there was a problem retrieving this event, or
AuthError if the user does not have the rights to inspect this
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 593932ad..014dab29 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -72,7 +72,13 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ MutableStateMap,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -96,7 +102,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
- auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
class FederationHandler(BaseHandler):
@@ -434,11 +440,11 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
return
- latest = await self.store.get_latest_event_ids_in_room(room_id)
+ latest_list = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
- latest = set(latest)
+ latest = set(latest_list)
latest |= seen
logger.info(
@@ -775,7 +781,7 @@ class FederationHandler(BaseHandler):
# keys across all devices.
current_keys = [
key
- for device in cached_devices
+ for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values()
]
@@ -937,15 +943,26 @@ class FederationHandler(BaseHandler):
return events
- async def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(
+ self, room_id: str, current_depth: int, limit: int
+ ) -> bool:
"""Checks the database to see if we should backfill before paginating,
and if so do.
+
+ Args:
+ room_id
+ current_depth: The depth from which we're paginating from. This is
+ used to decide if we should backfill and what extremities to
+ use.
+ limit: The number of events that the pagination request will
+ return. This is used as part of the heuristic to decide if we
+ should back paginate.
"""
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
- return
+ return False
# We only want to paginate if we can actually see the events we'll get,
# as otherwise we'll just spend a lot of resources to get redacted
@@ -998,16 +1015,54 @@ class FederationHandler(BaseHandler):
sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
+ # If we're approaching an extremity we trigger a backfill, otherwise we
+ # no-op.
+ #
+ # We chose twice the limit here as then clients paginating backwards
+ # will send pagination requests that trigger backfill at least twice
+ # using the most recent extremity before it gets removed (see below). We
+ # chose more than one times the limit in case of failure, but choosing a
+ # much larger factor will result in triggering a backfill request much
+ # earlier than necessary.
+ if current_depth - 2 * limit > max_depth:
+ logger.debug(
+ "Not backfilling as we don't need to. %d < %d - 2 * %d",
+ max_depth,
+ current_depth,
+ limit,
+ )
+ return False
+
+ logger.debug(
+ "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
+ room_id,
+ current_depth,
+ max_depth,
+ sorted_extremeties_tuple,
+ )
+
+ # We ignore extremities that have a greater depth than our current depth
+ # as:
+ # 1. we don't really care about getting events that have happened
+ # before our current position; and
+ # 2. we have likely previously tried and failed to backfill from that
+ # extremity, so to avoid getting "stuck" requesting the same
+ # backfill repeatedly we drop those extremities.
+ filtered_sorted_extremeties_tuple = [
+ t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
+ ]
+
+ # However, we need to check that the filtered extremities are non-empty.
+ # If they are empty then either we can a) bail or b) still attempt to
+ # backill. We opt to try backfilling anyway just in case we do get
+ # relevant events.
+ if filtered_sorted_extremeties_tuple:
+ sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
+
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
- if current_depth > max_depth:
- logger.debug(
- "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
- )
- return
-
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
@@ -1777,9 +1832,7 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@@ -1805,9 +1858,7 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -1877,8 +1928,8 @@ class FederationHandler(BaseHandler):
else:
return None
- def get_min_depth_for_context(self, context):
- return self.store.get_min_depth(context)
+ async def get_min_depth_for_context(self, context):
+ return await self.store.get_min_depth(context)
async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False
@@ -2057,7 +2108,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- auth_events: Optional[StateMap[EventBase]],
+ auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)
@@ -2107,8 +2158,8 @@ class FederationHandler(BaseHandler):
if backfilled or event.internal_metadata.is_outlier():
return
- extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
- extrem_ids = set(extrem_ids)
+ extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids:
@@ -2138,10 +2189,12 @@ class FederationHandler(BaseHandler):
)
state_sets = list(state_sets.values())
state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
+ current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
- current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ current_state_ids = {
+ k: e.event_id for k, e in current_states.items()
+ } # type: StateMap[str]
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2153,11 +2206,13 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
- current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
+ current_state_ids_list = [
+ e for k, e in current_state_ids.items() if k in auth_types
+ ]
- current_auth_events = await self.store.get_events(current_state_ids)
+ auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
+ (e.type, e.state_key): e for e in auth_events_map.values()
}
try:
@@ -2173,9 +2228,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
@@ -2227,7 +2280,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- auth_events: StateMap[EventBase],
+ auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""
@@ -2278,7 +2331,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- auth_events: StateMap[EventBase],
+ auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 0e2656cc..44df5679 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -52,7 +52,7 @@ def _create_rerouter(func_name):
return f
-class GroupsLocalWorkerHandler(object):
+class GroupsLocalWorkerHandler:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 92b74047..0ce6ddfb 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,8 +21,6 @@ import logging
import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
-from canonicaljson import json
-
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@@ -34,6 +32,7 @@ from synapse.api.errors import (
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, Requester
+from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -177,7 +176,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
- data = json.loads(e.msg) # XXX WAT?
+ data = json_decoder.decode(e.msg) # XXX WAT?
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index ae6bd1d3..d5ddc583 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.internet import defer
@@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
-from synapse.types import StreamToken, UserID
+from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(InitialSyncHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
@@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler):
def snapshot_all_rooms(
self,
- user_id=None,
- pagin_config=None,
- as_client_event=True,
- include_archived=False,
- ):
+ user_id: str,
+ pagin_config: PaginationConfig,
+ as_client_event: bool = True,
+ include_archived: bool = False,
+ ) -> JsonDict:
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
- user_id (str): The ID of the user making the request.
- pagin_config (synapse.api.streams.PaginationConfig): The pagination
- config used to determine how many messages *PER ROOM* to return.
- as_client_event (bool): True to get events in client-server format.
- include_archived (bool): True to get rooms that the user has left
+ user_id: The ID of the user making the request.
+ pagin_config: The pagination config used to determine how many
+ messages *PER ROOM* to return.
+ as_client_event: True to get events in client-server format.
+ include_archived: True to get rooms that the user has left
Returns:
- A list of dicts with "room_id" and "membership" keys for all rooms
- the user is currently invited or joined in on. Rooms where the user
- is joined on, may return a "messages" key with messages, depending
- on the specified PaginationConfig.
+ A JsonDict with the same format as the response to `/intialSync`
+ API
"""
key = (
user_id,
@@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler):
async def _snapshot_all_rooms(
self,
- user_id=None,
- pagin_config=None,
- as_client_event=True,
- include_archived=False,
- ):
+ user_id: str,
+ pagin_config: PaginationConfig,
+ as_client_event: bool = True,
+ include_archived: bool = False,
+ ) -> JsonDict:
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
@@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- async def handle_room(event):
+ async def handle_room(event: RoomsForUser):
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler):
return ret
- async def room_initial_sync(self, requester, room_id, pagin_config=None):
+ async def room_initial_sync(
+ self, requester: Requester, room_id: str, pagin_config: PaginationConfig
+ ) -> JsonDict:
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
- requester(Requester): The user to get a snapshot for.
- room_id(str): The room to get a snapshot of.
- pagin_config(synapse.streams.config.PaginationConfig):
- The pagination config used to determine how many messages to
- return.
+ requester: The user to get a snapshot for.
+ room_id: The room to get a snapshot of.
+ pagin_config: The pagination config used to determine how many
+ messages to return.
Raises:
AuthError if the user wasn't in the room.
Returns:
@@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler):
return result
async def _room_initial_sync_parted(
- self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
- ):
+ self,
+ user_id: str,
+ room_id: str,
+ pagin_config: PaginationConfig,
+ membership: Membership,
+ member_event_id: str,
+ is_peeking: bool,
+ ) -> JsonDict:
room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler):
}
async def _room_initial_sync_joined(
- self, user_id, room_id, pagin_config, membership, is_peeking
- ):
+ self,
+ user_id: str,
+ room_id: str,
+ pagin_config: PaginationConfig,
+ membership: Membership,
+ is_peeking: bool,
+ ) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2643438e..8a7b4916 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,9 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet.interfaces import IDelayedCall
@@ -34,6 +35,7 @@ from synapse.api.errors import (
Codes,
ConsentNotGivenError,
NotFoundError,
+ ShadowBanError,
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
@@ -47,14 +49,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
-from synapse.types import (
- Collection,
- Requester,
- RoomAlias,
- StreamToken,
- UserID,
- create_requester,
-)
+from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
@@ -68,7 +64,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class MessageHandler(object):
+class MessageHandler:
"""Contains some read only APIs to get state about a room
"""
@@ -92,12 +88,7 @@ class MessageHandler(object):
)
async def get_room_data(
- self,
- user_id: str,
- room_id: str,
- event_type: str,
- state_key: str,
- is_guest: bool,
+ self, user_id: str, room_id: str, event_type: str, state_key: str,
) -> dict:
""" Get data from a room.
@@ -106,11 +97,10 @@ class MessageHandler(object):
room_id
event_type
state_key
- is_guest
Returns:
The path data content.
Raises:
- SynapseError if something went wrong.
+ SynapseError or AuthError if the user is not in the room
"""
(
membership,
@@ -127,6 +117,16 @@ class MessageHandler(object):
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
+ else:
+ # check_user_in_room_or_world_readable, if it doesn't raise an AuthError, should
+ # only ever return a Membership.JOIN/LEAVE object
+ #
+ # Safeguard in case it returned something else
+ logger.error(
+ "Attempted to retrieve data from a room for a user that has never been in it. "
+ "This should not have happened."
+ )
+ raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
return data
@@ -361,7 +361,7 @@ class MessageHandler(object):
_DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
-class EventCreationHandler(object):
+class EventCreationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
@@ -439,7 +439,7 @@ class EventCreationHandler(object):
event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
@@ -644,37 +644,48 @@ class EventCreationHandler(object):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
+ ignore_shadow_ban: bool = False,
) -> int:
"""
Persists and notifies local clients and federation of an event.
Args:
- requester
- event the event to send.
- context: the context of the event.
+ requester: The requester sending the event.
+ event: The event to send.
+ context: The context of the event.
ratelimit: Whether to rate limit this send.
+ ignore_shadow_ban: True if shadow-banned users should be allowed to
+ send this event.
Return:
The stream_id of the persisted event.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500, "Tried to send member event through non-member codepath"
)
+ if not ignore_shadow_ban and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
- prev_state = await self.deduplicate_state_event(event, context)
- if prev_state is not None:
+ prev_event = await self.deduplicate_state_event(event, context)
+ if prev_event is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
event.event_id,
- prev_state.event_id,
+ prev_event.event_id,
)
- return prev_state
+ return await self.store.get_stream_id_for_event(prev_event.event_id)
return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
@@ -682,27 +693,32 @@ class EventCreationHandler(object):
async def deduplicate_state_event(
self, event: EventBase, context: EventContext
- ) -> None:
+ ) -> Optional[EventBase]:
"""
Checks whether event is in the latest resolved state in context.
- If so, returns the version of the event in context.
- Otherwise, returns None.
+ Args:
+ event: The event to check for duplication.
+ context: The event context.
+
+ Returns:
+ The previous verion of the event is returned, if it is found in the
+ event context. Otherwise, None is returned.
"""
prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
- return
+ return None
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
- return
+ return None
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
- return
+ return None
async def create_and_send_nonmember_event(
self,
@@ -710,12 +726,28 @@ class EventCreationHandler(object):
event_dict: dict,
ratelimit: bool = True,
txn_id: Optional[str] = None,
+ ignore_shadow_ban: bool = False,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
+
+ Args:
+ requester: The requester sending the event.
+ event_dict: An entire event.
+ ratelimit: Whether to rate limit this send.
+ txn_id: The transaction ID.
+ ignore_shadow_ban: True if shadow-banned users should be allowed to
+ send this event.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
+ if not ignore_shadow_ban and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
# We limit the number of concurrent event sends in a room so that we
# don't fork the DAG too much. If we don't limit then we can end up in
@@ -734,7 +766,11 @@ class EventCreationHandler(object):
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
stream_id = await self.send_nonmember_event(
- requester, event, context, ratelimit=ratelimit
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
+ ignore_shadow_ban=ignore_shadow_ban,
)
return event, stream_id
@@ -743,7 +779,7 @@ class EventCreationHandler(object):
self,
builder: EventBuilder,
requester: Optional[Requester] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@@ -859,7 +895,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db
try:
dump = frozendict_json_encoder.encode(event.content)
- json.loads(dump)
+ json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
@@ -891,9 +927,7 @@ class EventCreationHandler(object):
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
+ await self.store.remove_push_actions_from_staging(event.event_id)
raise
async def _validate_canonical_alias(
@@ -957,7 +991,7 @@ class EventCreationHandler(object):
allow_none=True,
)
- is_admin_redaction = (
+ is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)
@@ -1077,8 +1111,8 @@ class EventCreationHandler(object):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_map = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -1176,8 +1210,14 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
+ # Since this is a dummy-event it is OK if it is sent by a
+ # shadow-banned user.
await self.send_nonmember_event(
- requester, event, context, ratelimit=False
+ requester,
+ event,
+ context,
+ ratelimit=False,
+ ignore_shadow_ban=True,
)
dummy_event_sent = True
break
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index fa5ee5de..1b06f317 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -38,8 +37,8 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -94,6 +93,7 @@ class OidcHandler:
"""
def __init__(self, hs: "HomeServer"):
+ self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
@@ -123,9 +123,7 @@ class OidcHandler:
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_error.html"]
- )[0]
+ self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
@@ -370,7 +368,7 @@ class OidcHandler:
# and check for an error field. If not, we respond with a generic
# error message.
try:
- resp = json.loads(resp_body.decode("utf-8"))
+ resp = json_decoder.decode(resp_body.decode("utf-8"))
error = resp["error"]
description = resp.get("error_description", error)
except (ValueError, KeyError):
@@ -387,7 +385,7 @@ class OidcHandler:
# Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not.
- resp = json.loads(resp_body.decode("utf-8"))
+ resp = json_decoder.decode(resp_body.decode("utf-8"))
if "error" in resp:
error = resp["error"]
@@ -692,9 +690,17 @@ class OidcHandler:
self._render_error(request, "invalid_token", str(e))
return
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(userinfo, token)
+ user_id = await self._map_userinfo_to_user(
+ userinfo, token, user_agent, ip_address
+ )
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
@@ -831,7 +837,9 @@ class OidcHandler:
now = self._clock.time_msec()
return now < expiry
- async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+ async def _map_userinfo_to_user(
+ self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
+ ) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
@@ -846,6 +854,8 @@ class OidcHandler:
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
Raises:
MappingException: if there was an error while mapping some properties
@@ -859,6 +869,9 @@ class OidcHandler:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ remote_user_id = str(remote_user_id)
logger.info(
"Looking for existing mapping for user %s:%s",
@@ -902,7 +915,9 @@ class OidcHandler:
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=attributes["display_name"],
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 487420bb..6067585f 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -14,23 +14,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, Optional, Set
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
-from synapse.types import RoomStreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Requester, RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
-class PurgeStatus(object):
+class PurgeStatus:
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
@@ -58,14 +65,14 @@ class PurgeStatus(object):
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
-class PaginationHandler(object):
+class PaginationHandler:
"""Handles pagination and purge history requests.
These are in the same handler due to the fact we need to block clients
paginating during a purge.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -75,13 +82,16 @@ class PaginationHandler(object):
self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock()
- self._purges_in_progress_by_room = set()
+ self._purges_in_progress_by_room = set() # type: Set[str]
# map from purge id to PurgeStatus
- self._purges_by_id = {}
+ self._purges_by_id = {} # type: Dict[str, PurgeStatus]
self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+ self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
+ self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
+
if hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
@@ -96,7 +106,9 @@ class PaginationHandler(object):
job["longest_max_lifetime"],
)
- async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+ async def purge_history_for_rooms_in_range(
+ self, min_ms: Optional[int], max_ms: Optional[int]
+ ):
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@@ -104,14 +116,14 @@ class PaginationHandler(object):
retention policy.
Args:
- min_ms (int|None): Duration in milliseconds that define the lower limit of
+ min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, it means that the range has no
lower limit.
- max_ms (int|None): Duration in milliseconds that define the upper limit of
+ max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, it means that the range has no
upper limit.
"""
- # We want the storage layer to to include rooms with no retention policy in its
+ # We want the storage layer to include rooms with no retention policy in its
# return value only if a default retention policy is defined in the server's
# configuration and that policy's 'max_lifetime' is either lower (or equal) than
# max_ms or higher than min_ms (or both).
@@ -152,13 +164,32 @@ class PaginationHandler(object):
)
continue
- max_lifetime = retention_policy["max_lifetime"]
+ # If max_lifetime is None, it means that the room has no retention policy.
+ # Given we only retrieve such rooms when there's a default retention policy
+ # defined in the server's configuration, we can safely assume that's the
+ # case and use it for this room.
+ max_lifetime = (
+ retention_policy["max_lifetime"] or self._retention_default_max_lifetime
+ )
- if max_lifetime is None:
- # If max_lifetime is None, it means that include_null equals True,
- # therefore we can safely assume that there is a default policy defined
- # in the server's configuration.
- max_lifetime = self._retention_default_max_lifetime
+ # Cap the effective max_lifetime to be within the range allowed in the
+ # config.
+ # We do this in two steps:
+ # 1. Make sure it's higher or equal to the minimum allowed value, and if
+ # it's not replace it with that value. This is because the server
+ # operator can be required to not delete information before a given
+ # time, e.g. to comply with freedom of information laws.
+ # 2. Make sure the resulting value is lower or equal to the maximum allowed
+ # value, and if it's not replace it with that value. This is because the
+ # server operator can be required to delete any data after a specific
+ # amount of time.
+ if self._retention_allowed_lifetime_min is not None:
+ max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
+
+ if self._retention_allowed_lifetime_max is not None:
+ max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
+
+ logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
# Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime
@@ -195,18 +226,19 @@ class PaginationHandler(object):
"_purge_history", self._purge_history, purge_id, room_id, token, True,
)
- def start_purge_history(self, room_id, token, delete_local_events=False):
+ def start_purge_history(
+ self, room_id: str, token: str, delete_local_events: bool = False
+ ) -> str:
"""Start off a history purge on a room.
Args:
- room_id (str): The room to purge from
-
- token (str): topological token to delete events before
- delete_local_events (bool): True to delete local events as well as
+ room_id: The room to purge from
+ token: topological token to delete events before
+ delete_local_events: True to delete local events as well as
remote ones
Returns:
- str: unique ID for this purge transaction.
+ unique ID for this purge transaction.
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
@@ -225,15 +257,16 @@ class PaginationHandler(object):
)
return purge_id
- async def _purge_history(self, purge_id, room_id, token, delete_local_events):
+ async def _purge_history(
+ self, purge_id: str, room_id: str, token: str, delete_local_events: bool
+ ) -> None:
"""Carry out a history purge on a room.
Args:
- purge_id (str): The id for this purge
- room_id (str): The room to purge from
- token (str): topological token to delete events before
- delete_local_events (bool): True to delete local events as well as
- remote ones
+ purge_id: The id for this purge
+ room_id: The room to purge from
+ token: topological token to delete events before
+ delete_local_events: True to delete local events as well as remote ones
"""
self._purges_in_progress_by_room.add(room_id)
try:
@@ -258,20 +291,17 @@ class PaginationHandler(object):
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
- def get_purge_status(self, purge_id):
+ def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
"""Get the current status of an active purge
Args:
- purge_id (str): purge_id returned by start_purge_history
-
- Returns:
- PurgeStatus|None
+ purge_id: purge_id returned by start_purge_history
"""
return self._purges_by_id.get(purge_id)
- async def purge_room(self, room_id):
+ async def purge_room(self, room_id: str) -> None:
"""Purge the given room from the database"""
- with (await self.pagination_lock.write(room_id)):
+ with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
@@ -285,23 +315,22 @@ class PaginationHandler(object):
async def get_messages(
self,
- requester,
- room_id=None,
- pagin_config=None,
- as_client_event=True,
- event_filter=None,
- ):
+ requester: Requester,
+ room_id: str,
+ pagin_config: PaginationConfig,
+ as_client_event: bool = True,
+ event_filter: Optional[Filter] = None,
+ ) -> Dict[str, Any]:
"""Get messages in a room.
Args:
- requester (Requester): The user requesting messages.
- room_id (str): The room they want messages from.
- pagin_config (synapse.api.streams.PaginationConfig): The pagination
- config rules to apply, if any.
- as_client_event (bool): True to get events in client-server format.
- event_filter (Filter): Filter to apply to results or None
+ requester: The user requesting messages.
+ room_id: The room they want messages from.
+ pagin_config: The pagination config rules to apply, if any.
+ as_client_event: True to get events in client-server format.
+ event_filter: Filter to apply to results or None
Returns:
- dict: Pagination API results
+ Pagination API results
"""
user_id = requester.user.to_string()
@@ -321,7 +350,7 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (await self.pagination_lock.read(room_id)):
+ with await self.pagination_lock.read(room_id):
(
membership,
member_event_id,
@@ -333,9 +362,9 @@ class PaginationHandler(object):
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
- max_topo = room_token.topological
+ curr_topo = room_token.topological
else:
- max_topo = await self.store.get_max_topological_token(
+ curr_topo = await self.store.get_current_topological_token(
room_id, room_token.stream
)
@@ -343,15 +372,19 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
+
+ # This is only None if the room is world_readable, in which
+ # case "JOIN" would have been returned.
+ assert member_event_id
+
leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < max_topo:
+ if RoomStreamToken.parse(leave_token).topological < curr_topo:
source_config.from_key = str(leave_token)
await self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, max_topo
+ room_id, curr_topo, limit=source_config.limit,
)
events, next_key = await self.store.paginate_room_events(
@@ -394,8 +427,8 @@ class PaginationHandler(object):
)
if state_ids:
- state = await self.store.get_events(list(state_ids.values()))
- state = state.values()
+ state_dict = await self.store.get_events(list(state_ids.values()))
+ state = state_dict.values()
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index d06b1102..88e2f872 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -22,7 +22,7 @@ from synapse.api.errors import Codes, PasswordRefusedError
logger = logging.getLogger(__name__)
-class PasswordPolicyHandler(object):
+class PasswordPolicyHandler:
def __init__(self, hs):
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5387b372..91a3aec1 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -33,14 +33,14 @@ from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
+from synapse.api.presence import UserPresenceState
from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
-from synapse.storage.presence import UserPresenceState
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@@ -1010,7 +1010,7 @@ def format_user_presence_state(state, now, include_user_id=True):
return content
-class PresenceEventSource(object):
+class PresenceEventSource:
def __init__(self, hs):
# We can't call get_presence_handler here because there's a cycle:
#
@@ -1318,7 +1318,7 @@ async def get_interested_parties(
async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
-) -> List[Tuple[List[str], List[UserPresenceState]]]:
+) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
- hosts_and_states = []
+ hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 31a2e5ea..0cb8fad8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+import random
from synapse.api.errors import (
AuthError,
@@ -160,6 +161,9 @@ class BaseProfileHandler(BaseHandler):
Codes.FORBIDDEN,
)
+ if not isinstance(new_displayname, str):
+ raise SynapseError(400, "Invalid displayname")
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -213,8 +217,14 @@ class BaseProfileHandler(BaseHandler):
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
):
- """target_user is the user whose avatar_url is to be changed;
- auth_user is the user attempting to make this change."""
+ """Set a new avatar URL for a user.
+
+ Args:
+ target_user (UserID): the user whose avatar URL is to be changed.
+ requester (Requester): The user attempting to make this change.
+ new_avatar_url (str): The avatar URL to give this user.
+ by_admin (bool): Whether this change was made by an administrator.
+ """
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -228,6 +238,9 @@ class BaseProfileHandler(BaseHandler):
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
)
+ if not isinstance(new_avatar_url, str):
+ raise SynapseError(400, "Invalid displayname")
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
@@ -278,6 +291,12 @@ class BaseProfileHandler(BaseHandler):
await self.ratelimit(requester)
+ # Do not actually update the room state for shadow-banned users.
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ return
+
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f922d8a5..2cc6c2eb 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -123,7 +123,7 @@ class ReceiptsHandler(BaseHandler):
await self.federation.send_read_receipt(receipt)
-class ReceiptEventSource(object):
+class ReceiptEventSource:
def __init__(self, hs):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c94209ab..cde2dbca 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -26,6 +26,7 @@ from synapse.replication.http.register import (
ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet,
)
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
@@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self.spam_checker = hs.get_spam_checker()
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -124,7 +127,9 @@ class RegistrationHandler(BaseHandler):
try:
int(localpart)
raise SynapseError(
- 400, "Numeric user IDs are reserved for guest users."
+ 400,
+ "Numeric user IDs are reserved for guest users.",
+ errcode=Codes.INVALID_USERNAME,
)
except ValueError:
pass
@@ -142,6 +147,7 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
+ user_agent_ips=None,
):
"""Registers a new client on the server.
@@ -159,6 +165,8 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
+ user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+ during the registration process.
Returns:
str: user_id
Raises:
@@ -166,6 +174,24 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
+ result = self.spam_checker.check_registration_for_spam(
+ threepid, localpart, user_agent_ips or [],
+ )
+
+ if result == RegistrationBehaviour.DENY:
+ logger.info(
+ "Blocked registration of %r", localpart,
+ )
+ # We return a 429 to make it not obvious that they've been
+ # denied.
+ raise SynapseError(429, "Rate limited")
+
+ shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
+ if shadow_banned:
+ logger.info(
+ "Shadow banning registration of %r", localpart,
+ )
+
# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid)
@@ -194,6 +220,7 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
+ shadow_banned=shadow_banned,
)
if self.hs.config.user_directory_search_all_users:
@@ -224,6 +251,7 @@ class RegistrationHandler(BaseHandler):
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
address=address,
+ shadow_banned=shadow_banned,
)
# Successfully registered
@@ -529,6 +557,7 @@ class RegistrationHandler(BaseHandler):
admin=False,
user_type=None,
address=None,
+ shadow_banned=False,
):
"""Register user in the datastore.
@@ -546,6 +575,7 @@ class RegistrationHandler(BaseHandler):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the registration.
+ shadow_banned (bool): Whether to shadow-ban the user
Returns:
Awaitable
@@ -561,6 +591,7 @@ class RegistrationHandler(BaseHandler):
admin=admin,
user_type=user_type,
address=address,
+ shadow_banned=shadow_banned,
)
else:
return self.store.register_user(
@@ -572,6 +603,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
+ shadow_banned=shadow_banned,
)
async def register_device(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a8545255..a29305f6 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,9 +20,10 @@
import itertools
import logging
import math
+import random
import string
from collections import OrderedDict
-from typing import Awaitable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
@@ -32,11 +33,15 @@ from synapse.api.constants import (
RoomEncryptionAlgorithms,
)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
+ JsonDict,
+ MutableStateMap,
Requester,
RoomAlias,
RoomID,
@@ -47,12 +52,15 @@ from synapse.types import (
create_requester,
)
from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer, maybe_awaitable
+from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -61,7 +69,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
@@ -92,7 +100,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
- }
+ } # type: Dict[str, Dict[str, Any]]
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -129,6 +137,9 @@ class RoomCreationHandler(BaseHandler):
Returns:
the new room id
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned.
"""
await self.ratelimit(requester)
@@ -164,6 +175,15 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
+ """
+ Args:
+ requester: the user requesting the upgrade
+ old_room_id: the id of the room to be replaced
+ new_versions: the version to upgrade the room to
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned.
+ """
user_id = requester.user.to_string()
# start by allocating a new room id
@@ -215,6 +235,9 @@ class RoomCreationHandler(BaseHandler):
old_room_state = await tombstone_context.get_current_state_ids()
+ # We know the tombstone event isn't an outlier so it has current state.
+ assert old_room_state is not None
+
# update any aliases
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
@@ -247,6 +270,9 @@ class RoomCreationHandler(BaseHandler):
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned.
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
@@ -425,7 +451,7 @@ class RoomCreationHandler(BaseHandler):
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
- for k, old_event in old_room_member_state_events.items():
+ for old_event in old_room_member_state_events.values():
# Only transfer ban events
if (
"membership" in old_event.content
@@ -528,17 +554,21 @@ class RoomCreationHandler(BaseHandler):
logger.error("Unable to send updated alias events in new room: %s", e)
async def create_room(
- self, requester, config, ratelimit=True, creator_join_profile=None
+ self,
+ requester: Requester,
+ config: JsonDict,
+ ratelimit: bool = True,
+ creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
- requester (synapse.types.Requester):
+ requester:
The user who requested the room creation.
- config (dict) : A dict of configuration options.
- ratelimit (bool): set to False to disable the rate limiter
+ config : A dict of configuration options.
+ ratelimit: set to False to disable the rate limiter
- creator_join_profile (dict|None):
+ creator_join_profile:
Set to override the displayname and avatar for the creating
user in this room. If unset, displayname and avatar will be
derived from the user's profile. If set, should contain the
@@ -601,6 +631,7 @@ class RoomCreationHandler(BaseHandler):
Codes.UNSUPPORTED_ROOM_VERSION,
)
+ room_alias = None
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
@@ -611,9 +642,8 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
- else:
- room_alias = None
+ invite_3pid_list = config.get("invite_3pid", [])
invite_list = config.get("invite", [])
for i in invite_list:
try:
@@ -622,6 +652,14 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
+ if (invite_list or invite_3pid_list) and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+
+ # Allow the request to go through, but remove any associated invites.
+ invite_3pid_list = []
+ invite_list = []
+
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
@@ -636,8 +674,6 @@ class RoomCreationHandler(BaseHandler):
% (user_id,),
)
- invite_3pid_list = config.get("invite_3pid", [])
-
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -732,6 +768,8 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
+ # Note that update_membership with an action of "invite" can raise a
+ # ShadowBanError, but this was handled above by emptying invite_list.
_, last_stream_id = await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
@@ -746,6 +784,8 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
+ # Note that do_3pid_invite can raise a ShadowBanError, but this was
+ # handled above by emptying invite_3pid_list.
last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
@@ -771,23 +811,30 @@ class RoomCreationHandler(BaseHandler):
async def _send_events_for_new_room(
self,
- creator, # A Requester object.
- room_id,
- preset_config,
- invite_list,
- initial_state,
- creation_content,
- room_alias=None,
- power_level_content_override=None, # Doesn't apply when initial state has power level state event content
- creator_join_profile=None,
+ creator: Requester,
+ room_id: str,
+ preset_config: str,
+ invite_list: List[str],
+ initial_state: MutableStateMap,
+ creation_content: JsonDict,
+ room_alias: Optional[RoomAlias] = None,
+ power_level_content_override: Optional[JsonDict] = None,
+ creator_join_profile: Optional[JsonDict] = None,
) -> int:
"""Sends the initial events into a new room.
+ `power_level_content_override` doesn't apply when initial state has
+ power level state event content.
+
Returns:
The stream_id of the last event persisted.
"""
- def create(etype, content, **kwargs):
+ creator_id = creator.user.to_string()
+
+ event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
+
+ def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -795,23 +842,21 @@ class RoomCreationHandler(BaseHandler):
return e
- async def send(etype, content, **kwargs) -> int:
+ async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
+ # Allow these events to be sent even if the user is shadow-banned to
+ # allow the room creation to complete.
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- creator, event, ratelimit=False
+ creator, event, ratelimit=False, ignore_shadow_ban=True,
)
return last_stream_id
config = self._presets_dict[preset_config]
- creator_id = creator.user.to_string()
-
- event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
-
creation_content.update({"creator": creator_id})
await send(etype=EventTypes.Create, content=creation_content)
@@ -852,7 +897,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
- }
+ } # type: JsonDict
if config["original_invitees_have_ops"]:
for invitee in invite_list:
@@ -906,7 +951,7 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
- self, creator_id: str, is_public: str, room_version: RoomVersion,
+ self, creator_id: str, is_public: bool, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -929,24 +974,31 @@ class RoomCreationHandler(BaseHandler):
raise StoreError(500, "Couldn't generate a room ID.")
-class RoomContextHandler(object):
- def __init__(self, hs):
+class RoomContextHandler:
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
- async def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(
+ self,
+ user: UserID,
+ room_id: str,
+ event_id: str,
+ limit: int,
+ event_filter: Optional[Filter],
+ ) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
- user (UserID)
- room_id (str)
- event_id (str)
- limit (int): The maximum number of events to return in total
+ user
+ room_id
+ event_id
+ limit: The maximum number of events to return in total
(excluding state).
- event_filter (Filter|None): the filter to apply to the events returned
+ event_filter: the filter to apply to the events returned
(excluding the target event_id)
Returns:
@@ -1032,13 +1084,19 @@ class RoomContextHandler(object):
return results
-class RoomEventSource(object):
- def __init__(self, hs):
+class RoomEventSource:
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
async def get_new_events(
- self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
- ):
+ self,
+ user: UserID,
+ from_key: str,
+ limit: int,
+ room_ids: List[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
+ ) -> Tuple[List[EventBase], str]:
# We just ignore the key for now.
to_key = self.get_current_key()
@@ -1088,7 +1146,7 @@ class RoomEventSource(object):
return self.store.get_room_events_max_id(room_id)
-class RoomShutdownHandler(object):
+class RoomShutdownHandler:
DEFAULT_MESSAGE = (
"Sharing illegal content on this server is not permitted and rooms in"
@@ -1096,7 +1154,7 @@ class RoomShutdownHandler(object):
)
DEFAULT_ROOM_NAME = "Content Violation Notification"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.room_member_handler = hs.get_room_member_handler()
self._room_creation_handler = hs.get_room_creation_handler()
@@ -1272,9 +1330,7 @@ class RoomShutdownHandler(object):
ratelimit=False,
)
- aliases_for_room = await maybe_awaitable(
- self.store.get_aliases_for_room(room_id)
- )
+ aliases_for_room = await self.store.get_aliases_for_room(room_id)
await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9fcabb22..32b7e323 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -15,14 +15,21 @@
import abc
import logging
+import random
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ LimitExceededError,
+ ShadowBanError,
+ SynapseError,
+)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -31,7 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
-from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
+from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -44,7 +51,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class RoomMemberHandler(object):
+class RoomMemberHandler:
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
@@ -169,7 +176,7 @@ class RoomMemberHandler(object):
target: UserID,
room_id: str,
membership: str,
- prev_event_ids: Collection[str],
+ prev_event_ids: List[str],
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
@@ -301,6 +308,31 @@ class RoomMemberHandler(object):
content: Optional[dict] = None,
require_consent: bool = True,
) -> Tuple[str, int]:
+ """Update a user's membership in a room.
+
+ Params:
+ requester: The user who is performing the update.
+ target: The user whose membership is being updated.
+ room_id: The room ID whose membership is being updated.
+ action: The membership change, see synapse.api.constants.Membership.
+ txn_id: The transaction ID, if given.
+ remote_room_hosts: Remote servers to send the update to.
+ third_party_signed: Information from a 3PID invite.
+ ratelimit: Whether to rate limit the request.
+ content: The content of the created event.
+ require_consent: Whether consent is required.
+
+ Returns:
+ A tuple of the new event ID and stream ID.
+
+ Raises:
+ ShadowBanError if a shadow-banned requester attempts to send an invite.
+ """
+ if action == Membership.INVITE and requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -340,7 +372,7 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
- if not self.allow_per_room_profiles:
+ if not self.allow_per_room_profiles or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
@@ -710,9 +742,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
- async def _can_guest_join(
- self, current_state_ids: Dict[Tuple[str, str], str]
- ) -> bool:
+ async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@@ -722,7 +752,7 @@ class RoomMemberHandler(object):
guest_access = await self.store.get_event(guest_access_id)
- return (
+ return bool(
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -779,6 +809,25 @@ class RoomMemberHandler(object):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
+ """Invite a 3PID to a room.
+
+ Args:
+ room_id: The room to invite the 3PID to.
+ inviter: The user sending the invite.
+ medium: The 3PID's medium.
+ address: The 3PID's address.
+ id_server: The identity server to use.
+ requester: The user making the request.
+ txn_id: The transaction ID this is part of, or None if this is not
+ part of a transaction.
+ id_access_token: The optional identity server access token.
+
+ Returns:
+ The new stream ID.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
+ """
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@@ -786,6 +835,11 @@ class RoomMemberHandler(object):
403, "Invites have been disabled on this server", Codes.FORBIDDEN
)
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
await self.base_handler.ratelimit(requester)
@@ -810,6 +864,8 @@ class RoomMemberHandler(object):
)
if invitee:
+ # Note that update_membership with an action of "invite" can raise
+ # a ShadowBanError, but this was done above already.
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
@@ -915,9 +971,7 @@ class RoomMemberHandler(object):
)
return stream_id
- async def _is_host_in_room(
- self, current_state_ids: Dict[Tuple[str, str], str]
- ) -> bool:
+ async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -1048,7 +1102,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
- requester = types.create_requester(user, None, False, None)
+ requester = types.create_requester(user, None, False, False, None)
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index c1fcb984..66b063f9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -54,6 +54,7 @@ class Saml2SessionData:
class SamlHandler:
def __init__(self, hs: "synapse.server.HomeServer"):
+ self.hs = hs
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -133,8 +134,14 @@ class SamlHandler:
# the dict.
self.expire_sessions()
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
user_id, current_session = await self._map_saml_response_to_user(
- resp_bytes, relay_state
+ resp_bytes, relay_state, user_agent, ip_address
)
# Complete the interactive auth session or the login.
@@ -147,7 +154,11 @@ class SamlHandler:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
- self, resp_bytes: str, client_redirect_url: str
+ self,
+ resp_bytes: str,
+ client_redirect_url: str,
+ user_agent: str,
+ ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
"""
Given a sample response, retrieve the cached session and user for it.
@@ -155,6 +166,8 @@ class SamlHandler:
Args:
resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
Returns:
Tuple of the user ID and SAML session associated with this response.
@@ -291,6 +304,7 @@ class SamlHandler:
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
+ user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
@@ -346,12 +360,12 @@ MXID_MAPPER_MAP = {
@attr.s
-class SamlConfig(object):
+class SamlConfig:
mxid_source_attribute = attr.ib()
mxid_mapper = attr.ib()
-class DefaultSamlMappingProvider(object):
+class DefaultSamlMappingProvider:
__version__ = "0.0.1"
def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index 8590c1ef..7a4ae072 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
-class StateDeltasHandler(object):
+class StateDeltasHandler:
def __init__(self, hs):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c42dac18..e2ddb628 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -16,7 +16,7 @@
import itertools
import logging
-from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -31,6 +31,7 @@ from synapse.storage.state import StateFilter
from synapse.types import (
Collection,
JsonDict,
+ MutableStateMap,
RoomStreamToken,
StateMap,
StreamToken,
@@ -43,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
@@ -94,7 +98,12 @@ class TimelineBatch:
__bool__ = __nonzero__ # python3
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
class JoinedSyncResult:
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
@@ -103,6 +112,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
+ unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -236,8 +246,8 @@ class SyncResult:
__bool__ = __nonzero__ # python3
-class SyncHandler(object):
- def __init__(self, hs):
+class SyncHandler:
+ def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -588,7 +598,7 @@ class SyncHandler(object):
room_id: str,
sync_config: SyncConfig,
batch: TimelineBatch,
- state: StateMap[EventBase],
+ state: MutableStateMap[EventBase],
now_token: StreamToken,
) -> Optional[JsonDict]:
""" Works out a room summary block for this room, summarising the number
@@ -710,9 +720,8 @@ class SyncHandler(object):
]
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
- missing_hero_state = missing_hero_state.values()
- for s in missing_hero_state:
+ for s in missing_hero_state.values():
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
@@ -736,7 +745,7 @@ class SyncHandler(object):
since_token: Optional[StreamToken],
now_token: StreamToken,
full_state: bool,
- ) -> StateMap[EventBase]:
+ ) -> MutableStateMap[EventBase]:
""" Works out the difference in state between the start of the timeline
and the previous sync.
@@ -930,7 +939,7 @@ class SyncHandler(object):
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
- ) -> Optional[Dict[str, str]]:
+ ) -> Dict[str, int]:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
@@ -938,15 +947,10 @@ class SyncHandler(object):
receipt_type="m.read",
)
- if last_unread_event_id:
- notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
- room_id, sync_config.user.to_string(), last_unread_event_id
- )
- return notifs
-
- # There is no new information in this period, so your notification
- # count is whatever it was last time.
- return None
+ notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+ room_id, sync_config.user.to_string(), last_unread_event_id
+ )
+ return notifs
async def generate_sync_result(
self,
@@ -1769,7 +1773,7 @@ class SyncHandler(object):
ignored_users: Set[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
- tags: Optional[List[JsonDict]],
+ tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):
@@ -1885,7 +1889,7 @@ class SyncHandler(object):
)
if room_builder.rtype == "joined":
- unread_notifications = {} # type: Dict[str, str]
+ unread_notifications = {} # type: Dict[str, int]
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1894,14 +1898,16 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
+ unread_count=0,
)
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
- if notifs is not None:
- unread_notifications["notification_count"] = notifs["notify_count"]
- unread_notifications["highlight_count"] = notifs["highlight_count"]
+ unread_notifications["notification_count"] = notifs["notify_count"]
+ unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+ room_sync.unread_count = notifs["unread_count"]
sync_result_builder.joined.append(room_sync)
@@ -2069,7 +2075,7 @@ class SyncResultBuilder:
@attr.s
-class RoomSyncResultBuilder(object):
+class RoomSyncResultBuilder:
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index a86ac015..3cbfc2d7 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -14,10 +14,11 @@
# limitations under the License.
import logging
+import random
from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
@@ -227,9 +228,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
- async def started_typing(self, target_user, auth_user, room_id, timeout):
+ async def started_typing(self, target_user, requester, room_id, timeout):
target_user_id = target_user.to_string()
- auth_user_id = auth_user.to_string()
+ auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -237,6 +238,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
await self.auth.check_user_in_room(room_id, target_user_id)
logger.debug("%s has started typing in %s", target_user_id, room_id)
@@ -256,9 +262,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update(member=member, typing=True)
- async def stopped_typing(self, target_user, auth_user, room_id):
+ async def stopped_typing(self, target_user, requester, room_id):
target_user_id = target_user.to_string()
- auth_user_id = auth_user.to_string()
+ auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -266,6 +272,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
+ if requester.shadow_banned:
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
await self.auth.check_user_in_room(room_id, target_user_id)
logger.debug("%s has stopped typing in %s", target_user_id, room_id)
@@ -401,7 +412,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication")
-class TypingNotificationEventSource(object):
+class TypingNotificationEventSource:
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index a011e9fe..9146dc1a 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -16,13 +16,12 @@
import logging
from typing import Any
-from canonicaljson import json
-
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = json.loads(data.decode("utf-8"))
+ resp_body = json_decoder.decode(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 521b6d62..e21f8dbc 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def _handle_room_publicity_change(
self, room_id, prev_event_id, event_id, typ
):
- """Handle a room having potentially changed from/to world_readable/publically
+ """Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
@@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_name = prev_event.content.get("displayname")
new_name = event.content.get("displayname")
+ # If the new name is an unexpected form, do not update the directory.
+ if not isinstance(new_name, str):
+ new_name = prev_name
prev_avatar = prev_event.content.get("avatar_url")
new_avatar = event.content.get("avatar_url")
+ # If the new avatar is an unexpected form, do not update the directory.
+ if not isinstance(new_avatar, str):
+ new_avatar = prev_avatar
if prev_name != new_name or prev_avatar != new_avatar:
await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)