From 72ae407259bab90e49cf2294957444e12c13fe39 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 25 Mar 2024 23:31:55 +0800 Subject: [PATCH] feat(ops): add `tree_iter` function (#130) --- CHANGELOG.md | 1 + docs/source/ops.rst | 2 + include/registry.h | 6 ++ include/treespec.h | 89 ++++++++++++++-------- optree/_C.pyi | 13 +++- optree/__init__.py | 2 + optree/ops.py | 50 ++++++++++++- src/optree.cpp | 35 +++++++-- src/registry.cpp | 33 +++++++++ src/treespec/constructor.cpp | 5 +- src/treespec/flatten.cpp | 56 +++++++------- src/treespec/traversal.cpp | 139 ++++++++++++++++++++++++++++++++++- src/treespec/treespec.cpp | 31 -------- tests/test_ops.py | 20 +++++ 14 files changed, 383 insertions(+), 99 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ceaf150..ec27abc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add `tree_iter` function by [@XuehaiPan](https://github.com/XuehaiPan) in [#130](https://github.com/metaopt/optree/pull/130). - Add API to unregister node type in the registry by [@XuehaiPan](https://github.com/XuehaiPan) in [#124](https://github.com/metaopt/optree/pull/124). - Add tree map functions with transposed outputs `tree_transpose_map` and `tree_transpose_map_with_path` by [@XuehaiPan](https://github.com/XuehaiPan) in [#127](https://github.com/metaopt/optree/pull/127). - Add static constructors to create `PyTreeSpec` instances by [@XuehaiPan](https://github.com/XuehaiPan) in [#120](https://github.com/metaopt/optree/pull/120). diff --git a/docs/source/ops.rst b/docs/source/ops.rst index e9c2902d..d70bfe35 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -26,6 +26,7 @@ Tree Manipulation Functions tree_flatten tree_flatten_with_path tree_unflatten + tree_iter tree_leaves tree_structure tree_paths @@ -51,6 +52,7 @@ Tree Manipulation Functions .. autofunction:: tree_flatten .. autofunction:: tree_flatten_with_path .. autofunction:: tree_unflatten +.. autofunction:: tree_iter .. autofunction:: tree_leaves .. autofunction:: tree_structure .. autofunction:: tree_paths diff --git a/include/registry.h b/include/registry.h index aefbec6e..62a5be35 100644 --- a/include/registry.h +++ b/include/registry.h @@ -88,6 +88,12 @@ class PyTreeTypeRegistry { template static RegistrationPtr Lookup(const py::object &cls, const std::string ®istry_namespace); + // Compute the node kind of a given Python object. + template + static PyTreeKind GetKind(const py::handle &handle, + RegistrationPtr &custom, // NOLINT[runtime/references] + const std::string ®istry_namespace); + private: template static PyTreeTypeRegistry *Singleton(); diff --git a/include/treespec.h b/include/treespec.h index 74c2172e..c9b6e121 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -25,7 +25,7 @@ limitations under the License. #include // std::thread::id // NOLINT[build/c++11] #include // std::tuple #include // std::unordered_set -#include // std::pair +#include // std::pair, std::make_pair #include // std::vector #include "include/registry.h" @@ -40,6 +40,28 @@ using ssize_t = py::ssize_t; // The maximum depth of a pytree. constexpr ssize_t MAX_RECURSION_DEPTH = 2000; +// Test whether the given object is a leaf node. +bool IsLeaf(const py::object &object, + const std::optional &leaf_predicate, + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); + +// Test whether all elements in the given iterable are all leaves. +bool AllLeaves(const py::iterable &iterable, + const std::optional &leaf_predicate, + const bool &none_is_leaf = false, + const std::string ®istry_namespace = ""); + +template +bool IsLeafImpl(const py::handle &handle, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); + +template +bool AllLeavesImpl(const py::iterable &iterable, + const std::optional &leaf_predicate, + const std::string ®istry_namespace); + // A PyTreeSpec describes the tree structure of a PyTree. A PyTree is a tree of Python values, where // the interior nodes are tuples, lists, dictionaries, or user-defined containers, and the leaves // are other objects. @@ -164,18 +186,6 @@ class PyTreeSpec { const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); - // Test whether the given object is a leaf node. - static bool ObjectIsLeaf(const py::object &object, - const std::optional &leaf_predicate, - const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); - - // Test whether all elements in the given iterable are all leaves. - static bool AllLeaves(const py::iterable &iterable, - const std::optional &leaf_predicate, - const bool &none_is_leaf = false, - const std::string ®istry_namespace = ""); - private: using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr; @@ -232,12 +242,6 @@ class PyTreeSpec { const py::object *children, const size_t &num_children); - // Compute the node kind of a given Python object. - template - static PyTreeKind GetKind(const py::handle &handle, - RegistrationPtr &custom, // NOLINT[runtime/references] - const std::string ®istry_namespace); - // Recursive helper used to implement Flatten(). bool FlattenInto(const py::handle &handle, std::vector &leaves, // NOLINT[runtime/references] @@ -296,16 +300,6 @@ class PyTreeSpec { static std::unique_ptr MakeFromCollectionImpl(const py::handle &handle, std::string registry_namespace); - template - static bool ObjectIsLeafImpl(const py::handle &handle, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); - - template - static bool AllLeavesImpl(const py::iterable &iterable, - const std::optional &leaf_predicate, - const std::string ®istry_namespace); - class ThreadIndentTypeHash { public: using is_transparent = void; @@ -323,4 +317,41 @@ class PyTreeSpec { sm_hash_running{}; }; +class PyTreeIter { + public: + PyTreeIter(const py::object &tree, + const std::optional &leaf_predicate, + bool none_is_leaf, + std::string registry_namespace) + : m_agenda({std::make_pair(tree, 0)}), + m_leaf_predicate(leaf_predicate), + m_none_is_leaf(none_is_leaf), + m_namespace(std::move(registry_namespace)){}; + + PyTreeIter() = delete; + + ~PyTreeIter() = default; + + PyTreeIter(const PyTreeIter &) = delete; + + PyTreeIter operator=(const PyTreeIter &) = delete; + + PyTreeIter(PyTreeIter &&) = default; + + PyTreeIter &operator=(PyTreeIter &&) = default; + + [[nodiscard]] PyTreeIter &Iter() { return *this; } + + [[nodiscard]] py::object Next(); + + private: + std::vector> m_agenda; + std::optional m_leaf_predicate; + bool m_none_is_leaf; + std::string m_namespace; + + template + [[nodiscard]] py::object NextImpl(); +}; + } // namespace optree diff --git a/optree/_C.pyi b/optree/_C.pyi index 205175a3..cce76a91 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -17,7 +17,7 @@ import builtins import enum -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Iterator from typing import Any from optree.typing import CustomTreeNode, FlattenFunc, MetaData, PyTree, T, U, UnflattenFunc @@ -123,6 +123,17 @@ class PyTreeSpec: def __hash__(self) -> int: ... def __len__(self) -> int: ... +class PyTreeIter(Iterator[T]): + def __init__( + self, + tree: PyTree[T], + leaf_predicate: Callable[[T], bool] | None = None, + node_is_leaf: bool = False, + namespace: str = '', + ) -> None: ... + def __iter__(self) -> PyTreeIter[T]: ... + def __next__(self) -> T: ... + def register_node( cls: type[CustomTreeNode[T]], flatten_func: FlattenFunc, diff --git a/optree/__init__.py b/optree/__init__.py index 9f5fb82e..0965f59c 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -33,6 +33,7 @@ tree_flatten_one_level, tree_flatten_with_path, tree_is_leaf, + tree_iter, tree_leaves, tree_map, tree_map_, @@ -106,6 +107,7 @@ 'tree_flatten', 'tree_flatten_with_path', 'tree_unflatten', + 'tree_iter', 'tree_leaves', 'tree_structure', 'tree_paths', diff --git a/optree/ops.py b/optree/ops.py index 025fe8f0..28b9ecef 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -59,6 +59,7 @@ 'tree_flatten', 'tree_flatten_with_path', 'tree_unflatten', + 'tree_iter', 'tree_leaves', 'tree_structure', 'tree_paths', @@ -129,7 +130,7 @@ def tree_flatten( ) -> tuple[list[T], PyTreeSpec]: """Flatten a pytree. - See also :func:`tree_flatten_with_path`. + See also :func:`tree_flatten_with_path` and :func:`tree_unflatten`. The flattening order (i.e., the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal. @@ -283,6 +284,47 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[T]) -> PyTree[T]: return treespec.unflatten(leaves) +def tree_iter( + tree: PyTree[T], + is_leaf: Callable[[T], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = '', +) -> Iterable[T]: + """Get an iterator over the leaves of a pytree. + + See also :func:`tree_flatten` and :func:`tree_leaves`. + + >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} + >>> list(tree_iter(tree)) + [1, 2, 3, 4, 5] + >>> list(tree_iter(tree, none_is_leaf=True)) + [1, 2, 3, 4, None, 5] + >>> list(tree_iter(1)) + [1] + >>> list(tree_iter(None)) + [] + >>> list(tree_iter(None, none_is_leaf=True)) + [None] + + Args: + tree (pytree): A pytree to iterate over. + is_leaf (callable, optional): An optionally specified function that will be called at each + flattening step. It should return a boolean, with :data:`True` stopping the traversal + and the whole subtree being treated as a leaf, and :data:`False` indicating the + flattening should traverse the current object. + none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, + :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the + treespec rather than in the leaves list. (default: :data:`False`) + namespace (str, optional): The registry namespace used for custom pytree node types. + (default: :const:`''`, i.e., the global namespace) + + Returns: + An iterator over the leaf values. + """ + return _C.PyTreeIter(tree, is_leaf, none_is_leaf, namespace) + + def tree_leaves( tree: PyTree[T], is_leaf: Callable[[T], bool] | None = None, @@ -292,7 +334,7 @@ def tree_leaves( ) -> list[T]: """Get the leaves of a pytree. - See also :func:`tree_flatten`. + See also :func:`tree_flatten` and :func:`tree_iter`. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_leaves(tree) @@ -1827,7 +1869,7 @@ def tree_all( Otherwise, :data:`False`. """ return all( - tree_leaves( + tree_iter( tree, # type: ignore[arg-type] is_leaf=is_leaf, # type: ignore[arg-type] none_is_leaf=none_is_leaf, @@ -1878,7 +1920,7 @@ def tree_any( empty, return :data:`False`. """ return any( - tree_leaves( + tree_iter( tree, # type: ignore[arg-type] is_leaf=is_leaf, # type: ignore[arg-type] none_is_leaf=none_is_leaf, diff --git a/src/optree.cpp b/src/optree.cpp index 03a858ed..78ecdb3e 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -19,6 +19,7 @@ limitations under the License. #include #include // std::nullopt +#include // std::string #include "include/exceptions.h" #include "include/registry.h" @@ -67,14 +68,14 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] py::arg("none_is_leaf") = false, py::arg("namespace") = "") .def("is_leaf", - &PyTreeSpec::ObjectIsLeaf, + &IsLeaf, "Test whether the given object is a leaf node.", py::arg("obj"), py::arg("leaf_predicate") = std::nullopt, py::arg("none_is_leaf") = false, py::arg("namespace") = "") .def("all_leaves", - &PyTreeSpec::AllLeaves, + &AllLeaves, "Test whether all elements in the given iterable are all leaves.", py::arg("iterable"), py::arg("leaf_predicate") = std::nullopt, @@ -259,22 +260,44 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] "Serialization support for PyTreeSpec.", py::arg("state")); + auto PyTreeIterTypeObject = + py::class_(mod, "PyTreeIter", "Iterator over the leaves of a pytree."); + reinterpret_cast(PyTreeIterTypeObject.ptr())->tp_name = "optree.PyTreeIter"; + py::setattr(PyTreeIterTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); + + PyTreeIterTypeObject + .def(py::init, bool, std::string>(), + "Create a new iterator over the leaves of a pytree.", + py::arg("tree"), + py::arg("leaf_predicate") = std::nullopt, + py::arg("none_is_leaf") = false, + py::arg("namespace") = "") + .def("__iter__", &PyTreeIter::Iter, "Return the iterator object itself.") + .def("__next__", &PyTreeIter::Next, "Return the next leaf in the pytree."); + #ifdef Py_TPFLAGS_IMMUTABLETYPE + reinterpret_cast(PyTreeKindTypeObject.ptr())->tp_flags |= + Py_TPFLAGS_IMMUTABLETYPE; reinterpret_cast(PyTreeSpecTypeObject.ptr())->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE; - reinterpret_cast(PyTreeKindTypeObject.ptr())->tp_flags |= + reinterpret_cast(PyTreeIterTypeObject.ptr())->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE; - reinterpret_cast(PyTreeSpecTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY; reinterpret_cast(PyTreeKindTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY; + reinterpret_cast(PyTreeSpecTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY; + reinterpret_cast(PyTreeIterTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY; #endif + if (PyType_Ready(reinterpret_cast(PyTreeKindTypeObject.ptr())) < 0) + [[unlikely]] { + INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed."); + } if (PyType_Ready(reinterpret_cast(PyTreeSpecTypeObject.ptr())) < 0) [[unlikely]] { INTERNAL_ERROR("`PyType_Ready(&PyTreeSpec_Type)` failed."); } - if (PyType_Ready(reinterpret_cast(PyTreeKindTypeObject.ptr())) < 0) + if (PyType_Ready(reinterpret_cast(PyTreeIterTypeObject.ptr())) < 0) [[unlikely]] { - INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed."); + INTERNAL_ERROR("`PyType_Ready(&PyTreeIter_Type)` failed."); } } diff --git a/src/registry.cpp b/src/registry.cpp index d3c99fe7..cf764426 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -233,6 +233,39 @@ template PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::Lookup( const py::object&, const std::string&); +template +/*static*/ PyTreeKind PyTreeTypeRegistry::GetKind( + const py::handle& handle, + PyTreeTypeRegistry::RegistrationPtr& custom, // NOLINT[runtime/references] + const std::string& registry_namespace) { + RegistrationPtr registration = Lookup(py::type::of(handle), registry_namespace); + if (registration) [[likely]] { + if (registration->kind == PyTreeKind::Custom) [[unlikely]] { + custom = registration; + } else [[likely]] { + custom = nullptr; + } + return registration->kind; + } + custom = nullptr; + if (IsStructSequenceInstance(handle)) [[unlikely]] { + return PyTreeKind::StructSequence; + } + if (IsNamedTupleInstance(handle)) [[unlikely]] { + return PyTreeKind::NamedTuple; + } + return PyTreeKind::Leaf; +} + +template PyTreeKind PyTreeTypeRegistry::GetKind( + const py::handle&, + PyTreeTypeRegistry::RegistrationPtr& custom, // NOLINT[runtime/references] + const std::string&); +template PyTreeKind PyTreeTypeRegistry::GetKind( + const py::handle&, + PyTreeTypeRegistry::RegistrationPtr& custom, // NOLINT[runtime/references] + const std::string&); + size_t PyTreeTypeRegistry::TypeHash::operator()(const py::object& t) const { return std::hash{}(t.ptr()); } diff --git a/src/treespec/constructor.cpp b/src/treespec/constructor.cpp index c9f95c18..53be2196 100644 --- a/src/treespec/constructor.cpp +++ b/src/treespec/constructor.cpp @@ -72,7 +72,7 @@ template auto treespecs = reserved_vector(4); Node node; - node.kind = GetKind(handle, node.custom, registry_namespace); + node.kind = PyTreeTypeRegistry::GetKind(handle, node.custom, registry_namespace); auto verify_children = [&handle, &node](const std::vector& children, std::vector& treespecs, @@ -135,7 +135,8 @@ template break; } INTERNAL_ERROR( - "NoneIsLeaf is true, but PyTreeSpec::GetKind() returned `PyTreeKind::None`."); + "NoneIsLeaf is true, but PyTreeTypeRegistry::GetKind() returned " + "`PyTreeKind::None`."); } case PyTreeKind::Tuple: { diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index e89669a9..e29964b9 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -51,7 +51,8 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, if (leaf_predicate && (*leaf_predicate)(handle).cast()) [[unlikely]] { leaves.emplace_back(py::reinterpret_borrow(handle)); } else [[likely]] { - node.kind = GetKind(handle, node.custom, registry_namespace); + node.kind = + PyTreeTypeRegistry::GetKind(handle, node.custom, registry_namespace); // NOLINTNEXTLINE[misc-no-recursion] auto recurse = [this, &found_custom, &leaf_predicate, ®istry_namespace, &leaves, &depth]( const py::handle& child) -> void { @@ -69,7 +70,8 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, break; } INTERNAL_ERROR( - "NoneIsLeaf is true, but PyTreeSpec::GetKind() returned `PyTreeKind::None`."); + "NoneIsLeaf is true, but PyTreeTypeRegistry::GetKind() returned " + "`PyTreeKind::None`."); } case PyTreeKind::Tuple: { @@ -230,7 +232,8 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, leaves.emplace_back(py::reinterpret_borrow(handle)); paths.emplace_back(std::move(path)); } else [[likely]] { - node.kind = GetKind(handle, node.custom, registry_namespace); + node.kind = + PyTreeTypeRegistry::GetKind(handle, node.custom, registry_namespace); // NOLINTNEXTLINE[misc-no-recursion] auto recurse = [this, &found_custom, @@ -261,7 +264,8 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, break; } INTERNAL_ERROR( - "NoneIsLeaf is true, but PyTreeSpec::GetKind() returned PyTreeKind::None`."); + "NoneIsLeaf is true, but PyTreeTypeRegistry::GetKind() returned " + "PyTreeKind::None`."); } case PyTreeKind::Tuple: { @@ -448,7 +452,7 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { case PyTreeKind::None: { if (m_none_is_leaf) [[unlikely]] { INTERNAL_ERROR( - "NoneIsLeaf is true, but PyTreeSpec::GetKind() returned " + "NoneIsLeaf is true, but PyTreeTypeRegistry::GetKind() returned " "`PyTreeKind::None`."); } if (!object.is_none()) [[likely]] { @@ -651,45 +655,47 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { } template -/*static*/ bool PyTreeSpec::ObjectIsLeafImpl(const py::handle& handle, - const std::optional& leaf_predicate, - const std::string& registry_namespace) { - RegistrationPtr custom{nullptr}; +bool IsLeafImpl(const py::handle& handle, + const std::optional& leaf_predicate, + const std::string& registry_namespace) { + PyTreeTypeRegistry::RegistrationPtr custom{nullptr}; return ((leaf_predicate && (*leaf_predicate)(handle).cast()) || - (GetKind(handle, custom, registry_namespace) == PyTreeKind::Leaf)); + (PyTreeTypeRegistry::GetKind(handle, custom, registry_namespace) == + PyTreeKind::Leaf)); } -/*static*/ bool PyTreeSpec::ObjectIsLeaf(const py::object& object, - const std::optional& leaf_predicate, - const bool& none_is_leaf, - const std::string& registry_namespace) { +bool IsLeaf(const py::object& object, + const std::optional& leaf_predicate, + const bool& none_is_leaf, + const std::string& registry_namespace) { if (none_is_leaf) [[unlikely]] { - return ObjectIsLeafImpl(object, leaf_predicate, registry_namespace); + return IsLeafImpl(object, leaf_predicate, registry_namespace); } else [[likely]] { - return ObjectIsLeafImpl(object, leaf_predicate, registry_namespace); + return IsLeafImpl(object, leaf_predicate, registry_namespace); } } template -/*static*/ bool PyTreeSpec::AllLeavesImpl(const py::iterable& iterable, - const std::optional& leaf_predicate, - const std::string& registry_namespace) { - RegistrationPtr custom{nullptr}; +bool AllLeavesImpl(const py::iterable& iterable, + const std::optional& leaf_predicate, + const std::string& registry_namespace) { + PyTreeTypeRegistry::RegistrationPtr custom{nullptr}; for (const py::handle& h : iterable) { if (leaf_predicate && (*leaf_predicate)(h).cast()) [[unlikely]] { continue; } - if (GetKind(h, custom, registry_namespace) != PyTreeKind::Leaf) [[unlikely]] { + if (PyTreeTypeRegistry::GetKind(h, custom, registry_namespace) != + PyTreeKind::Leaf) [[unlikely]] { return false; } } return true; } -/*static*/ bool PyTreeSpec::AllLeaves(const py::iterable& iterable, - const std::optional& leaf_predicate, - const bool& none_is_leaf, - const std::string& registry_namespace) { +bool AllLeaves(const py::iterable& iterable, + const std::optional& leaf_predicate, + const bool& none_is_leaf, + const std::string& registry_namespace) { if (none_is_leaf) [[unlikely]] { return AllLeavesImpl(iterable, leaf_predicate, registry_namespace); } else [[likely]] { diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 45b2a5b2..cacf9e4e 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -15,7 +15,10 @@ limitations under the License. ================================================================================ */ -#include // std::move +#include // std::ostringstream +#include // std::runtime_error +#include // std::string +#include // std::move #include "include/exceptions.h" #include "include/registry.h" @@ -24,6 +27,140 @@ limitations under the License. namespace optree { +template +// NOLINTNEXTLINE[readability-function-cognitive-complexity] +py::object PyTreeIter::NextImpl() { + while (!m_agenda.empty()) [[likely]] { + auto [object, depth] = m_agenda.back(); + m_agenda.pop_back(); + + if (depth > MAX_RECURSION_DEPTH) [[unlikely]] { + PyErr_SetString(PyExc_RecursionError, + "Maximum recursion depth exceeded during flattening the tree."); + throw py::error_already_set(); + } + + if (m_leaf_predicate && (*m_leaf_predicate)(object).cast()) [[unlikely]] { + return object; + } + + PyTreeTypeRegistry::RegistrationPtr custom{nullptr}; + PyTreeKind kind = PyTreeTypeRegistry::GetKind(object, custom, m_namespace); + + ++depth; + switch (kind) { + case PyTreeKind::Leaf: { + return object; + } + + case PyTreeKind::None: { + if (!NoneIsLeaf) { + break; + } + INTERNAL_ERROR( + "NoneIsLeaf is true, but PyTreeTypeRegistry::GetKind() returned " + "`PyTreeKind::None`."); + } + + case PyTreeKind::Tuple: { + ssize_t arity = GET_SIZE(object); + for (ssize_t i = arity - 1; i >= 0; --i) { + m_agenda.emplace_back(GET_ITEM_BORROW(object, i), depth); + } + break; + } + + case PyTreeKind::List: { + ssize_t arity = GET_SIZE(object); + for (ssize_t i = arity - 1; i >= 0; --i) { + m_agenda.emplace_back(GET_ITEM_BORROW(object, i), depth); + } + break; + } + + case PyTreeKind::Dict: + case PyTreeKind::OrderedDict: + case PyTreeKind::DefaultDict: { + auto dict = py::reinterpret_borrow(object); + py::list keys = DictKeys(dict); + if (kind != PyTreeKind::OrderedDict) [[likely]] { + TotalOrderSort(keys); + } + if (PyList_Reverse(keys.ptr()) < 0) [[unlikely]] { + throw py::error_already_set(); + } + for (const py::handle& key : keys) { + m_agenda.emplace_back(dict[key], depth); + } + break; + } + + case PyTreeKind::NamedTuple: + case PyTreeKind::StructSequence: { + auto tuple = py::reinterpret_borrow(object); + ssize_t arity = GET_SIZE(tuple); + for (ssize_t i = arity - 1; i >= 0; --i) { + m_agenda.emplace_back(GET_ITEM_BORROW(tuple, i), depth); + } + break; + } + + case PyTreeKind::Deque: { + auto list = py::cast(object); + ssize_t arity = GET_SIZE(list); + for (ssize_t i = arity - 1; i >= 0; --i) { + m_agenda.emplace_back(GET_ITEM_BORROW(list, i), depth); + } + break; + } + + case PyTreeKind::Custom: { + py::tuple out = py::cast(custom->flatten_func(object)); + const ssize_t num_out = GET_SIZE(out); + if (num_out != 2 && num_out != 3) [[unlikely]] { + std::ostringstream oss{}; + oss << "PyTree custom flatten function for type " << PyRepr(custom->type) + << " should return a 2- or 3-tuple, got " << num_out << "."; + throw std::runtime_error(oss.str()); + } + auto children = py::cast(GET_ITEM_BORROW(out, 0)); + ssize_t arity = GET_SIZE(children); + if (num_out == 3) [[likely]] { + py::object node_entries = GET_ITEM_BORROW(out, 2); + if (!node_entries.is_none()) [[likely]] { + const ssize_t num_entries = + GET_SIZE(py::cast(std::move(node_entries))); + if (num_entries != arity) [[unlikely]] { + std::ostringstream oss{}; + oss << "PyTree custom flatten function for type " + << PyRepr(custom->type) + << " returned inconsistent number of children (" << arity + << ") and number of entries (" << num_entries << ")."; + throw std::runtime_error(oss.str()); + } + } + } + for (ssize_t i = arity - 1; i >= 0; --i) { + m_agenda.emplace_back(GET_ITEM_BORROW(children, i), depth); + } + break; + } + + default: + INTERNAL_ERROR(); + } + } + + throw py::stop_iteration(); +} + +py::object PyTreeIter::Next() { + if (m_none_is_leaf) [[unlikely]] { + return NextImpl(); + } + return NextImpl(); +} + py::object PyTreeSpec::Walk(const py::function& f_node, const py::object& f_leaf, const py::iterable& leaves) const { diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index b3ea5fa9..7c0ac2ea 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -166,37 +166,6 @@ namespace optree { } } -template -/*static*/ PyTreeKind PyTreeSpec::GetKind(const py::handle& handle, - RegistrationPtr& custom, - const std::string& registry_namespace) { - RegistrationPtr registration = - PyTreeTypeRegistry::Lookup(py::type::of(handle), registry_namespace); - if (registration) [[likely]] { - if (registration->kind == PyTreeKind::Custom) [[unlikely]] { - custom = registration; - } else [[likely]] { - custom = nullptr; - } - return registration->kind; - } - custom = nullptr; - if (IsStructSequenceInstance(handle)) [[unlikely]] { - return PyTreeKind::StructSequence; - } - if (IsNamedTupleInstance(handle)) [[unlikely]] { - return PyTreeKind::NamedTuple; - } - return PyTreeKind::Leaf; -} - -template PyTreeKind PyTreeSpec::GetKind(const py::handle&, - RegistrationPtr&, - const std::string&); -template PyTreeKind PyTreeSpec::GetKind(const py::handle&, - RegistrationPtr&, - const std::string&); - // NOLINTNEXTLINE[readability-function-cognitive-complexity] /*static*/ std::tuple PyTreeSpec::BroadcastToCommonSuffixImpl( std::vector& nodes, diff --git a/tests/test_ops.py b/tests/test_ops.py index 6c56db5f..7bfe5d66 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -130,6 +130,9 @@ def test_flatten_dict_order(): assert optree.tree_leaves({'a': 1, 2: 2}) == [2, 1] assert optree.tree_leaves({'a': 1, 2: 2, 3.0: 3}) == [3, 2, 1] assert optree.tree_leaves({2: 2, 3.0: 3}) == [2, 3] + assert list(optree.tree_iter({'a': 1, 2: 2})) == [2, 1] + assert list(optree.tree_iter({'a': 1, 2: 2, 3.0: 3})) == [3, 2, 1] + assert list(optree.tree_iter({2: 2, 3.0: 3})) == [2, 3] sorted_treespec = optree.tree_structure({'a': 1, 'b': 2, 'c': {'e': 3, 'f': None, 'g': 4}}) @@ -164,6 +167,20 @@ def test_tree_unflatten_mismatch_number_of_leaves(tree, none_is_leaf, namespace) optree.tree_unflatten(treespec, (*leaves, 0)) +@parametrize( + tree=list(TREES + LEAVES), + none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], +) +def test_tree_iter(tree, none_is_leaf, namespace): + leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) + it = optree.tree_iter(tree, none_is_leaf=none_is_leaf, namespace=namespace) + assert iter(it) is it + assert list(it) == leaves + with pytest.raises(StopIteration): + next(it) + + def test_walk(): tree = {'b': 2, 'a': 1, 'c': {'f': None, 'e': 3, 'g': 4}} # tree @@ -270,7 +287,9 @@ def test_flatten_up_to_none_is_leaf(): @parametrize( leaves_fn=[ optree.tree_leaves, + lambda tree, is_leaf: list(optree.tree_iter(tree, is_leaf)), lambda tree, is_leaf: optree.tree_flatten(tree, is_leaf)[0], + lambda tree, is_leaf: optree.tree_flatten_with_path(tree, is_leaf)[1], ], ) def test_flatten_is_leaf(leaves_fn): @@ -297,6 +316,7 @@ def test_flatten_is_leaf(leaves_fn): structure_fn=[ optree.tree_structure, lambda tree, is_leaf: optree.tree_flatten(tree, is_leaf)[1], + lambda tree, is_leaf: optree.tree_flatten_with_path(tree, is_leaf)[2], ], ) def test_structure_is_leaf(structure_fn):