diff --git a/userbenchmark/release-test/setup_env.sh b/userbenchmark/release-test/setup_env.sh index eb34e415e7..07bf43c090 100644 --- a/userbenchmark/release-test/setup_env.sh +++ b/userbenchmark/release-test/setup_env.sh @@ -29,14 +29,19 @@ conda update --all -y sudo ln -sf /usr/local/cuda-${CUDA_VERSION} /usr/local/cuda conda uninstall -y pytorch torchvision pytorch-cuda -conda uninstall -y pytorch torchvision cudatoolkit +conda uninstall -y pytorch torchvision # make sure we have a clean environment without pytorch pip uninstall -y torch torchvision # install magma conda install -y -c pytorch ${MAGMA_VERSION} -conda install --force-reinstall -v -y pytorch=${PYTORCH_VERSION} torchvision pytorch-cuda=${CUDA_VERSION} -c ${PYTORCH_CHANNEL} -c nvidia -pip install --force-reinstall numpy + +# install pip version of pytorch and torchvision +if [[ ${PYTORCH_CHANNEL} == "pytorch-test" ]]; then + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/test/cu121 +else + pip3 install torch torchvision +fi python -c 'import torch; print(torch.__version__); print(torch.version.git_version)'