Skip to content

Commit

Permalink
More typing
Browse files Browse the repository at this point in the history
  • Loading branch information
mar10 committed Oct 27, 2024
1 parent 91e2069 commit 3cfe000
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
5 changes: 4 additions & 1 deletion nutree/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

try:
from typing import Self # ruff: noqa: F401
from typing import Self
except ImportError:
from typing_extensions import Self # noqa

Expand Down Expand Up @@ -160,6 +160,9 @@ def __init__(self, value=None):
["Node"], Union[None, bool, IterationControl, Type[IterationControl]]
]

#:
MatchArgumentType = Union[str, PredicateCallbackType, list, tuple, Any]

#:
TraversalCallbackType = Callable[
["Node", Any],
Expand Down
13 changes: 9 additions & 4 deletions nutree/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
IterMethod,
KeyMapType,
MapperCallbackType,
MatchArgumentType,
PredicateCallbackType,
ReprArgType,
SelectBranch,
Expand Down Expand Up @@ -954,7 +955,7 @@ def _visit(other: Self) -> None:
pass
return

def filtered(self, predicate: PredicateCallbackType) -> Tree:
def filtered(self, predicate: PredicateCallbackType) -> Tree[Self]:
"""Return a filtered copy of this node and descendants as tree.
See also :ref:`iteration-callbacks`.
Expand Down Expand Up @@ -1207,7 +1208,11 @@ def iterator(
__iter__ = iterator

def _search(
self, match, *, max_results: int | None = None, add_self=False
self,
match,
*,
max_results: int | None = None,
add_self=False,
) -> Iterator[Self]:
if callable(match):
cb_match = match
Expand All @@ -1234,7 +1239,7 @@ def find_all(
self,
data=None,
*,
match: PredicateCallbackType | None = None,
match: MatchArgumentType | None = None,
data_id: DataIdType | None = None,
add_self=False,
max_results: int | None = None,
Expand All @@ -1259,7 +1264,7 @@ def find_first(
self,
data=None,
*,
match: PredicateCallbackType | None = None,
match: MatchArgumentType | None = None,
data_id: DataIdType | None = None,
) -> Self | None:
"""Return the first matching node or `None`.
Expand Down
5 changes: 3 additions & 2 deletions nutree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
IterMethod,
KeyMapType,
MapperCallbackType,
MatchArgumentType,
PredicateCallbackType,
ReprArgType,
Self,
Expand Down Expand Up @@ -460,7 +461,7 @@ def find_all(
self,
data=None,
*,
match: PredicateCallbackType | None = None,
match: MatchArgumentType | None = None,
data_id: DataIdType | None = None,
max_results: int | None = None,
) -> list[TNode]:
Expand Down Expand Up @@ -489,7 +490,7 @@ def find_first(
self,
data=None,
*,
match: PredicateCallbackType | None = None,
match: MatchArgumentType | None = None,
data_id: DataIdType | None = None,
node_id: int | None = None,
) -> TNode | None:
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ ignore = [
# convention = "google"

[tool.pyright]
# include = ["nutree", "tests"]
include = ["nutree"]
# typeCheckingMode = "off"
typeCheckingMode = "basic"
# typeCheckingMode = "off"
# include = ["nutree"]
include = ["nutree", "tests"]
exclude = ["nutree/rdf.py"]

# https://github.com/microsoft/pyright/blob/main/docs/configuration.md#sample-pyprojecttoml-file
reportUnnecessaryTypeIgnoreComment = true
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ def test_filter(self):
with pytest.raises(ValueError, match="Predicate is required"):
tree.filter(predicate=None) # type: ignore
with pytest.raises(ValueError, match="Predicate is required"):
tree.system_root.filter(predicate=None) # type: ignore
tree.system_root.filter(predicate=None)

def _tf(
*,
Expand Down

0 comments on commit 3cfe000

Please sign in to comment.