Skip to content

Commit

Permalink
fix: handle optional types (#35)
Browse files Browse the repository at this point in the history
We fix the handling of optional type to ensure that the type
wrapped by `Optional` undergoes the same handling as a
non-Optional one.
  • Loading branch information
P403n1x87 authored Jul 5, 2024
1 parent 7a7a988 commit 88d3bc2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 28 deletions.
63 changes: 36 additions & 27 deletions envier/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,35 @@ def __init__(
self.help_type = help_type
self.help_default = help_default

def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any:
if _type is bool:
return t.cast(T, raw.lower() in env.__truthy__)
elif _type in (list, tuple, set):
collection = raw.split(env.__item_separator__)
return t.cast(
T,
_type( # type: ignore[operator]
collection if self.map is None else map(self.map, collection) # type: ignore[arg-type]
),
)
elif _type is dict:
d = dict(
_.split(env.__value_separator__, 1)
for _ in raw.split(env.__item_separator__)
)
if self.map is not None:
d = dict(self.map(*_) for _ in d.items())
return t.cast(T, d)

if _check_type(raw, _type):
return t.cast(T, raw)

try:
return _type(raw)
except Exception as e:
msg = f"cannot cast {raw} to {self.type}"
raise TypeError(msg) from e

def _retrieve(self, env: "Env", prefix: str) -> T:
source = env.source

Expand Down Expand Up @@ -121,36 +150,14 @@ def _retrieve(self, env: "Env", prefix: str) -> T:
)
return parsed

if self.type is bool:
return t.cast(T, raw.lower() in env.__truthy__)
elif self.type in (list, tuple, set):
collection = raw.split(env.__item_separator__)
return t.cast(
T,
self.type( # type: ignore[operator]
collection if self.map is None else map(self.map, collection) # type: ignore[arg-type]
),
)
elif self.type is dict:
d = dict(
_.split(env.__value_separator__, 1)
for _ in raw.split(env.__item_separator__)
)
if self.map is not None:
d = dict(self.map(*_) for _ in d.items())
return t.cast(T, d)

if _check_type(raw, self.type):
return t.cast(T, raw)

if hasattr(self.type, "__origin__") and self.type.__origin__ is t.Union: # type: ignore[attr-defined,union-attr]
for ot in self.type.__args__: # type: ignore[attr-defined,union-attr]
try:
return t.cast(T, ot(raw))
return t.cast(T, self._cast(ot, raw, env))
except TypeError:
pass

return self.type(raw) # type: ignore[call-arg,operator]
return self._cast(self.type, raw, env)

def __call__(self, env: "Env", prefix: str) -> T:
value = self._retrieve(env, prefix)
Expand Down Expand Up @@ -436,9 +443,11 @@ def add_entries(full_prefix: str, config: t.Type[Env]) -> None:
(
f"``{private_prefix}{full_prefix}{_normalized(v.name)}``",
help_type, # type: ignore[attr-defined]
v.help_default
if v.help_default is not None
else str(v.default),
(
v.help_default
if v.help_default is not None
else str(v.default)
),
help_message,
)
)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,16 @@ class DictConfig(Env):
assert DictConfig().foo == expected


def test_env_optional_default():
def test_env_optional_default(monkeypatch):
class DictConfig(Env):
foo = Env.var(Optional[str], "foo", default=None)
bar = Env.var(Optional[bool], "bar", default=None)

assert DictConfig().foo is None
assert DictConfig().bar is None

monkeypatch.setenv("BAR", "0")
assert not DictConfig().bar


@pytest.mark.parametrize("value,_type", [(1, int), ("1", str)])
Expand Down

0 comments on commit 88d3bc2

Please sign in to comment.