From fe08be698d36d42a66839ce284989947220931cd Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 10 Mar 2016 17:33:01 -0500 Subject: Support inline default values. Signed-off-by: Daniel Nephin --- compose/config/config.py | 22 +++++++----- compose/config/interpolation.py | 74 +++++++++++++++++++++++++++++++---------- 2 files changed, 70 insertions(+), 26 deletions(-) (limited to 'compose') diff --git a/compose/config/config.py b/compose/config/config.py index aea1e094..4d32b50c 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -413,31 +413,35 @@ def load_services(config_details, config_file): return build_services(service_config) -def interpolate_config_section(filename, config, section, environment): - validate_config_section(filename, config, section) - return interpolate_environment_variables(config, section, environment) +def interpolate_config_section(config_file, config, section, environment): + validate_config_section(config_file.filename, config, section) + return interpolate_environment_variables( + config_file.version, + config, + section, + environment) def process_config_file(config_file, environment, service_name=None): services = interpolate_config_section( - config_file.filename, + config_file, config_file.get_service_dicts(), 'service', - environment,) + environment) if config_file.version in (V2_0, V2_1): processed_config = dict(config_file.config) processed_config['services'] = services processed_config['volumes'] = interpolate_config_section( - config_file.filename, + config_file, config_file.get_volumes(), 'volume', - environment,) + environment) processed_config['networks'] = interpolate_config_section( - config_file.filename, + config_file, config_file.get_networks(), 'network', - environment,) + environment) if config_file.version == V1: processed_config = services diff --git a/compose/config/interpolation.py b/compose/config/interpolation.py index 63020d91..cb841437 100644 --- a/compose/config/interpolation.py +++ b/compose/config/interpolation.py @@ -7,14 +7,35 @@ from string import Template import six from .errors import ConfigurationError +from compose.const import COMPOSEFILE_V1 as V1 +from compose.const import COMPOSEFILE_V2_0 as V2_0 + + log = logging.getLogger(__name__) -def interpolate_environment_variables(config, section, environment): +class Interpolator(object): + + def __init__(self, templater, mapping): + self.templater = templater + self.mapping = mapping + + def interpolate(self, string): + try: + return self.templater(string).substitute(self.mapping) + except ValueError: + raise InvalidInterpolation(string) + + +def interpolate_environment_variables(version, config, section, environment): + if version in (V2_0, V1): + interpolator = Interpolator(Template, environment) + else: + interpolator = Interpolator(TemplateWithDefaults, environment) def process_item(name, config_dict): return dict( - (key, interpolate_value(name, key, val, section, environment)) + (key, interpolate_value(name, key, val, section, interpolator)) for key, val in (config_dict or {}).items() ) @@ -24,9 +45,9 @@ def interpolate_environment_variables(config, section, environment): ) -def interpolate_value(name, config_key, value, section, mapping): +def interpolate_value(name, config_key, value, section, interpolator): try: - return recursive_interpolate(value, mapping) + return recursive_interpolate(value, interpolator) except InvalidInterpolation as e: raise ConfigurationError( 'Invalid interpolation format for "{config_key}" option ' @@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping): string=e.string)) -def recursive_interpolate(obj, mapping): +def recursive_interpolate(obj, interpolator): if isinstance(obj, six.string_types): - return interpolate(obj, mapping) - elif isinstance(obj, dict): + return interpolator.interpolate(obj) + if isinstance(obj, dict): return dict( - (key, recursive_interpolate(val, mapping)) + (key, recursive_interpolate(val, interpolator)) for (key, val) in obj.items() ) - elif isinstance(obj, list): - return [recursive_interpolate(val, mapping) for val in obj] - else: - return obj + if isinstance(obj, list): + return [recursive_interpolate(val, interpolator) for val in obj] + return obj -def interpolate(string, mapping): - try: - return Template(string).substitute(mapping) - except ValueError: - raise InvalidInterpolation(string) +class TemplateWithDefaults(Template): + idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?' + + # Modified from python2.7/string.py + def substitute(self, mapping): + # Helper function for .sub() + def convert(mo): + # Check the most common path first. + named = mo.group('named') or mo.group('braced') + if named is not None: + if ':-' in named: + var, _, default = named.partition(':-') + return mapping.get(var) or default + if '-' in named: + var, _, default = named.partition('-') + return mapping.get(var, default) + val = mapping[named] + return '%s' % (val,) + if mo.group('escaped') is not None: + return self.delimiter + if mo.group('invalid') is not None: + self._invalid(mo) + raise ValueError('Unrecognized named group in pattern', + self.pattern) + return self.pattern.sub(convert, self.template) class InvalidInterpolation(Exception): -- cgit v1.2.3