Skip to content

Commit

Permalink
PR tensorflow#15285: Ensure only one device is visible in pjrt_c_api_…
Browse files Browse the repository at this point in the history
…gpu_test

Imported from GitHub PR openxla/xla#15285

The test fails when the number of available devices is more than 1. This patch fixes that by ensuring that only one device is visible to the test.
Copybara import of the project:

--
587bebe70c7d298008eff0c65dfcfa901e1fe21a by Shraiysh Vaishay <svaishay@nvidia.com>:

Ensure only one device is visible in pjrt_c_api_gpu_test

The test fails when the number of available devices is more than 1.
This patch fixes that by ensuring that only one device is visible
to the test.

Merging this change closes tensorflow#15285

PiperOrigin-RevId: 655861635
  • Loading branch information
shraiysh authored and tensorflower-gardener committed Jul 25, 2024
1 parent a3a1ce3 commit f4107ff
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions third_party/xla/xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ xla_test(
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla/client:client_library",
"//xla/ffi:execution_context",
"//xla/ffi:ffi_api",
"//xla/ffi:type_id_registry",
Expand Down
5 changes: 4 additions & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/client/client_library.h"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/ffi_api.h"
Expand Down Expand Up @@ -215,6 +216,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) {
auto kv_store = std::make_shared<xla::InMemoryKeyValueStore>();
std::shared_ptr<::pjrt::PJRT_KeyValueCallbackData> kv_callback_data =
::pjrt::ConvertToCKeyValueCallbacks(kv_store);
xla::ClientLibrary::DestroyLocalInstances();

int num_nodes = 2;
std::vector<std::thread> threads;
Expand All @@ -225,7 +227,8 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) {
kv_store = kv_store] {
absl::flat_hash_map<std::string, xla::PjRtValueType> options = {
{"num_nodes", static_cast<int64_t>(num_nodes)},
{"node_id", static_cast<int64_t>(i)}};
{"node_id", static_cast<int64_t>(i)},
{"visible_devices", std::vector<int64_t>({0})}};
TF_ASSERT_OK_AND_ASSIGN(std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options));
TF_ASSERT_OK_AND_ASSIGN(
Expand Down

0 comments on commit f4107ff

Please sign in to comment.