Skip to content

Commit

Permalink
Add cluster register barrier feature - synchronized connects.
Browse files Browse the repository at this point in the history
This has a few benefits:
* Simplify state changes during init to be all-or-nothing op: Nobody will send heartbeat until everybody in the cluster has acked. That's less error flows to reason about during init.

* This helps with consistent in-sync restarts. There's no weird staggering / cascading restart flows regardless of what the scheduler does.

Concretely:
1. Task invokes connect.
2. Block until all tasks connect. In the meantime, tasks may restart multiple times without bringing down the service because initialization has not completed and there's nothing stateful that is corrupted.
3.  All tasks connect together.
4. Start sending heartbeats

---

Previously, we start sending heartbeats immediately after (1), which means that tasks restarting before (2) will result in service crashes, causing additional unnecessary restarts + scheduling overhead.

PiperOrigin-RevId: 689142223
  • Loading branch information
tensorflower-gardener committed Oct 23, 2024
1 parent ed9d15b commit 25e1991
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {

absl::Status RegisterTask(const CoordinatedTask& task,
uint64_t incarnation) override;
void RegisterTaskAsync(const CoordinatedTask& task, uint64_t incarnation,
StatusCallback done) override;
void WaitForAllTasks(const CoordinatedTask& task, const DeviceInfo& devices,
StatusCallback done) override;
void ShutdownTaskAsync(const CoordinatedTask& task,
Expand Down Expand Up @@ -155,6 +157,13 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
const DeviceInfo& ListClusterDevices() override
ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
uint64_t GetServiceIncarnation() override;
void BarrierAsyncLocked(
std::string barrier_id, absl::Duration timeout,
const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
StatusCallback ConnectAfterBarrierPasses(absl::string_view task_name,
StatusCallback done);
// Checks if any task has stopped sending heartbeats.
void CheckHeartbeatTimeout();
// Checks if any barrier has timed out.
Expand Down Expand Up @@ -199,7 +208,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
std::string_view barrier_id, absl::Duration timeout,
const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done);
StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
// Initializes a new barrier. Returns false if the barrier should fail
// immediately.
bool InitializeBarrier(
Expand All @@ -210,6 +219,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
void PassBarrier(std::string_view barrier_id, const absl::Status& result,
BarrierState* barrier)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
// Post-barrier hook to connect all tasks.
void ConnectAllTasks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
// Post-barrier hook to aggregate device info.
void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
// Post-shutdown barrier hook to disconnect tasks that acked and propagate
Expand Down Expand Up @@ -277,6 +288,15 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
CoordinatedTaskState GetState() { return state_; }
absl::Status GetStatus() { return status_; }
uint64_t GetTaskIncarnation() { return task_incarnation_; }
void SetTaskIncarnation(uint64_t task_incarnation) {
task_incarnation_ = task_incarnation;
}
void Connect() {
SetConnected(task_incarnation_);
LOG(INFO) << task_name_
<< " has connected to coordination service. Incarnation: "
<< task_incarnation_;
}
void SetConnected(uint64_t task_incarnation);
void Disconnect(uint64_t grace_period_duration_us);
absl::Status RecordHeartbeat(uint64_t task_incarnation);
Expand Down Expand Up @@ -321,6 +341,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
Env& env_;
const uint64_t service_incarnation_ = random::New64();
const uint64_t heartbeat_timeout_ms_;
bool cluster_register_with_barrier_ = false;
const absl::Duration cluster_register_timeout_;
const absl::Duration shutdown_barrier_timeout_;
// If a task restarts with a new incarnation, we may allow it to reconnect
// silently if configured. This is useful when we know that a task can
Expand Down Expand Up @@ -473,6 +495,9 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl(
? config.heartbeat_timeout_in_ms()
: kDefaultHeartbeatTimeoutMs;
}()),
cluster_register_with_barrier_(config.cluster_register_with_barrier()),
cluster_register_timeout_(
absl::Milliseconds(config.cluster_register_timeout_in_ms())),
shutdown_barrier_timeout_(
absl::Milliseconds(config.shutdown_barrier_timeout_in_ms())),
allow_new_incarnation_to_reconnect_(
Expand Down Expand Up @@ -599,6 +624,10 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() {
}

void CoordinationServiceStandaloneImpl::Stop() {
// Prevent recursion.
if (shutting_down_) {
return;
}
{
absl::MutexLock l(&kv_mu_);
for (const auto& [key, get_kv_callbacks] : get_cb_) {
Expand Down Expand Up @@ -664,25 +693,56 @@ void CoordinationServiceStandaloneImpl::LogConnectStatusLocked() const {

absl::Status CoordinationServiceStandaloneImpl::RegisterTask(
const CoordinatedTask& task, uint64_t incarnation) {
absl::Notification done;
absl::Status status;
RegisterTaskAsync(task, incarnation, [&](absl::Status s) {
status = s;
done.Notify();
});
done.WaitForNotification();
return status;
}

StatusCallback CoordinationServiceStandaloneImpl::ConnectAfterBarrierPasses(
absl::string_view task_name, StatusCallback done) {
return [this, task = std::string(task_name),
done = std::move(done)](absl::Status s) mutable {
state_mu_.AssertHeld();
if (s.ok()) {
// Connect task to service.
cluster_state_[task]->Connect();
done(s);
} else {
done(s);
// Initialization failed, stop service now.
Stop();
}
};
}

void CoordinationServiceStandaloneImpl::RegisterTaskAsync(
const CoordinatedTask& task, uint64_t incarnation, StatusCallback done) {
const std::string task_name = GetTaskName(task);

std::string error_message;
absl::MutexLock l(&state_mu_);
if (ServiceHasStopped()) {
return MakeCoordinationError(absl::InternalError(absl::StrCat(
done(MakeCoordinationError(absl::InternalError(absl::StrCat(
"Coordination service has stopped. RegisterTask() from task: ",
task_name,
" failed. This usually implies an earlier error that caused "
"coordination service to shut down before the workers disconnect "
"gracefully. Check the task leader's logs for an earlier error or "
"scheduler events (e.g. preemption, eviction) to debug the root "
"cause.")));
"cause."))));
return;
}
if (!cluster_state_.contains(task_name)) {
// Note: return early here as unexpected task register errors should not
// be propagated to other tasks.
return MakeCoordinationError(absl::InvalidArgumentError(
absl::StrCat("Unexpected task registered with task_name=", task_name)));
done(MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat(
"Unexpected task registered with task_name=", task_name))));
return;
}

auto* task_cluster_state = cluster_state_[task_name].get();
Expand All @@ -700,25 +760,37 @@ absl::Status CoordinationServiceStandaloneImpl::RegisterTask(
// an unavailable error state, but has now restarted (possibly with
// a new incarnation). This is only allowed if configured with
// `allow_new_incarnation_to_reconnect`.
task_cluster_state->SetConnected(incarnation);
LOG(INFO) << task_name
<< " has connected to coordination service. Incarnation: "
<< incarnation;
if (cluster_register_with_barrier_) {
// Impose barrier so that all tasks can register together.
// Note: it is possible that the same task restarts multiple times and
// registers itself with new incarnations.
// That is okay; in this code branch, the tasks are not connected yet,
// and the barrier has not succeeded yet.
// There is no state that needs to be cleaned up.
task_cluster_state->SetTaskIncarnation(incarnation);
BarrierAsyncLocked("[Init]Wait_for_all_tasks_to_register",
cluster_register_timeout_, task, {},
ConnectAfterBarrierPasses(task_name, std::move(done)));
return;
}
task_cluster_state->SetTaskIncarnation(incarnation);
task_cluster_state->Connect();
// TODO(b/369222279): Think about the barrier case - may need periodic
// reporting of stragglers.
LogConnectStatusLocked();
return absl::OkStatus();
done(absl::OkStatus());
return;
} else if (task_state == CoordinatedTaskState::TASKSTATE_CONNECTED) {
// This may happen if the service processes the initial RegisterTask(),
// but the agent did not receive the response so the agent retries again.
if (task_cluster_state->GetTaskIncarnation() == incarnation) {
// This should be a no-op, but we update the last heartbeat timestamp
// to give a longer grace period for the agent to start sending
// heartbeats.
task_cluster_state->SetConnected(incarnation);
LOG(INFO) << task_name
<< " has connected to coordination service with the same "
<< "incarnation again: " << incarnation;
task_cluster_state->Connect();
LogConnectStatusLocked();
return absl::OkStatus();
done(absl::OkStatus());
return;
} else {
error_message =
absl::StrCat(task_name,
Expand All @@ -739,7 +811,7 @@ absl::Status CoordinationServiceStandaloneImpl::RegisterTask(
MakeCoordinationError(absl::AbortedError(error_message), task);
SetTaskError(task_name, error);
PropagateError(error, task);
return error;
done(error);
}

void CoordinationServiceStandaloneImpl::WaitForAllTasks(
Expand Down Expand Up @@ -906,7 +978,9 @@ absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat(
// stopping heartbeats.
return MakeCoordinationError(absl::InvalidArgumentError(
absl::StrCat("Task with task_name=", task_name,
" must be registered before sending heartbeat messages")));
" must be registered before sending heartbeat messages. "
"The service might have restarted, please restart / reset "
"and register again.")));
}
VLOG(10) << "Record heartbeat from task: " << task_name
<< "at incarnation: " << incarnation << "at " << absl::Now();
Expand Down Expand Up @@ -1206,22 +1280,19 @@ bool CoordinationServiceStandaloneImpl::ValidateBarrierArgs(
absl::Status error = MakeCoordinationError(absl::InvalidArgumentError(
absl::StrCat("A non-participating task (", GetTaskName(task),
") called the barrier: ", barrier_id)));
{
absl::MutexLock l(&state_mu_);
// Check if coordination service has stopped. If so, return an error
// immediately.
if (ServiceHasStopped()) {
done(MakeCoordinationError(absl::InternalError(
"Barrier requested after coordination service has shut down.")));
return false;
}
auto pair = barriers_.try_emplace(barrier_id);
auto it = pair.first;
auto* barrier = &it->second;
// Make sure subsequent calls fail and existing waiting tasks receive the
// error.
PassBarrier(barrier_id, error, barrier);
// Check if coordination service has stopped. If so, return an error
// immediately.
if (ServiceHasStopped()) {
done(MakeCoordinationError(absl::InternalError(
"Barrier requested after coordination service has shut down.")));
return false;
}
auto pair = barriers_.try_emplace(barrier_id);
auto it = pair.first;
auto* barrier = &it->second;
// Make sure subsequent calls fail and existing waiting tasks receive the
// error.
PassBarrier(barrier_id, error, barrier);
done(error);
return false;
}
Expand Down Expand Up @@ -1305,6 +1376,15 @@ void CoordinationServiceStandaloneImpl::BarrierAsync(
std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) {
absl::MutexLock l(&state_mu_);
return BarrierAsyncLocked(barrier_id, timeout, task, participating_tasks,
std::move(done));
};

void CoordinationServiceStandaloneImpl::BarrierAsyncLocked(
std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) {
VLOG(3) << "Task " << GetTaskName(task) << " invoked BarrierAsync("
<< barrier_id << ").";

Expand All @@ -1313,7 +1393,6 @@ void CoordinationServiceStandaloneImpl::BarrierAsync(
return; // Exit early if args are wrong.
}

absl::MutexLock l(&state_mu_);
// Check if coordination service has stopped. If so, return an error
// immediately.
if (ServiceHasStopped()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ class CoordinationServiceInterface {
// - InvalidArgument: Unexpected task request.
// - Aborted: (1) task is in error state, or (2) task is in connected state
// with a different incarnation, indicating that it restarted.
// - DeadlineExceeded: waited too long for straggler tasks to register.
virtual absl::Status RegisterTask(const tensorflow::CoordinatedTask& task,
uint64_t incarnation) = 0;
virtual void RegisterTaskAsync(const tensorflow::CoordinatedTask& task,
uint64_t incarnation, StatusCallback done) = 0;

// Wait for all tasks to be up and running, and register local device
// info. The callback is invoked when all tasks are up and registered, or some
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void CoordinationServiceRpcHandler::RegisterTaskAsync(
const uint64_t incarnation = request->incarnation();
const uint64_t leader_incarnation = service_->GetServiceIncarnation();
response->set_leader_incarnation(leader_incarnation);
done(service_->RegisterTask(task, incarnation));
service_->RegisterTaskAsync(task, incarnation, done);
}

void CoordinationServiceRpcHandler::HeartbeatAsync(
Expand Down
Loading

0 comments on commit 25e1991

Please sign in to comment.