Skip to content

Commit

Permalink
Register fake class for fbgemm::TensorQueue before export.
Browse files Browse the repository at this point in the history
Summary: Fakify tensor queue torch bind class.

Differential Revision: D57509249
  • Loading branch information
ydwu4 authored and facebook-github-bot committed May 21, 2024
1 parent 66efb75 commit c6d7656
Showing 1 changed file with 14 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,19 @@ struct TensorQueue : torch::CustomClassHolder {
return queue_.size();
}

std::tuple<
std::tuple<std::string, Tensor>,
std::tuple<std::string, std::vector<Tensor>>>
__obj_flatten__() {
std::vector<Tensor> queue_vec;
for (const auto& val : queue_) {
queue_vec.push_back(val);
}
return std::make_tuple(
std::make_tuple("init_tensor", init_tensor_),
std::make_tuple("queue", queue_vec));
}

private:
std::deque<Tensor> queue_;
std::mutex mutex_;
Expand All @@ -552,6 +565,7 @@ static auto TensorQueueRegistry =
.def("pop", &TensorQueue::pop)
.def("top", &TensorQueue::top)
.def("size", &TensorQueue::size)
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self)
Expand Down

0 comments on commit c6d7656

Please sign in to comment.