Skip to content

Commit

Permalink
[TSL] Remove TSL_STATIC_THREAD_LOCAL_POD macro.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572168960
  • Loading branch information
chsigg authored and tensorflower-gardener committed Oct 10, 2023
1 parent 3effef2 commit c66f8f6
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 72 deletions.
8 changes: 0 additions & 8 deletions third_party/xla/third_party/tsl/tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1905,11 +1905,3 @@ tsl_cc_test(
"//tsl/lib/core:status_test_util",
],
)

cc_library(
name = "static_threadlocal",
hdrs = [
"static_threadlocal.h",
],
visibility = ["//visibility:public"],
)
42 changes: 0 additions & 42 deletions third_party/xla/third_party/tsl/tsl/platform/static_threadlocal.h

This file was deleted.

1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ cc_library(
"@local_tsl//tsl/cuda",
"@local_tsl//tsl/cuda:cudart",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:static_threadlocal",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
Expand Down
17 changes: 7 additions & 10 deletions third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/stacktrace.h"
#include "tsl/platform/static_threadlocal.h"
#include "tsl/platform/status.h"
#include "tsl/platform/threadpool.h"

bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false;
bool FLAGS_gpuexec_cuda_device_0_only = false;
static constexpr bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
static constexpr bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false;
static constexpr bool FLAGS_gpuexec_cuda_device_0_only = false;

#define RETURN_IF_CUDA_RES_ERROR(expr, ...) \
do { \
Expand Down Expand Up @@ -135,20 +134,18 @@ void SynchronizeOrDie() {
"Synchronize fail: ", tsl::CurrentStackTrace());
}

struct ThreadLocalData {
thread_local struct ThreadLocalData {
int64_t id;
GpuContext* context; // Only valid if id == a known good context.
int depth;
};

TSL_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
} tls_data = {};

} // namespace

ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) {
if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie();

auto* tls = &tls_data.get();
auto* tls = &tls_data;

// If this is an outermost scope, we must not assume that the CUDA context has
// been left in the same state we left it. Other code may have run on this
Expand Down Expand Up @@ -187,7 +184,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) {
ScopedActivateContext::~ScopedActivateContext() {
if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie();

auto* tls = &tls_data.get();
auto* tls = &tls_data;

if (kVerifyGpuContext) {
// Note that if kVerifyGpuContext is used, and contexts are deleted, it's
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ cc_library(
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:numbers",
"@local_tsl//tsl/platform:stacktrace",
"@local_tsl//tsl/platform:static_threadlocal",
]),
)

Expand Down
17 changes: 7 additions & 10 deletions third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ limitations under the License.
#include "tsl/platform/logging.h"
#include "tsl/platform/numbers.h"
#include "tsl/platform/stacktrace.h"
#include "tsl/platform/static_threadlocal.h"
#include "tsl/platform/threadpool.h"

bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
bool FLAGS_gpuexec_rocm_device_0_only = false;
static constexpr bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
static constexpr bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
static constexpr bool FLAGS_gpuexec_rocm_device_0_only = false;

#define RETURN_IF_ROCM_ERROR(expr, ...) \
do { \
Expand Down Expand Up @@ -128,20 +127,18 @@ void SynchronizeOrDie() {
}
}

struct ThreadLocalData {
thread_local struct ThreadLocalData {
int current_device_ordinal;
GpuContext* context; // Only valid if id == a known good context.
int depth;
};

TSL_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
} tls_data = {};

} // namespace

ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) {
if (FLAGS_gpuexec_rocm_sync_around_driver_calls) SynchronizeOrDie();

auto* tls = &tls_data.get();
auto* tls = &tls_data;
if (tls->depth == 0) {
VLOG(3) << "ScopedActivateContext switching to "
<< hip_context->device_ordinal();
Expand Down Expand Up @@ -177,7 +174,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) {
ScopedActivateContext::~ScopedActivateContext() {
if (FLAGS_gpuexec_rocm_sync_around_driver_calls) SynchronizeOrDie();

auto* tls = &tls_data.get();
auto* tls = &tls_data;

if (kVerifyGpuContext) {
CHECK_EQ(CurrentContext(),
Expand Down

0 comments on commit c66f8f6

Please sign in to comment.