diff --git a/builddeps/test-requirements.txt b/builddeps/test-requirements.txt index dc22d7d2..347925f1 100644 --- a/builddeps/test-requirements.txt +++ b/builddeps/test-requirements.txt @@ -1,11 +1,11 @@ absl-py numpy +jax jaxlib https://github.com/wsmoses/jax-md/archive/45059b8f63dad0b5cb171feafff71b82162487e7.tar.gz # maxtext can't be installed concurrently, but installing it fixes # https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz -jax[cuda12_pip]; sys_platform == 'linux' -jax; sys_platform != 'linux' +jax-cuda12-plugin[with_cuda]; sys_platform == 'linux' requests; sys_platform == 'linux' # -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # libtpu-nightly == 0.1.dev20240729; sys_platform == 'linux'