diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 17935eb8982..ead121b30c8 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -34,11 +34,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { namespace { uint64 kGlobalStepId = 0x100000000000000uLL; + int64 kFlowControlMaxSize = 16; } // namespace anonymous static void StartAbortRendevous(Rendezvous* rendez, const Status& s) { @@ -127,6 +129,23 @@ void BaseRendezvousMgr::FuseRecvLocalAsync( rendez->FuseRecvLocalAsync(parsed_keys, std::move(done_cb)); } +void BaseRendezvousMgr::FlowControlRecvLocalAsync(int64 step_id, + const StringPiece& tag, const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) { + auto rendez = FindOrCreate(step_id); + using namespace std::placeholders; + Rendezvous::DoneCallback done_cb = std::bind( + [rendez](Rendezvous::DoneCallback done, + // Begin unbound arguments. + const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { + rendez->Unref(); + done(s, send_args, recv_args, v, dead); + }, + std::move(done), _1, _2, _3, _4, _5); + rendez->FlowControlRecvLocalAsync(tag, parsed, std::move(done_cb)); +} + void BaseRendezvousMgr::Cleanup(int64 step_id) { Rendezvous* rendez = nullptr; { @@ -174,7 +193,17 @@ BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id) : env_(env), step_id_(step_id), local_(NewLocalRendezvous()), - session_(nullptr) {} + session_(nullptr), + flow_control_num_(0) { + Status s = ReadInt64FromEnvVar("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE", + kFlowControlMaxSize, &flow_control_max_size_); + if (!s.ok()) { + LOG(ERROR) << "Read REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE env error: " + << s.error_message(); + } + VLOG(2) << "BaseRemoteRendezvous set flow control max size: " + << flow_control_max_size_; +} BaseRemoteRendezvous::~BaseRemoteRendezvous() { CHECK(active_.empty()); @@ -221,6 +250,16 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { std::move(fuse_call.done)); } + std::vector deferred_flow_control_calls; + { + mutex_lock l(mu_); + std::swap(deferred_flow_control_calls, deferred_flow_control_calls_); + } + for (auto& fc_call : deferred_flow_control_calls) { + FlowControlRecvLocalAsyncInternal(fc_call.tag, fc_call.parsed, + std::move(fc_call.done)); + } + return Status::OK(); } @@ -271,6 +310,43 @@ Status BaseRemoteRendezvous::Send(const ParsedKey& parsed, return local_->Send(parsed, args, val, mu, is_dead); } +Status BaseRemoteRendezvous::FlowControlSend(const StringPiece& tag, + const ParsedKey& parsed, + const Args& args, + const Tensor& val, + const bool is_dead, + const int64 timeout_millis) { + VLOG(1) << "BaseRemoteRendezvous FlowControlSend " << this << " " + << parsed.FullKey(); + const std::string tag_string(tag.data(), tag.size()); + { + mutex_lock l(mu_); + while(status_.ok() && flow_control_num_ >= flow_control_max_size_) { + if (flow_control_cv_.wait_for( + l, std::chrono::milliseconds(timeout_millis)) == \ + std::cv_status::timeout) { + return errors::DeadlineExceeded("FlowControlSend has timed out."); + } + } + + if (!status_.ok()) return status_; + DCHECK(is_initialized_locked()); + if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { + return errors::InvalidArgument( + "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", + session_->worker_name); + } + + flow_control_num_++; + if (flow_control_counters_.count(tag_string) == 0) { + flow_control_counters_[tag_string] = 0; + } + flow_control_counters_[tag_string]++; + } + // Buffers "val" and "device_context" in local_. + return local_->Send(parsed, args, val, is_dead); +} + Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, bool is_src) { // Cache session pointer to avoid repeatedly taking & releasing the lock @@ -413,6 +489,63 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, } } +void BaseRemoteRendezvous::FlowControlRecvAsync(const StringPiece& tag, + const ParsedKey& parsed, + const Args& recv_args, + DoneCallback done) { + VLOG(1) << "RemoteRendezvous FlowControlRecvAsync " << this + << " " << tag << " " << parsed.FullKey(); + + Status s = ValidateDevices(parsed, false /*!is_src*/); + if (s.ok() && !is_initialized()) { + s.Update(errors::Internal( + "FlowControlRecvAsync called when uninitialized (key:", + parsed.FullKey(), ").")); + } + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), false); + return; + } + + // Are src and dst in the same worker? + if (IsSameWorker(parsed.src, parsed.dst)) { + // Recv the tensor from local_. + local_->RecvAsync( + parsed, recv_args, + [this, tag, parsed, done]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { + VLOG(2) << "RemoteRendezvous Finished Recv " << this << " " + << parsed.FullKey(); + Tensor* out = new Tensor; + StatusCallback final_callback = [done, send_args, recv_args, out, + is_dead](const Status& s) { + done(s, send_args, recv_args, *out, is_dead); + delete out; + }; + + if (status.ok()) { + SameWorkerRecvDone(parsed, send_args, recv_args, in, out, + std::move(final_callback)); + const std::string tag_string(tag.data(), tag.size()); + { + mutex_lock l(mu_); + flow_control_num_--; + DCHECK(flow_control_counters_.count(tag_string) != 0); + flow_control_counters_[tag_string]--; + } + flow_control_cv_.notify_one(); + } else { + final_callback(status); + } + }); + return; + } else { + FlowControlRecvFromRemoteAsync(tag, parsed, recv_args, std::move(done)); + } + +} + void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, DoneCallback done) { { @@ -600,6 +733,58 @@ void BaseRemoteRendezvous::FuseRecvLocalAsyncInternal( } } +void BaseRemoteRendezvous::FlowControlRecvLocalAsync(const StringPiece& tag, + const ParsedKey& parsed, + DoneCallback done) { + { + mutex_lock l(mu_); + if (!is_initialized_locked()) { + // FlowControlRecvLocalAsync can be called (due to an incoming RecvTensor + // RPC from a remote worker) before the RunStep (or PartialRunStep) RPC + // from the master arrives. RecvLocalAsync thus buffers the arguments + // until after the RemoteRendezvous is Initialize()'d, when it completes + // the rendezvous logic. At some point after Initialize() is called, a + // Tensor is produced locally that will then be sent in response to the + // incoming RPC. + DeferredFlowControlCall call(tag, parsed, std::move(done)); + deferred_flow_control_calls_.push_back(call); + return; + } + } + FlowControlRecvLocalAsyncInternal(tag, parsed, std::move(done)); +} + +void BaseRemoteRendezvous::FlowControlRecvLocalAsyncInternal( + const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) { + Status s = ValidateDevices(parsed, true /* is_src */); + if (!s.ok()) { + done(s, Args(), Args(), Tensor(), false); + return; + } + + using namespace std::placeholders; + Rendezvous::DoneCallback done_cb = std::bind( + [this, tag](Rendezvous::DoneCallback done, + // Begin unbound arguments. + const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { + done(s, send_args, recv_args, v, dead); + if (s.ok()) { + const std::string tag_string(tag.data(), tag.size()); + { + mutex_lock l(mu_); + flow_control_num_--; + DCHECK(flow_control_counters_.count(tag_string) != 0); + flow_control_counters_[tag_string]--; + } + flow_control_cv_.notify_one(); + } + }, + std::move(done), _1, _2, _3, _4, _5); + + local_->RecvAsync(parsed, Args(), std::move(done_cb)); +} + void BaseRemoteRendezvous::FuseRecvFromRemoteAsync( const std::vector& parsed_keys, const Rendezvous::Args& args, @@ -607,6 +792,12 @@ void BaseRemoteRendezvous::FuseRecvFromRemoteAsync( CHECK(false) << "FuseRecvFromRemoteAsync Unimplemented"; } +void BaseRemoteRendezvous::FlowControlRecvFromRemoteAsync( + const StringPiece& tag, const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, DoneCallback done) { + CHECK(false) << "FlowControlRecvFromRemoteAsync Unimplemented."; +} + void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, RefDoneCallback done) { @@ -636,6 +827,19 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, } } +int64 BaseRemoteRendezvous::GetAllFlowControlItemNum() { + mutex_lock l(mu_); + return flow_control_num_; +} + +int64 BaseRemoteRendezvous::GetFlowControlItemNum(StringPiece tag) { + const std::string tag_string(tag.data(), tag.size()); + mutex_lock l(mu_); + if (flow_control_counters_.count(tag_string) == 0) + return 0; + return flow_control_counters_[tag_string]; +} + void BaseRemoteRendezvous::StartAbort(const Status& s) { CHECK(!s.ok()); // Use a "derived" status as the status for the rendezvous. Derived @@ -656,7 +860,10 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) { } active_.clear(); } + flow_control_num_ = 0; + flow_control_counters_.clear(); } + flow_control_cv_.notify_all(); } void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call, @@ -707,4 +914,8 @@ BaseRemoteRendezvous::DeferredFuseCall::DeferredFuseCall( const std::vector& parsed_keys, FuseDoneCallback done) : parsed_keys(parsed_keys), done(std::move(done)) {} +BaseRemoteRendezvous::DeferredFlowControlCall::DeferredFlowControlCall( + const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) + : tag(tag), parsed(parsed), done(std::move(done)) {} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h index b65e59436c0..fc72d9bedfc 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ #include +#include #include #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" @@ -86,6 +87,10 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { const std::vector& parsed_keys, Rendezvous::FuseDoneCallback done) override; + void FlowControlRecvLocalAsync(int64 step_id, const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) override; + // Removes rendezvous for "step_id". // // TODO(zhifengc): Have a background thread in worker that @@ -140,6 +145,11 @@ class BaseRemoteRendezvous : public RemoteRendezvous { Status Send(const ParsedKey& key, const Rendezvous::Args& args, Tensor* val, mutex* mu, const bool is_dead) override; + Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead, + const int64 timeout_millis) override; + // This method is called only by the RecvOp. It tests to see // whether the value will be produced by a local or remote device // and handles accordingly. In the local case it forwards to @@ -147,6 +157,10 @@ class BaseRemoteRendezvous : public RemoteRendezvous { void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; + void FlowControlRecvAsync(const StringPiece& tag, + const ParsedKey& parsed_key, + const Args& args, DoneCallback done) override; + void StartAbort(const Status& status) override; // This method is called only by the local Worker, forwarded through @@ -171,10 +185,18 @@ class BaseRemoteRendezvous : public RemoteRendezvous { void FuseRecvLocalSync(const std::vector& parsed_keys, FuseDoneCallback done); + void FlowControlRecvLocalAsync(const StringPiece& tag, + const ParsedKey& parsed, DoneCallback done); + // For ref send/recv void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, RefDoneCallback done) override; + // Obtain statistical information + int64 GetAllFlowControlItemNum() override; + + int64 GetFlowControlItemNum(StringPiece tag) override; + protected: virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, @@ -185,6 +207,10 @@ class BaseRemoteRendezvous : public RemoteRendezvous { const Rendezvous::Args& args, FuseDoneCallback done); + virtual void FlowControlRecvFromRemoteAsync(const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, DoneCallback done); + // Returns true if "src" and "dst" are located in the same worker, // and hence may use a local rendezvous. virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, @@ -210,6 +236,12 @@ class BaseRemoteRendezvous : public RemoteRendezvous { mutable mutex mu_; + // For Flow Control. + int64 flow_control_max_size_; + int64 flow_control_num_ GUARDED_BY(mu_); + std::unordered_map flow_control_counters_ GUARDED_BY(mu_); + tensorflow::condition_variable flow_control_cv_; + // Status given by StartAbort() if any. Status status_ GUARDED_BY(mu_); WorkerSession* session_ GUARDED_BY(mu_); // Not owned. @@ -233,6 +265,16 @@ class BaseRemoteRendezvous : public RemoteRendezvous { }; std::vector deferred_fuse_calls_ GUARDED_BY(mu_); + struct DeferredFlowControlCall { + const StringPiece tag; + const ParsedKey parsed; + DoneCallback done; + + DeferredFlowControlCall(const StringPiece& tag, const ParsedKey& parsed, + DoneCallback done); + }; + std::vector deferred_flow_control_calls_ GUARDED_BY(mu_); + typedef std::function InactiveCallback; // Active outstanding RecvTensor calls. @@ -262,6 +304,9 @@ class BaseRemoteRendezvous : public RemoteRendezvous { void FuseRecvLocalAsyncInternal(const std::vector& parsed_keys, FuseDoneCallback done); + void FlowControlRecvLocalAsyncInternal(const StringPiece& tag, + const ParsedKey& parsed, + DoneCallback done); TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); }; diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h index caf4af97ac2..abc971c4552 100644 --- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -40,6 +40,11 @@ class RemoteRendezvous : public Rendezvous { public: // Fully construct the RemoteRendezvous. virtual Status Initialize(WorkerSession* session) = 0; + + // Obtain statistical information + virtual int64 GetAllFlowControlItemNum() = 0; + + virtual int64 GetFlowControlItemNum(StringPiece tag) = 0; }; // RendezvousMgr keeps track of a set of local rendezvous instances. @@ -87,7 +92,11 @@ class RendezvousMgrInterface { virtual void FuseRecvLocalAsync( int64 step_id, const std::vector& parsed_keys, - Rendezvous::FuseDoneCallback done) = 0; + Rendezvous::FuseDoneCallback done) = 0; + + virtual void FlowControlRecvLocalAsync(int64 step_id, const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) = 0; // Removes rendezvous for "step_id". // diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index ba95e80b496..c3fb6a8ee6c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -63,6 +63,7 @@ class GrpcRemoteWorker : cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)), recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)), fuserecvtensor_(Method(GrpcWorkerMethod::kFuseRecvTensor)), + flowcontrolrecvtensor_(Method(GrpcWorkerMethod::kFlowControlRecvTensor)), recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)), logging_(Method(GrpcWorkerMethod::kLogging)), tracing_(Method(GrpcWorkerMethod::kTracing)), @@ -210,6 +211,14 @@ class GrpcRemoteWorker : IssueRequest(request, response, fuserecvtensor_, done, call_opts); } + void FlowControlRecvTensorAsync(CallOptions* call_opts, + const FlowControlRecvTensorRequest* request, + TensorResponse* response, + StatusCallback done) { + VLOG(1) << "FlowControlRecvTensorAsync req: " << request->DebugString(); + IssueRequest(request, response, flowcontrolrecvtensor_, done, call_opts); + } + void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, TensorResponse* response, StatusCallback done) override { VLOG(1) << "RecvTensorAsync req: " << request->DebugString(); @@ -341,6 +350,7 @@ class GrpcRemoteWorker : const ::grpc::string cleanupall_; const ::grpc::string recvtensor_; const ::grpc::string fuserecvtensor_; + const ::grpc::string flowcontrolrecvtensor_; const ::grpc::string recvbuf_; const ::grpc::string logging_; const ::grpc::string tracing_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h index 20f1d2b5a62..2c885fec75d 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h @@ -6,6 +6,8 @@ namespace tensorflow { class CallOptions; class FuseTensorResponse; class FuseRecvTensorRequest; +class FlowControlRecvTensorRequest; +class TensorResponse; class GrpcWorkerInterface { public: @@ -13,6 +15,10 @@ class GrpcWorkerInterface { const FuseRecvTensorRequest* request, FuseTensorResponse* response, StatusCallback done) = 0; + + virtual void FlowControlRecvTensorAsync(CallOptions* call_opts, + const FlowControlRecvTensorRequest* request, + TensorResponse* response, StatusCallback done) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index ef4fbeab438..3bdacc29a12 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -170,6 +170,15 @@ class GrpcWorkerServiceThread { EnqueueFuseRecvTensorRequestRaw(); } + // Support FlowControlRecv + for (int i = 0; + i < gtl::FindWithDefault( + queue_depth_, static_cast(GrpcWorkerMethod::kFlowControlRecvTensor), + 1000); + ++i) { + EnqueueFlowControlRecvTensorRequestRaw(); + } + void* tag; bool ok; @@ -312,6 +321,24 @@ class GrpcWorkerServiceThread { EnqueueFuseRecvTensorRequestRaw(); } + void FlowControlRecvTensorHandlerRaw( + WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + + worker_->GrpcFlowControlRecvTensorAsync(call_opts, &call->request, + &call->response, + [call, call_opts + ](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + EnqueueFlowControlRecvTensorRequestRaw(); + } + void RecvBufHandler(WorkerCall* call) { Schedule([this, call]() { CallOptions* call_opts = new CallOptions; @@ -394,6 +421,19 @@ class GrpcWorkerServiceThread { } } + void EnqueueFlowControlRecvTensorRequestRaw() { + mutex_lock l(shutdown_mu_); + if (!is_shutdown_) { + Call:: + EnqueueRequestForMethod( + worker_service_, cq_.get(), + static_cast(GrpcWorkerMethod::kFlowControlRecvTensor), + &GrpcWorkerServiceThread::FlowControlRecvTensorHandlerRaw, + true /* supports cancel*/); + } + } + GrpcWorker* const worker_ = nullptr; // Not owned. std::unique_ptr<::grpc::ServerCompletionQueue> cq_; std::unique_ptr thread_; @@ -746,6 +786,128 @@ void GrpcWorker::GrpcFuseRecvTensorAsync(CallOptions* opts, }); } +// GrpcFlowControlRecvTensorAsync: unlike the other Worker methods, which use +// protocol buffers for a response object, to avoid extra protocol buffer +// serialization overhead we generate our response directly into a +// ::grpc::ByteBuffer object +void GrpcWorker::GrpcFlowControlRecvTensorAsync(CallOptions* opts, + const FlowControlRecvTensorRequest* request, + ::grpc::ByteBuffer* response, StatusCallback done) { + VLOG(1) << "GrpcFlowControlRecvTensorAsync req: " << request->DebugString(); + const int64 request_id = request->request_id(); + const int64 step_id = request->step_id(); + + bool cache_enabled = (response_cache_ != nullptr && request_id != 0); + + auto do_response = [response, done, cache_enabled](const Tensor& tensor, + bool is_dead, + const Status& status) { + if (status.ok()) { + grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response); + } + done(status); + }; + + // If response cache is enabled and the response cache already contains the + // request, we delegate this retry request to the response cache. Otherwise, + // we add the request to the response cache and start the computation to + // retrieve the requested data. + if (cache_enabled && + response_cache_->QueueRequest(request_id, step_id, do_response)) { + return; + } + + auto rendezvous_done = [this, request_id, do_response, cache_enabled]( + const Tensor& tensor, bool is_dead, + const Status& status) { + if (cache_enabled) { + // Data is ready. Process all pending requests in the response cache. + response_cache_->OnRequestFinished(request_id, tensor, is_dead, status); + } else { + do_response(tensor, is_dead, status); + } + }; + + auto fail = [&rendezvous_done](const Status& status) { + rendezvous_done(Tensor(), false, status); + }; + + Status s = recent_request_ids_.TrackUnique( + request_id, "RecvTensor (GrpcWorker)", *request); + if (!s.ok()) { + fail(s); + return; + } + + const string& key = request->rendezvous_key(); + TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); + Rendezvous::ParsedKey parsed; + s = Rendezvous::ParseKey(key, &parsed); + Device* src_dev = nullptr; + if (s.ok()) { + s = PrepareRecvTensor(parsed, &src_dev); + } + if (!s.ok()) { + fail(s); + return; + } + + // Request the tensor associated with the rendezvous key. + // Note that we log the cancellation here but do not abort the current step. + // gRPC can generate cancellations in response to transient network failures, + // and aborting the step eliminates the opportunity for client side retries. + // Repeated client failures will eventually cause the step to be aborted by + // the client. + opts->SetCancelCallback( + [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; }); + StringPiece tag = request->tag(); + env_->rendezvous_mgr->FlowControlRecvLocalAsync( + step_id, tag, parsed, + [opts, rendezvous_done, src_dev, request]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, + const bool is_dead) { + opts->ClearCancelCallback(); + if (status.ok()) { + // DMA can only be used for Tensors that do not fall into + // the following three odd edge cases: 1) a zero-size + // buffer, 2) a dead tensor which has an uninit value, and + // 3) the tensor has the on_host allocation attribute, + // i.e. it's in CPU RAM *independent of its assigned + // device type*. + const bool on_host = send_args.alloc_attrs.on_host(); + { + // Non-DMA cases. + if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { + DeviceContext* send_dev_context = send_args.device_context; + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = src_dev->GetAllocator(alloc_attrs); + Tensor* copy = new Tensor(alloc, val.dtype(), val.shape()); + CHECK(send_dev_context) + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. + StatusCallback copy_ready = [rendezvous_done, copy, + is_dead](const Status& s) { + // The value is now ready to be returned on the wire. + rendezvous_done(*copy, is_dead, s); + delete copy; + }; + + CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(), + src_dev, copy, send_dev_context, copy_ready); + return; + } + } + } + + rendezvous_done(val, is_dead, status); + }); +} + namespace { // If RecvBufRespExtra.tensor_content is a single large string, then gRPC // can stall on the recv side when the string buffer needs to be enlarged, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index 69759c420cc..48941d438c9 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -51,6 +51,10 @@ class GrpcWorker : public Worker { ::grpc::ByteBuffer* response, StatusCallback done); + virtual void GrpcFlowControlRecvTensorAsync(CallOptions* opts, + const FlowControlRecvTensorRequest* request, + ::grpc::ByteBuffer* response, StatusCallback done); + void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, StatusCallback done) override; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index 515d6e90beb..2095540e36a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -48,6 +48,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { return "/tensorflow.WorkerService/RecvTensor"; case GrpcWorkerMethod::kFuseRecvTensor: return "/tensorflow.WorkerService/FuseRecvTensor"; + case GrpcWorkerMethod::kFlowControlRecvTensor: + return "/tensorflow.WorkerService/FlowControlRecvTensor"; case GrpcWorkerMethod::kRecvBuf: return "/tensorflow.WorkerService/RecvBuf"; case GrpcWorkerMethod::kLogging: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index ff8e1c07cb4..ad77ee0fd80 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -80,6 +80,7 @@ enum class GrpcWorkerMethod { kCleanupAll, kRecvTensor, kFuseRecvTensor, + kFlowControlRecvTensor, kRecvBuf, kLogging, kTracing, diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 69f1481f59e..267bf09e66f 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -53,6 +53,10 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { const Rendezvous::Args& args, FuseDoneCallback done) override; + void FlowControlRecvFromRemoteAsync(const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, + DoneCallback done) override; + private: ~RpcRemoteRendezvous() override {} @@ -529,6 +533,247 @@ void RpcRemoteRendezvous::FuseRecvFromRemoteAsync( }); } + + +class FlowControlRpcRecvTensorCall : public BaseRecvTensorCall { + public: + FlowControlRpcRecvTensorCall() + : wi_(nullptr), dst_device_(nullptr) {} + + void Init(WorkerInterface* wi, int64 step_id, const StringPiece& tag, + const StringPiece& key, AllocatorAttributes alloc_attrs, + Device* dst_device, const Rendezvous::Args& recv_args, + Rendezvous::DoneCallback done) { + wi_ = wi; + grpc_wi_ = dynamic_cast(wi_); + alloc_attrs_ = alloc_attrs; + dst_device_ = dst_device; + recv_args_ = recv_args; + done_ = std::move(done); + req_.set_step_id(step_id); + req_.set_tag(tag.data(), tag.size()); + req_.set_request_id(GetUniqueRequestId()); + req_.set_rendezvous_key(key.data(), key.size()); + } + + void Reset() { + // The FlowControlRpcRemoteRendezvous using this object is responsible for + // calling ReleaseWorker() before Reset(). + DCHECK_EQ(static_cast(nullptr), wi_) + << "Leaking WorkerInterface in RpcRecvTensorCall::Reset()."; + + alloc_attrs_ = AllocatorAttributes(); + dst_device_ = nullptr; + // We don't clear opts_ and assume that Init will set up the state for + // opts_ appropriately. + req_.Clear(); + resp_.Clear(); + { + mutex_lock l(mu_); + status_ = Status::OK(); + } + done_ = nullptr; + } + + ~FlowControlRpcRecvTensorCall() override { + // Since only the FlowControlRpcRecvTensorFreeList will delete an + // FlowControlRpcRecvTensorCall, and it always sets this->wi_ to null when + // a call object is released to it, we can assert that this->wi_ is + // always null at the point of deletion. + CHECK_EQ(static_cast(nullptr), wi_) + << "Leaking WorkerInterface in FlowControlRpcRecvTensorCall destructor."; + } + + void Start(std::function recv_done) override { + StartRTCall(std::move(recv_done)); + } + + void StartAbort(const Status& s) override { + { + mutex_lock l(mu_); + status_.Update(s); + } + opts_.StartCancel(); + } + + Status status() const override { + mutex_lock l(mu_); + return status_; + } + + void ReleaseWorker(WorkerCacheInterface* worker_cache) { + DCHECK_NE(static_cast(nullptr), wi_) + << "FlowControlRpcRecvTensorCall::ReleaseWorker() called twice."; + worker_cache->ReleaseWorker(src_worker_, wi_); + wi_ = nullptr; + grpc_wi_ = nullptr; + } + + const Tensor& tensor() const { return resp_.tensor(); } + + bool is_dead() const { return resp_.metadata().is_dead(); } + + Device* dst_device() const { return dst_device_; } + const Rendezvous::Args recv_args() const { return recv_args_; } + const Rendezvous::DoneCallback& done() const { return done_; } + + private: + friend class RpcRemoteRendezvous; + + // Start the main RecvTensor call, checking for an async abort. + void StartRTCall(std::function recv_done) { + resp_.InitAlloc(dst_device_, alloc_attrs_); + using namespace std::placeholders; + StatusCallback cb = std::bind( + [this](std::function recv_done, + // Begin unbound arguments. + const Status& s) { + if (!s.ok()) { + mutex_lock l(mu_); + status_.Update(s); + } + recv_done(); + }, + std::move(recv_done), _1); + grpc_wi_->FlowControlRecvTensorAsync(&opts_, &req_, &resp_, std::move(cb)); + } + + string src_worker_; + string src_rel_device_; + WorkerInterface* wi_; // Not owned. + GrpcWorkerInterface* grpc_wi_; + AllocatorAttributes alloc_attrs_; + Device* dst_device_; + CallOptions opts_; + FlowControlRecvTensorRequest req_; + TensorResponse resp_; + Rendezvous::Args recv_args_; + Rendezvous::DoneCallback done_; + + mutable mutex mu_; + Status status_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(FlowControlRpcRecvTensorCall); +}; + +class FlowControlRpcRecvTensorFreeList { + public: + FlowControlRpcRecvTensorFreeList() {} + ~FlowControlRpcRecvTensorFreeList() { + for (size_t i = 0; i < objects_.size(); i++) { + delete objects_[i]; + } + } + + FlowControlRpcRecvTensorCall* New() { + { + mutex_lock l(mu_); + if (!objects_.empty()) { + FlowControlRpcRecvTensorCall* result = objects_.back(); + objects_.pop_back(); + return result; + } + } + return new FlowControlRpcRecvTensorCall; + } + + void Release(FlowControlRpcRecvTensorCall* obj) { + obj->Reset(); + { + mutex_lock l(mu_); + if (objects_.size() < kMaxObjects) { + objects_.push_back(obj); + return; + } + } + delete obj; + } + + private: + static const int kMaxObjects = 1000; + + mutex mu_; + std::vector objects_ GUARDED_BY(mu_); +}; + +static FlowControlRpcRecvTensorFreeList* get_flow_control_call_freelist() { + static FlowControlRpcRecvTensorFreeList* call_freelist = \ + new FlowControlRpcRecvTensorFreeList(); + return call_freelist; +} + +void RpcRemoteRendezvous::FlowControlRecvFromRemoteAsync( + const StringPiece& tag, const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& recv_args, DoneCallback done) { + CHECK(is_initialized()); + Status s; + + // Prepare a FlowControlRecvTensor call that can handle being aborted. + FlowControlRpcRecvTensorCall* call = get_flow_control_call_freelist()->New(); + + // key.src_device identifies a remote device. + if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_, + &call->src_rel_device_)) { + s = errors::Internal(parsed.src_device, + " is invalid remote source device."); + } + + WorkerSession* sess = session(); + WorkerInterface* rwi = + sess->worker_cache->GetOrCreateWorker(call->src_worker_); + if (s.ok() && rwi == nullptr) { + s = errors::Internal("No worker known as ", call->src_worker_); + } + + Device* dst_device; + if (s.ok()) { + s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); + } + if (!s.ok()) { + if (rwi != nullptr) { + sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); + } + get_flow_control_call_freelist()->Release(call); + done(s, Args(), recv_args, Tensor{}, false); + return; + } + + call->Init(rwi, step_id_, tag, parsed.FullKey(), recv_args.alloc_attrs, + dst_device, recv_args, std::move(done)); + + // Record "call" in active_ so that it can be aborted cleanly. + RegisterCall(call, recv_args); + + // RendezvousMgr already aborted, shouldn't send RPC call any more + if (!call->status().ok()) { + // NOTE: `*sess` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. + call->ReleaseWorker(sess->worker_cache.get()); + call->done()(call->status(), Args(), Args(), Tensor(), false); + get_flow_control_call_freelist()->Release(call); + return; + } + + // Start "call". + Ref(); + call->Start([this, call]() { + // Removes "call" from active_. Prevent StartAbort(). + DeregisterCall(call); + // If StartAbort was called prior to DeregisterCall, then the + // current status should be bad. + Status s = call->status(); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. + call->ReleaseWorker(session()->worker_cache.get()); + call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); + get_flow_control_call_freelist()->Release(call); + Unref(); + }); + +} + } // namespace RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env) diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 5021853ce23..75f41ab3057 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -211,6 +211,32 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) { } } +TEST_F(RpcRendezvousMgrTest, FlowControlSend) { + setenv("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE", "2", 1); + const int64 step_id = 123; + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); + { + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); + core::ScopedUnref unref(rendez); + Rendezvous::Args args; + TF_ASSERT_OK( + rendez->FlowControlSend("TEST", key, args, V("peach_0"), false)); + TF_ASSERT_OK( + rendez->FlowControlSend("TEST", key, args, V("peach_1"), false)); + + EXPECT_NE( + rendez->FlowControlSend("TEST", key, args, V("peach_2"), false, 100), + Status::OK()); + EXPECT_EQ(rendez->GetAllFlowControlItemNum(), 2); + EXPECT_EQ(rendez->GetFlowControlItemNum("TEST"), 2); + } + + unsetenv("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE"); +} + class DummyDeviceContext : public DeviceContext { public: explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index e4db066a562..4d1adf1a070 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -146,6 +146,47 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val, return Recv(key, args, val, is_dead, no_timeout); } +Status Rendezvous::FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead) { + int64 no_timeout = 300000; + return FlowControlSend(tag, key, args, val, is_dead, no_timeout); +} + +Status Rendezvous::FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, bool* is_dead, + int64 timeout_ms) { + Status ret; + Notification n; + FlowControlRecvAsync(tag, key, args, [&ret, &n, val, is_dead]( + const Status& s, const Args& send_args, + const Args& recv_args, const Tensor& v, + const bool dead) { + ret = s; + *val = v; + *is_dead = dead; + n.Notify(); + }); + if (timeout_ms > 0) { + int64 timeout_us = timeout_ms * 1000; + bool notified = WaitForNotificationWithTimeout(&n, timeout_us); + if (!notified) { + return Status(error::DEADLINE_EXCEEDED, + "Timed out waiting for notification"); + } + } else { + n.WaitForNotification(); + } + return ret; +} + +Status Rendezvous::FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, + bool* is_dead) { + const int64 no_timeout = 0; + return FlowControlRecv(tag, key, args, val, is_dead, no_timeout); +} + class LocalRendezvousImpl : public Rendezvous { public: explicit LocalRendezvousImpl() {} diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index 3aa65534272..106c0f26b32 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -108,6 +108,17 @@ class Rendezvous : public core::RefCounted { virtual Status Send(const ParsedKey& key, const Args& args, Tensor* ref_val, mutex* ref_mu, const bool is_dead) { return Status::OK(); } + virtual Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead, + const int64 timeout_millis) { + return errors::Unimplemented("[Rendezvous] unimplement FlowControlSend."); + } + + virtual Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead); + // Callback provided by a tensor consumer waiting on the rendezvous. // It will be invoked when the tensor is available, or when a non-OK // status arises in the production of that tensor. It also gets @@ -139,12 +150,27 @@ class Rendezvous : public core::RefCounted { virtual void FuseRecvAsync(const std::vector& parsed_keys, const Args& args, FuseDoneCallback done) {} + // Local rendezvous does not need this. + virtual void FlowControlRecvAsync(const StringPiece& tag, + const ParsedKey& parsed_key, const Args& args, + DoneCallback done) { + CHECK(false) << "[Rendezvous] unimplement FlowControlRecvAsync."; + } + // Synchronous wrapper for RecvAsync. Status Recv(const ParsedKey& key, const Args& args, Tensor* val, bool* is_dead, int64 timeout_ms); Status Recv(const ParsedKey& key, const Args& args, Tensor* val, bool* is_dead); + // Synchronous wrapper for FlowControlRecvAsync. + Status FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, bool* is_dead, + int64 timeout_ms); + + Status FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, bool* is_dead); + // Aborts all pending and future Send/Recv with the given "status". // // StartAbort() does not wait for ongoing calls to finish. diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops.cc b/tensorflow/core/kernels/file_slice_sendrecv_ops.cc index 6bfe54363f9..a919238a5ee 100644 --- a/tensorflow/core/kernels/file_slice_sendrecv_ops.cc +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops.cc @@ -33,11 +33,10 @@ FileSliceSendOp::FileSliceSendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; @@ -212,8 +211,9 @@ Status FileSliceSendOp::SendFileSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "FileSliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, data_t, + ctx->is_input_dead())); } @@ -253,11 +253,10 @@ FileSliceRecvOp::FileSliceRecvOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; } @@ -464,8 +463,9 @@ Status FileSliceRecvOp::RecvFileSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "FileSliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); file_ptr->Append(data_t.scalar()()); diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops.h b/tensorflow/core/kernels/file_slice_sendrecv_ops.h index 6701196d481..df7e6c646f8 100644 --- a/tensorflow/core/kernels/file_slice_sendrecv_ops.h +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops.h @@ -28,6 +28,7 @@ class FileSliceSendOp : public OpKernel { private: // Variables. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; int32 slice_size_; @@ -63,6 +64,7 @@ class FileSliceRecvOp: public OpKernel { private: // Variables. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; string recv_dir_; diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc b/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc index 931cd152253..62f5596bb62 100644 --- a/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc @@ -50,6 +50,13 @@ class DummyRendezvous : public Rendezvous { kv_.erase(key_str); return Status::OK(); } + + Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead) { + return Send(key, args, val, is_dead); + } + void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) override { std::string key_str = { key.FullKey().data(), key.FullKey().size() }; @@ -72,6 +79,12 @@ class DummyRendezvous : public Rendezvous { done(Status::OK(), var.args, args, var.data, var.is_dead); kv_.erase(key_str); } + + void FlowControlRecvAsync(const StringPiece& tag, const ParsedKey& parsed_key, + const Args& args, DoneCallback done) { + RecvAsync(parsed_key, args, done); + } + void StartAbort(const Status& status) override {} private: diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.cc b/tensorflow/core/kernels/slice_sendrecv_ops.cc index 25f1a4e8738..ee0e5426cbc 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops.cc +++ b/tensorflow/core/kernels/slice_sendrecv_ops.cc @@ -30,11 +30,10 @@ SliceSendOp::SliceSendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; @@ -171,8 +170,9 @@ Status SliceSendOp::SendString(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, + data_t, ctx->is_input_dead())); } else { TF_RETURN_IF_ERROR(SendStringSlice(ctx, frame_iter, elem, i)); } @@ -209,8 +209,9 @@ Status SliceSendOp::SendStringSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, data_t, + ctx->is_input_dead())); } return Status::OK(); @@ -248,8 +249,9 @@ Status SliceSendOp::SendBasicType(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, data_t, + ctx->is_input_dead())); } return Status::OK(); @@ -270,11 +272,10 @@ SliceRecvOp::SliceRecvOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; } @@ -440,8 +441,9 @@ Status SliceRecvOp::RecvString(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); output_flat(i) = data_t.scalar()(); @@ -484,8 +486,9 @@ Status SliceRecvOp::RecvStringSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); output_flat(index) += data_t.scalar()(); @@ -529,8 +532,9 @@ Status SliceRecvOp::RecvBasicType(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); auto data_base = data_t.data(); diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.h b/tensorflow/core/kernels/slice_sendrecv_ops.h index 43429bff32f..12e583e5551 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops.h +++ b/tensorflow/core/kernels/slice_sendrecv_ops.h @@ -28,6 +28,7 @@ class SliceSendOp : public OpKernel { private: // Variables. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; int32 slice_size_; @@ -58,6 +59,7 @@ class SliceRecvOp : public OpKernel { private: // Variable. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; int32 slice_size_; diff --git a/tensorflow/core/kernels/slice_sendrecv_ops_test.cc b/tensorflow/core/kernels/slice_sendrecv_ops_test.cc index 5693ed57918..0eeb6d98c36 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops_test.cc +++ b/tensorflow/core/kernels/slice_sendrecv_ops_test.cc @@ -50,6 +50,13 @@ class DummyRendezvous : public Rendezvous { kv_.erase(key_str); return Status::OK(); } + + Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead) { + return Send(key, args, val, is_dead); + } + void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) override { std::string key_str = { key.FullKey().data(), key.FullKey().size() }; @@ -72,6 +79,12 @@ class DummyRendezvous : public Rendezvous { done(Status::OK(), var.args, args, var.data, var.is_dead); kv_.erase(key_str); } + + void FlowControlRecvAsync(const StringPiece& tag, const ParsedKey& parsed_key, + const Args& args, DoneCallback done) { + RecvAsync(parsed_key, args, done); + } + void StartAbort(const Status& status) override {} private: diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 65ec7ffe4bc..fa18fec180c 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -441,6 +441,52 @@ message MarkRecvFinishedRequest { message MarkRecvFinishedResponse {} +//////////////////////////////////////////////////////////////////////////////// +// +// FlowControlRecvTensor method request messages +// +//////////////////////////////////////////////////////////////////////////////// + +message FlowControlRecvTensorRequest { + // The step in which the tensor will be produced. + // + // REQUIRED: This must eventually correspond to the `step_id` passed + // into a RunGraph call on the same WorkerService. + int64 step_id = 1; + + string tag = 2; + + // A key identifying the channel to receive tensors from. A RecvTensor request + // retrieves one tensor from the channel, but multiple tensors can be sent and + // received over the same channel with multiple RecvTensor requests. See + // rendezvous.h for details. + string rendezvous_key = 3; + + // If true, use an out-of-band DMA mechanism to transfer the + // received tensor. + bool dma_ok = 4; + + // Optional information on client-side device locality. + DeviceLocality client_locality = 5; + + // Optional information on server-side device locality. + DeviceLocality server_locality = 6; + + // Optional information needed by the RPC subsystem. + google.protobuf.Any transport_options = 7; + + // Unique identifier for this request. Every RecvTensorRequest must have a + // unique request_id, and retried RecvTensorRequests must have the same + // request_id. If request_id is zero, retry detection and response cache + // are disabled. + // + // Retried RecvTensorRequests are problematic because a RecvTensor with no + // corresponding sender will wait forever, and the tensor may have been + // delivered to a previous retry. Workers use request_ids to reject retried + // RecvTensor requests instead of waiting forever. + int64 request_id = 8; +} + //////////////////////////////////////////////////////////////////////////////// // // Logging method request/response messages diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto index 07a64c55ad8..f8e2f313573 100644 --- a/tensorflow/core/protobuf/worker_service.proto +++ b/tensorflow/core/protobuf/worker_service.proto @@ -72,6 +72,11 @@ service WorkerService { // FuseRecvTensor Method } + // See worker.proto for details. + rpc FlowControlRecvTensor(FlowControlRecvTensorRequest) return (RecvTensorResponse) { + // FlowControlRecvTensor Method + } + // See worker.proto for details. rpc Logging(LoggingRequest) returns (LoggingResponse);