From 9b4bfbf688f5a778966640f3bfc1f8b5d92b7344 Mon Sep 17 00:00:00 2001 From: nsosio Date: Thu, 16 Nov 2023 16:35:51 +0000 Subject: [PATCH] replace hardcoded cuda version with cuda from system --- benchmark.sh | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/benchmark.sh b/benchmark.sh index d2ca1d3c..c6204a0b 100755 --- a/benchmark.sh +++ b/benchmark.sh @@ -87,6 +87,20 @@ check_jq() { fi } +get_torch_cuda_version() { + # Get the full CUDA version using nvcc + CUDA_VERSION=$(nvcc --version | grep "release" | awk '{print $6}') + + # Remove dots from the CUDA version + CUDA_VERSION_NUMERIC=$(echo "${CUDA_VERSION}" | tr -d '.') + + # Set TORCH_CUDA_VERSION variable + TORCH_CUDA_VERSION="cu${CUDA_VERSION_NUMERIC}" + + # Return the dynamically set variable + echo "${TORCH_CUDA_VERSION}" +} + # Function to download models download_models() { echo -e "\nDownloading models..." @@ -126,7 +140,7 @@ run_benchmarks() { if [ "$DEVICE" == "cpu" ] || [ "$USE_NVIDIA" == true ]; then # Run Rust benchmarks if [ "$DEVICE" == "gpu" ] && [ "$PLATFORM" != "Darwin" ]; then - TORCH_CUDA_VERSION=cu117 + TORCH_CUDA_VERSION=$(get_torch_cuda_version) fi cargo run --release --bin sample \ --manifest-path="$DIR/rust_bench/llama2-burn/Cargo.toml" \