diff --git a/nutree/common.py b/nutree/common.py index 9d57ed4..4face0c 100644 --- a/nutree/common.py +++ b/nutree/common.py @@ -297,7 +297,10 @@ def call_traversal_cb(fn: Callable, node: Node, memo: Any) -> False | None: """ try: res = fn(node, memo) - if res is SkipBranch or isinstance(res, SkipBranch): + + if res is None: + return None + elif res is SkipBranch or isinstance(res, SkipBranch): return False elif res is StopTraversal or isinstance(res, StopTraversal): raise res @@ -306,10 +309,10 @@ def call_traversal_cb(fn: Callable, node: Node, memo: Any) -> False | None: elif res is StopIteration or isinstance(res, StopIteration): # Converts wrong syntax in exception handler... raise res - elif res is not None: + else: raise ValueError( "callback should not return values except for " - f"False, SkipBranch, or StopTraversal: {res!r}." + f"None, False, SkipBranch, or StopTraversal: {res!r}." ) except SkipBranch: return False