summaryrefslogtreecommitdiff
path: root/tests/rest/client/v1/test_login.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/v1/test_login.py')
-rw-r--r--tests/rest/client/v1/test_login.py271
1 files changed, 260 insertions, 11 deletions
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 1856c7ff..9033f09f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,10 +1,13 @@
import json
+import time
import urllib.parse
from mock import Mock
+import jwt
+
import synapse.rest.admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
@@ -20,12 +23,12 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
def make_homeserver(self, reactor, clock):
-
self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
@@ -34,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs
+ @override_config(
+ {
+ "rc_login": {
+ "address": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the account login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "account": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_address(self):
- self.hs.config.rc_login_address.burst_count = 5
- self.hs.config.rc_login_address.per_second = 0.17
-
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
for i in range(0, 6):
@@ -76,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ "account": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account(self):
- self.hs.config.rc_login_account.burst_count = 5
- self.hs.config.rc_login_account.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -115,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ "failed_attempts": {"per_second": 0.17, "burst_count": 5},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account_failed_attempts(self):
- self.hs.config.rc_login_failed_attempts.burst_count = 5
- self.hs.config.rc_login_failed_attempts.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -256,6 +289,72 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard logout this session
+ request, channel = self.make_request(
+ b"POST", "/logout", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard log out all of the user's sessions
+ request, channel = self.make_request(
+ b"POST", "/logout/all", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
class CASTestCase(unittest.HomeserverTestCase):
@@ -406,3 +505,153 @@ class CASTestCase(unittest.HomeserverTestCase):
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
self.assertIn(b"SSO account deactivated", channel.result["body"])
+
+
+class JWTTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ jwt_secret = "secret"
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_secret
+ self.hs.config.jwt_algorithm = "HS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_secret):
+ return jwt.encode(token, secret, "HS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid_registered(self):
+ self.register_user("kermit", "monkey")
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_valid_unregistered(self):
+ channel = self.jwt_login({"sub": "frog"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, "notsecret")
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_jwt_expired(self):
+ channel = self.jwt_login({"sub": "frog", "exp": 864000})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "JWT expired")
+
+ def test_login_jwt_not_before(self):
+ now = int(time.time())
+ channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_sub(self):
+ channel = self.jwt_login({"username": "root"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_token(self):
+ params = json.dumps({"type": "m.login.jwt"})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
+
+
+# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
+# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
+# signed by the private key.
+class JWTPubKeyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ ]
+
+ # This key's pubkey is used as the jwt_secret setting of synapse. Valid
+ # tokens are signed by this and validated using the pubkey. It is generated
+ # with `openssl genrsa 512` (not a secure way to generate real keys, but
+ # good enough for tests!)
+ jwt_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB",
+ "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk",
+ "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/",
+ "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq",
+ "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN",
+ "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA",
+ "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ # Generated with `openssl rsa -in foo.key -pubout`, with the the above
+ # private key placed in foo.key (jwt_privatekey).
+ jwt_pubkey = "\n".join(
+ [
+ "-----BEGIN PUBLIC KEY-----",
+ "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
+ "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
+ "-----END PUBLIC KEY-----",
+ ]
+ )
+
+ # This key is used to sign tokens that shouldn't be accepted by synapse.
+ # Generated just like jwt_privatekey.
+ bad_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
+ "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
+ "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
+ "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
+ "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
+ "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
+ "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_pubkey
+ self.hs.config.jwt_algorithm = "RS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_privatekey):
+ return jwt.encode(token, secret, "RS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid(self):
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")