summaryrefslogtreecommitdiff
path: root/macaroonbakery/tests/test_discharge_all.py
diff options
context:
space:
mode:
Diffstat (limited to 'macaroonbakery/tests/test_discharge_all.py')
-rw-r--r--macaroonbakery/tests/test_discharge_all.py170
1 files changed, 170 insertions, 0 deletions
diff --git a/macaroonbakery/tests/test_discharge_all.py b/macaroonbakery/tests/test_discharge_all.py
new file mode 100644
index 0000000..8da8823
--- /dev/null
+++ b/macaroonbakery/tests/test_discharge_all.py
@@ -0,0 +1,170 @@
+# Copyright 2017 Canonical Ltd.
+# Licensed under the LGPLv3, see LICENCE file for details.
+import unittest
+
+from pymacaroons.verifier import Verifier
+
+import macaroonbakery
+import macaroonbakery.checkers as checkers
+from macaroonbakery.tests import common
+
+
+def always_ok(predicate):
+ return True
+
+
+class TestDischargeAll(unittest.TestCase):
+ def test_discharge_all_no_discharges(self):
+ root_key = b'root key'
+ m = macaroonbakery.Macaroon(
+ root_key=root_key, id=b'id0', location='loc0',
+ version=macaroonbakery.LATEST_BAKERY_VERSION,
+ namespace=common.test_checker().namespace())
+ ms = macaroonbakery.discharge_all(
+ common.test_context, m, no_discharge(self))
+ self.assertEqual(len(ms), 1)
+ self.assertEqual(ms[0], m.macaroon)
+ v = Verifier()
+ v.satisfy_general(always_ok)
+ v.verify(m.macaroon, root_key, None)
+
+ def test_discharge_all_many_discharges(self):
+ root_key = b'root key'
+ m0 = macaroonbakery.Macaroon(
+ root_key=root_key, id=b'id0', location='loc0',
+ version=macaroonbakery.LATEST_BAKERY_VERSION)
+
+ class State(object):
+ total_required = 40
+ id = 1
+
+ def add_caveats(m):
+ for i in range(0, 1):
+ if State.total_required == 0:
+ break
+ cid = 'id{}'.format(State.id)
+ m.macaroon.add_third_party_caveat(
+ location='somewhere',
+ key='root key {}'.format(cid).encode('utf-8'),
+ key_id=cid.encode('utf-8'))
+ State.id += 1
+ State.total_required -= 1
+
+ add_caveats(m0)
+
+ def get_discharge(_, cav, payload):
+ self.assertEqual(payload, None)
+ m = macaroonbakery.Macaroon(
+ root_key='root key {}'.format(
+ cav.caveat_id.decode('utf-8')).encode('utf-8'),
+ id=cav.caveat_id, location='',
+ version=macaroonbakery.LATEST_BAKERY_VERSION)
+
+ add_caveats(m)
+ return m
+
+ ms = macaroonbakery.discharge_all(
+ common.test_context, m0, get_discharge)
+
+ self.assertEqual(len(ms), 41)
+
+ v = Verifier()
+ v.satisfy_general(always_ok)
+ v.verify(ms[0], root_key, ms[1:])
+
+ def test_discharge_all_many_discharges_with_real_third_party_caveats(self):
+ # This is the same flow as TestDischargeAllManyDischarges except that
+ # we're using actual third party caveats as added by
+ # Macaroon.add_caveat and we use a larger number of caveats
+ # so that caveat ids will need to get larger.
+ locator = macaroonbakery.ThirdPartyStore()
+ bakeries = {}
+ total_discharges_required = 40
+
+ class M:
+ bakery_id = 0
+ still_required = total_discharges_required
+
+ def add_bakery():
+ M.bakery_id += 1
+ loc = 'loc{}'.format(M.bakery_id)
+ bakeries[loc] = common.new_bakery(loc, locator)
+ return loc
+
+ ts = common.new_bakery('ts-loc', locator)
+
+ def checker(_, ci):
+ caveats = []
+ if ci.condition != 'something':
+ self.fail('unexpected condition')
+ for i in range(0, 2):
+ if M.still_required <= 0:
+ break
+ caveats.append(checkers.Caveat(location=add_bakery(),
+ condition='something'))
+ M.still_required -= 1
+ return caveats
+
+ root_key = b'root key'
+ m0 = macaroonbakery.Macaroon(
+ root_key=root_key, id=b'id0', location='ts-loc',
+ version=macaroonbakery.LATEST_BAKERY_VERSION)
+
+ m0.add_caveat(checkers. Caveat(location=add_bakery(),
+ condition='something'),
+ ts.oven.key, locator)
+
+ # We've added a caveat (the first) so one less caveat is required.
+ M.still_required -= 1
+
+ class ThirdPartyCaveatCheckerF(macaroonbakery.ThirdPartyCaveatChecker):
+ def check_third_party_caveat(self, ctx, info):
+ return checker(ctx, info)
+
+ def get_discharge(ctx, cav, payload):
+ return macaroonbakery.discharge(
+ ctx, cav.caveat_id, payload,
+ bakeries[cav.location].oven.key,
+ ThirdPartyCaveatCheckerF(), locator)
+
+ ms = macaroonbakery.discharge_all(common.test_context, m0,
+ get_discharge)
+
+ self.assertEqual(len(ms), total_discharges_required + 1)
+
+ v = Verifier()
+ v.satisfy_general(always_ok)
+ v.verify(ms[0], root_key, ms[1:])
+
+ def test_discharge_all_local_discharge(self):
+ oc = common.new_bakery('ts', None)
+ client_key = macaroonbakery.generate_key()
+ m = oc.oven.macaroon(macaroonbakery.LATEST_BAKERY_VERSION, common.ages,
+ [
+ macaroonbakery.local_third_party_caveat(
+ client_key.public_key,
+ macaroonbakery.LATEST_BAKERY_VERSION)
+ ], [macaroonbakery.LOGIN_OP])
+ ms = macaroonbakery.discharge_all(
+ common.test_context, m, no_discharge(self), client_key)
+ oc.checker.auth([ms]).allow(common.test_context,
+ [macaroonbakery.LOGIN_OP])
+
+ def test_discharge_all_local_discharge_version1(self):
+ oc = common.new_bakery('ts', None)
+ client_key = macaroonbakery.generate_key()
+ m = oc.oven.macaroon(macaroonbakery.BAKERY_V1, common.ages, [
+ macaroonbakery.local_third_party_caveat(
+ client_key.public_key, macaroonbakery.BAKERY_V1)
+ ], [macaroonbakery.LOGIN_OP])
+ ms = macaroonbakery.discharge_all(
+ common.test_context, m, no_discharge(self), client_key)
+ oc.checker.auth([ms]).allow(common.test_context,
+ [macaroonbakery.LOGIN_OP])
+
+
+def no_discharge(test):
+ def get_discharge(ctx, cav, payload):
+ test.fail("get_discharge called unexpectedly")
+
+ return get_discharge