diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index d1decfd3a885..7df848aad843 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -267,10 +267,10 @@ def jax_test( deps = [ "//jax", "//jax:test_util", - ] + deps + if_building_jaxlib(["//jaxlib/cuda:gpu_only_test_deps"]) + select({ - "//jax:enable_build_cuda_plugin_from_source": ["//jax_plugins:gpu_plugin_only_test_deps"], - "//conditions:default": [], - }), + ] + deps + if_building_jaxlib([ + "//jaxlib/cuda:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ]), data = data, shard_count = test_shards, tags = test_tags,