From 25e1991178a37f1cd02f7ea96053c04295155abd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 23 Oct 2024 16:20:05 -0700 Subject: [PATCH] Add cluster register barrier feature - synchronized connects. This has a few benefits: * Simplify state changes during init to be all-or-nothing op: Nobody will send heartbeat until everybody in the cluster has acked. That's less error flows to reason about during init. * This helps with consistent in-sync restarts. There's no weird staggering / cascading restart flows regardless of what the scheduler does. Concretely: 1. Task invokes connect. 2. Block until all tasks connect. In the meantime, tasks may restart multiple times without bringing down the service because initialization has not completed and there's nothing stateful that is corrupted. 3. All tasks connect together. 4. Start sending heartbeats --- Previously, we start sending heartbeats immediately after (1), which means that tasks restarting before (2) will result in service crashes, causing additional unnecessary restarts + scheduling overhead. PiperOrigin-RevId: 689142223 --- .../coordination/coordination_service.cc | 145 ++++++++++++++---- .../coordination/coordination_service.h | 3 + .../coordination_service_rpc_handler.cc | 2 +- .../coordination/coordination_service_test.cc | 97 ++++++++++++ .../tsl/protobuf/coordination_config.proto | 4 + 5 files changed, 217 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index ee17fefe9cf2f8..adec4173877473 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -121,6 +121,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { absl::Status RegisterTask(const CoordinatedTask& task, uint64_t incarnation) override; + void RegisterTaskAsync(const CoordinatedTask& task, uint64_t incarnation, + StatusCallback done) override; void WaitForAllTasks(const CoordinatedTask& task, const DeviceInfo& devices, StatusCallback done) override; void ShutdownTaskAsync(const CoordinatedTask& task, @@ -155,6 +157,13 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { const DeviceInfo& ListClusterDevices() override ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); uint64_t GetServiceIncarnation() override; + void BarrierAsyncLocked( + std::string barrier_id, absl::Duration timeout, + const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + StatusCallback ConnectAfterBarrierPasses(absl::string_view task_name, + StatusCallback done); // Checks if any task has stopped sending heartbeats. void CheckHeartbeatTimeout(); // Checks if any barrier has timed out. @@ -199,7 +208,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { std::string_view barrier_id, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, - StatusCallback done); + StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Initializes a new barrier. Returns false if the barrier should fail // immediately. bool InitializeBarrier( @@ -210,6 +219,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void PassBarrier(std::string_view barrier_id, const absl::Status& result, BarrierState* barrier) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-barrier hook to connect all tasks. + void ConnectAllTasks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Post-barrier hook to aggregate device info. void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Post-shutdown barrier hook to disconnect tasks that acked and propagate @@ -277,6 +288,15 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { CoordinatedTaskState GetState() { return state_; } absl::Status GetStatus() { return status_; } uint64_t GetTaskIncarnation() { return task_incarnation_; } + void SetTaskIncarnation(uint64_t task_incarnation) { + task_incarnation_ = task_incarnation; + } + void Connect() { + SetConnected(task_incarnation_); + LOG(INFO) << task_name_ + << " has connected to coordination service. Incarnation: " + << task_incarnation_; + } void SetConnected(uint64_t task_incarnation); void Disconnect(uint64_t grace_period_duration_us); absl::Status RecordHeartbeat(uint64_t task_incarnation); @@ -321,6 +341,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { Env& env_; const uint64_t service_incarnation_ = random::New64(); const uint64_t heartbeat_timeout_ms_; + bool cluster_register_with_barrier_ = false; + const absl::Duration cluster_register_timeout_; const absl::Duration shutdown_barrier_timeout_; // If a task restarts with a new incarnation, we may allow it to reconnect // silently if configured. This is useful when we know that a task can @@ -473,6 +495,9 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( ? config.heartbeat_timeout_in_ms() : kDefaultHeartbeatTimeoutMs; }()), + cluster_register_with_barrier_(config.cluster_register_with_barrier()), + cluster_register_timeout_( + absl::Milliseconds(config.cluster_register_timeout_in_ms())), shutdown_barrier_timeout_( absl::Milliseconds(config.shutdown_barrier_timeout_in_ms())), allow_new_incarnation_to_reconnect_( @@ -599,6 +624,10 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { } void CoordinationServiceStandaloneImpl::Stop() { + // Prevent recursion. + if (shutting_down_) { + return; + } { absl::MutexLock l(&kv_mu_); for (const auto& [key, get_kv_callbacks] : get_cb_) { @@ -664,25 +693,56 @@ void CoordinationServiceStandaloneImpl::LogConnectStatusLocked() const { absl::Status CoordinationServiceStandaloneImpl::RegisterTask( const CoordinatedTask& task, uint64_t incarnation) { + absl::Notification done; + absl::Status status; + RegisterTaskAsync(task, incarnation, [&](absl::Status s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + return status; +} + +StatusCallback CoordinationServiceStandaloneImpl::ConnectAfterBarrierPasses( + absl::string_view task_name, StatusCallback done) { + return [this, task = std::string(task_name), + done = std::move(done)](absl::Status s) mutable { + state_mu_.AssertHeld(); + if (s.ok()) { + // Connect task to service. + cluster_state_[task]->Connect(); + done(s); + } else { + done(s); + // Initialization failed, stop service now. + Stop(); + } + }; +} + +void CoordinationServiceStandaloneImpl::RegisterTaskAsync( + const CoordinatedTask& task, uint64_t incarnation, StatusCallback done) { const std::string task_name = GetTaskName(task); std::string error_message; absl::MutexLock l(&state_mu_); if (ServiceHasStopped()) { - return MakeCoordinationError(absl::InternalError(absl::StrCat( + done(MakeCoordinationError(absl::InternalError(absl::StrCat( "Coordination service has stopped. RegisterTask() from task: ", task_name, " failed. This usually implies an earlier error that caused " "coordination service to shut down before the workers disconnect " "gracefully. Check the task leader's logs for an earlier error or " "scheduler events (e.g. preemption, eviction) to debug the root " - "cause."))); + "cause.")))); + return; } if (!cluster_state_.contains(task_name)) { // Note: return early here as unexpected task register errors should not // be propagated to other tasks. - return MakeCoordinationError(absl::InvalidArgumentError( - absl::StrCat("Unexpected task registered with task_name=", task_name))); + done(MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat( + "Unexpected task registered with task_name=", task_name)))); + return; } auto* task_cluster_state = cluster_state_[task_name].get(); @@ -700,12 +760,26 @@ absl::Status CoordinationServiceStandaloneImpl::RegisterTask( // an unavailable error state, but has now restarted (possibly with // a new incarnation). This is only allowed if configured with // `allow_new_incarnation_to_reconnect`. - task_cluster_state->SetConnected(incarnation); - LOG(INFO) << task_name - << " has connected to coordination service. Incarnation: " - << incarnation; + if (cluster_register_with_barrier_) { + // Impose barrier so that all tasks can register together. + // Note: it is possible that the same task restarts multiple times and + // registers itself with new incarnations. + // That is okay; in this code branch, the tasks are not connected yet, + // and the barrier has not succeeded yet. + // There is no state that needs to be cleaned up. + task_cluster_state->SetTaskIncarnation(incarnation); + BarrierAsyncLocked("[Init]Wait_for_all_tasks_to_register", + cluster_register_timeout_, task, {}, + ConnectAfterBarrierPasses(task_name, std::move(done))); + return; + } + task_cluster_state->SetTaskIncarnation(incarnation); + task_cluster_state->Connect(); + // TODO(b/369222279): Think about the barrier case - may need periodic + // reporting of stragglers. LogConnectStatusLocked(); - return absl::OkStatus(); + done(absl::OkStatus()); + return; } else if (task_state == CoordinatedTaskState::TASKSTATE_CONNECTED) { // This may happen if the service processes the initial RegisterTask(), // but the agent did not receive the response so the agent retries again. @@ -713,12 +787,10 @@ absl::Status CoordinationServiceStandaloneImpl::RegisterTask( // This should be a no-op, but we update the last heartbeat timestamp // to give a longer grace period for the agent to start sending // heartbeats. - task_cluster_state->SetConnected(incarnation); - LOG(INFO) << task_name - << " has connected to coordination service with the same " - << "incarnation again: " << incarnation; + task_cluster_state->Connect(); LogConnectStatusLocked(); - return absl::OkStatus(); + done(absl::OkStatus()); + return; } else { error_message = absl::StrCat(task_name, @@ -739,7 +811,7 @@ absl::Status CoordinationServiceStandaloneImpl::RegisterTask( MakeCoordinationError(absl::AbortedError(error_message), task); SetTaskError(task_name, error); PropagateError(error, task); - return error; + done(error); } void CoordinationServiceStandaloneImpl::WaitForAllTasks( @@ -906,7 +978,9 @@ absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat( // stopping heartbeats. return MakeCoordinationError(absl::InvalidArgumentError( absl::StrCat("Task with task_name=", task_name, - " must be registered before sending heartbeat messages"))); + " must be registered before sending heartbeat messages. " + "The service might have restarted, please restart / reset " + "and register again."))); } VLOG(10) << "Record heartbeat from task: " << task_name << "at incarnation: " << incarnation << "at " << absl::Now(); @@ -1206,22 +1280,19 @@ bool CoordinationServiceStandaloneImpl::ValidateBarrierArgs( absl::Status error = MakeCoordinationError(absl::InvalidArgumentError( absl::StrCat("A non-participating task (", GetTaskName(task), ") called the barrier: ", barrier_id))); - { - absl::MutexLock l(&state_mu_); - // Check if coordination service has stopped. If so, return an error - // immediately. - if (ServiceHasStopped()) { - done(MakeCoordinationError(absl::InternalError( - "Barrier requested after coordination service has shut down."))); - return false; - } - auto pair = barriers_.try_emplace(barrier_id); - auto it = pair.first; - auto* barrier = &it->second; - // Make sure subsequent calls fail and existing waiting tasks receive the - // error. - PassBarrier(barrier_id, error, barrier); + // Check if coordination service has stopped. If so, return an error + // immediately. + if (ServiceHasStopped()) { + done(MakeCoordinationError(absl::InternalError( + "Barrier requested after coordination service has shut down."))); + return false; } + auto pair = barriers_.try_emplace(barrier_id); + auto it = pair.first; + auto* barrier = &it->second; + // Make sure subsequent calls fail and existing waiting tasks receive the + // error. + PassBarrier(barrier_id, error, barrier); done(error); return false; } @@ -1305,6 +1376,15 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, StatusCallback done) { + absl::MutexLock l(&state_mu_); + return BarrierAsyncLocked(barrier_id, timeout, task, participating_tasks, + std::move(done)); +}; + +void CoordinationServiceStandaloneImpl::BarrierAsyncLocked( + std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) { VLOG(3) << "Task " << GetTaskName(task) << " invoked BarrierAsync(" << barrier_id << ")."; @@ -1313,7 +1393,6 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( return; // Exit early if args are wrong. } - absl::MutexLock l(&state_mu_); // Check if coordination service has stopped. If so, return an error // immediately. if (ServiceHasStopped()) { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index 38b7ae23b734f6..af9eae6a34ec62 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -122,8 +122,11 @@ class CoordinationServiceInterface { // - InvalidArgument: Unexpected task request. // - Aborted: (1) task is in error state, or (2) task is in connected state // with a different incarnation, indicating that it restarted. + // - DeadlineExceeded: waited too long for straggler tasks to register. virtual absl::Status RegisterTask(const tensorflow::CoordinatedTask& task, uint64_t incarnation) = 0; + virtual void RegisterTaskAsync(const tensorflow::CoordinatedTask& task, + uint64_t incarnation, StatusCallback done) = 0; // Wait for all tasks to be up and running, and register local device // info. The callback is invoked when all tasks are up and registered, or some diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc index 0294d2bbbaab75..b043795731b5b1 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc @@ -66,7 +66,7 @@ void CoordinationServiceRpcHandler::RegisterTaskAsync( const uint64_t incarnation = request->incarnation(); const uint64_t leader_incarnation = service_->GetServiceIncarnation(); response->set_leader_incarnation(leader_incarnation); - done(service_->RegisterTask(task, incarnation)); + service_->RegisterTaskAsync(task, incarnation, done); } void CoordinationServiceRpcHandler::HeartbeatAsync( diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 2abb86377ecbee..0c93904d8c0931 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -236,6 +236,7 @@ class CoordinateTwoTasksTest : public ::testing::Test { void EnableCoordinationService( bool has_service_to_client_connection = true, bool enable_shutdown_barrier = false, + bool enable_register_barrier = false, bool set_worker_job_recoverable = false, bool allow_new_incarnation_to_reconnect = false) { CoordinationServiceConfig config = @@ -256,6 +257,11 @@ class CoordinateTwoTasksTest : public ::testing::Test { config.set_shutdown_barrier_timeout_in_ms(kShutdownBarrierTimeout / absl::Milliseconds(1)); } + if (enable_register_barrier) { + config.set_cluster_register_with_barrier(true); + config.set_cluster_register_timeout_in_ms(absl::Seconds(1) / + absl::Milliseconds(1)); + } if (allow_new_incarnation_to_reconnect) { config.set_allow_new_incarnation_to_reconnect(true); } @@ -1729,6 +1735,7 @@ TEST_F(CoordinateTwoTasksTest, TEST_F(CoordinateTwoTasksTest, UnrecoverableTaskPropagatesError) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/false); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1746,6 +1753,7 @@ TEST_F(CoordinateTwoTasksTest, UnrecoverableTaskPropagatesError) { TEST_F(CoordinateTwoTasksTest, RecoverableTaskWillNotPropagateError) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/true); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1765,6 +1773,7 @@ TEST_F(CoordinateTwoTasksTest, RecoverableTaskReportErrorResetAndRegisterAgain) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/true); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1789,6 +1798,7 @@ TEST_F(CoordinateTwoTasksTest, TEST_F(CoordinateTwoTasksTest, UnavailableTaskCanReconnect) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/false, /*allow_new_incarnation_to_reconnect=*/true); @@ -1971,4 +1981,91 @@ TEST_F(CoordinateTwoTasksTest, LatePollingTaskCanGetError) { HasSubstr("test_error_from_task_0")))); } +TEST_F(CoordinateTwoTasksTest, + RegisterWithBarrier_OldHeartbeat_ServiceNotStopped) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + // Service restarted. + // Old task 0 sends an unexpected heartbeat, which should fail. + // Crucially, this should not stop the service, so future API calls should not + // trigger an internal error (which occurs if service has shut down). + ASSERT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_ - 1), + StatusIs(absl::StatusCode::kInvalidArgument)); + absl::Status task0_status = absl::InternalError("uninitialized_status"); + // Task 0 registers first. + coord_service_->RegisterTaskAsync(task_0_, incarnation_0_, + [](const absl::Status& s) {}); + // Task 0 restarts with a new incarnation, and registers again. + // This should be allowed since all tasks have not joined the cluster yet. + coord_service_->RegisterTaskAsync( + task_0_, /*incarnation=*/incarnation_0_ + 1, + [&](const absl::Status& s) { task0_status = s; }); + // Now all tasks will register in a synchronized fashion due to the barrier. + EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + EXPECT_OK(task0_status); +} + +TEST_F(CoordinateTwoTasksTest, + RegisterWithBarrier_RestartBeforeBarrier_Succeeds) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + absl::Status task0_status = absl::InternalError("uninitialized_status"); + // Task 0 registers first. + coord_service_->RegisterTaskAsync(task_0_, incarnation_0_, + [](const absl::Status& s) {}); + // Task 0 restarts with a new incarnation, and registers again. + // This should be allowed since all tasks have not joined the cluster yet. + coord_service_->RegisterTaskAsync( + task_0_, /*incarnation=*/incarnation_0_ + 1, + [&](const absl::Status& s) { task0_status = s; }); + // Now all tasks will register in a synchronized fashion due to the barrier. + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + ASSERT_OK(task0_status); + // Task 0 joins again with the same incarnation. + // This is okay, it didn't restart, probably sent RPC twice due to network + // retries. + EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_ + 1)); +} + +TEST_F(CoordinateTwoTasksTest, RegisterWithBarrier_RestartAfterBarrier_Fails) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + absl::Status task0_status = absl::InternalError("uninitialized_status"); + // Task 0 registers first. + coord_service_->RegisterTaskAsync( + task_0_, incarnation_0_, + [&](const absl::Status& s) { task0_status = s; }); + // Now all tasks will register in a synchronized fashion due to the barrier. + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + ASSERT_OK(task0_status); + + // Task 0 restarts again with a new incarnation. + // This should fail since this happens after the initial register barrier + // (i.e. all tasks already acked once). + ASSERT_THAT(coord_service_->RegisterTask(task_0_, incarnation_0_ + 2), + StatusIs(absl::StatusCode::kAborted)); + // Service should have stopped due to the previous registration failure. + // Check for internal error code. + absl::Notification n; + absl::Status barrier_status; + coord_service_->BarrierAsync("barrier_id", absl::Seconds(10), task_0_, {}, + [&](const absl::Status& s) { + n.Notify(); + barrier_status = s; + }); + n.WaitForNotification(); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); +} + +TEST_F(CoordinateTwoTasksTest, RegisterWithBarrier_Timeout) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + // Task 0 joins without task 1. Times out eventually. + EXPECT_THAT(coord_service_->RegisterTask(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kDeadlineExceeded)); +} } // namespace tsl diff --git a/third_party/xla/xla/tsl/protobuf/coordination_config.proto b/third_party/xla/xla/tsl/protobuf/coordination_config.proto index 23aff65eb67985..645c992c64df42 100644 --- a/third_party/xla/xla/tsl/protobuf/coordination_config.proto +++ b/third_party/xla/xla/tsl/protobuf/coordination_config.proto @@ -29,6 +29,10 @@ message CoordinationServiceConfig { // Maximum wait time for all members in the cluster to be registered. int64 cluster_register_timeout_in_ms = 4; + // Denotes if we should synchronize the agents' register attempts by blocking + // on a barrier. This is useful for synchronized restarts. + bool cluster_register_with_barrier = 14; + // Heartbeat timeout, if a task does not record heartbeat in this time // window, it will be considered disconnected. // Note: This is also used as a grace period to accept any heartbeats after