Skip to content

Commit

Permalink
[Embedding] Refactor ShrinkPolicy class and shrink functions in Embed…
Browse files Browse the repository at this point in the history
…ding. (#861)


Signed-off-by: lixy9474 <lxy268263@alibaba-inc.com>
  • Loading branch information
lixy9474 authored May 18, 2023
1 parent f178454 commit 0c3e0ad
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 169 deletions.
12 changes: 3 additions & 9 deletions tensorflow/core/framework/embedding/dram_leveldb_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,9 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
return Status::OK();
}

Status Shrink(int64 value_len) override {
dram_->Shrink(value_len);
leveldb_->Shrink(value_len);
return Status::OK();
}

Status Shrink(int64 global_step, int64 steps_to_live) override {
dram_->Shrink(global_step, steps_to_live);
leveldb_->Shrink(global_step, steps_to_live);
Status Shrink(const ShrinkArgs& shrink_args) override {
dram_->Shrink(shrink_args);
leveldb_->Shrink(shrink_args);
return Status::OK();
}

Expand Down
12 changes: 3 additions & 9 deletions tensorflow/core/framework/embedding/dram_pmem_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,9 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
return Status::OK();
}

Status Shrink(int64 value_len) override {
dram_->Shrink(value_len);
pmem_->Shrink(value_len);
return Status::OK();
}

Status Shrink(int64 global_step, int64 steps_to_live) override {
dram_->Shrink(global_step, steps_to_live);
pmem_->Shrink(global_step, steps_to_live);
Status Shrink(const ShrinkArgs& shrink_args) override {
dram_->Shrink(shrink_args);
pmem_->Shrink(shrink_args);
return Status::OK();
}

Expand Down
12 changes: 3 additions & 9 deletions tensorflow/core/framework/embedding/dram_ssd_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,9 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
return Status::OK();
}

Status Shrink(int64 value_len) override {
dram_->Shrink(value_len);
ssd_hash_->Shrink(value_len);
return Status::OK();
}

Status Shrink(int64 global_step, int64 steps_to_live) override {
dram_->Shrink(global_step, steps_to_live);
ssd_hash_->Shrink(global_step, steps_to_live);
Status Shrink(const ShrinkArgs& shrink_args) override {
dram_->Shrink(shrink_args);
ssd_hash_->Shrink(shrink_args);
return Status::OK();
}

Expand Down
15 changes: 4 additions & 11 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,6 @@ class EmbeddingVar : public ResourceBase {
return emb_config_.steps_to_live;
}

float GetL2WeightThreshold() {
return emb_config_.l2_weight_threshold;
}

bool IsMultiLevel() {
return storage_->IsMultiLevel();
}
Expand Down Expand Up @@ -553,13 +549,10 @@ class EmbeddingVar : public ResourceBase {
return storage_;
}

Status Shrink() {
return storage_->Shrink(value_len_);
}

Status Shrink(int64 gs) {
if (emb_config_.steps_to_live > 0) {
return storage_->Shrink(gs, emb_config_.steps_to_live);
Status Shrink(embedding::ShrinkArgs& shrink_args) {
if (emb_config_.is_primary()) {
shrink_args.value_len = value_len_;
return storage_->Shrink(shrink_args);
} else {
return Status::OK();
}
Expand Down
45 changes: 26 additions & 19 deletions tensorflow/core/framework/embedding/globalstep_shrink_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,44 @@ namespace embedding {
template<typename K, typename V>
class GlobalStepShrinkPolicy : public ShrinkPolicy<K, V> {
public:
GlobalStepShrinkPolicy(
KVInterface<K, V>* kv,
Allocator* alloc,
int slot_num)
: ShrinkPolicy<K, V>(kv, alloc, slot_num) {}
GlobalStepShrinkPolicy(int64 steps_to_live,
Allocator* alloc,
KVInterface<K, V>* kv)
: steps_to_live_(steps_to_live),
kv_(kv),
ShrinkPolicy<K, V>(alloc) {}

TF_DISALLOW_COPY_AND_ASSIGN(GlobalStepShrinkPolicy);

void Shrink(int64 global_step, int64 steps_to_live) {
ShrinkPolicy<K, V>::ReleaseDeleteValues();
ShrinkPolicy<K, V>::GetSnapshot();
FilterToDelete(global_step, steps_to_live);
void Shrink(const ShrinkArgs& shrink_args) override {
ShrinkPolicy<K, V>::ReleaseValuePtrs();
std::vector<K> key_list;
std::vector<ValuePtr<V>*> value_list;
kv_->GetSnapshot(&key_list, &value_list);
FilterToDelete(shrink_args.global_step,
key_list, value_list);
}

private:
void FilterToDelete(int64 global_step, int64 steps_to_live) {
for (int64 i = 0; i < ShrinkPolicy<K, V>::key_list_.size(); ++i) {
int64 version = ShrinkPolicy<K, V>::value_list_[i]->GetStep();
void FilterToDelete(int64 global_step,
const std::vector<K>& key_list,
const std::vector<ValuePtr<V>*>& value_list) {
for (int64 i = 0; i < key_list.size(); ++i) {
int64 version = value_list[i]->GetStep();
if (version == -1) {
ShrinkPolicy<K, V>::value_list_[i]->SetStep(global_step);
value_list[i]->SetStep(global_step);
} else {
if (global_step - version > steps_to_live) {
ShrinkPolicy<K, V>::kv_->Remove(ShrinkPolicy<K, V>::key_list_[i]);
ShrinkPolicy<K, V>::to_delete_.emplace_back(
ShrinkPolicy<K, V>::value_list_[i]);
if (global_step - version > steps_to_live_) {
kv_->Remove(key_list[i]);
ShrinkPolicy<K, V>::EmplacePointer(value_list[i]);
}
}
}
ShrinkPolicy<K, V>::key_list_.clear();
ShrinkPolicy<K, V>::value_list_.clear();
}

private:
int64 steps_to_live_;
KVInterface<K, V>* kv_;
};
} // embedding
} // tensorflow
Expand Down
15 changes: 4 additions & 11 deletions tensorflow/core/framework/embedding/hbm_dram_ssd_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,17 +387,10 @@ class HbmDramSsdStorage : public MultiTierStorage<K, V> {
LOG(FATAL)<<"HbmDramSsdStorage dosen't support GetSnaoshot.";
}

Status Shrink(int64 value_len) override {
hbm_->Shrink(value_len);
dram_->Shrink(value_len);
ssd_->Shrink(value_len);
return Status::OK();
}

Status Shrink(int64 global_step, int64 steps_to_live) override {
hbm_->Shrink(global_step, steps_to_live);
dram_->Shrink(global_step, steps_to_live);
ssd_->Shrink(global_step, steps_to_live);
Status Shrink(const ShrinkArgs& shrink_args) override {
hbm_->Shrink(shrink_args);
dram_->Shrink(shrink_args);
ssd_->Shrink(shrink_args);
return Status::OK();
}

Expand Down
12 changes: 3 additions & 9 deletions tensorflow/core/framework/embedding/hbm_dram_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,9 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
return temp_hbm_key_list.size() + temp_dram_key_list.size();
}

Status Shrink(int64 value_len) override {
hbm_->Shrink(value_len);
dram_->Shrink(value_len);
return Status::OK();
}

Status Shrink(int64 global_step, int64 steps_to_live) override {
hbm_->Shrink(global_step, steps_to_live);
dram_->Shrink(global_step, steps_to_live);
Status Shrink(const ShrinkArgs& shrink_args) override {
hbm_->Shrink(shrink_args);
dram_->Shrink(shrink_args);
return Status::OK();
}

Expand Down
49 changes: 27 additions & 22 deletions tensorflow/core/framework/embedding/l2weight_shrink_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,46 +27,51 @@ template<typename K, typename V>
class L2WeightShrinkPolicy : public ShrinkPolicy<K, V> {
public:
L2WeightShrinkPolicy(float l2_weight_threshold,
int64 primary_index, int64 primary_offset,
KVInterface<K, V>* kv, Allocator* alloc,
int slot_num)
: l2_weight_threshold_(l2_weight_threshold),
primary_index_(primary_index), primary_offset_(primary_offset),
ShrinkPolicy<K, V>(kv, alloc, slot_num) {}
int64 index,
int64 offset,
Allocator* alloc,
KVInterface<K, V>* kv)
: index_(index),
offset_(offset),
kv_(kv),
l2_weight_threshold_(l2_weight_threshold),
ShrinkPolicy<K, V>(alloc) {}

TF_DISALLOW_COPY_AND_ASSIGN(L2WeightShrinkPolicy);

void Shrink(int64 value_len) {
ShrinkPolicy<K, V>::ReleaseDeleteValues();
ShrinkPolicy<K, V>::GetSnapshot();
FilterToDelete(value_len);
void Shrink(const ShrinkArgs& shrink_args) override {
ShrinkPolicy<K, V>::ReleaseValuePtrs();
std::vector<K> key_list;
std::vector<ValuePtr<V>*> value_list;
kv_->GetSnapshot(&key_list, &value_list);
FilterToDelete(shrink_args.value_len,
key_list, value_list);
}

private:
void FilterToDelete(int64 value_len) {
for (int64 i = 0; i < ShrinkPolicy<K, V>::key_list_.size(); ++i) {
V* val = ShrinkPolicy<K, V>::value_list_[i]->GetValue(
primary_index_, primary_offset_);
private:
void FilterToDelete(int64 value_len,
const std::vector<K>& key_list,
const std::vector<ValuePtr<V>*>& value_list) {
for (int64 i = 0; i < key_list.size(); ++i) {
V* val = value_list[i]->GetValue(index_, offset_);
if (val != nullptr) {
V l2_weight = (V)0.0;
for (int64 j = 0; j < value_len; j++) {
l2_weight += val[j] * val[j];
}
l2_weight *= (V)0.5;
if (l2_weight < (V)l2_weight_threshold_) {
ShrinkPolicy<K, V>::kv_->Remove(ShrinkPolicy<K, V>::key_list_[i]);
ShrinkPolicy<K, V>::to_delete_.emplace_back(
ShrinkPolicy<K, V>::value_list_[i]);
kv_->Remove(key_list[i]);
ShrinkPolicy<K, V>::EmplacePointer(value_list[i]);
}
}
}
ShrinkPolicy<K, V>::key_list_.clear();
ShrinkPolicy<K, V>::value_list_.clear();
}

private:
int64 primary_index_; // Shrink only handle primary slot
int64 primary_offset_;
int64 index_;
int64 offset_;
KVInterface<K, V>* kv_;
float l2_weight_threshold_;
};
} // embedding
Expand Down
59 changes: 37 additions & 22 deletions tensorflow/core/framework/embedding/shrink_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,55 @@ namespace tensorflow {
template<typename V>
class ValuePtr;

class Allocator;

namespace embedding {
struct ShrinkArgs {
ShrinkArgs(): global_step(0), value_len(0) {}

ShrinkArgs(int64 global_step,
int64 value_len)
: global_step(global_step),
value_len(value_len) {}
int64 global_step;
int64 value_len;
};

template<typename K, typename V>
class ShrinkPolicy {
public:
ShrinkPolicy(KVInterface<K, V>* kv, Allocator* alloc, int slot_num)
: kv_(kv), alloc_(alloc),
slot_num_(slot_num), shrink_count_(0) {}
ShrinkPolicy(Allocator* alloc): alloc_(alloc) {}
virtual ~ShrinkPolicy() {}

TF_DISALLOW_COPY_AND_ASSIGN(ShrinkPolicy);

inline Status GetSnapshot() {
shrink_count_ = (shrink_count_ + 1) % slot_num_;
return kv_->GetSnapshot(&key_list_, &value_list_);
virtual void Shrink(const ShrinkArgs& shrink_args) = 0;

protected:
void EmplacePointer(ValuePtr<V>* value_ptr) {
to_delete_.emplace_back(value_ptr);
}

void ReleaseDeleteValues() {
if (shrink_count_ == 0) {
for (auto it : to_delete_) {
it->Destroy(alloc_);
delete it;
}
to_delete_.clear();

void ReleaseValuePtrs() {
for (auto it : to_delete_) {
it->Destroy(alloc_);
delete it;
}
to_delete_.clear();
}

protected:
std::vector<K> key_list_;
std::vector<ValuePtr<V>*> value_list_;
protected:
std::vector<ValuePtr<V>*> to_delete_;

KVInterface<K, V>* kv_;
private:
Allocator* alloc_;
int slot_num_;
int shrink_count_;
};

template<typename K, typename V>
class NonShrinkPolicy: public ShrinkPolicy<K, V> {
public:
NonShrinkPolicy(): ShrinkPolicy<K, V>(nullptr) {}
TF_DISALLOW_COPY_AND_ASSIGN(NonShrinkPolicy);

void Shrink(const ShrinkArgs& shrink_args) {}
};
} // embedding
} // tensorflow
Expand Down
Loading

0 comments on commit 0c3e0ad

Please sign in to comment.