diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index 9a5b5cf9a19..b29493f2169 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -186,6 +186,13 @@ class EmbeddingVar : public ResourceBase { } } + Status Insert(K key, V* value) { + ValuePtr* value_ptr = nullptr; + CreateKey(key, &value_ptr, true); + LookupOrCreateEmb(value_ptr, value); + return Status::OK(); + } + Status LookupOrCreateKey(K key, ValuePtr** value_ptr) { Status s = storage_->GetOrCreate(key, value_ptr, emb_config_.total_num(storage_->GetAllocLen())); @@ -592,6 +599,34 @@ class EmbeddingVar : public ResourceBase { default_value_); } + void GetSnapshot(std::vector* key_list, + std::vector* value_list, + std::vector* version_list, + std::vector* freq_list) { + std::vector*> value_ptr_list; + storage_->GetSnapshot(key_list, &value_ptr_list); + bool is_save_freq = emb_config_.is_save_freq(); + bool is_save_version = emb_config_.is_save_version(); + for (int64 i = 0; i < key_list->size(); i++) { + V* val = value_ptr_list[i]->GetValue(emb_config_.emb_index, 0); + if (val != nullptr) { + value_list->emplace_back(val); + } else { + value_list->emplace_back(default_value_); + } + + if(is_save_version) { + int64 dump_version = value_ptr_list[i]->GetStep(); + version_list->emplace_back(dump_version); + } + + if(is_save_freq) { + int64 dump_freq = value_ptr_list[i]->GetFreq(); + freq_list->emplace_back(dump_freq); + } + } + } + mutex* mu() { return &mu_; } diff --git a/tensorflow/core/framework/embedding/eviction_manager.h b/tensorflow/core/framework/embedding/eviction_manager.h index b5a78765170..ca646c9b420 100644 --- a/tensorflow/core/framework/embedding/eviction_manager.h +++ b/tensorflow/core/framework/embedding/eviction_manager.h @@ -47,8 +47,7 @@ class EvictionManager { "EVICTION_MANAGER", 3, /*low_latency_hint=*/false)); } - ~EvictionManager() { - } + ~EvictionManager() {} TF_DISALLOW_COPY_AND_ASSIGN(EvictionManager); @@ -124,8 +123,8 @@ class EvictionManager { int64 num_of_threads_; int64 num_of_active_threads_; std::atomic_flag flag_ = ATOMIC_FLAG_INIT; - std::unique_ptr thread_pool_; std::map*, StorageItem*> storage_table_; + std::unique_ptr thread_pool_; mutex mu_; }; diff --git a/tensorflow/core/kernels/embedding_variable_ops_test.cc b/tensorflow/core/kernels/embedding_variable_ops_test.cc index eff4b77c2dc..4839c171708 100644 --- a/tensorflow/core/kernels/embedding_variable_ops_test.cc +++ b/tensorflow/core/kernels/embedding_variable_ops_test.cc @@ -1191,6 +1191,7 @@ TEST(EmbeddingVariableTest, TestLFUCache) { } TEST(EmbeddingVariableTest, TestCacheRestore) { + setenv("TF_SSDHASH_ASYNC_COMPACTION", "false", 1); int64 value_size = 4; Tensor value(DT_FLOAT, TensorShape({value_size})); test::FillValues(&value, std::vector(value_size, 9.0)); @@ -1237,8 +1238,11 @@ TEST(EmbeddingVariableTest, TestCacheRestore) { LOG(INFO) << "size:" << variable->Size(); BundleWriter writer(Env::Default(), Prefix("foo")); - DumpEmbeddingValues(variable, "var/part_0", &writer, &part_offset_tensor); - TF_ASSERT_OK(writer.Finish()); + embedding::ShrinkArgs shrink_args; + shrink_args.global_step = 1; + variable->Save("var/part_0", Prefix("foo"), &writer, shrink_args); + TF_ASSERT_OK(writer.Finish()); + variable->Unref(); auto imported_storage= embedding::StorageFactory::Create( embedding::StorageConfig(embedding::DRAM_SSDHASH, @@ -1258,6 +1262,7 @@ TEST(EmbeddingVariableTest, TestCacheRestore) { ASSERT_EQ(imported_storage->Size(0), ev_size - cache_size); ASSERT_EQ(imported_storage->Size(1), 2); + delete imported_storage; } void t1_gpu(KVInterface* hashmap) { @@ -1703,7 +1708,50 @@ TEST(EmbeddingVariableTest, TestLookupRemoveConcurrency) { for (auto &t : insert_threads) { t.join(); } - } +} + +TEST(EmbeddingVariableTest, TestInsertAndGetSnapshot) { + int value_size = 10; + Tensor value(DT_FLOAT, TensorShape({value_size})); + test::FillValues(&value, std::vector(value_size, 10.0)); + auto emb_config = EmbeddingConfig( + /*emb_index = */0, /*primary_emb_index = */0, + /*block_num = */1, /*slot_num = */0, + /*name = */"", /*steps_to_live = */0, + /*filter_freq = */0, /*max_freq = */999999, + /*l2_weight_threshold = */-1.0, /*layout = */"normal", + /*max_element_size = */0, /*false_positive_probability = */-1.0, + /*counter_type = */DT_UINT64); + auto storage = embedding::StorageFactory::Create( + embedding::StorageConfig(), cpu_allocator(), "EmbeddingVar"); + auto var = new EmbeddingVar("EmbeddingVar", + storage, + emb_config, + cpu_allocator()); + var->Init(value, 1); + float* set_value = (float*)malloc(value_size * sizeof(float)); + //Insertion + for (int i = 0; i < 100; i++) { + for (int j = 0; j < value_size; j++) { + set_value[j] = i + j; + } + var->Insert(i, set_value); + } + free(set_value); + //GetSnapshot + std::vector key_list; + std::vector value_ptr_list; + std::vector version_list; + std::vector freq_list; + var->GetSnapshot(&key_list, &value_ptr_list, + &version_list, &freq_list); + for (int i = 0; i < key_list.size(); i++) { + ASSERT_EQ(key_list[i], i); + for (int j = 0; j < value_size; j++) { + ASSERT_EQ(value_ptr_list[i][j], i + j); + } + } +} } // namespace } // namespace embedding