diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ce6850eb9da..07115cfea3c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1203,6 +1203,7 @@ tf_gen_op_libs( "encode_proto_ops", "experimental_dataset_ops", "feature_column_ops", + "file_slice_sendrecv_ops", "function_ops", "functional_ops", "fused_embedding_ops", @@ -1465,6 +1466,7 @@ cc_library( ":encode_proto_ops_op_lib", ":experimental_dataset_ops_op_lib", ":feature_column_ops_op_lib", + ":file_slice_sendrecv_ops_op_lib", ":function_ops_op_lib", ":functional_ops_op_lib", ":fused_embedding_ops_op_lib", diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index 3c2b20379c8..3aa65534272 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -82,6 +82,8 @@ class Rendezvous : public core::RefCounted { friend class FuseRecvOp; friend class SliceSendOp; friend class SliceRecvOp; + friend class FileSliceSendOp; + friend class FileSliceRecvOp; friend class RefSendOp; friend class RefRecvOp; string buf_; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index d9709d39f3f..59b25ee7c36 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -70,12 +70,14 @@ const std::unordered_map& Node::kNodeClassTable = {"_HostSend", NC_HOST_SEND}, {"_RefSend", NC_REF_SEND}, {"_SliceSend", NC_SLICE_SEND}, + {"_FileSliceSend", NC_FILE_SLICE_SEND}, {"_Recv", NC_RECV}, {"_HostRecv", NC_HOST_RECV}, {"_RefRecv", NC_REF_RECV}, {"_FuseRecv", NC_FUSE_RECV}, {"_HostFuseRecv", NC_HOST_FUSE_RECV}, {"_SliceRecv", NC_SLICE_RECV}, + {"_FileSliceRecv", NC_FILE_SLICE_RECV}, {"Const", NC_CONSTANT}, {"HostConst", NC_CONSTANT}, {"Variable", NC_VARIABLE}, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 0baf8f257a9..bd6d18cfc7c 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -220,15 +220,19 @@ class Node { bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND || class_ == NC_REF_SEND || - class_ == NC_SLICE_SEND; } + class_ == NC_SLICE_SEND || + class_ == NC_FILE_SLICE_SEND; } bool IsSliceSend() const { return class_ == NC_SLICE_SEND; } + bool IsFileSliceSend() const { return class_ == NC_FILE_SLICE_SEND; } bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV || class_ == NC_REF_RECV || - class_ == NC_SLICE_RECV; } + class_ == NC_SLICE_RECV || + class_ == NC_FILE_SLICE_RECV; } bool IsFuseRecv() const { return class_ == NC_FUSE_RECV || class_ == NC_HOST_FUSE_RECV; } bool IsSliceRecv() const {return class_ == NC_SLICE_RECV; } + bool IsFileSliceRecv() const { return class_ == NC_FILE_SLICE_RECV; } bool IsConstant() const { return class_ == NC_CONSTANT; } bool IsStage() const { return class_ == NC_TENSOR_BUFFER_PUT; } bool IsUnstage() const { return class_ == NC_TENSOR_BUFFER_TAKE; } @@ -339,12 +343,14 @@ class Node { NC_HOST_SEND, NC_REF_SEND, NC_SLICE_SEND, + NC_FILE_SLICE_SEND, NC_RECV, NC_HOST_RECV, NC_REF_RECV, NC_FUSE_RECV, NC_HOST_FUSE_RECV, NC_SLICE_RECV, + NC_FILE_SLICE_RECV, NC_CONSTANT, NC_VARIABLE, NC_KV_VAR_HANDLE, @@ -851,8 +857,10 @@ inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } inline bool IsSend(const Node* node) { return node->IsSend(); } inline bool IsSliceSend(const Node* node) { return node->IsSliceSend(); } +inline bool IsFileSliceSend(const Node* node) { return node->IsFileSliceSend(); } inline bool IsRecv(const Node* node) { return node->IsRecv(); } inline bool IsSliceRecv(const Node* node) { return node->IsSliceRecv(); } +inline bool IsFileSliceRecv(const Node* node) { return node->IsFileSliceRecv(); } inline bool IsFuseRecv(const Node* node) { return node->IsFuseRecv(); } inline bool IsHostSend(const Node* node) { return node->IsHostSend(); } inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); } diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 1201623ffcd..fd72927bd79 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -265,6 +265,10 @@ bool IsExp(const NodeDef& node) { return node.op() == "Exp"; } bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; } +bool IsFileSliceRecv(const NodeDef& node) { return node.op() == "_FileSliceRecv"; } + +bool IsFileSliceSend(const NodeDef& node) { return node.op() == "_FileSliceSend"; } + bool IsFill(const NodeDef& node) { return node.op() == "Fill"; } bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; } @@ -454,7 +458,8 @@ bool IsReciprocalGrad(const NodeDef& node) { } bool IsRecv(const NodeDef& node) { - return node.op() == "_Recv" || node.op() == "_HostRecv" || IsSliceRecv(node); + return node.op() == "_Recv" || node.op() == "_HostRecv" || + IsSliceRecv(node) || IsFileSliceRecv(node); } bool IsFuseRecv(const NodeDef& node) { @@ -502,7 +507,8 @@ bool IsSelect(const NodeDef& node) { return node.op() == "Select"; } bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; } bool IsSend(const NodeDef& node) { - return node.op() == "_Send" || node.op() == "_HostSend" || IsSliceSend(node); + return node.op() == "_Send" || node.op() == "_HostSend" || + IsSliceSend(node) || IsFileSliceSend(node); } bool IsShape(const NodeDef& node) { return node.op() == "Shape"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 737581fd412..10968ad2547 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -80,6 +80,8 @@ bool IsExit(const NodeDef& node); bool IsExp(const NodeDef& node); bool IsFakeParam(const NodeDef& node); bool IsFill(const NodeDef& node); +bool IsFileSliceRecv(const NodeDef& node); +bool IsFileSliceSend(const NodeDef& node); bool IsFloorDiv(const NodeDef& node); bool IsFloorMod(const NodeDef& node); bool IsFusedBatchNorm(const NodeDef& node); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 36721527cc2..4e6868a9897 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5423,6 +5423,7 @@ cc_library( name = "required", deps = [ ":no_op", + ":file_slice_sendrecv_ops", ":fuserecv_ops", ":sendrecv_ops", ":slice_sendrecv_ops", @@ -5446,10 +5447,33 @@ tf_kernel_library( deps = REQUIRED_DEPS, ) +cc_library( + name = "slice_sendrecv_utils", + hdrs = [ + "slice_sendrecv_utils.h" + ], + srcs = [ + "slice_sendrecv_utils.cc", + ], + deps = [ + "//tensorflow/core:framework", + ] +) + tf_kernel_library( name = "slice_sendrecv_ops", prefix = "slice_sendrecv_ops", - deps = REQUIRED_DEPS, + deps = REQUIRED_DEPS + [ + ":slice_sendrecv_utils", + ], +) + +tf_kernel_library( + name = "file_slice_sendrecv_ops", + prefix = "file_slice_sendrecv_ops", + deps = REQUIRED_DEPS + [ + ":slice_sendrecv_utils", + ], ) tf_kernel_library( @@ -5534,6 +5558,26 @@ tf_cc_test( ], ) +tf_cc_test( + name = "file_slice_sendrecv_ops_test", + srcs = ["file_slice_sendrecv_ops_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":control_flow_ops", + ":cwise_op", + ":file_slice_sendrecv_ops", + ":logging_ops", + ":ops_testutil", + ":ops_util", + ":slice_sendrecv_ops", + ":whole_file_read_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "fuserecv_ops", prefix = "fuserecv_ops", diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops.cc b/tensorflow/core/kernels/file_slice_sendrecv_ops.cc new file mode 100644 index 00000000000..6bfe54363f9 --- /dev/null +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops.cc @@ -0,0 +1,482 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/file_slice_sendrecv_ops.h" +#include "tensorflow/core/kernels/slice_sendrecv_utils.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +//------------------------------------------------------------------------------ +// Functions of FileSliceSendOp. + +FileSliceSendOp::FileSliceSendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = \ + slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, + recv_device, send_device_incarnation, tensor_name); + + if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { + hostmem_sendrecv_ = false; + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("slice_size", &slice_size_)); +} + +void FileSliceSendOp::Compute(OpKernelContext* ctx) { + OP_REQUIRES(ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + + const Tensor& file_path_t = ctx->input(0); + if (!ctx->is_input_dead()) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(file_path_t.shape()), + errors::InvalidArgument("file_path is not a scalar: ", + file_path_t.shape().DebugString())); + } + + FrameAndIter frame_iter = \ + slice_sendrecv::GetFrameAndIter(ctx, hostmem_sendrecv_); + + // get element_bytes. + uint64 element_bytes = 0; + OP_REQUIRES_OK(ctx, GetElementBytes(ctx, file_path_t, element_bytes)); + + // send total_bytes. + // total_bytes is the TotalBytes of the Tensor that contains the contents of + // the file. please refer Tensor::TotalBytes() + uint64 total_bytes = element_bytes + sizeof(tstring); + OP_REQUIRES_OK(ctx, SendTotalBytes(ctx, frame_iter, total_bytes)); + // if input is dead, only send total_bytes dead tensor. + if (ctx->is_input_dead()) { + return; + } + + // if total bytes is smaller than slice size, send directly. + if (total_bytes <= slice_size_) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->input_alloc_attr(0); + + Rendezvous::ParsedKey parsed_key; + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, "_transfer_data", + frame_iter, &parsed_key.buf_); + VLOG(2) << "FileSliceSend " << parsed_key.buf_; + OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + Tensor data_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_STRING, TensorShape({}), &data_t)); + if (element_bytes > 0) { + OP_REQUIRES_OK(ctx, ReadFileToString(Env::Default(), + file_path_t.scalar()(), data_t.scalar().data())); + } + OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed_key,args, data_t, + ctx->is_input_dead())); + return; + } + + // send shape, in order to match the behavior of 'SliceSend'. + OP_REQUIRES_OK(ctx, SendScalarShape(ctx, frame_iter)); + + // send element bytes, in order to match the behavior of 'SliceSend'. + OP_REQUIRES_OK(ctx, SendElementBytes(ctx, frame_iter, element_bytes)); + + // send data. + OP_REQUIRES_OK(ctx, SendFileSlice(ctx, frame_iter, file_path_t, element_bytes)); +} + +Status FileSliceSendOp::GetElementBytes(OpKernelContext* ctx, + const Tensor& file_path_t, + uint64& element_bytes) { + + if (ctx->is_input_dead()) { + element_bytes = 0; + return Status::OK(); + } + + const string& file_path = file_path_t.scalar()(); + Env* env = Env::Default(); + + if (env->FileExists(file_path) != Status::OK()) { + element_bytes = 0; + return Status::OK(); + } + + return env->GetFileSize(file_path, &element_bytes); +} + +Status FileSliceSendOp::SendUInt64MetaMsg(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const string& name, + const uint64 val) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + + Rendezvous::ParsedKey parsed_key; + Tensor val_t; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_UINT64, TensorShape({}), &val_t)); + val_t.scalar()() = val; + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, name, frame_iter, + &parsed_key.buf_); + VLOG(2) << "FileSliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + return ctx->rendezvous()->Send(parsed_key, args, val_t, ctx->is_input_dead()); +} + +Status FileSliceSendOp::SendTotalBytes(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const uint64 total_bytes) { + return SendUInt64MetaMsg(ctx, frame_iter, "_slice_transfer_totalbytes", + total_bytes); +} + +Status FileSliceSendOp::SendScalarShape(OpKernelContext* ctx, + const FrameAndIter& frame_iter) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + Rendezvous::ParsedKey parsed_key; + + Tensor shape_t; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({0}), &shape_t)); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, + "_slice_transfer_shape", frame_iter, &parsed_key.buf_); + VLOG(2) << "FileSliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + + return ctx->rendezvous()->Send(parsed_key, args, shape_t, + ctx->is_input_dead()); +} + +Status FileSliceSendOp::SendElementBytes(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const uint64 element_bytes) { + return SendUInt64MetaMsg(ctx, frame_iter, "_slice_transfer_elements_bytes", + element_bytes); +} + +Status FileSliceSendOp::SendFileSlice(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const Tensor& file_path_t, + const uint64 element_bytes) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + Rendezvous::ParsedKey parsed_key; + + std::unique_ptr file; + Env* env = Env::Default(); + const string& file_path = file_path_t.scalar()(); + TF_RETURN_IF_ERROR(env->NewRandomAccessFile(file_path, &file)); + + // Slice Send. + int64 slice_num = element_bytes / slice_size_; + if (element_bytes % slice_size_ != 0) { + slice_num += 1; + } + Tensor data_t; + for (int64 i = 0; i < slice_num; i++) { + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), &data_t)); + uint64 start = i * slice_size_; + uint64 copy_size = slice_size_; + if (start > element_bytes - slice_size_) { + copy_size = element_bytes - start; + } + TF_RETURN_IF_ERROR(ReadFileSlice(file, start, copy_size, data_t)); + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(0), "_", + std::to_string(i)); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); + VLOG(2) << "FileSliceSend " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, + ctx->is_input_dead())); + } + + + return Status::OK(); +} + +Status FileSliceSendOp::ReadFileSlice( + const std::unique_ptr& file, + const uint64 pos, const uint64 offset, + Tensor& data_t) { + string* data_s = data_t.scalar().data(); + gtl::STLStringResizeUninitialized(data_s, offset); + char* data_p = gtl::string_as_array(data_s); + StringPiece result; + TF_RETURN_IF_ERROR(file->Read(pos, offset, &result, data_p)); + if (result.data() != data_p) { + memmove(data_p, result.data(), result.size()); + } + + return Status::OK(); +} + +REGISTER_KERNEL_BUILDER(Name("_FileSliceSend").Device(DEVICE_CPU), + FileSliceSendOp); +REGISTER_KERNEL_BUILDER(Name("_FileSliceSend").Device(DEVICE_DEFAULT), + FileSliceSendOp); + +//------------------------------------------------------------------------------ +// Functions of FileSliceRecvOp. + +FileSliceRecvOp::FileSliceRecvOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string send_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); + string recv_device; + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); + uint64 send_device_incarnation; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + string tensor_name; + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + key_prefix_ = \ + slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, + recv_device, send_device_incarnation, tensor_name); + if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { + hostmem_sendrecv_ = false; + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_dir", &recv_dir_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("slice_size", &slice_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("timeout_ms", &timeout_ms_)); +} + +void FileSliceRecvOp::Compute(OpKernelContext* ctx) { + OP_REQUIRES(ctx, ctx->rendezvous() != nullptr, + errors::Internal("Op kernel context needs to provide a rendezvous.")); + + FrameAndIter frame_iter = \ + slice_sendrecv::GetFrameAndIter(ctx, hostmem_sendrecv_); + + bool is_dead = false; + uint64 total_bytes = 0; + OP_REQUIRES_OK(ctx, RecvTotalBytes(ctx, frame_iter, is_dead, total_bytes)); + if (is_dead) { + return; + } + + // Create file path output. + Env* env = Env::Default(); + if (!env->FileExists(recv_dir_).ok()) { + OP_REQUIRES_OK(ctx, env->RecursivelyCreateDir(recv_dir_)); + } + const string &filename = GenerateRecvFileName(ctx->op_kernel().name()); + const string &file_path = io::JoinPath(recv_dir_, "tempfilerecv-"+filename); + Tensor* file_path_t = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &file_path_t)); + file_path_t->scalar()() = file_path; + + // if total bytes is smaller than slice size, recv directly. + if (total_bytes <= slice_size_) { + OP_REQUIRES_OK(ctx, RecvFile(ctx, frame_iter, file_path)); + return; + } + + // recv shape, in order to match the behavior of 'SliceRecv'. + TensorShape shape; + OP_REQUIRES_OK(ctx, RecvShape(ctx, frame_iter, shape)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(shape), + errors::InvalidArgument( + "FileSliceRecv only supports receiving a tensor with a scalar shape.")); + + // recv element_bytes, in order to match the behavior of 'SliceRecv'. + uint64 element_bytes = 0; + OP_REQUIRES_OK(ctx, RecvElementBytes(ctx, frame_iter, element_bytes)); + + // recv data. + OP_REQUIRES_OK(ctx, RecvFileSlice(ctx, frame_iter, element_bytes, file_path)); +} + +Status FileSliceRecvOp::RecvUInt64MetaMsg(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const string& name, bool &is_dead, + uint64& val) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + + Rendezvous::ParsedKey parsed_key; + Tensor val_t; + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, name, frame_iter, + &parsed_key.buf_); + VLOG(2) << "FileSliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->Recv(parsed_key, args, &val_t, &is_dead, timeout_ms_)); + if (!is_dead) { + val = val_t.scalar()(); + } + + return Status::OK(); +} + +Status FileSliceRecvOp::RecvTotalBytes(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + bool& is_dead, uint64& total_bytes) { + return RecvUInt64MetaMsg(ctx, frame_iter, "_slice_transfer_totalbytes", + is_dead, total_bytes); +} + +string FileSliceRecvOp::GenerateRecvFileName(const string& op_name) { + const std::vector& file_name_vec = absl::StrSplit(op_name, "/"); + return absl::StrJoin(file_name_vec, "_"); +} + +Status FileSliceRecvOp::RecvShape(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + TensorShape& shape) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = AllocatorAttributes(); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + + Rendezvous::ParsedKey parsed_key; + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, "_slice_transfer_shape", + frame_iter, &parsed_key.buf_); + VLOG(2) << "FileSliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + + Tensor shape_t; + bool is_dead; + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &shape_t, + &is_dead, timeout_ms_)); + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + auto shape_vec = shape_t.vec(); + const int64 num_elements = shape_t.NumElements(); + for (int64 i = 0; i < num_elements; i++) { + shape.AddDim(shape_vec(i)); + } + + return Status::OK(); +} + +Status FileSliceRecvOp::RecvElementBytes(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + uint64& element_bytes) { + bool is_dead = false; + Status s = \ + RecvUInt64MetaMsg(ctx, frame_iter, "_slice_transfer_elements_bytes", is_dead, + element_bytes); + CHECK_EQ(is_dead, false); + + return s; +} + +Status FileSliceRecvOp::RecvFile(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const string& file_path) { + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->output_alloc_attr(0); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + + Rendezvous::ParsedKey parsed_key; + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, "_transfer_data", + frame_iter, &parsed_key.buf_); + VLOG(2) << "FileSliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + Tensor data_t; + bool is_dead = false; + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, + &is_dead, timeout_ms_)); + + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + + // Write data_t to file. + Env* env = Env::Default(); + return WriteStringToFile(env, file_path, data_t.scalar()()); +} + +Status FileSliceRecvOp::RecvFileSlice(OpKernelContext* ctx, + const FrameAndIter& frame_iter, + const uint64 element_bytes, + const string& file_path) { + // create file + Env* env = Env::Default(); + std::unique_ptr file_ptr; + TF_RETURN_IF_ERROR(env->NewWritableFile(file_path, &file_ptr)); + + Rendezvous::Args args; + args.device_context = ctx->op_device_context(); + args.alloc_attrs = ctx->output_alloc_attr(0); + if (ctx->is_eager()) { + // NOTE(fishx): Only set cancellation_manager in eager mode. Because in + // Tensorflow 1.x, session (or graph_mgr) will abort the underlying + // rendezvous if it encounters any error. + args.cancellation_manager = ctx->cancellation_manager(); + } + Rendezvous::ParsedKey parsed_key; + + int64 slice_num = element_bytes / slice_size_; + if (element_bytes % slice_size_ != 0) { + slice_num += 1; + } + Tensor data_t; + bool is_dead = false; + for (int64 i = 0; i < slice_num; i++) { + std::string tensor_name_suffix = \ + strings::StrCat("_slice_transfer_data_", std::to_string(0), "_", + std::to_string(i)); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); + VLOG(2) << "FileSliceRecv " << parsed_key.buf_; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, + &is_dead, timeout_ms_)); + // This shouldn't be a dead tensor. + CHECK_EQ(is_dead, false); + file_ptr->Append(data_t.scalar()()); + } + + return Status::OK(); +} + +REGISTER_KERNEL_BUILDER(Name("_FileSliceRecv").Device(DEVICE_CPU), + FileSliceRecvOp); +REGISTER_KERNEL_BUILDER(Name("_FileSliceRecv").Device(DEVICE_DEFAULT), + FileSliceRecvOp); + +}; // End of namespace tensorflow diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops.h b/tensorflow/core/kernels/file_slice_sendrecv_ops.h new file mode 100644 index 00000000000..6701196d481 --- /dev/null +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FILE_SLICE_SENDRECV_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_FILE_SLICE_SENDRECV_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class FileSliceSendOp : public OpKernel { + public: + explicit FileSliceSendOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + // Variables. + string key_prefix_; + bool hostmem_sendrecv_; + int32 slice_size_; + + // Functions. + Status GetElementBytes(OpKernelContext* ctx, const Tensor& file_path_t, + uint64& element_bytes); + + Status SendUInt64MetaMsg(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const string& name, const uint64 val); + + Status SendTotalBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const uint64 total_bytes); + + Status SendScalarShape(OpKernelContext* ctx, const FrameAndIter& frame_iter); + + Status SendElementBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const uint64 element_bytes); + + Status SendFileSlice(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const Tensor& file_path_t, const uint64 element_bytes); + + Status ReadFileSlice(const std::unique_ptr& file, + const uint64 pos, const uint64 offset, Tensor& data_t); + + TF_DISALLOW_COPY_AND_ASSIGN(FileSliceSendOp); +}; + +class FileSliceRecvOp: public OpKernel { + public: + explicit FileSliceRecvOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + // Variables. + string key_prefix_; + bool hostmem_sendrecv_; + string recv_dir_; + int32 slice_size_; + int64 timeout_ms_; + + // Functions. + Status RecvUInt64MetaMsg(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const string& name, bool &is_dead, uint64& val); + + Status RecvTotalBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, + bool& is_dead, uint64& total_bytes); + + string GenerateRecvFileName(const string& op_name); + + Status RecvFile(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const string& file_path); + + Status RecvShape(OpKernelContext* ctx, const FrameAndIter& frame_iter, + TensorShape& shape); + + Status RecvElementBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, + uint64& element_bytes); + + Status RecvFileSlice(OpKernelContext* ctx, const FrameAndIter& frame_iter, + const uint64 element_bytes, const string& file_path); + + TF_DISALLOW_COPY_AND_ASSIGN(FileSliceRecvOp); +}; + +}; // End of namespace tensorflow + +#endif // End of macro TENSORFLOW_CORE_KERNELS_FILE_SLICE_SENDRECV_OPS_H_ diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc b/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc new file mode 100644 index 00000000000..931cd152253 --- /dev/null +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc @@ -0,0 +1,483 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +namespace { +// Implement a trivial version of the Rendezvous interface, to avoid +// clouding the benchmark results with the time spent in the various +// implementations, and to avoid the duplicate-send or duplicate-recv +// errors that would arise from running either benchmark in a loop. +class DummyRendezvous : public Rendezvous { + // Functions. + Status Send(const ParsedKey& key, const Args& args, const Tensor& val, + const bool is_dead) override { + std::string key_str = { key.FullKey().data(), key.FullKey().size() }; + mutex_lock l(mu_); + // consumer does not reach. + if (kv_.count(key_str) == 0) { + struct Var var; + var.type = send; + var.args = args; + var.data = val; + var.is_dead = is_dead; + + kv_[key_str] = var; + return Status::OK(); + } + + auto var = kv_[key_str]; + CHECK_EQ(var.type, recv); + var.done(Status::OK(), args, var.args, val, is_dead); + kv_.erase(key_str); + return Status::OK(); + } + void RecvAsync(const ParsedKey& key, const Args& args, + DoneCallback done) override { + std::string key_str = { key.FullKey().data(), key.FullKey().size() }; + + mutex_lock l(mu_); + // producer does not reach. + if (kv_.count(key_str) == 0) { + struct Var var; + var.type = recv; + var.args = args; + var.done = done; + + kv_[key_str] = var; + return; + } + + // auto var = kv_[key_str]; + auto var = kv_[key_str]; + CHECK_EQ(var.type, send); + done(Status::OK(), var.args, args, var.data, var.is_dead); + kv_.erase(key_str); + } + void StartAbort(const Status& status) override {} + + private: + enum RendezvousType { + send, + recv + }; + // Type define. + struct Var { + RendezvousType type; + Args args; + Tensor data; + bool is_dead; + DoneCallback done; + }; + + // Variables. + mutex mu_; + std::unordered_map kv_ GUARDED_BY(mu_); +}; + +//------------------------------------------------------------------------------ +// Utils. +Node* FileSliceSend(Graph* g, Node* filename, const string& tensor, + const string& sender, const uint64 sender_incarnation, + const string& receiver, const int32 slice_size) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("FileSliceSend"), "_FileSliceSend") + .Input(filename, 0) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Attr("slice_size", slice_size) + .Finalize(g, &ret)); + + return ret; +} + +Node* FileSliceRecv(Graph* g, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver, + const string& recv_dir, const int32 slice_size, + const int64 timeout_ms) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("FileSliceRecv"), "_FileSliceRecv") + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Attr("recv_dir", recv_dir) + .Attr("slice_size", slice_size) + .Attr("timeout_ms", timeout_ms) + .Finalize(g, &ret)); + + return ret; +} + +Node* SliceSend(Graph* g, Node* input, const string& tensor, + const string& sender, const uint64 sender_incarnation, + const string& receiver, const int32 slice_size) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_SliceSend") + .Input(input, 0) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Attr("slice_size", slice_size) + .Finalize(g, &ret)); + return ret; +} + +Node* SliceRecv(Graph* g, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver, + const int32 slice_size, const int64 timeout_ms) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_SliceRecv") + .Attr("tensor_type", DT_STRING) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast(sender_incarnation)) + .Attr("recv_device", receiver) + .Attr("slice_size", slice_size) + .Attr("timeout_ms", timeout_ms) + .Finalize(g, &ret)); + return ret; +} + +Node* ReadFile(Graph* g, Node* filename) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("ReadFile"), "ReadFile") + .Input(filename, 0) + .Finalize(g, &ret)); + + return ret; +} + +Node* WriteFile(Graph* g, Node* filename, Node* contents) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("WriteFile"), "WriteFile") + .Input(filename, 0) + .Input(contents, 0) + .Finalize(g, &ret)); + + return ret; +} + +Node* Equal(Graph* g, Node* x, Node* y) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("Equal"), "Equal") + .Input(x) + .Input(y) + .Finalize(g, &ret)); + return ret; +} + +Node* Assert(Graph* g, Node* condition, + std::vector& data) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assert") + .Input(condition) + .Input(data) + .Finalize(g, &ret)); + return ret; +} + +//------------------------------------------------------------------------------ +// Graph Constructor. + +static Graph* TransferFile(const std::string& test_type, + const int32 slice_size) { + Graph* g = new Graph(OpRegistry::Global()); + const int64 timeout_ms = 5000; + std::string recv_dir = "/tmp/FileSliceTransferTestRecv"; + std::string filename = "/tmp/FileSliceTransferTestSend/send_" + test_type; + std::string contents = \ + "The quick brown fox jumps over the lazy dog."; // 44 chars. + + // send filename node. + Tensor filename_t(DT_STRING, TensorShape({})); + filename_t.scalar().setConstant(filename); + Node* filename_n = test::graph::Constant(g, filename_t); + + // contents node. + Tensor contents_t(DT_STRING, TensorShape({})); + contents_t.scalar().setConstant(contents); + Node* contents_n = test::graph::Constant(g, contents_t); + + Node* write_file_n = WriteFile(g, filename_n, contents_n); + Node* send_n = \ + FileSliceSend(g, filename_n, test_type, "/cpu:0", 1, "/cpu:0", slice_size); + g->AddControlEdge(write_file_n, send_n); + + Node* recv_n = FileSliceRecv(g, test_type, "/cpu:0", 1, "/cpu:0", recv_dir, + slice_size, timeout_ms); + Node* read_file_n = ReadFile(g, recv_n); + Node* equal_n = Equal(g, contents_n, read_file_n); + + std::vector data_out; + data_out.emplace_back(contents_n, 0); + data_out.emplace_back(read_file_n, 0); + Assert(g, equal_n, data_out); + + return g; +} + +static Graph* FileSliceSendTransferFileToSliceRecv(const std::string& test_type, + const int32 slice_size) { + Graph* g = new Graph(OpRegistry::Global()); + const int64 timeout_ms = 5000; + std::string filename = "/tmp/FileSliceTransferTestSend/send_" + test_type; + std::string contents = \ + "The quick brown fox jumps over the lazy dog."; // 44 chars. + + // send filename node. + Tensor filename_t(DT_STRING, TensorShape({})); + filename_t.scalar().setConstant(filename); + Node* filename_n = test::graph::Constant(g, filename_t); + + // contents node. + Tensor contents_t(DT_STRING, TensorShape({})); + contents_t.scalar().setConstant(contents); + Node* contents_n = test::graph::Constant(g, contents_t); + + Node* write_file_n = WriteFile(g, filename_n, contents_n); + Node* send_n = \ + FileSliceSend(g, filename_n, test_type, "/cpu:0", 1, "/cpu:0", slice_size); + g->AddControlEdge(write_file_n, send_n); + + Node* recv_n = \ + SliceRecv(g, test_type, "/cpu:0", 1, "/cpu:0", slice_size, timeout_ms); + Node* equal_n = Equal(g, contents_n, recv_n); + + std::vector data_out; + data_out.emplace_back(contents_n, 0); + data_out.emplace_back(recv_n, 0); + Assert(g, equal_n, data_out); + + return g; +} + +static Graph* SliceSendTransferFileToFileSliceRecv(const std::string& test_type, + const int32 slice_size) { + Graph* g = new Graph(OpRegistry::Global()); + const int64 timeout_ms = 5000; + std::string recv_dir = "/tmp/FileSliceTransferTestRecv"; + std::string contents = \ + "The quick brown fox jumps over the lazy dog."; // 44 chars. + + // contents node. + Tensor contents_t(DT_STRING, TensorShape({})); + contents_t.scalar().setConstant(contents); + Node* contents_n = test::graph::Constant(g, contents_t); + + Node* send_n = \ + SliceSend(g, contents_n, test_type, "/cpu:0", 1, "/cpu:0", slice_size); + + Node* recv_n = FileSliceRecv(g, test_type, "/cpu:0", 1, "/cpu:0", recv_dir, + slice_size, timeout_ms); + Node* read_file_n = ReadFile(g, recv_n); + Node* equal_n = Equal(g, contents_n, read_file_n); + + std::vector data_out; + data_out.emplace_back(contents_n, 0); + data_out.emplace_back(read_file_n, 0); + Assert(g, equal_n, data_out); + + return g; +} + +static Graph* TransferDeadTensor() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 1024; + const int64 timeout_ms = 5000; + std::string recv_dir = "/tmp/FileSliceTransferTestRecv"; + std::string filename = "/tmp/FileSliceTransferTestSend/send_dead_tensor"; + + // val + Tensor val_t(DT_STRING, TensorShape({})); + val_t.scalar()() = filename; + Node* val_n = test::graph::Constant(g, val_t); + + Tensor pred_t(DT_BOOL, TensorShape({})); + pred_t.scalar()() = true; + Node* pred_n = test::graph::Constant(g, pred_t); + + Node* switch_n = test::graph::Switch(g, val_n, pred_n); + FileSliceSend(g, switch_n, "dead_tensor", "/cpu:0", 1, "/cpu:0", slice_size); + FileSliceRecv(g, "dead_tensor", "/cpu:0", 1, "/cpu:0", recv_dir, slice_size, + timeout_ms); + + return g; +} + +static Graph* FileSliceSendTransferDeadTensorToSliceRecv() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 1024; + const int64 timeout_ms = 5000; + std::string recv_dir = "/tmp/FileSliceTransferTestRecv"; + std::string filename = "/tmp/FileSliceTransferTestSend/send_dead_tensor"; + + // val + Tensor val_t(DT_STRING, TensorShape({})); + val_t.scalar()() = filename; + Node* val_n = test::graph::Constant(g, val_t); + + Tensor pred_t(DT_BOOL, TensorShape({})); + pred_t.scalar()() = true; + Node* pred_n = test::graph::Constant(g, pred_t); + + Node* switch_n = test::graph::Switch(g, val_n, pred_n); + FileSliceSend(g, switch_n, "dead_tensor", "/cpu:0", 1, "/cpu:0", slice_size); + SliceRecv(g, "dead_tensor", "/cpu:0", 1, "/cpu:0", slice_size, timeout_ms); + + return g; +} + +static Graph* SliceSendTransferDeadTensorToFileSliceRecv() { + Graph* g = new Graph(OpRegistry::Global()); + const int32 slice_size = 1024; + const int64 timeout_ms = 5000; + std::string recv_dir = "/tmp/FileSliceTransferTestRecv"; + std::string contents = \ + "The quick brown fox jumps over the lazy dog."; // 44 chars. + + // val + Tensor val_t(DT_STRING, TensorShape({})); + val_t.scalar()() = contents; + Node* val_n = test::graph::Constant(g, val_t); + + Tensor pred_t(DT_BOOL, TensorShape({})); + pred_t.scalar()() = true; + Node* pred_n = test::graph::Constant(g, pred_t); + + Node* switch_n = test::graph::Switch(g, val_n, pred_n); + SliceSend(g, switch_n, "dead_tensor", "/cpu:0", 1, "/cpu:0", slice_size); + FileSliceRecv(g, "dead_tensor", "/cpu:0", 1, "/cpu:0", recv_dir, slice_size, + timeout_ms); + + return g; +} + +static Graph* TransferSmallFile() { + return TransferFile("small_file", 1024); +} + +static Graph* TransferBigFile() { + return TransferFile("big_file", 16); +} + +static Graph* FileSliceSendTransferSmallFileToSliceRecv() { + return FileSliceSendTransferFileToSliceRecv("small_file", 1024); +} + +static Graph* FileSliceSendTransferBigFileToSliceRecv() { + return FileSliceSendTransferFileToSliceRecv("big_file", 16); +} + +static Graph* SliceSendTransferSmallFileToFileSliceRecv() { + return SliceSendTransferFileToFileSliceRecv("small_file", 1024); +} + +static Graph* SliceSendTransferBigFileToFileSliceRecv() { + return SliceSendTransferFileToFileSliceRecv("big_file", 16); +} + +//------------------------------------------------------------------------------ +// Test Function. + +static void BM_TransferSmallFile(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferSmallFile(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +static void BM_TransferBigFile(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferBigFile(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +static void BM_FileSliceSendTransferSmallFileToSliceRecv(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", FileSliceSendTransferSmallFileToSliceRecv(), nullptr, + nullptr, new DummyRendezvous).Run(iters); +} + +static void BM_FileSliceSendTransferBigFileToSliceRecv(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", FileSliceSendTransferBigFileToSliceRecv(), nullptr, + nullptr, new DummyRendezvous).Run(iters); +} + +static void BM_SliceSendTransferSmallFileToFileSliceRecv(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", SliceSendTransferSmallFileToFileSliceRecv(), nullptr, + nullptr, new DummyRendezvous).Run(iters); +} + +static void BM_SliceSendTransferBigFileToFileSliceRecv(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", SliceSendTransferBigFileToFileSliceRecv(), nullptr, + nullptr, new DummyRendezvous).Run(iters); +} + +static void BM_TransferDeadTensor(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", TransferDeadTensor(), nullptr, nullptr, + new DummyRendezvous).Run(iters); +} + +static void BM_FileSliceSendTransferDeadTensorToSliceRecv(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", FileSliceSendTransferDeadTensorToSliceRecv(), nullptr, + nullptr, new DummyRendezvous).Run(iters); +} + +static void BM_SliceSendTransferDeadTensorToFileSliceRecv(int iters) { + testing::UseRealTime(); + testing::ItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", SliceSendTransferDeadTensorToFileSliceRecv(), nullptr, + nullptr, new DummyRendezvous).Run(iters); +} + +BENCHMARK(BM_TransferSmallFile); +BENCHMARK(BM_TransferBigFile); +BENCHMARK(BM_FileSliceSendTransferSmallFileToSliceRecv); +BENCHMARK(BM_FileSliceSendTransferBigFileToSliceRecv); +BENCHMARK(BM_SliceSendTransferSmallFileToFileSliceRecv); +BENCHMARK(BM_SliceSendTransferBigFileToFileSliceRecv); +BENCHMARK(BM_TransferDeadTensor); +BENCHMARK(BM_FileSliceSendTransferDeadTensorToSliceRecv); +BENCHMARK(BM_SliceSendTransferDeadTensorToFileSliceRecv); + +} // End of anonymous namespace + +} // End of namespace tensorflow diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.cc b/tensorflow/core/kernels/slice_sendrecv_ops.cc index f09f314ae10..25f1a4e8738 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops.cc +++ b/tensorflow/core/kernels/slice_sendrecv_ops.cc @@ -14,41 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/slice_sendrecv_ops.h" +#include "tensorflow/core/kernels/slice_sendrecv_utils.h" namespace tensorflow { -//------------------------------------------------------------------------------ -// Utils. -static string GetSliceRendezvousKeyPrefix(const string& send_device, - const string& recv_device, - const uint64 send_device_incarnation, - const string& tensor_name) { - return strings::StrCat(send_device, ";", - strings::FpToString(send_device_incarnation), ";", - recv_device, ";", tensor_name); -} - -static void GetSliceRendezvousKey(const string& key_prefix, - const string& tensor_name_suffix, - const FrameAndIter& frame_iter, string* key) { - key->clear(); - strings::StrAppend(key, key_prefix, tensor_name_suffix, ";", - frame_iter.frame_id, ":", frame_iter.iter_id); -} - -static FrameAndIter GetFrameAndIter(OpKernelContext* ctx, - bool hostmem_sendrecv) { - if (hostmem_sendrecv && ctx->call_frame() != nullptr) { - // Host memory send/recv pairs are added by - // common_runtime/memory_types.cc. When the pair of nodes are - // added inside a function, we need to use the function call frame - // to formulate the unique rendezvous key. - return FrameAndIter(reinterpret_cast(ctx->call_frame()), 0); - } else { - return ctx->frame_iter(); - } -} - //------------------------------------------------------------------------------ // Functions of SliceSendOp. @@ -64,8 +33,9 @@ SliceSendOp::SliceSendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { string tensor_name; OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); key_prefix_ = \ - GetSliceRendezvousKeyPrefix(send_device, recv_device, - send_device_incarnation, tensor_name); + slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, + recv_device, send_device_incarnation, tensor_name); + if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; } @@ -79,7 +49,8 @@ void SliceSendOp::Compute(OpKernelContext* ctx) { errors::Internal("Op kernel context needs to provide a rendezvous.")); const Tensor& input_t = ctx->input(0); - FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); + FrameAndIter frame_iter = \ + slice_sendrecv::GetFrameAndIter(ctx, hostmem_sendrecv_); // send total_bytes. OP_REQUIRES_OK(ctx, SendTotalBytes(ctx, frame_iter, input_t)); @@ -95,8 +66,8 @@ void SliceSendOp::Compute(OpKernelContext* ctx) { args.alloc_attrs = ctx->input_alloc_attr(0); Rendezvous::ParsedKey parsed_key; - GetSliceRendezvousKey(key_prefix_, "_transfer_data", frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, "_transfer_data", + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed_key, args, input_t, @@ -124,11 +95,11 @@ Status SliceSendOp::SendTotalBytes(OpKernelContext* ctx, Rendezvous::ParsedKey parsed_key; Tensor total_bytes_t; - TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({}), + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_UINT64, TensorShape({}), &total_bytes_t)); - total_bytes_t.scalar()() = input_t.TotalBytes(); - GetSliceRendezvousKey(key_prefix_, "_slice_transfer_totalbytes", frame_iter, - &parsed_key.buf_); + total_bytes_t.scalar()() = input_t.TotalBytes(); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, + "_slice_transfer_totalbytes", frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); return ctx->rendezvous()->Send(parsed_key, args, total_bytes_t, @@ -152,8 +123,8 @@ Status SliceSendOp::SendShape(OpKernelContext* ctx, for (int i = 0; i < rank; i++) { shape_vec(i) = shape.dim_size(i); } - GetSliceRendezvousKey(key_prefix_, "_slice_transfer_shape", frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, + "_slice_transfer_shape", frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); return ctx->rendezvous()->Send(parsed_key, args, shape_t, @@ -168,21 +139,21 @@ Status SliceSendOp::SendString(OpKernelContext* ctx, args.alloc_attrs = AllocatorAttributes(); Rendezvous::ParsedKey parsed_key; - // send elements size. - Tensor elements_size_t; - TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, input_t.shape(), - &elements_size_t)); + // send elements bytes. + Tensor elements_bytes_t; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_UINT64, input_t.shape(), + &elements_bytes_t)); int64 num_elements = input_t.NumElements(); auto input_flat = input_t.flat(); - auto elements_size_flat = elements_size_t.flat(); + auto elements_bytes_flat = elements_bytes_t.flat(); for (int64 i = 0; i < num_elements; i++) { - elements_size_flat(i) = input_flat(i).size(); + elements_bytes_flat(i) = input_flat(i).size(); } - GetSliceRendezvousKey(key_prefix_, "_slice_transfer_elements_size", - frame_iter, &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, + "_slice_transfer_elements_bytes", frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, elements_size_t, + TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, elements_bytes_t, ctx->is_input_dead())); // send data. @@ -196,8 +167,8 @@ Status SliceSendOp::SendString(OpKernelContext* ctx, data_t.scalar()() = elem; std::string tensor_name_suffix = \ strings::StrCat("_slice_transfer_data_", std::to_string(i)); - GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, @@ -218,7 +189,10 @@ Status SliceSendOp::SendStringSlice(OpKernelContext* ctx, args.alloc_attrs = ctx->input_alloc_attr(0); Rendezvous::ParsedKey parsed_key; - int64 slice_num = (elem.size() + slice_size_ - 1) / slice_size_; + int64 slice_num = elem.size() / slice_size_; + if (elem.size() % slice_size_ != 0) { + slice_num += 1; + } Tensor data_t; for (int64 i = 0; i < slice_num; i++) { TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), &data_t)); @@ -231,8 +205,8 @@ Status SliceSendOp::SendStringSlice(OpKernelContext* ctx, std::string tensor_name_suffix = \ strings::StrCat("_slice_transfer_data_", std::to_string(index), "_", std::to_string(i)); - GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, @@ -252,12 +226,15 @@ Status SliceSendOp::SendBasicType(OpKernelContext* ctx, // send data. Tensor data_t; - int64 bytes_num = input_t.TotalBytes(); - int64 slice_num = (bytes_num + slice_size_ - 1) / slice_size_; + size_t bytes_num = input_t.TotalBytes(); + int64 slice_num = bytes_num / slice_size_; + if (bytes_num % slice_size_ != 0) { + slice_num += 1; + } unsigned char* input_base = reinterpret_cast(input_t.data()); for (int64 i = 0; i < slice_num; i++) { - int64 start = i * slice_size_; - int64 copy_size = slice_size_; + size_t start = i * slice_size_; + size_t copy_size = slice_size_; if (start > bytes_num - slice_size_) { copy_size = bytes_num - start; } @@ -267,8 +244,8 @@ Status SliceSendOp::SendBasicType(OpKernelContext* ctx, std::memcpy(data_base, input_base+start, copy_size); std::string tensor_name_suffix = \ strings::StrCat("_slice_transfer_data_", std::to_string(i)); - GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, @@ -296,8 +273,8 @@ SliceRecvOp::SliceRecvOp(OpKernelConstruction* ctx) : OpKernel(ctx) { string tensor_name; OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); key_prefix_ = \ - GetSliceRendezvousKeyPrefix(send_device, recv_device, - send_device_incarnation, tensor_name); + slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, + recv_device, send_device_incarnation, tensor_name); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; } @@ -311,11 +288,12 @@ void SliceRecvOp::Compute(OpKernelContext* ctx) { ctx, ctx->rendezvous() != nullptr, errors::Internal("Op kernel context needs to provide a rendezvous.")); - FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); + FrameAndIter frame_iter = \ + slice_sendrecv::GetFrameAndIter(ctx, hostmem_sendrecv_); bool is_dead; // recv total_bytes. - int64 total_bytes; + uint64 total_bytes; OP_REQUIRES_OK(ctx, RecvTotalBytes(ctx, frame_iter, is_dead, total_bytes)); if (is_dead) { return; @@ -334,8 +312,8 @@ void SliceRecvOp::Compute(OpKernelContext* ctx) { } Rendezvous::ParsedKey parsed_key; - GetSliceRendezvousKey(key_prefix_, "_transfer_data", frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, "_transfer_data", + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); Tensor data_t; @@ -364,7 +342,7 @@ void SliceRecvOp::Compute(OpKernelContext* ctx) { Status SliceRecvOp::RecvTotalBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, - bool& is_dead, int64& total_bytes) { + bool& is_dead, uint64& total_bytes) { Rendezvous::Args args; args.device_context = ctx->op_device_context(); args.alloc_attrs = AllocatorAttributes(); @@ -377,14 +355,14 @@ Status SliceRecvOp::RecvTotalBytes(OpKernelContext* ctx, Rendezvous::ParsedKey parsed_key; Tensor total_bytes_t; - GetSliceRendezvousKey(key_prefix_, "_slice_transfer_totalbytes", frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, + "_slice_transfer_totalbytes", frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &total_bytes_t, &is_dead, timeout_ms_)); if (!is_dead) { - total_bytes = total_bytes_t.scalar()(); + total_bytes = total_bytes_t.scalar()(); } return Status::OK(); @@ -404,8 +382,8 @@ Status SliceRecvOp::RecvShape(OpKernelContext* ctx, } Rendezvous::ParsedKey parsed_key; - GetSliceRendezvousKey(key_prefix_, "_slice_transfer_shape", frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, "_slice_transfer_shape", + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); @@ -439,27 +417,27 @@ Status SliceRecvOp::RecvString(OpKernelContext* ctx, Rendezvous::ParsedKey parsed_key; bool is_dead; - // recv elements size. - GetSliceRendezvousKey(key_prefix_, "_slice_transfer_elements_size", - frame_iter, &parsed_key.buf_); + // recv elements bytes. + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, + "_slice_transfer_elements_bytes", frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - Tensor elements_size_t; - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &elements_size_t, + Tensor elements_bytes_t; + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &elements_bytes_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); - auto elements_size_flat = elements_size_t.flat(); + auto elements_bytes_flat = elements_bytes_t.flat(); int64 num_elements = shape.num_elements(); args.alloc_attrs = ctx->output_alloc_attr(0); Tensor data_t; auto output_flat = output_t->flat(); for (int64 i = 0; i < num_elements; i++) { - if (elements_size_flat(i) <= slice_size_) { + if (elements_bytes_flat(i) <= slice_size_) { std::string tensor_name_suffix = \ strings::StrCat("_slice_transfer_data_", std::to_string(i)); - GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, @@ -469,7 +447,7 @@ Status SliceRecvOp::RecvString(OpKernelContext* ctx, output_flat(i) = data_t.scalar()(); } else { TF_RETURN_IF_ERROR(RecvStringSlice(ctx, frame_iter, i, - elements_size_flat(i), output_flat)); + elements_bytes_flat(i), output_flat)); } } @@ -478,7 +456,8 @@ Status SliceRecvOp::RecvString(OpKernelContext* ctx, Status SliceRecvOp::RecvStringSlice(OpKernelContext* ctx, const FrameAndIter& frame_iter, - const int64 index, const int64 element_size, + const int64 index, + const uint64 element_bytes, TTypes::Flat& output_flat) { Rendezvous::Args args; args.device_context = ctx->op_device_context(); @@ -491,15 +470,18 @@ Status SliceRecvOp::RecvStringSlice(OpKernelContext* ctx, } Rendezvous::ParsedKey parsed_key; - int64 slice_num = (element_size + slice_size_ - 1) / slice_size_; + int64 slice_num = element_bytes / slice_size_; + if (element_bytes % slice_size_ != 0) { + slice_num += 1; + } Tensor data_t; bool is_dead = false; for (int64 i = 0; i < slice_num; i++) { std::string tensor_name_suffix = \ strings::StrCat("_slice_transfer_data_", std::to_string(index), "_", std::to_string(i)); - GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, @@ -514,7 +496,7 @@ Status SliceRecvOp::RecvStringSlice(OpKernelContext* ctx, Status SliceRecvOp::RecvBasicType(OpKernelContext* ctx, const FrameAndIter& frame_iter, - const int64 total_bytes, + const uint64 total_bytes, Tensor*& output_t) { Rendezvous::Args args; args.device_context = ctx->op_device_context(); @@ -529,19 +511,22 @@ Status SliceRecvOp::RecvBasicType(OpKernelContext* ctx, Tensor data_t; bool is_dead = false; - int64 slice_num = (total_bytes + slice_size_ - 1) / slice_size_; + int64 slice_num = total_bytes / slice_size_; + if (total_bytes % slice_size_ != 0) { + slice_num += 1; + } unsigned char* output_base = \ reinterpret_cast(output_t->data()); for (int64 i = 0; i < slice_num; i++) { - int64 start = i * slice_size_; - int64 copy_size = slice_size_; + uint64 start = i * slice_size_; + uint64 copy_size = slice_size_; if (start > total_bytes - slice_size_) { copy_size = total_bytes - start; } std::string tensor_name_suffix = \ strings::StrCat("_slice_transfer_data_", std::to_string(i)); - GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, frame_iter, - &parsed_key.buf_); + slice_sendrecv::GetSliceRendezvousKey(key_prefix_, tensor_name_suffix, + frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.h b/tensorflow/core/kernels/slice_sendrecv_ops.h index df55c080aa1..43429bff32f 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops.h +++ b/tensorflow/core/kernels/slice_sendrecv_ops.h @@ -66,7 +66,7 @@ class SliceRecvOp : public OpKernel { // Fucntions. Status RecvTotalBytes(OpKernelContext* ctx, const FrameAndIter& frame_iter, - bool& is_dead, int64& total_bytes); + bool& is_dead, uint64& total_bytes); Status RecvShape(OpKernelContext* ctx, const FrameAndIter& frame_iter, TensorShape& shape); @@ -75,11 +75,11 @@ class SliceRecvOp : public OpKernel { const TensorShape& shape, Tensor*& output_t); Status RecvStringSlice(OpKernelContext* ctx, const FrameAndIter& frame_iter, - const int64 index, const int64 element_size, + const int64 index, const uint64 element_bytes, TTypes::Flat& output_flat); Status RecvBasicType(OpKernelContext* ctx, const FrameAndIter& frame_iter, - const int64 total_bytes, Tensor*& output_t); + const uint64 total_bytes, Tensor*& output_t); TF_DISALLOW_COPY_AND_ASSIGN(SliceRecvOp); }; diff --git a/tensorflow/core/kernels/slice_sendrecv_utils.cc b/tensorflow/core/kernels/slice_sendrecv_utils.cc new file mode 100644 index 00000000000..56c2166c650 --- /dev/null +++ b/tensorflow/core/kernels/slice_sendrecv_utils.cc @@ -0,0 +1,53 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/slice_sendrecv_utils.h" + +namespace tensorflow { + +namespace slice_sendrecv { + +string GetSliceRendezvousKeyPrefix(const string& send_device, + const string& recv_device, + const uint64 send_device_incarnation, + const string& tensor_name) { + return strings::StrCat(send_device, ";", + strings::FpToString(send_device_incarnation), ";", + recv_device, ";", tensor_name); +} + +void GetSliceRendezvousKey(const string& key_prefix, + const string& tensor_name_suffix, + const FrameAndIter& frame_iter, string* key) { + key->clear(); + strings::StrAppend(key, key_prefix, tensor_name_suffix, ";", + frame_iter.frame_id, ":", frame_iter.iter_id); +} + +FrameAndIter GetFrameAndIter(OpKernelContext* ctx, bool hostmem_sendrecv) { + if (hostmem_sendrecv && ctx->call_frame() != nullptr) { + // Host memory send/recv pairs are added by + // common_runtime/memory_types.cc. When the pair of nodes are + // added inside a function, we need to use the function call frame + // to formulate the unique rendezvous key. + return FrameAndIter(reinterpret_cast(ctx->call_frame()), 0); + } else { + return ctx->frame_iter(); + } +} + +}; // End of namespace slice_sendrecv + +}; // End of namespace tensorflow diff --git a/tensorflow/core/kernels/slice_sendrecv_utils.h b/tensorflow/core/kernels/slice_sendrecv_utils.h new file mode 100644 index 00000000000..3605eece2ca --- /dev/null +++ b/tensorflow/core/kernels/slice_sendrecv_utils.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SLICE_SENDRECV_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_SLICE_SENDRECV_UTILS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace slice_sendrecv { + +extern string GetSliceRendezvousKeyPrefix(const string& send_device, + const string& recv_device, + const uint64 send_device_incarnation, + const string& tensor_name); + +extern void GetSliceRendezvousKey(const string& key_prefix, + const string& tensor_name_suffix, + const FrameAndIter& frame_iter, string* key); + +extern FrameAndIter GetFrameAndIter(OpKernelContext* ctx, + bool hostmem_sendrecv); + +}; // End of namespace slice_sendrecv + +}; // End of namespace tensorflow + +#endif // End of macro TENSORFLOW_CORE_KERNELS_SLICE_SENDRECV_UTILS_H_ diff --git a/tensorflow/core/ops/file_slice_sendrecv_ops.cc b/tensorflow/core/ops/file_slice_sendrecv_ops.cc new file mode 100644 index 00000000000..c7eb20d1358 --- /dev/null +++ b/tensorflow/core/ops/file_slice_sendrecv_ops.cc @@ -0,0 +1,77 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/common_shape_fns.h" + +namespace tensorflow { +REGISTER_OP("_FileSliceSend") + .Input("file_path: string") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Attr("slice_size: int >= 1") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the file from send_device to recv_device. +Supports sending the file of any size. + +file_path: The file to send. +tensor_name: The name of the tensor to send. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +slice_size: The maximum number of bytes transferred at one time. +)doc"); + +REGISTER_OP("_FileSliceRecv") + .Output("file_path: string") + .Attr("tensor_name: string") + .Attr("send_device: string") + .Attr("send_device_incarnation: int") + .Attr("recv_device: string") + .Attr("client_terminated: bool = false") + .Attr("recv_dir: string") + .Attr("slice_size: int >= 1") + .Attr("timeout_ms: int >= 0 = 300000") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Receives the file from send_device on recv_device. +Supports recving the file of any size. + +file_path: The file to receive. +tensor_name: The name of the tensor to receive. +send_device: The name of the device sending the tensor. +send_device_incarnation: The current incarnation of send_device. +recv_device: The name of the device receiving the tensor. +client_terminated: If set to true, this indicates that the node was added + to the graph as a result of a client-side feed or fetch of Tensor data, + in which case the corresponding send or recv is expected to be managed + locally by the caller. +recv_dir: the directory to store received file. +slice_size: The maximum number of bytes transferred at one time. +timeout_ms: The maximum wait time for receiving a tensor. +)doc"); + +}; // End of namespace tensorflow