From 996865f038df52acefa4dcba97047fe3181e1280 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Mon, 24 Jul 2023 20:56:51 -0700 Subject: [PATCH] Create an async_while kernel that may dispatch an iteration before the previous iteration is complete. PiperOrigin-RevId: 550759127 --- tensorflow/core/tfrt/mlrt/kernel/kernel.cc | 294 +++++++++++++++++- .../core/tfrt/mlrt/kernel/kernel_test.cc | 281 +++++++++++++++++ 2 files changed, 574 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel.cc index 0c2e0bf12a67c8..b47be8d5aa6520 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/kernel.cc @@ -50,6 +50,297 @@ namespace tensorflow { namespace tf_mlrt { namespace { +// AsyncWhileOp dispatch the body function repeatedly until the body function +// returns a predicate value of false. Each invocation of the body function +// corresponds to an iteration in a while loop. The body function is expected to +// have the following input signature (predicate_promise, +// mutable_tensor0_future, mutable_tensor0_promise, mutable_tensor1_future, +// mutable_tensor1_promise, ...., immutable_tensors). AsyncWhileOp dispatch the +// next iteraion as soon as the previous iteration has set the +// predicate_promise. Hence, in the case that the body function set +// predicate_promise earlier than setting any other promises, multiple +// iterations can run parallelly via this op. +class AsyncWhileOp : mlrt::KernelFrame { + public: + using KernelFrame::KernelFrame; + + static constexpr char kName[] = "tf_mlrt.async_while"; + + mlrt::bc::Function body_function() const { + uint32_t func_idx = attributes().GetAs(0); + return execution_context() + .loaded_executable() + .executable() + .functions()[func_idx]; + } + + // Arguments that remains unchanged between iterations are called + // immutable(invariants). Immutables are all at the bottom of the argument + // list. Immutable_size reflects the number of immutables. + uint32_t immutable_size() const { return attributes().GetAs(1); } + + void Invoke(); + + private: + // This utility function is used when an iteration has set its + // predicate_promise. If predicate is true, it dispatches the next iteration. + // If predicate is false, it set ups the AsyncOp's return futures via + // final_promises. + static void OnPredicateReady( + tensorflow::tfrt_stub::FallbackTensor predicate, + std::vector async_handles, + std::vector mutable_tensor_futures, + std::vector immutable_tensors, + std::vector final_promises, mlrt::bc::Function body_fn, + mlrt::ExecutionContext& execution_context, uint32_t counter); + + // A utility function to populate the results in final_promises. + static void PopulateFinalPromise( + std::vector& final_promises, + const std::vector& mutable_tensor_futures, + const std::vector& + immutable_tensors); +}; + +void AsyncWhileOp::OnPredicateReady( + tensorflow::tfrt_stub::FallbackTensor predicate, + std::vector async_handles, + std::vector mutable_tensor_futures, + std::vector immutable_tensors, + std::vector final_promises, mlrt::bc::Function body_fn, + mlrt::ExecutionContext& execution_context, uint32_t counter) { + // final_promises[0] contains the final predicate and serves something similar + // as async_handle that the caller can wait and know the program is complete. + DCHECK_EQ(final_promises.size(), + mutable_tensor_futures.size() + immutable_tensors.size() + 1); + + // [predicate_promise; arg0_future, arg0_promise, arg1_future, arg1_promise, + // ..., immutable_args] + const uint32_t body_argument_size = + 1 + 2 * mutable_tensor_futures.size() + immutable_tensors.size(); + DCHECK_EQ(body_fn.input_regs().size(), body_argument_size); + + tsl::profiler::TraceMe trace_me([&]() { + return tsl::profiler::TraceMeEncode( + "tf_mlrt.AsyncWhileOp.OnPredicateReady", + {{"counter", counter}, {"name", body_fn.name().Get()}}); + }); + + bool predicate_value = predicate.tensor().scalar()(); + if (!predicate_value) { + // No more iterations. + if (async_handles.empty()) { + // Initial predicate is false + PopulateFinalPromise(final_promises, mutable_tensor_futures, + immutable_tensors); + } else { + // Iterations ends. Wait for all futures to be ready. + mlrt::Future await_all = mlrt::AwaitAll(absl::MakeSpan(async_handles)); + std::move(await_all).Then( + [final_promises = std::move(final_promises), + variant_tensor_futures = std::move(mutable_tensor_futures), + async_handles = std::move(async_handles), + immutable_tensors](absl::Status status) mutable { + if (status.ok()) { + PopulateFinalPromise(final_promises, variant_tensor_futures, + immutable_tensors); + return; + } else { + for (auto& final_promise : final_promises) { + std::move(final_promise).SetError(status); + } + } + }); + } + return; + } + // proceed to schedule the next iteration n+1. + // Creates arguments for dispatching the next iteration. + std::vector body_args; + body_args.resize(body_argument_size); + + // Set predicate_promise + auto arg_iter = body_args.begin(); + auto predicate_promise = + mlrt::Promise::Allocate(); + auto predicate_future = predicate_promise.GetFuture(); + arg_iter->Set(std::move(predicate_promise)); + ++arg_iter; + + // Current iteration n receives mutable tensor values in future from + // iteration n-1 and creates promises to return those mutable tensors after + // updating them from the current iteration. + std::vector next_futures; + next_futures.reserve(mutable_tensor_futures.size()); + + for (auto& mutable_tensor : mutable_tensor_futures) { + // Future from the previous iteration as input to the current iteration. + arg_iter->Set(std::move(mutable_tensor)); + ++arg_iter; + + // Promise to return values from the current iteration. + auto next_promise = + mlrt::Promise::Allocate(); + next_futures.push_back(next_promise.GetFuture()); + arg_iter->Set(std::move(next_promise)); + ++arg_iter; + } + + // Tensors that remains unchanged across iterations are copied over due to + // asynchronous execution between iterations. + for (auto& immutable_tensor : immutable_tensors) { + arg_iter->Set(immutable_tensor); + arg_iter++; + } + + // Launch this iteration. + auto [promise, handle] = mlrt::AsyncHandle::Allocate(execution_context); + auto& thread_execution_context = handle.execution_context(); + thread_execution_context.set_exit_handler( + [&execution_context = thread_execution_context, + promise = std::move(promise)]() mutable { + std::move(promise).Finish(execution_context.status()); + }); + + thread_execution_context.CallByMove(body_fn, absl::MakeSpan(body_args), + absl::Span()); + + thread_execution_context.work_queue()->AddTask( + [&execution_context = thread_execution_context]() { + mlrt::Execute(execution_context); + }); + + // save handles + async_handles.push_back(std::move(handle)); + + std::move(predicate_future) + .Then([futures = std::move(next_futures), + immutable_tensors = std::move(immutable_tensors), + final_promises = std::move(final_promises), + body_args = std::move(body_args), + async_handles = std::move(async_handles), body_fn, counter, + &execution_context = thread_execution_context]( + absl::StatusOr + predicate_result) mutable { + if (!predicate_result.ok()) { + auto status = predicate_result.status(); + mlrt::Future await_all = + mlrt::AwaitAll(absl::MakeSpan(async_handles)); + std::move(await_all).Then([final_promises = std::move(final_promises), + async_handles = std::move(async_handles), + status]() mutable { + for (auto& final_promise : final_promises) { + std::move(final_promise).SetError(status); + } + }); + execution_context.Fail(status); + return; + } + + // Keep body_args alive for thread execution. + OnPredicateReady(*predicate_result, std::move(async_handles), + std::move(futures), immutable_tensors, + std::move(final_promises), body_fn, execution_context, + ++counter); + }); +} + +void AsyncWhileOp::PopulateFinalPromise( + std::vector& final_promises, + const std::vector& mutable_tensor_futures, + const std::vector& + immutable_tensors) { + // The final predicate needs to be a tensor, not bool so that await_all + // can be used. + tensorflow::Tensor final_predicate_tensor(false); + + auto final_promise_iter = final_promises.begin(); + std::move(*final_promise_iter) + .Set( + tensorflow::tfrt_stub::FallbackTensor( + std::move(final_predicate_tensor))); + final_promise_iter++; + for (auto& mutable_tensor_future : mutable_tensor_futures) { + DCHECK(mutable_tensor_future.IsReady()); + std::move(*final_promise_iter) + .Set( + std::move(mutable_tensor_future + .Get())); + final_promise_iter++; + } + for (auto& immutable_tensor : immutable_tensors) { + std::move(*final_promise_iter) + .Set(immutable_tensor); + final_promise_iter++; + } +} + +void AsyncWhileOp::Invoke() { + mlrt::bc::Function body_fn = body_function(); + + // Argument: [final_predicate, %variant0, %variant1, ..., %invariant0,...] + // + // Results: [final_predicate, %variant0, %variant1, ..., %invariant0,...] + // + DCHECK_EQ(arguments().size(), results().size()); + + // [predicate_promise; arg0_future, arg0_promise, arg1_future, arg1_promise, + // ..., invariant_args] + // minus 1 b/c predicate is not a tensor + const uint32_t immutable_tensor_size = immutable_size(); + const uint32_t mutable_tensor_size = + arguments().size() - immutable_tensor_size - 1; + + const uint32_t body_argument_size = + 1 + (2 * mutable_tensor_size) + immutable_tensor_size; + DCHECK_EQ(body_fn.input_regs().size(), body_argument_size); + DCHECK_EQ(body_fn.output_regs().size(), 0); + + tsl::profiler::TraceMe trace_me([&]() { + return tsl::profiler::TraceMeEncode("tf_mlrt.async_while", + {{"name", body_fn.name().Get()}}); + }); + + // Save the future of final results. The last iteration will set the promises. + std::vector final_promises; + final_promises.reserve(arguments().size()); + for (int i = 0; i < arguments().size(); ++i) { + final_promises.push_back( + mlrt::Promise::Allocate()); + results()[i] = final_promises.back().GetFuture(); + } + + // Populate input arguments into a list of dummy futures to bootstrap the + // first iteration. + std::vector mutable_tensor_futures; + mutable_tensor_futures.reserve(mutable_tensor_size); + + // Plus 1 because the very first argument is a boolean predicate . + auto arg_iter = arguments().begin() + 1; + for (int i = 0; i < mutable_tensor_size; ++i) { + auto tensor_promise = + mlrt::Promise::Allocate(); + mutable_tensor_futures.push_back(tensor_promise.GetFuture()); + std::move(tensor_promise) + .Set( + arg_iter->Get()); + arg_iter++; + } + + std::vector immutable_tensors; + immutable_tensors.reserve(immutable_tensor_size); + for (int i = 0; i < immutable_tensor_size; ++i) { + immutable_tensors.push_back( + arg_iter->Get()); + arg_iter++; + } + OnPredicateReady(arguments()[0].Get(), + /*async_handles=*/{}, std::move(mutable_tensor_futures), + immutable_tensors, std::move(final_promises), body_fn, + execution_context(), + /*counter=*/0); +} + struct MapFnOp : mlrt::KernelFrame { using KernelFrame::KernelFrame; @@ -142,7 +433,7 @@ void MapFnOp::Invoke() { body_arg_last_uses.begin() + 2 * num_tensor_list_or_flow_in() + 2, true); - // Copy the invairant arguments (after max_iteration + + // Copy the invariant arguments (after max_iteration + // tensor_list_or_flow_ins) auto arg_iter = body_args.begin() + 2 * num_tensor_list_or_flow_in() + 2; for (int j = num_tensor_list_or_flow_in() + 1; j < arguments().size(); @@ -684,6 +975,7 @@ void RegisterTfMlrtKernels(mlrt::KernelRegistry& registry) { registry.Register(); registry.Register("tfrt_fallback_sync.executeop"); registry.Register(); + registry.Register(); registry.Register(); registry.Register(); registry.Register("tf_mlrt.set_resource", &SetResource); diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc index 3ec567743fdf21..6e090a7a49f257 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc @@ -2438,6 +2438,287 @@ TEST(KernelTest, PromiseReturn) { output.Get().tensor(), expected); } +// A function body for AsyncWhile. +void TestAsyncWhileFnBody(mlrt::KernelFrame frame) { + ASSERT_EQ(frame.arguments().size(), 4); + + auto predicate_promise = std::move(frame.arguments()[0].Get()); + auto prev_loop_count_future = frame.arguments()[1].Get(); + auto next_loop_count_promise = + std::move(frame.arguments()[2].Get()); + + int32_t max_iteration = frame.arguments()[3] + .Get() + .tensor() + .scalar()(); + + for (; !prev_loop_count_future.IsReady();) { + // wait for future to be ready + } + int32_t prev_loop_count = + prev_loop_count_future.Get() + .tensor() + .scalar()(); + tensorflow::Tensor next_loop_count(DT_INT32, {}); + next_loop_count.scalar()() = prev_loop_count + 1; + + tensorflow::Tensor predicate(DT_BOOL, {}); + predicate.scalar()() = prev_loop_count + 1 < max_iteration; + std::move(predicate_promise) + .Set(std::move(predicate)); + + std::move(next_loop_count_promise) + .Set(std::move(next_loop_count)); +} + +mlrt::bc::Buffer CreateAsyncWhileExecutable() { + mlrt::bc::Buffer buffer; + mlrt::bc::Allocator allocator(&buffer); + + auto executable_ctor = mlrt::bc::New(&allocator); + mlrt::testing::SymbolTable kernels; + std::vector kernel_names = {"tf_mlrt.async_while", + "tf_mlrt.await_all", + "test_async_while_body", "return"}; + executable_ctor.construct_kernel_names(kernel_names.size()) + .Assign(kernel_names); + kernels.Def(kernel_names); + mlrt::testing::AttributeTable attributes( + executable_ctor.construct_attributes(1)); + + attributes.Add("body_idx", 1); + attributes.Add("invariant_size", 1); + auto functions_ctor = executable_ctor.construct_functions(2); + + { + auto function_ctor = functions_ctor.ConstructAt(0); + function_ctor.construct_name("main"); + mlrt::testing::SymbolTable regs; + function_ctor.construct_input_regs(3).Assign( + regs.Def({"initial_predicate", "loop_count", "max_iterations"})); + + auto kernels_ctor = function_ctor.construct_kernels(3); + { + auto kernel_ctor = kernels_ctor.ConstructAt(0); + kernel_ctor.set_code(kernels.Use("tf_mlrt.async_while")); + kernel_ctor.construct_attributes(2).Assign( + {attributes.GetHandle("body_idx"), + attributes.GetHandle("invariant_size")}); + kernel_ctor.construct_arguments(3).Assign( + regs.Use({"initial_predicate", "loop_count", "max_iterations"})); + kernel_ctor.construct_results(3).Assign( + regs.Def({"last_predicate_future", "final_loop_count_future", + "final_max_iterations_future"})); + } + { + auto kernel_ctor = kernels_ctor.ConstructAt(1); + kernel_ctor.set_code(kernels.Use("tf_mlrt.await_all")); + kernel_ctor.construct_arguments(3).Assign( + regs.Use({"last_predicate_future", "final_loop_count_future", + "final_max_iterations_future"})); + kernel_ctor.construct_last_uses(3).Assign({true, true, true}); + kernel_ctor.construct_results(3).Assign(regs.Def( + {"last_predicate", "final_loop_count", "final_max_iterations"})); + } + { + auto kernel_ctor = kernels_ctor.ConstructAt(2); + kernel_ctor.set_code(kernels.Use("return")); + kernel_ctor.construct_arguments(1).Assign({regs.Use("final_loop_count")}); + } + function_ctor.set_num_regs(regs.size()); + function_ctor.construct_output_regs(1).Assign( + {regs.Use("final_loop_count")}); + } + { + auto function_ctor = functions_ctor.ConstructAt(1); + function_ctor.construct_name("body_function"); + + mlrt::testing::SymbolTable regs; + + function_ctor.construct_input_regs(4).Assign( + regs.Def({"predicate_promise", "prev_loop_count_future", + "loop_count_promise", "max_iterations"})); + auto kernels_ctor = function_ctor.construct_kernels(2); + { + auto kernel_ctor = kernels_ctor.ConstructAt(0); + kernel_ctor.set_code(kernels.Use("test_async_while_body")); + kernel_ctor.construct_arguments(4).Assign( + regs.Use({"predicate_promise", "prev_loop_count_future", + "loop_count_promise", "max_iterations"})); + } + { + auto kernel_ctor = kernels_ctor.ConstructAt(1); + kernel_ctor.set_code(kernels.Use("return")); + } + function_ctor.set_num_regs(regs.size()); + } + return buffer; +} + +struct AsyncWhileOpTestParams { + bool initial_predicate; + int final_result; +}; +class AsyncWhileOpTestFixture + : public ::testing::TestWithParam {}; +TEST_P(AsyncWhileOpTestFixture, AsyncWhileOp) { + auto params = GetParam(); + auto buffer = CreateAsyncWhileExecutable(); + + mlrt::bc::Executable executable(buffer.data()); + + mlrt::KernelRegistry registry; + RegisterTfMlrtKernels(registry); + registry.Register("test_async_while_body", TestAsyncWhileFnBody); + + mlrt::LoadedExecutable loaded_executable(executable, registry); + + auto work_queue = tfrt::CreateMultiThreadedWorkQueue( + /*num_threads=*/4, /*num_blocking_threads=*/4); + mlrt::ExecutionContext execution_context(&loaded_executable); + execution_context.set_work_queue(work_queue.get()); + + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + TF_ASSERT_OK_AND_ASSIGN(auto fallback_state, tfrt_stub::FallbackState::Create( + session_options, fdef_lib)); + + std::function)> runner = + [](const std::function& f) { f(); }; + tfrt_stub::OpKernelRunnerTable runner_table; + tfd::FallbackResourceArray resource_array; + tfd::KernelFallbackCompatRequestState fallback_request_state( + &runner, &fallback_state->device_manager(), /*step_id=*/0, &runner_table, + &resource_array, /*user_intra_op_threadpool=*/nullptr, + /*model_metadata=*/std::nullopt, + &fallback_state->process_function_library_runtime()); + + tfrt::ResourceContext resource_context; + + auto tf_context = + std::make_unique(&fallback_request_state, &resource_context); + execution_context.AddUserContext(std::move(tf_context)); + + std::vector args; + args.resize(3); + + // initial predicate is true + tensorflow::Tensor initial_predicate_tensor{DT_BOOL, {}}; + initial_predicate_tensor.scalar()() = params.initial_predicate; + args.at(0).Set( + tfrt_stub::FallbackTensor(std::move(initial_predicate_tensor))); + + tensorflow::Tensor loop_count_tensor{DT_INT32, {}}; + loop_count_tensor.scalar()() = 0; + args.at(1).Set(tfrt_stub::FallbackTensor(std::move(loop_count_tensor))); + + tensorflow::Tensor max_iteration_tensor{DT_INT32, {}}; + max_iteration_tensor.scalar()() = 2; + args.at(2).Set(tfrt_stub::FallbackTensor(std::move(max_iteration_tensor))); + + mlrt::Value result; + + absl::Notification notification; + execution_context.set_exit_handler( + [¬ification]() { notification.Notify(); }); + + std::vector last_uses = {true, true, true}; + execution_context.Call(executable.functions()[0], last_uses, + absl::MakeSpan(args), absl::MakeSpan(&result, 1)); + mlrt::Execute(execution_context); + + notification.WaitForNotification(); + + ASSERT_OK(execution_context.status()); + + tensorflow::Tensor expected(tensorflow::DT_INT32, {}); + expected.scalar()() = params.final_result; + + auto& to_be = result.Get(); + tensorflow::test::ExpectEqual(to_be.tensor(), expected); +} + +INSTANTIATE_TEST_SUITE_P( + AsyncWhileOpTestSuite, AsyncWhileOpTestFixture, + ::testing::ValuesIn({{true, 2}, {false, 0}})); + +// A AsyncWhile body function that triggers failure. +void TestAsyncWhileFnBodyError(mlrt::KernelFrame frame) { + ASSERT_EQ(frame.arguments().size(), 4); + + frame.execution_context().Fail(absl::InternalError("Test error")); +} +TEST(KernelTest, AsyncWhileOpError) { + auto buffer = CreateAsyncWhileExecutable(); + + mlrt::bc::Executable executable(buffer.data()); + + mlrt::KernelRegistry registry; + RegisterTfMlrtKernels(registry); + registry.Register("test_async_while_body", TestAsyncWhileFnBodyError); + + mlrt::LoadedExecutable loaded_executable(executable, registry); + + auto work_queue = tfrt::CreateMultiThreadedWorkQueue( + /*num_threads=*/4, /*num_blocking_threads=*/4); + mlrt::ExecutionContext execution_context(&loaded_executable); + execution_context.set_work_queue(work_queue.get()); + + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + TF_ASSERT_OK_AND_ASSIGN(auto fallback_state, tfrt_stub::FallbackState::Create( + session_options, fdef_lib)); + + std::function)> runner = + [](const std::function& f) { f(); }; + tfrt_stub::OpKernelRunnerTable runner_table; + tfd::FallbackResourceArray resource_array; + tfd::KernelFallbackCompatRequestState fallback_request_state( + &runner, &fallback_state->device_manager(), /*step_id=*/0, &runner_table, + &resource_array, /*user_intra_op_threadpool=*/nullptr, + /*model_metadata=*/std::nullopt, + &fallback_state->process_function_library_runtime()); + + tfrt::ResourceContext resource_context; + + auto tf_context = + std::make_unique(&fallback_request_state, &resource_context); + execution_context.AddUserContext(std::move(tf_context)); + + std::vector args; + args.resize(3); + + // initial predicate is true + tensorflow::Tensor initial_predicate_tensor{DT_BOOL, {}}; + initial_predicate_tensor.scalar()() = true; + args.at(0).Set( + tfrt_stub::FallbackTensor(std::move(initial_predicate_tensor))); + + tensorflow::Tensor loop_count_tensor{DT_INT32, {}}; + loop_count_tensor.scalar()() = 0; + args.at(1).Set(tfrt_stub::FallbackTensor(std::move(loop_count_tensor))); + + tensorflow::Tensor max_iteration_tensor{DT_INT32, {}}; + max_iteration_tensor.scalar()() = 2; + args.at(2).Set(tfrt_stub::FallbackTensor(std::move(max_iteration_tensor))); + + mlrt::Value result; + + absl::Notification notification; + execution_context.set_exit_handler( + [¬ification]() { notification.Notify(); }); + + std::vector last_uses = {true, true, true}; + execution_context.Call(executable.functions()[0], last_uses, + absl::MakeSpan(args), absl::MakeSpan(&result, 1)); + mlrt::Execute(execution_context); + + notification.WaitForNotification(); + EXPECT_THAT( + execution_context.status(), + ::tsl::testing::StatusIs(absl::StatusCode::kInternal, "Test error")); +} + } // namespace } // namespace tf_mlrt } // namespace tensorflow