Skip to content

Commit

Permalink
✨ add context for parse
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Feb 25, 2024
1 parent 3dc2c87 commit e9d1c60
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 42 deletions.
11 changes: 8 additions & 3 deletions src/arclet/alconna/_internal/_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..model import HeadResult, OptionResult, Sentence, SubcommandResult
from ..output import output_manager
from ..typing import TDC, InnerShortcutArgs
from ..constraint import SHORTCUT_TRIGGER, SHORTCUT_ARGS, SHORTCUT_REST, SHORTCUT_REGEX_MATCH
from ._handlers import (
_handle_shortcut_data,
_handle_shortcut_reg,
Expand Down Expand Up @@ -181,7 +182,7 @@ def process(self, argv: Argv[TDC]) -> Self:
ParamsUnmatched: 名称不匹配
FuzzyMatchSuccess: 模糊匹配成功
"""
sub = argv.context = self.command
sub = argv.current_node = self.command
name, _ = argv.next(sub.separators)
if name != sub.name: # 先匹配节点名称
if argv.fuzzy_match and levenshtein(name, sub.name) >= argv.fuzzy_threshold:
Expand Down Expand Up @@ -332,6 +333,10 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]:
raise e from exc
return self.export(argv, True, e)
else:
argv.context[SHORTCUT_TRIGGER] = _next
argv.context[SHORTCUT_ARGS] = short
argv.context[SHORTCUT_REST] = rest
argv.context[SHORTCUT_REGEX_MATCH] = mat
self.reset()
argv.reset()
return self.shortcut(argv, rest, short, mat)
Expand Down Expand Up @@ -384,7 +389,7 @@ def analyse(self, argv: Argv[TDC]) -> Arparma[TDC] | None:
return _SPECIAL[handler](self, argv)
if comp_ctx.get(None):
if isinstance(e1, InvalidParam):
argv.free(argv.context.separators if argv.context else None)
argv.free(argv.current_node.separators if argv.current_node else None)
raise PauseTriggered(prompt(self, argv), e1, argv) from e1
if self.command.meta.raise_exception:
raise
Expand All @@ -406,7 +411,7 @@ def export(
fail (bool, optional): 是否解析失败. Defaults to False.
exception (Exception | None, optional): 解析失败时的异常. Defaults to None.
"""
result = Arparma(self.command.path, argv.origin, not fail, self.header_result)
result = Arparma(self.command.path, argv.origin, not fail, self.header_result, ctx=argv.exit())
if fail:
result.error_info = exception
result.error_data = argv.release()
Expand Down
25 changes: 22 additions & 3 deletions src/arclet/alconna/_internal/_argv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from typing import Any, Callable, ClassVar, Generic, Iterable
from typing_extensions import Self

Expand All @@ -11,6 +11,7 @@
from ..config import Namespace, config
from ..exceptions import NullMessage
from ..typing import TDC
from ..constraint import ARGV_OVERRIDES


@dataclass(repr=True)
Expand Down Expand Up @@ -41,7 +42,7 @@ class Argv(Generic[TDC]):
param_ids: set[str] = field(default_factory=set)
"""节点名集合"""

context: Arg | Subcommand | Option | None = field(init=False)
current_node: Arg | Subcommand | Option | None = field(init=False)
"""当前节点"""
current_index: int = field(init=False)
"""当前数据的索引"""
Expand All @@ -55,6 +56,7 @@ class Argv(Generic[TDC]):
"""命令的token"""
origin: TDC = field(init=False)
"""原始命令"""
context: dict[str, Any] = field(init=False)
_sep: tuple[str, ...] | None = field(init=False)

_cache: ClassVar[dict[type, dict[str, Any]]] = {}
Expand Down Expand Up @@ -84,7 +86,7 @@ def reset(self):
self.token = 0
self.origin = "None" # type: ignore
self._sep = None
self.context = None
self.current_node = None

@staticmethod
def generate_token(data: list) -> int:
Expand Down Expand Up @@ -254,3 +256,20 @@ def data_set(self):
def data_reset(self, data: list[str | Any], index: int):
self.raw_data = data
self.current_index = index

def enter(self, ctx: dict[str, Any] | None = None) -> Self:
"""进入上下文"""
if ctx and ARGV_OVERRIDES in ctx:
field_names = [f.name for f in fields(self)]
for k, v in ctx[ARGV_OVERRIDES].items():
if k in field_names:
setattr(self, k, v)
self.context = {} if ctx is None else ctx
return self

def exit(self) -> dict[str, Any]:
"""退出上下文"""
try:
return self.context
finally:
self.context = {}
20 changes: 10 additions & 10 deletions src/arclet/alconna/_internal/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any, Any], result

def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict[str, Any]):
value, arg = slot
argv.context = arg
argv.current_node = arg
key = arg.name
default_val = arg.field.default
_result = []
Expand Down Expand Up @@ -87,7 +87,7 @@ def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict

def step_varkey(argv: Argv, slot: tuple[MultiKeyWordVar, Arg], result: dict[str, Any]):
value, arg = slot
argv.context = arg
argv.current_node = arg
name = arg.name
default_val = arg.field.default
_result = {}
Expand Down Expand Up @@ -190,7 +190,7 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]:
"""
result = {}
for arg in args.argument.normal:
argv.context = arg
argv.current_node = arg
may_arg, _str = argv.next(arg.separators)
if _str and may_arg in argv.special:
if argv.special[may_arg] not in argv.namespace.disable_builtin_options:
Expand Down Expand Up @@ -230,7 +230,7 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]:
step_keyword(argv, args, result)
for slot in args.argument.vars_keyword:
step_varkey(argv, slot, result)
argv.context = None
argv.current_node = None
return result


Expand All @@ -242,7 +242,7 @@ def handle_option(argv: Argv, opt: Option) -> tuple[str, OptionResult]:
argv (Argv): 命令行参数
opt (Option): 目标 `Option`
"""
argv.context = opt
argv.current_node = opt
_cnt = 0
error = True
name, _ = argv.next(opt.separators)
Expand Down Expand Up @@ -340,7 +340,7 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv):
_data.clear()
return True
except InvalidParam as e:
if argv.context.__class__ is Arg:
if argv.current_node.__class__ is Arg:
raise e
argv.data_reset(_data, _index)

Expand Down Expand Up @@ -376,14 +376,14 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non
elif analyser.compact_params and (res := analyse_compact_params(analyser, argv)):
if res.__class__ is str:
raise InvalidParam(res)
argv.context = None
argv.current_node = None
return True
else:
_param = None
if not _param and analyser.command.nargs and not analyser.args_result:
analyser.args_result = analyse_args(argv, analyser.self_args)
if analyser.args_result:
argv.context = None
argv.current_node = None
return True
if _param.__class__ is Sentence:
analyser.sentences.append(argv.next()[0])
Expand Down Expand Up @@ -432,7 +432,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: tuple[str, ...] | Non
else:
return False
analyser.sentences.clear()
argv.context = None
argv.current_node = None
return True


Expand Down Expand Up @@ -710,7 +710,7 @@ def _prompt_none(analyser: Analyser, argv: Argv, got: list[str]):

def prompt(analyser: Analyser, argv: Argv, trigger: str | None = None):
"""获取补全列表"""
_trigger = trigger or argv.context
_trigger = trigger or argv.current_node
got = [*analyser.options_result.keys(), *analyser.subcommands_result.keys(), *analyser.sentences]
if isinstance(_trigger, Arg):
return _prompt_unit(analyser, argv, _trigger)
Expand Down
9 changes: 9 additions & 0 deletions src/arclet/alconna/arparma.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
main_args: dict[str, Any] | None = None,
options: dict[str, OptionResult] | None = None,
subcommands: dict[str, SubcommandResult] | None = None,
ctx: dict[str, Any] | None = None,
):
"""初始化 `Arparma`
Args:
Expand All @@ -135,6 +136,7 @@ def __init__(
main_args (dict[str, Any] | None, optional): 主参数匹配结果
options (dict[str, OptionResult] | None, optional): 选项匹配结果
subcommands (dict[str, SubcommandResult] | None, optional): 子命令匹配结果
ctx (dict[str, Any] | None, optional): 上下文
"""
self.source = source
self.origin = origin
Expand All @@ -146,11 +148,18 @@ def __init__(
self.other_args = {}
self.options = options or {}
self.subcommands = subcommands or {}
self.context = ctx or {}

_additional: ClassVar[dict[str, Callable[[], Any]]] = {}
query = _Query[Any]()

def _clr(self):
self.context.clear()
self.error_data.clear()
self.main_args.clear()
self.other_args.clear()
self.options.clear()
self.subcommands.clear()
ks = list(self.__dict__.keys())
for k in ks:
delattr(self, k)
Expand Down
38 changes: 21 additions & 17 deletions src/arclet/alconna/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class CommandNode:

name: str
"""命令节点名称"""
aliases: frozenset[str]
"""命令节点别名"""
dest: str
"""命令节点目标名称"""
default: Any
Expand All @@ -68,6 +70,7 @@ def __init__(
self,
name: str,
args: Arg | Args | None = None,
alias: Iterable[str] | None = None,
dest: str | None = None,
default: Any = Empty,
action: Action | None = None,
Expand All @@ -88,12 +91,21 @@ def __init__(
help_text (str | None, optional): 命令帮助信息
requires (str | list[str] | tuple[str, ...] | set[str] | None, optional): 命令节点需求前缀
"""
if not name:
raise InvalidArgs(lang.require("common", "name_empty"))
_parts = name.split(" ")
self.name = _parts[-1]
aliases = list(alias or [])
parts = name.split(" ")
_name = parts[-1]
if "|" in _name:
_aliases = _name.split("|")
_aliases.sort(key=len, reverse=True)
_name = _aliases[0]
aliases.extend(_aliases[1:])
if not _name:
raise InvalidArgs(lang.require("common", "name_empty"))
aliases.insert(0, _name)
self.name = _name
self.aliases = frozenset(aliases)
self.requires = ([requires] if isinstance(requires, str) else list(requires)) if requires else []
self.requires.extend(_parts[:-1])
self.requires.extend(parts[:-1])
self.args = Args() + args
self.default = default
self.action = action or store
Expand Down Expand Up @@ -184,21 +196,12 @@ def __init__(
compact (bool, optional): 是否允许名称与后随参数之间无分隔符
priority (int, optional): 命令选项优先级
"""
aliases = list(alias or [])
_name = name.split(" ")[-1]
if "|" in _name:
_aliases = _name.split("|")
_aliases.sort(key=len, reverse=True)
name = name.replace(_name, _aliases[0])
_name = _aliases[0]
aliases.extend(_aliases[1:])
aliases.insert(0, _name)
self.aliases = frozenset(aliases)

self.priority = priority
self.compact = compact
if default is not Empty:
default = default if isinstance(default, OptionResult) else OptionResult(default)
super().__init__(name, args, dest, default, action, separators, help_text, requires)
super().__init__(name, args, alias, dest, default, action, separators, help_text, requires)
if self.separators == ("",):
self.compact = True
self.separators = (" ",)
Expand Down Expand Up @@ -266,6 +269,7 @@ def __init__(
self,
name: str,
*args: Args | Arg | Option | Subcommand | list[Option | Subcommand],
alias: Iterable[str] | None = None,
dest: str | None = None,
default: Any = Empty,
separators: str | Sequence[str] | set[str] | None = None,
Expand Down Expand Up @@ -293,7 +297,7 @@ def __init__(
super().__init__(
name,
reduce(lambda x, y: x + y, [Args()] + [i for i in args if isinstance(i, (Arg, Args))]), # type: ignore
dest, default, None, separators, help_text, requires,
alias, dest, default, None, separators, help_text, requires,
)

def __add__(self, other: Option | Args | Arg | str) -> Self:
Expand Down
7 changes: 7 additions & 0 deletions src/arclet/alconna/constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Literal

ARGV_OVERRIDES: Literal["$argv.overrides"] = "$argv.overrides"
SHORTCUT_TRIGGER: Literal["$shortcut.trigger"] = "$shortcut.trigger"
SHORTCUT_REST: Literal["$shortcut.rest"] = "$shortcut.rest"
SHORTCUT_ARGS: Literal["$shortcut.args"] = "$shortcut.args"
SHORTCUT_REGEX_MATCH: Literal["$shortcut.regex_match"] = "$shortcut.regex_match"
17 changes: 9 additions & 8 deletions src/arclet/alconna/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,41 +346,42 @@ def subcommand(self, sub: Subcommand) -> Self:
"""添加子命令"""
return self.add(sub)

def _parse(self, message: TDC) -> Arparma[TDC]:
def _parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC]:
if self._union:
for ana, argv in command_manager.unpack(self.union):
if (res := ana.process(argv.build(message))).matched:
if (res := ana.process(argv.enter(ctx).build(message))).matched:
return res
analyser = command_manager.require(self)
argv = command_manager.resolve(self)
argv.build(message)
argv.enter(ctx).build(message)
return analyser.process(argv)

@overload
def parse(self, message: TDC) -> Arparma[TDC]:
def parse(self, message: TDC, ctx: dict[str, Any] | None = None) -> Arparma[TDC]:
...

@overload
def parse(self, message, *, duplication: type[T_Duplication]) -> T_Duplication:
def parse(self, message, ctx: dict[str, Any] | None = None, *, duplication: type[T_Duplication]) -> T_Duplication:
...

def parse(self, message: TDC, *, duplication: type[T_Duplication] | None = None) -> Arparma[TDC] | T_Duplication:
def parse(self, message: TDC, ctx: dict[str, Any] | None = None, *, duplication: type[T_Duplication] | None = None) -> Arparma[TDC] | T_Duplication:
"""命令分析功能, 传入字符串或消息链, 返回一个特定的数据集合类
Args:
message (TDC): 命令消息
ctx (dict[str, Any], optional): 上下文信息
duplication (type[T_Duplication], optional): 指定的`副本`类型
Returns:
Arparma[TDC] | T_Duplication: 若`duplication`参数为`None`则返回`Arparma`对象, 否则返回`duplication`类型的对象
Raises:
NullMessage: 传入的消息为空时抛出
"""
try:
arp = self._parse(message)
arp = self._parse(message, ctx)
except NullMessage as e:
if self.meta.raise_exception:
raise e
return Arparma(self.path, message, False, error_info=e)
return Arparma(self.path, message, False, error_info=e, ctx=ctx)
if arp.matched:
arp = arp.execute(self.behaviors)
if self._executors:
Expand Down
5 changes: 4 additions & 1 deletion src/arclet/alconna/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class ShortcutArgs(TypedDict):
"""快捷指令的正则匹配结果的额外处理函数"""


DEFAULT_WRAPPER = lambda slot, content: content


class InnerShortcutArgs:
command: DataCollection[Any]
args: list[Any]
Expand All @@ -51,7 +54,7 @@ def __init__(
self.args = args or []
self.fuzzy = fuzzy
self.prefix = prefix
self.wrapper = wrapper or (lambda slot, content: content)
self.wrapper = wrapper or DEFAULT_WRAPPER

def __repr__(self):
return f"ShortcutArgs({self.command!r}, args={self.args!r}, fuzzy={self.fuzzy}, prefix={self.prefix})"
Expand Down

0 comments on commit e9d1c60

Please sign in to comment.