diff options
Diffstat (limited to 'synapse/rest')
20 files changed, 311 insertions, 135 deletions
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index ffd3aa38..5996de11 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError from synapse.http.servlet import ( @@ -20,8 +21,12 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet): "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() - async def on_GET(self, request, user_id, device_id): + async def on_GET( + self, request: SynapseRequest, user_id, device_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet): ) return 200, device - async def on_DELETE(self, request, user_id, device_id): + async def on_DELETE( + self, request: SynapseRequest, user_id: str, device_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet): await self.device_handler.delete_device(target_user.to_string(), device_id) return 200, {} - async def on_PUT(self, request, user_id, device_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str, device_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): """ Args: hs (synapse.server.HomeServer): server @@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index fd482f0e..381c3fe6 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -14,10 +14,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) @@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, report_id): + async def on_GET( + self, request: SynapseRequest, report_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) message = ( "The report_id parameter must be a string representing a positive integer." ) try: - report_id = int(report_id) + resolved_report_id = int(report_id) except ValueError: raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) - if report_id < 0: + if resolved_report_id < 0: raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) - ret = await self.store.get_event_report(report_id) + ret = await self.store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index b996862c..511c859f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,7 +17,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1a3a36f6..f2c42a0f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -44,6 +44,48 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ResolveRoomIdMixin: + def __init__(self, hs: "HomeServer"): + self.room_member_handler = hs.get_room_member_handler() + + async def resolve_room_id( + self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None + ) -> Tuple[str, Optional[List[str]]]: + """ + Resolve a room identifier to a room ID, if necessary. + + This also performanes checks to ensure the room ID is of the proper form. + + Args: + room_identifier: The room ID or alias. + remote_room_hosts: The potential remote room hosts to use. + + Returns: + The resolved room ID. + + Raises: + SynapseError if the room ID is of the wrong form. + """ + if RoomID.is_valid(room_identifier): + resolved_room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + ( + room_id, + remote_room_hosts, + ) = await self.room_member_handler.lookup_room_alias(room_alias) + resolved_room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id, remote_room_hosts + + class ShutdownRoomRestServlet(RestServlet): """Shuts down a room by removing all local users from the room and blocking all future invites and joins to the room. Any local aliases will be repointed @@ -334,14 +376,14 @@ class RoomStateRestServlet(RestServlet): return 200, ret -class JoinRoomAliasServlet(RestServlet): +class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") def __init__(self, hs: "HomeServer"): + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() - self.room_member_handler = hs.get_room_member_handler() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() @@ -362,22 +404,16 @@ class JoinRoomAliasServlet(RestServlet): if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") - if RoomID.is_valid(room_identifier): - room_id = room_identifier - try: - remote_room_hosts = [ - x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] - except Exception: - remote_room_hosts = None - elif RoomAlias.is_valid(room_identifier): - handler = self.room_member_handler - room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) + # Get the room ID from the identifier. + try: + remote_room_hosts = [ + x.decode("ascii") for x in request.args[b"server_name"] + ] # type: Optional[List[str]] + except Exception: + remote_room_hosts = None + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, remote_room_hosts + ) fake_requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity @@ -412,7 +448,7 @@ class JoinRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} -class MakeRoomAdminRestServlet(RestServlet): +class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): """Allows a server admin to get power in a room if a local user has power in a room. Will also invite the user if they're not in the room and it's a private room. Can specify another user (rather than the admin user) to be @@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") def __init__(self, hs: "HomeServer"): + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() - self.room_member_handler = hs.get_room_member_handler() self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id - async def on_POST(self, request, room_identifier): + async def on_POST( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) content = parse_json_object_from_request(request, allow_empty_body=True) - # Resolve to a room ID, if necessary. - if RoomID.is_valid(room_identifier): - room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) + room_id, _ = await self.resolve_room_id(room_identifier) # Which user to grant room admin rights to. user_to_add = content.get("user_id", requester.user.to_string()) @@ -556,7 +584,7 @@ class MakeRoomAdminRestServlet(RestServlet): return 200, {} -class ForwardExtremitiesRestServlet(RestServlet): +class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): """Allows a server admin to get or clear forward extremities. Clearing does not require restarting the server. @@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") def __init__(self, hs: "HomeServer"): + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() - self.room_member_handler = hs.get_room_member_handler() self.store = hs.get_datastore() - async def resolve_room_id(self, room_identifier: str) -> str: - """Resolve to a room ID, if necessary.""" - if RoomID.is_valid(room_identifier): - resolved_room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) - resolved_room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - if not resolved_room_id: - raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier - ) - return resolved_room_id - - async def on_DELETE(self, request, room_identifier): + async def on_DELETE( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) - room_id = await self.resolve_room_id(room_identifier) + room_id, _ = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) return 200, {"deleted": deleted_count} - async def on_GET(self, request, room_identifier): + async def on_GET( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) - room_id = await self.resolve_room_id(room_identifier) + room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) return 200, {"count": len(extremities), "results": extremities} @@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) await assert_user_is_admin(self.auth, requester.user) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 998a0ef6..267a9934 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -16,7 +16,7 @@ import hashlib import hmac import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -35,6 +35,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.types import JsonDict, UserID if TYPE_CHECKING: @@ -46,13 +47,15 @@ logger = logging.getLogger(__name__) class UsersRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, List[JsonDict]]: target_user = UserID.from_string(user_id) await assert_requester_is_admin(self.auth, request) @@ -152,7 +155,7 @@ class UserRestServletV2(RestServlet): otherwise an error. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -164,7 +167,9 @@ class UserRestServletV2(RestServlet): self.registration_handler = hs.get_registration_handler() self.pusher_pool = hs.get_pusherpool() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -178,7 +183,9 @@ class UserRestServletV2(RestServlet): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -272,6 +279,8 @@ class UserRestServletV2(RestServlet): ) user = await self.admin_handler.get_user(target_user) + assert user is not None + return 200, user else: # create user @@ -329,9 +338,10 @@ class UserRestServletV2(RestServlet): target_user, requester, body["avatar_url"], True ) - ret = await self.admin_handler.get_user(target_user) + user = await self.admin_handler.get_user(target_user) + assert user is not None - return 201, ret + return 201, user class UserRegisterServlet(RestServlet): @@ -345,10 +355,10 @@ class UserRegisterServlet(RestServlet): PATTERNS = admin_patterns("/register") NONCE_TIMEOUT = 60 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.auth_handler = hs.get_auth_handler() self.reactor = hs.get_reactor() - self.nonces = {} + self.nonces = {} # type: Dict[str, int] self.hs = hs def _clear_old_nonces(self): @@ -361,7 +371,7 @@ class UserRegisterServlet(RestServlet): if now - v > self.NONCE_TIMEOUT: del self.nonces[k] - def on_GET(self, request): + def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Generate a new nonce. """ @@ -371,7 +381,7 @@ class UserRegisterServlet(RestServlet): self.nonces[nonce] = int(self.reactor.seconds()) return 200, {"nonce": nonce} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() if not self.hs.config.registration_shared_secret: @@ -477,12 +487,14 @@ class WhoisRestServlet(RestServlet): client_patterns("/admin" + path_regex, v1=True) ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) auth_user = requester.user @@ -507,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet): self.is_mine = hs.is_mine self.store = hs.get_datastore() - async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, target_user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -549,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet): self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) @@ -583,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet): PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() - async def on_POST(self, request, target_user_id): + async def on_POST( + self, request: SynapseRequest, target_user_id: str + ) -> Tuple[int, JsonDict]: """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. """ @@ -625,12 +641,14 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, target_user_id): + async def on_GET( + self, request: SynapseRequest, target_user_id: str + ) -> Tuple[int, Optional[List[JsonDict]]]: """Get request to search user table for specific users according to search term. This needs user to have a administrator access in Synapse. @@ -681,12 +699,14 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -698,7 +718,9 @@ class UserAdminServlet(RestServlet): return 200, {"admin": is_admin} - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user @@ -729,12 +751,14 @@ class UserMembershipRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) room_ids = await self.store.get_rooms_for_user(user_id) @@ -757,7 +781,7 @@ class PushersRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -798,7 +822,7 @@ class UserMediaRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -832,8 +856,33 @@ class UserMediaRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + media, total = await self.store.get_local_media_by_user_paginate( - start, limit, user_id + start, limit, user_id, order_by, direction ) ret = {"media": media, "total": total} @@ -865,7 +914,9 @@ class UserTokenRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user @@ -917,7 +968,9 @@ class ShadowBanRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 6e2fbedd..925edfc4 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService from synapse.handlers.sso import SsoIdentityProvider +from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, @@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet): hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._public_baseurl = hs.config.public_baseurl def register(self, http_server: HttpServer) -> None: super().register(http_server) @@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet): async def on_GET( self, request: SynapseRequest, idp_id: Optional[str] = None ) -> None: + if not self._public_baseurl: + raise SynapseError(400, "SSO requires a valid public_baseurl") + + # if this isn't the expected hostname, redirect to the right one, so that we + # get our cookies back. + requested_uri = get_request_uri(request) + baseurl_bytes = self._public_baseurl.encode("utf-8") + if not requested_uri.startswith(baseurl_bytes): + # swap out the incorrect base URL for the right one. + # + # The idea here is to redirect from + # https://foo.bar/whatever/_matrix/... + # to + # https://public.baseurl/_matrix/... + # + i = requested_uri.index(b"/_matrix") + new_uri = baseurl_bytes[:-1] + requested_uri[i:] + logger.info( + "Requested URI %s is not canonical: redirecting to %s", + requested_uri.decode("utf-8", errors="replace"), + new_uri.decode("utf-8", errors="replace"), + ) + request.redirect(new_uri) + finish_request(request) + return + client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index d3434225..7aea4ceb 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -18,7 +18,7 @@ import logging from functools import wraps from typing import TYPE_CHECKING, Optional, Tuple -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.constants import ( MAX_GROUP_CATEGORYID_LENGTH, diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index a3dee14e..79c1b526 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("messages",)) - sender_user_id = requester.user.to_string() - await self.device_message_handler.send_device_message( - sender_user_id, message_type, content["messages"] + requester, message_type, content["messages"] ) response = (200, {}) # type: Tuple[int, dict] diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 90bbeca6..63669470 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -21,7 +21,7 @@ from typing import Awaitable, Dict, Generator, List, Optional, Tuple from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json @@ -49,18 +49,20 @@ TEXT_CONTENT_TYPES = [ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: + # The type on postpath seems incorrect in Twisted 21.2.0. + postpath = request.postpath # type: List[bytes] # type: ignore + assert postpath + # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. - server_name, media_id = request.postpath[:2] - - if isinstance(server_name, bytes): - server_name = server_name.decode("utf-8") - media_id = media_id.decode("utf8") + server_name_bytes, media_id_bytes = postpath[:2] + server_name = server_name_bytes.decode("utf-8") + media_id = media_id_bytes.decode("utf8") file_name = None - if len(request.postpath) > 2: + if len(postpath) > 2: try: - file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8")) + file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) except UnicodeDecodeError: pass return server_name, media_id, file_name diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 4e4c6971..9039662f 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING -from twisted.web.http import Request +from twisted.web.server import Request from synapse.http.server import DirectServeJsonResource, respond_with_json diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 48f44331..8a43581f 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING -from twisted.web.http import Request +from twisted.web.server import Request from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_boolean diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a0162d42..0641924f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -22,8 +22,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple import twisted.internet.error import twisted.web.http -from twisted.web.http import Request from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.api.errors import ( FederationDeniedError, @@ -509,7 +509,7 @@ class MediaRepository: t_height: int, t_method: str, t_type: str, - url_cache: str, + url_cache: Optional[str], ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 1057e638..b1b1c9e6 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -244,7 +244,7 @@ class MediaStorage: await consumer.wait() return local_path - raise Exception("file could not be found") + raise NotFoundError() def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 6104ef4e..a074e807 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -29,7 +29,7 @@ from urllib import parse as urlparse import attr from twisted.internet.error import DNSLookupError -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -149,8 +149,7 @@ class PreviewUrlResource(DirectServeJsonResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, - http_proxy=os.getenvb(b"http_proxy"), - https_proxy=os.getenvb(b"HTTPS_PROXY"), + use_proxy=True, ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d653a58b..fbcd50f1 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -18,7 +18,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers @@ -114,6 +114,7 @@ class ThumbnailResource(DirectServeJsonResource): m_type, thumbnail_infos, media_id, + media_id, url_cache=media_info["url_cache"], server_name=None, ) @@ -269,6 +270,7 @@ class ThumbnailResource(DirectServeJsonResource): method, m_type, thumbnail_infos, + media_id, media_info["filesystem_id"], url_cache=None, server_name=server_name, @@ -282,6 +284,7 @@ class ThumbnailResource(DirectServeJsonResource): desired_method: str, desired_type: str, thumbnail_infos: List[Dict[str, Any]], + media_id: str, file_id: str, url_cache: Optional[str] = None, server_name: Optional[str] = None, @@ -317,8 +320,59 @@ class ThumbnailResource(DirectServeJsonResource): return responder = await self.media_storage.fetch_media(file_info) + if responder: + await respond_with_responder( + request, + responder, + file_info.thumbnail_type, + file_info.thumbnail_length, + ) + return + + # If we can't find the thumbnail we regenerate it. This can happen + # if e.g. we've deleted the thumbnails but still have the original + # image somewhere. + # + # Since we have an entry for the thumbnail in the DB we a) know we + # have have successfully generated the thumbnail in the past (so we + # don't need to worry about repeatedly failing to generate + # thumbnails), and b) have already calculated that appropriate + # width/height/method so we can just call the "generate exact" + # methods. + + # First let's check that we do actually have the original image + # still. This will throw a 404 if we don't. + # TODO: We should refetch the thumbnails for remote media. + await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) + + if server_name: + await self.media_repo.generate_remote_exact_thumbnail( + server_name, + file_id=file_id, + media_id=media_id, + t_width=file_info.thumbnail_width, + t_height=file_info.thumbnail_height, + t_method=file_info.thumbnail_method, + t_type=file_info.thumbnail_type, + ) + else: + await self.media_repo.generate_local_exact_thumbnail( + media_id=media_id, + t_width=file_info.thumbnail_width, + t_height=file_info.thumbnail_height, + t_method=file_info.thumbnail_method, + t_type=file_info.thumbnail_type, + url_cache=url_cache, + ) + + responder = await self.media_storage.fetch_media(file_info) await respond_with_responder( - request, responder, file_info.thumbnail_type, file_info.thumbnail_length + request, + responder, + file_info.thumbnail_type, + file_info.thumbnail_length, ) else: logger.info("Failed to find any generated thumbnails") diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 11362777..5e104fac 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -15,9 +15,9 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import IO, TYPE_CHECKING -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json @@ -79,7 +79,9 @@ class UploadResource(DirectServeJsonResource): headers = request.requestHeaders if headers.hasHeader(b"Content-Type"): - media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii") + content_type_headers = headers.getRawHeaders(b"Content-Type") + assert content_type_headers # for mypy + media_type = content_type_headers[0].decode("ascii") else: raise SynapseError(msg="Upload request missing 'Content-Type'", code=400) @@ -88,8 +90,9 @@ class UploadResource(DirectServeJsonResource): # TODO(markjh): parse content-dispostion try: + content = request.content # type: IO # type: ignore content_uri = await self.media_repo.create_content( - media_type, upload_name, request.content, content_length, requester.user + media_type, upload_name, content, content_length, requester.user ) except SpamMediaException: # For uploading of media we want to respond with a 400, instead of diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index b2e0f938..78ee0b5e 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 9e4fbc0c..d26ce46e 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 96077cfc..51acaa9a 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -16,8 +16,8 @@ import logging from typing import TYPE_CHECKING, List -from twisted.web.http import Request from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py index dfefeb77..f2acce24 100644 --- a/synapse/rest/synapse/client/sso_register.py +++ b/synapse/rest/synapse/client/sso_register.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING -from twisted.web.http import Request +from twisted.web.server import Request from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request |