diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 2e9deaaf4066e7..c9f95edb6ef95a 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -164,6 +164,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index 0af899a77c4d9a..45c8db2a1b52f3 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -89,6 +89,9 @@ xla_cc_test( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 6c95003951c4c6..c4a7482c85354b 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -534,6 +534,37 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_DeviceMemory_Free_Args, data); typedef XLA_FFI_Error* XLA_FFI_DeviceMemory_Free( XLA_FFI_DeviceMemory_Free_Args* args); +//===----------------------------------------------------------------------===// +// ThreadPool +//===----------------------------------------------------------------------===// + +// A function pointer for a task to be scheduled on a thread pool. XLA runtime +// will call this function with a user-defined `data` pointer on one of the +// runtime-managed threads. For XLA:CPU backends the task will be invoked on +// a thread pool that runs all compute tasks (Eigen thread pool). +// +// IMPORTANT: Users must not rely on any particular execution order or the +// number of available threads. Tasks can be executed in the caller thread, or +// in a thread pool with size `1`, and it is unsafe to assume that all scheduled +// tasks can be executed in parallel. +typedef void XLA_FFI_Task(void* data); + +struct XLA_FFI_ThreadPool_Schedule_Args { + size_t struct_size; + XLA_FFI_Extension_Base* extension_start; + + XLA_FFI_ExecutionContext* ctx; + XLA_FFI_Task* task; + void* data; +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_ThreadPool_Schedule_Args, data); + +// Schedules a task to be executed on a thread pool managed by XLA runtime. +// Returns an error if thread pool is not available. +typedef XLA_FFI_Error* XLA_FFI_ThreadPool_Schedule( + XLA_FFI_ThreadPool_Schedule_Args* args); + //===----------------------------------------------------------------------===// // Metadata extension //===----------------------------------------------------------------------===// @@ -577,6 +608,7 @@ struct XLA_FFI_Api { _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Get); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Allocate); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Free); + _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_Schedule); }; #undef _XLA_FFI_API_STRUCT_FIELD diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index abb517a7a8d0e8..6990090a4cda8e 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -920,8 +920,6 @@ struct CtxDecoding> { // the particular call to FFI handler. class ScratchAllocator { public: - ScratchAllocator(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, - DiagnosticEngine& diagnostic); ~ScratchAllocator(); ScratchAllocator(ScratchAllocator&&) = default; @@ -930,6 +928,11 @@ class ScratchAllocator { std::optional Allocate(size_t size, size_t alignment = 1); private: + friend struct CtxDecoding; + + ScratchAllocator(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine& diagnostic); + struct Allocation { size_t size; void* data; @@ -997,6 +1000,73 @@ inline ScratchAllocator::~ScratchAllocator() { } } +//===----------------------------------------------------------------------===// +// ThreadPool +//===----------------------------------------------------------------------===// + +class ThreadPool { + public: + template + void Schedule(F&& f) { + XLA_FFI_Task* task = +[](void* data) { + auto* f = reinterpret_cast(data); + (*f)(); + delete f; + }; + + F* data = new F(std::forward(f)); + + XLA_FFI_ThreadPool_Schedule_Args args; + args.struct_size = XLA_FFI_ThreadPool_Schedule_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.ctx = ctx_; + args.task = task; + args.data = data; + + if (XLA_FFI_Error* error = api_->XLA_FFI_ThreadPool_Schedule(&args)) { + diagnostic_.Emit("Failed to schedule task on a thread pool: ") + << internal::GetErrorMessage(api_, error); + internal::DestroyError(api_, error); + + // If thread pool is not available, we execute the task in the caller + // thread. We choose not to return error from `Schedule` for consistency + // with Eigen thread pool implementation, and because it would make + // recursive work scheduling more difficult. + task(data); + } + } + + private: + friend struct CtxDecoding; + + ThreadPool(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine& diagnostic); + + const XLA_FFI_Api* api_; + XLA_FFI_ExecutionContext* ctx_; + DiagnosticEngine& diagnostic_; +}; + +// Context decoding for thread pool. +// +// Example: Ffi::Bind().Ctx() +// .To([](ThreadPool thread_pool) { ... }); +template <> +struct CtxDecoding { + using Type = ThreadPool; + + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine& diagnostic) { + return ThreadPool(api, ctx, diagnostic); + } +}; + +inline ThreadPool::ThreadPool(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine& diagnostic) + : api_(api), ctx_(ctx), diagnostic_(diagnostic) {} + //===----------------------------------------------------------------------===// // Type Registration //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index ccf3125b7067da..27dcff09504d03 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/match.h" +#include "absl/synchronization/blocking_counter.h" #include "xla/ffi/call_frame.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/execution_state.h" @@ -38,9 +39,14 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/env.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" +#include "tsl/platform/threadpool.h" + +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" namespace xla::ffi { @@ -1039,6 +1045,38 @@ TEST(FfiTest, ScratchAllocatorUnimplemented) { TF_ASSERT_OK(status); } +TEST(FfiTest, ThreadPool) { + tsl::thread::ThreadPool pool(tsl::Env::Default(), "XLAEigen", 2); + Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads()); + + auto fn = [&](ThreadPool thread_pool) { + // Use a pair of blocking counters to check that scheduled task was executed + // on a thread pool (it would deadlock if executed inline). + absl::BlockingCounter prepare(1); + absl::BlockingCounter execute(1); + + thread_pool.Schedule([&] { + prepare.Wait(); + execute.DecrementCount(); + }); + + prepare.DecrementCount(); + execute.Wait(); + + return Error::Success(); + }; + + auto handler = Ffi::Bind().Ctx().To(fn); + CallFrame call_frame = + CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build(); + + CallOptions options; + options.backend_options = CallOptions::CpuOptions{&device}; + + auto status = Call(*handler, call_frame, options); + TF_ASSERT_OK(status); +} + TEST(FfiTest, Metadata) { auto api = GetXlaFfiApi(); auto handler = Ffi::BindTo([]() { return Error::Success(); }); diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 8b1ae332679d80..de747370ea8ed1 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -49,6 +50,9 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" + //===----------------------------------------------------------------------===// // XLA FFI C structs definition //===----------------------------------------------------------------------===// @@ -636,6 +640,31 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Free( return nullptr; } +static XLA_FFI_Error* XLA_FFI_ThreadPool_Schedule( + XLA_FFI_ThreadPool_Schedule_Args* args) { + XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "XLA_FFI_ThreadPool_Schedule_Args", + XLA_FFI_ThreadPool_Schedule_Args_STRUCT_SIZE, args->struct_size)); + + auto* cpu = std::get_if( + &args->ctx->backend_context); + + if (ABSL_PREDICT_FALSE(cpu == nullptr)) { + return new XLA_FFI_Error{ + Unimplemented("XLA FFI CPU context is not available")}; + } + + if (ABSL_PREDICT_FALSE(cpu->intra_op_thread_pool == nullptr)) { + return new XLA_FFI_Error{ + Unimplemented("No intra-op thread pool available on this platform")}; + } + + cpu->intra_op_thread_pool->enqueueNoNotification( + [task = args->task, data = args->data] { (*task)(data); }); + + return nullptr; +} + //===----------------------------------------------------------------------===// // XLA FFI Internal Api Implementation //===----------------------------------------------------------------------===// @@ -740,6 +769,7 @@ static XLA_FFI_Api api = { XLA_FFI_State_Get, XLA_FFI_DeviceMemory_Allocate, XLA_FFI_DeviceMemory_Free, + XLA_FFI_ThreadPool_Schedule, }; const XLA_FFI_Api* GetXlaFfiApi() { return &api; }