summaryrefslogtreecommitdiff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
authorAndrej Shadura <andrewsh@debian.org>2021-07-14 08:49:13 +0200
committerAndrej Shadura <andrewsh@debian.org>2021-07-14 08:49:13 +0200
commitf63f4d3518e7df05a25fd057b7faadceab5c5bb6 (patch)
treeea36466b4a01868d60f8b1435e6dbf3243b29d88 /synapse/handlers/auth.py
parent149c9216ba9e2dfbda0fd178e19d25c177ca08a4 (diff)
New upstream version 1.38.0
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py132
1 files changed, 127 insertions, 5 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1971e373..e2ac595a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -30,6 +30,7 @@ from typing import (
Optional,
Tuple,
Union,
+ cast,
)
import attr
@@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode
from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
+ from synapse.rest.client.v1.login import LoginResponse
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -777,6 +779,108 @@ class AuthHandler(BaseHandler):
"params": params,
}
+ async def refresh_token(
+ self,
+ refresh_token: str,
+ valid_until_ms: Optional[int],
+ ) -> Tuple[str, str]:
+ """
+ Consumes a refresh token and generate both a new access token and a new refresh token from it.
+
+ The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+
+ Args:
+ refresh_token: The token to consume.
+ valid_until_ms: The expiration timestamp of the new access token.
+
+ Returns:
+ A tuple containing the new access token and refresh token
+ """
+
+ # Verify the token signature first before looking up the token
+ if not self._verify_refresh_token(refresh_token):
+ raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+
+ existing_token = await self.store.lookup_refresh_token(refresh_token)
+ if existing_token is None:
+ raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+
+ if (
+ existing_token.has_next_access_token_been_used
+ or existing_token.has_next_refresh_token_been_refreshed
+ ):
+ raise SynapseError(
+ 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+ )
+
+ (
+ new_refresh_token,
+ new_refresh_token_id,
+ ) = await self.get_refresh_token_for_user_id(
+ user_id=existing_token.user_id, device_id=existing_token.device_id
+ )
+ access_token = await self.get_access_token_for_user_id(
+ user_id=existing_token.user_id,
+ device_id=existing_token.device_id,
+ valid_until_ms=valid_until_ms,
+ refresh_token_id=new_refresh_token_id,
+ )
+ await self.store.replace_refresh_token(
+ existing_token.token_id, new_refresh_token_id
+ )
+ return access_token, new_refresh_token
+
+ def _verify_refresh_token(self, token: str) -> bool:
+ """
+ Verifies the shape of a refresh token.
+
+ Args:
+ token: The refresh token to verify
+
+ Returns:
+ Whether the token has the right shape
+ """
+ parts = token.split("_", maxsplit=4)
+ if len(parts) != 4:
+ return False
+
+ type, localpart, rand, crc = parts
+
+ # Refresh tokens are prefixed by "syr_", let's check that
+ if type != "syr":
+ return False
+
+ # Check the CRC
+ base = f"{type}_{localpart}_{rand}"
+ expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+ if crc != expected_crc:
+ return False
+
+ return True
+
+ async def get_refresh_token_for_user_id(
+ self,
+ user_id: str,
+ device_id: str,
+ ) -> Tuple[str, int]:
+ """
+ Creates a new refresh token for the user with the given user ID.
+
+ Args:
+ user_id: canonical user ID
+ device_id: the device ID to associate with the token.
+
+ Returns:
+ The newly created refresh token and its ID in the database
+ """
+ refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
+ refresh_token_id = await self.store.add_refresh_token_to_user(
+ user_id=user_id,
+ token=refresh_token,
+ device_id=device_id,
+ )
+ return refresh_token, refresh_token_id
+
async def get_access_token_for_user_id(
self,
user_id: str,
@@ -784,6 +888,7 @@ class AuthHandler(BaseHandler):
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
+ refresh_token_id: Optional[int] = None,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -801,6 +906,8 @@ class AuthHandler(BaseHandler):
valid_until_ms: when the token is valid until. None for
no expiry.
is_appservice_ghost: Whether the user is an application ghost user
+ refresh_token_id: the refresh token ID that will be associated with
+ this access token.
Returns:
The access token for the user's session.
Raises:
@@ -836,6 +943,7 @@ class AuthHandler(BaseHandler):
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
+ refresh_token_id=refresh_token_id,
)
# the device *should* have been registered before we got here; however,
@@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler):
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler):
self,
username: str,
login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler):
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"
+ def generate_refresh_token(self, for_user: UserID) -> str:
+ """Generates an opaque string, for use as a refresh token"""
+
+ # we use the following format for refresh tokens:
+ # syr_<base64 local part>_<random string>_<base62 crc check>
+
+ b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
+ random_string = stringutils.random_string(20)
+ base = f"syr_{b64local}_{random_string}"
+
+ crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+ return f"{base}_{crc}"
+
async def validate_short_term_login_token(
self, login_token: str
) -> LoginTokenAttributes:
@@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler):
)
respond_with_html(request, 200, html)
- async def _sso_login_callback(self, login_result: JsonDict) -> None:
+ async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
"""
A login callback which might add additional attributes to the login response.
@@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler):
extra_attributes = self._extra_attributes.get(login_result["user_id"])
if extra_attributes:
- login_result.update(extra_attributes.extra_attributes)
+ login_result_dict = cast(Dict[str, Any], login_result)
+ login_result_dict.update(extra_attributes.extra_attributes)
def _expire_sso_extra_attributes(self) -> None:
"""