Skip to content

Commit

Permalink
[Embedding] Refactor the code of Save Op for EmbeddingVariable. (#900)
Browse files Browse the repository at this point in the history
Signed-off-by: lixy9474 <lxy268263@alibaba-inc.com>
  • Loading branch information
lixy9474 authored Aug 7, 2023
1 parent 2065fc0 commit 4cd9ed8
Show file tree
Hide file tree
Showing 31 changed files with 1,191 additions and 1,554 deletions.
10 changes: 10 additions & 0 deletions tensorflow/core/framework/embedding/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,13 @@ enum EmbeddingVariableType {
IMMUTABLE = 0;
MUTABLE = 1;
}

enum ValuePtrStatus {
OK = 0;
IS_DELETED = 1;
}

enum ValuePosition {
IN_DRAM = 0;
NOT_IN_DRAM = 1;
}
86 changes: 45 additions & 41 deletions tensorflow/core/framework/embedding/dram_leveldb_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,6 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
return false;
}

void iterator_mutex_lock() override {
leveldb_->get_mutex()->lock();
}

void iterator_mutex_unlock() override {
leveldb_->get_mutex()->unlock();
}

int64 Size() const override {
int64 total_size = dram_->Size();
total_size += leveldb_->Size();
Expand All @@ -145,46 +137,58 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
return -1;
}

Status GetSnapshot(std::vector<K>* key_list,
std::vector<ValuePtr<V>*>* value_ptr_list) override {
{
mutex_lock l(*(dram_->get_mutex()));
TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list));
Status Save(
const string& tensor_name,
const string& prefix,
BundleWriter* writer,
const EmbeddingConfig& emb_config,
ShrinkArgs& shrink_args,
int64 value_len,
V* default_value) override {
std::vector<K> key_list, tmp_leveldb_key_list;
std::vector<ValuePtr<V>*> value_ptr_list, tmp_leveldb_value_list;
TF_CHECK_OK(dram_->GetSnapshot(&key_list, &value_ptr_list));

TF_CHECK_OK(leveldb_->GetSnapshot(
&tmp_leveldb_key_list, &tmp_leveldb_value_list));

for (int64 i = 0; i < tmp_leveldb_value_list.size(); i++) {
tmp_leveldb_value_list[i]->SetPtr((V*)ValuePosition::NOT_IN_DRAM);
tmp_leveldb_value_list[i]->SetInitialized(emb_config.primary_emb_index);
}
{
mutex_lock l(*(leveldb_->get_mutex()));
TF_CHECK_OK(leveldb_->GetSnapshot(key_list, value_ptr_list));

std::vector<K> leveldb_key_list;
for (int64 i = 0; i < tmp_leveldb_key_list.size(); i++) {
Status s = dram_->Contains(tmp_leveldb_key_list[i]);
if (!s.ok()) {
key_list.emplace_back(tmp_leveldb_key_list[i]);
leveldb_key_list.emplace_back(tmp_leveldb_key_list[i]);
value_ptr_list.emplace_back(tmp_leveldb_value_list[i]);
}
}
return Status::OK();
}

Status Shrink(const ShrinkArgs& shrink_args) override {
dram_->Shrink(shrink_args);
leveldb_->Shrink(shrink_args);
return Status::OK();
}
ValueIterator<V>* value_iter =
leveldb_->GetValueIterator(
leveldb_key_list, emb_config.emb_index, value_len);

int64 GetSnapshot(std::vector<K>* key_list,
std::vector<V* >* value_list,
std::vector<int64>* version_list,
std::vector<int64>* freq_list,
const EmbeddingConfig& emb_config,
FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
embedding::Iterator** it) override {
{
mutex_lock l(*(dram_->get_mutex()));
std::vector<ValuePtr<V>*> value_ptr_list;
std::vector<K> key_list_tmp;
TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list));
MultiTierStorage<K, V>::SetListsForCheckpoint(
key_list_tmp, value_ptr_list, emb_config,
key_list, value_list, version_list, freq_list);
}
{
mutex_lock l(*(leveldb_->get_mutex()));
*it = leveldb_->GetIterator();
TF_CHECK_OK((Storage<K, V>::SaveToCheckpoint(
tensor_name, writer,
emb_config,
value_len, default_value,
key_list,
value_ptr_list,
value_iter)));
}
return key_list->size();

for (auto it: tmp_leveldb_value_list) {
delete it;
}

delete value_iter;

return Status::OK();
}

Status Eviction(K* evict_ids, int64 evict_size) override {
Expand Down
82 changes: 32 additions & 50 deletions tensorflow/core/framework/embedding/dram_pmem_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,59 +150,41 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
return -1;
}

Status GetSnapshot(std::vector<K>* key_list,
std::vector<ValuePtr<V>* >* value_ptr_list) override {
{
mutex_lock l(*(dram_->get_mutex()));
TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list));
}
{
mutex_lock l(*(pmem_->get_mutex()));
TF_CHECK_OK(pmem_->GetSnapshot(key_list, value_ptr_list));
Status Save(
const string& tensor_name,
const string& prefix,
BundleWriter* writer,
const EmbeddingConfig& emb_config,
ShrinkArgs& shrink_args,
int64 value_len,
V* default_value) override {
std::vector<K> key_list, tmp_pmem_key_list;
std::vector<ValuePtr<V>*> value_ptr_list, tmp_pmem_value_list;

TF_CHECK_OK(dram_->GetSnapshot(&key_list, &value_ptr_list));
dram_->Shrink(key_list, value_ptr_list, shrink_args, value_len);

TF_CHECK_OK(pmem_->GetSnapshot(&tmp_pmem_key_list,
&tmp_pmem_value_list));
pmem_->Shrink(tmp_pmem_key_list, tmp_pmem_value_list,
shrink_args, value_len);

for (int64 i = 0; i < tmp_pmem_key_list.size(); i++) {
Status s = dram_->Contains(tmp_pmem_key_list[i]);
if (!s.ok()) {
key_list.emplace_back(tmp_pmem_key_list[i]);
value_ptr_list.emplace_back(tmp_pmem_value_list[i]);
}
}
return Status::OK();
}

Status Shrink(const ShrinkArgs& shrink_args) override {
dram_->Shrink(shrink_args);
pmem_->Shrink(shrink_args);
return Status::OK();
}
TF_CHECK_OK((Storage<K, V>::SaveToCheckpoint(
tensor_name, writer,
emb_config,
value_len, default_value,
key_list,
value_ptr_list)));

void iterator_mutex_lock() override {
return;
}

void iterator_mutex_unlock() override {
return;
}

int64 GetSnapshot(std::vector<K>* key_list,
std::vector<V* >* value_list,
std::vector<int64>* version_list,
std::vector<int64>* freq_list,
const EmbeddingConfig& emb_config,
FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
embedding::Iterator** it) override {
{
mutex_lock l(*(dram_->get_mutex()));
std::vector<ValuePtr<V>*> value_ptr_list;
std::vector<K> key_list_tmp;
TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list));
MultiTierStorage<K, V>::SetListsForCheckpoint(
key_list_tmp, value_ptr_list, emb_config,
key_list, value_list, version_list, freq_list);
}
{
mutex_lock l(*(pmem_->get_mutex()));
std::vector<ValuePtr<V>*> value_ptr_list;
std::vector<K> key_list_tmp;
TF_CHECK_OK(pmem_->GetSnapshot(&key_list_tmp, &value_ptr_list));
MultiTierStorage<K, V>::SetListsForCheckpoint(
key_list_tmp, value_ptr_list, emb_config,
key_list, value_list, version_list, freq_list);
}
return key_list->size();
return Status::OK();
}

Status Eviction(K* evict_ids, int64 evict_size) override {
Expand Down
75 changes: 13 additions & 62 deletions tensorflow/core/framework/embedding/dram_ssd_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,70 +144,21 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
return true;
}

Status GetSnapshot(std::vector<K>* key_list,
std::vector<ValuePtr<V>*>* value_ptr_list) override {
{
mutex_lock l(*(dram_->get_mutex()));
TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list));
}
{
mutex_lock l(*(ssd_hash_->get_mutex()));
TF_CHECK_OK(ssd_hash_->GetSnapshot(key_list, value_ptr_list));
}
return Status::OK();
}

Status Shrink(const ShrinkArgs& shrink_args) override {
dram_->Shrink(shrink_args);
ssd_hash_->Shrink(shrink_args);
return Status::OK();
}

int64 GetSnapshot(std::vector<K>* key_list,
std::vector<V* >* value_list,
std::vector<int64>* version_list,
std::vector<int64>* freq_list,
Status Save(
const string& tensor_name,
const string& prefix,
BundleWriter* writer,
const EmbeddingConfig& emb_config,
FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
embedding::Iterator** it) override {
{
mutex_lock l(*(dram_->get_mutex()));
std::vector<ValuePtr<V>*> value_ptr_list;
std::vector<K> key_list_tmp;
TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list));
MultiTierStorage<K, V>::SetListsForCheckpoint(
key_list_tmp, value_ptr_list, emb_config,
key_list, value_list, version_list, freq_list);
}
{
mutex_lock l(*(ssd_hash_->get_mutex()));
*it = ssd_hash_->GetIterator();
}
return key_list->size();
}
ShrinkArgs& shrink_args,
int64 value_len,
V* default_value) override {
dram_->Save(tensor_name, prefix, writer, emb_config,
shrink_args, value_len, default_value);

int64 GetSnapshotWithoutFetchPersistentEmb(
std::vector<K>* key_list,
std::vector<V*>* value_list,
std::vector<int64>* version_list,
std::vector<int64>* freq_list,
const EmbeddingConfig& emb_config,
SsdRecordDescriptor<K>* ssd_rec_desc) override {
{
mutex_lock l(*(dram_->get_mutex()));
std::vector<ValuePtr<V>*> value_ptr_list;
std::vector<K> temp_key_list;
TF_CHECK_OK(dram_->GetSnapshot(&temp_key_list, &value_ptr_list));
MultiTierStorage<K, V>::SetListsForCheckpoint(
temp_key_list, value_ptr_list, emb_config,
key_list, value_list, version_list,
freq_list);
}
{
mutex_lock l(*(ssd_hash_->get_mutex()));
ssd_hash_->SetSsdRecordDescriptor(ssd_rec_desc);
}
return key_list->size() + ssd_rec_desc->key_list.size();
ssd_hash_->Save(tensor_name, prefix, writer, emb_config,
shrink_args, value_len, default_value);

return Status::OK();
}

Status RestoreSSD(int64 emb_index, int64 emb_slot_num, int64 value_len,
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/core/framework/embedding/embedding_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ struct EmbeddingConfig {
return emb_index == primary_emb_index;
}

bool is_save_freq() const {
return filter_freq != 0 ||
record_freq ||
normal_fix_flag == 1;
}

bool is_save_version() const {
return steps_to_live != 0 || record_version;
}

int64 total_num(int alloc_len) {
return block_num *
(1 + (1 - normal_fix_flag) * slot_num) *
Expand Down
32 changes: 8 additions & 24 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,30 +582,14 @@ class EmbeddingVar : public ResourceBase {
emb_config_, device, reader, this, filter_);
}

int64 GetSnapshot(std::vector<K>* key_list,
std::vector<V* >* value_list,
std::vector<int64>* version_list,
std::vector<int64>* freq_list,
embedding::Iterator** it = nullptr) {
// for Interface Compatible
// TODO Multi-tiered Embedding should use iterator in 'GetSnapshot' caller
embedding::Iterator* _it = nullptr;
it = (it == nullptr) ? &_it : it;
return storage_->GetSnapshot(
key_list, value_list, version_list,
freq_list, emb_config_, filter_, it);
}

int64 GetSnapshotWithoutFetchPersistentEmb(
std::vector<K>* key_list,
std::vector<V*>* value_list,
std::vector<int64>* version_list,
std::vector<int64>* freq_list,
SsdRecordDescriptor<K>* ssd_rec_desc) {
return storage_->
GetSnapshotWithoutFetchPersistentEmb(
key_list, value_list, version_list,
freq_list, emb_config_, ssd_rec_desc);
Status Save(const string& tensor_name,
const string& prefix,
BundleWriter* writer,
embedding::ShrinkArgs& shrink_args) {
return storage_->Save(tensor_name, prefix,
writer, emb_config_,
shrink_args, value_len_,
default_value_);
}

mutex* mu() {
Expand Down
Loading

0 comments on commit 4cd9ed8

Please sign in to comment.