diff --git a/src/arclet/alconna/_internal/_analyser.py b/src/arclet/alconna/_internal/_analyser.py index ac4e1ac4..f6d4fb9b 100644 --- a/src/arclet/alconna/_internal/_analyser.py +++ b/src/arclet/alconna/_internal/_analyser.py @@ -17,6 +17,7 @@ from ..model import HeadResult, OptionResult, Sentence, SubcommandResult from ..output import output_manager from ..typing import TDC, InnerShortcutArgs +from ..constraint import SHORTCUT_TRIGGER, SHORTCUT_ARGS, SHORTCUT_REST, SHORTCUT_REGEX_MATCH from ._handlers import ( _handle_shortcut_data, _handle_shortcut_reg, @@ -181,7 +182,7 @@ def process(self, argv: Argv[TDC]) -> Self: ParamsUnmatched: 名称不匹配 FuzzyMatchSuccess: 模糊匹配成功 """ - sub = argv.context = self.command + sub = argv.current_node = self.command name, _ = argv.next(sub.separators) if name != sub.name: # 先匹配节点名称 if argv.fuzzy_match and levenshtein(name, sub.name) >= argv.fuzzy_threshold: @@ -332,6 +333,10 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]: raise e from exc return self.export(argv, True, e) else: + argv.context[SHORTCUT_TRIGGER] = _next + argv.context[SHORTCUT_ARGS] = short + argv.context[SHORTCUT_REST] = rest + argv.context[SHORTCUT_REGEX_MATCH] = mat self.reset() argv.reset() return self.shortcut(argv, rest, short, mat) @@ -384,7 +389,7 @@ def analyse(self, argv: Argv[TDC]) -> Arparma[TDC] | None: return _SPECIAL[handler](self, argv) if comp_ctx.get(None): if isinstance(e1, InvalidParam): - argv.free(argv.context.separators if argv.context else None) + argv.free(argv.current_node.separators if argv.current_node else None) raise PauseTriggered(prompt(self, argv), e1, argv) from e1 if self.command.meta.raise_exception: raise @@ -406,7 +411,7 @@ def export( fail (bool, optional): 是否解析失败. Defaults to False. exception (Exception | None, optional): 解析失败时的异常. Defaults to None. """ - result = Arparma(self.command.path, argv.origin, not fail, self.header_result) + result = Arparma(self.command.path, argv.origin, not fail, self.header_result, ctx=argv.exit()) if fail: result.error_info = exception result.error_data = argv.release() diff --git a/src/arclet/alconna/_internal/_argv.py b/src/arclet/alconna/_internal/_argv.py index 0db02aca..ff4742bf 100644 --- a/src/arclet/alconna/_internal/_argv.py +++ b/src/arclet/alconna/_internal/_argv.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import Any, Callable, ClassVar, Generic, Iterable from typing_extensions import Self @@ -11,6 +11,7 @@ from ..config import Namespace, config from ..exceptions import NullMessage from ..typing import TDC +from ..constraint import ARGV_OVERRIDES @dataclass(repr=True) @@ -41,7 +42,7 @@ class Argv(Generic[TDC]): param_ids: set[str] = field(default_factory=set) """节点名集合""" - context: Arg | Subcommand | Option | None = field(init=False) + current_node: Arg | Subcommand | Option | None = field(init=False) """当前节点""" current_index: int = field(init=False) """当前数据的索引""" @@ -55,6 +56,7 @@ class Argv(Generic[TDC]): """命令的token""" origin: TDC = field(init=False) """原始命令""" + context: dict[str, Any] = field(init=False) _sep: tuple[str, ...] | None = field(init=False) _cache: ClassVar[dict[type, dict[str, Any]]] = {} @@ -84,7 +86,7 @@ def reset(self): self.token = 0 self.origin = "None" # type: ignore self._sep = None - self.context = None + self.current_node = None @staticmethod def generate_token(data: list) -> int: @@ -254,3 +256,20 @@ def data_set(self): def data_reset(self, data: list[str | Any], index: int): self.raw_data = data self.current_index = index + + def enter(self, ctx: dict[str, Any] | None = None) -> Self: + """进入上下文""" + if ctx and ARGV_OVERRIDES in ctx: + field_names = [f.name for f in fields(self)] + for k, v in ctx[ARGV_OVERRIDES].items(): + if k in field_names: + setattr(self, k, v) + self.context = {} if ctx is None else ctx + return self + + def exit(self) -> dict[str, Any]: + """退出上下文""" + try: + return self.context + finally: + self.context = {} diff --git a/src/arclet/alconna/_internal/_handlers.py b/src/arclet/alconna/_internal/_handlers.py index 89bbc949..33818e76 100644 --- a/src/arclet/alconna/_internal/_handlers.py +++ b/src/arclet/alconna/_internal/_handlers.py @@ -46,7 +46,7 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any, Any], result def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict[str, Any]): value, arg = slot - argv.context = arg + argv.current_node = arg key = arg.name default_val = arg.field.default _result = [] @@ -87,7 +87,7 @@ def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict def step_varkey(argv: Argv, slot: tuple[MultiKeyWordVar, Arg], result: dict[str, Any]): value, arg = slot - argv.context = arg + argv.current_node = arg name = arg.name default_val = arg.field.default _result = {} @@ -190,7 +190,7 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: """ result = {} for arg in args.argument.normal: - argv.context = arg + argv.current_node = arg may_arg, _str = argv.next(arg.separators) if _str and may_arg in argv.special: if argv.special[may_arg] not in argv.namespace.disable_builtin_options: @@ -230,7 +230,7 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: step_keyword(argv, args, result) for slot in args.argument.vars_keyword: step_varkey(argv, slot, result) - argv.context = None + argv.current_node = None return result @@ -242,7 +242,7 @@ def handle_option(argv: Argv, opt: Option) -> tuple[str, OptionResult]: argv (Argv): 命令行参数 opt (Option): 目标 `Option` """ - argv.context = opt + argv.current_node = opt _cnt = 0 error = True name, _ = argv.next(opt.separators) @@ -340,7 +340,7 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv): _data.clear() return True except InvalidParam as e: - if argv.context.__class__ is Arg: + if argv.current_node.__class__ is Arg: raise e argv.data_reset(_data, _index) @@ -376,14 +376,14 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non elif analyser.compact_params and (res := analyse_compact_params(analyser, argv)): if res.__class__ is str: raise InvalidParam(res) - argv.context = None + argv.current_node = None return True else: _param = None if not _param and analyser.command.nargs and not analyser.args_result: analyser.args_result = analyse_args(argv, analyser.self_args) if analyser.args_result: - argv.context = None + argv.current_node = None return True if _param.__class__ is Sentence: analyser.sentences.append(argv.next()[0]) @@ -432,7 +432,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non else: return False analyser.sentences.clear() - argv.context = None + argv.current_node = None return True @@ -710,7 +710,7 @@ def _prompt_none(analyser: Analyser, argv: Argv, got: list[str]): def prompt(analyser: Analyser, argv: Argv, trigger: str | None = None): """获取补全列表""" - _trigger = trigger or argv.context + _trigger = trigger or argv.current_node got = [*analyser.options_result.keys(), *analyser.subcommands_result.keys(), *analyser.sentences] if isinstance(_trigger, Arg): return _prompt_unit(analyser, argv, _trigger) diff --git a/src/arclet/alconna/arparma.py b/src/arclet/alconna/arparma.py index 7bbb8faa..0540ed4f 100644 --- a/src/arclet/alconna/arparma.py +++ b/src/arclet/alconna/arparma.py @@ -123,6 +123,7 @@ def __init__( main_args: dict[str, Any] | None = None, options: dict[str, OptionResult] | None = None, subcommands: dict[str, SubcommandResult] | None = None, + ctx: dict[str, Any] | None = None, ): """初始化 `Arparma` Args: @@ -135,6 +136,7 @@ def __init__( main_args (dict[str, Any] | None, optional): 主参数匹配结果 options (dict[str, OptionResult] | None, optional): 选项匹配结果 subcommands (dict[str, SubcommandResult] | None, optional): 子命令匹配结果 + ctx (dict[str, Any] | None, optional): 上下文 """ self.source = source self.origin = origin @@ -146,11 +148,18 @@ def __init__( self.other_args = {} self.options = options or {} self.subcommands = subcommands or {} + self.context = ctx or {} _additional: ClassVar[dict[str, Callable[[], Any]]] = {} query = _Query[Any]() def _clr(self): + self.context.clear() + self.error_data.clear() + self.main_args.clear() + self.other_args.clear() + self.options.clear() + self.subcommands.clear() ks = list(self.__dict__.keys()) for k in ks: delattr(self, k) diff --git a/src/arclet/alconna/base.py b/src/arclet/alconna/base.py index 22efaa8e..2b5c538b 100644 --- a/src/arclet/alconna/base.py +++ b/src/arclet/alconna/base.py @@ -49,6 +49,8 @@ class CommandNode: name: str """命令节点名称""" + aliases: frozenset[str] + """命令节点别名""" dest: str """命令节点目标名称""" default: Any @@ -68,6 +70,7 @@ def __init__( self, name: str, args: Arg | Args | None = None, + alias: Iterable[str] | None = None, dest: str | None = None, default: Any = Empty, action: Action | None = None, @@ -88,12 +91,21 @@ def __init__( help_text (str | None, optional): 命令帮助信息 requires (str | list[str] | tuple[str, ...] | set[str] | None, optional): 命令节点需求前缀 """ - if not name: - raise InvalidArgs(lang.require("common", "name_empty")) - _parts = name.split(" ") - self.name = _parts[-1] + aliases = list(alias or []) + parts = name.split(" ") + _name = parts[-1] + if "|" in _name: + _aliases = _name.split("|") + _aliases.sort(key=len, reverse=True) + _name = _aliases[0] + aliases.extend(_aliases[1:]) + if not _name: + raise InvalidArgs(lang.require("common", "name_empty")) + aliases.insert(0, _name) + self.name = _name + self.aliases = frozenset(aliases) self.requires = ([requires] if isinstance(requires, str) else list(requires)) if requires else [] - self.requires.extend(_parts[:-1]) + self.requires.extend(parts[:-1]) self.args = Args() + args self.default = default self.action = action or store @@ -184,21 +196,12 @@ def __init__( compact (bool, optional): 是否允许名称与后随参数之间无分隔符 priority (int, optional): 命令选项优先级 """ - aliases = list(alias or []) - _name = name.split(" ")[-1] - if "|" in _name: - _aliases = _name.split("|") - _aliases.sort(key=len, reverse=True) - name = name.replace(_name, _aliases[0]) - _name = _aliases[0] - aliases.extend(_aliases[1:]) - aliases.insert(0, _name) - self.aliases = frozenset(aliases) + self.priority = priority self.compact = compact if default is not Empty: default = default if isinstance(default, OptionResult) else OptionResult(default) - super().__init__(name, args, dest, default, action, separators, help_text, requires) + super().__init__(name, args, alias, dest, default, action, separators, help_text, requires) if self.separators == ("",): self.compact = True self.separators = (" ",) @@ -266,6 +269,7 @@ def __init__( self, name: str, *args: Args | Arg | Option | Subcommand | list[Option | Subcommand], + alias: Iterable[str] | None = None, dest: str | None = None, default: Any = Empty, separators: str | Sequence[str] | set[str] | None = None, @@ -293,7 +297,7 @@ def __init__( super().__init__( name, reduce(lambda x, y: x + y, [Args()] + [i for i in args if isinstance(i, (Arg, Args))]), # type: ignore - dest, default, None, separators, help_text, requires, + alias, dest, default, None, separators, help_text, requires, ) def __add__(self, other: Option | Args | Arg | str) -> Self: diff --git a/src/arclet/alconna/constraint.py b/src/arclet/alconna/constraint.py new file mode 100644 index 00000000..b2bbf9e8 --- /dev/null +++ b/src/arclet/alconna/constraint.py @@ -0,0 +1,7 @@ +from typing import Literal + +ARGV_OVERRIDES: Literal["$argv.overrides"] = "$argv.overrides" +SHORTCUT_TRIGGER: Literal["$shortcut.trigger"] = "$shortcut.trigger" +SHORTCUT_REST: Literal["$shortcut.rest"] = "$shortcut.rest" +SHORTCUT_ARGS: Literal["$shortcut.args"] = "$shortcut.args" +SHORTCUT_REGEX_MATCH: Literal["$shortcut.regex_match"] = "$shortcut.regex_match" diff --git a/src/arclet/alconna/core.py b/src/arclet/alconna/core.py index e8d4e8ff..7ec9b476 100644 --- a/src/arclet/alconna/core.py +++ b/src/arclet/alconna/core.py @@ -346,29 +346,30 @@ def subcommand(self, sub: Subcommand) -> Self: """添加子命令""" return self.add(sub) - def _parse(self, message: TDC) -> Arparma[TDC]: + def _parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC]: if self._union: for ana, argv in command_manager.unpack(self.union): - if (res := ana.process(argv.build(message))).matched: + if (res := ana.process(argv.enter(ctx).build(message))).matched: return res analyser = command_manager.require(self) argv = command_manager.resolve(self) - argv.build(message) + argv.enter(ctx).build(message) return analyser.process(argv) @overload - def parse(self, message: TDC) -> Arparma[TDC]: + def parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC]: ... @overload - def parse(self, message, *, duplication: type[T_Duplication]) -> T_Duplication: + def parse(self, message, ctx: dict[str, Any] | None = None, *, duplication: type[T_Duplication]) -> T_Duplication: ... - def parse(self, message: TDC, *, duplication: type[T_Duplication] | None = None) -> Arparma[TDC] | T_Duplication: + def parse(self, message: TDC, ctx: dict[str, Any] | None = None, *, duplication: type[T_Duplication] | None = None) -> Arparma[TDC] | T_Duplication: """命令分析功能, 传入字符串或消息链, 返回一个特定的数据集合类 Args: message (TDC): 命令消息 + ctx (dict[str, Any], optional): 上下文信息 duplication (type[T_Duplication], optional): 指定的`副本`类型 Returns: Arparma[TDC] | T_Duplication: 若`duplication`参数为`None`则返回`Arparma`对象, 否则返回`duplication`类型的对象 @@ -376,11 +377,11 @@ def parse(self, message: TDC, *, duplication: type[T_Duplication] | None = None) NullMessage: 传入的消息为空时抛出 """ try: - arp = self._parse(message) + arp = self._parse(message, ctx) except NullMessage as e: if self.meta.raise_exception: raise e - return Arparma(self.path, message, False, error_info=e) + return Arparma(self.path, message, False, error_info=e, ctx=ctx) if arp.matched: arp = arp.execute(self.behaviors) if self._executors: diff --git a/src/arclet/alconna/typing.py b/src/arclet/alconna/typing.py index bdec2159..78cbfb2e 100644 --- a/src/arclet/alconna/typing.py +++ b/src/arclet/alconna/typing.py @@ -30,6 +30,9 @@ class ShortcutArgs(TypedDict): """快捷指令的正则匹配结果的额外处理函数""" +DEFAULT_WRAPPER = lambda slot, content: content + + class InnerShortcutArgs: command: DataCollection[Any] args: list[Any] @@ -51,7 +54,7 @@ def __init__( self.args = args or [] self.fuzzy = fuzzy self.prefix = prefix - self.wrapper = wrapper or (lambda slot, content: content) + self.wrapper = wrapper or DEFAULT_WRAPPER def __repr__(self): return f"ShortcutArgs({self.command!r}, args={self.args!r}, fuzzy={self.fuzzy}, prefix={self.prefix})"