From aee6311918a6f6a8d388f4caa71dbf0b8b2e8f43 Mon Sep 17 00:00:00 2001 From: "chenbangduo.cbd" Date: Tue, 5 Mar 2024 15:52:19 +0800 Subject: [PATCH] [Embedding] Adjust the header file of embedding variable. Signed-off-by: chenbangduo.cbd --- tensorflow/core/framework/embedding/embedding_var.h | 1 - tensorflow/core/kernels/kv_variable_ops.cc | 1 + tensorflow/core/kernels/kv_variable_restore_ops.cc | 1 + tensorflow/core/kernels/training_ali_ops.cc | 8 ++++---- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index c0d26a2f4d8..81941bc9ff9 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/framework/embedding/gpu_hash_map_kv.h" #include "tensorflow/core/framework/embedding/embedding_config.h" #include "tensorflow/core/framework/embedding/storage.h" -#include "tensorflow/core/framework/embedding/storage_factory.h" #include "tensorflow/core/framework/typed_allocator.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/kv_variable_ops.cc b/tensorflow/core/kernels/kv_variable_ops.cc index 5cd0ef140bd..b7567ffe924 100644 --- a/tensorflow/core/kernels/kv_variable_ops.cc +++ b/tensorflow/core/kernels/kv_variable_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/embedding/cache.h" #include "tensorflow/core/framework/embedding/config.pb.h" #include "tensorflow/core/framework/embedding/embedding_var.h" +#include "tensorflow/core/framework/embedding/storage_factory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" diff --git a/tensorflow/core/kernels/kv_variable_restore_ops.cc b/tensorflow/core/kernels/kv_variable_restore_ops.cc index 2eccf485ef8..e16db9b4cd6 100644 --- a/tensorflow/core/kernels/kv_variable_restore_ops.cc +++ b/tensorflow/core/kernels/kv_variable_restore_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/embedding/cache.h" #include "tensorflow/core/framework/embedding/config.pb.h" #include "tensorflow/core/framework/embedding/embedding_var.h" +#include "tensorflow/core/framework/embedding/storage_factory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" diff --git a/tensorflow/core/kernels/training_ali_ops.cc b/tensorflow/core/kernels/training_ali_ops.cc index 546b30e29dd..fc21ab610cf 100644 --- a/tensorflow/core/kernels/training_ali_ops.cc +++ b/tensorflow/core/kernels/training_ali_ops.cc @@ -236,7 +236,7 @@ class KvSparseApplyAdagradGPUOp : public OpKernel { T** dev_a = dev_v + task_size; CHECK(dev_a); CHECK(dev_v); - DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2); + se::DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2); stream->ThenMemcpy(&dev_v_ptr, v, sizeof(T*) * task_size * 2); int block_size = 128; @@ -1606,7 +1606,7 @@ class KvSparseApplyAdamGPUOp : public OpKernel { CHECK(dev_m_ptr); CHECK(dev_v_ptr); - DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); + se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3); int block_size = 128; @@ -2579,7 +2579,7 @@ class KvSparseApplyAdamAsyncGPUOp : public OpKernel { CHECK(dev_m_ptr); CHECK(dev_v_ptr); - DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); + se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3); int block_size = 128; @@ -3236,7 +3236,7 @@ class KvSparseApplyAdamWGPUOp : public OpKernel { CHECK(dev_m_ptr); CHECK(dev_v_ptr); - DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); + se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3); int block_size = 128;