Skip to content

Commit

Permalink
[EmbeddingVar] Add KvResourceCleanUpOp.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
  • Loading branch information
JackMoriarty committed Apr 11, 2024
1 parent cf16856 commit e919596
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 4 deletions.
6 changes: 6 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,12 @@ class EmbeddingVar : public ResourceBase {
return filter_;
}

void CleanUp() {
if (emb_config_.is_primary() && emb_config_.primary_emb_index == 0) {
storage_->CleanUp();
}
}

protected:
~EmbeddingVar() override {
// When dynamic dimension embedding is used,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class NonFreqDescriptor: public BaseFreqDescriptor {
public:
int64 GetFreq(void* value_ptr) override {
LOG(FATAL)<<"Can not get freq from NonFreqCounter.";
return 0;
}

BaseFreqDescriptor* Clone() override {
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/core/framework/embedding/multi_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ class MultiTierStorage : public Storage<K, V> {
FeatureDescriptor<V>* hbm_feat_desc,
FeatureDescriptor<V>* dram_feat_desc);
#endif //GOOGL_CUDA

void CleanUp() {
LOG(FATAL) << "Function [CleanUp] of MultiTierStorage is not implemented.";
}

private:
virtual Status EvictionWithDelayedDestroy(K* evict_ids, int64 evict_size) {}

Expand Down
38 changes: 38 additions & 0 deletions tensorflow/core/framework/embedding/single_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,18 @@ class SingleTierStorage : public Storage<K, V> {
false/*to_dram*/, is_incr, restore_buff);
return s;
}

void CleanUp() override {
std::vector<K> key_list;
std::vector<void*> value_ptr_list;
kv_->GetSnapshot(&key_list, &value_ptr_list);

int list_size = key_list.size();
for (int i = 0; i < list_size; i++) {
kv_->Remove(key_list[i]);
feat_desc_->Deallocate(value_ptr_list[i]);
}
}

protected:
virtual void Shrink(std::vector<K>& key_list,
Expand Down Expand Up @@ -453,6 +465,10 @@ class HbmStorage : public SingleTierStorage<K, V> {
GPUHashTable<K, V>* HashTable() override {
return SingleTierStorage<K, V>::kv_->HashTable();
}

void CleanUp() override {
LOG(FATAL) << "Function [CleanUp] of HbmStorage is not implemented.";
}
protected:
Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id,
int64 partition_num, int64 value_len, bool is_filter,
Expand Down Expand Up @@ -495,6 +511,11 @@ class HbmStorageWithCpuKv: public SingleTierStorage<K, V> {
Status TryInsert(K key, void* value_ptr) {
return SingleTierStorage<K, V>::kv_->Insert(key, value_ptr);
}

void CleanUp() override {
LOG(FATAL) << "Function [CleanUp] of HbmStorageWithCPUKv is not implemented.";
}

public:
friend class HbmDramStorage<K, V>;
friend class HbmDramSsdStorage<K, V>;
Expand All @@ -521,6 +542,10 @@ class PmemMemkindStorage : public SingleTierStorage<K, V> {
}
~PmemMemkindStorage() override {}

void CleanUp() override {
LOG(FATAL) << "Function [CleanUp] of PmemMemkindStorage is not implemented.";
}

TF_DISALLOW_COPY_AND_ASSIGN(PmemMemkindStorage);
};

Expand All @@ -537,6 +562,10 @@ class PmemLibpmemStorage : public SingleTierStorage<K, V> {
return SingleTierStorage<K, V>::kv_->Commit(keys, value_ptr);
}

void CleanUp() override {
LOG(FATAL) << "Function [CleanUp] of PmemLibpmemStorage is not implemented.";
}

TF_DISALLOW_COPY_AND_ASSIGN(PmemLibpmemStorage);

protected:
Expand Down Expand Up @@ -577,6 +606,11 @@ class LevelDBStore : public SingleTierStorage<K, V> {
key_list, emb_index, value_len,
leveldb_kv, SingleTierStorage<K, V>::feat_desc_);
}

void CleanUp() override {
LOG(FATAL) << "Function [CleanUp] of LevelDBStorage is not implemented.";
}

public:
friend class DramLevelDBStore<K, V>;
};
Expand Down Expand Up @@ -646,6 +680,10 @@ class SsdHashStorage : public SingleTierStorage<K, V> {
reinterpret_cast<SSDHashKV<K, V>*>(SingleTierStorage<K, V>::kv_);
ssd_kv->SetSsdRecordDescriptor(ssd_rec_desc);
}

void CleanUp() override {
LOG(FATAL) << "Function [CleanUp] of SsdHashStorage is not implemented.";
}
public:
friend class DramSsdHashStorage<K, V>;
#if GOOGLE_CUDA
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/framework/embedding/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class Storage {
return Status::OK();
}

virtual void CleanUp() = 0;

protected:
virtual Status RestoreSSD(int64 emb_index, int64 emb_slot_num,
int64 value_len,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/framework/variable.proto
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ message VariableDef {
bool is_embedding_var = 91;

string initialize_op_for_restore = 92;

string clean_up_op_name = 93;
}

message SaveSliceInfoDef {
Expand Down
51 changes: 51 additions & 0 deletions tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,5 +557,56 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS_GPU)

#undef REGISTER_KERNELS_ALL
#undef REGISTER_KERNELS

template <typename TKey, typename TValue>
class KvResourceCleanUpOp : public OpKernel {
public:
explicit KvResourceCleanUpOp(OpKernelConstruction* ctx)
:OpKernel(ctx) {
TF_CHECK_OK(ReadBoolFromEnvVar("ENABLE_EV_CLEAN_UP", false, &enable_));
}

void Compute(OpKernelContext* ctx) {
if (!enable_) {
return;
}

EmbeddingVar<TKey, TValue>* ev = nullptr;
Status s = LookupResource(ctx, HandleFromInput(ctx, 0), &ev);

if (s.ok()) {
ev->Unref();
ev->CleanUp();
}
}

private:
bool enable_;
};

#define REGISTER_KERNELS(dev, ktype, vtype) \
REGISTER_KERNEL_BUILDER(Name("KvResourceCleanUp") \
.Device(DEVICE_##dev) \
.TypeConstraint<ktype>("Tkeys") \
.TypeConstraint<vtype>("dtype"), \
KvResourceCleanUpOp<ktype, vtype>);

#define REGISTER_KERNELS_ALL(dev, type) \
REGISTER_KERNELS(dev, int32, type) \
REGISTER_KERNELS(dev, int64, type)

#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS_ALL(CPU, type)
TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_CPU)
#undef REGISTER_KERNELS_CPU

#if GOOGLE_CUDA
#define REGISTER_KERNELS_GPU(type) REGISTER_KERNELS_ALL(GPU, type)
TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_GPU)
#undef REGISTER_KERNELS_GPU
#endif // End of macro GOOGLE_CUDA

#undef REGISTER_KERNELS_ALL
#undef REGISTER_KERNELS

} // namespace tensorflow

4 changes: 2 additions & 2 deletions tensorflow/core/kernels/kv_variable_restore_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ class KvResourceImportV3Op: public AsyncOpKernel {

// EV should not be initialized at this time.
if (ev->IsInitialized()) {
LOG(ERROR) << "Import parameter for EV (" << name_string
<< ") failed, this EV has already been initialized.";
LOG(WARNING) << "EV (" << name_string
<< ") has already been initialized.";
}

auto do_compute = [this, context, file_name_string, ev,
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/ops/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -893,4 +893,13 @@ REGISTER_OP("KvResourceLookupResource")
})
.Doc(R"doc()doc");

REGISTER_OP("KvResourceCleanUp")
.Input("resource_handle: resource")
.Attr("Tkeys: {int64, int32}")
.Attr("dtype: type = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return Status::OK();
})
.Doc(R"doc()doc");

} // namespace tensorflow
27 changes: 25 additions & 2 deletions tensorflow/python/ops/kv_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,17 @@ def _init_from_args(self,
dtype=self._dtype))
if initial_value is not None:
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
with ops.control_dependencies(None if self._is_primary else [self._primary.initializer]):
if self._is_primary:
self._clean_up_op = \
gen_kv_variable_ops.kv_resource_clean_up(self._handle, \
Tkeys=self._invalid_key_type, dtype=self._dtype)
else:
self._clean_up_op = self._primary.clean_up_op

control_dep = [self._clean_up_op]
if not self._is_primary:
control_dep.append(self._primary.initializer)
with ops.control_dependencies(control_dep):
self._init_op = gen_kv_variable_ops.initialize_kv_variable_v2_op(
self._handle,
self._primary._handle,
Expand Down Expand Up @@ -450,7 +460,10 @@ def export(self):


def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):
with ops.control_dependencies(None if self._is_primary else [self._primary._init_op_for_restore]):
control_dep = [self._clean_up_op]
if not self._is_primary:
control_dep.append(self._primary._init_op_for_restore)
with ops.control_dependencies(control_dep):
self._initializer_for_restore = gen_kv_variable_ops.initialize_kv_variable_v2_op(
self._handle,
self._primary._handle,
Expand Down Expand Up @@ -494,6 +507,11 @@ def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):

def need_counts(self):
return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier)

@property
def clean_up_op(self):
return self._clean_up_op

@property
def gather_op(self):
return self._gather_op
Expand Down Expand Up @@ -585,6 +603,9 @@ def _init_from_proto(self, variable_def, import_scope=None):
self._primary_handle = g.as_graph_element(
ops.prepend_name_scope(
primary_name, import_scope=import_scope))

self._clean_up_op = g.as_graph_element(ops.prepend_name_scope(
variable_def.clean_up_op_name, import_scope=import_scope))
self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
self._invalid_key = -1
self._steps_to_live = init_op.get_attr("steps_to_live")
Expand Down Expand Up @@ -913,6 +934,8 @@ def to_proto(self, export_scope=None):
self._save_slice_info.to_proto(export_scope=export_scope))
var_def.initialize_op_for_restore = ops.strip_name_scope(
self._init_op_for_restore.name, export_scope)
var_def.clean_up_op_name = \
ops.strip_name_scope(self._clean_up_op.name, export_scope)
return var_def
else:
return None
Expand Down

0 comments on commit e919596

Please sign in to comment.