Skip to content

Commit

Permalink
improve flags
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 3, 2023
1 parent 9782c26 commit 90dc79b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
11 changes: 9 additions & 2 deletions onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
return 2;
#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080))
#if !defined(DISABLE_FLOAT8_TYPES)
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
return 1;
Expand Down Expand Up @@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {
}

auto first_type = input_A->GetElementType();
#if !defined(DISABLE_FLOAT8_TYPES)
bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
if (!is_float8)
#endif
return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
input_C, scale_A, scale_B, scale_Y);
#if !defined(DISABLE_FLOAT8_TYPES)
return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
input_C, scale_A, scale_B, scale_Y);
#endif
}

Status GemmFloat8::ComputeRowMajor(
Expand Down Expand Up @@ -197,10 +201,13 @@ Status GemmFloat8::ComputeGemm(
switch (d_cuda_type) {
case CUDA_R_16F:
switch (a_cuda_type) {
#if !defined(DISABLE_FLOAT8_TYPES)
// Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8
case CUDA_R_8F_E4M3:
case CUDA_R_8F_E5M2:
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
break;
#endif
default:
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
break;
Expand Down Expand Up @@ -267,7 +274,7 @@ Status GemmFloat8::ComputeGemm(
sizeof(p_scale_b)));

// float 8
#if CUDA_VERSION >= 11080
#if !defined(DISABLE_FLOAT8_TYPES)
if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN ||
dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) {
// For FP8 output, cuBLAS requires C_type to be same as bias_type
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ const char* CudaDataTypeToString(cudaDataType_t dt) {
return "CUDA_R_16BF";
case CUDA_R_32F:
return "CUDA_R_32F";
#if (CUDA_VERSION >= 11080)
#if !defined(DISABLE_FLOAT8_TYPES)
// Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8
case CUDA_R_8F_E4M3:
return "CUDA_R_8F_E4M3";
case CUDA_R_8F_E5M2:
Expand Down Expand Up @@ -101,7 +102,7 @@ cudaDataType_t ToCudaDataType(int32_t element_type) {
return CUDA_R_16F;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
return CUDA_R_16BF;
#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080))
#if !defined(DISABLE_FLOAT8_TYPES)
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
return CUDA_R_8F_E4M3;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
Expand Down

0 comments on commit 90dc79b

Please sign in to comment.