Skip to content

Commit

Permalink
[IncrCKPT] fix the bug of importing embedding variable.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
  • Loading branch information
JackMoriarty committed Mar 25, 2024
1 parent 6dae552 commit 6d272d8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 22 deletions.
50 changes: 28 additions & 22 deletions tensorflow/core/framework/embedding/embedding_var_restore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,45 +102,48 @@ void CheckpointLoader<K, V>::RestoreInternal(
Tensor part_filter_offset_tensor;
if (!restore_args_.m_is_oldform) {
/****** InitPartOffsetTensor ******/
TensorShape part_offset_shape, part_filter_offset_shape;
DataType part_offset_type, part_filter_offset_type;
TensorShape part_offset_shape;
DataType part_offset_type;
string offset_tensor_name;
if (!restore_args_.m_is_incr) {
offset_tensor_name = name_string + kPartOffsetTensorSuffsix;
} else {
offset_tensor_name = name_string + kIncrPartOffsetTensorSuffsix;
}

string offset_filter_tensor_name =
name_string + kPartFilterOffsetTensorSuffsix;

Status s = reader_->LookupDtypeAndShape(
offset_tensor_name, &part_offset_type, &part_offset_shape);
if (!s.ok()) {
LOG(ERROR) << "EV restoring fail:" << s.error_message();
}
s = reader_->LookupDtypeAndShape(offset_filter_tensor_name,
&part_filter_offset_type,
&part_filter_offset_shape);
if (!s.ok()) {
LOG(ERROR) << "EV restoring fail: " << s.error_message();
}
part_offset_tensor =
Tensor(cpu_allocator(), part_offset_type, part_offset_shape);
part_filter_offset_tensor = Tensor(
cpu_allocator(), part_filter_offset_type, part_filter_offset_shape);
s = reader_->Lookup(offset_tensor_name, &part_offset_tensor);
if (!s.ok()) {
LOG(ERROR) << "EV restoring fail:" << s.error_message();
}

s = reader_->Lookup(offset_filter_tensor_name,
&part_filter_offset_tensor);
if (!s.ok()) {
LOG(ERROR) << "EV restoring fail: " << s.error_message();
if (restore_args_.m_has_filter) {
TensorShape part_filter_offset_shape;
DataType part_filter_offset_type;
string offset_filter_tensor_name =
name_string + kPartFilterOffsetTensorSuffsix;
s = reader_->LookupDtypeAndShape(offset_filter_tensor_name,
&part_filter_offset_type,
&part_filter_offset_shape);
if (!s.ok()) {
LOG(ERROR) << "EV restoring fail: " << s.error_message();
}
part_filter_offset_tensor = \
Tensor(cpu_allocator(), part_filter_offset_type,
part_filter_offset_shape);
s = reader_->Lookup(offset_filter_tensor_name,
&part_filter_offset_tensor);
if (!s.ok()) {
LOG(ERROR) << "EV restoring fail: " << s.error_message();
}
}
}
auto part_offset_flat = part_offset_tensor.flat<int32>();
auto part_filter_offset_flat = part_filter_offset_tensor.flat<int32>();

if (restore_args_.m_is_oldform) {
VLOG(1) << "old form, EV name:" << name_string
Expand All @@ -164,6 +167,7 @@ void CheckpointLoader<K, V>::RestoreInternal(
VLOG(1) << "new form checkpoint... :" << name_string
<< " , partition_id:" << restore_args_.m_partition_id
<< " , partition_num:" << restore_args_.m_partition_num;
auto part_offset_flat = part_offset_tensor.flat<int32>();
for (size_t i = 0; i < restore_args_.m_loaded_parts.size(); i++) {
int subpart_id = restore_args_.m_loaded_parts[i];
size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim;
Expand All @@ -183,6 +187,7 @@ void CheckpointLoader<K, V>::RestoreInternal(
new_dim, emb_config, device);

if (restore_args_.m_has_filter) {
auto part_filter_offset_flat = part_filter_offset_tensor.flat<int32>();
Status s = EVRestoreFilteredFeatures(
subpart_id, new_dim, restore_buff, part_filter_offset_flat,
emb_config, device);
Expand Down Expand Up @@ -444,7 +449,7 @@ Status CheckpointLoader<K, V>::EVInitTensorNameAndShape(
}
st = reader_->LookupHeader(restore_args_.m_tensor_version + "_filtered",
sizeof(K) * version_filter_shape.dim_size(0));
if (!st.ok()) {
if (!st.ok() && st.code() != error::NOT_FOUND) {
return st;
}
st = reader_->LookupTensorShape(restore_args_.m_tensor_freq + "_filtered",
Expand All @@ -463,7 +468,8 @@ Status CheckpointLoader<K, V>::EVInitTensorNameAndShape(
return st;
}
}
return st;

return Status::OK();
}
#define REGISTER_KERNELS(ktype, vtype) \
template Status CheckpointLoader<ktype, vtype>::EVInitTensorNameAndShape(\
Expand Down Expand Up @@ -644,4 +650,4 @@ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX)
#undef REGISTER_KERNELS_ALL_INDEX
#undef REGISTER_KERNELS

}// namespace tensorflow
}// namespace tensorflow
54 changes: 54 additions & 0 deletions tensorflow/python/training/incr_ckpt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,5 +451,59 @@ def testIncrementalSaverForResourceVariable(self):
saver.build()
incr_saver = incr_saver_module._get_incremental_saver(True, saver)

def testIncrementalSaverSaveAndRestore(self):
tmp_path = self.get_temp_dir()
full_ckpt_dir = os.path.join(tmp_path, "model.ckpt")
incr_ckpt_dir = os.path.join(tmp_path, "incr.ckpt")
full_ckpt_path = None
incr_ckpt_path = None

# construct graph
emb_var = variable_scope.get_embedding_variable("emb", embedding_dim=3,
initializer = init_ops.ones_initializer(dtypes.float32))
emb = embedding_ops.embedding_lookup(emb_var,
math_ops.cast([0, 1, 2, 3, 4], dtypes.int64))
loss = math_ops.reduce_sum(emb, name = 'reduce_sum')
opt = adagrad.AdagradOptimizer(0.1)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
init = variables.global_variables_initializer()
saver = saver_module.Saver(sharded=True, incremental_save_restore=True)
incr_saver = \
incr_saver_module.IncrementalSaver(sharded=True,
saver_def=saver.saver_def, defer_build=True)
incr_saver.build(saver._builder.filename_tensor)

# generate full ckpt and incr ckpt.
full_ckpt_value=None
incr_ckpt_value=None
with self.test_session() as sess:
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS))
sess.run([init])
sess.run([train_op])
full_ckpt_path = saver.save(sess, full_ckpt_dir, global_step = 10)
full_ckpt_value = sess.run([emb])
print("full_ckpt: {}".format(full_ckpt_value))
sess.run([train_op])
incr_ckpt_path = \
incr_saver.incremental_save(sess, incr_ckpt_dir, global_step=20)
incr_ckpt_value = sess.run([emb])
print("incr_ckpt: {}".format(incr_ckpt_value))

# check the value after restoring parameter.
with self.test_session() as sess:
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS))
sess.run([init])
saver.restore(sess, full_ckpt_path)
restore_full_ckpt_value = sess.run([emb])
print("restore_full_ckpt: {}".format(restore_full_ckpt_value))
incr_saver.incremental_restore(sess, full_ckpt_path, incr_ckpt_path)
restore_incr_ckpt_value = sess.run([emb])
print("restore_incr_ckpt: {}".format(restore_incr_ckpt_value))
self.assertAllClose(full_ckpt_value, restore_full_ckpt_value)
self.assertAllClose(incr_ckpt_value, restore_incr_ckpt_value)

if __name__ == "__main__":
googletest.main()

0 comments on commit 6d272d8

Please sign in to comment.