Skip to content

Commit

Permalink
Create an async_while kernel that may dispatch an iteration before
Browse files Browse the repository at this point in the history
the previous iteration is complete.

PiperOrigin-RevId: 550759127
  • Loading branch information
deqiangc authored and tensorflower-gardener committed Jul 25, 2023
1 parent f15c449 commit 996865f
Show file tree
Hide file tree
Showing 2 changed files with 574 additions and 1 deletion.
294 changes: 293 additions & 1 deletion tensorflow/core/tfrt/mlrt/kernel/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,297 @@ namespace tensorflow {
namespace tf_mlrt {
namespace {

// AsyncWhileOp dispatch the body function repeatedly until the body function
// returns a predicate value of false. Each invocation of the body function
// corresponds to an iteration in a while loop. The body function is expected to
// have the following input signature (predicate_promise,
// mutable_tensor0_future, mutable_tensor0_promise, mutable_tensor1_future,
// mutable_tensor1_promise, ...., immutable_tensors). AsyncWhileOp dispatch the
// next iteraion as soon as the previous iteration has set the
// predicate_promise. Hence, in the case that the body function set
// predicate_promise earlier than setting any other promises, multiple
// iterations can run parallelly via this op.
class AsyncWhileOp : mlrt::KernelFrame {
public:
using KernelFrame::KernelFrame;

static constexpr char kName[] = "tf_mlrt.async_while";

mlrt::bc::Function body_function() const {
uint32_t func_idx = attributes().GetAs<uint32_t>(0);
return execution_context()
.loaded_executable()
.executable()
.functions()[func_idx];
}

// Arguments that remains unchanged between iterations are called
// immutable(invariants). Immutables are all at the bottom of the argument
// list. Immutable_size reflects the number of immutables.
uint32_t immutable_size() const { return attributes().GetAs<uint32_t>(1); }

void Invoke();

private:
// This utility function is used when an iteration has set its
// predicate_promise. If predicate is true, it dispatches the next iteration.
// If predicate is false, it set ups the AsyncOp's return futures via
// final_promises.
static void OnPredicateReady(
tensorflow::tfrt_stub::FallbackTensor predicate,
std::vector<mlrt::AsyncHandle> async_handles,
std::vector<mlrt::Future> mutable_tensor_futures,
std::vector<tensorflow::tfrt_stub::FallbackTensor> immutable_tensors,
std::vector<mlrt::Promise> final_promises, mlrt::bc::Function body_fn,
mlrt::ExecutionContext& execution_context, uint32_t counter);

// A utility function to populate the results in final_promises.
static void PopulateFinalPromise(
std::vector<mlrt::Promise>& final_promises,
const std::vector<mlrt::Future>& mutable_tensor_futures,
const std::vector<tensorflow::tfrt_stub::FallbackTensor>&
immutable_tensors);
};

void AsyncWhileOp::OnPredicateReady(
tensorflow::tfrt_stub::FallbackTensor predicate,
std::vector<mlrt::AsyncHandle> async_handles,
std::vector<mlrt::Future> mutable_tensor_futures,
std::vector<tensorflow::tfrt_stub::FallbackTensor> immutable_tensors,
std::vector<mlrt::Promise> final_promises, mlrt::bc::Function body_fn,
mlrt::ExecutionContext& execution_context, uint32_t counter) {
// final_promises[0] contains the final predicate and serves something similar
// as async_handle that the caller can wait and know the program is complete.
DCHECK_EQ(final_promises.size(),
mutable_tensor_futures.size() + immutable_tensors.size() + 1);

// [predicate_promise; arg0_future, arg0_promise, arg1_future, arg1_promise,
// ..., immutable_args]
const uint32_t body_argument_size =
1 + 2 * mutable_tensor_futures.size() + immutable_tensors.size();
DCHECK_EQ(body_fn.input_regs().size(), body_argument_size);

tsl::profiler::TraceMe trace_me([&]() {
return tsl::profiler::TraceMeEncode(
"tf_mlrt.AsyncWhileOp.OnPredicateReady",
{{"counter", counter}, {"name", body_fn.name().Get()}});
});

bool predicate_value = predicate.tensor().scalar<bool>()();
if (!predicate_value) {
// No more iterations.
if (async_handles.empty()) {
// Initial predicate is false
PopulateFinalPromise(final_promises, mutable_tensor_futures,
immutable_tensors);
} else {
// Iterations ends. Wait for all futures to be ready.
mlrt::Future await_all = mlrt::AwaitAll(absl::MakeSpan(async_handles));
std::move(await_all).Then(
[final_promises = std::move(final_promises),
variant_tensor_futures = std::move(mutable_tensor_futures),
async_handles = std::move(async_handles),
immutable_tensors](absl::Status status) mutable {
if (status.ok()) {
PopulateFinalPromise(final_promises, variant_tensor_futures,
immutable_tensors);
return;
} else {
for (auto& final_promise : final_promises) {
std::move(final_promise).SetError(status);
}
}
});
}
return;
}
// proceed to schedule the next iteration n+1.
// Creates arguments for dispatching the next iteration.
std::vector<mlrt::Value> body_args;
body_args.resize(body_argument_size);

// Set predicate_promise
auto arg_iter = body_args.begin();
auto predicate_promise =
mlrt::Promise::Allocate<tensorflow::tfrt_stub::FallbackTensor>();
auto predicate_future = predicate_promise.GetFuture();
arg_iter->Set(std::move(predicate_promise));
++arg_iter;

// Current iteration n receives mutable tensor values in future from
// iteration n-1 and creates promises to return those mutable tensors after
// updating them from the current iteration.
std::vector<mlrt::Future> next_futures;
next_futures.reserve(mutable_tensor_futures.size());

for (auto& mutable_tensor : mutable_tensor_futures) {
// Future from the previous iteration as input to the current iteration.
arg_iter->Set(std::move(mutable_tensor));
++arg_iter;

// Promise to return values from the current iteration.
auto next_promise =
mlrt::Promise::Allocate<tensorflow::tfrt_stub::FallbackTensor>();
next_futures.push_back(next_promise.GetFuture());
arg_iter->Set(std::move(next_promise));
++arg_iter;
}

// Tensors that remains unchanged across iterations are copied over due to
// asynchronous execution between iterations.
for (auto& immutable_tensor : immutable_tensors) {
arg_iter->Set(immutable_tensor);
arg_iter++;
}

// Launch this iteration.
auto [promise, handle] = mlrt::AsyncHandle::Allocate(execution_context);
auto& thread_execution_context = handle.execution_context();
thread_execution_context.set_exit_handler(
[&execution_context = thread_execution_context,
promise = std::move(promise)]() mutable {
std::move(promise).Finish(execution_context.status());
});

thread_execution_context.CallByMove(body_fn, absl::MakeSpan(body_args),
absl::Span<mlrt::Value>());

thread_execution_context.work_queue()->AddTask(
[&execution_context = thread_execution_context]() {
mlrt::Execute(execution_context);
});

// save handles
async_handles.push_back(std::move(handle));

std::move(predicate_future)
.Then([futures = std::move(next_futures),
immutable_tensors = std::move(immutable_tensors),
final_promises = std::move(final_promises),
body_args = std::move(body_args),
async_handles = std::move(async_handles), body_fn, counter,
&execution_context = thread_execution_context](
absl::StatusOr<tensorflow::tfrt_stub::FallbackTensor>
predicate_result) mutable {
if (!predicate_result.ok()) {
auto status = predicate_result.status();
mlrt::Future await_all =
mlrt::AwaitAll(absl::MakeSpan(async_handles));
std::move(await_all).Then([final_promises = std::move(final_promises),
async_handles = std::move(async_handles),
status]() mutable {
for (auto& final_promise : final_promises) {
std::move(final_promise).SetError(status);
}
});
execution_context.Fail(status);
return;
}

// Keep body_args alive for thread execution.
OnPredicateReady(*predicate_result, std::move(async_handles),
std::move(futures), immutable_tensors,
std::move(final_promises), body_fn, execution_context,
++counter);
});
}

void AsyncWhileOp::PopulateFinalPromise(
std::vector<mlrt::Promise>& final_promises,
const std::vector<mlrt::Future>& mutable_tensor_futures,
const std::vector<tensorflow::tfrt_stub::FallbackTensor>&
immutable_tensors) {
// The final predicate needs to be a tensor, not bool so that await_all
// can be used.
tensorflow::Tensor final_predicate_tensor(false);

auto final_promise_iter = final_promises.begin();
std::move(*final_promise_iter)
.Set<tensorflow::tfrt_stub::FallbackTensor>(
tensorflow::tfrt_stub::FallbackTensor(
std::move(final_predicate_tensor)));
final_promise_iter++;
for (auto& mutable_tensor_future : mutable_tensor_futures) {
DCHECK(mutable_tensor_future.IsReady());
std::move(*final_promise_iter)
.Set<tensorflow::tfrt_stub::FallbackTensor>(
std::move(mutable_tensor_future
.Get<tensorflow::tfrt_stub::FallbackTensor>()));
final_promise_iter++;
}
for (auto& immutable_tensor : immutable_tensors) {
std::move(*final_promise_iter)
.Set<tensorflow::tfrt_stub::FallbackTensor>(immutable_tensor);
final_promise_iter++;
}
}

void AsyncWhileOp::Invoke() {
mlrt::bc::Function body_fn = body_function();

// Argument: [final_predicate, %variant0, %variant1, ..., %invariant0,...]
//
// Results: [final_predicate, %variant0, %variant1, ..., %invariant0,...]
//
DCHECK_EQ(arguments().size(), results().size());

// [predicate_promise; arg0_future, arg0_promise, arg1_future, arg1_promise,
// ..., invariant_args]
// minus 1 b/c predicate is not a tensor
const uint32_t immutable_tensor_size = immutable_size();
const uint32_t mutable_tensor_size =
arguments().size() - immutable_tensor_size - 1;

const uint32_t body_argument_size =
1 + (2 * mutable_tensor_size) + immutable_tensor_size;
DCHECK_EQ(body_fn.input_regs().size(), body_argument_size);
DCHECK_EQ(body_fn.output_regs().size(), 0);

tsl::profiler::TraceMe trace_me([&]() {
return tsl::profiler::TraceMeEncode("tf_mlrt.async_while",
{{"name", body_fn.name().Get()}});
});

// Save the future of final results. The last iteration will set the promises.
std::vector<mlrt::Promise> final_promises;
final_promises.reserve(arguments().size());
for (int i = 0; i < arguments().size(); ++i) {
final_promises.push_back(
mlrt::Promise::Allocate<tensorflow::tfrt_stub::FallbackTensor>());
results()[i] = final_promises.back().GetFuture();
}

// Populate input arguments into a list of dummy futures to bootstrap the
// first iteration.
std::vector<mlrt::Future> mutable_tensor_futures;
mutable_tensor_futures.reserve(mutable_tensor_size);

// Plus 1 because the very first argument is a boolean predicate .
auto arg_iter = arguments().begin() + 1;
for (int i = 0; i < mutable_tensor_size; ++i) {
auto tensor_promise =
mlrt::Promise::Allocate<tensorflow::tfrt_stub::FallbackTensor>();
mutable_tensor_futures.push_back(tensor_promise.GetFuture());
std::move(tensor_promise)
.Set<tensorflow::tfrt_stub::FallbackTensor>(
arg_iter->Get<tensorflow::tfrt_stub::FallbackTensor>());
arg_iter++;
}

std::vector<tensorflow::tfrt_stub::FallbackTensor> immutable_tensors;
immutable_tensors.reserve(immutable_tensor_size);
for (int i = 0; i < immutable_tensor_size; ++i) {
immutable_tensors.push_back(
arg_iter->Get<tensorflow::tfrt_stub::FallbackTensor>());
arg_iter++;
}
OnPredicateReady(arguments()[0].Get<tensorflow::tfrt_stub::FallbackTensor>(),
/*async_handles=*/{}, std::move(mutable_tensor_futures),
immutable_tensors, std::move(final_promises), body_fn,
execution_context(),
/*counter=*/0);
}

struct MapFnOp : mlrt::KernelFrame {
using KernelFrame::KernelFrame;

Expand Down Expand Up @@ -142,7 +433,7 @@ void MapFnOp::Invoke() {
body_arg_last_uses.begin() + 2 * num_tensor_list_or_flow_in() + 2,
true);

// Copy the invairant arguments (after max_iteration +
// Copy the invariant arguments (after max_iteration +
// tensor_list_or_flow_ins)
auto arg_iter = body_args.begin() + 2 * num_tensor_list_or_flow_in() + 2;
for (int j = num_tensor_list_or_flow_in() + 1; j < arguments().size();
Expand Down Expand Up @@ -684,6 +975,7 @@ void RegisterTfMlrtKernels(mlrt::KernelRegistry& registry) {
registry.Register<ExecuteOp>();
registry.Register<ExecuteOp>("tfrt_fallback_sync.executeop");
registry.Register<AsyncExecuteOp>();
registry.Register<AsyncWhileOp>();
registry.Register<ExecuteOpDevice>();
registry.Register<AsyncExecuteOpDevice>();
registry.Register("tf_mlrt.set_resource", &SetResource);
Expand Down
Loading

0 comments on commit 996865f

Please sign in to comment.