From ba2dd9265661a21265c58560150b88fcdd371b94 Mon Sep 17 00:00:00 2001 From: rf_tar_railt <3165388245@qq.com> Date: Mon, 26 Feb 2024 18:24:12 +0800 Subject: [PATCH] :sparkles: add context query in arparma --- exam7.py | 2 ++ example/exec_sql.py | 2 ++ pyproject.toml | 8 ++++- src/arclet/alconna/_internal/_header.py | 46 ++++++++++++------------- src/arclet/alconna/arparma.py | 28 ++++++++++----- src/arclet/alconna/manager.py | 4 +-- src/arclet/alconna/typing.py | 8 ++--- 7 files changed, 59 insertions(+), 39 deletions(-) diff --git a/exam7.py b/exam7.py index ba903d8a..946362c2 100644 --- a/exam7.py +++ b/exam7.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, field from typing import Literal, Optional, overload from typing_extensions import Self diff --git a/example/exec_sql.py b/example/exec_sql.py index 1f49ddc8..38c96f49 100644 --- a/example/exec_sql.py +++ b/example/exec_sql.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlite3 import connect from typing import Optional diff --git a/pyproject.toml b/pyproject.toml index 32c90f2e..73a05681 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,4 +132,10 @@ extend-exclude = ''' profile = "black" line_length = 120 skip_gitignore = true -extra_standard_library = ["typing_extensions"] \ No newline at end of file +extra_standard_library = ["typing_extensions"] + +[tool.pyright] +pythonVersion = "3.8" +pythonPlatform = "All" +typeCheckingMode = "basic" +disableBytesTypePromotions = true diff --git a/src/arclet/alconna/_internal/_header.py b/src/arclet/alconna/_internal/_header.py index 534089b1..7190b3da 100644 --- a/src/arclet/alconna/_internal/_header.py +++ b/src/arclet/alconna/_internal/_header.py @@ -59,35 +59,35 @@ class Pair: __slots__ = ("prefix", "pattern", "is_prefix_pat", "gd_supplier", "_match") + def _match1(self, command: str, pbfn: Callable[..., Any], comp: bool): + if command == self.pattern: + return command, None + if comp and command.startswith(self.pattern): + pbfn(command[len(self.pattern):], replace=True) + return self.pattern, None + return None, None + + def _match2(self, command: str, pbfn: Callable[..., Any], comp: bool): + if mat := self.pattern.fullmatch(command): + return command, mat + if comp and (mat := self.pattern.match(command)): + pbfn(command[len(mat[0]):], replace=True) + return mat[0], mat + return None, None + def __init__(self, prefix: Any, pattern: TPattern | str): self.prefix = prefix self.pattern = pattern self.is_prefix_pat = isinstance(self.prefix, BasePattern) if isinstance(self.pattern, str): self.gd_supplier = lambda mat: None - - def _match(command: str, pbfn: Callable[..., ...], comp: bool): - if command == self.pattern: - return command, None - if comp and command.startswith(self.pattern): - pbfn(command[len(self.pattern):], replace=True) - return self.pattern, None - return None, None - + self._match = self._match1 else: self.gd_supplier = lambda mat: mat.groupdict() + self._match = self._match2 - def _match(command: str, pbfn: Callable[..., ...], comp: bool): - if mat := self.pattern.fullmatch(command): - return command, mat - if comp and (mat := self.pattern.match(command)): - pbfn(command[len(mat[0]):], replace=True) - return mat[0], mat - return None, None - - self._match = _match - def match(self, _pf: Any, command: str, pbfn: Callable[..., ...], comp: bool): + def match(self, _pf: Any, command: str, pbfn: Callable[..., Any], comp: bool): cmd, mat = self._match(command, pbfn, comp) if cmd is None: return @@ -159,7 +159,7 @@ def __repr__(self): prefixes.append(pf) return f"[{'│'.join(prefixes)}]{cmd}" - def match0(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool): + def match0(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool): if self.prefix and p_str and pf in self.prefix: if (val := self.command.validate(cmd)).success: return (pf, cmd), (pf, val._value), True, None @@ -177,7 +177,7 @@ def match0(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[... return (pf, cmd), (val._value, cmd[:len(str(val2._value))]), True, None return - def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool): + def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool): if p_str or not c_str: return if (val := self.patterns.validate(pf)).success and (mat := self.command.fullmatch(cmd)): @@ -186,7 +186,7 @@ def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[... pbfn(cmd[len(mat[0]):], replace=True) return (pf, cmd), (pf, mat[0]), True, mat.groupdict() - def match2(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool): + def match2(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool): if not p_str and not c_str: return if p_str: @@ -212,7 +212,7 @@ def match2(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[... pbfn(cmd[len(mat[0]):], replace=True) return (pf, cmd), (val._value, mat[0]), True, mat.groupdict() - def match(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool): + def match(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool) -> Any: ... diff --git a/src/arclet/alconna/arparma.py b/src/arclet/alconna/arparma.py index 0540ed4f..3341b78e 100644 --- a/src/arclet/alconna/arparma.py +++ b/src/arclet/alconna/arparma.py @@ -8,7 +8,7 @@ from typing import Any, Callable, ClassVar, Generic, TypeVar, cast, overload from typing_extensions import Self -from tarina import Empty, generic_isinstance, lang +from tarina import Empty, generic_isinstance, lang, safe_eval from .exceptions import BehaveCancelled, OutBoundsBehave from .model import HeadResult, OptionResult, SubcommandResult @@ -88,9 +88,9 @@ def __call__(self, path: str, default: D | None = None) -> T | D | None: source, endpoint = self.source.__require__(path.split(".")) if source is None: return default - if isinstance(source, (OptionResult, SubcommandResult)): - return getattr(source, endpoint, default) if endpoint else source # type: ignore - return source.get(endpoint, default) if endpoint else MappingProxyType(source) # type: ignore + if isinstance(source, dict): + return source.get(endpoint, default) if endpoint else MappingProxyType(source) # type: ignore + return getattr(source, endpoint, default) if endpoint else source # type: ignore class Arparma(Generic[TDC]): @@ -266,6 +266,7 @@ def call(self, target: Callable[..., T]) -> T: data = { **{k: v() for k, v in self._additional.items()}, **self.all_matched_args, + "context": self.context, "all_args": self.all_matched_args, "options": self.options, "subcommands": self.subcommands, @@ -285,7 +286,10 @@ def call(self, target: Callable[..., T]) -> T: kw_args[p.name] = data[p.name] bind = sig.bind(*pos_args, **kw_args) bind.apply_defaults() - return target(*bind.args, **bind.kwargs) + try: + return target(*bind.args, **bind.kwargs) + finally: + data.clear() def fail(self, exc: type[Exception] | Exception) -> Self: """生成一个失败的 `Arparma`""" @@ -295,11 +299,11 @@ def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult | """如果能够返回, 除开基本信息, 一定返回该path所在的dict""" if len(parts) == 1: part = parts[0] - for src in (self.main_args, self.other_args, self.options, self.subcommands): + if part in {"options", "subcommands", "main_args", "other_args", "context"}: + return getattr(self, part, {}), "" + for src in (self.main_args, self.other_args, self.options, self.subcommands, self.context): if part in src: return src, part - if part in {"options", "subcommands", "main_args", "other_args"}: - return getattr(self, part, {}), "" return (self.all_matched_args, "") if part == "args" else (None, part) prefix = parts.pop(0) # parts[0] if prefix in {"options", "subcommands"} and prefix in self.components: @@ -311,7 +315,13 @@ def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult | prefix = prefix.replace("$main", "main_args").replace("$other", "other_args") if prefix in {"main_args", "other_args"}: return getattr(self, prefix, {}), parts.pop(0) - return None, prefix + path = ".".join([prefix] + parts) + if path in self.context: + return self.context, path + try: + return safe_eval(path, self.context), "" # type: ignore + except Exception: + return None, prefix def query_with(self, arg_type: type[T], *args): return self.query[arg_type](*args) diff --git a/src/arclet/alconna/manager.py b/src/arclet/alconna/manager.py index eb910d61..069031d5 100644 --- a/src/arclet/alconna/manager.py +++ b/src/arclet/alconna/manager.py @@ -8,7 +8,7 @@ import weakref from copy import copy from datetime import datetime -from typing import TYPE_CHECKING, Any, Match, Union +from typing import TYPE_CHECKING, Any, Match, Union, MutableSet from weakref import WeakKeyDictionary, WeakValueDictionary from tarina import LRU, lang @@ -152,7 +152,7 @@ def require(self, command: Alconna[TDC]) -> Analyser[TDC]: namespace, name = self._command_part(command.path) raise ValueError(lang.require("manager", "undefined_command").format(target=f"{namespace}.{name}")) from e - def unpack(self, commands: set[Alconna]) -> zip[tuple[Analyser, Argv]]: # type: ignore + def unpack(self, commands: MutableSet[Alconna]) -> "zip[tuple[Analyser, Argv]]": """获取多个命令解析器""" return zip( [v for k, v in self.__analysers.items() if k in commands], diff --git a/src/arclet/alconna/typing.py b/src/arclet/alconna/typing.py index c4a716a7..5d236e4d 100644 --- a/src/arclet/alconna/typing.py +++ b/src/arclet/alconna/typing.py @@ -136,18 +136,18 @@ def __calc_eq__(self, other): # pragma: no cover class _ContextPattern(BasePattern[Any, Tuple[str, Dict[str, Any]]]): def __init__(self, *names: str, style: Literal["bracket", "parentheses"] = "parentheses"): if not names: - pat = "$(VAR)" if style == "parentheses" else "${VAR}" + pat = "$(VAR)" if style == "parentheses" else "{VAR}" super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=pat) self.pattern = pat self.regex_pattern = re.compile(r"\$\((.+)\)", re.DOTALL) if style == "parentheses" else re.compile(r"\{(.+)\}", re.DOTALL) else: - pat = f"$({'|'.join(names)})" if style == "parentheses" else f"${{{'|'.join(names)}}}" + pat = f"$({'|'.join(names)})" if style == "parentheses" else f"{{{'|'.join(names)}}}" super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=pat) self.pattern = pat if style == "parentheses": - self.regex_pattern = re.compile(rf"\$\(({'|'.join(map(re.escape, names))})\)") + self.regex_pattern = re.compile(rf"\$\(({'|'.join(re.escape(name) + '.*' for name in names)})\)") else: - self.regex_pattern = re.compile(rf"\{{({'|'.join(map(re.escape, names))})\}}") + self.regex_pattern = re.compile(rf"\{{({'|'.join(re.escape(name) + '.*' for name in names)})\}}") def match(self, input_: Tuple[str, Dict[str, Any]]) -> Any: pat, ctx = input_