Skip to content

Commit

Permalink
Add some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mar10 committed Sep 22, 2024
1 parent 22f3d30 commit 89da78d
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 103 deletions.
8 changes: 8 additions & 0 deletions docs/sphinx/ug_objects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ GenericNodeData can also be initialized with keyword args like this::

obj = GenericNodeData(a=1, b=2)

Trees that contain GenericNodeData objects can be serialized and deserialized
using the :meth:`~nutree.tree.Tree.save` and :meth:`~nutree.tree.Tree.load`
methods::

tree.save(file_path, mapper=GenericNodeData.serialize_mapper)
...
tree2 = Tree.load(file_path, mapper=GenericNodeData.deserialize_mapper)

.. warning::
The :class:`~nutree.common.GenericNodeData` provides a hash value because
any class that is hashable, so it can be used as a data object. However, the
Expand Down
29 changes: 23 additions & 6 deletions nutree/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,17 @@ class StopTraversal(IterationControl):
"""Raised or returned by traversal callbacks to stop iteration.
Optionally, a return value may be passed.
Note that if a callback returns ``False``, this will be converted to an
Note that if a callback returns ``False``, this will be converted to a
``StopTraversal(None)`` exception.
"""

def __init__(self, value=None):
self.value = value


#:
#: Generic callback for `tree.filter()`, `tree.copy()`, ...
PredicateCallbackType = Callable[["Node"], Union[None, bool, IterationControl]]
#:
#: Generic callback for `tree.to_dot()`, ...
MapperCallbackType = Callable[["Node", dict], Union[None, Any]]
#: Callback for `tree.save()`
SerializeMapperType = Callable[["Node", dict], Union[None, dict]]
Expand Down Expand Up @@ -221,10 +221,27 @@ def __getattr__(self, name: str) -> Any:
except KeyError:
raise AttributeError(name) from None

@staticmethod
def serialize_mapper(nutree_node, data):
@classmethod
def serialize_mapper(cls, nutree_node: Node, data: dict) -> Union[None, dict]:
"""Serialize the data object to a dictionary.
Example::
tree.save(file_path, mapper=GenericNodeData.serialize_mapper)
"""
return nutree_node.data._dict.copy()

@classmethod
def deserialize_mapper(cls, nutree_node: Node, data: dict) -> Union[str, object]:
"""Serialize the data object to a dictionary.
Example::
tree = Tree.load(file_path, mapper=GenericNodeData.deserialize_mapper)
"""
return cls(**data)


def get_version() -> str:
from nutree import __version__
Expand Down Expand Up @@ -323,7 +340,7 @@ def call_traversal_cb(fn: Callable, node: Node, memo: Any) -> False | None:
RuntimeWarning,
stacklevel=3,
)
raise StopTraversal(e.value) from None
raise StopTraversal(e.value) from e
return None


Expand Down
157 changes: 155 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest
from nutree import AmbiguousMatchError, IterMethod, Node, Tree
from nutree.common import SkipBranch, StopTraversal
from nutree.common import SkipBranch, StopTraversal, check_python_version
from nutree.fs import load_tree_from_fs

from . import fixture
Expand All @@ -23,6 +23,12 @@ def _make_tree_2():
return t


class TestCommon:
def test_check_python_version(self):
assert check_python_version((3, 7)) is True
assert check_python_version((99, 1)) is False


class TestBasics:
def test_add_child(self):
tree = Tree("fixture")
Expand Down Expand Up @@ -354,7 +360,7 @@ def test_data_id(self):
)
assert tree._self_check()

def test_search(self):
def test_find(self):
tree = self.tree

records = tree["Records"]
Expand Down Expand Up @@ -637,20 +643,94 @@ def cb(node, memo):
tree.visit(cb, method=IterMethod.LEVEL_ORDER)
assert ",".join(res) == "A,B,a1,a2,b1,a11,a12,b11"

def test_visit_cb(self):
"""
Tree<'fixture'>
├── A
│ ├── a1
│ │ ├── a11
│ │ ╰── a12
│ ╰── a2
╰── B
╰── b1
╰── b11
"""
tree = fixture.create_tree()

res = []

def cb(node, memo):
res.append(node.name)
if node.name == "a1":
return SkipBranch
if node.name == "b1":
return StopTraversal

res_2 = tree.visit(cb)

assert res_2 is None
assert ",".join(res) == "A,a1,a2,B,b1"

res = []

def cb(node, memo):
res.append(node.name)
if node.name == "a1":
raise SkipBranch(and_self=True)
if node.name == "b1":
raise StopTraversal("Found b1")

res_2 = tree.visit(cb)

assert res_2 == "Found b1"
# and_self does not skip self in this case
assert ",".join(res) == "A,a1,a2,B,b1"

res = []

def cb(node, memo):
res.append(node.name)
if node.name == "a12":
raise StopIteration

res_2 = tree.visit(cb)

assert ",".join(res) == "A,a1,a11,a12"

res = []

def cb(node, memo):
res.append(node.name)
if node.name == "a12":
return StopIteration

res_2 = tree.visit(cb)

assert ",".join(res) == "A,a1,a11,a12"

res = []

def cb(node, memo):
res.append(node.name)
if node.name == "a12":
return False

res_2 = tree.visit(cb)

assert ",".join(res) == "A,a1,a11,a12"

res = []

def cb(node, memo):
res.append(node.name)
if node.name == "b1":
return 17

with pytest.raises(
ValueError, match="callback should not return values except for"
):
res_2 = tree.visit(cb)


class TestMutate:
def test_add(self):
Expand Down Expand Up @@ -983,6 +1063,17 @@ def test_tree_copy_to(self):
)

def test_filter(self):
"""
Tree<'fixture'>
├── A
│ ├── a1
│ │ ├── a11
│ │ ╰── a12
│ ╰── a2
╰── B
╰── b1
╰── b11
"""
tree = fixture.create_tree()

def pred(node):
Expand All @@ -1005,6 +1096,17 @@ def pred(node):
)

def test_filtered(self):
"""
Tree<'fixture'>
├── A
│ ├── a1
│ │ ├── a11
│ │ ╰── a12
│ ╰── a2
╰── B
╰── b1
╰── b11
"""
tree = fixture.create_tree()

def pred(node):
Expand All @@ -1026,6 +1128,57 @@ def pred(node):
""",
)

def pred(node):
if node.name == "a12":
raise SkipBranch
return "2" in node.name.lower()

tree_2 = tree.filtered(predicate=pred)

assert tree_2._self_check()
assert fixture.check_content(
tree_2,
"""
Tree<*>
╰── A
╰── a2
╰── a2
""",
)

def pred(node):
if node.name == "a12":
raise StopIteration
return "2" in node.name.lower()

tree_2 = tree.filtered(predicate=pred)

assert tree_2._self_check()
assert fixture.check_content(
tree_2,
"""
Tree<*>
""",
)

tree_2 = tree.filtered(predicate=None)

assert tree_2._self_check()
assert fixture.check_content(
tree_2,
"""
Tree<*>
├── A
│ ├── a1
│ │ ├── a11
│ │ ╰── a12
│ ╰── a2
╰── B
╰── b1
╰── b11
""",
)


class TestFS:
@pytest.mark.skipif(os.name == "nt", reason="windows has different eol size")
Expand Down
20 changes: 20 additions & 0 deletions tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def test_serialize_compressed(self):
with fixture.WritableTempFile("r+t") as temp_file:
tree.save(temp_file.name, compression=zipfile.ZIP_DEFLATED)
tree_2 = Tree.load(temp_file.name)

with pytest.raises(UnicodeDecodeError):
_ = Tree.load(temp_file.name, auto_uncompress=False)

assert fixture.trees_equal(tree, tree_2)

with fixture.WritableTempFile("r+t") as temp_file:
tree.save(temp_file.name, compression=True)
tree_2 = Tree.load(temp_file.name)
assert fixture.trees_equal(tree, tree_2)

with fixture.WritableTempFile("r+t") as temp_file:
Expand All @@ -99,6 +108,17 @@ def test_serialize_compressed(self):
tree_2 = Tree.load(temp_file.name)
assert fixture.trees_equal(tree, tree_2)

def test_serialize_uncompressed(self):
tree = fixture.create_tree()
tree.add_child("äöüß: \u00e4\u00f6\u00fc\u00df")
tree.add_child("emoji: 😀")

with fixture.WritableTempFile("r+t") as temp_file:
tree.save(temp_file.name, compression=False)
tree_2 = Tree.load(temp_file.name)

assert fixture.trees_equal(tree, tree_2)

def _test_serialize_objects(self, *, mode: str):
"""Save/load an object tree with clones.
Expand Down
Loading

0 comments on commit 89da78d

Please sign in to comment.