Skip to content

Commit

Permalink
✨ add context query in arparma
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Feb 26, 2024
1 parent aa4cc6a commit ba2dd92
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 39 deletions.
2 changes: 2 additions & 0 deletions exam7.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal, Optional, overload
from typing_extensions import Self
Expand Down
2 changes: 2 additions & 0 deletions example/exec_sql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from sqlite3 import connect
from typing import Optional

Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,10 @@ extend-exclude = '''
profile = "black"
line_length = 120
skip_gitignore = true
extra_standard_library = ["typing_extensions"]
extra_standard_library = ["typing_extensions"]

[tool.pyright]
pythonVersion = "3.8"
pythonPlatform = "All"
typeCheckingMode = "basic"
disableBytesTypePromotions = true
46 changes: 23 additions & 23 deletions src/arclet/alconna/_internal/_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,35 +59,35 @@ class Pair:

__slots__ = ("prefix", "pattern", "is_prefix_pat", "gd_supplier", "_match")

def _match1(self, command: str, pbfn: Callable[..., Any], comp: bool):
if command == self.pattern:
return command, None
if comp and command.startswith(self.pattern):
pbfn(command[len(self.pattern):], replace=True)
return self.pattern, None
return None, None

def _match2(self, command: str, pbfn: Callable[..., Any], comp: bool):
if mat := self.pattern.fullmatch(command):
return command, mat
if comp and (mat := self.pattern.match(command)):
pbfn(command[len(mat[0]):], replace=True)
return mat[0], mat
return None, None

def __init__(self, prefix: Any, pattern: TPattern | str):
self.prefix = prefix
self.pattern = pattern
self.is_prefix_pat = isinstance(self.prefix, BasePattern)
if isinstance(self.pattern, str):
self.gd_supplier = lambda mat: None

def _match(command: str, pbfn: Callable[..., ...], comp: bool):
if command == self.pattern:
return command, None
if comp and command.startswith(self.pattern):
pbfn(command[len(self.pattern):], replace=True)
return self.pattern, None
return None, None

self._match = self._match1
else:
self.gd_supplier = lambda mat: mat.groupdict()
self._match = self._match2

def _match(command: str, pbfn: Callable[..., ...], comp: bool):
if mat := self.pattern.fullmatch(command):
return command, mat
if comp and (mat := self.pattern.match(command)):
pbfn(command[len(mat[0]):], replace=True)
return mat[0], mat
return None, None

self._match = _match

def match(self, _pf: Any, command: str, pbfn: Callable[..., ...], comp: bool):
def match(self, _pf: Any, command: str, pbfn: Callable[..., Any], comp: bool):
cmd, mat = self._match(command, pbfn, comp)
if cmd is None:
return
Expand Down Expand Up @@ -159,7 +159,7 @@ def __repr__(self):
prefixes.append(pf)
return f"[{'│'.join(prefixes)}]{cmd}"

def match0(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool):
def match0(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool):
if self.prefix and p_str and pf in self.prefix:
if (val := self.command.validate(cmd)).success:
return (pf, cmd), (pf, val._value), True, None
Expand All @@ -177,7 +177,7 @@ def match0(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[...
return (pf, cmd), (val._value, cmd[:len(str(val2._value))]), True, None
return

def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool):
def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool):
if p_str or not c_str:
return
if (val := self.patterns.validate(pf)).success and (mat := self.command.fullmatch(cmd)):
Expand All @@ -186,7 +186,7 @@ def match1(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[...
pbfn(cmd[len(mat[0]):], replace=True)
return (pf, cmd), (pf, mat[0]), True, mat.groupdict()

def match2(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool):
def match2(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool):
if not p_str and not c_str:
return
if p_str:
Expand All @@ -212,7 +212,7 @@ def match2(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[...
pbfn(cmd[len(mat[0]):], replace=True)
return (pf, cmd), (val._value, mat[0]), True, mat.groupdict()

def match(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., ...], comp: bool):
def match(self, pf: Any, cmd: Any, p_str: bool, c_str: bool, pbfn: Callable[..., Any], comp: bool) -> Any:
...


Expand Down
28 changes: 19 additions & 9 deletions src/arclet/alconna/arparma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Callable, ClassVar, Generic, TypeVar, cast, overload
from typing_extensions import Self

from tarina import Empty, generic_isinstance, lang
from tarina import Empty, generic_isinstance, lang, safe_eval

from .exceptions import BehaveCancelled, OutBoundsBehave
from .model import HeadResult, OptionResult, SubcommandResult
Expand Down Expand Up @@ -88,9 +88,9 @@ def __call__(self, path: str, default: D | None = None) -> T | D | None:
source, endpoint = self.source.__require__(path.split("."))
if source is None:
return default
if isinstance(source, (OptionResult, SubcommandResult)):
return getattr(source, endpoint, default) if endpoint else source # type: ignore
return source.get(endpoint, default) if endpoint else MappingProxyType(source) # type: ignore
if isinstance(source, dict):
return source.get(endpoint, default) if endpoint else MappingProxyType(source) # type: ignore
return getattr(source, endpoint, default) if endpoint else source # type: ignore


class Arparma(Generic[TDC]):
Expand Down Expand Up @@ -266,6 +266,7 @@ def call(self, target: Callable[..., T]) -> T:
data = {
**{k: v() for k, v in self._additional.items()},
**self.all_matched_args,
"context": self.context,
"all_args": self.all_matched_args,
"options": self.options,
"subcommands": self.subcommands,
Expand All @@ -285,7 +286,10 @@ def call(self, target: Callable[..., T]) -> T:
kw_args[p.name] = data[p.name]
bind = sig.bind(*pos_args, **kw_args)
bind.apply_defaults()
return target(*bind.args, **bind.kwargs)
try:
return target(*bind.args, **bind.kwargs)
finally:
data.clear()

def fail(self, exc: type[Exception] | Exception) -> Self:
"""生成一个失败的 `Arparma`"""
Expand All @@ -295,11 +299,11 @@ def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult |
"""如果能够返回, 除开基本信息, 一定返回该path所在的dict"""
if len(parts) == 1:
part = parts[0]
for src in (self.main_args, self.other_args, self.options, self.subcommands):
if part in {"options", "subcommands", "main_args", "other_args", "context"}:
return getattr(self, part, {}), ""
for src in (self.main_args, self.other_args, self.options, self.subcommands, self.context):
if part in src:
return src, part
if part in {"options", "subcommands", "main_args", "other_args"}:
return getattr(self, part, {}), ""
return (self.all_matched_args, "") if part == "args" else (None, part)
prefix = parts.pop(0) # parts[0]
if prefix in {"options", "subcommands"} and prefix in self.components:
Expand All @@ -311,7 +315,13 @@ def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult |
prefix = prefix.replace("$main", "main_args").replace("$other", "other_args")
if prefix in {"main_args", "other_args"}:
return getattr(self, prefix, {}), parts.pop(0)
return None, prefix
path = ".".join([prefix] + parts)
if path in self.context:
return self.context, path
try:
return safe_eval(path, self.context), "" # type: ignore
except Exception:
return None, prefix

def query_with(self, arg_type: type[T], *args):
return self.query[arg_type](*args)
Expand Down
4 changes: 2 additions & 2 deletions src/arclet/alconna/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import weakref
from copy import copy
from datetime import datetime
from typing import TYPE_CHECKING, Any, Match, Union
from typing import TYPE_CHECKING, Any, Match, Union, MutableSet
from weakref import WeakKeyDictionary, WeakValueDictionary

from tarina import LRU, lang
Expand Down Expand Up @@ -152,7 +152,7 @@ def require(self, command: Alconna[TDC]) -> Analyser[TDC]:
namespace, name = self._command_part(command.path)
raise ValueError(lang.require("manager", "undefined_command").format(target=f"{namespace}.{name}")) from e

def unpack(self, commands: set[Alconna]) -> zip[tuple[Analyser, Argv]]: # type: ignore
def unpack(self, commands: MutableSet[Alconna]) -> "zip[tuple[Analyser, Argv]]":
"""获取多个命令解析器"""
return zip(
[v for k, v in self.__analysers.items() if k in commands],
Expand Down
8 changes: 4 additions & 4 deletions src/arclet/alconna/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ def __calc_eq__(self, other): # pragma: no cover
class _ContextPattern(BasePattern[Any, Tuple[str, Dict[str, Any]]]):
def __init__(self, *names: str, style: Literal["bracket", "parentheses"] = "parentheses"):
if not names:
pat = "$(VAR)" if style == "parentheses" else "${VAR}"
pat = "$(VAR)" if style == "parentheses" else "{VAR}"
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=pat)
self.pattern = pat
self.regex_pattern = re.compile(r"\$\((.+)\)", re.DOTALL) if style == "parentheses" else re.compile(r"\{(.+)\}", re.DOTALL)
else:
pat = f"$({'|'.join(names)})" if style == "parentheses" else f"${{{'|'.join(names)}}}"
pat = f"$({'|'.join(names)})" if style == "parentheses" else f"{{{'|'.join(names)}}}"
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=pat)
self.pattern = pat
if style == "parentheses":
self.regex_pattern = re.compile(rf"\$\(({'|'.join(map(re.escape, names))})\)")
self.regex_pattern = re.compile(rf"\$\(({'|'.join(re.escape(name) + '.*' for name in names)})\)")
else:
self.regex_pattern = re.compile(rf"\{{({'|'.join(map(re.escape, names))})\}}")
self.regex_pattern = re.compile(rf"\{{({'|'.join(re.escape(name) + '.*' for name in names)})\}}")

def match(self, input_: Tuple[str, Dict[str, Any]]) -> Any:
pat, ctx = input_
Expand Down

0 comments on commit ba2dd92

Please sign in to comment.