Skip to content

Commit

Permalink
🍻 Args compatibility in v1 dir
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 24, 2024
1 parent a6fd365 commit 78f6342
Show file tree
Hide file tree
Showing 13 changed files with 425 additions and 166 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ dev = [
]

[tool.pdm.scripts]
test = "pytest -v -W ignore --ignore entry_test.py --durations=0 -s"
test = "pytest -v --ignore entry_test.py --durations=0 -s"
benchmark = "python benchmark.py"
deps = "pydeps -o alconna.svg ./src/arclet/alconna --max-bacon=4 --cluster --keep-target-cluster --rmprefix alconna. "

Expand Down
64 changes: 39 additions & 25 deletions src/arclet/alconna/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Generic, Literal, TypeVar, ClassVar, ForwardRef, Final, TYPE_CHECKING, get_origin, get_args
from typing_extensions import dataclass_transform, ParamSpec, Concatenate, TypeAlias

from nepattern import NONE, BasePattern, RawStr, UnionPattern, parser, STRING
from nepattern import NONE, BasePattern, RawStr, UnionPattern, parser
from tarina import Empty, lang

from ._dcls import safe_dcls_kw, safe_field_kw
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(
setattr(self.field, k, v)

def __str__(self):
n, v = f"'{self.name}'", str(self.type_)
n, v = f"'{self.name_display}'", self.type_display
return (n if n == v else f"{n}: {v}") + (f" = '{self.field.display}'" if self.field.display is not Empty else "")

def __add__(self, other) -> "ArgsBuilder":
Expand All @@ -162,6 +162,33 @@ def __iter__(self):
def separators(self):
return self.field.seps

@property
def name_display(self):
n = self.name
if self.field.optional:
n = f"{n}?"
if self.field.notice:
n = f"{n}#{self.field.notice}"
return n

@property
def type_display(self):
if self.field.hidden:
return "***"
v = str(self.type_)
if self.field.kw_only:
v = f"{self.field.kw_sep}{v}"
if self.field.multiple is not False:
if self.field.multiple is True:
v = f"({v}+)"
elif self.field.multiple == "str":
v = f"{v}+"
elif isinstance(self.field.multiple, int):
v = f"({v}+)[:{self.field.multiple}]"
else:
v = f"({v}{self.field.multiple})"
return v


class _Args:
__slots__ = ("unpack", "vars_positional", "vars_keyword", "keyword_only", "normal", "data", "_visit", "optional_count", "origin")
Expand All @@ -172,7 +199,7 @@ def __init__(self, args: list[Arg[Any]], origin: type[ArgsBase] | None = None):
self.normal: list[Arg[Any]] = []
self.keyword_only: dict[str, Arg[Any]] = {}
self.vars_positional: list[tuple[int | Literal["+", "*", "str"], Arg[Any]]] = []
self.vars_keyword: list[tuple[str, Arg[Any]]] = []
self.vars_keyword: list[tuple[int | Literal["+", "*", "str"], Arg[Any]]] = []
self._visit = set()
self.optional_count = 0
self.__check_vars__()
Expand All @@ -194,18 +221,18 @@ def __check_vars__(self):
continue
self._visit.add(arg.name)
if arg.field.multiple is not False:
flag = arg.field.multiple
if flag is True:
flag = "+"
if arg.field.kw_only:
for slot in self.vars_positional:
_, a = slot
if arg.field.kw_sep in a.field.seps:
raise InvalidArgs("varkey cannot use the same sep as varpos's Arg")
self.vars_keyword.append((arg.field.kw_sep, arg))
self.vars_keyword.append((flag, arg))
elif self.keyword_only:
raise InvalidArgs(lang.require("args", "exclude_mutable_args"))
else:
flag = arg.field.multiple
if flag is True:
flag = "+"
self.vars_positional.append((flag, arg))
elif arg.field.kw_only:
if self.vars_keyword:
Expand Down Expand Up @@ -239,9 +266,10 @@ def __repr__(self):


_P = ParamSpec("_P")
_T1 = TypeVar("_T1", bound="ArgsBuilder")


def _arg_init_wrapper(func: Callable[_P, Field[_T]]) -> Callable[[ArgsBuilder, str], Callable[Concatenate[TAValue[_T], _P], ArgsBuilder]]:
def _arg_init_wrapper(func: Callable[_P, Field[_T]]) -> Callable[[_T1, str], Callable[Concatenate[TAValue[_T], _P], _T1]]:
return lambda builder, name: lambda type_, *args, **kwargs: builder.__lshift__(Arg(name, type_, func(*args, **kwargs)))


Expand All @@ -259,15 +287,6 @@ def __lshift__(self, arg: Arg):
def __getattr__(self, item: str):
return wrapper(self, item)

def __getitem__(self, item):
# warnings.warn("Args[...] is deprecated, use Args.xxx(...) instead", DeprecationWarning, stacklevel=2)
data: tuple[Arg, ...] | tuple[Any, ...] = item if isinstance(item, tuple) else (item,)
if isinstance(data[0], Arg):
self._args.extend(data)
else:
self._args.append(Arg(*data))
return self

def build(self):
return _Args(self._args)

Expand All @@ -284,13 +303,8 @@ class __ArgsBuilderInstance:
def __getattr__(self, item: str):
return ArgsBuilder().__getattr__(item)

def __getitem__(self, item):
# warnings.warn("Args[...] is deprecated, use Args.xxx(...) instead", DeprecationWarning, stacklevel=2)
data: tuple[Arg, ...] | tuple[Any, ...] = item if isinstance(item, tuple) else (item,)
if isinstance(data[0], Arg):
return ArgsBuilder(*data)
else:
return ArgsBuilder(Arg(*data))
def __lshift__(self, other):
return ArgsBuilder() << other


Args: Final = __ArgsBuilderInstance()
Expand Down Expand Up @@ -349,7 +363,7 @@ def __new__(
raise TypeError(f"{name!r} is a Field but has no type annotation")
cls.__args_data__ = _Args(data_args + cls_args, cls)

dcls = dc.make_dataclass(cls.__name__, [((arg.name, arg.type_, arg.field.to_dc_field()) ) for arg in cls.__args_data__.data], namespace=types_namespace, repr=True)
dcls = dc.make_dataclass(cls.__name__, [(arg.name, arg.type_, arg.field.to_dc_field()) for arg in cls.__args_data__.data], namespace=types_namespace, repr=True)
cls.__init__ = dcls.__init__ # type: ignore
if "__repr__" not in cls.__dict__:
cls.__repr__ = dcls.__repr__ # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def param(self, parameter: Arg) -> str:
return f"<...{name}>"
arg = f"[{name}" if parameter.field.optional else f"<{name}"
if parameter.type_ not in (ANY, AnyString):
arg += f": {parameter.type_}"
arg += f": {parameter.type_display}"
if parameter.field.display is not Empty:
arg += f" = {parameter.field.display}"
return f"{arg}]" if parameter.field.optional else f"{arg}>"
Expand Down
8 changes: 4 additions & 4 deletions src/arclet/alconna/ingedia/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def step_varpos(argv: Argv, args: _Args, slot: tuple[int | Literal["+", "*", "st
if _str and kwonly_seps and split_once(pat.match(may_arg)["name"], kwonly_seps, argv.filter_crlf)[0] in args.keyword_only: # noqa: E501 # type: ignore
argv.rollback(may_arg)
break
if _str and args.vars_keyword and args.vars_keyword[0][0] in may_arg:
if _str and args.vars_keyword and args.vars_keyword[0][1].field.kw_sep in may_arg:
argv.rollback(may_arg)
break
if (res := value.validate(may_arg)).flag != "valid":
Expand All @@ -118,13 +118,13 @@ def step_varpos(argv: Argv, args: _Args, slot: tuple[int | Literal["+", "*", "st
result[key] = tuple(_result)


def step_varkey(argv: Argv, slot: tuple[str, Arg], result: dict[str, Any]):
kw_sep, arg = slot
flag = arg.field.multiple
def step_varkey(argv: Argv, slot: tuple[int | Literal["+", "*", "str"], Arg], result: dict[str, Any]):
flag, arg = slot
length = int(flag) if flag.__class__ is int else -1
value = arg.type_
name = arg.name
default_val = arg.field.default
kw_sep = arg.field.kw_sep
_result = {}
count = 0
while argv.current_index != argv.ndata:
Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from typing_extensions import TypeAlias

from nepattern import BasePattern, MatchFailed, MatchMode, parser
from nepattern import BasePattern, MatchFailed, MatchMode
from tarina import generic_isinstance, lang


Expand Down
20 changes: 11 additions & 9 deletions src/arclet/alconna/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from arclet.alconna.action import store_true as store_true # noqa: F401
from arclet.alconna.action import store_value as store_value # noqa: F401
from arclet.alconna.args import Arg as Arg # noqa: F401
# from arclet.alconna.args import ArgFlag as ArgFlag # noqa: F401
from arclet.alconna.args import Args as Args # noqa: F401
from arclet.alconna.args import Field as Field # noqa: F401
from arclet.alconna.ingedia._argv import Argv as Argv # noqa: F401
from arclet.alconna.ingedia._argv import argv_config as argv_config # noqa: F401
Expand All @@ -37,13 +35,9 @@
from arclet.alconna.manager import ShortcutArgs as ShortcutArgs # noqa: F401
from arclet.alconna.manager import command_manager as command_manager # noqa: F401
from arclet.alconna.typing import AllParam as AllParam # noqa: F401
# from arclet.alconna.typing import KeyWordVar as KeyWordVar # noqa: F401
# from arclet.alconna.typing import Kw as Kw # noqa: F401
# from arclet.alconna.typing import MultiVar as MultiVar # noqa: F401
# from arclet.alconna.typing import Nargs as Nargs # noqa: F401
# from arclet.alconna.typing import StrMulti as StrMulti # noqa: F401
# from arclet.alconna.typing import UnpackVar as UnpackVar # noqa: F401
# from arclet.alconna.typing import Up as Up # noqa: F401

from .args import ArgFlag as ArgFlag
from .args import Args as Args

from .compat import CommandMeta as CommandMeta
from .compat import Namespace as Namespace
Expand All @@ -54,3 +48,11 @@
from .stub import ArgsStub as ArgsStub
from .stub import OptionStub as OptionStub
from .stub import SubcommandStub as SubcommandStub

from .typing import KeyWordVar as KeyWordVar
from .typing import Kw as Kw
from .typing import MultiVar as MultiVar
from .typing import Nargs as Nargs
from .typing import StrMulti as StrMulti
from .typing import UnpackVar as UnpackVar
from .typing import Up as Up
118 changes: 118 additions & 0 deletions src/arclet/alconna/v1/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

import warnings
from enum import Enum
from typing import Any, Final, Iterable

from tarina import Empty
from typing_extensions import Self

from arclet.alconna.args import ArgsBuilder, Arg
from arclet.alconna.typing import TAValue

from .typing import KeyWordVar, MultiVar, _StrMulti, UnpackVar


class ArgFlag(str, Enum):
"""标识参数单元的特殊属性"""

OPTIONAL = "?"
HIDDEN = "/"
ANTI = "!"


class _CompatArgsBuilder(ArgsBuilder):
def __getitem__(self, item):
data: tuple[Arg, ...] | tuple[Any, ...] = item if isinstance(item, tuple) else (item,)
if isinstance(data[0], Arg):
self._args.extend(data)
else:
self._args.append(Arg(*data))
return self

def build(self):
for arg in self._args:
value = arg.type_
if isinstance(value, MultiVar):
if isinstance(value, _StrMulti):
arg.field.multiple = "str"
else:
arg.field.multiple = value.flag if value.length < 1 else value.length
arg.type_ = value.base
if isinstance(value.base, KeyWordVar):
arg.type_ = value.base.base
arg.field.kw_only = True
arg.field.kw_sep = value.base.sep
elif isinstance(value, KeyWordVar):
arg.field.kw_only = True
arg.field.kw_sep = value.sep
arg.type_ = value.base
elif isinstance(value, UnpackVar):
arg.type_ = value.of(value.origin)
return super().build()

def __truediv__(self, other) -> Self:
self.separate(*other if isinstance(other, (list, tuple, set)) else other)
return self

def separate(self, *separator: str) -> Self:
"""设置参数的分隔符
Args:
*separator (str): 分隔符
Returns:
Self: 参数集合自身
"""
for arg in self._args:
arg.field.seps = "".join(separator)
return self

def add(self, name: str, *, value: TAValue[Any], default: Any = Empty, flags: list[ArgFlag] | None = None) -> Self:
"""添加一个参数
Args:
name (str): 参数名称
value (TAValue): 参数值
default (Any, optional): 参数默认值.
flags (list[ArgFlag] | None, optional): 参数标记.
Returns:
Self: 参数集合自身
"""
if next(filter(lambda x: x.name == name, self._args), False):
return self
self._args.append(Arg(name, value, default))
return self


class __CompatArgsBuilderInstance:
__slots__ = ()

def __getattr__(self, item: str):
return _CompatArgsBuilder().__getattr__(item)

def __getitem__(self, item):
warnings.warn("Args[...] is deprecated, use Args.xxx(...) instead", DeprecationWarning, stacklevel=2)
data: tuple[Arg, ...] | tuple[Any, ...] = item if isinstance(item, tuple) else (item,)
if isinstance(data[0], Arg):
return _CompatArgsBuilder(*data)
else:
return _CompatArgsBuilder(Arg(*data))

def __call__(self, *args: Arg[Any], separators: str | Iterable[str] | None = None):
"""
构造一个 `Args`
Args:
*args (Arg): 参数单元
separators (str | Iterable[str] | None, optional): 可选的为所有参数单元指定分隔符
"""
if separators is not None:
seps = "".join(separators) if isinstance(separators, Iterable) else separators
for arg in args:
arg.field.seps = seps
return _CompatArgsBuilder(*args)


Args: Final = __CompatArgsBuilderInstance()
Loading

0 comments on commit 78f6342

Please sign in to comment.