# Copyright 2017 Canonical Ltd. # Licensed under the LGPLv3, see LICENCE file for details. import unittest import macaroonbakery.bakery as bakery import macaroonbakery.checkers as checkers from macaroonbakery.tests import common from pymacaroons.verifier import Verifier def always_ok(predicate): return True class TestDischargeAll(unittest.TestCase): def test_discharge_all_no_discharges(self): root_key = b'root key' m = bakery.Macaroon( root_key=root_key, id=b'id0', location='loc0', version=bakery.LATEST_VERSION, namespace=common.test_checker().namespace()) ms = bakery.discharge_all(m, no_discharge(self)) self.assertEqual(len(ms), 1) self.assertEqual(ms[0], m.macaroon) v = Verifier() v.satisfy_general(always_ok) v.verify(m.macaroon, root_key, None) def test_discharge_all_many_discharges(self): root_key = b'root key' m0 = bakery.Macaroon( root_key=root_key, id=b'id0', location='loc0', version=bakery.LATEST_VERSION) class State(object): total_required = 40 id = 1 def add_caveats(m): for i in range(0, 1): if State.total_required == 0: break cid = 'id{}'.format(State.id) m.macaroon.add_third_party_caveat( location='somewhere', key='root key {}'.format(cid).encode('utf-8'), key_id=cid.encode('utf-8')) State.id += 1 State.total_required -= 1 add_caveats(m0) def get_discharge(cav, payload): self.assertEqual(payload, None) m = bakery.Macaroon( root_key='root key {}'.format( cav.caveat_id.decode('utf-8')).encode('utf-8'), id=cav.caveat_id, location='', version=bakery.LATEST_VERSION) add_caveats(m) return m ms = bakery.discharge_all(m0, get_discharge) self.assertEqual(len(ms), 41) v = Verifier() v.satisfy_general(always_ok) v.verify(ms[0], root_key, ms[1:]) def test_discharge_all_many_discharges_with_real_third_party_caveats(self): # This is the same flow as TestDischargeAllManyDischarges except that # we're using actual third party caveats as added by # Macaroon.add_caveat and we use a larger number of caveats # so that caveat ids will need to get larger. locator = bakery.ThirdPartyStore() bakeries = {} total_discharges_required = 40 class M: bakery_id = 0 still_required = total_discharges_required def add_bakery(): M.bakery_id += 1 loc = 'loc{}'.format(M.bakery_id) bakeries[loc] = common.new_bakery(loc, locator) return loc ts = common.new_bakery('ts-loc', locator) def checker(_, ci): caveats = [] if ci.condition != 'something': self.fail('unexpected condition') for i in range(0, 2): if M.still_required <= 0: break caveats.append(checkers.Caveat(location=add_bakery(), condition='something')) M.still_required -= 1 return caveats root_key = b'root key' m0 = bakery.Macaroon( root_key=root_key, id=b'id0', location='ts-loc', version=bakery.LATEST_VERSION) m0.add_caveat(checkers. Caveat(location=add_bakery(), condition='something'), ts.oven.key, locator) # We've added a caveat (the first) so one less caveat is required. M.still_required -= 1 class ThirdPartyCaveatCheckerF(bakery.ThirdPartyCaveatChecker): def check_third_party_caveat(self, ctx, info): return checker(ctx, info) def get_discharge(cav, payload): return bakery.discharge( common.test_context, cav.caveat_id, payload, bakeries[cav.location].oven.key, ThirdPartyCaveatCheckerF(), locator) ms = bakery.discharge_all(m0, get_discharge) self.assertEqual(len(ms), total_discharges_required + 1) v = Verifier() v.satisfy_general(always_ok) v.verify(ms[0], root_key, ms[1:]) def test_discharge_all_local_discharge(self): oc = common.new_bakery('ts', None) client_key = bakery.generate_key() m = oc.oven.macaroon(bakery.LATEST_VERSION, common.ages, [ bakery.local_third_party_caveat( client_key.public_key, bakery.LATEST_VERSION) ], [bakery.LOGIN_OP]) ms = bakery.discharge_all(m, no_discharge(self), client_key) oc.checker.auth([ms]).allow(common.test_context, [bakery.LOGIN_OP]) def test_discharge_all_local_discharge_version1(self): oc = common.new_bakery('ts', None) client_key = bakery.generate_key() m = oc.oven.macaroon(bakery.VERSION_1, common.ages, [ bakery.local_third_party_caveat( client_key.public_key, bakery.VERSION_1) ], [bakery.LOGIN_OP]) ms = bakery.discharge_all(m, no_discharge(self), client_key) oc.checker.auth([ms]).allow(common.test_context, [bakery.LOGIN_OP]) def no_discharge(test): def get_discharge(cav, payload): test.fail("get_discharge called unexpectedly") return get_discharge