diff --git a/Dockerfile b/Dockerfile index ac8a04a..86b3774 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Description: Dockerfile for JAX with CUDA support -FROM nvidia/cuda:11.1.1-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 # Set the working directory WORKDIR /workspace @@ -29,7 +29,7 @@ RUN python3.11 -m pip install -r requirements.txt # Install JAX with CUDA support. HPC is on CUDA 11, and JAX 0.2.25 is the latest version for that RUN python3.11 -m pip install --upgrade \ - "jax[cuda11_pip]==0.4.24" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ + "jax[cuda11_pip]==0.4.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ optax # Set the environment variables