summaryrefslogtreecommitdiff
path: root/macaroonbakery/tests/test_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'macaroonbakery/tests/test_client.py')
-rw-r--r--macaroonbakery/tests/test_client.py395
1 files changed, 395 insertions, 0 deletions
diff --git a/macaroonbakery/tests/test_client.py b/macaroonbakery/tests/test_client.py
new file mode 100644
index 0000000..8263f54
--- /dev/null
+++ b/macaroonbakery/tests/test_client.py
@@ -0,0 +1,395 @@
+# Copyright 2017 Canonical Ltd.
+# Licensed under the LGPLv3, see LICENCE file for details.
+import base64
+import datetime
+import json
+import os
+from unittest import TestCase
+try:
+ from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
+except ImportError:
+ from http.server import HTTPServer, BaseHTTPRequestHandler
+import threading
+
+from httmock import (
+ HTTMock,
+ urlmatch
+)
+import requests
+from six.moves.urllib.parse import parse_qs
+
+import macaroonbakery as bakery
+import macaroonbakery.httpbakery as httpbakery
+import macaroonbakery.checkers as checkers
+
+AGES = datetime.datetime.utcnow() + datetime.timedelta(days=1)
+TEST_OP = bakery.Op(entity='test', action='test')
+
+
+class TestClient(TestCase):
+ def setUp(self):
+ super(TestClient, self).setUp()
+ # http_proxy would cause requests to talk to the proxy, which is
+ # unlikely to know how to talk to the test server.
+ os.environ.pop('http_proxy', None)
+
+ def test_single_service_first_party(self):
+ b = new_bakery('loc', None, None)
+
+ def handler(*args):
+ GetHandler(b, None, None, None, None, *args)
+ try:
+ httpd = HTTPServer(('', 0), handler)
+ thread = threading.Thread(target=httpd.serve_forever)
+ thread.start()
+ srv_macaroon = b.oven.macaroon(
+ version=bakery.LATEST_VERSION, expiry=AGES,
+ caveats=None, ops=[TEST_OP])
+ self.assertEquals(srv_macaroon.macaroon.location, 'loc')
+ client = httpbakery.Client()
+ client.cookies.set_cookie(requests.cookies.create_cookie(
+ 'macaroon-test', base64.b64encode(json.dumps([
+ srv_macaroon.to_dict().get('m')
+ ]).encode('utf-8')).decode('utf-8')
+ ))
+ resp = requests.get(
+ url='http://' + httpd.server_address[0] + ':' +
+ str(httpd.server_address[1]),
+ cookies=client.cookies, auth=client.auth())
+ resp.raise_for_status()
+ self.assertEquals(resp.text, 'done')
+ finally:
+ httpd.shutdown()
+
+ def test_single_party_with_header(self):
+ b = new_bakery('loc', None, None)
+
+ def handler(*args):
+ GetHandler(b, None, None, None, None, *args)
+ try:
+ httpd = HTTPServer(('', 0), handler)
+ thread = threading.Thread(target=httpd.serve_forever)
+ thread.start()
+ srv_macaroon = b.oven.macaroon(
+ version=bakery.LATEST_VERSION,
+ expiry=AGES, caveats=None, ops=[TEST_OP])
+ self.assertEquals(srv_macaroon.macaroon.location, 'loc')
+ headers = {
+ 'Macaroons': base64.b64encode(json.dumps([
+ srv_macaroon.to_dict().get('m')
+ ]).encode('utf-8'))
+ }
+ resp = requests.get(
+ url='http://' + httpd.server_address[0] + ':' +
+ str(httpd.server_address[1]),
+ headers=headers)
+ resp.raise_for_status()
+ self.assertEquals(resp.text, 'done')
+ finally:
+ httpd.shutdown()
+
+ def test_repeated_request_with_body(self):
+ class _DischargerLocator(bakery.ThirdPartyLocator):
+ def __init__(self):
+ self.key = bakery.generate_key()
+
+ def third_party_info(self, loc):
+ if loc == 'http://1.2.3.4':
+ return bakery.ThirdPartyInfo(
+ public_key=self.key.public_key,
+ version=bakery.LATEST_VERSION,
+ )
+
+ d = _DischargerLocator()
+ b = new_bakery('loc', d, None)
+
+ @urlmatch(path='.*/discharge')
+ def discharge(url, request):
+ qs = parse_qs(request.body)
+ content = {q: qs[q][0] for q in qs}
+ m = httpbakery.discharge(checkers.AuthContext(), content, d.key, d, alwaysOK3rd)
+ return {
+ 'status_code': 200,
+ 'content': {
+ 'Macaroon': m.to_dict()
+ }
+ }
+
+ def handler(*args):
+ GetHandler(b, 'http://1.2.3.4', None, None, None, *args)
+ try:
+ httpd = HTTPServer(('', 0), handler)
+ thread = threading.Thread(target=httpd.serve_forever)
+ thread.start()
+ client = httpbakery.Client()
+ with HTTMock(discharge):
+ resp = requests.get(
+ url='http://' + httpd.server_address[0] + ':' +
+ str(httpd.server_address[1]),
+ cookies=client.cookies,
+ auth=client.auth())
+ resp.raise_for_status()
+ self.assertEquals(resp.text, 'done')
+ finally:
+ httpd.shutdown()
+
+ def test_too_many_discharge(self):
+ class _DischargerLocator(bakery.ThirdPartyLocator):
+ def __init__(self):
+ self.key = bakery.generate_key()
+
+ def third_party_info(self, loc):
+ if loc == 'http://1.2.3.4':
+ return bakery.ThirdPartyInfo(
+ public_key=self.key.public_key,
+ version=bakery.LATEST_VERSION,
+ )
+
+ d = _DischargerLocator()
+ b = new_bakery('loc', d, None)
+
+ @urlmatch(path='.*/discharge')
+ def discharge(url, request):
+ wrong_macaroon = bakery.Macaroon(
+ root_key=b'some key', id=b'xxx',
+ location='some other location',
+ version=bakery.VERSION_0)
+ return {
+ 'status_code': 200,
+ 'content': {
+ 'Macaroon': wrong_macaroon.to_dict()
+ }
+ }
+
+ def handler(*args):
+ GetHandler(b, 'http://1.2.3.4', None, None, None, *args)
+ try:
+ httpd = HTTPServer(('', 0), handler)
+ thread = threading.Thread(target=httpd.serve_forever)
+ thread.start()
+ client = httpbakery.Client()
+ with HTTMock(discharge):
+ with self.assertRaises(httpbakery.BakeryException) as ctx:
+ requests.get(
+ url='http://' + httpd.server_address[0] + ':' +
+ str(httpd.server_address[1]),
+ cookies=client.cookies,
+ auth=client.auth())
+ self.assertEqual(ctx.exception.args[0],
+ 'too many (3) discharge requests')
+ finally:
+ httpd.shutdown()
+
+ def test_third_party_discharge_refused(self):
+ class _DischargerLocator(bakery.ThirdPartyLocator):
+ def __init__(self):
+ self.key = bakery.generate_key()
+
+ def third_party_info(self, loc):
+ if loc == 'http://1.2.3.4':
+ return bakery.ThirdPartyInfo(
+ public_key=self.key.public_key,
+ version=bakery.LATEST_VERSION,
+ )
+
+ def check(cond, arg):
+ raise bakery.ThirdPartyCaveatCheckFailed('boo! cond' + cond)
+
+ d = _DischargerLocator()
+ b = new_bakery('loc', d, None)
+
+ @urlmatch(path='.*/discharge')
+ def discharge(url, request):
+ qs = parse_qs(request.body)
+ content = {q: qs[q][0] for q in qs}
+ httpbakery.discharge(checkers.AuthContext(), content, d.key, d,
+ ThirdPartyCaveatCheckerF(check))
+
+ def handler(*args):
+ GetHandler(b, 'http://1.2.3.4', None, None, None, *args)
+ try:
+ httpd = HTTPServer(('', 0), handler)
+ thread = threading.Thread(target=httpd.serve_forever)
+ thread.start()
+ client = httpbakery.Client()
+ with HTTMock(discharge):
+ with self.assertRaises(bakery.ThirdPartyCaveatCheckFailed):
+ requests.get(
+ url='http://' + httpd.server_address[0] + ':' +
+ str(httpd.server_address[1]),
+ cookies=client.cookies,
+ auth=client.auth())
+ finally:
+ httpd.shutdown()
+
+ def test_discharge_with_interaction_required_error(self):
+ class _DischargerLocator(bakery.ThirdPartyLocator):
+ def __init__(self):
+ self.key = bakery.generate_key()
+
+ def third_party_info(self, loc):
+ if loc == 'http://1.2.3.4':
+ return bakery.ThirdPartyInfo(
+ public_key=self.key.public_key,
+ version=bakery.LATEST_VERSION,
+ )
+ d = _DischargerLocator()
+ b = new_bakery('loc', d, None)
+
+ @urlmatch(path='.*/discharge')
+ def discharge(url, request):
+ return {
+ 'status_code': 401,
+ 'content': {
+ 'Code': httpbakery.ERR_INTERACTION_REQUIRED,
+ 'Message': 'interaction required',
+ 'Info': {
+ 'WaitURL': 'http://0.1.2.3/',
+ 'VisitURL': 'http://0.1.2.3/',
+ },
+ }
+ }
+
+ def handler(*args):
+ GetHandler(b, 'http://1.2.3.4', None, None, None, *args)
+
+ try:
+ httpd = HTTPServer(('', 0), handler)
+ thread = threading.Thread(target=httpd.serve_forever)
+ thread.start()
+
+ class MyInteractor(httpbakery.LegacyInteractor):
+ def legacy_interact(self, ctx, location, visit_url):
+ raise httpbakery.InteractionError('cannot visit')
+
+ def interact(self, ctx, location, interaction_required_err):
+ pass
+
+ def kind(self):
+ return httpbakery.WEB_BROWSER_INTERACTION_KIND
+
+ client = httpbakery.Client(interaction_methods=[MyInteractor()])
+
+ with HTTMock(discharge):
+ with self.assertRaises(httpbakery.InteractionError):
+ requests.get(
+ 'http://' + httpd.server_address[0] + ':' + str(
+ httpd.server_address[1]),
+ cookies=client.cookies,
+ auth=client.auth())
+ finally:
+ httpd.shutdown()
+
+
+class GetHandler(BaseHTTPRequestHandler):
+ '''A mock HTTP server that serves a GET request'''
+ def __init__(self, bakery, auth_location, mutate_error,
+ caveats, version, *args):
+ '''
+ @param bakery used to check incoming requests and macaroons
+ for discharge-required errors.
+ @param auth_location holds the location of any 3rd party
+ authorizer. If this is not None, a 3rd party caveat will be
+ added addressed to this location.
+ @param mutate_error if non None, will be called with any
+ discharge-required error before responding to the client.
+ @param caveats called to get caveats to add to the returned
+ macaroon.
+ @param holds the version of the bakery that the
+ server will purport to serve.
+ '''
+ self._bakery = bakery
+ self._auth_location = auth_location
+ self._mutate_error = mutate_error
+ self._caveats = caveats
+ self._server_version = version
+ BaseHTTPRequestHandler.__init__(self, *args)
+
+ def do_GET(self):
+ '''do_GET implements a handler for the HTTP GET method'''
+ ctx = checkers.AuthContext()
+ auth_checker = self._bakery.checker.auth(
+ httpbakery.extract_macaroons(self.headers))
+ try:
+ auth_checker.allow(ctx, [TEST_OP])
+ except (bakery.PermissionDenied,
+ bakery.VerificationError) as exc:
+ return self._write_discharge_error(exc)
+ self.send_response(200)
+ self.end_headers()
+ content_len = int(self.headers.get('content-length', 0))
+ content = 'done'
+ if self.path != '/no-body'and content_len > 0:
+ body = self.rfile.read(content_len)
+ content = content + ' ' + body
+ self.wfile.write(content.encode('utf-8'))
+ return
+
+ def _write_discharge_error(self, exc):
+ version = httpbakery.request_version(self.headers)
+ if version < bakery.LATEST_VERSION:
+ self._server_version = version
+
+ caveats = []
+ if self._auth_location != '':
+ caveats = [
+ checkers.Caveat(location=self._auth_location,
+ condition='is-ok')
+ ]
+ if self._caveats is not None:
+ caveats.extend(self._caveats)
+
+ m = self._bakery.oven.macaroon(
+ version=bakery.LATEST_VERSION, expiry=AGES,
+ caveats=caveats, ops=[TEST_OP])
+
+ content, headers = httpbakery.discharge_required_response(
+ m, '/', 'test', exc.args[0])
+ self.send_response(401)
+ for h in headers:
+ self.send_header(h, headers[h])
+ self.send_header('Connection', 'close')
+ self.end_headers()
+ self.wfile.write(content)
+
+
+def new_bakery(location, locator, checker):
+ '''Return a new bakery instance.
+ @param location Location of the bakery {str}.
+ @param locator Locator for third parties {ThirdPartyLocator or None}
+ @param checker Caveat checker {FirstPartyCaveatChecker or None}
+ @return {Bakery}
+ '''
+ if checker is None:
+ c = checkers.Checker()
+ c.namespace().register('testns', '')
+ c.register('is', 'testns', check_is_something)
+ checker = c
+ key = bakery.generate_key()
+ return bakery.Bakery(
+ location=location,
+ locator=locator,
+ key=key,
+ checker=checker,
+ )
+
+
+def is_something_caveat():
+ return checkers.Caveat(condition='is something', namespace='testns')
+
+
+def check_is_something(ctx, cond, arg):
+ if arg != 'something':
+ return '{} doesn\'t match "something"'.format(arg)
+ return None
+
+
+class ThirdPartyCaveatCheckerF(bakery.ThirdPartyCaveatChecker):
+ def __init__(self, check):
+ self._check = check
+
+ def check_third_party_caveat(self, ctx, info):
+ cond, arg = checkers.parse_caveat(info.condition)
+ return self._check(cond, arg)
+
+alwaysOK3rd = ThirdPartyCaveatCheckerF(lambda cond, arg: [])