Skip to content

Commit

Permalink
Add a check for device memory allocator in FFI.
Browse files Browse the repository at this point in the history
When the device allocator isn't set, attempting to allocate memory results in a segfault. This is user error, but it's probably worth throwing an error instead of crashing.

PiperOrigin-RevId: 660731970
  • Loading branch information
dfm authored and tensorflower-gardener committed Aug 8, 2024
1 parent d70e4b5 commit 4dbed9e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
13 changes: 13 additions & 0 deletions third_party/xla/xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,19 @@ TEST(FfiTest, ScratchAllocator) {
EXPECT_EQ(allocator.count, 0);
}

TEST(FfiTest, ScratchAllocatorUnimplemented) {
auto fn = [&](ScratchAllocator scratch_allocator) {
auto mem = scratch_allocator.Allocate(1024);
EXPECT_FALSE(mem.has_value());
return Error::Success();
};
auto handler = Ffi::Bind().Ctx<ScratchAllocator>().To(fn);
CallFrame call_frame =
CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build();
auto status = Call(*handler, call_frame);
TF_ASSERT_OK(status);
}

//===----------------------------------------------------------------------===//
// Performance benchmarks are below.
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions third_party/xla/xla/ffi/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Allocate(
InvalidArgument("Unsupported alignment: %d", args->alignment)};
}

if (ABSL_PREDICT_FALSE(args->ctx->allocator == nullptr)) {
return new XLA_FFI_Error{
Unimplemented("No device memory allocator available on this platform")};
}

absl::StatusOr<stream_executor::OwningDeviceMemory> memory =
args->ctx->allocator->Allocate(args->ctx->device_ordinal, args->size);
if (!memory.ok()) {
Expand All @@ -486,6 +491,11 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Free(
"XLA_FFI_DeviceMemory_Free_Args",
XLA_FFI_DeviceMemory_Free_Args_STRUCT_SIZE, args->struct_size));

if (ABSL_PREDICT_FALSE(args->ctx->allocator == nullptr)) {
return new XLA_FFI_Error{
Unimplemented("No device memory allocator available on this platform")};
}

absl::Status status = args->ctx->allocator->Deallocate(
args->ctx->device_ordinal,
stream_executor::DeviceMemoryBase(args->data, args->size));
Expand Down

0 comments on commit 4dbed9e

Please sign in to comment.