diff --git a/Dockerfile b/Dockerfile index a843bda..ac8a04a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,7 @@ RUN python3.11 -m pip install -r requirements.txt # Install JAX with CUDA support. HPC is on CUDA 11, and JAX 0.2.25 is the latest version for that RUN python3.11 -m pip install --upgrade \ - "jax[cuda11_pip]==0.4.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ + "jax[cuda11_pip]==0.4.24" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ optax # Set the environment variables