From 4cd9ed895a3b8f64909b65c04d06ea4e95f761ad Mon Sep 17 00:00:00 2001 From: lixy9474 Date: Mon, 7 Aug 2023 11:38:43 +0800 Subject: [PATCH] [Embedding] Refactor the code of Save Op for EmbeddingVariable. (#900) Signed-off-by: lixy9474 --- .../core/framework/embedding/config.proto | 10 + .../embedding/dram_leveldb_storage.h | 86 ++-- .../framework/embedding/dram_pmem_storage.h | 82 ++-- .../framework/embedding/dram_ssd_storage.h | 75 +--- .../framework/embedding/embedding_config.h | 10 + .../core/framework/embedding/embedding_var.h | 32 +- .../embedding/embedding_var_ckpt_data.h | 236 ++++++++++ .../embedding/embedding_var_dump_iterator.h | 95 ++++ .../embedding/embedding_var_restore.h | 2 - .../embedding/globalstep_shrink_policy.h | 12 +- .../framework/embedding/gpu_hash_map_kv.h | 2 - .../embedding/hbm_dram_ssd_storage.h | 84 ++-- .../framework/embedding/hbm_dram_storage.h | 98 ++-- .../embedding/hbm_storage_iterator.h | 315 +++---------- .../core/framework/embedding/kv_interface.h | 20 +- .../embedding/l2weight_shrink_policy.h | 14 +- .../core/framework/embedding/leveldb_kv.h | 114 +++-- .../framework/embedding/multi_tier_storage.h | 69 +-- .../core/framework/embedding/shrink_policy.h | 8 +- .../framework/embedding/single_tier_storage.h | 195 ++++---- .../core/framework/embedding/ssd_hash_kv.h | 37 +- .../embedding/ssd_record_descriptor.h | 148 ++++++ tensorflow/core/framework/embedding/storage.h | 117 ++++- tensorflow/core/kernels/BUILD | 6 +- .../kernels/embedding_variable_ops_test.cc | 161 +------ .../embedding_variable_performance_test.cc | 13 +- tensorflow/core/kernels/kv_variable_ops.h | 420 ------------------ ...tore_ops.cc => kv_variable_restore_ops.cc} | 80 ---- tensorflow/core/kernels/save_restore_tensor.h | 54 +-- .../core/kernels/save_restore_v2_ops.cc | 18 +- .../python/ops/embedding_variable_ops_test.py | 132 ++++-- 31 files changed, 1191 insertions(+), 1554 deletions(-) create mode 100644 tensorflow/core/framework/embedding/embedding_var_ckpt_data.h create mode 100644 tensorflow/core/framework/embedding/embedding_var_dump_iterator.h create mode 100644 tensorflow/core/framework/embedding/ssd_record_descriptor.h rename tensorflow/core/kernels/{kv_variable_save_restore_ops.cc => kv_variable_restore_ops.cc} (86%) diff --git a/tensorflow/core/framework/embedding/config.proto b/tensorflow/core/framework/embedding/config.proto index 1eef3edccc2..3d5fae9f6ad 100644 --- a/tensorflow/core/framework/embedding/config.proto +++ b/tensorflow/core/framework/embedding/config.proto @@ -46,3 +46,13 @@ enum EmbeddingVariableType { IMMUTABLE = 0; MUTABLE = 1; } + +enum ValuePtrStatus { + OK = 0; + IS_DELETED = 1; +} + +enum ValuePosition { + IN_DRAM = 0; + NOT_IN_DRAM = 1; +} diff --git a/tensorflow/core/framework/embedding/dram_leveldb_storage.h b/tensorflow/core/framework/embedding/dram_leveldb_storage.h index c6c64e14865..fdb6697d541 100644 --- a/tensorflow/core/framework/embedding/dram_leveldb_storage.h +++ b/tensorflow/core/framework/embedding/dram_leveldb_storage.h @@ -111,14 +111,6 @@ class DramLevelDBStore : public MultiTierStorage { return false; } - void iterator_mutex_lock() override { - leveldb_->get_mutex()->lock(); - } - - void iterator_mutex_unlock() override { - leveldb_->get_mutex()->unlock(); - } - int64 Size() const override { int64 total_size = dram_->Size(); total_size += leveldb_->Size(); @@ -145,46 +137,58 @@ class DramLevelDBStore : public MultiTierStorage { return -1; } - Status GetSnapshot(std::vector* key_list, - std::vector*>* value_ptr_list) override { - { - mutex_lock l(*(dram_->get_mutex())); - TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list)); + Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, + const EmbeddingConfig& emb_config, + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { + std::vector key_list, tmp_leveldb_key_list; + std::vector*> value_ptr_list, tmp_leveldb_value_list; + TF_CHECK_OK(dram_->GetSnapshot(&key_list, &value_ptr_list)); + + TF_CHECK_OK(leveldb_->GetSnapshot( + &tmp_leveldb_key_list, &tmp_leveldb_value_list)); + + for (int64 i = 0; i < tmp_leveldb_value_list.size(); i++) { + tmp_leveldb_value_list[i]->SetPtr((V*)ValuePosition::NOT_IN_DRAM); + tmp_leveldb_value_list[i]->SetInitialized(emb_config.primary_emb_index); } - { - mutex_lock l(*(leveldb_->get_mutex())); - TF_CHECK_OK(leveldb_->GetSnapshot(key_list, value_ptr_list)); + + std::vector leveldb_key_list; + for (int64 i = 0; i < tmp_leveldb_key_list.size(); i++) { + Status s = dram_->Contains(tmp_leveldb_key_list[i]); + if (!s.ok()) { + key_list.emplace_back(tmp_leveldb_key_list[i]); + leveldb_key_list.emplace_back(tmp_leveldb_key_list[i]); + value_ptr_list.emplace_back(tmp_leveldb_value_list[i]); + } } - return Status::OK(); - } - Status Shrink(const ShrinkArgs& shrink_args) override { - dram_->Shrink(shrink_args); - leveldb_->Shrink(shrink_args); - return Status::OK(); - } + ValueIterator* value_iter = + leveldb_->GetValueIterator( + leveldb_key_list, emb_config.emb_index, value_len); - int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) override { - { - mutex_lock l(*(dram_->get_mutex())); - std::vector*> value_ptr_list; - std::vector key_list_tmp; - TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list)); - MultiTierStorage::SetListsForCheckpoint( - key_list_tmp, value_ptr_list, emb_config, - key_list, value_list, version_list, freq_list); - } { mutex_lock l(*(leveldb_->get_mutex())); - *it = leveldb_->GetIterator(); + TF_CHECK_OK((Storage::SaveToCheckpoint( + tensor_name, writer, + emb_config, + value_len, default_value, + key_list, + value_ptr_list, + value_iter))); } - return key_list->size(); + + for (auto it: tmp_leveldb_value_list) { + delete it; + } + + delete value_iter; + + return Status::OK(); } Status Eviction(K* evict_ids, int64 evict_size) override { diff --git a/tensorflow/core/framework/embedding/dram_pmem_storage.h b/tensorflow/core/framework/embedding/dram_pmem_storage.h index 47b6115e801..fd19f75ab4c 100644 --- a/tensorflow/core/framework/embedding/dram_pmem_storage.h +++ b/tensorflow/core/framework/embedding/dram_pmem_storage.h @@ -150,59 +150,41 @@ class DramPmemStorage : public MultiTierStorage { return -1; } - Status GetSnapshot(std::vector* key_list, - std::vector* >* value_ptr_list) override { - { - mutex_lock l(*(dram_->get_mutex())); - TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list)); - } - { - mutex_lock l(*(pmem_->get_mutex())); - TF_CHECK_OK(pmem_->GetSnapshot(key_list, value_ptr_list)); + Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, + const EmbeddingConfig& emb_config, + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { + std::vector key_list, tmp_pmem_key_list; + std::vector*> value_ptr_list, tmp_pmem_value_list; + + TF_CHECK_OK(dram_->GetSnapshot(&key_list, &value_ptr_list)); + dram_->Shrink(key_list, value_ptr_list, shrink_args, value_len); + + TF_CHECK_OK(pmem_->GetSnapshot(&tmp_pmem_key_list, + &tmp_pmem_value_list)); + pmem_->Shrink(tmp_pmem_key_list, tmp_pmem_value_list, + shrink_args, value_len); + + for (int64 i = 0; i < tmp_pmem_key_list.size(); i++) { + Status s = dram_->Contains(tmp_pmem_key_list[i]); + if (!s.ok()) { + key_list.emplace_back(tmp_pmem_key_list[i]); + value_ptr_list.emplace_back(tmp_pmem_value_list[i]); + } } - return Status::OK(); - } - Status Shrink(const ShrinkArgs& shrink_args) override { - dram_->Shrink(shrink_args); - pmem_->Shrink(shrink_args); - return Status::OK(); - } + TF_CHECK_OK((Storage::SaveToCheckpoint( + tensor_name, writer, + emb_config, + value_len, default_value, + key_list, + value_ptr_list))); - void iterator_mutex_lock() override { - return; - } - - void iterator_mutex_unlock() override { - return; - } - - int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) override { - { - mutex_lock l(*(dram_->get_mutex())); - std::vector*> value_ptr_list; - std::vector key_list_tmp; - TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list)); - MultiTierStorage::SetListsForCheckpoint( - key_list_tmp, value_ptr_list, emb_config, - key_list, value_list, version_list, freq_list); - } - { - mutex_lock l(*(pmem_->get_mutex())); - std::vector*> value_ptr_list; - std::vector key_list_tmp; - TF_CHECK_OK(pmem_->GetSnapshot(&key_list_tmp, &value_ptr_list)); - MultiTierStorage::SetListsForCheckpoint( - key_list_tmp, value_ptr_list, emb_config, - key_list, value_list, version_list, freq_list); - } - return key_list->size(); + return Status::OK(); } Status Eviction(K* evict_ids, int64 evict_size) override { diff --git a/tensorflow/core/framework/embedding/dram_ssd_storage.h b/tensorflow/core/framework/embedding/dram_ssd_storage.h index 675395c667d..4243cc14eb3 100644 --- a/tensorflow/core/framework/embedding/dram_ssd_storage.h +++ b/tensorflow/core/framework/embedding/dram_ssd_storage.h @@ -144,70 +144,21 @@ class DramSsdHashStorage : public MultiTierStorage { return true; } - Status GetSnapshot(std::vector* key_list, - std::vector*>* value_ptr_list) override { - { - mutex_lock l(*(dram_->get_mutex())); - TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list)); - } - { - mutex_lock l(*(ssd_hash_->get_mutex())); - TF_CHECK_OK(ssd_hash_->GetSnapshot(key_list, value_ptr_list)); - } - return Status::OK(); - } - - Status Shrink(const ShrinkArgs& shrink_args) override { - dram_->Shrink(shrink_args); - ssd_hash_->Shrink(shrink_args); - return Status::OK(); - } - - int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, + Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) override { - { - mutex_lock l(*(dram_->get_mutex())); - std::vector*> value_ptr_list; - std::vector key_list_tmp; - TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list)); - MultiTierStorage::SetListsForCheckpoint( - key_list_tmp, value_ptr_list, emb_config, - key_list, value_list, version_list, freq_list); - } - { - mutex_lock l(*(ssd_hash_->get_mutex())); - *it = ssd_hash_->GetIterator(); - } - return key_list->size(); - } + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { + dram_->Save(tensor_name, prefix, writer, emb_config, + shrink_args, value_len, default_value); - int64 GetSnapshotWithoutFetchPersistentEmb( - std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - SsdRecordDescriptor* ssd_rec_desc) override { - { - mutex_lock l(*(dram_->get_mutex())); - std::vector*> value_ptr_list; - std::vector temp_key_list; - TF_CHECK_OK(dram_->GetSnapshot(&temp_key_list, &value_ptr_list)); - MultiTierStorage::SetListsForCheckpoint( - temp_key_list, value_ptr_list, emb_config, - key_list, value_list, version_list, - freq_list); - } - { - mutex_lock l(*(ssd_hash_->get_mutex())); - ssd_hash_->SetSsdRecordDescriptor(ssd_rec_desc); - } - return key_list->size() + ssd_rec_desc->key_list.size(); + ssd_hash_->Save(tensor_name, prefix, writer, emb_config, + shrink_args, value_len, default_value); + + return Status::OK(); } Status RestoreSSD(int64 emb_index, int64 emb_slot_num, int64 value_len, diff --git a/tensorflow/core/framework/embedding/embedding_config.h b/tensorflow/core/framework/embedding/embedding_config.h index 3aaa259c3f2..0a50b492159 100644 --- a/tensorflow/core/framework/embedding/embedding_config.h +++ b/tensorflow/core/framework/embedding/embedding_config.h @@ -101,6 +101,16 @@ struct EmbeddingConfig { return emb_index == primary_emb_index; } + bool is_save_freq() const { + return filter_freq != 0 || + record_freq || + normal_fix_flag == 1; + } + + bool is_save_version() const { + return steps_to_live != 0 || record_version; + } + int64 total_num(int alloc_len) { return block_num * (1 + (1 - normal_fix_flag) * slot_num) * diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index 534ebf68950..9a5b5cf9a19 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -582,30 +582,14 @@ class EmbeddingVar : public ResourceBase { emb_config_, device, reader, this, filter_); } - int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - embedding::Iterator** it = nullptr) { - // for Interface Compatible - // TODO Multi-tiered Embedding should use iterator in 'GetSnapshot' caller - embedding::Iterator* _it = nullptr; - it = (it == nullptr) ? &_it : it; - return storage_->GetSnapshot( - key_list, value_list, version_list, - freq_list, emb_config_, filter_, it); - } - - int64 GetSnapshotWithoutFetchPersistentEmb( - std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - SsdRecordDescriptor* ssd_rec_desc) { - return storage_-> - GetSnapshotWithoutFetchPersistentEmb( - key_list, value_list, version_list, - freq_list, emb_config_, ssd_rec_desc); + Status Save(const string& tensor_name, + const string& prefix, + BundleWriter* writer, + embedding::ShrinkArgs& shrink_args) { + return storage_->Save(tensor_name, prefix, + writer, emb_config_, + shrink_args, value_len_, + default_value_); } mutex* mu() { diff --git a/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h new file mode 100644 index 00000000000..aa1a08cbcfd --- /dev/null +++ b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h @@ -0,0 +1,236 @@ +/* Copyright 2022 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_CKPT_DATA_ +#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_CKPT_DATA_ +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#include "tensorflow/core/kernels/save_restore_tensor.h" +#include "tensorflow/core/framework/embedding/embedding_config.h" +#include "tensorflow/core/framework/embedding/embedding_var_dump_iterator.h" +namespace tensorflow { +namespace embedding { + +template +class EmbeddingVarCkptData { + public: + void Emplace(K key, ValuePtr* value_ptr, + const EmbeddingConfig& emb_config, + V* default_value, int64 value_offset, + bool is_save_freq, + bool is_save_version, + bool save_unfiltered_features) { + if((int64)value_ptr == ValuePtrStatus::IS_DELETED) + return; + + V* primary_val = value_ptr->GetValue(0, 0); + bool is_not_admit = + primary_val == nullptr + && emb_config.filter_freq != 0; + + if (!is_not_admit) { + key_vec_.emplace_back(key); + + if (primary_val == nullptr) { + value_ptr_vec_.emplace_back(default_value); + } else if ( + (int64)primary_val == ValuePosition::NOT_IN_DRAM) { + value_ptr_vec_.emplace_back((V*)ValuePosition::NOT_IN_DRAM); + } else { + V* val = value_ptr->GetValue(emb_config.emb_index, + value_offset); + value_ptr_vec_.emplace_back(val); + } + + + if(is_save_version) { + int64 dump_version = value_ptr->GetStep(); + version_vec_.emplace_back(dump_version); + } + + if(is_save_freq) { + int64 dump_freq = value_ptr->GetFreq(); + freq_vec_.emplace_back(dump_freq); + } + } else { + if (!save_unfiltered_features) + return; + + key_filter_vec_.emplace_back(key); + + if(is_save_version) { + int64 dump_version = value_ptr->GetStep(); + version_filter_vec_.emplace_back(dump_version); + } + + int64 dump_freq = value_ptr->GetFreq(); + freq_filter_vec_.emplace_back(dump_freq); + } + } + + void Emplace(K key, V* value_ptr) { + key_vec_.emplace_back(key); + value_ptr_vec_.emplace_back(value_ptr); + } + + void SetWithPartition( + std::vector>& ev_ckpt_data_parts) { + part_offset_.resize(kSavedPartitionNum + 1); + part_filter_offset_.resize(kSavedPartitionNum + 1); + part_offset_[0] = 0; + part_filter_offset_[0] = 0; + for (int i = 0; i < kSavedPartitionNum; i++) { + part_offset_[i + 1] = + part_offset_[i] + ev_ckpt_data_parts[i].key_vec_.size(); + + part_filter_offset_[i + 1] = + part_filter_offset_[i] + + ev_ckpt_data_parts[i].key_filter_vec_.size(); + + for (int64 j = 0; j < ev_ckpt_data_parts[i].key_vec_.size(); j++) { + key_vec_.emplace_back(ev_ckpt_data_parts[i].key_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].value_ptr_vec_.size(); j++) { + value_ptr_vec_.emplace_back(ev_ckpt_data_parts[i].value_ptr_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].version_vec_.size(); j++) { + version_vec_.emplace_back(ev_ckpt_data_parts[i].version_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].freq_vec_.size(); j++) { + freq_vec_.emplace_back(ev_ckpt_data_parts[i].freq_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].key_filter_vec_.size(); j++) { + key_filter_vec_.emplace_back(ev_ckpt_data_parts[i].key_filter_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].version_filter_vec_.size(); j++) { + version_filter_vec_.emplace_back(ev_ckpt_data_parts[i].version_filter_vec_[j]); + } + + for (int64 j = 0; j < ev_ckpt_data_parts[i].freq_filter_vec_.size(); j++) { + freq_filter_vec_.emplace_back(ev_ckpt_data_parts[i].freq_filter_vec_[j]); + } + } + } + + Status ExportToCkpt(const string& tensor_name, + BundleWriter* writer, + int64 value_len, + ValueIterator* value_iter = nullptr) { + size_t bytes_limit = 8 << 20; + std::unique_ptr dump_buffer(new char[bytes_limit]); + + EVVectorDataDumpIterator key_dump_iter(key_vec_); + Status s = SaveTensorWithFixedBuffer( + tensor_name + "-keys", writer, dump_buffer.get(), + bytes_limit, &key_dump_iter, + TensorShape({key_vec_.size()})); + if (!s.ok()) + return s; + + EV2dVectorDataDumpIterator value_dump_iter( + value_ptr_vec_, value_len, value_iter); + s = SaveTensorWithFixedBuffer( + tensor_name + "-values", writer, dump_buffer.get(), + bytes_limit, &value_dump_iter, + TensorShape({value_ptr_vec_.size(), value_len})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator version_dump_iter(version_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-versions", writer, dump_buffer.get(), + bytes_limit, &version_dump_iter, + TensorShape({version_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator freq_dump_iter(freq_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-freqs", writer, dump_buffer.get(), + bytes_limit, &freq_dump_iter, + TensorShape({freq_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator filtered_key_dump_iter(key_filter_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-keys_filtered", writer, dump_buffer.get(), + bytes_limit, &filtered_key_dump_iter, + TensorShape({key_filter_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + filtered_version_dump_iter(version_filter_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-versions_filtered", + writer, dump_buffer.get(), + bytes_limit, &filtered_version_dump_iter, + TensorShape({version_filter_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + filtered_freq_dump_iter(freq_filter_vec_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-freqs_filtered", + writer, dump_buffer.get(), + bytes_limit, &filtered_freq_dump_iter, + TensorShape({freq_filter_vec_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + part_offset_dump_iter(part_offset_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-partition_offset", + writer, dump_buffer.get(), + bytes_limit, &part_offset_dump_iter, + TensorShape({part_offset_.size()})); + if (!s.ok()) + return s; + + EVVectorDataDumpIterator + part_filter_offset_dump_iter(part_filter_offset_); + s = SaveTensorWithFixedBuffer( + tensor_name + "-partition_filter_offset", + writer, dump_buffer.get(), + bytes_limit, &part_filter_offset_dump_iter, + TensorShape({part_filter_offset_.size()})); + if (!s.ok()) + return s; + + return Status::OK(); + } + + private: + std::vector key_vec_; + std::vector value_ptr_vec_; + std::vector version_vec_; + std::vector freq_vec_; + std::vector key_filter_vec_; + std::vector version_filter_vec_; + std::vector freq_filter_vec_; + std::vector part_offset_; + std::vector part_filter_offset_; + const int kSavedPartitionNum = 1000; +}; +} //namespace embedding +} //namespace tensorflow +#endif //TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_CKPT_DATA_ diff --git a/tensorflow/core/framework/embedding/embedding_var_dump_iterator.h b/tensorflow/core/framework/embedding/embedding_var_dump_iterator.h new file mode 100644 index 00000000000..71ba054b873 --- /dev/null +++ b/tensorflow/core/framework/embedding/embedding_var_dump_iterator.h @@ -0,0 +1,95 @@ +/* Copyright 2022 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_DUMP_ITERATOR_ +#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_DUMP_ITERATOR_ +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#include "tensorflow/core/kernels/save_restore_tensor.h" +namespace tensorflow { +namespace embedding { +template +class EVVectorDataDumpIterator: public DumpIterator { + public: + EVVectorDataDumpIterator(const std::vector& item_list) + : curr_iter_(item_list.begin()), + end_iter_(item_list.end()) {} + + bool HasNext() const { + return curr_iter_ != end_iter_; + } + + T Next() { + T val = *curr_iter_; + curr_iter_++; + return val; + } + + private: + typename std::vector::const_iterator curr_iter_; + typename std::vector::const_iterator end_iter_; +}; + +template +class EV2dVectorDataDumpIterator: public DumpIterator { + public: + EV2dVectorDataDumpIterator( + std::vector& valueptr_list, + int64 value_len, + ValueIterator* val_iter) + : curr_iter_(valueptr_list.begin()), + end_iter_(valueptr_list.end()), + val_iter_(val_iter), + value_len_(value_len), + col_idx_(0) { + if (!valueptr_list.empty()) { + if ((int64)*curr_iter_ == ValuePosition::NOT_IN_DRAM) { + curr_ptr_ = val_iter_->Next(); + } else { + curr_ptr_ = *curr_iter_; + } + } + } + + bool HasNext() const { + return curr_iter_ != end_iter_; + } + + T Next() { + T val = curr_ptr_[col_idx_++]; + if (col_idx_ >= value_len_) { + curr_iter_++; + col_idx_ = 0; + if (curr_iter_ != end_iter_) { + if ((int64)*curr_iter_ == ValuePosition::NOT_IN_DRAM) { + curr_ptr_ = val_iter_->Next(); + } else { + curr_ptr_ = *curr_iter_; + } + } + } + return val; + } + + private: + typename std::vector::const_iterator curr_iter_; + typename std::vector::const_iterator end_iter_; + ValueIterator* val_iter_; + int64 value_len_; + int64 col_idx_; + T* curr_ptr_ = nullptr; +}; +} //namespace embedding +} //namespace tensorflow +#endif //TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_EMBEDDING_VAR_DUMP_ITERATOR_ diff --git a/tensorflow/core/framework/embedding/embedding_var_restore.h b/tensorflow/core/framework/embedding/embedding_var_restore.h index 821ef7485e8..ec97566fbec 100644 --- a/tensorflow/core/framework/embedding/embedding_var_restore.h +++ b/tensorflow/core/framework/embedding/embedding_var_restore.h @@ -40,9 +40,7 @@ using GPUDevice = Eigen::GpuDevice; template class EmbeddingVar; - namespace { - const int kSavedPartitionNum = 1000; const size_t kBufferSize = 8 << 20; constexpr char kPartStr[] = "part_"; diff --git a/tensorflow/core/framework/embedding/globalstep_shrink_policy.h b/tensorflow/core/framework/embedding/globalstep_shrink_policy.h index 17551a6c387..a2af6a2430a 100644 --- a/tensorflow/core/framework/embedding/globalstep_shrink_policy.h +++ b/tensorflow/core/framework/embedding/globalstep_shrink_policy.h @@ -35,19 +35,18 @@ class GlobalStepShrinkPolicy : public ShrinkPolicy { TF_DISALLOW_COPY_AND_ASSIGN(GlobalStepShrinkPolicy); - void Shrink(const ShrinkArgs& shrink_args) override { + void Shrink(std::vector& key_list, + std::vector*>& value_list, + const ShrinkArgs& shrink_args) override { ShrinkPolicy::ReleaseValuePtrs(); - std::vector key_list; - std::vector*> value_list; - kv_->GetSnapshot(&key_list, &value_list); FilterToDelete(shrink_args.global_step, key_list, value_list); } private: void FilterToDelete(int64 global_step, - const std::vector& key_list, - const std::vector*>& value_list) { + std::vector& key_list, + std::vector*>& value_list) { for (int64 i = 0; i < key_list.size(); ++i) { int64 version = value_list[i]->GetStep(); if (version == -1) { @@ -56,6 +55,7 @@ class GlobalStepShrinkPolicy : public ShrinkPolicy { if (global_step - version > steps_to_live_) { kv_->Remove(key_list[i]); ShrinkPolicy::EmplacePointer(value_list[i]); + value_list[i] = (ValuePtr*)ValuePtrStatus::IS_DELETED; } } } diff --git a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h index 56542237a3e..1dd90d63a6e 100644 --- a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h @@ -256,8 +256,6 @@ class GPUHashMapKV : public KVInterface { std::string DebugString() const override { return std::string(); } - Iterator* GetIterator() override { return nullptr; } - GPUHashTable* HashTable() override { return hash_table_; } Status BatchLookup(const Eigen::GpuDevice& device, const K* keys, diff --git a/tensorflow/core/framework/embedding/hbm_dram_ssd_storage.h b/tensorflow/core/framework/embedding/hbm_dram_ssd_storage.h index 72a3ef4483c..581f1f1cfaf 100644 --- a/tensorflow/core/framework/embedding/hbm_dram_ssd_storage.h +++ b/tensorflow/core/framework/embedding/hbm_dram_ssd_storage.h @@ -302,45 +302,67 @@ class HbmDramSsdStorage : public MultiTierStorage { return false; } - void iterator_mutex_lock() override { - ssd_->get_mutex()->lock(); - } + Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, + const EmbeddingConfig& emb_config, + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { + std::vector key_list, tmp_dram_key_list; + std::vector*> value_ptr_list, tmp_dram_value_list; + TF_CHECK_OK(hbm_->GetSnapshot(&key_list, &value_ptr_list)); + hbm_->Shrink(key_list, value_ptr_list, shrink_args, value_len); + + HbmValueIterator hbm_value_iter( + key_list, value_ptr_list, + emb_config.emb_index, Storage::alloc_len_, + gpu_alloc_); + + std::vector*> tmp_hbm_value_ptrs(value_ptr_list.size()); + for (int64 i = 0; i < value_ptr_list.size(); i++) { + ValuePtr* value_ptr = hbm_->CreateValuePtr(value_len); + memcpy((char *)value_ptr->GetPtr(), + (char *)value_ptr_list[i]->GetPtr(), + sizeof(FixedLengthHeader)); + value_ptr->SetPtr((V*)ValuePosition::NOT_IN_DRAM); + value_ptr->SetInitialized(emb_config.primary_emb_index); + tmp_hbm_value_ptrs[i] = value_ptr; + value_ptr_list[i] = value_ptr; + } - void iterator_mutex_unlock() override { - ssd_->get_mutex()->unlock(); - } + TF_CHECK_OK(dram_->GetSnapshot(&tmp_dram_key_list, + &tmp_dram_value_list)); + dram_->Shrink(tmp_dram_key_list, tmp_dram_value_list, + shrink_args, value_len); - Status GetSnapshot(std::vector* key_list, - std::vector* >* value_ptr_list) override { - { - mutex_lock l(*(hbm_->get_mutex())); - TF_CHECK_OK(hbm_->GetSnapshot(key_list, value_ptr_list)); + for (int64 i = 0; i < tmp_dram_key_list.size(); i++) { + Status s = hbm_->Contains(tmp_dram_key_list[i]); + if (!s.ok()) { + key_list.emplace_back(tmp_dram_key_list[i]); + value_ptr_list.emplace_back(tmp_dram_value_list[i]); + } } + { - mutex_lock l(*(dram_->get_mutex())); - TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list)); + mutex_lock l(*(hbm_->get_mutex())); + TF_CHECK_OK((Storage::SaveToCheckpoint( + tensor_name, writer, + emb_config, + value_len, default_value, + key_list, + value_ptr_list, + &hbm_value_iter))); } - { - mutex_lock l(*(ssd_->get_mutex())); - TF_CHECK_OK(ssd_->GetSnapshot(key_list, value_ptr_list)); + + for (auto it: tmp_hbm_value_ptrs) { + delete it; } - return Status::OK(); - } - int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) override { - LOG(FATAL)<<"HbmDramSsdStorage dosen't support GetSnaoshot."; - } + ssd_->Save(tensor_name, prefix, writer, emb_config, + shrink_args, value_len, default_value); - Status Shrink(const ShrinkArgs& shrink_args) override { - hbm_->Shrink(shrink_args); - dram_->Shrink(shrink_args); - ssd_->Shrink(shrink_args); return Status::OK(); } diff --git a/tensorflow/core/framework/embedding/hbm_dram_storage.h b/tensorflow/core/framework/embedding/hbm_dram_storage.h index ce8e9a91643..518c39287e0 100644 --- a/tensorflow/core/framework/embedding/hbm_dram_storage.h +++ b/tensorflow/core/framework/embedding/hbm_dram_storage.h @@ -261,63 +261,63 @@ class HbmDramStorage : public MultiTierStorage { return false; } - void iterator_mutex_lock() override { - return; - } + Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, + const EmbeddingConfig& emb_config, + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { + std::vector key_list, tmp_dram_key_list; + std::vector*> value_ptr_list, tmp_dram_value_list; + TF_CHECK_OK(hbm_->GetSnapshot(&key_list, &value_ptr_list)); + hbm_->Shrink(key_list, value_ptr_list, shrink_args, value_len); + + HbmValueIterator hbm_value_iter( + key_list, value_ptr_list, + emb_config.emb_index, Storage::alloc_len_, + gpu_alloc_); + + std::vector*> tmp_hbm_value_ptrs(value_ptr_list.size()); + for (int64 i = 0; i < value_ptr_list.size(); i++) { + ValuePtr* value_ptr = hbm_->CreateValuePtr(value_len); + memcpy((char *)value_ptr->GetPtr(), + (char *)value_ptr_list[i]->GetPtr(), + sizeof(FixedLengthHeader)); + value_ptr->SetPtr((V*)ValuePosition::NOT_IN_DRAM); + value_ptr->SetInitialized(emb_config.primary_emb_index); + tmp_hbm_value_ptrs[i] = value_ptr; + value_ptr_list[i] = value_ptr; + } - void iterator_mutex_unlock() override { - return; - } + TF_CHECK_OK(dram_->GetSnapshot(&tmp_dram_key_list, + &tmp_dram_value_list)); + dram_->Shrink(tmp_dram_key_list, tmp_dram_value_list, + shrink_args, value_len); - Status GetSnapshot(std::vector* key_list, - std::vector* >* value_ptr_list) override { - { - mutex_lock l(*(hbm_->get_mutex())); - TF_CHECK_OK(hbm_->GetSnapshot(key_list, value_ptr_list)); - } - { - mutex_lock l(*(dram_->get_mutex())); - TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list)); + for (int64 i = 0; i < tmp_dram_key_list.size(); i++) { + Status s = hbm_->Contains(tmp_dram_key_list[i]); + if (!s.ok()) { + key_list.emplace_back(tmp_dram_key_list[i]); + value_ptr_list.emplace_back(tmp_dram_value_list[i]); + } } - return Status::OK(); - } - int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) override { - std::vector*> hbm_value_ptr_list, dram_value_ptr_list; - std::vector temp_hbm_key_list, temp_dram_key_list; - // Get Snapshot of HBM storage { mutex_lock l(*(hbm_->get_mutex())); - TF_CHECK_OK(hbm_->GetSnapshot(&temp_hbm_key_list, - &hbm_value_ptr_list)); + TF_CHECK_OK((Storage::SaveToCheckpoint( + tensor_name, writer, + emb_config, + value_len, default_value, + key_list, + value_ptr_list, + &hbm_value_iter))); } - // Get Snapshot of DRAM storage. - { - mutex_lock l(*(dram_->get_mutex())); - TF_CHECK_OK(dram_->GetSnapshot(&temp_dram_key_list, - &dram_value_ptr_list)); - } - *it = new HbmDramIterator(temp_hbm_key_list, - temp_dram_key_list, - hbm_value_ptr_list, - dram_value_ptr_list, - Storage::alloc_len_, - gpu_alloc_, - emb_config.emb_index); - // This return value is not the exact number of IDs - // because the two tables intersect. - return temp_hbm_key_list.size() + temp_dram_key_list.size(); - } - Status Shrink(const ShrinkArgs& shrink_args) override { - hbm_->Shrink(shrink_args); - dram_->Shrink(shrink_args); + for (auto it: tmp_hbm_value_ptrs) { + delete it; + } return Status::OK(); } diff --git a/tensorflow/core/framework/embedding/hbm_storage_iterator.h b/tensorflow/core/framework/embedding/hbm_storage_iterator.h index 4831b940bb8..36d331e74aa 100644 --- a/tensorflow/core/framework/embedding/hbm_storage_iterator.h +++ b/tensorflow/core/framework/embedding/hbm_storage_iterator.h @@ -23,267 +23,77 @@ template class ValuePtr; namespace embedding { -class Iterator; - -namespace { - const int kSavedPartitionNum = 1000; -} - -template -class PartitionedCheckpointData { +template +class HbmValueIterator: public ValueIterator { public: - PartitionedCheckpointData() { - key_list_parts.resize(kSavedPartitionNum); - value_list_parts.resize(kSavedPartitionNum); - version_list_parts.resize(kSavedPartitionNum); - freq_list_parts.resize(kSavedPartitionNum); - key_filter_list_parts.resize(kSavedPartitionNum); - version_filter_list_parts.resize(kSavedPartitionNum); - freq_filter_list_parts.resize(kSavedPartitionNum); - } - - ~PartitionedCheckpointData() { - } - - void EmplaceToPartList(K key, ValuePtr* value_ptr, bool is_on_hbm, - int64 emb_index, int64 emb_offset) { - int64 part_id = key % kSavedPartitionNum; - V* val = value_ptr->GetValue(emb_index, emb_offset); - V* primary_val = value_ptr->GetValue(0, 0); - - int64 freq = value_ptr->GetFreq(); - int64 version = value_ptr->GetStep(); - if (primary_val == nullptr) { - // id is filtered by feature filter. - key_filter_list_parts[part_id].emplace_back(key); - freq_filter_list_parts[part_id].emplace_back(freq); - version_filter_list_parts[part_id].emplace_back(version); - } else { - if (val != nullptr) { - key_list_parts[part_id].emplace_back(key); - freq_list_parts[part_id].emplace_back(freq); - version_list_parts[part_id].emplace_back(version); - value_list_parts[part_id].emplace_back( - std::pair(val, is_on_hbm)); + HbmValueIterator( + const std::vector& key_list, + const std::vector*>& value_ptr_list, + int64 emb_index, + int64 value_len, + Allocator* alloc) + : value_len_(value_len), + alloc_(alloc) { + int64 emb_offset = value_len_ * emb_index; + std::vector> value_parts_vec(kSavedPartitionNum); + for (int64 i = 0; i < key_list.size(); i++) { + for (int part_id = 0; part_id < kSavedPartitionNum; part_id++) { + if (key_list[i] % kSavedPartitionNum == part_id) { + value_parts_vec[part_id].emplace_back( + value_ptr_list[i]->GetValue(emb_index, emb_offset)); + break; + } } } - } - - void GenerateKeyList(std::vector* output_key_list) { - MergePartList(key_list_parts, output_key_list); - } - void GenerateFilteredKeyList(std::vector* output_filter_key_list) { - MergePartList(key_filter_list_parts, output_filter_key_list); - } - - void GenerateValueList( - std::vector>* output_value_list, - std::vector* hbm_ptr_list) { - for (int i = 0; i < kSavedPartitionNum; i++) { - for (int j = 0; j < value_list_parts[i].size(); j++) { - output_value_list->emplace_back(value_list_parts[i][j]); - if (value_list_parts[i][j].second) - hbm_ptr_list->emplace_back(value_list_parts[i][j].first); - } - } - } - - void GenerateFreqList(std::vector* output_freq_list) { - MergePartList(freq_list_parts, output_freq_list); - } - - void GenerateFilteredFreqList( - std::vector* output_filter_freq_list) { - MergePartList(freq_filter_list_parts, output_filter_freq_list); - } - - void GenerateVersionList( - std::vector* output_version_list) { - MergePartList(version_list_parts, output_version_list); - } - - void GenerateFilteredVersionList( - std::vector* output_filter_version_list) { - MergePartList(version_filter_list_parts, - output_filter_version_list); - } - - void GeneratePartOffset(std::vector* part_offset) { - for (int64 i = 0; i < kSavedPartitionNum; i++) { - (*part_offset)[i + 1] = (*part_offset)[i] + key_list_parts[i].size(); - } - } - - void GeneratePartFilterOffset(std::vector* part_filter_offset) { for (int64 i = 0; i < kSavedPartitionNum; i++) { - (*part_filter_offset)[i + 1] = (*part_filter_offset)[i] - + key_filter_list_parts[i].size(); + values_.splice(values_.end(), value_parts_vec[i]); } - } - private: - template - void MergePartList( - const std::vector>& part_list, - std::vector *output_list) { - for (int i = 0; i < kSavedPartitionNum; i++) { - for (int j = 0; j < part_list[i].size(); j++) { - output_list->emplace_back(part_list[i][j]); - } - } - } - - std::vector> key_list_parts; - std::vector>> value_list_parts; - std::vector> version_list_parts; - std::vector> freq_list_parts; - std::vector> key_filter_list_parts; - std::vector> version_filter_list_parts; - std::vector> freq_filter_list_parts; -}; - -template -class HbmDramIterator: public Iterator { - public: - HbmDramIterator( - const std::vector& hbm_key_list, - const std::vector& dram_key_list, - const std::vector*>& hbm_value_ptr_list, - const std::vector*>& dram_value_ptr_list, - int64 value_len, - Allocator* alloc, - int64 emb_index): - value_len_(value_len), - alloc_(alloc), - cursor_(0), - hbm_ptr_cursor_(0), - fill_buffer_st_(0), - fill_buffer_ed_(0), - emb_index_(emb_index) { - part_offset_.resize(kSavedPartitionNum + 1); - part_offset_[0] = 0; - part_filter_offset_.resize(kSavedPartitionNum + 1); - part_filter_offset_[0] = 0; - emb_offset_ = value_len_ * emb_index_; - std::set hbm_keys; + values_iter_ = values_.begin(); - PartitionedCheckpointData ckpt_data; - for (int64 i = 0; i < hbm_key_list.size(); i++) { - ckpt_data.EmplaceToPartList( - hbm_key_list[i], hbm_value_ptr_list[i], true, - emb_index_, emb_offset_); - hbm_keys.insert(hbm_key_list[i]); - } - for (int64 i = 0; i < dram_key_list.size(); i++) { - if (hbm_keys.find(dram_key_list[i]) == hbm_keys.end()) { - ckpt_data.EmplaceToPartList( - dram_key_list[i], dram_value_ptr_list[i], false, - emb_index_, emb_offset_); - } - } - - ckpt_data.GenerateKeyList(&key_list_); - ckpt_data.GenerateValueList(&value_list_, &hbm_ptr_list_); - ckpt_data.GenerateFreqList(&freq_list_); - ckpt_data.GenerateVersionList(&version_list_); - ckpt_data.GeneratePartOffset(&part_offset_); - - ckpt_data.GenerateFilteredKeyList(&filtered_key_list_); - ckpt_data.GenerateFilteredFreqList(&filtered_freq_list_); - ckpt_data.GenerateFilteredVersionList(&filtered_version_list_); - ckpt_data.GeneratePartFilterOffset(&part_filter_offset_); - - dev_addr_list_ = (V**)alloc_->AllocateRaw(Allocator::kAllocatorAlignment, - buffer_capacity_ / value_len_ * sizeof(V*)); - dev_embedding_buffer_ = (V*)alloc_->AllocateRaw(Allocator::kAllocatorAlignment, + num_of_embs_ = buffer_capacity_ / value_len_; + dev_addr_list_ = (V**)alloc_->AllocateRaw( + Allocator::kAllocatorAlignment, + num_of_embs_ * sizeof(V*)); + dev_embedding_buffer_ = (V*)alloc_->AllocateRaw( + Allocator::kAllocatorAlignment, buffer_capacity_ * sizeof(V)); - local_addr_list_ = new V*[buffer_capacity_ / value_len_]; + + FillEmbeddingBuffer(); } - ~HbmDramIterator() { + ~HbmValueIterator() { alloc_->DeallocateRaw(dev_addr_list_); alloc_->DeallocateRaw(dev_embedding_buffer_); - delete[] local_addr_list_; - } - - virtual bool Valid() { - return !(cursor_ == current_key_list_->size()); - } - - virtual void SeekToFirst() { - cursor_ = 0; - hbm_ptr_cursor_ = 0; - fill_buffer_st_ = 0; - fill_buffer_ed_ = 0; - } - - virtual void SwitchToFilteredFeatures() { - current_key_list_ = &filtered_key_list_; - current_freq_list_ = &filtered_freq_list_; - current_version_list_ = &filtered_version_list_; - } - - virtual void SwitchToAdmitFeatures() { - current_key_list_ = &key_list_; - current_freq_list_ = &freq_list_; - current_version_list_ = &version_list_; - } - - virtual void Next() { - cursor_++; - } - - virtual void Key(char* val, int64 dim) { - *((int64*)val) = (*current_key_list_)[cursor_]; - } - - virtual void Value(char* val, int64 dim, int64 value_offset) { - if (value_list_[cursor_].second) { - if (hbm_ptr_cursor_ == fill_buffer_ed_) { - FillEmbeddingBuffer(); - } - memcpy(val, - embedding_buffer_ + - (hbm_ptr_cursor_ - fill_buffer_st_) * value_len_, - dim); - hbm_ptr_cursor_++; - } else { - memcpy(val, value_list_[cursor_].first, dim); - } - } - - virtual void Freq(char* val, int64 dim) { - *((int64*)val) = (*current_freq_list_)[cursor_]; - } - - virtual void Version(char* val, int64 dim) { - *((int64*)val) = (*current_version_list_)[cursor_]; } - virtual void SetPartOffset(int32* part_offset_ptr) { - for (int64 i = 0; i < kSavedPartitionNum + 1; i++) { - part_offset_ptr[i] = part_offset_[i]; + V* Next() { + if (buffer_cursor_ == num_of_embs_) { + FillEmbeddingBuffer(); + buffer_cursor_ = 0; } - } - virtual void SetPartFilterOffset(int32* part_offset_ptr) { - for (int64 i = 0; i < kSavedPartitionNum + 1; i++) { - part_offset_ptr[i] = part_filter_offset_[i]; - } + V* val = embedding_buffer_ + value_len_ * buffer_cursor_; + counter_++; + values_iter_++; + buffer_cursor_++; + return val; } private: void FillEmbeddingBuffer() { int64 total_num = std::min( - buffer_capacity_ / value_len_, - (int64)(hbm_ptr_list_.size() - hbm_ptr_cursor_)); - fill_buffer_st_ = hbm_ptr_cursor_; + num_of_embs_, + (int64)(values_.size() - counter_)); + std::vector local_addr_list(total_num); + auto iter = values_iter_; for (int64 i = 0; i < total_num; i++) { - local_addr_list_[i] = hbm_ptr_list_[fill_buffer_st_ + i]; + local_addr_list[i] = *iter; + iter++; } cudaMemcpy(dev_addr_list_, - local_addr_list_, + local_addr_list.data(), sizeof(V*) * total_num, cudaMemcpyHostToDevice); int block_dim = 128; @@ -301,36 +111,19 @@ class HbmDramIterator: public Iterator { dev_embedding_buffer_, sizeof(V) * total_num * value_len_, cudaMemcpyDeviceToHost); - fill_buffer_ed_ = fill_buffer_st_ + total_num; } - - std::vector key_list_; - std::vector> value_list_; - std::vector freq_list_; - std::vector version_list_; - std::vector part_offset_; - std::vector filtered_key_list_; - std::vector filtered_freq_list_; - std::vector filtered_version_list_; - std::vector part_filter_offset_; - std::vector hbm_ptr_list_; - + private: + std::list values_; + typename std::list::iterator values_iter_; const static int64 buffer_capacity_ = 1024 * 1024 * 1; V embedding_buffer_[buffer_capacity_]; + int64 counter_ = 0; + int64 buffer_cursor_ = 0; + int64 value_len_; + int64 num_of_embs_ = 0; + Allocator* alloc_; V** dev_addr_list_; V* dev_embedding_buffer_; - V** local_addr_list_; - Allocator* alloc_; - int64 value_len_; - int64 cursor_; - int64 hbm_ptr_cursor_; - int64 fill_buffer_st_; - int64 fill_buffer_ed_; - int64 emb_index_; - int64 emb_offset_; - std::vector* current_key_list_; - std::vector* current_freq_list_; - std::vector* current_version_list_; }; } // embedding diff --git a/tensorflow/core/framework/embedding/kv_interface.h b/tensorflow/core/framework/embedding/kv_interface.h index 40108a140cc..71667cf0917 100644 --- a/tensorflow/core/framework/embedding/kv_interface.h +++ b/tensorflow/core/framework/embedding/kv_interface.h @@ -30,21 +30,11 @@ template class GPUHashTable; namespace embedding { -class Iterator { + +template +class ValueIterator { public: - Iterator() {}; - virtual ~Iterator() {}; - virtual bool Valid() {return true;}; - virtual void SeekToFirst() {}; - virtual void SwitchToFilteredFeatures() {}; - virtual void SwitchToAdmitFeatures() {}; - virtual void Next() {}; - virtual void Key(char* val, int64 dim) {}; - virtual void Freq(char* val, int64 dim) {}; - virtual void Version(char* val, int64 dim) {}; - virtual void Value(char* val, int64 dim, int64 value_offset) {}; - virtual void SetPartOffset(int32* part_offet_ptr) {}; - virtual void SetPartFilterOffset(int32* part_offet_ptr) {}; + virtual V* Next() = 0; }; template @@ -98,8 +88,6 @@ class KVInterface { virtual std::string DebugString() const = 0; - virtual Iterator* GetIterator() { return nullptr; } - virtual Status BatchLookupOrCreate(const K* keys, V* val, V* default_v, int32 default_v_num, size_t n, const Eigen::GpuDevice& device) { diff --git a/tensorflow/core/framework/embedding/l2weight_shrink_policy.h b/tensorflow/core/framework/embedding/l2weight_shrink_policy.h index 3185cd539ab..2af6b58f94b 100644 --- a/tensorflow/core/framework/embedding/l2weight_shrink_policy.h +++ b/tensorflow/core/framework/embedding/l2weight_shrink_policy.h @@ -38,20 +38,19 @@ class L2WeightShrinkPolicy : public ShrinkPolicy { ShrinkPolicy(alloc) {} TF_DISALLOW_COPY_AND_ASSIGN(L2WeightShrinkPolicy); - - void Shrink(const ShrinkArgs& shrink_args) override { + + void Shrink(std::vector& key_list, + std::vector*>& value_list, + const ShrinkArgs& shrink_args) override { ShrinkPolicy::ReleaseValuePtrs(); - std::vector key_list; - std::vector*> value_list; - kv_->GetSnapshot(&key_list, &value_list); FilterToDelete(shrink_args.value_len, key_list, value_list); } private: void FilterToDelete(int64 value_len, - const std::vector& key_list, - const std::vector*>& value_list) { + std::vector& key_list, + std::vector*>& value_list) { for (int64 i = 0; i < key_list.size(); ++i) { V* val = value_list[i]->GetValue(index_, offset_); if (val != nullptr) { @@ -62,6 +61,7 @@ class L2WeightShrinkPolicy : public ShrinkPolicy { l2_weight *= (V)0.5; if (l2_weight < (V)l2_weight_threshold_) { kv_->Remove(key_list[i]); + value_list[i] = (ValuePtr*)ValuePtrStatus::IS_DELETED; ShrinkPolicy::EmplacePointer(value_list[i]); } } diff --git a/tensorflow/core/framework/embedding/leveldb_kv.h b/tensorflow/core/framework/embedding/leveldb_kv.h index d6dc09b49b4..8ea1fa63fc2 100644 --- a/tensorflow/core/framework/embedding/leveldb_kv.h +++ b/tensorflow/core/framework/embedding/leveldb_kv.h @@ -73,45 +73,6 @@ class SizeCounter { int num_parts_; }; -class DBIterator : public Iterator { - public: - DBIterator(leveldb::Iterator* it):it_(it) {} - virtual ~DBIterator() { - delete it_; - }; - virtual bool Valid() { - return it_->Valid(); - } - virtual void SeekToFirst() { - return it_->SeekToFirst(); - } - virtual void Next() { - return it_->Next(); - } - virtual void Key(char* val, int64 dim) { - memcpy(val, it_->key().ToString().data(), dim); - } - virtual void Value(char* val, int64 dim, int64 value_offset) { - memcpy(val, - it_->value().ToString().data() + - value_offset + sizeof(FixedLengthHeader), dim); - } - virtual void Freq(char* val, int64 dim) { - memcpy(val, - it_->value().ToString().data(), sizeof(FixedLengthHeader)); - *((int64*)val) = - reinterpret_cast(val)->GetFreqCounter(); - } - virtual void Version(char* val, int64 dim) { - memcpy(val, - it_->value().ToString().data(), sizeof(FixedLengthHeader)); - *((int64*)val) = - reinterpret_cast(val)->GetGlobalStep(); - } - private: - leveldb::Iterator* it_; -}; - template class LevelDBKV : public KVInterface { public: @@ -216,14 +177,22 @@ class LevelDBKV : public KVInterface { Status GetSnapshot(std::vector* key_list, std::vector*>* value_ptr_list) override { - return Status::OK(); - } - - Iterator* GetIterator() override { ReadOptions options; options.snapshot = db_->GetSnapshot(); leveldb::Iterator* it = db_->NewIterator(options); - return new DBIterator(it); + for (it->SeekToFirst(); it->Valid(); it->Next()) { + K key; + memcpy((char*)&key, it->key().ToString().data(), sizeof(K)); + key_list->emplace_back(key); + ValuePtr* value_ptr = + new NormalGPUValuePtr(ev_allocator(), 1); + memcpy((char *)value_ptr->GetPtr(), + it->value().ToString().data(), + sizeof(FixedLengthHeader)); + value_ptr_list->emplace_back(value_ptr); + } + delete it; + return Status::OK(); } int64 Size() const override { @@ -247,6 +216,63 @@ class LevelDBKV : public KVInterface { int total_dims_; }; +template +class DBValueIterator: public ValueIterator { + public: + DBValueIterator( + const std::vector& key_list, + int64 emb_index, + int64 value_len, + LevelDBKV* leveldb_kv) + : value_len_(value_len), + emb_index_(emb_index), + leveldb_kv_(leveldb_kv) { + int64 emb_offset = value_len_ * emb_index; + std::vector> keys_parts_vec(kSavedPartitionNum); + for (int64 i = 0; i < key_list.size(); i++) { + for (int part_id = 0; part_id < kSavedPartitionNum; part_id++) { + if (key_list[i] % kSavedPartitionNum == part_id) { + keys_parts_vec[part_id].emplace_back(key_list[i]); + break; + } + } + } + + for (int64 i = 0; i < kSavedPartitionNum; i++) { + keys_.splice(keys_.end(), keys_parts_vec[i]); + } + + keys_iter_= keys_.begin(); + } + + ~DBValueIterator() { + delete value_ptr_; + } + + V* Next() { + if (value_ptr_ != nullptr) { + value_ptr_->Destroy(ev_allocator()); + delete value_ptr_; + } + K key = *(keys_iter_++); + + Status s = leveldb_kv_->Lookup(key, &value_ptr_); + if (!s.ok()) { + LOG(FATAL)<<"Not found value in LevelDB when Save."; + } + return value_ptr_->GetValue(emb_index_, value_len_ * emb_index_); + } + + private: + int64 value_len_; + int64 emb_index_; + LevelDBKV* leveldb_kv_; + std::list keys_; + typename std::list::const_iterator keys_iter_; + ValuePtr* value_ptr_ = nullptr; + int64 key_cursor_ = 0; +}; + } //namespace embedding } //namespace tensorflow diff --git a/tensorflow/core/framework/embedding/multi_tier_storage.h b/tensorflow/core/framework/embedding/multi_tier_storage.h index ac82f3911fb..ff18425ad9a 100644 --- a/tensorflow/core/framework/embedding/multi_tier_storage.h +++ b/tensorflow/core/framework/embedding/multi_tier_storage.h @@ -93,9 +93,9 @@ class MultiTierStorage : public Storage { return Status::OK(); } - embedding::Iterator* GetIterator() { - LOG(FATAL)<<"GetIterator isn't support by MultiTierStorage."; - return nullptr; + Status GetSnapshot(std::vector* key_list, + std::vector*>* value_ptr_list) override { + LOG(FATAL)<<"Can't get snapshot of MultiTierStorage."; } void CopyEmbeddingsFromCPUToGPU( @@ -110,74 +110,11 @@ class MultiTierStorage : public Storage { LOG(FATAL) << "Unsupport CopyEmbeddingsFromCPUToGPU in MultiTierStorage."; }; - void SetListsForCheckpoint( - const std::vector& input_key_list, - const std::vector*>& input_value_ptr_list, - const EmbeddingConfig& emb_config, - std::vector* output_key_list, - std::vector* output_value_list, - std::vector* output_version_list, - std::vector* output_freq_list) { - for (int64 i = 0; i < input_key_list.size(); ++i) { - output_key_list->emplace_back(input_key_list[i]); - - //NormalContiguousValuePtr is used, GetFreq() is valid. - int64 dump_freq = input_value_ptr_list[i]->GetFreq(); - output_freq_list->emplace_back(dump_freq); - - if (emb_config.steps_to_live != 0 || emb_config.record_version) { - int64 dump_version = input_value_ptr_list[i]->GetStep(); - output_version_list->emplace_back(dump_version); - } - - V* val = input_value_ptr_list[i]->GetValue(emb_config.emb_index, - Storage::GetOffset(emb_config.emb_index)); - V* primary_val = input_value_ptr_list[i]->GetValue( - emb_config.primary_emb_index, - Storage::GetOffset(emb_config.primary_emb_index)); - /* Classify features into 3 categories: - 1. filtered - 2. not involved in backward - 3. normal - */ - if (primary_val == nullptr) { - output_value_list->emplace_back(nullptr); - } else { - if (val == nullptr) { - output_value_list->emplace_back(reinterpret_cast(-1)); - } else { - output_value_list->emplace_back(val); - } - } - } - } - - virtual int64 GetSnapshotWithoutFetchPersistentEmb( - std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - SsdRecordDescriptor* ssd_rec_desc) override { - LOG(FATAL)<<"The Storage dosen't use presisten memory" - <<" or this storage hasn't suppported" - <<" GetSnapshotWithoutFetchPersistentEmb yet"; - return -1; - } - Status Contains(K key) override { LOG(FATAL)<<"Contains is not support in MultiTierStorage."; return Status::OK(); } - void iterator_mutex_lock() override { - return; - } - - void iterator_mutex_unlock() override { - return; - } - bool IsMultiLevel() override { return true; } diff --git a/tensorflow/core/framework/embedding/shrink_policy.h b/tensorflow/core/framework/embedding/shrink_policy.h index 13cb51ff30d..ea063a113a3 100644 --- a/tensorflow/core/framework/embedding/shrink_policy.h +++ b/tensorflow/core/framework/embedding/shrink_policy.h @@ -45,7 +45,9 @@ class ShrinkPolicy { TF_DISALLOW_COPY_AND_ASSIGN(ShrinkPolicy); - virtual void Shrink(const ShrinkArgs& shrink_args) = 0; + virtual void Shrink(std::vector& key_list, + std::vector*>& value_list, + const ShrinkArgs& shrink_args) = 0; protected: void EmplacePointer(ValuePtr* value_ptr) { @@ -71,7 +73,9 @@ class NonShrinkPolicy: public ShrinkPolicy { NonShrinkPolicy(): ShrinkPolicy(nullptr) {} TF_DISALLOW_COPY_AND_ASSIGN(NonShrinkPolicy); - void Shrink(const ShrinkArgs& shrink_args) {} + void Shrink(std::vector& key_list, + std::vector*>& value_list, + const ShrinkArgs& shrink_args) override {} }; } // embedding } // tensorflow diff --git a/tensorflow/core/framework/embedding/single_tier_storage.h b/tensorflow/core/framework/embedding/single_tier_storage.h index 54bf1f76c14..f9de65df588 100644 --- a/tensorflow/core/framework/embedding/single_tier_storage.h +++ b/tensorflow/core/framework/embedding/single_tier_storage.h @@ -235,73 +235,33 @@ class SingleTierStorage : public Storage { Status GetSnapshot(std::vector* key_list, std::vector*>* value_ptr_list) override { + mutex_lock l(Storage::mu_); return kv_->GetSnapshot(key_list, value_ptr_list); } - virtual int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, + Status Save( + const std::string& tensor_name, + const std::string& prefix, + BundleWriter* writer, const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) override { + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { std::vector*> value_ptr_list; std::vector key_list_tmp; - TF_CHECK_OK(kv_->GetSnapshot(&key_list_tmp, &value_ptr_list)); - if (key_list_tmp.empty()) { - *it = kv_->GetIterator(); - return 0; - } - for (int64 i = 0; i < key_list_tmp.size(); ++i) { - V* val = value_ptr_list[i]->GetValue(emb_config.emb_index, - Storage::GetOffset(emb_config.emb_index)); - V* primary_val = value_ptr_list[i]->GetValue( - emb_config.primary_emb_index, - Storage::GetOffset(emb_config.primary_emb_index)); - key_list->emplace_back(key_list_tmp[i]); - if (emb_config.filter_freq != 0 || emb_config.record_freq) { - int64 dump_freq = filter->GetFreq( - key_list_tmp[i], value_ptr_list[i]); - freq_list->emplace_back(dump_freq); - } - if (emb_config.steps_to_live != 0 || emb_config.record_version) { - int64 dump_version = value_ptr_list[i]->GetStep(); - version_list->emplace_back(dump_version); - } - if (val != nullptr && primary_val != nullptr) { - value_list->emplace_back(val); - } else if (val == nullptr && primary_val != nullptr) { - // only forward, no backward - value_list->emplace_back(reinterpret_cast(-1)); - } else { - // feature filtered - value_list->emplace_back(nullptr); - } - } - return key_list->size(); - } - - int64 GetSnapshotWithoutFetchPersistentEmb( - std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - SsdRecordDescriptor* ssd_rec_desc) override { - LOG(FATAL)<<"The Storage dosen't use presisten memory" - <<" or this storage hasn't suppported " - <<" GetSnapshotWithoutFetchPersistentEmb yet"; - return -1; - } + TF_CHECK_OK(kv_->GetSnapshot( + &key_list_tmp, &value_ptr_list)); - virtual embedding::Iterator* GetIterator() override { - LOG(FATAL)<<"GetIterator isn't support by "<::mu_); - shrink_policy_->Shrink(shrink_args); + TF_CHECK_OK((Storage::SaveToCheckpoint( + tensor_name, writer, + emb_config, + value_len, default_value, + key_list_tmp, + value_ptr_list))); return Status::OK(); } @@ -331,12 +291,8 @@ class SingleTierStorage : public Storage { return false; } - void iterator_mutex_lock() override { - return; - } - - void iterator_mutex_unlock() override { - return; + bool IsUsePersistentStorage() override { + return false; } void Schedule(std::function fn) override { @@ -366,6 +322,19 @@ class SingleTierStorage : public Storage { false/*to_dram*/, is_incr, restore_buff); return s; } + + virtual void Shrink(std::vector& key_list, + std::vector*>& value_ptr_list, + ShrinkArgs& shrink_args, + int64 value_len) { + mutex_lock l(Storage::mu_); + shrink_args.value_len = value_len; + shrink_policy_->Shrink( + key_list, + value_ptr_list, + shrink_args); + } + protected: KVInterface* kv_; ShrinkPolicy* shrink_policy_; @@ -409,6 +378,17 @@ class DramStorage : public SingleTierStorage { void SetTotalDims(int64 total_dims) override { SingleTierStorage::kv_->SetTotalDims(total_dims); } + + void Shrink(std::vector& key_list, + std::vector*>& value_ptr_list, + ShrinkArgs& shrink_args, + int64 value_len) override { + SingleTierStorage::Shrink( + key_list, + value_ptr_list, + shrink_args, + value_len); + } }; #if GOOGLE_CUDA @@ -449,18 +429,33 @@ class HbmStorage : public SingleTierStorage { size_t n, const V* default_v) override { SingleTierStorage::kv_->BatchLookup(device, keys, val, n, default_v); } - - int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, + + Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) override { + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { + std::vector value_ptr_list; + std::vector key_list_tmp; GPUHashMapKV* gpu_kv = dynamic_cast*>(SingleTierStorage::kv_); - gpu_kv->GetSnapshot(key_list, value_list, emb_config); - return key_list->size(); + gpu_kv->GetSnapshot(&key_list_tmp, &value_ptr_list, emb_config); + + TF_CHECK_OK((Storage::SaveToCheckpoint( + tensor_name, writer, + value_len, + key_list_tmp, + value_ptr_list))); + + if (value_ptr_list.size() > 0) { + TypedAllocator::Deallocate( + cpu_allocator(), value_ptr_list[0], + value_ptr_list.size() * value_len); + } + return Status::OK(); } GPUHashTable* HashTable() override { @@ -532,6 +527,17 @@ class HbmStorageWithCpuKv: public SingleTierStorage { friend class HbmDramSsdStorage; protected: void SetTotalDims(int64 total_dims) override {} + + void Shrink(std::vector& key_list, + std::vector*>& value_ptr_list, + ShrinkArgs& shrink_args, + int64 value_len) override { + SingleTierStorage::Shrink( + key_list, + value_ptr_list, + shrink_args, + value_len); + } }; #endif // GOOGLE_CUDA @@ -568,6 +574,17 @@ class PmemLibpmemStorage : public SingleTierStorage { protected: friend class DramPmemStorage; void SetTotalDims(int64 total_dims) override {} + + void Shrink(std::vector& key_list, + std::vector*>& value_ptr_list, + ShrinkArgs& shrink_args, + int64 value_len) override { + SingleTierStorage::Shrink( + key_list, + value_ptr_list, + shrink_args, + value_len); + } }; template @@ -585,10 +602,13 @@ class LevelDBStore : public SingleTierStorage { return SingleTierStorage::kv_->Commit(keys, value_ptr); } - embedding::Iterator* GetIterator() override { + embedding::ValueIterator* GetValueIterator( + const std::vector& key_list, + int64 emb_index, int64 value_len) { LevelDBKV* leveldb_kv = reinterpret_cast*>(SingleTierStorage::kv_); - return leveldb_kv->GetIterator(); + return new DBValueIterator( + key_list, emb_index, value_len, leveldb_kv); } public: friend class DramLevelDBStore; @@ -614,10 +634,25 @@ class SsdHashStorage : public SingleTierStorage { return SingleTierStorage::kv_->Commit(keys, value_ptr); } - embedding::Iterator* GetIterator() override { - SSDHashKV* ssd_kv = - reinterpret_cast*>(SingleTierStorage::kv_); - return ssd_kv->GetIterator(); + Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, + const EmbeddingConfig& emb_config, + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) override { + if (emb_config.is_primary()) { + SSDHashKV* ssd_kv = + reinterpret_cast*>(SingleTierStorage::kv_); + SsdRecordDescriptor ssd_rec_desc; + { + mutex_lock l(Storage::mu_); + ssd_kv->SetSsdRecordDescriptor(&ssd_rec_desc); + } + ssd_rec_desc.GenerateCheckpoint(prefix, tensor_name); + } + return Status::OK(); } void Import(K* key_list, int64* key_file_id_list, diff --git a/tensorflow/core/framework/embedding/ssd_hash_kv.h b/tensorflow/core/framework/embedding/ssd_hash_kv.h index 9e2afc360a9..8040421233e 100644 --- a/tensorflow/core/framework/embedding/ssd_hash_kv.h +++ b/tensorflow/core/framework/embedding/ssd_hash_kv.h @@ -22,6 +22,7 @@ limitations under the License. #include "sparsehash/dense_hash_map_lockless" #include "sparsehash/dense_hash_set_lockless" +#include "tensorflow/core/framework/embedding/ssd_record_descriptor.h" #include "tensorflow/core/framework/embedding/emb_file_creator.h" #include "tensorflow/core/framework/embedding/kv_interface.h" #include "tensorflow/core/framework/embedding/value_ptr.h" @@ -35,24 +36,6 @@ namespace tensorflow { template class ValuePtr; -template -struct SsdRecordDescriptor { - //prefix of embedding file - tstring file_prefix; - //keys in ssd storage - std::vector key_list; - //file ids of features - std::vector key_file_id_list; - //offsets in the file of features - std::vector key_offset_list; - //files in ssd storage - std::vector file_list; - //number of invalid records in the file - std::vector invalid_record_count_list; - //number of records in the file - std::vector record_count_list; -}; - namespace embedding { class EmbPosition { public: @@ -83,7 +66,7 @@ class EmbPosition { }; template -class SSDIterator : public Iterator { +class SSDIterator { public: SSDIterator(google::dense_hash_map_lockless* hash_map, const std::vector& emb_files, int64 value_len, @@ -271,19 +254,13 @@ class SSDHashKV : public KVInterface { done_ = true; } - Iterator* GetIterator() override { - return new SSDIterator(&hash_map_, emb_files_, val_len_, - write_buffer_); - } - void SetSsdRecordDescriptor(SsdRecordDescriptor* ssd_rec_desc) { mutex_lock l(compact_save_mu_); - auto ssd_iter = - reinterpret_cast*>(GetIterator()); - for (ssd_iter->SeekToFirst(); ssd_iter->Valid(); ssd_iter->Next()) { - ssd_rec_desc->key_list.emplace_back(ssd_iter->Key()); - ssd_rec_desc->key_file_id_list.emplace_back(ssd_iter->FileId()); - ssd_rec_desc->key_offset_list.emplace_back(ssd_iter->Offset()); + SSDIterator ssd_iter(&hash_map_, emb_files_, val_len_, write_buffer_); + for (ssd_iter.SeekToFirst(); ssd_iter.Valid(); ssd_iter.Next()) { + ssd_rec_desc->key_list.emplace_back(ssd_iter.Key()); + ssd_rec_desc->key_file_id_list.emplace_back(ssd_iter.FileId()); + ssd_rec_desc->key_offset_list.emplace_back(ssd_iter.Offset()); } ssd_rec_desc->file_prefix = path_; diff --git a/tensorflow/core/framework/embedding/ssd_record_descriptor.h b/tensorflow/core/framework/embedding/ssd_record_descriptor.h new file mode 100644 index 00000000000..9d015236934 --- /dev/null +++ b/tensorflow/core/framework/embedding/ssd_record_descriptor.h @@ -0,0 +1,148 @@ +/* Copyright 2022 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_SSD_RECORD_DESCRIPTOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_SSD_RECORD_DESCRIPTOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/embedding/embedding_var_dump_iterator.h" +#include "tensorflow/core/framework/embedding/kv_interface.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { +namespace embedding { + +template +class SsdRecordDescriptor { + public: + //prefix of embedding file + tstring file_prefix; + //keys in ssd storage + std::vector key_list; + //file ids of features + std::vector key_file_id_list; + //offsets in the file of features + std::vector key_offset_list; + //files in ssd storage + std::vector file_list; + //number of invalid records in the file + std::vector invalid_record_count_list; + //number of records in the file + std::vector record_count_list; + + void GenerateCheckpoint(const std::string& prefix, + const std::string& var_name) { + DumpSsdMeta(prefix, var_name); + CopyEmbeddingFilesToCkptDir(prefix, var_name); + } + + private: + template + void DumpSection(const std::vector& data_vec, + const std::string& section_str, + BundleWriter* writer, + std::vector& dump_buffer) { + EVVectorDataDumpIterator iter(data_vec); + SaveTensorWithFixedBuffer( + section_str, + writer, dump_buffer.data(), + dump_buffer.size(), &iter, + TensorShape({data_vec.size()})); + } + + void DumpSsdMeta(const std::string& prefix, + const std::string& var_name) { + std::fstream fs; + std::string var_name_temp(var_name); + std::string new_str = "_"; + int64 pos = var_name_temp.find("/"); + while (pos != std::string::npos) { + var_name_temp.replace(pos, 1, new_str.data(), 1); + pos = var_name_temp.find("/"); + } + + std::string ssd_record_path = + prefix + "-" + var_name_temp + "-ssd_record"; + BundleWriter ssd_record_writer(Env::Default(), + ssd_record_path); + size_t bytes_limit = 8 << 20; + std::vector dump_buffer(bytes_limit); + + DumpSection(key_list, "keys", + &ssd_record_writer, dump_buffer); + DumpSection(key_file_id_list, "keys_file_id", + &ssd_record_writer, dump_buffer); + DumpSection(key_offset_list, "keys_offset", + &ssd_record_writer, dump_buffer); + DumpSection(file_list, "files", + &ssd_record_writer, dump_buffer); + DumpSection(invalid_record_count_list, "invalid_record_count", + &ssd_record_writer, dump_buffer); + DumpSection(record_count_list, "record_count", + &ssd_record_writer, dump_buffer); + + ssd_record_writer.Finish(); + } + + void CopyEmbeddingFilesToCkptDir( + const std::string& prefix, + const std::string& var_name) { + std::string var_name_temp(var_name); + std::string new_str = "_"; + int64 pos = var_name_temp.find("/"); + while (pos != std::string::npos) { + var_name_temp.replace(pos, 1, new_str.data(), 1); + pos = var_name_temp.find("/"); + } + + std::string embedding_folder_path = + prefix + "-" + var_name_temp + "-emb_files/"; + Status s = Env::Default()->CreateDir(embedding_folder_path); + if (errors::IsAlreadyExists(s)) { + int64 undeleted_files, undeleted_dirs; + Env::Default()-> + DeleteRecursively(embedding_folder_path, + &undeleted_files, + &undeleted_dirs); + Env::Default()->CreateDir(embedding_folder_path); + } + + for (int64 i = 0; i < file_list.size(); i++) { + int64 file_id = file_list[i]; + std::stringstream old_ss; + old_ss << std::setw(4) << std::setfill('0') << file_id << ".emb"; + std::string file_path = file_prefix + old_ss.str(); + std::string file_name = file_path.substr(file_path.rfind("/")); + std::stringstream new_ss; + new_ss << file_id << ".emb"; + std::string new_file_path = embedding_folder_path + new_ss.str(); + Status s = Env::Default()->CopyFile(file_path, new_file_path); + if (!s.ok()) { + LOG(FATAL)<<"Copy file "< struct EmbeddingVarContext; +namespace { + const int kSavedPartitionNum = 1000; +} namespace embedding { template @@ -97,22 +101,14 @@ class Storage { virtual int64 Size(int level) const = 0; virtual Status GetSnapshot(std::vector* key_list, std::vector*>* value_ptr_list) = 0; - virtual int64 GetSnapshot(std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, + virtual Status Save( + const string& tensor_name, + const string& prefix, + BundleWriter* writer, const EmbeddingConfig& emb_config, - FilterPolicy>* filter, - embedding::Iterator** it) = 0; - virtual int64 GetSnapshotWithoutFetchPersistentEmb( - std::vector* key_list, - std::vector* value_list, - std::vector* version_list, - std::vector* freq_list, - const EmbeddingConfig& emb_config, - SsdRecordDescriptor* ssd_rec_desc) = 0; - virtual embedding::Iterator* GetIterator() = 0; - virtual Status Shrink(const ShrinkArgs& shrink_args) = 0; + ShrinkArgs& shrink_args, + int64 value_len, + V* default_value) = 0; virtual Status BatchCommit(const std::vector& keys, const std::vector*>& value_ptrs) = 0; @@ -146,8 +142,6 @@ class Storage { virtual bool IsUseHbm() = 0; virtual bool IsSingleHbm() = 0; virtual bool IsUsePersistentStorage() { return false; }; - virtual void iterator_mutex_lock() = 0; - virtual void iterator_mutex_unlock() = 0; virtual void Schedule(std::function fn) = 0; virtual void CreateEmbeddingMemoryPool( Allocator* alloc, @@ -274,6 +268,95 @@ class Storage { return Status::OK(); } + private: + void GeneratePartitionedCkptData( + const std::vector& key_list, + const std::vector*>& value_ptr_list, + EmbeddingVarCkptData* partitioned_ckpt_data, + const EmbeddingConfig& emb_config, + V* default_value) { + std::vector> + ev_ckpt_data_parts(kSavedPartitionNum); + + bool save_unfiltered_features = true; + TF_CHECK_OK(ReadBoolFromEnvVar( + "TF_EV_SAVE_FILTERED_FEATURES", true, &save_unfiltered_features)); + + bool is_save_freq = emb_config.is_save_freq(); + bool is_save_version = emb_config.is_save_version(); + + for (int64 i = 0; i < key_list.size(); i++) { + for (int part_id = 0; part_id < kSavedPartitionNum; part_id++) { + if (key_list[i] % kSavedPartitionNum == part_id) { + ev_ckpt_data_parts[part_id].Emplace( + key_list[i], value_ptr_list[i], + emb_config, default_value, + GetOffset(emb_config.emb_index), + is_save_freq, + is_save_version, + save_unfiltered_features); + break; + } + } + } + + partitioned_ckpt_data->SetWithPartition(ev_ckpt_data_parts); + } + + void GeneratePartitionedCkptData( + const std::vector& key_list, + const std::vector& value_ptr_list, + EmbeddingVarCkptData* partitioned_ckpt_data) { + std::vector> + ev_ckpt_data_parts(kSavedPartitionNum); + + for (int64 i = 0; i < key_list.size(); i++) { + for (int part_id = 0; part_id < kSavedPartitionNum; part_id++) { + if (key_list[i] % kSavedPartitionNum == part_id) { + ev_ckpt_data_parts[part_id].Emplace( + key_list[i], value_ptr_list[i]); + break; + } + } + } + + partitioned_ckpt_data->SetWithPartition(ev_ckpt_data_parts); + } + + protected: + Status SaveToCheckpoint( + const string& tensor_name, + BundleWriter* writer, + const EmbeddingConfig& emb_config, + int64 value_len, + V* default_value, + const std::vector& key_list, + const std::vector*>& value_ptr_list, + ValueIterator* value_iter = nullptr) { + EmbeddingVarCkptData partitioned_ckpt_data; + GeneratePartitionedCkptData(key_list, value_ptr_list, + &partitioned_ckpt_data, emb_config, + default_value); + Status s = + partitioned_ckpt_data.ExportToCkpt( + tensor_name, writer, value_len, value_iter); + return Status::OK(); + } + + Status SaveToCheckpoint( + const string& tensor_name, + BundleWriter* writer, + int64 value_len, + const std::vector& key_list, + const std::vector& value_ptr_list) { + EmbeddingVarCkptData partitioned_ckpt_data; + GeneratePartitionedCkptData(key_list, value_ptr_list, + &partitioned_ckpt_data); + Status s = + partitioned_ckpt_data.ExportToCkpt(tensor_name, writer, value_len); + return Status::OK(); + } + protected: int64 alloc_len_ = 0; int64 total_dims_ = 0; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 23d12c295ca..fc1b2cd9c67 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2907,7 +2907,7 @@ tf_kernel_library( hdrs = ["kv_variable_ops.h"], srcs = ["kv_variable_ops.cc", "kv_variable_lookup_ops.cc", - "kv_variable_save_restore_ops.cc"], + "kv_variable_restore_ops.cc"], copts = tf_copts() + ["-g"], deps = [ ":bounds_check", @@ -5453,14 +5453,14 @@ tf_kernel_library( "group_embedding/group_embedding_lookup_sparse_backward_ops.cu.cc",], deps = ["//third_party/eigen3", "//tensorflow/core/kernels:gpu_device_array", + "//tensorflow/core/util/tensor_bundle", ":training_op_helpers", ":variable_ops", + ":save_restore_tensor", "//tensorflow/core:embedding_gpu", "@sparsehash_c11//:dense_hash_map", "@libcuckoo//:libcuckoo", ":unique_ali_op", - ":save_restore_tensor", - "//tensorflow/core/util/tensor_bundle", "@com_github_google_leveldb//:leveldb",] + DYNAMIC_DEPS + mkl_deps() + if_cuda(["@cub_archive//:cub", ":fused_embedding_common_cuh", diff --git a/tensorflow/core/kernels/embedding_variable_ops_test.cc b/tensorflow/core/kernels/embedding_variable_ops_test.cc index 408a2bfd16c..eff4b77c2dc 100644 --- a/tensorflow/core/kernels/embedding_variable_ops_test.cc +++ b/tensorflow/core/kernels/embedding_variable_ops_test.cc @@ -118,86 +118,6 @@ std::vector AllTensorKeys(BundleReader* reader) { return ret; } -TEST(TensorBundleTest, TestEVShrinkL2) { - int64 value_size = 3; - int64 insert_num = 5; - Tensor value(DT_FLOAT, TensorShape({value_size})); - test::FillValues(&value, std::vector(value_size, 1.0)); - //float* fill_v = (float*)malloc(value_size * sizeof(float)); - EmbeddingConfig emb_config = - EmbeddingConfig(0, 0, 1, 1, "", 0, 0, 99999, 14.0); - auto storage = embedding::StorageFactory::Create( - embedding::StorageConfig( - StorageType::DRAM, - "", {1024, 1024, 1024, 1024}, - "light", - emb_config), - cpu_allocator(), - "name"); - auto emb_var = new EmbeddingVar("name", - storage, emb_config, - cpu_allocator()); - emb_var ->Init(value, 1); - - for (int64 i=0; i < insert_num; ++i) { - ValuePtr* value_ptr = nullptr; - Status s = emb_var->LookupOrCreateKey(i, &value_ptr); - typename TTypes::Flat vflat = emb_var->flat(value_ptr, i); - vflat += vflat.constant((float)i); - } - - int size = emb_var->Size(); - embedding::ShrinkArgs shrink_args; - emb_var->Shrink(shrink_args); - LOG(INFO) << "Before shrink size:" << size; - LOG(INFO) << "After shrink size:" << emb_var->Size(); - - ASSERT_EQ(emb_var->Size(), 2); -} - -TEST(TensorBundleTest, TestEVShrinkLockless) { - - int64 value_size = 64; - int64 insert_num = 30; - Tensor value(DT_FLOAT, TensorShape({value_size})); - test::FillValues(&value, std::vector(value_size, 9.0)); - float* fill_v = (float*)malloc(value_size * sizeof(float)); - - int steps_to_live = 5; - EmbeddingConfig emb_config = EmbeddingConfig(0, 0, 1, 1, "", steps_to_live); - auto storage = embedding::StorageFactory::Create( - embedding::StorageConfig( - StorageType::DRAM, - "", {1024, 1024, 1024, 1024}, - "normal", - emb_config), - cpu_allocator(), - "name"); - auto emb_var = new EmbeddingVar("name", - storage, emb_config, - cpu_allocator()); - emb_var ->Init(value, 1); - LOG(INFO) << "size:" << emb_var->Size(); - - for (int64 i=0; i < insert_num; ++i) { - ValuePtr* value_ptr = nullptr; - Status s = emb_var->LookupOrCreateKey(i, &value_ptr); - typename TTypes::Flat vflat = emb_var->flat(value_ptr, i); - emb_var->UpdateVersion(value_ptr, i); - } - - int size = emb_var->Size(); - embedding::ShrinkArgs shrink_args; - shrink_args.global_step = insert_num; - emb_var->Shrink(shrink_args); - - LOG(INFO) << "Before shrink size:" << size; - LOG(INFO) << "After shrink size: " << emb_var->Size(); - - ASSERT_EQ(size, insert_num); - ASSERT_EQ(emb_var->Size(), steps_to_live); -} - TEST(EmbeddingVariableTest, TestEmptyEV) { int64 value_size = 8; Tensor value(DT_FLOAT, TensorShape({value_size})); @@ -213,7 +133,9 @@ TEST(EmbeddingVariableTest, TestEmptyEV) { Tensor part_offset_tensor(DT_INT32, TensorShape({kSavedPartitionNum + 1})); BundleWriter writer(Env::Default(), Prefix("foo")); - DumpEmbeddingValues(variable, "var/part_0", &writer, &part_offset_tensor); + embedding::ShrinkArgs shrink_args; + shrink_args.global_step = 1; + variable->Save("var/part_0", Prefix("foo"), &writer, shrink_args); TF_ASSERT_OK(writer.Finish()); { @@ -288,7 +210,9 @@ TEST(EmbeddingVariableTest, TestEVExportSmallLockless) { LOG(INFO) << "size:" << variable->Size(); BundleWriter writer(Env::Default(), Prefix("foo")); - DumpEmbeddingValues(variable, "var/part_0", &writer, &part_offset_tensor); + embedding::ShrinkArgs shrink_args; + shrink_args.global_step = 1; + variable->Save("var/part_0", Prefix("foo"), &writer, shrink_args); TF_ASSERT_OK(writer.Finish()); { @@ -364,7 +288,9 @@ TEST(EmbeddingVariableTest, TestEVExportLargeLockless) { LOG(INFO) << "size:" << variable->Size(); BundleWriter writer(Env::Default(), Prefix("foo")); - DumpEmbeddingValues(variable, "var/part_0", &writer, &part_offset_tensor); + embedding::ShrinkArgs shrink_args; + shrink_args.global_step = 1; + variable->Save("var/part_0", Prefix("foo"), &writer, shrink_args); TF_ASSERT_OK(writer.Finish()); { @@ -444,15 +370,7 @@ TEST(EmbeddingVariableTest, TestMultiInsertion) { t.join(); } - std::vector tot_key_list; - std::vector tot_valueptr_list; - std::vector tot_version_list; - std::vector tot_freq_list; - embedding::Iterator* it = nullptr; - int64 total_size = variable->GetSnapshot(&tot_key_list, &tot_valueptr_list, &tot_version_list, &tot_freq_list, &it); - ASSERT_EQ(variable->Size(), 5); - ASSERT_EQ(variable->Size(), total_size); } void InsertAndLookup(EmbeddingVar* variable, @@ -511,9 +429,7 @@ TEST(EmbeddingVariableTest, TestBloomFilter) { std::vector version_list; std::vector freq_list; - embedding::Iterator* it = nullptr; - var->GetSnapshot(&keylist, &valuelist, &version_list, &freq_list, &it); - ASSERT_EQ(var->Size(), keylist.size()); + ASSERT_EQ(var->Size(), 1); } TEST(EmbeddingVariableTest, TestBloomCounterInt64) { @@ -1093,63 +1009,6 @@ TEST(EmbeddingVariableTest, TestSizeDBKV) { LOG(INFO) << "2 size:" << hashmap->Size(); } -TEST(EmbeddingVariableTest, TestSSDIterator) { - std::string temp_dir = testing::TmpDir(); - Allocator* alloc = ev_allocator(); - auto hashmap = new SSDHashKV(temp_dir, alloc); - hashmap->SetTotalDims(126); - ASSERT_EQ(hashmap->Size(), 0); - std::vector*> value_ptrs; - for (int64 i = 0; i < 10; ++i) { - auto tmp = new NormalContiguousValuePtr(alloc, 126); - tmp->SetValue((float)i, 126); - value_ptrs.emplace_back(tmp); - } - for (int64 i = 0; i < 10; i++) { - hashmap->Commit(i, value_ptrs[i]); - } - embedding::Iterator* it = hashmap->GetIterator(); - int64 index = 0; - float val_p[126] = {0.0}; - for (it->SeekToFirst(); it->Valid(); it->Next()) { - int64 key = -1; - it->Key((char*)&key, sizeof(int64)); - it->Value((char*)val_p, 126 * sizeof(float), 0); - ASSERT_EQ(key, index); - for (int i = 0; i < 126; i++) - ASSERT_EQ(val_p[i], key); - index++; - } -} - -TEST(EmbeddingVariableTest, TestLevelDBIterator) { - auto hashmap = new LevelDBKV(testing::TmpDir()); - hashmap->SetTotalDims(126); - ASSERT_EQ(hashmap->Size(), 0); - std::vector*> value_ptrs; - for (int64 i = 0; i < 10; ++i) { - ValuePtr* tmp = - new NormalContiguousValuePtr(ev_allocator(), 126); - tmp->SetValue((float)i, 126); - value_ptrs.emplace_back(tmp); - } - for (int64 i = 0; i < 10; i++) { - hashmap->Commit(i, value_ptrs[i]); - } - embedding::Iterator* it = hashmap->GetIterator(); - int64 index = 0; - float val_p[126] = {0.0}; - for (it->SeekToFirst(); it->Valid(); it->Next()) { - int64 key = -1; - it->Key((char*)&key, sizeof(int64)); - it->Value((char*)val_p, 126 * sizeof(float), 0); - ASSERT_EQ(key, index); - for (int i = 0; i < 126; i++) - ASSERT_EQ(val_p[i], key); - index++; - } -} - TEST(EmbeddingVariableTest, TestLRUCachePrefetch) { BatchCache* cache = new LRUCache(); int num_ids = 5; diff --git a/tensorflow/core/kernels/embedding_variable_performance_test.cc b/tensorflow/core/kernels/embedding_variable_performance_test.cc index ee04b4468f6..9b01e35840b 100644 --- a/tensorflow/core/kernels/embedding_variable_performance_test.cc +++ b/tensorflow/core/kernels/embedding_variable_performance_test.cc @@ -354,17 +354,10 @@ void PerfSave(Tensor& default_value, BundleWriter writer(Env::Default(), Prefix("foo")); timespec start, end; double total_time = 0.0; - if (steps_to_live != 0 || l2_weight_threshold != -1.0) { - clock_gettime(CLOCK_MONOTONIC, &start); - embedding::ShrinkArgs shrink_args; - shrink_args.global_step = 100; - ev->Shrink(shrink_args); - clock_gettime(CLOCK_MONOTONIC, &end); - total_time += (double)(end.tv_sec - start.tv_sec) * - 1000000000 + end.tv_nsec - start.tv_nsec; - } + embedding::ShrinkArgs shrink_args; + shrink_args.global_step = 100; clock_gettime(CLOCK_MONOTONIC, &start); - DumpEmbeddingValues(ev, "var", &writer, &part_offset_tensor); + ev->Save("var", Prefix("foo"), &writer, shrink_args); clock_gettime(CLOCK_MONOTONIC, &end); total_time += (double)(end.tv_sec - start.tv_sec) * 1000000000 + end.tv_nsec - start.tv_nsec; diff --git a/tensorflow/core/kernels/kv_variable_ops.h b/tensorflow/core/kernels/kv_variable_ops.h index b6b29acbedc..8e3572443ba 100644 --- a/tensorflow/core/kernels/kv_variable_ops.h +++ b/tensorflow/core/kernels/kv_variable_ops.h @@ -167,426 +167,6 @@ Status GetInputEmbeddingVar(OpKernelContext* ctx, int input, } } -template -void DumpSsdIndexMeta( - SsdRecordDescriptor& ssd_rec_desc, - const std::string& prefix, - const std::string& var_name) { - std::fstream fs; - std::string var_name_temp(var_name); - std::string new_str = "_"; - int64 pos = var_name_temp.find("/"); - while (pos != std::string::npos) { - var_name_temp.replace(pos, 1, new_str.data(), 1); - pos =var_name_temp.find("/"); - } - - std::string ssd_record_path = - prefix + "-" + var_name_temp + "-ssd_record"; - - BundleWriter ssd_record_writer(Env::Default(), - ssd_record_path); - typedef EVFreqDumpIterator Int64DataDumpIterator; - size_t bytes_limit = 8 << 20; - char* dump_buffer = new char[bytes_limit]; - - int64 num_of_keys = ssd_rec_desc.key_list.size(); - EVKeyDumpIterator keys_iter(ssd_rec_desc.key_list); - SaveTensorWithFixedBuffer( - "keys", - &ssd_record_writer, dump_buffer, - bytes_limit, &keys_iter, - TensorShape({num_of_keys})); - - Int64DataDumpIterator key_file_id_iter(ssd_rec_desc.key_file_id_list); - SaveTensorWithFixedBuffer( - "keys_file_id", - &ssd_record_writer, dump_buffer, - bytes_limit, &key_file_id_iter, - TensorShape({num_of_keys})); - - Int64DataDumpIterator key_offset_iter(ssd_rec_desc.key_offset_list); - SaveTensorWithFixedBuffer( - "keys_offset", - &ssd_record_writer, dump_buffer, - bytes_limit, &key_offset_iter, - TensorShape({num_of_keys})); - - int64 num_of_files = ssd_rec_desc.file_list.size(); - Int64DataDumpIterator files_iter(ssd_rec_desc.file_list); - SaveTensorWithFixedBuffer( - "files", - &ssd_record_writer, dump_buffer, - bytes_limit, &files_iter, - TensorShape({num_of_files})); - - Int64DataDumpIterator - invalid_record_count_iter(ssd_rec_desc.invalid_record_count_list); - SaveTensorWithFixedBuffer( - "invalid_record_count", - &ssd_record_writer, dump_buffer, - bytes_limit, &invalid_record_count_iter, - TensorShape({num_of_files})); - - Int64DataDumpIterator - record_count_iter(ssd_rec_desc.record_count_list); - SaveTensorWithFixedBuffer( - "record_count", - &ssd_record_writer, dump_buffer, - bytes_limit, &record_count_iter, - TensorShape({num_of_files})); - - ssd_record_writer.Finish(); - delete[] dump_buffer; -} - -template -void CopyEmbeddingfilesToCkptDir( - const SsdRecordDescriptor& ssd_rec_desc, - const std::string& prefix, - const std::string& var_name) { - std::string var_name_temp(var_name); - std::string new_str = "_"; - int64 pos = var_name_temp.find("/"); - while (pos != std::string::npos) { - var_name_temp.replace(pos, 1, new_str.data(), 1); - pos =var_name_temp.find("/"); - } - - std::string embedding_folder_path = - prefix + "-" + var_name_temp + "-emb_files/"; - Status s = Env::Default()->CreateDir(embedding_folder_path); - if (errors::IsAlreadyExists(s)) { - int64 undeleted_files, undeleted_dirs; - Env::Default()-> - DeleteRecursively(embedding_folder_path, - &undeleted_files, - &undeleted_dirs); - Env::Default()->CreateDir(embedding_folder_path); - } - - for (int64 i = 0; i < ssd_rec_desc.file_list.size(); i++) { - int64 file_id = ssd_rec_desc.file_list[i]; - std::stringstream old_ss; - old_ss << std::setw(4) << std::setfill('0') << file_id << ".emb"; - std::string file_path = ssd_rec_desc.file_prefix + old_ss.str(); - std::string file_name = file_path.substr(file_path.rfind("/")); - std::stringstream new_ss; - new_ss << file_id << ".emb"; - std::string new_file_path = embedding_folder_path + new_ss.str(); - Status s = Env::Default()->CopyFile(file_path, new_file_path); - if (!s.ok()) { - LOG(FATAL)<<"Copy file "< -Status DumpEmbeddingValues(EmbeddingVar* ev, - const string& tensor_key, BundleWriter* writer, - Tensor* part_offset_tensor, - const std::string& prefix = "") { - std::vector tot_key_list; - std::vector tot_valueptr_list; - std::vector tot_version_list; - std::vector tot_freq_list; - std::vector tot_key_filter_list; - std::vector tot_freq_filter_list; - std::vector tot_version_filter_list; - embedding::Iterator* it = nullptr; - int64 num_of_keys = 0; - //For the time being, only ev which uses SSD for storage, - //ev->IsUsePersistentStorage() will get true. - if (ev->IsUsePersistentStorage()) { - SsdRecordDescriptor ssd_rec_desc; - num_of_keys = - ev->GetSnapshotWithoutFetchPersistentEmb( - &tot_key_list, - &tot_valueptr_list, - &tot_version_list, - &tot_freq_list, - &ssd_rec_desc); - bool is_primary = (ev->GetEmbeddingIndex() == 0); - if (is_primary) { - DumpSsdIndexMeta(ssd_rec_desc, prefix, tensor_key); - CopyEmbeddingfilesToCkptDir(ssd_rec_desc, prefix, tensor_key); - } - } else { - num_of_keys = ev->GetSnapshot( - &tot_key_list, - &tot_valueptr_list, - &tot_version_list, - &tot_freq_list, &it); - } - - VLOG(1) << "EV:" << tensor_key << ", save size:" << num_of_keys; - int64 iterator_size = 0; - int64 filter_iterator_size = 0; - if (it != nullptr) { - it->SwitchToAdmitFeatures(); - ev->storage()->iterator_mutex_lock(); - for (it->SeekToFirst(); it->Valid(); it->Next()) { - ++iterator_size; - } - it->SwitchToFilteredFeatures(); - for (it->SeekToFirst(); it->Valid(); it->Next()) { - ++filter_iterator_size; - } - } - - std::vector > key_list_parts; - std::vector > valueptr_list_parts; - std::vector > version_list_parts; - std::vector > freq_list_parts; - - std::vector > key_filter_list_parts; - std::vector > version_filter_list_parts; - std::vector > freq_filter_list_parts; - - std::vector partitioned_tot_key_list; - std::vector partitioned_tot_valueptr_list; - std::vector partitioned_tot_version_list; - std::vector partitioned_tot_freq_list; - std::vector partitioned_tot_key_filter_list; - std::vector partitioned_tot_version_filter_list; - std::vector partitioned_tot_freq_filter_list; - std::vector part_filter_offset; - - key_list_parts.resize(kSavedPartitionNum); - valueptr_list_parts.resize(kSavedPartitionNum); - version_list_parts.resize(kSavedPartitionNum); - freq_list_parts.resize(kSavedPartitionNum); - key_filter_list_parts.resize(kSavedPartitionNum); - version_filter_list_parts.resize(kSavedPartitionNum); - freq_filter_list_parts.resize(kSavedPartitionNum); - part_filter_offset.resize(kSavedPartitionNum + 1); - //partitioned_tot_key_list.resize(tot_key_list.size()); - //partitioned_tot_valueptr_list.resize(tot_valueptr_list.size()); - - // save the ev with kSavedPartitionNum piece of tensor - // so that we can dynamically load ev with changed partition number - bool save_unfiltered_features = true; - TF_CHECK_OK(ReadBoolFromEnvVar( - "TF_EV_SAVE_FILTERED_FEATURES", true, &save_unfiltered_features)); - int64 filter_freq = ev->MinFreq(); - for (size_t i = 0; i < tot_key_list.size(); i++) { - for (int partid = 0; partid < kSavedPartitionNum; partid++) { - if (tot_key_list[i] % kSavedPartitionNum == partid) { - if (tot_valueptr_list[i] == reinterpret_cast(-1)) { - // only forward, no backward, bypass - } else if (tot_valueptr_list[i] == nullptr) { - if (filter_freq != 0) { - if (save_unfiltered_features) { - key_filter_list_parts[partid].push_back(tot_key_list[i]); - } - } else { - key_list_parts[partid].push_back(tot_key_list[i]); - valueptr_list_parts[partid].push_back( - ev->GetDefaultValue(tot_key_list[i])); - } - } else { - key_list_parts[partid].push_back(tot_key_list[i]); - valueptr_list_parts[partid].push_back(tot_valueptr_list[i]); - } - break; - } - } - } - - for (size_t i = 0; i < tot_version_list.size(); i++) { - for (int partid = 0; partid < kSavedPartitionNum; partid++) { - if (tot_key_list[i] % kSavedPartitionNum == partid) { - if (tot_valueptr_list[i] == reinterpret_cast(-1)) { - // only forward, no backward, bypass - } else if (tot_valueptr_list[i] == nullptr) { - if (filter_freq != 0) { - if (save_unfiltered_features) { - version_filter_list_parts[partid].push_back(tot_version_list[i]); - } - } else { - version_list_parts[partid].push_back(tot_version_list[i]); - } - } else { - version_list_parts[partid].push_back(tot_version_list[i]); - } - break; - } - } - } - - for (size_t i = 0; i < tot_freq_list.size(); i++) { - for (int partid = 0; partid < kSavedPartitionNum; partid++) { - if (tot_key_list[i] % kSavedPartitionNum == partid) { - if (tot_valueptr_list[i] == reinterpret_cast(-1)) { - // only forward, no backward, bypass - } else if (tot_valueptr_list[i] == nullptr) { - if (filter_freq != 0) { - if (save_unfiltered_features) { - freq_filter_list_parts[partid].push_back(tot_freq_list[i]); - } - } else { - freq_list_parts[partid].push_back(tot_freq_list[i]); - } - } else { - freq_list_parts[partid].push_back(tot_freq_list[i]); - } - break; - } - } - } - - auto part_offset_flat = part_offset_tensor->flat(); - part_offset_flat(0) = 0; - part_filter_offset[0] = 0; - int ptsize = 0; - for (int partid = 0; partid < kSavedPartitionNum; partid++) { - std::vector& key_list = key_list_parts[partid]; - std::vector& valueptr_list = valueptr_list_parts[partid]; - std::vector& version_list = version_list_parts[partid]; - std::vector& freq_list = freq_list_parts[partid]; - std::vector& key_filter_list = key_filter_list_parts[partid]; - std::vector& version_filter_list = - version_filter_list_parts[partid]; - std::vector& freq_filter_list = freq_filter_list_parts[partid]; - - ptsize += key_list.size(); - for (int inpid = 0; inpid < key_list.size(); inpid++) { - partitioned_tot_key_list.push_back(key_list[inpid]); - partitioned_tot_valueptr_list.push_back(valueptr_list[inpid]); - } - for (int inpid = 0; inpid < version_list.size(); inpid++) { - partitioned_tot_version_list.push_back(version_list[inpid]); - } - for (int inpid = 0; inpid < freq_list.size(); inpid++) { - partitioned_tot_freq_list.push_back(freq_list[inpid]); - } - for (int inpid = 0; inpid < key_filter_list.size(); inpid++) { - partitioned_tot_key_filter_list.push_back(key_filter_list[inpid]); - } - for (int inpid = 0; inpid < version_filter_list.size(); inpid++) { - partitioned_tot_version_filter_list.push_back(version_filter_list[inpid]); - } - for (int inpid = 0; inpid < freq_filter_list.size(); inpid++) { - partitioned_tot_freq_filter_list.push_back(freq_filter_list[inpid]); - } - - part_offset_flat(partid + 1) = part_offset_flat(partid) + key_list.size(); - part_filter_offset[partid + 1] = part_filter_offset[partid] + key_filter_list.size(); - } - // TODO: DB iterator not support partition_offset - if (it != nullptr) { - it->SetPartOffset((int32*)part_offset_tensor->data()); - } - writer->Add(tensor_key + "-partition_offset", *part_offset_tensor); - for(int i = 0; i < kSavedPartitionNum + 1; i++) { - part_offset_flat(i) = part_filter_offset[i]; - } - if (it != nullptr) { - it->SetPartFilterOffset((int32*)part_offset_tensor->data()); - } - writer->Add(tensor_key + "-partition_filter_offset", *part_offset_tensor); - - VLOG(1) << "EV before partition:" << tensor_key << ", keysize:" << tot_key_list.size() - << ", valueptr size:" << tot_valueptr_list.size(); - VLOG(1) << "EV after partition:" << tensor_key << ", ptsize:" << ptsize - << ", keysize:"<< partitioned_tot_key_list.size() - <<", valueptr size:" << partitioned_tot_valueptr_list.size(); - - size_t bytes_limit = 8 << 20; - char* dump_buffer = (char*)malloc(sizeof(char) * bytes_limit); - Status st; - if (it != nullptr) { - it->SwitchToAdmitFeatures(); - } - EVKeyDumpIterator ev_key_dump_iter(partitioned_tot_key_list); - st = SaveTensorWithFixedBuffer(tensor_key + "-keys", writer, dump_buffer, - bytes_limit, &ev_key_dump_iter, - TensorShape({partitioned_tot_key_list.size() + iterator_size}), - it); - if (!st.ok()) { - free(dump_buffer); - return st; - } - - EVValueDumpIterator ev_value_dump_iter(ev, partitioned_tot_valueptr_list); - st = SaveTensorWithFixedBuffer(tensor_key + "-values", writer, dump_buffer, - bytes_limit, &ev_value_dump_iter, - TensorShape({partitioned_tot_key_list.size() + iterator_size, ev->ValueLen()}), - it, ev->storage()->GetOffset(ev->GetEmbeddingIndex())); - if (!st.ok()) { - free(dump_buffer); - return st; - } - - EVVersionDumpIterator ev_version_dump_iter(partitioned_tot_version_list); - st = SaveTensorWithFixedBuffer(tensor_key + "-versions", writer, dump_buffer, - bytes_limit, &ev_version_dump_iter, - TensorShape({partitioned_tot_version_list.size() + iterator_size}), - it, -3); - if (!st.ok()) { - free(dump_buffer); - return st; - } - - EVFreqDumpIterator ev_freq_dump_iter(partitioned_tot_freq_list); - st = SaveTensorWithFixedBuffer(tensor_key + "-freqs", writer, dump_buffer, - bytes_limit, &ev_freq_dump_iter, - TensorShape({partitioned_tot_freq_list.size() + iterator_size}), - it, -2); - if (!st.ok()) { - free(dump_buffer); - return st; - } - if (it != nullptr) { - it->SwitchToFilteredFeatures(); - } - EVKeyDumpIterator ev_key_filter_dump_iter(partitioned_tot_key_filter_list); - st = SaveTensorWithFixedBuffer(tensor_key + "-keys_filtered", - writer, dump_buffer, bytes_limit, &ev_key_filter_dump_iter, - TensorShape({partitioned_tot_key_filter_list.size() - + filter_iterator_size}), it); - if (!st.ok()) { - free(dump_buffer); - return st; - } - - EVVersionDumpIterator ev_version_filter_dump_iter( - partitioned_tot_version_filter_list); - st = SaveTensorWithFixedBuffer(tensor_key + "-versions_filtered", - writer, dump_buffer, bytes_limit, &ev_version_filter_dump_iter, - TensorShape({partitioned_tot_version_filter_list.size() - + filter_iterator_size}), it, -3); - if (!st.ok()) { - free(dump_buffer); - return st; - } - - EVFreqDumpIterator ev_freq_filter_dump_iter( - partitioned_tot_freq_filter_list); - st = SaveTensorWithFixedBuffer(tensor_key + "-freqs_filtered", - writer, dump_buffer, bytes_limit, &ev_freq_filter_dump_iter, - TensorShape({partitioned_tot_freq_filter_list.size() - + filter_iterator_size}), it, -2); - if (!st.ok()) { - free(dump_buffer); - return st; - } - - free(dump_buffer); - - if (it != nullptr) { - ev->storage()->iterator_mutex_unlock(); - delete it; - } - - if (ev->IsSingleHbm() && tot_valueptr_list.size() > 0) { - TypedAllocator::Deallocate( - cpu_allocator(), tot_valueptr_list[0], - tot_valueptr_list.size() * ev->ValueLen()); - } - return Status::OK(); -} - Status MoveMatchingFiles( Env* env, const tstring& pattern, diff --git a/tensorflow/core/kernels/kv_variable_save_restore_ops.cc b/tensorflow/core/kernels/kv_variable_restore_ops.cc similarity index 86% rename from tensorflow/core/kernels/kv_variable_save_restore_ops.cc rename to tensorflow/core/kernels/kv_variable_restore_ops.cc index fa7e043ffd3..23a504eea5d 100644 --- a/tensorflow/core/kernels/kv_variable_save_restore_ops.cc +++ b/tensorflow/core/kernels/kv_variable_restore_ops.cc @@ -508,86 +508,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS_GPU) #undef REGISTER_KERNELS_GPU #endif // GOOGLE_CUDA -#undef REGISTER_KERNELS_ALL -#undef REGISTER_KERNELS - -// Op that outputs tensors of all keys and all values. -template -class KvResourceExportOp : public OpKernel { - public: - explicit KvResourceExportOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext *ctx) override { - EmbeddingVar *ev = nullptr; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &ev)); - core::ScopedUnref unref_me(ev); - std::vector tot_key_list; - std::vector tot_valueptr_list; - std::vector tot_version_list; - std::vector tot_freq_list; - embedding::Iterator* it = nullptr; - int64 total_size = ev->GetSnapshot( - &tot_key_list, &tot_valueptr_list, &tot_version_list, - &tot_freq_list, &it); - - // Create an output tensor - Tensor *keys_output_tensor = NULL; - Tensor *values_output_tensor = NULL; - Tensor *versions_output_tensor = NULL; - Tensor *freq_output_tensor = NULL; - - OP_REQUIRES_OK(ctx, ctx->allocate_output( - 0, TensorShape({total_size}), &keys_output_tensor)); - OP_REQUIRES_OK(ctx, ctx->allocate_output( - 1, TensorShape({total_size, ev->ValueLen()}), - &values_output_tensor)); - OP_REQUIRES_OK(ctx, ctx->allocate_output( - 2, TensorShape({tot_version_list.size()}), - &versions_output_tensor)); - OP_REQUIRES_OK(ctx, ctx->allocate_output( - 3, TensorShape({tot_freq_list.size()}), - &freq_output_tensor)); - - auto keys_output = keys_output_tensor->template flat(); - auto val_matrix = values_output_tensor->matrix(); - auto versions_output = versions_output_tensor->template flat(); - auto freq_output = freq_output_tensor->template flat(); - - for(size_t i = 0; i < total_size; i++) { - keys_output(i) = tot_key_list[i]; - TValue *value = tot_valueptr_list[i]; - for(int64 m = 0; m < ev->ValueLen(); m++) { - val_matrix(i, m) = *(value + m); - } - if (tot_version_list.size() != 0) { - versions_output(i) = tot_version_list[i]; - } - if (tot_freq_list.size() != 0) { - freq_output(i) = tot_freq_list[i]; - } - } - } -}; - -#define REGISTER_KERNELS(dev, ktype, vtype) \ - REGISTER_KERNEL_BUILDER(Name("KvResourceExport") \ - .Device(DEVICE_##dev) \ - .TypeConstraint("Tkeys") \ - .TypeConstraint("Tvalues"), \ - KvResourceExportOp); -#define REGISTER_KERNELS_ALL(dev, type) \ - REGISTER_KERNELS(dev, int32, type) \ - REGISTER_KERNELS(dev, int64, type) -#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS_ALL(CPU, type) -TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_CPU) -#undef REGISTER_KERNELS_CPU - -#if GOOGLE_CUDA -#define REGISTER_KERNELS_GPU(type) REGISTER_KERNELS_ALL(GPU, type) -TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS_GPU) -#undef REGISTER_KERNELS_GPU -#endif // GOOGLE_CUDA - #undef REGISTER_KERNELS_ALL #undef REGISTER_KERNELS } // namespace tensorflow diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h index 55572eabfb2..4f69ebe3fb5 100644 --- a/tensorflow/core/kernels/save_restore_tensor.h +++ b/tensorflow/core/kernels/save_restore_tensor.h @@ -120,10 +120,6 @@ Status SaveTensorWithFixedBuffer(const string& tensor_name, size_t bytes_limit, DumpIterator* dump_iter, const TensorShape& dump_tensor_shape, - embedding::Iterator* it = nullptr, - // -1: save key, x_offset: save embedding(primary or slot offset) - // -2: save frequency, -3: save version - int64 value_offset = -1, bool use_shape = true) { bool dump_happened = false; size_t bytes_written = 0; @@ -149,55 +145,7 @@ Status SaveTensorWithFixedBuffer(const string& tensor_name, bytes_written += sizeof(T); total_bytes_written += sizeof(T); } - if (it != nullptr) { - int64 size = 0; - if (value_offset < 0) { - size = sizeof(T); - } else { - size = sizeof(T) * dump_tensor_shape.dim_size(1); - } - char val[size] = {0}; - for (it->SeekToFirst(); it->Valid(); it->Next()) { - int64 dim = 0; - void* start = nullptr; - if (value_offset < 0) { - if (value_offset == -1){ - it->Key(val, sizeof(T)); - } else if (value_offset == -2) { - it->Freq(val, sizeof(T)); - } else { - it->Version(val, sizeof(T)); - } - if (bytes_written + sizeof(T) > bytes_limit) { - dump_happened = true; - writer->AppendSegmentData(dump_buffer, bytes_written); - bytes_written = 0; - buffer_idx = 0; - } - key_dump_buffer[buffer_idx] = *((T*)val); - buffer_idx++; - bytes_written += sizeof(T); - total_bytes_written += sizeof(T); - - } else { - dim = dump_tensor_shape.dim_size(1); - it->Value(val, dim * sizeof(T), value_offset * sizeof(T)); - - for (int j = 0; j < dim; ++j) { - if (bytes_written + sizeof(T) > bytes_limit) { - dump_happened = true; - writer->AppendSegmentData(dump_buffer, bytes_written); - bytes_written = 0; - buffer_idx = 0; - } - key_dump_buffer[buffer_idx] = *((T*)val + j); - buffer_idx++; - bytes_written += sizeof(T); - total_bytes_written += sizeof(T); - } - } - } - } + if (!dump_happened) { VLOG(1) << tensor_name << " only one buffer written, size:" << bytes_written; writer->AddCompeleteData(dump_buffer, bytes_written); diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index 313f6b81825..ace7667864c 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -126,19 +126,14 @@ class SaveV2 : public OpKernel { LookupResource(context, HandleFromInput(context, variable_index), &variable)); const Tensor& global_step = context->input(3); - Tensor part_offset_tensor; - context->allocate_temp(DT_INT32, - TensorShape({kSavedPartitionNum + 1}), - &part_offset_tensor); TGlobalStep global_step_scalar = global_step.scalar()(); core::ScopedUnref s(variable); embedding::ShrinkArgs shrink_args; shrink_args.global_step = global_step_scalar; - OP_REQUIRES_OK(context, variable->Shrink(shrink_args)); const Tensor& prefix = context->input(0); const string& prefix_string = prefix.scalar()(); - OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name, - &writer, &part_offset_tensor, prefix_string)); + OP_REQUIRES_OK(context, variable->Save(tensor_name, + prefix_string, &writer, shrink_args)); } void Compute(OpKernelContext* context) override { @@ -304,19 +299,14 @@ class SaveV3 : public OpKernel { EmbeddingVar* variable, const string& tensor_name, BundleWriter& writer) { const Tensor& global_step = context->input(5); - Tensor part_offset_tensor; - context->allocate_temp(DT_INT32, - TensorShape({kSavedPartitionNum + 1}), - &part_offset_tensor); TGlobalStep global_step_scalar = global_step.scalar()(); core::ScopedUnref s(variable); embedding::ShrinkArgs shrink_args; shrink_args.global_step = global_step_scalar; - OP_REQUIRES_OK(context, variable->Shrink(shrink_args)); const Tensor& prefix = context->input(0); const string& prefix_string = prefix.scalar()(); - OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name, - &writer, &part_offset_tensor, prefix_string)); + OP_REQUIRES_OK(context, variable->Save(tensor_name, + prefix_string, &writer, shrink_args)); } void Compute(OpKernelContext* context) override { diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index fae4ef91380..d3e453df9d1 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -453,39 +453,6 @@ def testEmbeddingVariableForLookupInt32(self): save_path = saver.save(sess, model_path, global_step=12345) saver.restore(sess, save_path) - def testEmbeddingVariableForExport(self): - print("testEmbeddingVariableForExport") - with ops.device('/cpu:0'): - ev_config = variables.EmbeddingVariableOption(filter_option=variables.CounterFilter(filter_freq=1)) - var = variable_scope.get_embedding_variable("var_1", embedding_dim=3, - initializer=init_ops.ones_initializer(dtypes.float32), steps_to_live=10000, ev_option=ev_config) - emb = embedding_ops.embedding_lookup(var, math_ops.cast([0,1,2,5,6,7], dtypes.int64)) - fun = math_ops.multiply(emb, 0.0, name='multiply') - loss = math_ops.reduce_sum(fun, name='reduce_sum') - opt = adam.AdamOptimizer(0.01) - g_v = opt.compute_gradients(loss) - gs = training_util.get_or_create_global_step() - train_op = opt.apply_gradients(g_v, gs) - init = variables.global_variables_initializer() - keys, values, versions, freqs = var.export() - with self.test_session() as sess: - sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) - sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) - sess.run([init]) - sess.run(train_op) - sess.run(train_op) - fetches = sess.run([keys, values, versions, freqs]) - print(fetches) - self.assertAllEqual([0, 1, 2, 5, 6, 7], fetches[0]) - self.assertAllEqual([[1., 1., 1.], - [1., 1., 1.], - [1., 1., 1.], - [1., 1., 1.], - [1., 1., 1.], - [1., 1., 1.]], fetches[1]) - self.assertAllEqual([1, 1, 1, 1, 1, 1], fetches[2]) - self.assertAllEqual([1, 1, 1, 1, 1, 1], fetches[3]) - def testEmbeddingVariableForGetShape(self): print("testEmbeddingVariableForGetShape") with ops.device("/cpu:0"): @@ -614,6 +581,35 @@ def testCategoricalColumnWithEmbeddingVariableFunction(self): for i in range(ids[col_name].shape.as_list()[0]): self.assertAllEqual(val_list[0][i], val_list[1][i]) + def testEmbeddinVariableForPartitionOffset(self): + print("testEmbeddinVariableForPartitionOffset") + checkpoint_directory = self.get_temp_dir() + with ops.device("/cpu:0"): + var = variable_scope.get_embedding_variable("var_1", embedding_dim = 3) + emb = embedding_ops.embedding_lookup(var, math_ops.cast([0, 1, 1000, 1001, 2, 1002], dtypes.int64)) + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + opt = adagrad.AdagradOptimizer(0.1) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + saver = saver_module.Saver(sharded=True) + init = variables.global_variables_initializer() + with self.test_session() as sess: + sess.run([init]) + sess.run(train_op) + model_path = os.path.join(checkpoint_directory, "model.ckpt") + saver.save(sess, model_path) + + for name, shape in checkpoint_utils.list_variables(model_path): + if "partition_offset" in name: + self.assertEqual(shape[0], 1001) + part_offset = checkpoint_utils.load_variable(model_path, name) + self.assertEqual(part_offset[0], 0) + self.assertEqual(part_offset[1], 2) + self.assertEqual(part_offset[2], 4) + for i in range(3, len(part_offset)): + self.assertEqual(part_offset[i], 6) + def testEmbeddingVariableForL2FeatureEvictionFromContribFeatureColumn(self): print("testEmbeddingVariableForL2FeatureEvictionFromContribFeatureColumn") checkpoint_directory = self.get_temp_dir() @@ -1829,6 +1825,76 @@ def runTestAdagrad(self, var, g): del os.environ["TF_SSDHASH_ASYNC_COMPACTION"] + def testEmbeddingVariableForDramAndLevelDBSaveCkpt(self): + print("testEmbeddingVariableForDramAndLevelDBSaveCkpt") + checkpoint_directory = self.get_temp_dir() + def runTestAdagrad(self, var, g): + ids = array_ops.placeholder(dtypes.int64, name="ids") + emb = embedding_ops.embedding_lookup(var, ids) + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = adagrad.AdagradOptimizer(0.1) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v, global_step=gs) + saver = saver_module.Saver() + init = variables.global_variables_initializer() + model_path = os.path.join(checkpoint_directory, + "model1.ckpt") + tires = kv_variable_ops.lookup_tier(emb_var, + math_ops.cast([0,1,2,3,4,5,6,7,8,9,10,11], dtypes.int64)) + with self.test_session(graph=g) as sess: + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) + sess.run([init]) + sess.run([train_op], {ids:[0,1,2,3,4,5]}) + sess.run([train_op], {ids:[6,7,8,9,10,11]}) + sess.run([train_op], {ids:[0,1,2,3,4,5]}) + result = sess.run(tires) + for i in range(0, 12): + if i in range(0, 6): + self.assertEqual(result[i], 0) + else: + self.assertEqual(result[i], 1) + saver.save(sess, model_path) + for name, shape in checkpoint_utils.list_variables(model_path): + if name == "var_1-keys" or name == "var_1/Adagrad-keys": + self.assertEqual(shape[0], 12) + keys = checkpoint_utils.load_variable(model_path, name) + self.assertAllEqual(np.array([0,1,2,3,4,5,6,7,8,9,10,11]), keys) + if name == "var_1-freqs" or name == "var_1/Adagrad-freqs": + freqs = checkpoint_utils.load_variable(model_path, name) + self.assertAllEqual(np.array([2,2,2,2,2,2,1,1,1,1,1,1]), freqs) + if name == "var_1/Adagrad-values": + values = checkpoint_utils.load_variable(model_path, name) + for i in range(0, shape[0]): + for j in range(0, shape[1]): + if i < 6: + self.assertAlmostEqual(values[i][j], 8.1, delta=1e-05) + else: + self.assertAlmostEqual(values[i][j], 4.1, delta=1e-05) + if name == "var_1-values": + values = checkpoint_utils.load_variable(model_path, name) + for i in range(0, shape[0]): + for j in range(0, shape[1]): + if i < 6: + self.assertAlmostEqual(values[i][j], 0.8309542, delta=1e-05) + else: + self.assertAlmostEqual(values[i][j], 0.90122706, delta=1e-05) + + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + storage_option = variables.StorageOption( + storage_type=config_pb2.StorageType.DRAM_LEVELDB, + storage_path = checkpoint_directory, + storage_size=[1024 * 6]) + ev_option = variables.EmbeddingVariableOption( + storage_option=storage_option) + emb_var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 128, + initializer=init_ops.ones_initializer(dtypes.float32), + ev_option = ev_option) + runTestAdagrad(self, emb_var, g) + @test_util.run_gpu_only def testEmbeddingVariableForHBMandDRAMSaveCkpt(self): print("testEmbeddingVariableForHBMandDRAMSaveCkpt")