Source code for oscar.flag

# pylint: disable=W0621
import base64
import sys

from six.moves import configparser


class _Required(object):

    def __str__(self):
        return '<required>'

REQUIRED = _Required()
UNSET = object()
INDENT = '    '


[docs]def default_usage(globalflags): '''Default for printing out usage. :type globalflags: GlobalFlagSet ''' sys.stderr.write('Usage of %s:\n' % sys.argv[0]) globalflags.write_flags(sys.stderr) sys.exit(0)
[docs]def default_usage_long(globalflags, return_code=0): '''Default for printing out long-form usage. :type globalflags: GlobalFlagSet ''' sys.stderr.write('Usage of %s:\n' % sys.argv[0]) globalflags.write_flags_long(sys.stderr) sys.exit(return_code)
[docs]class GlobalFlagSet(object): '''GlobalFlagSet is a collection of namespaces and flag logic.''' def __init__(self, usage=default_usage, usage_long=default_usage_long): '''Create a new GlobalFlagSet. :type usage: F(GlobalFlagSet) :type usage_long: G(GlobalFlagSet) ''' self.usage = usage self.usage_long = usage_long self.namespace_flags = dict() self.args = []
[docs] def namespace(self, namespace): '''Returns a :class:`NamespaceFlagSet` associated with ``namespace``. :type namespace: str :rtype: NamespaceFlagSet ''' return self.namespace_flags.setdefault(namespace, NamespaceFlagSet())
[docs] def get(self, namespace, flag): '''get the flag object associated with ``flag`` in ``namespace``. :type namespace: str :type flag: str :rtype: Var :raises KeyError: if the flag is not found ''' return self.namespace_flags[namespace]._flags[flag]
[docs] def find_short(self, flag): '''Find the namespace for a non-qualified ``flag``. If the ``flag`` is not found or multiple flags are found, :exc:`KeyError` is raised. :type flag: str :rtype: str :raises KeyError: if the flag is not found or ambiguous ''' matches = [] for namespace, flagset in self.namespace_flags.items(): if flag in flagset._flags: matches.append(namespace) if len(matches) > 1: raise KeyError('ambiguous flag \'%s\' in namespaces %s' % (flag, matches)) if len(matches) == 0: raise KeyError('%s' % flag) return matches[0]
[docs] def write_flags(self, out, namespace='__main__'): '''Prints the usage to ``out``. :type out: file :type namespace: str ''' for name, var in sorted(self.namespace(namespace)._flags.items()): try: _ = self.find_short(name) except KeyError: out.write(INDENT + '-%s.%s=%s: %s (%s)\n' % ( namespace, name, var.default, var.description, var.type_str)) else: out.write(INDENT + '[%s.]%s=%s: %s (%s)\n' % ( namespace, name, var.default, var.description, var.type_str)) if hasattr(var, 'long_description'): for line in var.long_description.splitlines(): out.write(INDENT + INDENT + line + '\n')
[docs] def write_flags_long(self, out): '''Prints all flag usage to ``out``. :type out: file ''' for namespace in sorted(self.namespace_flags): out.write('%s:\n' % namespace) self.write_flags(out, namespace) out.write('\n')
[docs] def visit(self, func): '''Walk all *set* flags, calling ``func`` on each. :type func: F(str, str, Var) :param func: visiting function ''' for namespace in sorted(self.namespace_flags): for name, flag in sorted(self.namespace(namespace)._flags.items()): if flag.is_set(): func(namespace, name, flag)
[docs] def visit_all(self, func): '''Walk all flags, calling ``func`` on each. :type func: F(str, str, Var) :param func: visiting function ''' for namespace in sorted(self.namespace_flags): for name, flag in sorted(self.namespace(namespace)._flags.items()): func(namespace, name, flag)
[docs] def check_required(self): '''Returns a list of ``(namespace, name, flag)`` of all unset, required flags. :rtype: list[tuple(str, str, Var)] ''' nonset = [] def _check(namespace, name, flag): if flag.default is REQUIRED and not flag.is_set(): nonset.append((namespace, name, flag)) self.visit_all(_check) return nonset
[docs] def parse_commandline(self, args): '''Parse a commandline into flags and arguments. Parse a commandline. a single '-' and a double '--' are treated as equivalent for denoting a flag. Flag values may be separated by '=' or be the next argument. Booleans are a special case that indicate a true value or must have their values separated by '='. :type args: list(str) :raises KeyError: on unknown flag :raises ParseException: on invalid command-line syntax ''' self.args = args while len(self.args): arg = self.args[0] if len(arg) == 0 or arg[0] != '-' or len(arg) == 1: break num_minuses = 1 if arg[1] == '-': num_minuses += 1 if len(arg) == 2: # -- terminates flag list (rest are args). self.args = self.args[1:] break name = arg[num_minuses:] if len(name) == 0 or name[0] == '-' or name[0] == '=': raise ParseException('bad flag syntax: %s' % arg) self.args = self.args[1:] has_value = False value = '' for i in range(1, len(name)): if name[i] == '=': value = name[i + 1:] has_value = True name = name[0:i] break if name == 'help' or name == 'h': self.usage(self) if name == 'helplong': self.usage_long(self) if '.' not in name: namespace = self.find_short(name) else: parts = name.split('.') namespace, name = '.'.join(parts[:-1]), parts[-1] flag = self.get(namespace, name) if isinstance(flag, Bool): if has_value: flag.set(value) else: flag.set('True') else: if not has_value and len(self.args) > 0: has_value = True value, self.args = self.args[0], self.args[1:] if not has_value: raise ParseException('flag needs an argument: %s' % name) flag.set(value)
[docs] def parse_environment(self, args): '''Parse environment variable tuples. Call with os.environ: >>> import os >>> flag.parse_environment(os.environ.items()) Recognizes `SECURED_SETTING_` prefixed flags, translates from base64 and maps to the short name. `secure` is also set to `True` on the underlying flag object, which should be respected by users of :meth:`visit` and :meth:`visit_all`. :type args: list[tuple(str, str)] ''' for name, value in args: # First treat SECURED_SETTING_ values specially. secure = False if name.startswith('SECURED_SETTING_'): name = name[len('SECURED_SETTING_'):] value = base64.b64decode(value) secure = True try: if '.' not in name: namespace = self.find_short(name) else: parts = name.split('.') namespace, name = '.'.join(parts[:-1]), parts[-1] flag = self.get(namespace, name) flag.set(value) flag.secure = flag.secure or secure except KeyError: # Ignore environment variables that don't map to a setting. pass
[docs] def parse_ini(self, file_p): '''Parse a :mod:`ConfigParser` compatible file object. Namespaces are section headers, keys are flags:: [__main__] foo=bar [foo.bar] baz=42 ``file_p`` only needs to implement ``readline(size=0)``. :type file_p: file ''' config = configparser.ConfigParser() # Need to set optionxform to `str` so that ConfigParser will be case # sensitive in its parsing of section headers and keys config.optionxform = str config.readfp(file_p) for section in config.sections(): for name, value in config.items(section): self.get(section, name).set(value)
[docs]class NamespaceFlagSet(object): '''Represents a set of flags in a namespace.''' def __init__(self): '''Create a new :class:`NamespaceFlagSet`.''' self.__dict__['_flags'] = dict() def __getattr__(self, name): '''Return the value for a flag. :type name: str :rtype: Var ''' return self._flags[name].get() def __setattr__(self, name, value): '''Set a flag accessor if none exists else return a flag value. :type name: str :type value: str or Var :raises FlagException: on flag redefinition or invalid definition ''' if name in self.__dict__: self.__dict__[name] = value return if name in self._flags: if isinstance(value, Var): raise FlagException('%s was already defined' % name) if value is UNSET: flag_obj = self._flags[name] try: del flag_obj.value except AttributeError: # No-op unsetting an unset flag. pass else: self._flags[name].set(value) else: if not isinstance(value, Var): raise FlagException('%s is not a flag.Var' % name) self._flags[name] = value def __dir__(self): '''Returns all flag attributes declared in the namespace. :rtype: list[str] ''' return list(self.__dict__.keys()) + list(self._flags.keys())
[docs]class FlagException(Exception): '''Error in flag initialization or access.'''
[docs]class ParseException(Exception): '''Error in command-line parsing.'''
[docs]class Var(object): '''Base of all flag accessors.''' value = UNSET type_str = 'Unknown' def __init__(self, description, default=None, secure=False): '''Create a new Var flag accessor. :type description: str :type default: None or _Required of object or T :type secure: bool ''' self.description = description self.default = default self.secure = secure
[docs] def get(self): '''Return the flag value, or default if it is not set.''' return self.value if self.is_set() else self.default
[docs] def is_set(self): ''' :rtype: bool ''' return self.value is not UNSET
[docs]class String(Var): '''String-valued flag.''' def __init__(self, description, default=None, secure=False): """ :rtype: str """ super(String, self).__init__(description, default, secure) type_str = 'String'
[docs] def set(self, value): self.value = value
[docs]class Int(Var): '''Integer-valued flag.''' def __init__(self, description, default=None, secure=False): """ :rtype: int """ super(Int, self).__init__(description, default, secure) type_str = 'Int'
[docs] def set(self, value): self.value = int(value)
[docs]class Float(Var): '''Float-valued flag.''' def __init__(self, description, default=None, secure=False): """ :rtype: float """ super(Float, self).__init__(description, default, secure) type_str = 'Float'
[docs] def set(self, value): self.value = float(value)
[docs]class Bool(Var): '''Boolean-valued flag.''' def __init__(self, description, default=None, secure=False): """ :rtype: bool """ super(Bool, self).__init__(description, default, secure) type_str = 'Bool'
[docs] def set(self, value): if value.lower() in ('t', 'true', 'yes', 'on', '1'): self.value = True elif value.lower() in ('f', 'false', 'no', 'off', '0'): self.value = False else: raise ValueError(value)
[docs]class List(Var): '''Flag that is a list of another flag type.''' type_str = 'List[_]' def __init__(self, inner_type, separator, description, default=None, secure=False): '''Create a list flag of `inner_type`. E.g.: >>> int_list = List(Int, ',', 'List of integers.') >>> int_list.set('1,2,3,4,5') >>> int_list.get() [1, 2, 3, 4, 5] :type inner_type: type :type separator: str :type description: str :type default: list or None or _Required :type secure: bool :rtype: list ''' default = [] if default is None else default super(List, self).__init__(description, default, secure) self.separator = separator self.inner_value = inner_type(description, default, secure) self.type_str = 'List[%s]' % self.inner_value.type_str
[docs] def set(self, value): ''' :type value: str ''' self.value = [] for part in value.split(self.separator): self.inner_value.set(part) self.value.append(self.inner_value.get())
# Default functions that use the default flagset. GLOBAL_FLAGS = GlobalFlagSet()
[docs]def namespace(name): '''Return a namespace from :const:`GLOBAL_FLAGS`. :type name: str :rtype: NamespaceFlagSet ''' return GLOBAL_FLAGS.namespace(name)
[docs]def parse_commandline(args): '''Parse commandline ``args`` with :const:`GLOBAL_FLAGS`. :type args: list[str] ''' GLOBAL_FLAGS.parse_commandline(args)
[docs]def parse_environment(args): '''Parse environment ``args`` with :const:`GLOBAL_FLAGS`. :type args: list[tuple(str, str)] ''' GLOBAL_FLAGS.parse_environment(args)
[docs]def parse_ini(file_p): '''Parse a :mod:`ConfigParser` compatible file with :const:`GLOBAL_FLAGS`. :type file_p: file ''' GLOBAL_FLAGS.parse_ini(file_p)
[docs]def args(): '''Return positional ``args`` from :const:`GLOBAL_FLAGS`. :rtype: list[str] ''' return GLOBAL_FLAGS.args
[docs]def die_on_missing_required(): '''If missing required flags, die and write usage.''' nonset = GLOBAL_FLAGS.check_required() if nonset: sys.stderr.write('Missing required flags:\n') for namespace, name, _ in nonset: sys.stderr.write(' [%s.]%s\n' % (namespace, name)) GLOBAL_FLAGS.usage_long(GLOBAL_FLAGS, return_code=1)