diff options
Diffstat (limited to 'macaroonbakery/tests/test_client.py')
-rw-r--r-- | macaroonbakery/tests/test_client.py | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/macaroonbakery/tests/test_client.py b/macaroonbakery/tests/test_client.py new file mode 100644 index 0000000..8263f54 --- /dev/null +++ b/macaroonbakery/tests/test_client.py @@ -0,0 +1,395 @@ +# Copyright 2017 Canonical Ltd. +# Licensed under the LGPLv3, see LICENCE file for details. +import base64 +import datetime +import json +import os +from unittest import TestCase +try: + from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler +except ImportError: + from http.server import HTTPServer, BaseHTTPRequestHandler +import threading + +from httmock import ( + HTTMock, + urlmatch +) +import requests +from six.moves.urllib.parse import parse_qs + +import macaroonbakery as bakery +import macaroonbakery.httpbakery as httpbakery +import macaroonbakery.checkers as checkers + +AGES = datetime.datetime.utcnow() + datetime.timedelta(days=1) +TEST_OP = bakery.Op(entity='test', action='test') + + +class TestClient(TestCase): + def setUp(self): + super(TestClient, self).setUp() + # http_proxy would cause requests to talk to the proxy, which is + # unlikely to know how to talk to the test server. + os.environ.pop('http_proxy', None) + + def test_single_service_first_party(self): + b = new_bakery('loc', None, None) + + def handler(*args): + GetHandler(b, None, None, None, None, *args) + try: + httpd = HTTPServer(('', 0), handler) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + srv_macaroon = b.oven.macaroon( + version=bakery.LATEST_VERSION, expiry=AGES, + caveats=None, ops=[TEST_OP]) + self.assertEquals(srv_macaroon.macaroon.location, 'loc') + client = httpbakery.Client() + client.cookies.set_cookie(requests.cookies.create_cookie( + 'macaroon-test', base64.b64encode(json.dumps([ + srv_macaroon.to_dict().get('m') + ]).encode('utf-8')).decode('utf-8') + )) + resp = requests.get( + url='http://' + httpd.server_address[0] + ':' + + str(httpd.server_address[1]), + cookies=client.cookies, auth=client.auth()) + resp.raise_for_status() + self.assertEquals(resp.text, 'done') + finally: + httpd.shutdown() + + def test_single_party_with_header(self): + b = new_bakery('loc', None, None) + + def handler(*args): + GetHandler(b, None, None, None, None, *args) + try: + httpd = HTTPServer(('', 0), handler) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + srv_macaroon = b.oven.macaroon( + version=bakery.LATEST_VERSION, + expiry=AGES, caveats=None, ops=[TEST_OP]) + self.assertEquals(srv_macaroon.macaroon.location, 'loc') + headers = { + 'Macaroons': base64.b64encode(json.dumps([ + srv_macaroon.to_dict().get('m') + ]).encode('utf-8')) + } + resp = requests.get( + url='http://' + httpd.server_address[0] + ':' + + str(httpd.server_address[1]), + headers=headers) + resp.raise_for_status() + self.assertEquals(resp.text, 'done') + finally: + httpd.shutdown() + + def test_repeated_request_with_body(self): + class _DischargerLocator(bakery.ThirdPartyLocator): + def __init__(self): + self.key = bakery.generate_key() + + def third_party_info(self, loc): + if loc == 'http://1.2.3.4': + return bakery.ThirdPartyInfo( + public_key=self.key.public_key, + version=bakery.LATEST_VERSION, + ) + + d = _DischargerLocator() + b = new_bakery('loc', d, None) + + @urlmatch(path='.*/discharge') + def discharge(url, request): + qs = parse_qs(request.body) + content = {q: qs[q][0] for q in qs} + m = httpbakery.discharge(checkers.AuthContext(), content, d.key, d, alwaysOK3rd) + return { + 'status_code': 200, + 'content': { + 'Macaroon': m.to_dict() + } + } + + def handler(*args): + GetHandler(b, 'http://1.2.3.4', None, None, None, *args) + try: + httpd = HTTPServer(('', 0), handler) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + client = httpbakery.Client() + with HTTMock(discharge): + resp = requests.get( + url='http://' + httpd.server_address[0] + ':' + + str(httpd.server_address[1]), + cookies=client.cookies, + auth=client.auth()) + resp.raise_for_status() + self.assertEquals(resp.text, 'done') + finally: + httpd.shutdown() + + def test_too_many_discharge(self): + class _DischargerLocator(bakery.ThirdPartyLocator): + def __init__(self): + self.key = bakery.generate_key() + + def third_party_info(self, loc): + if loc == 'http://1.2.3.4': + return bakery.ThirdPartyInfo( + public_key=self.key.public_key, + version=bakery.LATEST_VERSION, + ) + + d = _DischargerLocator() + b = new_bakery('loc', d, None) + + @urlmatch(path='.*/discharge') + def discharge(url, request): + wrong_macaroon = bakery.Macaroon( + root_key=b'some key', id=b'xxx', + location='some other location', + version=bakery.VERSION_0) + return { + 'status_code': 200, + 'content': { + 'Macaroon': wrong_macaroon.to_dict() + } + } + + def handler(*args): + GetHandler(b, 'http://1.2.3.4', None, None, None, *args) + try: + httpd = HTTPServer(('', 0), handler) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + client = httpbakery.Client() + with HTTMock(discharge): + with self.assertRaises(httpbakery.BakeryException) as ctx: + requests.get( + url='http://' + httpd.server_address[0] + ':' + + str(httpd.server_address[1]), + cookies=client.cookies, + auth=client.auth()) + self.assertEqual(ctx.exception.args[0], + 'too many (3) discharge requests') + finally: + httpd.shutdown() + + def test_third_party_discharge_refused(self): + class _DischargerLocator(bakery.ThirdPartyLocator): + def __init__(self): + self.key = bakery.generate_key() + + def third_party_info(self, loc): + if loc == 'http://1.2.3.4': + return bakery.ThirdPartyInfo( + public_key=self.key.public_key, + version=bakery.LATEST_VERSION, + ) + + def check(cond, arg): + raise bakery.ThirdPartyCaveatCheckFailed('boo! cond' + cond) + + d = _DischargerLocator() + b = new_bakery('loc', d, None) + + @urlmatch(path='.*/discharge') + def discharge(url, request): + qs = parse_qs(request.body) + content = {q: qs[q][0] for q in qs} + httpbakery.discharge(checkers.AuthContext(), content, d.key, d, + ThirdPartyCaveatCheckerF(check)) + + def handler(*args): + GetHandler(b, 'http://1.2.3.4', None, None, None, *args) + try: + httpd = HTTPServer(('', 0), handler) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + client = httpbakery.Client() + with HTTMock(discharge): + with self.assertRaises(bakery.ThirdPartyCaveatCheckFailed): + requests.get( + url='http://' + httpd.server_address[0] + ':' + + str(httpd.server_address[1]), + cookies=client.cookies, + auth=client.auth()) + finally: + httpd.shutdown() + + def test_discharge_with_interaction_required_error(self): + class _DischargerLocator(bakery.ThirdPartyLocator): + def __init__(self): + self.key = bakery.generate_key() + + def third_party_info(self, loc): + if loc == 'http://1.2.3.4': + return bakery.ThirdPartyInfo( + public_key=self.key.public_key, + version=bakery.LATEST_VERSION, + ) + d = _DischargerLocator() + b = new_bakery('loc', d, None) + + @urlmatch(path='.*/discharge') + def discharge(url, request): + return { + 'status_code': 401, + 'content': { + 'Code': httpbakery.ERR_INTERACTION_REQUIRED, + 'Message': 'interaction required', + 'Info': { + 'WaitURL': 'http://0.1.2.3/', + 'VisitURL': 'http://0.1.2.3/', + }, + } + } + + def handler(*args): + GetHandler(b, 'http://1.2.3.4', None, None, None, *args) + + try: + httpd = HTTPServer(('', 0), handler) + thread = threading.Thread(target=httpd.serve_forever) + thread.start() + + class MyInteractor(httpbakery.LegacyInteractor): + def legacy_interact(self, ctx, location, visit_url): + raise httpbakery.InteractionError('cannot visit') + + def interact(self, ctx, location, interaction_required_err): + pass + + def kind(self): + return httpbakery.WEB_BROWSER_INTERACTION_KIND + + client = httpbakery.Client(interaction_methods=[MyInteractor()]) + + with HTTMock(discharge): + with self.assertRaises(httpbakery.InteractionError): + requests.get( + 'http://' + httpd.server_address[0] + ':' + str( + httpd.server_address[1]), + cookies=client.cookies, + auth=client.auth()) + finally: + httpd.shutdown() + + +class GetHandler(BaseHTTPRequestHandler): + '''A mock HTTP server that serves a GET request''' + def __init__(self, bakery, auth_location, mutate_error, + caveats, version, *args): + ''' + @param bakery used to check incoming requests and macaroons + for discharge-required errors. + @param auth_location holds the location of any 3rd party + authorizer. If this is not None, a 3rd party caveat will be + added addressed to this location. + @param mutate_error if non None, will be called with any + discharge-required error before responding to the client. + @param caveats called to get caveats to add to the returned + macaroon. + @param holds the version of the bakery that the + server will purport to serve. + ''' + self._bakery = bakery + self._auth_location = auth_location + self._mutate_error = mutate_error + self._caveats = caveats + self._server_version = version + BaseHTTPRequestHandler.__init__(self, *args) + + def do_GET(self): + '''do_GET implements a handler for the HTTP GET method''' + ctx = checkers.AuthContext() + auth_checker = self._bakery.checker.auth( + httpbakery.extract_macaroons(self.headers)) + try: + auth_checker.allow(ctx, [TEST_OP]) + except (bakery.PermissionDenied, + bakery.VerificationError) as exc: + return self._write_discharge_error(exc) + self.send_response(200) + self.end_headers() + content_len = int(self.headers.get('content-length', 0)) + content = 'done' + if self.path != '/no-body'and content_len > 0: + body = self.rfile.read(content_len) + content = content + ' ' + body + self.wfile.write(content.encode('utf-8')) + return + + def _write_discharge_error(self, exc): + version = httpbakery.request_version(self.headers) + if version < bakery.LATEST_VERSION: + self._server_version = version + + caveats = [] + if self._auth_location != '': + caveats = [ + checkers.Caveat(location=self._auth_location, + condition='is-ok') + ] + if self._caveats is not None: + caveats.extend(self._caveats) + + m = self._bakery.oven.macaroon( + version=bakery.LATEST_VERSION, expiry=AGES, + caveats=caveats, ops=[TEST_OP]) + + content, headers = httpbakery.discharge_required_response( + m, '/', 'test', exc.args[0]) + self.send_response(401) + for h in headers: + self.send_header(h, headers[h]) + self.send_header('Connection', 'close') + self.end_headers() + self.wfile.write(content) + + +def new_bakery(location, locator, checker): + '''Return a new bakery instance. + @param location Location of the bakery {str}. + @param locator Locator for third parties {ThirdPartyLocator or None} + @param checker Caveat checker {FirstPartyCaveatChecker or None} + @return {Bakery} + ''' + if checker is None: + c = checkers.Checker() + c.namespace().register('testns', '') + c.register('is', 'testns', check_is_something) + checker = c + key = bakery.generate_key() + return bakery.Bakery( + location=location, + locator=locator, + key=key, + checker=checker, + ) + + +def is_something_caveat(): + return checkers.Caveat(condition='is something', namespace='testns') + + +def check_is_something(ctx, cond, arg): + if arg != 'something': + return '{} doesn\'t match "something"'.format(arg) + return None + + +class ThirdPartyCaveatCheckerF(bakery.ThirdPartyCaveatChecker): + def __init__(self, check): + self._check = check + + def check_third_party_caveat(self, ctx, info): + cond, arg = checkers.parse_caveat(info.condition) + return self._check(cond, arg) + +alwaysOK3rd = ThirdPartyCaveatCheckerF(lambda cond, arg: []) |