Skip to content

Commit

Permalink
@Inject decorator (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmagusiak authored Dec 4, 2023
1 parent af440bd commit a93f0e1
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 51 deletions.
36 changes: 29 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,9 @@ Then configuration is built from:
- `PYTHON_ALPHACONF` environment variable may contain a path to load
- configuration files from configuration directories (using application name)
- environment variables based on key prefixes,
except "BASE" and "PYTHON";
except "BASE" and "PYTHON"; \
if you have a configuration key "abc", all environment variables starting
with "ABC_" will be loaded where keys are converted to lower case and "_"
to ".": "ABC_HELLO=a" would set "abc.hello=a"
with "ABC_" will be loaded, for example "ABC_HELLO=a" would set "abc.hello=a"
- key-values from the program arguments

Finally, the configuration is fully resolved and logging is configured.
Expand Down Expand Up @@ -104,10 +103,11 @@ class MyConf(pydantic.BaseModel):

def build(self):
# use as a factory pattern to create more complex objects
# for example, a connection to the database
return self.value * 2

# setup the configuration
alphaconf.setup_configuration(MyConf, path='a')
alphaconf.setup_configuration(MyConf, prefix='a')
# read the value
alphaconf.get('a', MyConf)
v = alphaconf.get(MyConf) # because it's registered as a type
Expand All @@ -122,6 +122,29 @@ You can read values or passwords from files, by using the template
or, more securely, read the file in the code
`alphaconf.get('secret_file', Path).read_text().strip()`.

### Inject parameters

We can inject default values to functions from the configuration.
Either one by one, where we can map a factory function or a configuration key.
Or inject all automatically base on the parameter name.

```python
from alphaconf.inject import inject, inject_auto

@inject('name', 'application.name')
@inject_auto(ignore={'name'})
def main(name: str, example=None):
pass

# similar to
def main(name: str=None, example=None):
if name is None:
name = alphaconf.get('application.name', str)
if example is None:
example = alphaconf.get('example', default=example)
...
```

### Invoke integration

Just add the lines below to parameterize invoke.
Expand All @@ -136,9 +159,8 @@ alphaconf.invoke.run(__name__, ns)
```

## Way to 1.0
- Run function `@alphaconf.inject`
- Run a specific function `alphaconf.cli.run_module()`:
find functions and parse their args
- Run a specific function `alphaconf my.module.main`:
find functions and inject args
- Install completions for bash `alphaconf --install-autocompletion`

[OmegaConf]: https://omegaconf.readthedocs.io/
Expand Down
92 changes: 92 additions & 0 deletions alphaconf/inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import functools
import inspect
from typing import Any, Callable, Dict, Optional, Union

import alphaconf

from .internal.type_resolvers import type_from_annotation

__all__ = ["inject", "inject_auto"]


class ParamDefaultsFunction:
"""Function wrapper that injects default parameters"""

_arg_factory: Dict[str, Callable[[], Any]]

def __init__(self, func: Callable):
self.func = func
self.signature = inspect.signature(func)
self._arg_factory = {}

def bind(self, name: str, factory: Callable[[], Any]):
self._arg_factory[name] = factory

def __call__(self, *a, **kw):
args = self.signature.bind_partial(*a, **kw).arguments
kw.update(
{name: factory() for name, factory in self._arg_factory.items() if name not in args}
)
return self.func(*a, **kw)

@staticmethod
def wrap(func) -> "ParamDefaultsFunction":
if isinstance(func, ParamDefaultsFunction):
return func
return functools.wraps(func)(ParamDefaultsFunction(func))


def getter(
key: str, ktype: Optional[type] = None, *, param: Optional[inspect.Parameter] = None
) -> Callable[[], Any]:
"""Factory function that calls alphaconf.get
The parameter from the signature can be given to extract the type to cast to
and whether the configuration value is optional.
:param key: The key using in alphaconf.get
:param ktype: Type to cast to
:param param: The parameter object from the signature
"""
if ktype is None and param and (ptype := param.annotation) is not param.empty:
ktype = next(type_from_annotation(ptype), None)
if param is not None and param.default is not param.empty:
xparam = param
return (
lambda: xparam.default
if (value := alphaconf.get(key, ktype, default=None)) is None
and xparam.default is not xparam.empty
else value
)
return lambda: alphaconf.get(key, ktype)


def inject(name: str, factory: Union[None, str, Callable[[], Any]]):
"""Inject an argument to a function from a factory or alphaconf"""

def do_inject(func):
f = ParamDefaultsFunction.wrap(func)
if isinstance(factory, str) or factory is None:
b = getter(factory or name, param=f.signature.parameters[name])
else:
b = factory
f.bind(name, b)
return f

return do_inject


def inject_auto(*, prefix: str = "", ignore: set = set()):
"""Inject automatically all paramters"""
if prefix and not prefix.endswith("."):
prefix += "."

def do_inject(func):
f = ParamDefaultsFunction.wrap(func)
for name, param in f.signature.parameters.items():
if name in ignore:
continue
f.bind(name, getter(prefix + name, param=param))
return f

return do_inject
49 changes: 23 additions & 26 deletions alphaconf/internal/configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import os
import typing
import warnings
from enum import Enum
from typing import (
Expand All @@ -19,7 +18,7 @@

from omegaconf import Container, DictConfig, OmegaConf

from .type_resolvers import convert_to_type, pydantic
from .type_resolvers import convert_to_type, pydantic, type_from_annotation

T = TypeVar('T')

Expand Down Expand Up @@ -92,14 +91,14 @@ def get(self, key: Union[str, Type], type=None, *, default=raise_on_missing):
)
if value is raise_on_missing:
if default is raise_on_missing:
raise ValueError(f"No value for: {key}")
raise KeyError(f"No value for: {key}")
return default
# check the returned type and convert when necessary
if type is not None and isinstance(value, type):
return value
if isinstance(value, Container):
value = OmegaConf.to_object(value)
if type is not None and default is not None:
if type is not None and value is not default:
value = convert_to_type(value, type)
return value

Expand All @@ -110,12 +109,12 @@ def __get_type(self, key: Type, *, default=raise_on_missing):
key_str = self.__type_path.get(key)
if key_str is None:
if default is raise_on_missing:
raise ValueError(f"Key not found for type {key}")
raise KeyError(f"Key not found for type {key}")
return default
try:
value = self.get(key_str, key)
self.__type_value = value
except ValueError:
except KeyError:
if default is raise_on_missing:
raise
value = default
Expand All @@ -130,7 +129,7 @@ def setup_configuration(
conf: Union[DictConfig, dict, Any],
helpers: Dict[str, str] = {},
*,
path: str = "",
prefix: str = "",
):
"""Add a default configuration
Expand All @@ -146,27 +145,27 @@ def setup_configuration(
conf_type = None
if conf_type:
# if already registered, set path to None
self.__type_path[conf_type] = None if conf_type in self.__type_path else path
self.__type_path[conf_type] = None if conf_type in self.__type_path else prefix
self.__type_value.pop(conf_type, None)
if path and not path.endswith('.'):
path += "."
if prefix and not prefix.endswith('.'):
prefix += "."
if isinstance(conf, str):
warnings.warn("provide a dict directly", DeprecationWarning)
created_config = OmegaConf.create(conf)
if not isinstance(created_config, DictConfig):
raise ValueError("The config is not a dict")
conf = created_config
if isinstance(conf, DictConfig):
config = self.__prepare_dictconfig(conf, path=path)
config = self.__prepare_dictconfig(conf, path=prefix)
else:
created_config = self.__prepare_config(conf, path=path)
created_config = self.__prepare_config(conf, path=prefix)
if not isinstance(created_config, DictConfig):
raise ValueError("Failed to convert to a DictConfig")
raise TypeError("Failed to convert to a DictConfig")
config = created_config
# add path and merge
if path:
config = self.__add_path(config, path.rstrip("."))
helpers = {path + k: v for k, v in helpers.items()}
# add prefix and merge
if prefix:
config = self.__add_prefix(config, prefix.rstrip("."))
helpers = {prefix + k: v for k, v in helpers.items()}
self._merge([config])
# helpers
self.helpers.update(**helpers)
Expand Down Expand Up @@ -221,12 +220,12 @@ def __prepare_dictconfig(
sub_configs = []
for k, v in obj.items_ex(resolve=False):
if not isinstance(k, str):
raise ValueError("Expecting only str instances in dict")
raise TypeError("Expecting only str instances in dict")
if recursive:
v = self.__prepare_config(v, path + k + ".")
if '.' in k:
obj.pop(k)
sub_configs.append(self.__add_path(v, k))
sub_configs.append(self.__add_prefix(v, k))
if sub_configs:
obj = cast(DictConfig, OmegaConf.unsafe_merge(obj, *sub_configs))
return obj
Expand All @@ -252,9 +251,6 @@ def __prepare_pydantic(self, obj, path):
# pydantic instance, prepare helpers
self.__prepare_pydantic(type(obj), path)
return obj.model_dump(mode="json")
# parse typing recursively for documentation
for t in typing.get_args(obj):
self.__prepare_pydantic(t, path)
# check if not a type
if not isinstance(obj, type):
return obj
Expand All @@ -279,13 +275,14 @@ def __prepare_pydantic(self, obj, path):
from alphaconf import SECRET_MASKS

SECRET_MASKS.append(lambda s: s == path)
elif check_type and field.annotation:
self.__prepare_pydantic(field.annotation, path + k + ".")
elif check_type:
for ftype in type_from_annotation(field.annotation):
self.__prepare_pydantic(ftype, path + k + ".")
return defaults
return None

@staticmethod
def __add_path(config: Any, path: str) -> DictConfig:
for part in reversed(path.split(".")):
def __add_prefix(config: Any, prefix: str) -> DictConfig:
for part in reversed(prefix.split(".")):
config = OmegaConf.create({part: config})
return config
24 changes: 15 additions & 9 deletions alphaconf/internal/type_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import typing
from pathlib import Path

from omegaconf import OmegaConf
Expand All @@ -16,11 +17,7 @@
"""


def read_text(value):
return Path(value).expanduser().read_text()


def parse_bool(value) -> bool:
def _parse_bool(value) -> bool:
if isinstance(value, str):
value = value.strip().lower()
if value in ('no', 'false', 'n', 'f', 'off', 'none', 'null', 'undefined', '0'):
Expand All @@ -29,14 +26,14 @@ def parse_bool(value) -> bool:


TYPE_CONVERTER = {
bool: parse_bool,
bool: _parse_bool,
datetime.datetime: datetime.datetime.fromisoformat,
datetime.date: lambda s: datetime.datetime.strptime(s, '%Y-%m-%d').date(),
datetime.time: datetime.time.fromisoformat,
Path: lambda s: Path(s).expanduser(),
Path: lambda s: Path(str(s)).expanduser(),
str: lambda v: str(v),
'read_text': read_text,
'read_strip': lambda s: read_text(s).strip(),
'read_text': lambda s: Path(s).expanduser().read_text(),
'read_strip': lambda s: Path(s).expanduser().read_text().strip(),
'read_bytes': lambda s: Path(s).expanduser().read_bytes(),
}

Expand Down Expand Up @@ -65,3 +62,12 @@ def convert_to_type(value, type):
if pydantic:
return pydantic.TypeAdapter(type).validate_python(value)
return type(value)


def type_from_annotation(annotation) -> typing.Generator[type, None, None]:
"""Given an annotation (optional), figure out the types"""
if isinstance(annotation, type) and annotation is not type(None):
yield annotation
else:
for t in typing.get_args(annotation):
yield from type_from_annotation(t)
2 changes: 1 addition & 1 deletion example-typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MyConfiguration(BaseModel):
connection: Optional[Conn] = None


alphaconf.setup_configuration(MyConfiguration, path="c")
alphaconf.setup_configuration(MyConfiguration, prefix="c")


def main():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ classifiers = [
]
dependencies = [
"omegaconf>=2",
"pydantic>=2",
]

[project.optional-dependencies]
color = ["colorama"]
dotenv = ["python-dotenv"]
invoke = ["invoke"]
pydantic = ["pydantic>=2"]
toml = ["toml"]

[project.urls]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_alphaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_setup_configuration():


def test_setup_configuration_invalid():
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# invalid configuration (must be non-empty)
alphaconf.setup_configuration(None)

Expand Down Expand Up @@ -132,8 +132,8 @@ def test_app_environ(application):
)
application.setup_configuration(load_dotenv=False, env_prefixes=True)
config = application.configuration
with pytest.raises(ValueError):
# XXX should not be loaded
with pytest.raises(KeyError):
# prefix with underscore only should be loaded
config.get('xxx')
assert config.get('testmyenv.x') == 'overwrite'
assert config.get('testmyenv.y') == 'new'
Loading

0 comments on commit a93f0e1

Please sign in to comment.