Skip to content

Commit

Permalink
replace hardcoded cuda version with cuda from system
Browse files Browse the repository at this point in the history
  • Loading branch information
nsosio committed Nov 16, 2023
1 parent b32365e commit 9b4bfbf
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ check_jq() {
fi
}

get_torch_cuda_version() {
# Get the full CUDA version using nvcc
CUDA_VERSION=$(nvcc --version | grep "release" | awk '{print $6}')

# Remove dots from the CUDA version
CUDA_VERSION_NUMERIC=$(echo "${CUDA_VERSION}" | tr -d '.')

# Set TORCH_CUDA_VERSION variable
TORCH_CUDA_VERSION="cu${CUDA_VERSION_NUMERIC}"

# Return the dynamically set variable
echo "${TORCH_CUDA_VERSION}"
}

# Function to download models
download_models() {
echo -e "\nDownloading models..."
Expand Down Expand Up @@ -126,7 +140,7 @@ run_benchmarks() {
if [ "$DEVICE" == "cpu" ] || [ "$USE_NVIDIA" == true ]; then
# Run Rust benchmarks
if [ "$DEVICE" == "gpu" ] && [ "$PLATFORM" != "Darwin" ]; then
TORCH_CUDA_VERSION=cu117
TORCH_CUDA_VERSION=$(get_torch_cuda_version)
fi
cargo run --release --bin sample \
--manifest-path="$DIR/rust_bench/llama2-burn/Cargo.toml" \
Expand Down

0 comments on commit 9b4bfbf

Please sign in to comment.