From a0393c0d8053996906498f3cdb48c296489251db Mon Sep 17 00:00:00 2001 From: smjleo Date: Tue, 15 Oct 2024 11:56:47 +0100 Subject: [PATCH] does nvrtc work as deps? --- .buildkite/gpu_pipeline.yml | 1 + builddeps/test-requirements.txt | 1 + test/BUILD | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.buildkite/gpu_pipeline.yml b/.buildkite/gpu_pipeline.yml index 760d7707..2df91181 100644 --- a/.buildkite/gpu_pipeline.yml +++ b/.buildkite/gpu_pipeline.yml @@ -51,6 +51,7 @@ steps: export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/test.runfiles/pypi_nvidia_cublas_cu12/site-packages/nvidia/cublas/lib:\$LD_LIBRARY_PATH" export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_cupti_cu12/site-packages/nvidia/cuda_cupti/lib:\$LD_LIBRARY_PATH" export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_runtime_cu12/site-packages/nvidia/cuda_runtime/lib:\$LD_LIBRARY_PATH" + export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_nvrtc_cu12/site-packages/nvidia/cuda_nvrtc/lib:\$LD_LIBRARY_PATH" export PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_nvcc_cu12/site-packages/nvidia/cuda_nvcc/bin:\$PATH" export TF_CPP_MIN_LOG_LEVEL=0 diff --git a/builddeps/test-requirements.txt b/builddeps/test-requirements.txt index 569c677b..1e9ece29 100644 --- a/builddeps/test-requirements.txt +++ b/builddeps/test-requirements.txt @@ -5,5 +5,6 @@ jaxlib https://github.com/wsmoses/jax-md/archive/1188490610b95023f8a51166c3f6b92da31e78fe.tar.gz jax[cuda12_pip]; sys_platform == 'linux' requests; sys_platform == 'linux' +nvidia-cuda-nvrtc-cu12; sys_platform == 'linux' # -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # libtpu-nightly == 0.1.dev20240729; sys_platform == 'linux' diff --git a/test/BUILD b/test/BUILD index 8b3eedad..37651097 100644 --- a/test/BUILD +++ b/test/BUILD @@ -63,7 +63,7 @@ TEST_DEPS = [ "@pypi_absl_py//:pkg", ] + select({ ":use_tpu": ["@pypi_libtpu_nightly//:pkg", "@pypi_requests//:pkg"], - "@bazel_tools//src/conditions:linux_x86_64": ["@pypi_jax_cuda12_plugin//:pkg"], + "@bazel_tools//src/conditions:linux_x86_64": ["@pypi_jax_cuda12_plugin//:pkg", "@pypi_nvidia_cuda_nvrtc_cu12//:pkg"], "//conditions:default": [] })