Skip to content

Commit

Permalink
Fix copy() with predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
mar10 committed Oct 22, 2024
1 parent 4948958 commit 26b590a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
19 changes: 11 additions & 8 deletions nutree/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,14 @@ def _visit(other: Node) -> None:
parent_stack.append((False, n))

res = call_predicate(predicate, n)
if isinstance(res, SkipBranch):

if res is None or res is False: # Add only if has a `true` descendant
_visit(n)
elif res is True: # Add this node (and also check children)
p = _create_parents()
# p.add_child(n)
_visit(n)
elif isinstance(res, SkipBranch):
if res.and_self is False:
# Add the node itself if user explicitly returned
# `SkipBranch(and_self=False)`
Expand All @@ -918,12 +925,8 @@ def _visit(other: Node) -> None:
# Unconditionally copy whole branch: no need to visit children
p = _create_parents()
p._add_from(n)
elif res in (None, False): # Add only if has a `true` descendant
_visit(n)
elif res is True: # Add this node (and also check children)
p = _create_parents()
p.add_child(n)
_visit(n)
else:
raise ValueError(f"Invalid predicate return value: {res}")

parent_stack.pop()
return
Expand Down Expand Up @@ -958,7 +961,7 @@ def _visit(parent: Node) -> bool:

for n in parent.children:
res = call_predicate(predicate, n)
if res in (None, False): # Keep only if has a `true` descendant
if res is None or res is False: # Keep only if has a `true` descendant
if _visit(n):
must_keep = True
else:
Expand Down
6 changes: 6 additions & 0 deletions tests/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def _check_content(
else:
style = "ascii32"

if isinstance(tree, Tree):
assert tree._self_check()

s1 = indent(canonical_repr(tree, repr=repr, style=style), " ")
s2 = indent(canonical_repr(expect_ascii, repr=repr, style=style), " ")
if ignore_tree_name:
Expand Down Expand Up @@ -281,6 +284,9 @@ def check_content(


def trees_equal(tree_1, tree_2, ignore_tree_name=True) -> bool:
assert tree_1 is not tree_2
if not tree_1 or not tree_2 or (len(tree_1) != len(tree_2)):
return False
return check_content(tree_1, tree_2, ignore_tree_name=ignore_tree_name)


Expand Down
27 changes: 20 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,6 @@ def test_tree_copy(self):
`- a12
""",
)
assert subtree._self_check()

subtree = tree["A"].copy(add_self=False)
assert fixture.check_content(
Expand All @@ -1138,7 +1137,26 @@ def test_tree_copy(self):
`- a2
""",
)
assert subtree._self_check()

def test_node_copy_predicate(self):
tree = fixture.create_tree()

tree_2 = tree.copy()
assert fixture.trees_equal(tree, tree_2)

tree_3 = tree.copy(predicate=lambda n: "2" not in n.name.lower())
assert fixture.check_content(
tree_3,
"""
Tree<'fixture'>
├── A
│ ╰── a1
│ ╰── a11
╰── B
╰── b1
╰── b11
""",
)

def test_node_copy_to(self):
tree_1 = fixture.create_tree()
Expand Down Expand Up @@ -1278,17 +1296,14 @@ def pred(node):

tree_2 = tree.filtered(predicate=pred)

assert tree_2._self_check()
assert fixture.check_content(
tree_2,
"""
Tree<*>
╰── A
├── a1
│ ╰── a12
│ ╰── a12
╰── a2
╰── a2
""",
)

Expand All @@ -1299,14 +1314,12 @@ def pred(node):

tree_2 = tree.filtered(predicate=pred)

assert tree_2._self_check()
assert fixture.check_content(
tree_2,
"""
Tree<*>
╰── A
╰── a2
╰── a2
""",
)

Expand Down

0 comments on commit 26b590a

Please sign in to comment.