From e919596f4fdda612ea6453b1d013bf6513b33d7e Mon Sep 17 00:00:00 2001 From: "chenbangduo.cbd" Date: Wed, 10 Apr 2024 10:41:20 +0800 Subject: [PATCH] [EmbeddingVar] Add KvResourceCleanUpOp. Signed-off-by: chenbangduo.cbd --- .../core/framework/embedding/embedding_var.h | 6 +++ .../embedding/feature_descriptor_impl.h | 1 + .../framework/embedding/multi_tier_storage.h | 5 ++ .../framework/embedding/single_tier_storage.h | 38 ++++++++++++++ tensorflow/core/framework/embedding/storage.h | 2 + tensorflow/core/framework/variable.proto | 2 + tensorflow/core/kernels/kv_variable_ops.cc | 51 +++++++++++++++++++ .../core/kernels/kv_variable_restore_ops.cc | 4 +- tensorflow/core/ops/kv_variable_ops.cc | 9 ++++ tensorflow/python/ops/kv_variable_ops.py | 27 +++++++++- 10 files changed, 141 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index 81941bc9ff9..a74fe725f4d 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -733,6 +733,12 @@ class EmbeddingVar : public ResourceBase { return filter_; } + void CleanUp() { + if (emb_config_.is_primary() && emb_config_.primary_emb_index == 0) { + storage_->CleanUp(); + } + } + protected: ~EmbeddingVar() override { // When dynamic dimension embedding is used, diff --git a/tensorflow/core/framework/embedding/feature_descriptor_impl.h b/tensorflow/core/framework/embedding/feature_descriptor_impl.h index 6996d22f447..444d5990cd1 100644 --- a/tensorflow/core/framework/embedding/feature_descriptor_impl.h +++ b/tensorflow/core/framework/embedding/feature_descriptor_impl.h @@ -74,6 +74,7 @@ class NonFreqDescriptor: public BaseFreqDescriptor { public: int64 GetFreq(void* value_ptr) override { LOG(FATAL)<<"Can not get freq from NonFreqCounter."; + return 0; } BaseFreqDescriptor* Clone() override { diff --git a/tensorflow/core/framework/embedding/multi_tier_storage.h b/tensorflow/core/framework/embedding/multi_tier_storage.h index e27521f1a65..f2faab6d8f8 100644 --- a/tensorflow/core/framework/embedding/multi_tier_storage.h +++ b/tensorflow/core/framework/embedding/multi_tier_storage.h @@ -248,6 +248,11 @@ class MultiTierStorage : public Storage { FeatureDescriptor* hbm_feat_desc, FeatureDescriptor* dram_feat_desc); #endif //GOOGL_CUDA + + void CleanUp() { + LOG(FATAL) << "Function [CleanUp] of MultiTierStorage is not implemented."; + } + private: virtual Status EvictionWithDelayedDestroy(K* evict_ids, int64 evict_size) {} diff --git a/tensorflow/core/framework/embedding/single_tier_storage.h b/tensorflow/core/framework/embedding/single_tier_storage.h index 1c6bdd90790..a3df786ad73 100644 --- a/tensorflow/core/framework/embedding/single_tier_storage.h +++ b/tensorflow/core/framework/embedding/single_tier_storage.h @@ -307,6 +307,18 @@ class SingleTierStorage : public Storage { false/*to_dram*/, is_incr, restore_buff); return s; } + + void CleanUp() override { + std::vector key_list; + std::vector value_ptr_list; + kv_->GetSnapshot(&key_list, &value_ptr_list); + + int list_size = key_list.size(); + for (int i = 0; i < list_size; i++) { + kv_->Remove(key_list[i]); + feat_desc_->Deallocate(value_ptr_list[i]); + } + } protected: virtual void Shrink(std::vector& key_list, @@ -453,6 +465,10 @@ class HbmStorage : public SingleTierStorage { GPUHashTable* HashTable() override { return SingleTierStorage::kv_->HashTable(); } + + void CleanUp() override { + LOG(FATAL) << "Function [CleanUp] of HbmStorage is not implemented."; + } protected: Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id, int64 partition_num, int64 value_len, bool is_filter, @@ -495,6 +511,11 @@ class HbmStorageWithCpuKv: public SingleTierStorage { Status TryInsert(K key, void* value_ptr) { return SingleTierStorage::kv_->Insert(key, value_ptr); } + + void CleanUp() override { + LOG(FATAL) << "Function [CleanUp] of HbmStorageWithCPUKv is not implemented."; + } + public: friend class HbmDramStorage; friend class HbmDramSsdStorage; @@ -521,6 +542,10 @@ class PmemMemkindStorage : public SingleTierStorage { } ~PmemMemkindStorage() override {} + void CleanUp() override { + LOG(FATAL) << "Function [CleanUp] of PmemMemkindStorage is not implemented."; + } + TF_DISALLOW_COPY_AND_ASSIGN(PmemMemkindStorage); }; @@ -537,6 +562,10 @@ class PmemLibpmemStorage : public SingleTierStorage { return SingleTierStorage::kv_->Commit(keys, value_ptr); } + void CleanUp() override { + LOG(FATAL) << "Function [CleanUp] of PmemLibpmemStorage is not implemented."; + } + TF_DISALLOW_COPY_AND_ASSIGN(PmemLibpmemStorage); protected: @@ -577,6 +606,11 @@ class LevelDBStore : public SingleTierStorage { key_list, emb_index, value_len, leveldb_kv, SingleTierStorage::feat_desc_); } + + void CleanUp() override { + LOG(FATAL) << "Function [CleanUp] of LevelDBStorage is not implemented."; + } + public: friend class DramLevelDBStore; }; @@ -646,6 +680,10 @@ class SsdHashStorage : public SingleTierStorage { reinterpret_cast*>(SingleTierStorage::kv_); ssd_kv->SetSsdRecordDescriptor(ssd_rec_desc); } + + void CleanUp() override { + LOG(FATAL) << "Function [CleanUp] of SsdHashStorage is not implemented."; + } public: friend class DramSsdHashStorage; #if GOOGLE_CUDA diff --git a/tensorflow/core/framework/embedding/storage.h b/tensorflow/core/framework/embedding/storage.h index 559588af7e1..0fa2a4e7411 100644 --- a/tensorflow/core/framework/embedding/storage.h +++ b/tensorflow/core/framework/embedding/storage.h @@ -210,6 +210,8 @@ class Storage { return Status::OK(); } + virtual void CleanUp() = 0; + protected: virtual Status RestoreSSD(int64 emb_index, int64 emb_slot_num, int64 value_len, diff --git a/tensorflow/core/framework/variable.proto b/tensorflow/core/framework/variable.proto index 5f9e0f16b5d..4129cbab6a0 100644 --- a/tensorflow/core/framework/variable.proto +++ b/tensorflow/core/framework/variable.proto @@ -76,6 +76,8 @@ message VariableDef { bool is_embedding_var = 91; string initialize_op_for_restore = 92; + + string clean_up_op_name = 93; } message SaveSliceInfoDef { diff --git a/tensorflow/core/kernels/kv_variable_ops.cc b/tensorflow/core/kernels/kv_variable_ops.cc index b7567ffe924..1d44c7203c0 100644 --- a/tensorflow/core/kernels/kv_variable_ops.cc +++ b/tensorflow/core/kernels/kv_variable_ops.cc @@ -557,5 +557,56 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS_GPU) #undef REGISTER_KERNELS_ALL #undef REGISTER_KERNELS + +template +class KvResourceCleanUpOp : public OpKernel { + public: + explicit KvResourceCleanUpOp(OpKernelConstruction* ctx) + :OpKernel(ctx) { + TF_CHECK_OK(ReadBoolFromEnvVar("ENABLE_EV_CLEAN_UP", false, &enable_)); + } + + void Compute(OpKernelContext* ctx) { + if (!enable_) { + return; + } + + EmbeddingVar* ev = nullptr; + Status s = LookupResource(ctx, HandleFromInput(ctx, 0), &ev); + + if (s.ok()) { + ev->Unref(); + ev->CleanUp(); + } + } + + private: + bool enable_; +}; + +#define REGISTER_KERNELS(dev, ktype, vtype) \ + REGISTER_KERNEL_BUILDER(Name("KvResourceCleanUp") \ + .Device(DEVICE_##dev) \ + .TypeConstraint("Tkeys") \ + .TypeConstraint("dtype"), \ + KvResourceCleanUpOp); + +#define REGISTER_KERNELS_ALL(dev, type) \ + REGISTER_KERNELS(dev, int32, type) \ + REGISTER_KERNELS(dev, int64, type) + +#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS_ALL(CPU, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_CPU) +#undef REGISTER_KERNELS_CPU + +#if GOOGLE_CUDA +#define REGISTER_KERNELS_GPU(type) REGISTER_KERNELS_ALL(GPU, type) +TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_GPU) +#undef REGISTER_KERNELS_GPU +#endif // End of macro GOOGLE_CUDA + +#undef REGISTER_KERNELS_ALL +#undef REGISTER_KERNELS + } // namespace tensorflow diff --git a/tensorflow/core/kernels/kv_variable_restore_ops.cc b/tensorflow/core/kernels/kv_variable_restore_ops.cc index e16db9b4cd6..0a0165595f0 100644 --- a/tensorflow/core/kernels/kv_variable_restore_ops.cc +++ b/tensorflow/core/kernels/kv_variable_restore_ops.cc @@ -376,8 +376,8 @@ class KvResourceImportV3Op: public AsyncOpKernel { // EV should not be initialized at this time. if (ev->IsInitialized()) { - LOG(ERROR) << "Import parameter for EV (" << name_string - << ") failed, this EV has already been initialized."; + LOG(WARNING) << "EV (" << name_string + << ") has already been initialized."; } auto do_compute = [this, context, file_name_string, ev, diff --git a/tensorflow/core/ops/kv_variable_ops.cc b/tensorflow/core/ops/kv_variable_ops.cc index 4d003b4213e..527efd1dcc4 100644 --- a/tensorflow/core/ops/kv_variable_ops.cc +++ b/tensorflow/core/ops/kv_variable_ops.cc @@ -893,4 +893,13 @@ REGISTER_OP("KvResourceLookupResource") }) .Doc(R"doc()doc"); +REGISTER_OP("KvResourceCleanUp") + .Input("resource_handle: resource") + .Attr("Tkeys: {int64, int32}") + .Attr("dtype: type = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + return Status::OK(); + }) + .Doc(R"doc()doc"); + } // namespace tensorflow diff --git a/tensorflow/python/ops/kv_variable_ops.py b/tensorflow/python/ops/kv_variable_ops.py index 840aadf2541..90b6aae449c 100644 --- a/tensorflow/python/ops/kv_variable_ops.py +++ b/tensorflow/python/ops/kv_variable_ops.py @@ -382,7 +382,17 @@ def _init_from_args(self, dtype=self._dtype)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): - with ops.control_dependencies(None if self._is_primary else [self._primary.initializer]): + if self._is_primary: + self._clean_up_op = \ + gen_kv_variable_ops.kv_resource_clean_up(self._handle, \ + Tkeys=self._invalid_key_type, dtype=self._dtype) + else: + self._clean_up_op = self._primary.clean_up_op + + control_dep = [self._clean_up_op] + if not self._is_primary: + control_dep.append(self._primary.initializer) + with ops.control_dependencies(control_dep): self._init_op = gen_kv_variable_ops.initialize_kv_variable_v2_op( self._handle, self._primary._handle, @@ -450,7 +460,10 @@ def export(self): def create_init_op_for_restore(self, name, initial_value, invalid_key, rank): - with ops.control_dependencies(None if self._is_primary else [self._primary._init_op_for_restore]): + control_dep = [self._clean_up_op] + if not self._is_primary: + control_dep.append(self._primary._init_op_for_restore) + with ops.control_dependencies(control_dep): self._initializer_for_restore = gen_kv_variable_ops.initialize_kv_variable_v2_op( self._handle, self._primary._handle, @@ -494,6 +507,11 @@ def create_init_op_for_restore(self, name, initial_value, invalid_key, rank): def need_counts(self): return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier) + + @property + def clean_up_op(self): + return self._clean_up_op + @property def gather_op(self): return self._gather_op @@ -585,6 +603,9 @@ def _init_from_proto(self, variable_def, import_scope=None): self._primary_handle = g.as_graph_element( ops.prepend_name_scope( primary_name, import_scope=import_scope)) + + self._clean_up_op = g.as_graph_element(ops.prepend_name_scope( + variable_def.clean_up_op_name, import_scope=import_scope)) self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) self._invalid_key = -1 self._steps_to_live = init_op.get_attr("steps_to_live") @@ -913,6 +934,8 @@ def to_proto(self, export_scope=None): self._save_slice_info.to_proto(export_scope=export_scope)) var_def.initialize_op_for_restore = ops.strip_name_scope( self._init_op_for_restore.name, export_scope) + var_def.clean_up_op_name = \ + ops.strip_name_scope(self._clean_up_op.name, export_scope) return var_def else: return None