diff --git a/docs/sphinx/ug_objects.rst b/docs/sphinx/ug_objects.rst index 9e5b831..b5eac05 100644 --- a/docs/sphinx/ug_objects.rst +++ b/docs/sphinx/ug_objects.rst @@ -207,6 +207,14 @@ GenericNodeData can also be initialized with keyword args like this:: obj = GenericNodeData(a=1, b=2) +Trees that contain GenericNodeData objects can be serialized and deserialized +using the :meth:`~nutree.tree.Tree.save` and :meth:`~nutree.tree.Tree.load` +methods:: + + tree.save(file_path, mapper=GenericNodeData.serialize_mapper) + ... + tree2 = Tree.load(file_path, mapper=GenericNodeData.deserialize_mapper) + .. warning:: The :class:`~nutree.common.GenericNodeData` provides a hash value because any class that is hashable, so it can be used as a data object. However, the diff --git a/nutree/common.py b/nutree/common.py index 4face0c..8dc3b9e 100644 --- a/nutree/common.py +++ b/nutree/common.py @@ -107,7 +107,7 @@ class StopTraversal(IterationControl): """Raised or returned by traversal callbacks to stop iteration. Optionally, a return value may be passed. - Note that if a callback returns ``False``, this will be converted to an + Note that if a callback returns ``False``, this will be converted to a ``StopTraversal(None)`` exception. """ @@ -115,9 +115,9 @@ def __init__(self, value=None): self.value = value -#: +#: Generic callback for `tree.filter()`, `tree.copy()`, ... PredicateCallbackType = Callable[["Node"], Union[None, bool, IterationControl]] -#: +#: Generic callback for `tree.to_dot()`, ... MapperCallbackType = Callable[["Node", dict], Union[None, Any]] #: Callback for `tree.save()` SerializeMapperType = Callable[["Node", dict], Union[None, dict]] @@ -221,10 +221,27 @@ def __getattr__(self, name: str) -> Any: except KeyError: raise AttributeError(name) from None - @staticmethod - def serialize_mapper(nutree_node, data): + @classmethod + def serialize_mapper(cls, nutree_node: Node, data: dict) -> Union[None, dict]: + """Serialize the data object to a dictionary. + + Example:: + + tree.save(file_path, mapper=GenericNodeData.serialize_mapper) + + """ return nutree_node.data._dict.copy() + @classmethod + def deserialize_mapper(cls, nutree_node: Node, data: dict) -> Union[str, object]: + """Serialize the data object to a dictionary. + + Example:: + + tree = Tree.load(file_path, mapper=GenericNodeData.deserialize_mapper) + """ + return cls(**data) + def get_version() -> str: from nutree import __version__ @@ -323,7 +340,7 @@ def call_traversal_cb(fn: Callable, node: Node, memo: Any) -> False | None: RuntimeWarning, stacklevel=3, ) - raise StopTraversal(e.value) from None + raise StopTraversal(e.value) from e return None diff --git a/tests/test_core.py b/tests/test_core.py index ac8e7a2..4d3acd9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,7 +9,7 @@ import pytest from nutree import AmbiguousMatchError, IterMethod, Node, Tree -from nutree.common import SkipBranch, StopTraversal +from nutree.common import SkipBranch, StopTraversal, check_python_version from nutree.fs import load_tree_from_fs from . import fixture @@ -23,6 +23,12 @@ def _make_tree_2(): return t +class TestCommon: + def test_check_python_version(self): + assert check_python_version((3, 7)) is True + assert check_python_version((99, 1)) is False + + class TestBasics: def test_add_child(self): tree = Tree("fixture") @@ -354,7 +360,7 @@ def test_data_id(self): ) assert tree._self_check() - def test_search(self): + def test_find(self): tree = self.tree records = tree["Records"] @@ -637,20 +643,94 @@ def cb(node, memo): tree.visit(cb, method=IterMethod.LEVEL_ORDER) assert ",".join(res) == "A,B,a1,a2,b1,a11,a12,b11" + def test_visit_cb(self): + """ + Tree<'fixture'> + ├── A + │ ├── a1 + │ │ ├── a11 + │ │ ╰── a12 + │ ╰── a2 + ╰── B + ╰── b1 + ╰── b11 + """ + tree = fixture.create_tree() + res = [] def cb(node, memo): res.append(node.name) if node.name == "a1": return SkipBranch + if node.name == "b1": + return StopTraversal + + res_2 = tree.visit(cb) + + assert res_2 is None + assert ",".join(res) == "A,a1,a2,B,b1" + + res = [] + + def cb(node, memo): + res.append(node.name) + if node.name == "a1": + raise SkipBranch(and_self=True) if node.name == "b1": raise StopTraversal("Found b1") res_2 = tree.visit(cb) assert res_2 == "Found b1" + # and_self does not skip self in this case assert ",".join(res) == "A,a1,a2,B,b1" + res = [] + + def cb(node, memo): + res.append(node.name) + if node.name == "a12": + raise StopIteration + + res_2 = tree.visit(cb) + + assert ",".join(res) == "A,a1,a11,a12" + + res = [] + + def cb(node, memo): + res.append(node.name) + if node.name == "a12": + return StopIteration + + res_2 = tree.visit(cb) + + assert ",".join(res) == "A,a1,a11,a12" + + res = [] + + def cb(node, memo): + res.append(node.name) + if node.name == "a12": + return False + + res_2 = tree.visit(cb) + + assert ",".join(res) == "A,a1,a11,a12" + + res = [] + + def cb(node, memo): + res.append(node.name) + if node.name == "b1": + return 17 + + with pytest.raises( + ValueError, match="callback should not return values except for" + ): + res_2 = tree.visit(cb) + class TestMutate: def test_add(self): @@ -983,6 +1063,17 @@ def test_tree_copy_to(self): ) def test_filter(self): + """ + Tree<'fixture'> + ├── A + │ ├── a1 + │ │ ├── a11 + │ │ ╰── a12 + │ ╰── a2 + ╰── B + ╰── b1 + ╰── b11 + """ tree = fixture.create_tree() def pred(node): @@ -1005,6 +1096,17 @@ def pred(node): ) def test_filtered(self): + """ + Tree<'fixture'> + ├── A + │ ├── a1 + │ │ ├── a11 + │ │ ╰── a12 + │ ╰── a2 + ╰── B + ╰── b1 + ╰── b11 + """ tree = fixture.create_tree() def pred(node): @@ -1026,6 +1128,57 @@ def pred(node): """, ) + def pred(node): + if node.name == "a12": + raise SkipBranch + return "2" in node.name.lower() + + tree_2 = tree.filtered(predicate=pred) + + assert tree_2._self_check() + assert fixture.check_content( + tree_2, + """ + Tree<*> + ╰── A + ╰── a2 + ╰── a2 + """, + ) + + def pred(node): + if node.name == "a12": + raise StopIteration + return "2" in node.name.lower() + + tree_2 = tree.filtered(predicate=pred) + + assert tree_2._self_check() + assert fixture.check_content( + tree_2, + """ + Tree<*> + """, + ) + + tree_2 = tree.filtered(predicate=None) + + assert tree_2._self_check() + assert fixture.check_content( + tree_2, + """ + Tree<*> + ├── A + │ ├── a1 + │ │ ├── a11 + │ │ ╰── a12 + │ ╰── a2 + ╰── B + ╰── b1 + ╰── b11 + """, + ) + class TestFS: @pytest.mark.skipif(os.name == "nt", reason="windows has different eol size") diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 525f6e1..bdf0475 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -87,6 +87,15 @@ def test_serialize_compressed(self): with fixture.WritableTempFile("r+t") as temp_file: tree.save(temp_file.name, compression=zipfile.ZIP_DEFLATED) tree_2 = Tree.load(temp_file.name) + + with pytest.raises(UnicodeDecodeError): + _ = Tree.load(temp_file.name, auto_uncompress=False) + + assert fixture.trees_equal(tree, tree_2) + + with fixture.WritableTempFile("r+t") as temp_file: + tree.save(temp_file.name, compression=True) + tree_2 = Tree.load(temp_file.name) assert fixture.trees_equal(tree, tree_2) with fixture.WritableTempFile("r+t") as temp_file: @@ -99,6 +108,17 @@ def test_serialize_compressed(self): tree_2 = Tree.load(temp_file.name) assert fixture.trees_equal(tree, tree_2) + def test_serialize_uncompressed(self): + tree = fixture.create_tree() + tree.add_child("äöüß: \u00e4\u00f6\u00fc\u00df") + tree.add_child("emoji: 😀") + + with fixture.WritableTempFile("r+t") as temp_file: + tree.save(temp_file.name, compression=False) + tree_2 = Tree.load(temp_file.name) + + assert fixture.trees_equal(tree, tree_2) + def _test_serialize_objects(self, *, mode: str): """Save/load an object tree with clones. diff --git a/tests/test_tree_generator.py b/tests/test_tree_generator.py index e1dde91..ec5ff96 100644 --- a/tests/test_tree_generator.py +++ b/tests/test_tree_generator.py @@ -19,113 +19,140 @@ ) from nutree.typed_tree import TypedTree +from tests import fixture -def test_simple(): - structure_def = { - "name": "fmea", - #: Types define the default properties of the nodes - "types": { - #: Default properties for all node types - "*": {":factory": GenericNodeData}, - #: Specific default properties for each node type - "function": {"icon": "bi bi-gear"}, - "failure": {"icon": "bi bi-exclamation-triangle"}, - "cause": {"icon": "bi bi-tools"}, - "effect": {"icon": "bi bi-lightning"}, - }, - #: Relations define the possible parent / child relationships between - #: node types and optionally override the default properties. - "relations": { - "__root__": { - "function": { - ":count": 3, - "title": "Function {hier_idx}", - "date": DateRangeRandomizer( - datetime.date(2020, 1, 1), datetime.date(2020, 12, 31) - ), - "date2": DateRangeRandomizer( - datetime.date(2020, 1, 1), 365, probability=0.99 - ), - "value": ValueRandomizer("foo", probability=0.5), - "expanded": SparseBoolRandomizer(probability=0.5), - "state": SampleRandomizer(["open", "closed"], probability=0.99), - }, + +class TestBase: + def test_simple(self): + structure_def = { + "name": "fmea", + #: Types define the default properties of the nodes + "types": { + #: Default properties for all node types + "*": {":factory": GenericNodeData}, + #: Specific default properties for each node type + "function": {"icon": "bi bi-gear"}, + "failure": {"icon": "bi bi-exclamation-triangle"}, + "cause": {"icon": "bi bi-tools"}, + "effect": {"icon": "bi bi-lightning"}, }, - "function": { - "failure": { - ":count": RangeRandomizer(1, 3), - "title": "Failure {hier_idx}", + #: Relations define the possible parent / child relationships between + #: node types and optionally override the default properties. + "relations": { + "__root__": { + "function": { + ":count": 3, + "title": "Function {hier_idx}", + "date": DateRangeRandomizer( + datetime.date(2020, 1, 1), datetime.date(2020, 12, 31) + ), + "date2": DateRangeRandomizer( + datetime.date(2020, 1, 1), 365, probability=0.99 + ), + "value": ValueRandomizer("foo", probability=0.5), + "expanded": SparseBoolRandomizer(probability=0.5), + "state": SampleRandomizer(["open", "closed"], probability=0.99), + }, }, - }, - "failure": { - "cause": { - ":count": RangeRandomizer(1, 3, probability=0.99), - "title": "Cause {hier_idx}", + "function": { + "failure": { + ":count": RangeRandomizer(1, 3), + "title": "Failure {hier_idx}", + }, }, - "effect": { - ":count": RangeRandomizer(1, 3), - "title": "Effect {hier_idx}", + "failure": { + "cause": { + ":count": RangeRandomizer(1, 3, probability=0.99), + "title": "Cause {hier_idx}", + }, + "effect": { + ":count": RangeRandomizer(1, 3), + "title": "Effect {hier_idx}", + }, }, }, - }, - } - tree = Tree.build_random_tree(structure_def) - tree.print() - assert type(tree) is Tree - assert tree.calc_height() == 3 + } + tree = Tree.build_random_tree(structure_def) + tree.print() + assert type(tree) is Tree + assert tree.calc_height() == 3 - tree2 = TypedTree.build_random_tree(structure_def) - tree2.print() - assert type(tree2) is TypedTree - assert tree2.calc_height() == 3 + tree2 = TypedTree.build_random_tree(structure_def) + tree2.print() + assert type(tree2) is TypedTree + assert tree2.calc_height() == 3 + # Save and load with GenericNodeData mappers + with fixture.WritableTempFile("r+t") as temp_file: + tree.save( + temp_file.name, + compression=True, + mapper=GenericNodeData.serialize_mapper, + ) + tree3 = Tree.load(temp_file.name, mapper=GenericNodeData.deserialize_mapper) + tree3.print() + assert fixture.trees_equal(tree, tree3) -def test_fabulist(): - if not fab: - pytest.skip("fabulist not installed") + def test_fabulist(self): + if not fab: + pytest.skip("fabulist not installed") - structure_def = { - "name": "fmea", - #: Types define the default properties of the nodes - "types": { - #: Default properties for all node types (optional, default - #: is GenericNodeData) - "*": {":factory": GenericNodeData}, - #: Specific default properties for each node type - "function": {"icon": "bi bi-gear"}, - "failure": {"icon": "bi bi-exclamation-triangle"}, - "cause": {"icon": "bi bi-tools"}, - "effect": {"icon": "bi bi-lightning"}, - }, - #: Relations define the possible parent / child relationships between - #: node types and optionally override the default properties. - "relations": { - "__root__": { - "function": { - ":count": 3, - "title": TextRandomizer(("{idx}: Provide $(Noun:plural)",)), - "details": BlindTextRandomizer(dialect="ipsum"), - "expanded": True, - }, + structure_def = { + "name": "fmea", + #: Types define the default properties of the nodes + "types": { + #: Default properties for all node types (optional, default + #: is GenericNodeData) + "*": {":factory": GenericNodeData}, + #: Specific default properties for each node type + "function": {"icon": "bi bi-gear"}, + "failure": {"icon": "bi bi-exclamation-triangle"}, + "cause": {"icon": "bi bi-tools"}, + "effect": {"icon": "bi bi-lightning"}, }, - "function": { - "failure": { - ":count": RangeRandomizer(1, 3), - "title": TextRandomizer("$(Noun:plural) not provided"), + #: Relations define the possible parent / child relationships between + #: node types and optionally override the default properties. + "relations": { + "__root__": { + "function": { + ":count": 3, + "title": TextRandomizer(("{idx}: Provide $(Noun:plural)",)), + "details": BlindTextRandomizer(dialect="ipsum"), + "expanded": True, + }, }, - }, - "failure": { - "cause": { - ":count": RangeRandomizer(1, 3), - "title": TextRandomizer("$(Noun:plural) not provided"), + "function": { + "failure": { + ":count": RangeRandomizer(1, 3), + "title": TextRandomizer("$(Noun:plural) not provided"), + }, }, - "effect": { - ":count": RangeRandomizer(1, 3), - "title": TextRandomizer("$(Noun:plural) not provided"), + "failure": { + "cause": { + ":count": RangeRandomizer(1, 3), + "title": TextRandomizer("$(Noun:plural) not provided"), + }, + "effect": { + ":count": RangeRandomizer(1, 3), + "title": TextRandomizer("$(Noun:plural) not provided"), + }, }, }, - }, - } - tree = TypedTree.build_random_tree(structure_def) - tree.print() - assert type(tree) is TypedTree + } + tree = TypedTree.build_random_tree(structure_def) + tree.print() + + assert type(tree) is TypedTree + + +class TestRandomizers: + def test_range(self): + r = RangeRandomizer(1, 3) + for v in (r.generate() for _ in range(100)): + assert isinstance(v, int) + assert 1 <= v <= 3 + + r = RangeRandomizer(1.0, 3.0) + for v in (r.generate() for _ in range(100)): + assert isinstance(v, float) + assert 1 <= v <= 3