Skip to content

Commit

Permalink
[XLA:CPU] Refactor CpuExecutable so LLVM errors can be propagated
Browse files Browse the repository at this point in the history
Otherwise we'd crash on cases like non-existing CustomCall target.

PiperOrigin-RevId: 562563302
  • Loading branch information
d0k authored and tensorflow-jenkins committed Sep 7, 2023
1 parent 7ce14c3 commit 921b0f9
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 40 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2388,6 +2388,7 @@ tf_xla_py_strict_test(
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:errors",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/ops:random_ops",
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/compiler/tests/xla_custom_call_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import random_ops
Expand Down Expand Up @@ -46,6 +47,22 @@ def f(x, y):
self.assertIn('custom_call_target="my_call"', hlo)
self.assertIn('backend_config="my_backend_config"', hlo)

def testXlaCustomCallOpDoesntExist(self):
with ops.device('device:{}:0'.format(self.device)):

def f():
return xla.custom_call(
args=(1, 2),
target_name='my_non_existing_call_target',
dtype=dtypes.int32,
shape=(),
backend_config='my_backend_config',
)

with self.assertRaises(errors_impl.InvalidArgumentError):
compiled_f = def_function.function(f, jit_compile=True)
compiled_f()

def testXlaCustomCallV2Op(self):
with ops.device('device:{}:0'.format(self.device)):

Expand Down
11 changes: 7 additions & 4 deletions tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1404,9 +1404,12 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
std::move(llvm_context));
cantFail((*jit)->AddModule(std::move(thread_safe_module)));

auto cpu_executable = std::make_unique<CpuExecutable>(
std::move(*jit), std::move(assignment), std::move(module), function_name,
std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map));
TF_ASSIGN_OR_RETURN(
auto cpu_executable,
CpuExecutable::Create(std::move(*jit), std::move(assignment),
std::move(module), function_name,
std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)));

if (embed_ir_in_executable) {
cpu_executable->set_ir_module_string(ir_module_string);
Expand Down Expand Up @@ -1508,7 +1511,7 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable(
obj_file);
}

return std::make_unique<CpuExecutable>(
return CpuExecutable::Create(
std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map), std::move(assignment),
std::move(xla_runtime_executable));
Expand Down
61 changes: 34 additions & 27 deletions tensorflow/compiler/xla/service/cpu/cpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,53 +58,60 @@ namespace cpu {

namespace runtime = ::xla::runtime;

CpuExecutable::CpuExecutable(
StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloModule> hlo_module,
const std::string& entry_function_name,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
assignment_(std::move(assignment)),
module_name_(entry_function_name) {
if (assignment_) {
buffer_assignment_ =
std::make_shared<BufferAssignmentProto>(assignment_->ToProto());
}
if (has_module()) {
XlaDebugInfoManager::Get()->RegisterModule(shared_module(),
buffer_assignment_);
}
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) {
std::unique_ptr<CpuExecutable> executable(new CpuExecutable(
std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map), std::move(assignment)));
executable->jit_ = std::move(jit);
executable->module_name_ = entry_function_name;

// Resolve symbols in the constructor rather than at execution time to avoid
// races because FindSymbol is not thread safe.
llvm::Expected<llvm::orc::ExecutorSymbolDef> sym =
jit_->FindCompiledSymbol(entry_function_name);
executable->jit_->FindCompiledSymbol(entry_function_name);
// We expect to find the symbol provided with entry_function_name; otherwise
// this is an internal error.
CHECK(sym->getAddress()) << "Symbol " << entry_function_name << " not found.";
if (!sym) {
return absl::InvalidArgumentError(
absl::StrCat("Symbol ", entry_function_name, " not found."));
}
// getAddress can do work under the hood in the jit, so it needs to be
// guarded by the mutex.
compute_function_ =
executable->compute_function_ =
reinterpret_cast<ComputeFunctionType>(sym->getAddress().getValue());
VLOG(1) << "compute_function_ at address "
<< reinterpret_cast<void*>(compute_function_);
jit_->DoneCompiling();
<< reinterpret_cast<void*>(executable->compute_function_);
executable->jit_->DoneCompiling();
return executable;
}

CpuExecutable::CpuExecutable(
StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable)
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable) {
std::unique_ptr<CpuExecutable> executable(new CpuExecutable(
std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map), std::move(assignment)));
executable->xla_runtime_executable_ = std::move(xla_runtime_executable);
return executable;
}

CpuExecutable::CpuExecutable(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
std::unique_ptr<const BufferAssignment> assignment)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
assignment_(std::move(assignment)),
xla_runtime_executable_(std::move(xla_runtime_executable)) {
assignment_(std::move(assignment)) {
if (assignment_) {
buffer_assignment_ =
std::make_shared<BufferAssignmentProto>(assignment_->ToProto());
Expand Down Expand Up @@ -328,9 +335,9 @@ StatusOr<std::unique_ptr<Executable>> CpuExecutable::LoadFromObjFile(
std::move(executable_ptr), xla_framework_mapping,
std::move(*ffi_modules_state));

return std::unique_ptr<Executable>(new CpuExecutable(
std::move(hlo_module), nullptr, nullptr, std::move(buffer_assignment),
std::move(xla_runtime_executable)));
return CpuExecutable::Create(std::move(hlo_module), nullptr, nullptr,
std::move(buffer_assignment),
std::move(xla_runtime_executable));
}

StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
Expand Down
23 changes: 14 additions & 9 deletions tensorflow/compiler/xla/service/cpu/cpu_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,15 @@ class XlaRuntimeCpuExecutable {
// architecture, so JIT-ed code and host code share the same ABI.
class CpuExecutable : public Executable {
public:
CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloModule> hlo_module,
const std::string& entry_function_name,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
// XLA Runtime constructor.
CpuExecutable(
static StatusOr<std::unique_ptr<CpuExecutable>> Create(
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloModule> hlo_module,
const std::string& entry_function_name,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
// XLA Runtime factory method.
static StatusOr<std::unique_ptr<CpuExecutable>> Create(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
Expand Down Expand Up @@ -257,7 +258,7 @@ class CpuExecutable : public Executable {
const InstructionValueSet& GetRootValueSet() const;

// The JIT containing compiled modules.
const std::unique_ptr<SimpleOrcJIT> jit_;
std::unique_ptr<SimpleOrcJIT> jit_;

// Buffer assignment for the buffers we need to allocate.
const std::unique_ptr<const BufferAssignment> assignment_;
Expand All @@ -281,6 +282,10 @@ class CpuExecutable : public Executable {
// If not null, XLA Runtime is enabled.
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable_;

CpuExecutable(std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
std::unique_ptr<const BufferAssignment> assignment);
CpuExecutable(const CpuExecutable&) = delete;
CpuExecutable& operator=(const CpuExecutable&) = delete;
};
Expand Down

0 comments on commit 921b0f9

Please sign in to comment.