From 26b590a6c79f4fd46fc93cb9fdc0a1b747d5d209 Mon Sep 17 00:00:00 2001 From: Martin Wendt Date: Tue, 22 Oct 2024 21:38:16 +0200 Subject: [PATCH] Fix copy() with predicate --- nutree/node.py | 19 +++++++++++-------- tests/fixture.py | 6 ++++++ tests/test_core.py | 27 ++++++++++++++++++++------- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/nutree/node.py b/nutree/node.py index 493aa27..45ec1dd 100644 --- a/nutree/node.py +++ b/nutree/node.py @@ -906,7 +906,14 @@ def _visit(other: Node) -> None: parent_stack.append((False, n)) res = call_predicate(predicate, n) - if isinstance(res, SkipBranch): + + if res is None or res is False: # Add only if has a `true` descendant + _visit(n) + elif res is True: # Add this node (and also check children) + p = _create_parents() + # p.add_child(n) + _visit(n) + elif isinstance(res, SkipBranch): if res.and_self is False: # Add the node itself if user explicitly returned # `SkipBranch(and_self=False)` @@ -918,12 +925,8 @@ def _visit(other: Node) -> None: # Unconditionally copy whole branch: no need to visit children p = _create_parents() p._add_from(n) - elif res in (None, False): # Add only if has a `true` descendant - _visit(n) - elif res is True: # Add this node (and also check children) - p = _create_parents() - p.add_child(n) - _visit(n) + else: + raise ValueError(f"Invalid predicate return value: {res}") parent_stack.pop() return @@ -958,7 +961,7 @@ def _visit(parent: Node) -> bool: for n in parent.children: res = call_predicate(predicate, n) - if res in (None, False): # Keep only if has a `true` descendant + if res is None or res is False: # Keep only if has a `true` descendant if _visit(n): must_keep = True else: diff --git a/tests/fixture.py b/tests/fixture.py index 7d49670..4261a4a 100644 --- a/tests/fixture.py +++ b/tests/fixture.py @@ -253,6 +253,9 @@ def _check_content( else: style = "ascii32" + if isinstance(tree, Tree): + assert tree._self_check() + s1 = indent(canonical_repr(tree, repr=repr, style=style), " ") s2 = indent(canonical_repr(expect_ascii, repr=repr, style=style), " ") if ignore_tree_name: @@ -281,6 +284,9 @@ def check_content( def trees_equal(tree_1, tree_2, ignore_tree_name=True) -> bool: + assert tree_1 is not tree_2 + if not tree_1 or not tree_2 or (len(tree_1) != len(tree_2)): + return False return check_content(tree_1, tree_2, ignore_tree_name=ignore_tree_name) diff --git a/tests/test_core.py b/tests/test_core.py index e47892e..0c586bf 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1125,7 +1125,6 @@ def test_tree_copy(self): `- a12 """, ) - assert subtree._self_check() subtree = tree["A"].copy(add_self=False) assert fixture.check_content( @@ -1138,7 +1137,26 @@ def test_tree_copy(self): `- a2 """, ) - assert subtree._self_check() + + def test_node_copy_predicate(self): + tree = fixture.create_tree() + + tree_2 = tree.copy() + assert fixture.trees_equal(tree, tree_2) + + tree_3 = tree.copy(predicate=lambda n: "2" not in n.name.lower()) + assert fixture.check_content( + tree_3, + """ + Tree<'fixture'> + ├── A + │ ╰── a1 + │ ╰── a11 + ╰── B + ╰── b1 + ╰── b11 + """, + ) def test_node_copy_to(self): tree_1 = fixture.create_tree() @@ -1278,7 +1296,6 @@ def pred(node): tree_2 = tree.filtered(predicate=pred) - assert tree_2._self_check() assert fixture.check_content( tree_2, """ @@ -1286,9 +1303,7 @@ def pred(node): ╰── A ├── a1 │ ╰── a12 - │ ╰── a12 ╰── a2 - ╰── a2 """, ) @@ -1299,14 +1314,12 @@ def pred(node): tree_2 = tree.filtered(predicate=pred) - assert tree_2._self_check() assert fixture.check_content( tree_2, """ Tree<*> ╰── A ╰── a2 - ╰── a2 """, )