Skip to content

Commit

Permalink
add set_range (#3213)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3213

X-link: facebookresearch/FBGEMM#310

add set_range

in `KVTensorWrapper`:
- add `set_range()` method to write data into the backing storage
- make the constructor parameter `snapshot_handle` optional, because it's not needed for writing to the backing storage

Reviewed By: xunnanxu

Differential Revision: D63702602

fbshipit-source-id: 3b064c1bb12bca0c85f77739458e94a2722b3749
  • Loading branch information
Yulu Jia authored and facebook-github-bot committed Oct 3, 2024
1 parent d1c40a9 commit 9a845cc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -396,24 +396,26 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
public:
explicit KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle,
std::vector<int64_t> shape,
int64_t dtype,
int64_t row_offset)
: db_(db->impl_),
snapshot_handle_(std::move(snapshot_handle)),
shape_(std::move(shape)),
row_offset_(row_offset) {
int64_t row_offset,
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
snapshot_handle)
: db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) {
CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported";
options_ = at::TensorOptions()
.dtype(static_cast<c10::ScalarType>(dtype))
.device(at::kCPU)
.layout(at::kStrided);
if (snapshot_handle.has_value()) {
snapshot_handle_ = std::move(snapshot_handle.value());
}
}

at::Tensor narrow(int64_t dim, int64_t start, int64_t length) {
CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported";
CHECK_EQ(db_->get_max_D(), shape_[1]);
CHECK_TRUE(snapshot_handle_ != nullptr);
auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_);
db_->get_range_from_snapshot(
t, start + row_offset_, length, snapshot_handle_->handle);
Expand All @@ -422,6 +424,16 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
return t.narrow(1, 0, shape_[1]);
}

void set_range(
int64_t dim,
const int64_t start,
const int64_t length,
const at::Tensor& weights) {
CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported";
CHECK_EQ(db_->get_max_D(), shape_[1]);
db_->set_range(weights, start + row_offset_, length);
}

c10::IntArrayRef size() {
return shape_;
}
Expand Down Expand Up @@ -537,21 +549,25 @@ static auto kv_tensor_wrapper =
.def(
torch::init<
c10::intrusive_ptr<EmbeddingRocksDBWrapper>,
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>,
std::vector<int64_t>,
int64_t,
int64_t>(),
int64_t,
std::optional<
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>>(),
"",
{torch::arg("db"),
torch::arg("snapshot_handle"),
torch::arg("shape"),
torch::arg("dtype"),
torch::arg("row_offset")})
torch::arg("row_offset"),
// snapshot must be provided for reading
// not needed for writing
torch::arg("snapshot_handle") = std::nullopt})
.def(
"narrow",
&KVTensorWrapper::narrow,
"",
{torch::arg("dim"), torch::arg("start"), torch::arg("length")})
.def("set_range", &KVTensorWrapper::set_range)
.def_property("dtype_str", &KVTensorWrapper::dtype_str)
.def_property(
"shape",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,21 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
const SnapshotHandle* snapshot_handle) {
const auto seq_indices =
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
int64_t* count_ = new int64_t[1];
count_[0] = length;
const auto count = at::from_blob(count_, {1}, at::kLong);
const auto count = at::tensor({length}, at::ScalarType::Long);
folly::coro::blockingWait(
get_kv_db_async_impl(seq_indices, weights, count, snapshot_handle));
}

void set_range(
const at::Tensor& weights,
const int64_t start,
const int64_t length) {
const auto seq_indices =
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
const auto count = at::tensor({length}, at::ScalarType::Long);
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
}

int64_t get_max_D() {
return max_D_;
}
Expand Down
71 changes: 70 additions & 1 deletion fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_read_tensor_using_wrapper_from_db(self) -> None:
# create a view tensor wrapper
snapshot = ssd_db.create_snapshot()
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
ssd_db, snapshot, [E, D], weights.dtype, 0
ssd_db, [E, D], weights.dtype, 0, snapshot
)
self.assertEqual(tensor_wrapper.shape, [E, D])

Expand All @@ -100,3 +100,72 @@ def test_read_tensor_using_wrapper_from_db(self) -> None:
del tensor_wrapper
del snapshot
self.assertEqual(ssd_db.get_snapshot_count(), 0)

def test_write_tensor_to_db(self) -> None:
E = int(1e4) # num total rows
D = 128 # emb dimension
N = 1000 # window size
weights_precision = SparseType.FP32
weights_dtype = weights_precision.as_dtype()

with tempfile.TemporaryDirectory() as ssd_directory:
# pyre-fixme[16]: Module `classes` has no attribute `fbgemm`.
ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
8, # num_shards
8, # num_threads
0, # ssd_memtable_flush_period,
0, # ssd_memtable_flush_offset,
4, # ssd_l0_files_per_compact,
D, # embedding_dim
0, # ssd_rate_limit_mbps,
1, # ssd_size_ratio,
8, # ssd_compaction_trigger,
536870912, # 512MB ssd_write_buffer_size,
8, # ssd_max_write_buffer_num,
-0.01, # ssd_uniform_init_lower
0.01, # ssd_uniform_init_upper
32, # row_storage_bitwidth
10 * (2**20), # block cache size
)

weights = torch.arange(N * D, dtype=weights_dtype).view(N, D)
output_weights = torch.empty_like(weights)

# no snapshot needed for writing to rocksdb
tensor_wrapper0 = torch.classes.fbgemm.KVTensorWrapper(
ssd_db, [E, D], weights.dtype, 0
)
step = N
for i in range(0, E, step):
tensor_wrapper0.set_range(0, i, step, weights)

# force waiting for set to complete
indices = torch.arange(step)
for i in range(0, E, step):
ssd_db.get(i + indices, output_weights, torch.tensor(indices.shape[0]))

# create a view tensor wrapper
snapshot = ssd_db.create_snapshot()
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
ssd_db, [E, D], weights.dtype, 0, snapshot
)
self.assertEqual(tensor_wrapper.shape, [E, D])

# table has a total of E rows
# load 1000 rows at a time
step = 1000
for i in range(0, E, step):
narrowed = tensor_wrapper.narrow(0, i, step)
self.assertTrue(
torch.equal(narrowed, weights),
msg=(
f"Tensor value mismatch :\n"
f"actual\n{narrowed}\n\nexpected\n{weights}"
),
)

del tensor_wrapper0
del tensor_wrapper
del snapshot
self.assertEqual(ssd_db.get_snapshot_count(), 0)

0 comments on commit 9a845cc

Please sign in to comment.