diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 9723ddfea..73e6973c6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -396,24 +396,26 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { public: explicit KVTensorWrapper( c10::intrusive_ptr db, - c10::intrusive_ptr snapshot_handle, std::vector shape, int64_t dtype, - int64_t row_offset) - : db_(db->impl_), - snapshot_handle_(std::move(snapshot_handle)), - shape_(std::move(shape)), - row_offset_(row_offset) { + int64_t row_offset, + std::optional> + snapshot_handle) + : db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) { CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported"; options_ = at::TensorOptions() .dtype(static_cast(dtype)) .device(at::kCPU) .layout(at::kStrided); + if (snapshot_handle.has_value()) { + snapshot_handle_ = std::move(snapshot_handle.value()); + } } at::Tensor narrow(int64_t dim, int64_t start, int64_t length) { CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported"; CHECK_EQ(db_->get_max_D(), shape_[1]); + CHECK_TRUE(snapshot_handle_ != nullptr); auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_); db_->get_range_from_snapshot( t, start + row_offset_, length, snapshot_handle_->handle); @@ -422,6 +424,16 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { return t.narrow(1, 0, shape_[1]); } + void set_range( + int64_t dim, + const int64_t start, + const int64_t length, + const at::Tensor& weights) { + CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported"; + CHECK_EQ(db_->get_max_D(), shape_[1]); + db_->set_range(weights, start + row_offset_, length); + } + c10::IntArrayRef size() { return shape_; } @@ -537,21 +549,25 @@ static auto kv_tensor_wrapper = .def( torch::init< c10::intrusive_ptr, - c10::intrusive_ptr, std::vector, int64_t, - int64_t>(), + int64_t, + std::optional< + c10::intrusive_ptr>>(), "", {torch::arg("db"), - torch::arg("snapshot_handle"), torch::arg("shape"), torch::arg("dtype"), - torch::arg("row_offset")}) + torch::arg("row_offset"), + // snapshot must be provided for reading + // not needed for writing + torch::arg("snapshot_handle") = std::nullopt}) .def( "narrow", &KVTensorWrapper::narrow, "", {torch::arg("dim"), torch::arg("start"), torch::arg("length")}) + .def("set_range", &KVTensorWrapper::set_range) .def_property("dtype_str", &KVTensorWrapper::dtype_str) .def_property( "shape", diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 2a8f101c5..f14897854 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -528,13 +528,21 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { const SnapshotHandle* snapshot_handle) { const auto seq_indices = at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); - int64_t* count_ = new int64_t[1]; - count_[0] = length; - const auto count = at::from_blob(count_, {1}, at::kLong); + const auto count = at::tensor({length}, at::ScalarType::Long); folly::coro::blockingWait( get_kv_db_async_impl(seq_indices, weights, count, snapshot_handle)); } + void set_range( + const at::Tensor& weights, + const int64_t start, + const int64_t length) { + const auto seq_indices = + at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); + const auto count = at::tensor({length}, at::ScalarType::Long); + folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count)); + } + int64_t get_max_D() { return max_D_; } diff --git a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py index 1530f1e75..f1277ccac 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py @@ -76,7 +76,7 @@ def test_read_tensor_using_wrapper_from_db(self) -> None: # create a view tensor wrapper snapshot = ssd_db.create_snapshot() tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( - ssd_db, snapshot, [E, D], weights.dtype, 0 + ssd_db, [E, D], weights.dtype, 0, snapshot ) self.assertEqual(tensor_wrapper.shape, [E, D]) @@ -100,3 +100,72 @@ def test_read_tensor_using_wrapper_from_db(self) -> None: del tensor_wrapper del snapshot self.assertEqual(ssd_db.get_snapshot_count(), 0) + + def test_write_tensor_to_db(self) -> None: + E = int(1e4) # num total rows + D = 128 # emb dimension + N = 1000 # window size + weights_precision = SparseType.FP32 + weights_dtype = weights_precision.as_dtype() + + with tempfile.TemporaryDirectory() as ssd_directory: + # pyre-fixme[16]: Module `classes` has no attribute `fbgemm`. + ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( + ssd_directory, + 8, # num_shards + 8, # num_threads + 0, # ssd_memtable_flush_period, + 0, # ssd_memtable_flush_offset, + 4, # ssd_l0_files_per_compact, + D, # embedding_dim + 0, # ssd_rate_limit_mbps, + 1, # ssd_size_ratio, + 8, # ssd_compaction_trigger, + 536870912, # 512MB ssd_write_buffer_size, + 8, # ssd_max_write_buffer_num, + -0.01, # ssd_uniform_init_lower + 0.01, # ssd_uniform_init_upper + 32, # row_storage_bitwidth + 10 * (2**20), # block cache size + ) + + weights = torch.arange(N * D, dtype=weights_dtype).view(N, D) + output_weights = torch.empty_like(weights) + + # no snapshot needed for writing to rocksdb + tensor_wrapper0 = torch.classes.fbgemm.KVTensorWrapper( + ssd_db, [E, D], weights.dtype, 0 + ) + step = N + for i in range(0, E, step): + tensor_wrapper0.set_range(0, i, step, weights) + + # force waiting for set to complete + indices = torch.arange(step) + for i in range(0, E, step): + ssd_db.get(i + indices, output_weights, torch.tensor(indices.shape[0])) + + # create a view tensor wrapper + snapshot = ssd_db.create_snapshot() + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + ssd_db, [E, D], weights.dtype, 0, snapshot + ) + self.assertEqual(tensor_wrapper.shape, [E, D]) + + # table has a total of E rows + # load 1000 rows at a time + step = 1000 + for i in range(0, E, step): + narrowed = tensor_wrapper.narrow(0, i, step) + self.assertTrue( + torch.equal(narrowed, weights), + msg=( + f"Tensor value mismatch :\n" + f"actual\n{narrowed}\n\nexpected\n{weights}" + ), + ) + + del tensor_wrapper0 + del tensor_wrapper + del snapshot + self.assertEqual(ssd_db.get_snapshot_count(), 0)