diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index df25342342cd..d0c52aa57074 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: return 2; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: return 1; @@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { } auto first_type = input_A->GetElementType(); +#if !defined(DISABLE_FLOAT8_TYPES) bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; if (!is_float8) +#endif return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#if !defined(DISABLE_FLOAT8_TYPES) return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#endif } Status GemmFloat8::ComputeRowMajor( @@ -197,10 +201,13 @@ Status GemmFloat8::ComputeGemm( switch (d_cuda_type) { case CUDA_R_16F: switch (a_cuda_type) { +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; break; +#endif default: compute_type = CUBLAS_COMPUTE_32F_FAST_16F; break; @@ -267,7 +274,7 @@ Status GemmFloat8::ComputeGemm( sizeof(p_scale_b))); // float 8 -#if CUDA_VERSION >= 11080 +#if !defined(DISABLE_FLOAT8_TYPES) if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { // For FP8 output, cuBLAS requires C_type to be same as bias_type diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 288ca8e97e34..33f2938940e4 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -62,7 +62,8 @@ const char* CudaDataTypeToString(cudaDataType_t dt) { return "CUDA_R_16BF"; case CUDA_R_32F: return "CUDA_R_32F"; -#if (CUDA_VERSION >= 11080) +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 case CUDA_R_8F_E4M3: return "CUDA_R_8F_E4M3"; case CUDA_R_8F_E5M2: @@ -101,7 +102,7 @@ cudaDataType_t ToCudaDataType(int32_t element_type) { return CUDA_R_16F; case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return CUDA_R_16BF; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: return CUDA_R_8F_E4M3; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: