Skip to content

Commit

Permalink
Adding new device flag for compilation cache
Browse files Browse the repository at this point in the history
If the device flag is specified, then the cache will only be used for that specific device. By default, this change does impact how the compilation cache works.

This flag is being added to prevent errors from occurring when running on hardware with multiple device types.

PiperOrigin-RevId: 555311811
  • Loading branch information
tensorflower-gardener committed Aug 9, 2023
1 parent ee02ea9 commit fa481b4
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 4 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ cc_library(
"//tensorflow/core/tfrt/common:pjrt_util",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/tsl/framework:device_id_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/jit/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
&mark_for_compilation_flags->tf_xla_persistent_cache_directory,
"If non-empty, JIT-compiled executables are saved to and loaded "
"from the specified file system directory path. Empty by default."),
Flag("tf_xla_persistent_cache_device_types",
&mark_for_compilation_flags->tf_xla_persistent_cache_device_types,
"If non-empty, the persistent cache will only be used for the "
"specified devices (comma separated). Each device type should be "
"able to be converted to `DeviceType`."),
Flag("tf_xla_disable_strict_signature_checks",
&mark_for_compilation_flags->tf_xla_disable_strict_signature_checks,
"If true, entires loaded into the XLA compile cache will not have "
Expand Down Expand Up @@ -214,6 +219,7 @@ void AllocateAndParseFlags() {
->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false;
mark_for_compilation_flags->tf_xla_persistent_cache_directory = "";
mark_for_compilation_flags->tf_xla_persistent_cache_device_types = "";
mark_for_compilation_flags->tf_xla_disable_strict_signature_checks = false;
mark_for_compilation_flags->tf_xla_persistent_cache_prefix =
"xla_compile_cache";
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/jit/flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ struct MarkForCompilationPassFlags {
// specified file system directory path.
std::string tf_xla_persistent_cache_directory;

// If non-empty, the persistent cache will only be used for the specified
// devices (comma separated). Each device type should be able to be converted
// to `DeviceType`.
std::string tf_xla_persistent_cache_device_types;

// If true, entries loaded into the XLA compile cache will not have their
// signatures checked strictly. This should generally not be disabled except
// for debugging. Defaults to false.
Expand Down
30 changes: 28 additions & 2 deletions tensorflow/compiler/jit/xla_platform_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/device_executable_persistor.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h"
Expand Down Expand Up @@ -63,8 +66,11 @@ XlaDeviceCompiler* CreateXlaDeviceCompiler(

PjRtDeviceCompiler* CreatePjRtDeviceCompiler(DeviceType compilation_device_type,
xla::PjRtClient* pjrt_client) {
std::string persistent_cache_directory =
GetPersistentCacheDirectory(compilation_device_type);

PjRtDeviceExecutablePersistor::Config persistor_config(
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
persistent_cache_directory,
GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);

Expand Down Expand Up @@ -142,6 +148,23 @@ Status GetCompilationDeviceTypeAndPjRtClient(
}
} // namespace

std::string GetPersistentCacheDirectory(
const DeviceType& compilation_device_type) {
// If a persistent cache device type is specified, ensure it matches
// compilation device type.
if (!GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types.empty() &&
!absl::c_any_of(absl::StrSplit(GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types,
','),
[&](absl::string_view device) {
return compilation_device_type == DeviceType(device);
})) {
return "";
}
return GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory;
}

xla::StatusOr<std::optional<std::set<int>>> ParseVisibleDeviceList(
absl::string_view visible_device_list) {
std::set<int> gpu_ids;
Expand All @@ -166,8 +189,11 @@ xla::StatusOr<std::optional<std::set<int>>> ParseVisibleDeviceList(
Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr,
const XlaPlatformInfo& platform_info,
XlaDeviceCompiler** xla_device_compiler) {
std::string persistent_cache_directory =
GetPersistentCacheDirectory(platform_info.device_type());

XlaDeviceExecutablePersistor::Config persistor_config(
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
persistent_cache_directory,
GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);

Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/jit/xla_platform_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler(
// Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);

// Obtains persistent cache directory for executables that target a given device
// based off xla flags. If you shouldn't use persistent caching, returns "".
std::string GetPersistentCacheDirectory(
const DeviceType& compilation_device_type);

// Returns allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
Expand Down
99 changes: 97 additions & 2 deletions tensorflow/compiler/jit/xla_platform_info_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class XlaPlatformInfoTest : public ::testing::Test {
protected:
void SetUp() override {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_directory = "";
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types = "";
}

DeviceSetup device_setup_;
Expand Down Expand Up @@ -73,6 +77,29 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerXlaDeviceMetadata) {
EXPECT_EQ(xla_device_compiler->client(), metadata->client());
}

TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerXlaDeviceCacheEnabled) {
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_directory = "/tmp/xla_cache";
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types = DEVICE_XLA_GPU;
device_setup_.AddDevicesAndSetUp({DEVICE_XLA_GPU});

Device* device = device_setup_.GetDevice(DEVICE_XLA_GPU);
const XlaDevice::Metadata* metadata = nullptr;
TF_CHECK_OK(XlaDevice::GetMetadataFromDevice(device, &metadata));
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device);

XlaDeviceCompiler* xla_device_compiler = nullptr;
TF_EXPECT_OK(BuildXlaDeviceCompiler(device, device_setup_.flr(),
platform_info, &xla_device_compiler));
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);

EXPECT_EQ(xla_device_compiler->device_type(), metadata->jit_device_type());
EXPECT_EQ(xla_device_compiler->client(), metadata->client());
EXPECT_EQ(xla_device_compiler->persistor()->persistent_cache_directory(),
"/tmp/xla_cache");
}

TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerNonXlaDevice) {
device_setup_.AddDevicesAndSetUp({DEVICE_GPU});
Device* device = device_setup_.GetDevice(DEVICE_GPU);
Expand Down Expand Up @@ -115,7 +142,12 @@ TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerXlaDevice) {
EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client);
}

TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerGpuDevice) {
TEST_F(XlaPlatformInfoTest,
GetOrCreatePjRtDeviceCompilerAndProfilerGpuDeviceCacheEnabled) {
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_directory = "/tmp/xla_cache";
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types = DEVICE_GPU_XLA_JIT;
device_setup_.AddDevicesAndSetUp({DEVICE_GPU});
Device* device = device_setup_.GetDevice(DEVICE_GPU);
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device);
Expand All @@ -131,6 +163,8 @@ TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerGpuDevice) {
TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler(
ctx, platform_info, device_setup_.flr(), &pjrt_device_compiler,
&profiler));
EXPECT_EQ(pjrt_device_compiler->persistor()->persistent_cache_directory(),
"/tmp/xla_cache");
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
core::ScopedUnref profiler_ref(profiler);
}
Expand Down Expand Up @@ -161,9 +195,42 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerTpuDevice) {
EXPECT_EQ(xla_device_compiler->client(), nullptr);
}

TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerNoCompilationCache) {
DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_directory = "/tmp/xla_cache";
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types = DEVICE_XLA_GPU;

// Instead of creating/initializing a TPU device, create a dummy platform_info
// and use a nullptr for Device for testing purposes. Only
// XlaPlatformInfo::device_type() is needed to build the appropriate
// XlaDeviceCompiler.
Device* device = nullptr;
XlaPlatformInfo platform_info(DeviceType(DEVICE_TPU), /*platform_id=*/nullptr,
/*xla_device_metadata=*/nullptr,
/*pjrt_device_metadata=*/nullptr,
/*device_allocator=*/nullptr);

XlaDeviceCompiler* xla_device_compiler = nullptr;
TF_EXPECT_OK(BuildXlaDeviceCompiler(device, nullptr, platform_info,
&xla_device_compiler));
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);

EXPECT_EQ(xla_device_compiler->device_type(), compilation_device_type);
// Check to make sure compilation cache path is empty.
EXPECT_TRUE(
xla_device_compiler->persistor()->persistent_cache_directory().empty());
}

// TODO(b/255826209): Look into using an actual TPU device for the unit test,
// and move this out of OSS.
TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerTpuDevice) {
TEST_F(XlaPlatformInfoTest,
GetOrCreatePjRtDeviceCompilerAndProfilerTpuDeviceNoCompilationCache) {
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_directory = "/tmp/xla_cache";
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types = DEVICE_GPU_XLA_JIT;
DeviceType device_type = DeviceType(DEVICE_TPU);
DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
// Use a CPU PjRtClient instead of a TPU one just for testing whether
Expand Down Expand Up @@ -196,6 +263,34 @@ TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerTpuDevice) {

EXPECT_EQ(pjrt_device_compiler->device_type(), compilation_device_type);
EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client);
EXPECT_TRUE(
pjrt_device_compiler->persistor()->persistent_cache_directory().empty());
}

TEST_F(XlaPlatformInfoTest, GetPersistentCacheDirectoryMultiple) {
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_directory = "/tmp/xla_cache";
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types = "GPU,CPU";
DeviceType device_gpu = DeviceType(DEVICE_GPU);
EXPECT_EQ(GetPersistentCacheDirectory(device_gpu), "/tmp/xla_cache");
DeviceType device_cpu = DeviceType(DEVICE_CPU);
EXPECT_EQ(GetPersistentCacheDirectory(device_cpu), "/tmp/xla_cache");
DeviceType device_tpu = DeviceType(DEVICE_TPU);
EXPECT_TRUE(GetPersistentCacheDirectory(device_tpu).empty());
}

TEST_F(XlaPlatformInfoTest, GetPersistentCacheDirectoryNoDeviceTypes) {
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_directory = "/tmp/xla_cache";
tensorflow::GetMarkForCompilationPassFlags()
->tf_xla_persistent_cache_device_types = "";
DeviceType device_gpu = DeviceType(DEVICE_GPU);
EXPECT_EQ(GetPersistentCacheDirectory(device_gpu), "/tmp/xla_cache");
DeviceType device_cpu = DeviceType(DEVICE_CPU);
EXPECT_EQ(GetPersistentCacheDirectory(device_cpu), "/tmp/xla_cache");
DeviceType device_tpu = DeviceType(DEVICE_TPU);
EXPECT_EQ(GetPersistentCacheDirectory(device_tpu), "/tmp/xla_cache");
}

} // namespace
Expand Down

0 comments on commit fa481b4

Please sign in to comment.