Skip to content

Commit

Permalink
[stream_executor] NFC: Clean up xla/stream_executor:executor_cache de…
Browse files Browse the repository at this point in the history
…pencencies and fix warnings

PiperOrigin-RevId: 567520429
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Sep 22, 2023
1 parent acd860c commit b73d62a
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 91 deletions.
5 changes: 2 additions & 3 deletions tensorflow/c/experimental/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@ cc_library(
"//tensorflow/core/common_runtime/device:device_utils",
"//tensorflow/core/platform:strcat",
"@com_google_absl//absl/functional:any_invocable",
"@local_xla//xla/stream_executor:executor_cache",
"@local_xla//xla/stream_executor",
"@local_xla//xla/stream_executor:multi_platform_manager",
"@local_xla//xla/stream_executor:platform",
"@local_xla//xla/stream_executor:stream_executor_pimpl",
],
)

Expand All @@ -66,7 +65,7 @@ cc_library(
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"@local_xla//xla/stream_executor:executor_cache",
"@local_xla//xla/stream_executor",
],
)

Expand Down
3 changes: 1 addition & 2 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ limitations under the License.
#include "xla/stream_executor/multi_platform_manager.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor_internal.h"
#include "xla/stream_executor/stream_executor_pimpl.h"
#include "xla/stream_executor/stream_executor.h"
#include "tensorflow/core/common_runtime/device/device_utils.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "xla/stream_executor/executor_cache.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"

namespace stream_executor {

Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/runtime/collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ limitations under the License.
#include "xla/service/gpu/runtime/support.h"
#include "xla/service/gpu/thunk.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/stream_executor/stream.h"

namespace xla {
namespace gpu {
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/runtime/collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License.
#include "xla/runtime/custom_call_registry.h"
#include "xla/service/gpu/nccl_collective_thunk.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/stream.h"

namespace xla {
namespace gpu {
Expand Down
51 changes: 7 additions & 44 deletions third_party/xla/xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ cc_library(
"//xla/stream_executor/platform",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
Expand Down Expand Up @@ -406,63 +407,25 @@ cc_library(

cc_library(
name = "executor_cache",
srcs = [
"device_description.h",
"device_memory.h",
"device_options.h",
"event.h",
"executor_cache.cc",
"launch_dim.h",
"plugin.h",
"plugin_registry.h",
"stream_executor_pimpl.h",
"temporary_device_memory.h",
"temporary_memory_manager.h",
],
hdrs = [
"blas.h",
"executor_cache.h",
"fft.h",
"kernel.h",
"kernel_cache_config.h",
"kernel_spec.h",
"platform.h",
"stream.h",
"stream_executor_internal.h",
"trace_listener.h",
],
srcs = ["executor_cache.cc"],
hdrs = ["executor_cache.h"],
visibility = ["//visibility:public"],
deps = [
":allocator_stats",
":data_type",
":device_description",
":device_description_proto_cc",
":device_memory",
":device_options",
":fft",
":kernel_cache_config",
":kernel_spec",
":launch_dim",
":plugin",
":stream_executor_headers",
":platform",
":stream_executor_pimpl_header",
"//xla/stream_executor/platform",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/protobuf:dnn_proto_cc",
],
)

Expand Down
65 changes: 34 additions & 31 deletions third_party/xla/xla/stream_executor/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,27 @@ limitations under the License.
#include "xla/stream_executor/executor_cache.h"

#include <memory>
#include <utility>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor_pimpl.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {

ExecutorCache::ExecutorCache() = default;
ExecutorCache::~ExecutorCache() { DestroyAllExecutors(); }

tsl::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
const StreamExecutorConfig& config,
const std::function<ExecutorFactory>& factory) {
const StreamExecutorConfig& config, const ExecutorFactory& factory) {
// In the fast path case, the cache already has an entry and we can just
// return after Get() which only takes a shared lock and not a unique lock.
// If we need to create, we take a unique lock on cache_.
auto fast_result = Get(config);
if (fast_result.ok()) {
if (auto fast_result = Get(config); fast_result.ok()) {
return fast_result;
}

Expand All @@ -38,7 +45,7 @@ tsl::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
absl::MutexLock lock{&mutex_};
entry = &cache_[config.ordinal];
// Release the map lock; the address of 'entry' is stable because
// std::map guarantees reference stability.
// absl::node_hash_map guarantees reference stability.
}

// Acquire the per-Entry mutex without holding the map mutex. Initializing
Expand Down Expand Up @@ -70,47 +77,43 @@ tsl::StatusOr<StreamExecutor*> ExecutorCache::Get(
{
absl::ReaderMutexLock lock{&mutex_};

{
if (config.gpu_stream) {
// Need to iterate through all stored executors.
for (auto& [ordinal, e] : cache_) {
absl::ReaderMutexLock l{&e.configurations_mutex};
for (auto& [c, executor] : e.configurations) {
if (executor->FindAllocatedStream(config.gpu_stream)) {
return executor.get();
}
// If gpu stream is not nullptr we have to find StreamExecutor that owns it,
// and return NOT_FOUND error if we can't find it.
if (config.gpu_stream) {
for (auto& [ordinal, e] : cache_) {
absl::ReaderMutexLock l{&e.configurations_mutex};
for (auto& [c, executor] : e.configurations) {
if (executor->FindAllocatedStream(config.gpu_stream)) {
return executor.get();
}
}
return tsl::Status(
absl::StatusCode::kNotFound,
absl::StrFormat("No executors own stream %p", config.gpu_stream));
}
return absl::NotFoundError(
absl::StrFormat("No executors own stream %p", config.gpu_stream));
}

auto it = cache_.find(config.ordinal);
if (it != cache_.end()) {
if (auto it = cache_.find(config.ordinal); it != cache_.end()) {
entry = &it->second;
} else {
return tsl::Status(
absl::StatusCode::kNotFound,
absl::StrFormat("No executors registered for ordinal %d",
config.ordinal));
return absl::NotFoundError(absl::StrFormat(
"No executors registered for ordinal %d", config.ordinal));
}
}

absl::ReaderMutexLock lock{&entry->configurations_mutex};
if (entry->configurations.empty()) {
return tsl::Status(absl::StatusCode::kNotFound,
absl::StrFormat("No executors registered for ordinal %d",
config.ordinal));
return absl::NotFoundError(absl::StrFormat(
"No executors registered for ordinal %d", config.ordinal));
}
for (const auto& iter : entry->configurations) {
if (iter.first.device_options == config.device_options) {

for (auto& [entry_config, entry_executor] : entry->configurations) {
if (entry_config.device_options == config.device_options) {
VLOG(2) << "hit in cache for device ordinal " << config.ordinal;
return iter.second.get();
return entry_executor.get();
}
}
return tsl::Status(absl::StatusCode::kNotFound,
"No executor found with a matching config.");

return absl::NotFoundError("No executor found with a matching config.");
}

void ExecutorCache::DestroyAllExecutors() {
Expand Down
27 changes: 18 additions & 9 deletions third_party/xla/xla/stream_executor/executor_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,37 @@ limitations under the License.
#define XLA_STREAM_EXECUTOR_EXECUTOR_CACHE_H_

#include <functional>
#include <map>
#include <memory>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/container/node_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "xla/stream_executor/stream_executor_pimpl.h"
#include "tsl/platform/status.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform/port.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {

// Forward declare.
class StreamExecutor;

// Utility class to allow Platform objects to manage cached StreamExecutors.
// Thread-safe.
class ExecutorCache {
public:
ExecutorCache() {}
using ExecutorFactory =
std::function<tsl::StatusOr<std::unique_ptr<StreamExecutor>>()>;

ExecutorCache();
~ExecutorCache();

// Looks up 'config' in the cache. Returns a pointer to the existing executor,
// if already present, or creates it using 'factory', if it does not.
// Factories may be executed concurrently for different device ordinals.
typedef tsl::StatusOr<std::unique_ptr<StreamExecutor>> ExecutorFactory();
tsl::StatusOr<StreamExecutor*> GetOrCreate(
const StreamExecutorConfig& config,
const std::function<ExecutorFactory>& factory);
tsl::StatusOr<StreamExecutor*> GetOrCreate(const StreamExecutorConfig& config,
const ExecutorFactory& factory);

// Returns a pointer to the described executor (if one with a matching config
// has been created), or a NOT_FOUND status.
Expand Down Expand Up @@ -70,7 +79,7 @@ class ExecutorCache {
// We key off of ordinal (instead of just looking up all fields in the
// StreamExecutorConfig) for a slight improvement in lookup time.
absl::Mutex mutex_;
std::map<int, Entry> cache_ ABSL_GUARDED_BY(mutex_);
absl::node_hash_map<int, Entry> cache_ ABSL_GUARDED_BY(mutex_);

SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
};
Expand Down

0 comments on commit b73d62a

Please sign in to comment.