diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a678c6fbea5881..9e34e0f0f2f0b7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -2388,6 +2388,7 @@ tf_xla_py_strict_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/ops:random_ops", diff --git a/tensorflow/compiler/tests/xla_custom_call_ops_test.py b/tensorflow/compiler/tests/xla_custom_call_ops_test.py index d185b684aee070..164b6d7c875d8b 100644 --- a/tensorflow/compiler/tests/xla_custom_call_ops_test.py +++ b/tensorflow/compiler/tests/xla_custom_call_ops_test.py @@ -18,6 +18,7 @@ from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import random_ops @@ -46,6 +47,22 @@ def f(x, y): self.assertIn('custom_call_target="my_call"', hlo) self.assertIn('backend_config="my_backend_config"', hlo) + def testXlaCustomCallOpDoesntExist(self): + with ops.device('device:{}:0'.format(self.device)): + + def f(): + return xla.custom_call( + args=(1, 2), + target_name='my_non_existing_call_target', + dtype=dtypes.int32, + shape=(), + backend_config='my_backend_config', + ) + + with self.assertRaises(errors_impl.InvalidArgumentError): + compiled_f = def_function.function(f, jit_compile=True) + compiled_f() + def testXlaCustomCallV2Op(self): with ops.device('device:{}:0'.format(self.device)): diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e21aa7c85dc6a4..6362969d4243cb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -1404,9 +1404,12 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { std::move(llvm_context)); cantFail((*jit)->AddModule(std::move(thread_safe_module))); - auto cpu_executable = std::make_unique( - std::move(*jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)); + TF_ASSIGN_OR_RETURN( + auto cpu_executable, + CpuExecutable::Create(std::move(*jit), std::move(assignment), + std::move(module), function_name, + std::move(hlo_profile_printer_data), + std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { cpu_executable->set_ir_module_string(ir_module_string); @@ -1508,7 +1511,7 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable( obj_file); } - return std::make_unique( + return CpuExecutable::Create( std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map), std::move(assignment), std::move(xla_runtime_executable)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index a358c25fa471d6..c63c9e937cd165 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -58,53 +58,60 @@ namespace cpu { namespace runtime = ::xla::runtime; -CpuExecutable::CpuExecutable( +StatusOr> CpuExecutable::Create( std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, const std::string& entry_function_name, std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), - std::move(hlo_profile_index_map)), - jit_(std::move(jit)), - assignment_(std::move(assignment)), - module_name_(entry_function_name) { - if (assignment_) { - buffer_assignment_ = - std::make_shared(assignment_->ToProto()); - } - if (has_module()) { - XlaDebugInfoManager::Get()->RegisterModule(shared_module(), - buffer_assignment_); - } + std::unique_ptr hlo_profile_index_map) { + std::unique_ptr executable(new CpuExecutable( + std::move(hlo_module), std::move(hlo_profile_printer_data), + std::move(hlo_profile_index_map), std::move(assignment))); + executable->jit_ = std::move(jit); + executable->module_name_ = entry_function_name; // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. llvm::Expected sym = - jit_->FindCompiledSymbol(entry_function_name); + executable->jit_->FindCompiledSymbol(entry_function_name); // We expect to find the symbol provided with entry_function_name; otherwise // this is an internal error. - CHECK(sym->getAddress()) << "Symbol " << entry_function_name << " not found."; + if (!sym) { + return absl::InvalidArgumentError( + absl::StrCat("Symbol ", entry_function_name, " not found.")); + } // getAddress can do work under the hood in the jit, so it needs to be // guarded by the mutex. - compute_function_ = + executable->compute_function_ = reinterpret_cast(sym->getAddress().getValue()); VLOG(1) << "compute_function_ at address " - << reinterpret_cast(compute_function_); - jit_->DoneCompiling(); + << reinterpret_cast(executable->compute_function_); + executable->jit_->DoneCompiling(); + return executable; } -CpuExecutable::CpuExecutable( +StatusOr> CpuExecutable::Create( std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map, std::unique_ptr assignment, - std::unique_ptr xla_runtime_executable) + std::unique_ptr xla_runtime_executable) { + std::unique_ptr executable(new CpuExecutable( + std::move(hlo_module), std::move(hlo_profile_printer_data), + std::move(hlo_profile_index_map), std::move(assignment))); + executable->xla_runtime_executable_ = std::move(xla_runtime_executable); + return executable; +} + +CpuExecutable::CpuExecutable( + std::unique_ptr hlo_module, + std::unique_ptr hlo_profile_printer_data, + std::unique_ptr hlo_profile_index_map, + std::unique_ptr assignment) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), - assignment_(std::move(assignment)), - xla_runtime_executable_(std::move(xla_runtime_executable)) { + assignment_(std::move(assignment)) { if (assignment_) { buffer_assignment_ = std::make_shared(assignment_->ToProto()); @@ -328,9 +335,9 @@ StatusOr> CpuExecutable::LoadFromObjFile( std::move(executable_ptr), xla_framework_mapping, std::move(*ffi_modules_state)); - return std::unique_ptr(new CpuExecutable( - std::move(hlo_module), nullptr, nullptr, std::move(buffer_assignment), - std::move(xla_runtime_executable))); + return CpuExecutable::Create(std::move(hlo_module), nullptr, nullptr, + std::move(buffer_assignment), + std::move(xla_runtime_executable)); } StatusOr CpuExecutable::CreateResultShapedBuffer( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 46e705ebe458e9..1eb61241be5860 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -138,14 +138,15 @@ class XlaRuntimeCpuExecutable { // architecture, so JIT-ed code and host code share the same ABI. class CpuExecutable : public Executable { public: - CpuExecutable(std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - const std::string& entry_function_name, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map); - // XLA Runtime constructor. - CpuExecutable( + static StatusOr> Create( + std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + const std::string& entry_function_name, + std::unique_ptr hlo_profile_printer_data, + std::unique_ptr hlo_profile_index_map); + // XLA Runtime factory method. + static StatusOr> Create( std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map, @@ -257,7 +258,7 @@ class CpuExecutable : public Executable { const InstructionValueSet& GetRootValueSet() const; // The JIT containing compiled modules. - const std::unique_ptr jit_; + std::unique_ptr jit_; // Buffer assignment for the buffers we need to allocate. const std::unique_ptr assignment_; @@ -281,6 +282,10 @@ class CpuExecutable : public Executable { // If not null, XLA Runtime is enabled. std::unique_ptr xla_runtime_executable_; + CpuExecutable(std::unique_ptr hlo_module, + std::unique_ptr hlo_profile_printer_data, + std::unique_ptr hlo_profile_index_map, + std::unique_ptr assignment); CpuExecutable(const CpuExecutable&) = delete; CpuExecutable& operator=(const CpuExecutable&) = delete; };