Skip to content

Commit

Permalink
feat(src/utils): add cache to is_namedtuple and is_structseq (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Mar 25, 2024
1 parent 72ae407 commit b3e2abf
Show file tree
Hide file tree
Showing 13 changed files with 422 additions and 121 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add function `is_namedtuple_instance` and `is_structseq_instance` and result caches by [@XuehaiPan](https://github.com/XuehaiPan) in [#121](https://github.com/metaopt/optree/pull/121).
- Add `tree_iter` function by [@XuehaiPan](https://github.com/XuehaiPan) in [#130](https://github.com/metaopt/optree/pull/130).
- Add API to unregister node type in the registry by [@XuehaiPan](https://github.com/XuehaiPan) in [#124](https://github.com/metaopt/optree/pull/124).
- Add tree map functions with transposed outputs `tree_transpose_map` and `tree_transpose_map_with_path` by [@XuehaiPan](https://github.com/XuehaiPan) in [#127](https://github.com/metaopt/optree/pull/127).
Expand All @@ -23,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Set recursion limit to 1000 for all platforms by [@XuehaiPan](https://github.com/XuehaiPan) in [#121](https://github.com/metaopt/optree/pull/121).
- Allow types to be registered in both the global namespace and custom namespaces by [@XuehaiPan](https://github.com/XuehaiPan) in [#124](https://github.com/metaopt/optree/pull/124).
- Set `treespec_is_leaf` as strict by default by [@XuehaiPan](https://github.com/XuehaiPan) in [#120](https://github.com/metaopt/optree/pull/120).
- Reorder functions for better code correspondence between C++ and Python by [@XuehaiPan](https://github.com/XuehaiPan) in [#117](https://github.com/metaopt/optree/pull/117).
Expand Down
6 changes: 6 additions & 0 deletions docs/source/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ Typing Support
PyTreeTypeVar
CustomTreeNode
is_namedtuple
is_namedtuple_instance
is_namedtuple_class
namedtuple_fields
is_structseq
is_structseq_instance
is_structseq_class
structseq_fields

Expand All @@ -38,12 +40,16 @@ Typing Support

.. autofunction:: is_namedtuple

.. autofunction:: is_namedtuple_instance

.. autofunction:: is_namedtuple_class

.. autofunction:: namedtuple_fields

.. autofunction:: is_structseq

.. autofunction:: is_structseq_instance

.. autofunction:: is_structseq_class

.. autofunction:: structseq_fields
36 changes: 2 additions & 34 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ limitations under the License.
#include <unordered_set> // std::unordered_set
#include <utility> // std::pair

#include "include/utils.h"

namespace optree {

namespace py = pybind11;
Expand Down Expand Up @@ -108,40 +110,6 @@ class PyTreeTypeRegistry {
static RegistrationPtr UnregisterImpl(const py::object &cls,
const std::string &registry_namespace);

class TypeHash {
public:
using is_transparent = void;
size_t operator()(const py::object &t) const;
size_t operator()(const py::handle &t) const;
};
class TypeEq {
public:
using is_transparent = void;
bool operator()(const py::object &a, const py::object &b) const;
bool operator()(const py::object &a, const py::handle &b) const;
bool operator()(const py::handle &a, const py::object &b) const;
bool operator()(const py::handle &a, const py::handle &b) const;
};

class NamedTypeHash {
public:
using is_transparent = void;
size_t operator()(const std::pair<std::string, py::object> &p) const;
size_t operator()(const std::pair<std::string, py::handle> &p) const;
};
class NamedTypeEq {
public:
using is_transparent = void;
bool operator()(const std::pair<std::string, py::object> &a,
const std::pair<std::string, py::object> &b) const;
bool operator()(const std::pair<std::string, py::object> &a,
const std::pair<std::string, py::handle> &b) const;
bool operator()(const std::pair<std::string, py::handle> &a,
const std::pair<std::string, py::object> &b) const;
bool operator()(const std::pair<std::string, py::handle> &a,
const std::pair<std::string, py::handle> &b) const;
};

inline static std::unordered_set<py::object, TypeHash, TypeEq> sm_builtins_types{};
std::unordered_map<py::object, RegistrationPtr, TypeHash, TypeEq> m_registrations{};
std::unordered_map<std::pair<std::string, py::object>,
Expand Down
2 changes: 1 addition & 1 deletion include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using size_t = py::size_t;
using ssize_t = py::ssize_t;

// The maximum depth of a pytree.
constexpr ssize_t MAX_RECURSION_DEPTH = 2000;
constexpr ssize_t MAX_RECURSION_DEPTH = 1000;

// Test whether the given object is a leaf node.
bool IsLeaf(const py::object &object,
Expand Down
142 changes: 128 additions & 14 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,21 @@ limitations under the License.

#include <pybind11/pybind11.h>

#include <exception> // std::rethrow_exception, std::current_exception
#include <functional> // std::hash
#include <sstream> // std::ostringstream
#include <string> // std::string
#include <utility> // std::move, std::pair, std::make_pair
#include <vector> // std::vector
#include <exception> // std::rethrow_exception, std::current_exception
#include <functional> // std::hash
#include <sstream> // std::ostringstream
#include <string> // std::string
#include <unordered_map> // std::unordered_map
#include <utility> // std::move, std::pair, std::make_pair
#include <vector> // std::vector

namespace py = pybind11;
using size_t = py::size_t;
using ssize_t = py::ssize_t;

// The maximum size of the type cache.
constexpr ssize_t MAX_TYPE_CACHE_SIZE = 4096;

// boost::hash_combine
template <class T>
inline void HashCombine(py::size_t& seed, const T& v) { // NOLINT[runtime/references]
Expand All @@ -50,6 +54,58 @@ inline void HashCombine(py::ssize_t& seed, const T& v) { // NOLINT[runtime/refe
seed ^= (hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}

class TypeHash {
public:
using is_transparent = void;
py::size_t operator()(const py::object& t) const { return std::hash<PyObject*>{}(t.ptr()); }
py::size_t operator()(const py::handle& t) const { return std::hash<PyObject*>{}(t.ptr()); }
};
class TypeEq {
public:
using is_transparent = void;
bool operator()(const py::object& a, const py::object& b) const { return a.ptr() == b.ptr(); }
bool operator()(const py::object& a, const py::handle& b) const { return a.ptr() == b.ptr(); }
bool operator()(const py::handle& a, const py::object& b) const { return a.ptr() == b.ptr(); }
bool operator()(const py::handle& a, const py::handle& b) const { return a.ptr() == b.ptr(); }
};

class NamedTypeHash {
public:
using is_transparent = void;
py::size_t operator()(const std::pair<std::string, py::object>& p) const {
py::size_t seed = 0;
HashCombine(seed, p.first);
HashCombine(seed, p.second.ptr());
return seed;
}
py::size_t operator()(const std::pair<std::string, py::handle>& p) const {
py::size_t seed = 0;
HashCombine(seed, p.first);
HashCombine(seed, p.second.ptr());
return seed;
}
};
class NamedTypeEq {
public:
using is_transparent = void;
bool operator()(const std::pair<std::string, py::object>& a,
const std::pair<std::string, py::object>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool operator()(const std::pair<std::string, py::object>& a,
const std::pair<std::string, py::handle>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool operator()(const std::pair<std::string, py::handle>& a,
const std::pair<std::string, py::object>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool operator()(const std::pair<std::string, py::handle>& a,
const std::pair<std::string, py::handle>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
};

constexpr bool NONE_IS_LEAF = true;
constexpr bool NONE_IS_NODE = false;

Expand Down Expand Up @@ -399,7 +455,25 @@ inline bool IsNamedTupleClassImpl(const py::handle& type) {
return false;
}
inline bool IsNamedTupleClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsNamedTupleClassImpl(type);
if (!PyType_Check(type.ptr())) [[unlikely]] {
return false;
}

static auto cache = std::unordered_map<py::handle, bool, TypeHash, TypeEq>{};
auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
return it->second;
}
bool result = IsNamedTupleClassImpl(type);
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
cache.erase(type);
weakref.dec_ref();
}))
.release();
}
return result;
}
inline bool IsNamedTupleInstance(const py::handle& object) {
return IsNamedTupleClass(py::type::handle_of(object));
Expand Down Expand Up @@ -462,7 +536,25 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) {
return false;
}
inline bool IsStructSequenceClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsStructSequenceClassImpl(type);
if (!PyType_Check(type.ptr())) [[unlikely]] {
return false;
}

static auto cache = std::unordered_map<py::handle, bool, TypeHash, TypeEq>{};
auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
return it->second;
}
bool result = IsStructSequenceClassImpl(type);
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
cache.erase(type);
weakref.dec_ref();
}))
.release();
}
return result;
}
inline bool IsStructSequenceInstance(const py::handle& object) {
return IsStructSequenceClass(py::type::handle_of(object));
Expand All @@ -477,6 +569,16 @@ inline void AssertExactStructSequence(const py::handle& object) {
PyRepr(object) + ".");
}
}
inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) {
const auto n_sequence_fields = getattr(type, Py_Get_ID(n_sequence_fields)).cast<ssize_t>();
auto* members = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_members;
py::tuple fields{n_sequence_fields};
for (ssize_t i = 0; i < n_sequence_fields; ++i) {
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
SET_ITEM<py::tuple>(fields, i, py::str(members[i].name));
}
return fields;
}
inline py::tuple StructSequenceGetFields(const py::handle& object) {
py::handle type;
if (PyType_Check(object.ptr())) [[unlikely]] {
Expand All @@ -492,12 +594,24 @@ inline py::tuple StructSequenceGetFields(const py::handle& object) {
}
}

const auto n_sequence_fields = getattr(type, Py_Get_ID(n_sequence_fields)).cast<ssize_t>();
auto* members = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_members;
py::tuple fields{n_sequence_fields};
for (ssize_t i = 0; i < n_sequence_fields; ++i) {
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
SET_ITEM<py::tuple>(fields, i, py::str(members[i].name));
static auto cache = std::unordered_map<py::handle, py::tuple, TypeHash, TypeEq>{};
auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
return it->second;
}
py::tuple fields = StructSequenceGetFieldsImpl(type);
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
cache.emplace(type, fields);
fields.inc_ref();
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
it->second.dec_ref();
cache.erase(it);
}
weakref.dec_ref();
}))
.release();
}
return fields;
}
Expand Down
2 changes: 2 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def all_leaves(
namespace: str = '',
) -> bool: ...
def is_namedtuple(obj: object | type) -> bool: ...
def is_namedtuple_instance(obj: object) -> bool: ...
def is_namedtuple_class(cls: type) -> bool: ...
def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: ...
def is_structseq(obj: object | type) -> bool: ...
def is_structseq_instance(obj: object) -> bool: ...
def is_structseq_class(cls: type) -> bool: ...
def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: ...

Expand Down
8 changes: 6 additions & 2 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@
UnflattenFunc,
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
namedtuple_fields,
structseq_fields,
)
Expand Down Expand Up @@ -174,14 +176,16 @@
'UnflattenFunc',
'is_namedtuple',
'is_namedtuple_class',
'is_namedtuple_instance',
'namedtuple_fields',
'is_structseq',
'is_structseq_instance',
'is_structseq_class',
'structseq_fields',
]

MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH # 2000
"""Maximum recursion depth for pytree traversal. It is 2000.
MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH # 1000
"""Maximum recursion depth for pytree traversal. It is 1000.
This limit prevents infinite recursion from causing an overflow of the C stack
and crashing Python.
Expand Down
16 changes: 8 additions & 8 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
S,
T,
U,
is_namedtuple_class,
is_structseq_class,
is_namedtuple_instance,
is_structseq_instance,
namedtuple_fields,
)
from optree.typing import structseq as PyStructSequence # noqa: N812
Expand Down Expand Up @@ -109,8 +109,8 @@
'prefix_errors',
]

MAX_RECURSION_DEPTH: int = _C.MAX_RECURSION_DEPTH # 2000
"""Maximum recursion depth for pytree traversal. It is 2000.
MAX_RECURSION_DEPTH: int = _C.MAX_RECURSION_DEPTH # 1000
"""Maximum recursion depth for pytree traversal. It is 1000.
This limit prevents infinite recursion from causing an overflow of the C stack
and crashing Python.
Expand Down Expand Up @@ -2418,7 +2418,7 @@ def treespec_namedtuple(
Returns:
A treespec representing a dict node with the given children.
"""
if not is_namedtuple_class(type(namedtuple)):
if not is_namedtuple_instance(namedtuple):
raise ValueError(f'Expected a namedtuple of PyTreeSpec(s), got {namedtuple!r}.')
return _C.make_from_collection(
namedtuple, # type: ignore[arg-type]
Expand Down Expand Up @@ -2595,7 +2595,7 @@ def treespec_structseq(
Returns:
A treespec representing a PyStructSequence node with the given children.
"""
if not is_structseq_class(type(structseq)):
if not is_structseq_instance(structseq):
raise ValueError(f'Expected a PyStructSequence of PyTreeSpec(s), got {structseq!r}.')
return _C.make_from_collection(
structseq, # type: ignore[arg-type]
Expand Down Expand Up @@ -2855,11 +2855,11 @@ def _child_keys(
if handler:
return list(handler(tree))

if is_structseq_class(type(tree)):
if is_structseq_instance(tree):
# Handle PyStructSequence as a special case, based on heuristic
return list(map(AttributeKeyPathEntry, structseq_fields(tree))) # type: ignore[arg-type]

if is_namedtuple_class(type(tree)):
if is_namedtuple_instance(tree):
# Handle namedtuple as a special case, based on heuristic
return list(map(AttributeKeyPathEntry, namedtuple_fields(tree))) # type: ignore[arg-type]

Expand Down
Loading

0 comments on commit b3e2abf

Please sign in to comment.