Skip to content

Commit

Permalink
check compatibility between cuda version and float 8
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 2, 2023
1 parent 7cc6d53 commit 9782c26
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import json
import os
import packaging.version as pv
import platform
import re
import shlex
Expand Down Expand Up @@ -1085,6 +1086,12 @@ def generate_build_tree(
if args.use_cuda:
nvcc_threads = number_of_nvcc_threads(args)
cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(nvcc_threads))
if not args.disable_float8_types and args.cuda_version:
if pv.Version(args.cuda_version) < pv.Version("11.8"):
raise BuildError(
f"Float 8 types require CUDA>=11.8. They must be disabled on CUDA=={args.cuda_version}. "
f"See option disable_types."
)
if args.use_rocm:
cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home)
cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version)
Expand Down

0 comments on commit 9782c26

Please sign in to comment.