Skip to content

Commit

Permalink
[xla:ffi] Add support for ThreadPool to external FFI
Browse files Browse the repository at this point in the history
WARNING: It is unsafe to block in the FFI handler waiting for work submitted into a thread pool, completion must be signaled with AsyncValue. This is coming soon!
PiperOrigin-RevId: 669139776
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Aug 30, 2024
1 parent 729db88 commit 2288d10
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 2 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/ffi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@eigen_archive//:eigen3",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/ffi/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ xla_cc_test(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:test",
Expand Down
32 changes: 32 additions & 0 deletions third_party/xla/xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,37 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_DeviceMemory_Free_Args, data);
typedef XLA_FFI_Error* XLA_FFI_DeviceMemory_Free(
XLA_FFI_DeviceMemory_Free_Args* args);

//===----------------------------------------------------------------------===//
// ThreadPool
//===----------------------------------------------------------------------===//

// A function pointer for a task to be scheduled on a thread pool. XLA runtime
// will call this function with a user-defined `data` pointer on one of the
// runtime-managed threads. For XLA:CPU backends the task will be invoked on
// a thread pool that runs all compute tasks (Eigen thread pool).
//
// IMPORTANT: Users must not rely on any particular execution order or the
// number of available threads. Tasks can be executed in the caller thread, or
// in a thread pool with size `1`, and it is unsafe to assume that all scheduled
// tasks can be executed in parallel.
typedef void XLA_FFI_Task(void* data);

struct XLA_FFI_ThreadPool_Schedule_Args {
size_t struct_size;
XLA_FFI_Extension_Base* extension_start;

XLA_FFI_ExecutionContext* ctx;
XLA_FFI_Task* task;
void* data;
};

XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_ThreadPool_Schedule_Args, data);

// Schedules a task to be executed on a thread pool managed by XLA runtime.
// Returns an error if thread pool is not available.
typedef XLA_FFI_Error* XLA_FFI_ThreadPool_Schedule(
XLA_FFI_ThreadPool_Schedule_Args* args);

//===----------------------------------------------------------------------===//
// Metadata extension
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -577,6 +608,7 @@ struct XLA_FFI_Api {
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Get);
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Allocate);
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Free);
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_Schedule);
};

#undef _XLA_FFI_API_STRUCT_FIELD
Expand Down
74 changes: 72 additions & 2 deletions third_party/xla/xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,6 @@ struct CtxDecoding<PlatformStream<T>> {
// the particular call to FFI handler.
class ScratchAllocator {
public:
ScratchAllocator(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic);
~ScratchAllocator();

ScratchAllocator(ScratchAllocator&&) = default;
Expand All @@ -930,6 +928,11 @@ class ScratchAllocator {
std::optional<void*> Allocate(size_t size, size_t alignment = 1);

private:
friend struct CtxDecoding<ScratchAllocator>;

ScratchAllocator(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic);

struct Allocation {
size_t size;
void* data;
Expand Down Expand Up @@ -997,6 +1000,73 @@ inline ScratchAllocator::~ScratchAllocator() {
}
}

//===----------------------------------------------------------------------===//
// ThreadPool
//===----------------------------------------------------------------------===//

class ThreadPool {
public:
template <typename F>
void Schedule(F&& f) {
XLA_FFI_Task* task = +[](void* data) {
auto* f = reinterpret_cast<F*>(data);
(*f)();
delete f;
};

F* data = new F(std::forward<F>(f));

XLA_FFI_ThreadPool_Schedule_Args args;
args.struct_size = XLA_FFI_ThreadPool_Schedule_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.ctx = ctx_;
args.task = task;
args.data = data;

if (XLA_FFI_Error* error = api_->XLA_FFI_ThreadPool_Schedule(&args)) {
diagnostic_.Emit("Failed to schedule task on a thread pool: ")
<< internal::GetErrorMessage(api_, error);
internal::DestroyError(api_, error);

// If thread pool is not available, we execute the task in the caller
// thread. We choose not to return error from `Schedule` for consistency
// with Eigen thread pool implementation, and because it would make
// recursive work scheduling more difficult.
task(data);
}
}

private:
friend struct CtxDecoding<ThreadPool>;

ThreadPool(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic);

const XLA_FFI_Api* api_;
XLA_FFI_ExecutionContext* ctx_;
DiagnosticEngine& diagnostic_;
};

// Context decoding for thread pool.
//
// Example: Ffi::Bind().Ctx<ThreadPool>()
// .To([](ThreadPool thread_pool) { ... });
template <>
struct CtxDecoding<ThreadPool> {
using Type = ThreadPool;

static std::optional<Type> Decode(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic) {
return ThreadPool(api, ctx, diagnostic);
}
};

inline ThreadPool::ThreadPool(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic)
: api_(api), ctx_(ctx), diagnostic_(diagnostic) {}

//===----------------------------------------------------------------------===//
// Type Registration
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 38 additions & 0 deletions third_party/xla/xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/synchronization/blocking_counter.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
Expand All @@ -38,9 +39,14 @@ limitations under the License.
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/env.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/test.h"
#include "tsl/platform/test_benchmark.h"
#include "tsl/platform/threadpool.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

namespace xla::ffi {

Expand Down Expand Up @@ -1039,6 +1045,38 @@ TEST(FfiTest, ScratchAllocatorUnimplemented) {
TF_ASSERT_OK(status);
}

TEST(FfiTest, ThreadPool) {
tsl::thread::ThreadPool pool(tsl::Env::Default(), "XLAEigen", 2);
Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads());

auto fn = [&](ThreadPool thread_pool) {
// Use a pair of blocking counters to check that scheduled task was executed
// on a thread pool (it would deadlock if executed inline).
absl::BlockingCounter prepare(1);
absl::BlockingCounter execute(1);

thread_pool.Schedule([&] {
prepare.Wait();
execute.DecrementCount();
});

prepare.DecrementCount();
execute.Wait();

return Error::Success();
};

auto handler = Ffi::Bind().Ctx<ThreadPool>().To(fn);
CallFrame call_frame =
CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();

CallOptions options;
options.backend_options = CallOptions::CpuOptions{&device};

auto status = Call(*handler, call_frame, options);
TF_ASSERT_OK(status);
}

TEST(FfiTest, Metadata) {
auto api = GetXlaFfiApi();
auto handler = Ffi::BindTo([]() { return Error::Success(); });
Expand Down
30 changes: 30 additions & 0 deletions third_party/xla/xla/ffi/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <exception>
#include <new>
#include <string>
#include <string_view>
#include <utility>
Expand Down Expand Up @@ -49,6 +50,9 @@ limitations under the License.
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

//===----------------------------------------------------------------------===//
// XLA FFI C structs definition
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -636,6 +640,31 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Free(
return nullptr;
}

static XLA_FFI_Error* XLA_FFI_ThreadPool_Schedule(
XLA_FFI_ThreadPool_Schedule_Args* args) {
XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"XLA_FFI_ThreadPool_Schedule_Args",
XLA_FFI_ThreadPool_Schedule_Args_STRUCT_SIZE, args->struct_size));

auto* cpu = std::get_if<XLA_FFI_ExecutionContext::CpuContext>(
&args->ctx->backend_context);

if (ABSL_PREDICT_FALSE(cpu == nullptr)) {
return new XLA_FFI_Error{
Unimplemented("XLA FFI CPU context is not available")};
}

if (ABSL_PREDICT_FALSE(cpu->intra_op_thread_pool == nullptr)) {
return new XLA_FFI_Error{
Unimplemented("No intra-op thread pool available on this platform")};
}

cpu->intra_op_thread_pool->enqueueNoNotification(
[task = args->task, data = args->data] { (*task)(data); });

return nullptr;
}

//===----------------------------------------------------------------------===//
// XLA FFI Internal Api Implementation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -740,6 +769,7 @@ static XLA_FFI_Api api = {
XLA_FFI_State_Get,
XLA_FFI_DeviceMemory_Allocate,
XLA_FFI_DeviceMemory_Free,
XLA_FFI_ThreadPool_Schedule,
};

const XLA_FFI_Api* GetXlaFfiApi() { return &api; }
Expand Down

0 comments on commit 2288d10

Please sign in to comment.