Skip to content

Commit

Permalink
[Embedding] Add GetSnapshot and Create API for EmbeddingVariable.
Browse files Browse the repository at this point in the history
Signed-off-by: lixy9474 <lxy268263@alibaba-inc.com>
  • Loading branch information
lixy9474 committed Aug 9, 2023
1 parent 4983e02 commit 6faa4b3
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 6 deletions.
35 changes: 35 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ class EmbeddingVar : public ResourceBase {
}
}

Status Create(K key, V* value) {
ValuePtr<V>* value_ptr = nullptr;
CreateKey(key, &value_ptr, true);
LookupOrCreateEmb(value_ptr, value);
return Status::OK();
}

Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr) {
Status s = storage_->GetOrCreate(key, value_ptr,
emb_config_.total_num(storage_->GetAllocLen()));
Expand Down Expand Up @@ -592,6 +599,34 @@ class EmbeddingVar : public ResourceBase {
default_value_);
}

void GetSnapshot(std::vector<K>* key_list,
std::vector<V*>* value_list,
std::vector<int64>* version_list,
std::vector<int64>* freq_list) {
std::vector<ValuePtr<V>*> value_ptr_list;
storage_->GetSnapshot(key_list, &value_ptr_list);
bool is_save_freq = emb_config_.is_save_freq();
bool is_save_version = emb_config_.is_save_version();
for (int64 i = 0; i < key_list->size(); i++) {
V* val = value_ptr_list[i]->GetValue(emb_config_.emb_index, 0);
if (val != nullptr) {
value_list->emplace_back(val);
} else {
value_list->emplace_back(default_value_);
}

if(is_save_version) {
int64 dump_version = value_ptr_list[i]->GetStep();
version_list->emplace_back(dump_version);
}

if(is_save_freq) {
int64 dump_freq = value_ptr_list[i]->GetFreq();
freq_list->emplace_back(dump_freq);
}
}
}

mutex* mu() {
return &mu_;
}
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/core/framework/embedding/eviction_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ class EvictionManager {
"EVICTION_MANAGER", 3, /*low_latency_hint=*/false));
}

~EvictionManager() {
}
~EvictionManager() {}

TF_DISALLOW_COPY_AND_ASSIGN(EvictionManager);

Expand Down Expand Up @@ -124,8 +123,8 @@ class EvictionManager {
int64 num_of_threads_;
int64 num_of_active_threads_;
std::atomic_flag flag_ = ATOMIC_FLAG_INIT;
std::unique_ptr<thread::ThreadPool> thread_pool_;
std::map<MultiTierStorage<K,V>*, StorageItem<K, V>*> storage_table_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
mutex mu_;
};

Expand Down
54 changes: 51 additions & 3 deletions tensorflow/core/kernels/embedding_variable_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ TEST(EmbeddingVariableTest, TestLFUCache) {
}

TEST(EmbeddingVariableTest, TestCacheRestore) {
setenv("TF_SSDHASH_ASYNC_COMPACTION", "false", 1);
int64 value_size = 4;
Tensor value(DT_FLOAT, TensorShape({value_size}));
test::FillValues<float>(&value, std::vector<float>(value_size, 9.0));
Expand Down Expand Up @@ -1237,8 +1238,11 @@ TEST(EmbeddingVariableTest, TestCacheRestore) {
LOG(INFO) << "size:" << variable->Size();

BundleWriter writer(Env::Default(), Prefix("foo"));
DumpEmbeddingValues(variable, "var/part_0", &writer, &part_offset_tensor);
TF_ASSERT_OK(writer.Finish());
embedding::ShrinkArgs shrink_args;
shrink_args.global_step = 1;
variable->Save("var/part_0", Prefix("foo"), &writer, shrink_args);
TF_ASSERT_OK(writer.Finish());
variable->Unref();

auto imported_storage= embedding::StorageFactory::Create<int64, float>(
embedding::StorageConfig(embedding::DRAM_SSDHASH,
Expand All @@ -1258,6 +1262,7 @@ TEST(EmbeddingVariableTest, TestCacheRestore) {

ASSERT_EQ(imported_storage->Size(0), ev_size - cache_size);
ASSERT_EQ(imported_storage->Size(1), 2);
delete imported_storage;
}

void t1_gpu(KVInterface<int64, float>* hashmap) {
Expand Down Expand Up @@ -1703,7 +1708,50 @@ TEST(EmbeddingVariableTest, TestLookupRemoveConcurrency) {
for (auto &t : insert_threads) {
t.join();
}
}
}

TEST(EmbeddingVariableTest, TestCreateAndGetSnapshot) {
int value_size = 10;
Tensor value(DT_FLOAT, TensorShape({value_size}));
test::FillValues<float>(&value, std::vector<float>(value_size, 10.0));
auto emb_config = EmbeddingConfig(
/*emb_index = */0, /*primary_emb_index = */0,
/*block_num = */1, /*slot_num = */0,
/*name = */"", /*steps_to_live = */0,
/*filter_freq = */0, /*max_freq = */999999,
/*l2_weight_threshold = */-1.0, /*layout = */"normal",
/*max_element_size = */0, /*false_positive_probability = */-1.0,
/*counter_type = */DT_UINT64);
auto storage = embedding::StorageFactory::Create<int64, float>(
embedding::StorageConfig(), cpu_allocator(), "EmbeddingVar");
auto var = new EmbeddingVar<int64, float>("EmbeddingVar",
storage,
emb_config,
cpu_allocator());
var->Init(value, 1);
float* set_value = (float*)malloc(value_size * sizeof(float));
//Create
for (int i = 0; i < 100; i++) {
for (int j = 0; j < value_size; j++) {
set_value[j] = i + j;
}
var->Create(i, set_value);
}
free(set_value);
//GetSnapshot
std::vector<int64> key_list;
std::vector<float*> value_ptr_list;
std::vector<int64> version_list;
std::vector<int64> freq_list;
var->GetSnapshot(&key_list, &value_ptr_list,
&version_list, &freq_list);
for (int i = 0; i < key_list.size(); i++) {
ASSERT_EQ(key_list[i], i);
for (int j = 0; j < value_size; j++) {
ASSERT_EQ(value_ptr_list[i][j], i + j);
}
}
}

} // namespace
} // namespace embedding
Expand Down

0 comments on commit 6faa4b3

Please sign in to comment.