From c6d76569c7bde2ac3d86d5f8f19d830b4eb2e53d Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Tue, 21 May 2024 10:36:45 -0700 Subject: [PATCH] Register fake class for fbgemm::TensorQueue before export. Summary: Fakify tensor queue torch bind class. Differential Revision: D57509249 --- .../embedding_forward_quantized_host_cpu.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 1bf00c5e2d..3127a4980b 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -539,6 +539,19 @@ struct TensorQueue : torch::CustomClassHolder { return queue_.size(); } + std::tuple< + std::tuple, + std::tuple>> + __obj_flatten__() { + std::vector queue_vec; + for (const auto& val : queue_) { + queue_vec.push_back(val); + } + return std::make_tuple( + std::make_tuple("init_tensor", init_tensor_), + std::make_tuple("queue", queue_vec)); + } + private: std::deque queue_; std::mutex mutex_; @@ -552,6 +565,7 @@ static auto TensorQueueRegistry = .def("pop", &TensorQueue::pop) .def("top", &TensorQueue::top) .def("size", &TensorQueue::size) + .def("__obj_flatten__", &TensorQueue::__obj_flatten__) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self)