From fbdaba21951e8b508af6328112d654ec326ed863 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 15 Jan 2024 11:22:08 -0500 Subject: [PATCH] hip --- include/flexflow/accessor.h | 12 +- include/flexflow/utils/cuda_helper.h | 14 ++ inference/file_loader.cc | 12 +- src/ops/arg_topk.cpp | 12 ++ src/ops/argmax.cpp | 12 +- src/ops/element_unary.cpp | 6 + src/ops/fused.cpp | 16 ++ src/ops/inc_multihead_self_attention.cpp | 24 +++ src/ops/inc_multihead_self_attention.cu | 14 +- src/ops/kernels/decompress_kernels.cpp | 23 ++- src/ops/kernels/element_binary_kernels.cpp | 20 ++- src/ops/kernels/embedding_kernels.cpp | 154 +++++++++++++++++- src/ops/kernels/linear_kernels.cpp | 23 +++ src/ops/kernels/linear_kernels.cu | 3 +- src/ops/kernels/rms_norm_kernels.cpp | 6 + src/ops/kernels/rms_norm_kernels.cu | 2 +- src/ops/kernels/softmax.cpp | 11 +- src/ops/layer_norm.cpp | 9 + src/ops/linear.cc | 16 +- src/ops/residual_layer_norm.cpp | 12 ++ src/ops/residual_layer_norm.cu | 5 +- src/ops/sigmoid_silu_multi.cpp | 8 + src/ops/softmax.cc | 8 +- src/ops/spec_inc_multihead_self_attention.cpp | 13 ++ src/ops/spec_inc_multihead_self_attention.cu | 10 +- src/ops/tree_inc_multihead_self_attention.cpp | 18 ++ src/parallel_ops/combine.cc | 4 +- src/parallel_ops/kernels/combine_kernels.cpp | 3 + .../kernels/reduction_kernels.cpp | 10 ++ .../kernels/replicate_kernels.cpp | 3 + src/parallel_ops/reduction.cc | 4 +- src/parallel_ops/replicate.cc | 4 +- src/runtime/accessor.cc | 18 +- src/runtime/ffconst_utils.cc | 2 +- src/runtime/hip_helper.cpp | 42 ++++- src/runtime/initializer_kernel.cpp | 11 ++ src/runtime/model.cc | 2 +- src/runtime/parallel_tensor.cc | 8 +- 38 files changed, 511 insertions(+), 63 deletions(-) diff --git a/include/flexflow/accessor.h b/include/flexflow/accessor.h index 26ed006646..121140c926 100644 --- a/include/flexflow/accessor.h +++ b/include/flexflow/accessor.h @@ -8,7 +8,7 @@ #include #elif defined(FF_USE_HIP_CUDA) #include -#include +#include #elif defined(FF_USE_HIP_ROCM) #include #include @@ -18,6 +18,12 @@ namespace FlexFlow { +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +typedef __nv_bfloat16 __ff_bfloat16; +#elif defined(FF_USE_HIP_ROCM) +typedef hip_bfloat16 __ff_bfloat16; +#endif + template using AccessorRO = Legion::FieldAccessor>; @@ -64,7 +70,7 @@ class GenericTensorAccessorW { float *get_float_ptr() const; double *get_double_ptr() const; half *get_half_ptr() const; - __nv_bfloat16 *get_bfloat16_ptr() const; + __ff_bfloat16 *get_bfloat16_ptr() const; char *get_byte_ptr() const; DataType data_type; Legion::Domain domain; @@ -84,7 +90,7 @@ class GenericTensorAccessorR { float const *get_float_ptr() const; double const *get_double_ptr() const; half const *get_half_ptr() const; - __nv_bfloat16 const *get_bfloat16_ptr() const; + __ff_bfloat16 const *get_bfloat16_ptr() const; char const *get_byte_ptr() const; DataType data_type; Legion::Domain domain; diff --git a/include/flexflow/utils/cuda_helper.h b/include/flexflow/utils/cuda_helper.h index f8bf67b3e1..f3038860d7 100644 --- a/include/flexflow/utils/cuda_helper.h +++ b/include/flexflow/utils/cuda_helper.h @@ -161,6 +161,20 @@ T *download_tensor(T const *ptr, size_t num_elements); template bool download_tensor(T const *ptr, T *dst, size_t num_elements); +// data type for cublasgemm +template +struct cublasAlphaBetaType { + using type = float; // default +}; +template <> +struct cublasAlphaBetaType { + using type = half; +}; +template <> +struct cublasAlphaBetaType<__nv_bfloat16> { + using type = float; +}; + cudnnStatus_t cudnnSetTensorDescriptorFromDomain(cudnnTensorDescriptor_t tensor, Legion::Domain domain, DataType data_type = DT_FLOAT); diff --git a/inference/file_loader.cc b/inference/file_loader.cc index 9f9bce56b7..60b2c80b06 100644 --- a/inference/file_loader.cc +++ b/inference/file_loader.cc @@ -127,7 +127,7 @@ void load_attention_weights_multi_query(DT *ptr, ///////////////////////bfloat16 function/////////////////////// -void load_from_file_b16(__nv_bfloat16 *ptr, size_t size, std::string filepath) { +void load_from_file_b16(__ff_bfloat16 *ptr, size_t size, std::string filepath) { std::ifstream in(filepath, std::ios::in | std::ios::binary); if (!in.good()) { std::cout << "Could not open file: " << filepath << std::endl; @@ -156,7 +156,7 @@ void load_from_file_b16(__nv_bfloat16 *ptr, size_t size, std::string filepath) { in.close(); } -void load_attention_weights_v2_b16(__nv_bfloat16 *ptr, +void load_attention_weights_v2_b16(__ff_bfloat16 *ptr, int num_heads, int num_kv_heads, size_t hidden_dim, @@ -290,7 +290,7 @@ void load_attention_weights_v2_b16(__nv_bfloat16 *ptr, } } -void load_attention_bias_v2_b16(__nv_bfloat16 *ptr, +void load_attention_bias_v2_b16(__ff_bfloat16 *ptr, int num_heads, int num_kv_heads, size_t hidden_dim, @@ -382,8 +382,8 @@ void FileDataLoader::load_single_weight_tensor_b16(FFModel *ff, dims_vec.push_back(weight->dims[i]); volume *= weight->dims[i]; } - assert(data_type_size(weight->data_type) == sizeof(__nv_bfloat16)); - __nv_bfloat16 *data = (__nv_bfloat16 *)malloc(sizeof(__nv_bfloat16) * volume); + assert(data_type_size(weight->data_type) == sizeof(__ff_bfloat16)); + __ff_bfloat16 *data = (__ff_bfloat16 *)malloc(sizeof(__ff_bfloat16) * volume); std::string weight_filename = removeGuidOperatorName(std::string(l->name)); @@ -446,7 +446,7 @@ void FileDataLoader::load_single_weight_tensor_b16(FFModel *ff, // Copy the weight data from the buffer to the weight's ParallelTensor ParallelTensor weight_pt; ff->get_parallel_tensor_from_tensor(weight, weight_pt); - weight_pt->set_tensor<__nv_bfloat16>(ff, dims_vec, data); + weight_pt->set_tensor<__ff_bfloat16>(ff, dims_vec, data); // Free buffer memory delete data; diff --git a/src/ops/arg_topk.cpp b/src/ops/arg_topk.cpp index f431d3d4bf..17cc7534ad 100644 --- a/src/ops/arg_topk.cpp +++ b/src/ops/arg_topk.cpp @@ -515,6 +515,18 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, m->sorted, m->speculative_decoding ? bc : nullptr, stream); + } else if (input.data_type == DT_B16) { + ArgTopK::forward_kernel(m, + input.get_bfloat16_ptr(), + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, + indices.get_int32_ptr(), + batch_size, + length, + k, + m->sorted, + m->speculative_decoding ? bc : nullptr, + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/argmax.cpp b/src/ops/argmax.cpp index 8a1cf0b3b0..7437b8b969 100644 --- a/src/ops/argmax.cpp +++ b/src/ops/argmax.cpp @@ -466,6 +466,16 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, length, batch_size, stream); + } else if (input.data_type == DT_B16) { + ArgMax::forward_kernel(m, + input.get_bfloat16_ptr(), + indices.get_int32_ptr(), + m->probs, + m->beam_search ? parent.get_int32_ptr() + : nullptr, + length, + batch_size, + stream); } else { assert(false && "Unsupported data type"); } @@ -491,7 +501,7 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, : OpMeta(handler, op) { DataType data_type = op->data_type; size_t prob_size = batch_size; - assert(data_type == DT_FLOAT || data_type == DT_HALF); + assert(data_type == DT_FLOAT || data_type == DT_HALF || data_type == DT_B16); size_t total_size = prob_size * sizeof(float); gpu_mem_allocator.create_legion_instance(reserveInst, total_size); probs = gpu_mem_allocator.allocate_instance(prob_size); diff --git a/src/ops/element_unary.cpp b/src/ops/element_unary.cpp index e20200420f..06a4d3c36a 100644 --- a/src/ops/element_unary.cpp +++ b/src/ops/element_unary.cpp @@ -314,6 +314,12 @@ template void int64_t *output_ptr, size_t num_elements); +template void ElementUnary::forward_kernel_wrapper( + ElementUnaryMeta const *m, + hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements); + template void ElementUnary::backward_kernel_wrapper(ElementUnaryMeta const *m, float const *input_ptr, diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index 3282bc57d9..af98f153e2 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -428,6 +428,11 @@ __host__ void FusedOp::forward_task(Task const *task, m, my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr()); + } else if (m->input_type == DT_B16) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr()); } break; } @@ -815,6 +820,12 @@ __host__ void my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr(), my_input_accessor[0].domain.get_volume()); + } else if (m->data_type == DT_B16) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr(), + my_input_accessor[0].domain.get_volume()); } else { assert(false && "Unsupported data type in ElementUnary forward"); } @@ -1039,6 +1050,11 @@ __host__ void m, my_input_accessor[0].get_float_ptr(), my_output_accessor[0].get_float_ptr()); + } else if (m->input_type == DT_B16) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_bfloat16_ptr(), + my_output_accessor[0].get_bfloat16_ptr()); } break; } diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index d60386f927..4d63ba628d 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -802,6 +802,23 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_B16) { + if (m->offload) { + pre_build_weight_kernel(m, weight, input.data_type, stream); + } + hip_bfloat16 const *bias_ptr = + use_bias ? bias.get_bfloat16_ptr() + : static_cast(nullptr); + Kernels::IncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + m->offload ? static_cast(m->weight_ptr) + : weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } @@ -1098,4 +1115,11 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( DataType data_type, hipStream_t stream); +template void + Kernels::IncMultiHeadAttention::pre_build_weight_kernel( + IncMultiHeadSelfAttentionMeta const *m, + GenericTensorAccessorR const weight, + DataType data_type, + cudaStream_t stream); + }; // namespace FlexFlow diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 37ee080eec..af77ac9f6d 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -513,7 +513,8 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, // Step 1: Compute QKV projections { - float alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; // after transpositions int m_q = m->qProjSize * m->num_q_heads; int m_k = m->kProjSize * m->num_q_heads; @@ -645,7 +646,8 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, #endif // Project to output, save result directly on output tensor { - float alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; // after transpositions int m_ = m->oProjSize; int k = m->vProjSize * m->num_q_heads; @@ -925,7 +927,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (m->output_type[0] == DT_FLOAT) { compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - }else if (m->output_type[0] == DT_B16) { + } else if (m->output_type[0] == DT_B16) { compute_type = CUBLAS_COMPUTE_32F; } #endif @@ -951,7 +953,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, // Step 1: compute query-key product QK.T/sqrt(d_k) { // Scale by sqrt(d_k) as per the original attention paper - float alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; if (*m->qk_prod_scaling) { alpha = static_cast
(1.0f / sqrt(m->kProjSize)); } @@ -1082,7 +1085,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ // softmax(QK.T/sqrt(d_k)).T { - float alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; // after transpositions int m_ = m->vProjSize; int n = num_new_tokens; diff --git a/src/ops/kernels/decompress_kernels.cpp b/src/ops/kernels/decompress_kernels.cpp index 22bf93d449..5127e0842e 100644 --- a/src/ops/kernels/decompress_kernels.cpp +++ b/src/ops/kernels/decompress_kernels.cpp @@ -54,10 +54,20 @@ template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int4_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void + decompress_int4_general_weights(char const *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int in_dim, + int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize); template __global__ void decompress_int8_general_weights( char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize); +template __global__ void + decompress_int8_general_weights(char const *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int in_dim, + int valueSize); template __global__ void decompress_int4_attention_weights(char *input_weight_ptr, float *weight_ptr, @@ -71,7 +81,12 @@ template __global__ void int qProjSize, int qSize, int num_heads); - +template __global__ void + decompress_int4_attention_weights(char *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int qProjSize, + int qSize, + int num_heads); template __global__ void decompress_int8_attention_weights(char *input_weight_ptr, float *weight_ptr, @@ -86,5 +101,11 @@ template __global__ void int qSize, int num_heads); +template __global__ void + decompress_int8_attention_weights(char *input_weight_ptr, + hip_bfloat16 *weight_ptr, + int qProjSize, + int qSize, + int num_heads); } // namespace Kernels }; // namespace FlexFlow \ No newline at end of file diff --git a/src/ops/kernels/element_binary_kernels.cpp b/src/ops/kernels/element_binary_kernels.cpp index a65372de85..05467182e7 100644 --- a/src/ops/kernels/element_binary_kernels.cpp +++ b/src/ops/kernels/element_binary_kernels.cpp @@ -82,8 +82,24 @@ void forward_kernel_wrapper(ElementBinaryMeta const *m, } // print_tensor(in1_ptr, in1_domain.get_volume(), "input1:"); // print_tensor(in2_ptr, in2_domain.get_volume(), "input2:"); - Internal::forward_kernel( - m, in1.get_float_ptr(), in2.get_float_ptr(), out.get_float_ptr(), stream); + if (out.data_type == DT_HALF) { + Internal::forward_kernel( + m, in1.get_half_ptr(), in2.get_half_ptr(), out.get_half_ptr(), stream); + } else if (out.data_type == DT_FLOAT) { + Internal::forward_kernel(m, + in1.get_float_ptr(), + in2.get_float_ptr(), + out.get_float_ptr(), + stream); + } else if (out.data_type == DT_B16) { + Internal::forward_kernel(m, + in1.get_bfloat16_ptr(), + in2.get_bfloat16_ptr(), + out.get_bfloat16_ptr(), + stream); + } else { + assert(false && "Unsupported data type"); + } // print_tensor(out_ptr, in1_domain.get_volume(), "output:"); if (m->profiling) { checkCUDA(hipEventRecord(t_end, stream)); diff --git a/src/ops/kernels/embedding_kernels.cpp b/src/ops/kernels/embedding_kernels.cpp index ee4a6fcea1..7384f4e7e6 100644 --- a/src/ops/kernels/embedding_kernels.cpp +++ b/src/ops/kernels/embedding_kernels.cpp @@ -60,7 +60,7 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); - } else if (weight.data_type == DT_HALF) { + } else if (weight.data_type == DT_DOUBLE) { Internal::forward_kernel(input.get_int32_ptr(), output.get_double_ptr(), weight.get_double_ptr(), @@ -70,6 +70,16 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (weight.data_type == DT_B16) { + Internal::forward_kernel(input.get_int32_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -104,6 +114,16 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (weight.data_type == DT_B16) { + Internal::forward_kernel(input.get_int64_ptr(), + output.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -162,6 +182,16 @@ void backward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (m->output_type[0] == DT_B16) { + Internal::backward_kernel(input.get_int32_ptr(), + output.get_bfloat16_ptr(), + weight_grad.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -196,6 +226,16 @@ void backward_kernel_wrapper(EmbeddingMeta const *m, m->aggr, output.domain.get_volume(), stream); + } else if (m->output_type[0] == DT_B16) { + Internal::backward_kernel(input.get_int64_ptr(), + output.get_bfloat16_ptr(), + weight_grad.get_bfloat16_ptr(), + in_dim, + out_dim, + batch_size, + m->aggr, + output.domain.get_volume(), + stream); } else { assert(false && "Unsupported DataType in Embedding"); } @@ -332,6 +372,50 @@ __global__ void embed_backward_no_aggr(int64_t const *input, } } +template <> +__global__ void + embed_backward_no_aggr(int const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + int wordIdx = input[idx]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, output[i]); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += output[i]; +#endif + } +} + +template <> +__global__ void + embed_backward_no_aggr(int64_t const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + int64_t wordIdx = input[idx]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, output[i]); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += output[i]; +#endif + } +} + template __global__ void embed_backward_with_aggr(TI const *input, TD const *output, @@ -426,6 +510,74 @@ __global__ void embed_backward_with_aggr(int64_t const *input, } } +template <> +__global__ void + embed_backward_with_aggr(int const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int in_dim, + int batch_size, + AggrMode aggr) { + hip_bfloat16 scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + hip_bfloat16 gradient; + if (aggr == AGGR_MODE_SUM) { + gradient = output[i]; + } else { + assert(aggr == AGGR_MODE_AVG); + gradient = output[i] * scale; + } + for (int j = 0; j < in_dim; j++) { + int wordIdx = input[idx * in_dim + j]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, gradient); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += gradient; +#endif + } + } +} + +template <> +__global__ void + embed_backward_with_aggr(int64_t const *input, + hip_bfloat16 const *output, + hip_bfloat16 *embed, + int out_dim, + int in_dim, + int batch_size, + AggrMode aggr) { + hip_bfloat16 scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + hip_bfloat16 gradient; + if (aggr == AGGR_MODE_SUM) { + gradient = output[i]; + } else { + assert(aggr == AGGR_MODE_AVG); + gradient = output[i] * scale; + } + for (int j = 0; j < in_dim; j++) { + int64_t wordIdx = input[idx * in_dim + j]; +#if __CUDA_ARCH__ >= 700 + atomicAdd(embed + wordIdx * out_dim + off, gradient); +#else + assert(false); + // TODO: this implementation may result in race condition + // so we use an assertion failure to warn users + embed[wordIdx * out_dim + off] += gradient; +#endif + } + } +} + /*static*/ template void forward_kernel(TI const *input_ptr, diff --git a/src/ops/kernels/linear_kernels.cpp b/src/ops/kernels/linear_kernels.cpp index 072eb5e96b..3cd9862bba 100644 --- a/src/ops/kernels/linear_kernels.cpp +++ b/src/ops/kernels/linear_kernels.cpp @@ -124,6 +124,16 @@ void forward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); + } else if (m->input_type[0] == DT_B16) { + Internal::forward_kernel(m, + input_ptr, + output_ptr, + weight_ptr, + bias_ptr, + in_dim, + out_dim, + batch_size, + stream); } if (m->profiling) { @@ -189,6 +199,19 @@ void backward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); + } else if (m->input_type[0] == DT_B16) { + Internal::backward_kernel(m, + input_ptr, + input_grad_ptr, + output_ptr, + output_grad_ptr, + kernel_ptr, + kernel_grad_ptr, + bias_grad_ptr, + in_dim, + out_dim, + batch_size, + stream); } if (m->profiling) { diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index db75a5bf32..ad416b433c 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -333,7 +333,8 @@ void forward_kernel(LinearMeta const *m, } checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - float alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]); cudaDataType_t weight_type = m->offload ? ff_to_cuda_datatype(m->weight_ptr_type) diff --git a/src/ops/kernels/rms_norm_kernels.cpp b/src/ops/kernels/rms_norm_kernels.cpp index 24ab7051e6..c81e73dd77 100644 --- a/src/ops/kernels/rms_norm_kernels.cpp +++ b/src/ops/kernels/rms_norm_kernels.cpp @@ -190,6 +190,12 @@ void forward_kernel_wrapper(RMSNormMeta const *m, weight.get_float_ptr(), output.get_float_ptr(), stream); + } else if (output.data_type == DT_B16) { + forward_kernel(m, + input.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + stream); } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/kernels/rms_norm_kernels.cu b/src/ops/kernels/rms_norm_kernels.cu index 67a0b156b5..19661724d2 100644 --- a/src/ops/kernels/rms_norm_kernels.cu +++ b/src/ops/kernels/rms_norm_kernels.cu @@ -137,7 +137,7 @@ __global__ void template __global__ void NormKernel(int64_t N, T const *X, T const *rstd, T *Y) { using T_ACC = T; - int64_t const i = blockIdx.x; + const int64_t i = blockIdx.x; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { int64_t const index = i * N + j; Y[index] = static_cast(X[index]) * static_cast(rstd[i]); diff --git a/src/ops/kernels/softmax.cpp b/src/ops/kernels/softmax.cpp index 89c9f14a01..e785d404d7 100644 --- a/src/ops/kernels/softmax.cpp +++ b/src/ops/kernels/softmax.cpp @@ -107,7 +107,10 @@ template void forward_kernel_wrapper(SoftmaxMeta const *m, template void forward_kernel_wrapper(SoftmaxMeta const *m, half const *input_ptr, half *output_ptr); - +template void + forward_kernel_wrapper(SoftmaxMeta const *m, + hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr); template void backward_kernel_wrapper(SoftmaxMeta const *m, float *input_grad_ptr, float const *output_grad_ptr, @@ -116,7 +119,11 @@ template void backward_kernel_wrapper(SoftmaxMeta const *m, half *input_grad_ptr, half const *output_grad_ptr, size_t num_elements); - +template void + backward_kernel_wrapper(SoftmaxMeta const *m, + hip_bfloat16 *input_grad_ptr, + hip_bfloat16 const *output_grad_ptr, + size_t num_elements); namespace Internal { template void forward_kernel(SoftmaxMeta const *m, diff --git a/src/ops/layer_norm.cpp b/src/ops/layer_norm.cpp index 07dbdb3dfb..ff420901e4 100644 --- a/src/ops/layer_norm.cpp +++ b/src/ops/layer_norm.cpp @@ -182,6 +182,15 @@ void LayerNorm::forward_kernel_wrapper(LayerNormMeta const *m, gamma.get_half_ptr(), m->use_bias ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_B16) { + LayerNorm::forward_kernel( + m, + input.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 8cdc392c96..cce2fb9c9e 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -157,7 +157,7 @@ Op *Linear::create_operator_from_layer( Linear::Linear(FFModel &model, Linear const &other, - ParallelTensor const input, + const ParallelTensor input, bool allocate_weights) : Linear(model, other.layer_guid, @@ -175,7 +175,7 @@ Linear::Linear(FFModel &model, Linear::Linear(FFModel &model, LinearParams const ¶ms, - ParallelTensor const input, + const ParallelTensor input, char const *name, bool allocate_weights) : Linear(model, @@ -439,10 +439,10 @@ OpMeta *Linear::init_task(Task const *task, } \ } else if (output.data_type == DT_B16) { \ if (linear->quantization_type != DT_NONE) { \ - return init_task_with_dim<__nv_bfloat16, char, DIM>( \ + return init_task_with_dim<__ff_bfloat16, char, DIM>( \ task, regions, ctx, runtime); \ } else { \ - return init_task_with_dim<__nv_bfloat16, __nv_bfloat16, DIM>( \ + return init_task_with_dim<__ff_bfloat16, __ff_bfloat16, DIM>( \ task, regions, ctx, runtime); \ } \ } else { \ @@ -714,10 +714,10 @@ void Linear::forward_task(Task const *task, } \ } else if (m->output_type[0] == DT_B16) { \ if (m->quantization_type != DT_NONE) { \ - return forward_task_with_dim<__nv_bfloat16, char, DIM>( \ + return forward_task_with_dim<__ff_bfloat16, char, DIM>( \ task, regions, ctx, runtime); \ } else { \ - return forward_task_with_dim<__nv_bfloat16, __nv_bfloat16, DIM>( \ + return forward_task_with_dim<__ff_bfloat16, __ff_bfloat16, DIM>( \ task, regions, ctx, runtime); \ } \ } else { \ @@ -876,7 +876,7 @@ void Linear::backward_task(Task const *task, } else if (m->output_type[0] == DT_FLOAT) { \ return backward_task_with_dim(task, regions, ctx, runtime); \ } else if (m->output_type[0] == DT_B16) { \ - return backward_task_with_dim<__nv_bfloat16, DIM>( \ + return backward_task_with_dim<__ff_bfloat16, DIM>( \ task, regions, ctx, runtime); \ } else { \ assert(false && "Unsupported data type"); \ @@ -1360,7 +1360,7 @@ bool LinearParams::is_valid(ParallelTensorShape const &input_shape) const { * It takes a the input tensor as a parameter, instead of the input's * ParallelTensorShape. */ -void LinearParams::solve_dims(ParallelTensor const input, +void LinearParams::solve_dims(const ParallelTensor input, ParallelDim output_dims[MAX_TENSOR_DIM], int *output_ndims, ParallelDim kernel_dims[MAX_TENSOR_DIM], diff --git a/src/ops/residual_layer_norm.cpp b/src/ops/residual_layer_norm.cpp index f1b7a537b0..09823adf27 100644 --- a/src/ops/residual_layer_norm.cpp +++ b/src/ops/residual_layer_norm.cpp @@ -230,6 +230,18 @@ void ResidualLayerNorm::inference_kernel_wrapper( m->elementwise_affine ? gamma.get_half_ptr() : nullptr, (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, stream); + } else if (m->input_type[0] == DT_B16) { + ResidualLayerNorm::inference_kernel( + m, + input.get_bfloat16_ptr(), + residual1.get_bfloat16_ptr(), + m->use_two_residuals ? residual2.get_bfloat16_ptr() : nullptr, + added_output.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, + stream); } else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/residual_layer_norm.cu b/src/ops/residual_layer_norm.cu index cbb7e241f3..fc0bd5d5a2 100644 --- a/src/ops/residual_layer_norm.cu +++ b/src/ops/residual_layer_norm.cu @@ -234,9 +234,10 @@ void ResidualLayerNorm::inference_kernel_wrapper( added_output.get_bfloat16_ptr(), output.get_bfloat16_ptr(), m->elementwise_affine ? gamma.get_bfloat16_ptr() : nullptr, - (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_bfloat16_ptr() + : nullptr, stream); - }else { + } else { assert(false && "unsupport datatype in layernorm"); } diff --git a/src/ops/sigmoid_silu_multi.cpp b/src/ops/sigmoid_silu_multi.cpp index 7b7f30a288..b446ed0065 100644 --- a/src/ops/sigmoid_silu_multi.cpp +++ b/src/ops/sigmoid_silu_multi.cpp @@ -101,6 +101,14 @@ void SigmoidSiluMulti::inference_kernel_wrapper( input1.get_half_ptr(), input2.get_half_ptr(), output.get_half_ptr()); + } else if (m->input_type[0] == DT_B16) { + SigmoidSiluMultiKernel<<>>(input1.domain.get_volume(), + input1.get_bfloat16_ptr(), + input2.get_bfloat16_ptr(), + output.get_bfloat16_ptr()); } else { assert(false && "unsupport datatype in SigmoidSiluMulti"); } diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index 101eff9417..87598de634 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -52,7 +52,7 @@ SoftmaxParams Softmax::get_params() const { return params; } -Tensor FFModel::softmax(Tensor const _input, +Tensor FFModel::softmax(const Tensor _input, int dim, DataType data_type, char const *name) { @@ -93,7 +93,7 @@ Op *Softmax::create_operator_from_layer( } Softmax::Softmax(FFModel &model, - ParallelTensor const _input, + const ParallelTensor _input, int _dim, char const *name) : Op(model, @@ -117,7 +117,7 @@ Softmax::Softmax(FFModel &model, Softmax::Softmax(FFModel &model, SoftmaxParams const ¶ms, - ParallelTensor const input, + const ParallelTensor input, char const *name) : Softmax(model, input, params.dim, name) {} @@ -370,7 +370,7 @@ void Softmax::backward_task(Task const *task, } else if (m->output_type == DT_FLOAT) { \ return backward_task_with_dim(task, regions, ctx, runtime); \ } else if (m->output_type == DT_B16) { \ - return backward_task_with_dim<__nv_bfloat16, DIM>( \ + return backward_task_with_dim<__ff_bfloat16, DIM>( \ task, regions, ctx, runtime); \ } else { \ assert(false && "Unsupported data type"); \ diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index b1687d12a2..a7b8d4535c 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -562,6 +562,19 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_B16) { + float const *bias_ptr = use_bias + ? bias.get_bfloat16_ptr() + : static_cast(nullptr); + Kernels::SpecIncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index cc1a08ab9b..885a09c980 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -480,7 +480,7 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (m->output_type[0] == DT_FLOAT) { compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - }else if (m->output_type[0] == DT_B16) { + } else if (m->output_type[0] == DT_B16) { compute_type = CUBLAS_COMPUTE_32F; } #endif @@ -535,7 +535,8 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, int strideC = num_new_tokens * total_tokens; // a flag of using this scaling alpha - float alpha = 1.0f, beta = 0.0f; + typename cublasAlphaBetaType
::type alpha = 1.0; + typename cublasAlphaBetaType
::type beta = 0.0; if (*m->qk_prod_scaling) { alpha = static_cast
(1.0f / sqrt(m->kProjSize)); } @@ -791,9 +792,10 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); - }else if (input.data_type == DT_B16) { + } else if (input.data_type == DT_B16) { __nv_bfloat16 const *bias_ptr = - use_bias ? bias.get_bfloat16_ptr() : static_cast<__nv_bfloat16 const *>(nullptr); + use_bias ? bias.get_bfloat16_ptr() + : static_cast<__nv_bfloat16 const *>(nullptr); Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( m, bc, diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index 26291fb3b4..4f399aafc5 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -570,6 +570,24 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( output.get_float_ptr(), bias_ptr, stream); + } else if (input.data_type == DT_B16) { + if (m->offload) { + pre_build_weight_kernel(m, weight, input.data_type, stream); + } + + hip_bfloat16 const *bias_ptr = + use_bias ? bias.get_bfloat16_ptr() + : static_cast(nullptr); + Kernels::TreeIncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_bfloat16_ptr(), + m->offload ? static_cast(m->weight_ptr) + : weight.get_bfloat16_ptr(), + output.get_bfloat16_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index ffe57006fc..8bd963b3d3 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -72,7 +72,7 @@ Combine::Combine(FFModel &model, name) {} Combine::Combine(FFModel &model, - ParallelTensor const _input, + const ParallelTensor _input, int _combine_legion_dim, int _combine_degree, char const *name) @@ -366,7 +366,7 @@ void Combine::forward_task(Task const *task, if (data_type == DT_HALF) { forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_B16) { - forward_task_with_type<__nv_bfloat16>(task, regions, ctx, runtime); + forward_task_with_type<__ff_bfloat16>(task, regions, ctx, runtime); } else if (data_type == DT_FLOAT) { forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_DOUBLE) { diff --git a/src/parallel_ops/kernels/combine_kernels.cpp b/src/parallel_ops/kernels/combine_kernels.cpp index d6e9568223..b593a92a7d 100644 --- a/src/parallel_ops/kernels/combine_kernels.cpp +++ b/src/parallel_ops/kernels/combine_kernels.cpp @@ -57,6 +57,9 @@ template void forward_kernel(half const *input_ptr, template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements); +template void forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements); template void forward_kernel(double const *input_ptr, double *output_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/reduction_kernels.cpp b/src/parallel_ops/kernels/reduction_kernels.cpp index 2a3fe5cca1..ade5d9b402 100644 --- a/src/parallel_ops/kernels/reduction_kernels.cpp +++ b/src/parallel_ops/kernels/reduction_kernels.cpp @@ -78,6 +78,12 @@ template __global__ void reduction_forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements, size_t num_replicas); +template __global__ void + reduction_forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements, + size_t num_replicas); + template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements, @@ -86,6 +92,10 @@ template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements, size_t num_replicas); +template void forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements, + size_t num_replicas); template void backward_kernel(float const *output_grad_ptr, float *input_grad_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/replicate_kernels.cpp b/src/parallel_ops/kernels/replicate_kernels.cpp index 1647f014be..7d6b2fc63a 100644 --- a/src/parallel_ops/kernels/replicate_kernels.cpp +++ b/src/parallel_ops/kernels/replicate_kernels.cpp @@ -73,6 +73,9 @@ template void forward_kernel(float const *input_ptr, template void forward_kernel(half const *input_ptr, half *output_ptr, size_t num_elements); +template void forward_kernel(hip_bfloat16 const *input_ptr, + hip_bfloat16 *output_ptr, + size_t num_elements); template __global__ void replicate_backward_kernel(float const *input_ptr, float *output_ptr, diff --git a/src/parallel_ops/reduction.cc b/src/parallel_ops/reduction.cc index d53cae330b..bd7cb4085b 100644 --- a/src/parallel_ops/reduction.cc +++ b/src/parallel_ops/reduction.cc @@ -60,7 +60,7 @@ ReductionParams Reduction::get_params() const { } Reduction::Reduction(FFModel &model, - ParallelTensor const _input, + const ParallelTensor _input, int _reduction_legion_dim, int _reduction_degree, char const *name) @@ -381,7 +381,7 @@ void Reduction::forward_task(Task const *task, num_elements, num_replicas); } else if (input.data_type == DT_B16) { - forward_kernel<__nv_bfloat16>(input.get_bfloat16_ptr(), + forward_kernel<__ff_bfloat16>(input.get_bfloat16_ptr(), output.get_bfloat16_ptr(), num_elements, num_replicas); diff --git a/src/parallel_ops/replicate.cc b/src/parallel_ops/replicate.cc index d9f7b2263b..18e4d26f26 100644 --- a/src/parallel_ops/replicate.cc +++ b/src/parallel_ops/replicate.cc @@ -59,7 +59,7 @@ ReplicateParams Replicate::get_params() const { } Replicate::Replicate(FFModel &model, - ParallelTensor const _input, + const ParallelTensor _input, int _replicate_legion_dim, int _replicate_degree, char const *name) @@ -374,7 +374,7 @@ void Replicate::forward_task(Task const *task, output.get_float_ptr(), input_domain.get_volume()); } else if (input.data_type == DT_B16) { - forward_kernel<__nv_bfloat16>(input.get_bfloat16_ptr(), + forward_kernel<__ff_bfloat16>(input.get_bfloat16_ptr(), output.get_bfloat16_ptr(), input_domain.get_volume()); } else { diff --git a/src/runtime/accessor.cc b/src/runtime/accessor.cc index 9bb6c08eea..30aef93df9 100644 --- a/src/runtime/accessor.cc +++ b/src/runtime/accessor.cc @@ -77,12 +77,12 @@ half const *GenericTensorAccessorR::get_half_ptr() const { } } -__nv_bfloat16 const *GenericTensorAccessorR::get_bfloat16_ptr() const { +__ff_bfloat16 const *GenericTensorAccessorR::get_bfloat16_ptr() const { if (data_type == DT_B16) { - return static_cast<__nv_bfloat16 const *>(ptr); + return static_cast<__ff_bfloat16 const *>(ptr); } else { assert(false && "Invalid Accessor Type"); - return static_cast<__nv_bfloat16 const *>(nullptr); + return static_cast<__ff_bfloat16 const *>(nullptr); } } @@ -174,12 +174,12 @@ half *GenericTensorAccessorW::get_half_ptr() const { } } -__nv_bfloat16 *GenericTensorAccessorW::get_bfloat16_ptr() const { +__ff_bfloat16 *GenericTensorAccessorW::get_bfloat16_ptr() const { if (data_type == DT_B16) { - return static_cast<__nv_bfloat16 *>(ptr); + return static_cast<__ff_bfloat16 *>(ptr); } else { assert(false && "Invalid Accessor Type"); - return static_cast<__nv_bfloat16 *>(nullptr); + return static_cast<__ff_bfloat16 *>(nullptr); } } @@ -290,7 +290,7 @@ GenericTensorAccessorR break; } case DT_B16: { - ptr = helperGetTensorPointerRO<__nv_bfloat16>( + ptr = helperGetTensorPointerRO<__ff_bfloat16>( region, req, fid, ctx, runtime); break; } @@ -341,7 +341,7 @@ GenericTensorAccessorW break; } case DT_B16: { - ptr = helperGetTensorPointerWO<__nv_bfloat16>( + ptr = helperGetTensorPointerWO<__ff_bfloat16>( region, req, fid, ctx, runtime); break; } @@ -392,7 +392,7 @@ GenericTensorAccessorW break; } case DT_B16: { - ptr = helperGetTensorPointerRW<__nv_bfloat16>( + ptr = helperGetTensorPointerRW<__ff_bfloat16>( region, req, fid, ctx, runtime); break; } diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index 764c938dc3..232d711d82 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -214,7 +214,7 @@ size_t data_type_size(DataType type) { case DT_HALF: return sizeof(half); case DT_B16: - return sizeof(__nv_bfloat16); + return sizeof(__ff_bfloat16); case DT_FLOAT: return sizeof(float); case DT_DOUBLE: diff --git a/src/runtime/hip_helper.cpp b/src/runtime/hip_helper.cpp index fb94135c8f..00a6730de0 100644 --- a/src/runtime/hip_helper.cpp +++ b/src/runtime/hip_helper.cpp @@ -299,6 +299,33 @@ __host__ void checkCUDA(hipHostFree(host_ptr)); } +template <> +__host__ void save_tensor(hip_bfloat16 const *ptr, + size_t num_elements, + char const *file_name) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + hip_bfloat16 *host_ptr; + checkCUDA(hipHostMalloc(&host_ptr, + sizeof(hip_bfloat16) * num_elements, + hipHostMallocPortable | hipHostMallocMapped)); + checkCUDA(hipMemcpyAsync(host_ptr, + ptr, + sizeof(hip_bfloat16) * num_elements, + hipMemcpyDeviceToHost, + stream)); + checkCUDA(hipDeviceSynchronize()); + FILE *tensor_file; + tensor_file = fopen(file_name, "w"); + assert(tensor_file != NULL); + for (unsigned i = 0; i < num_elements; i++) { + fprintf(tensor_file, "%.9f, ", (float)host_ptr[i]); + } + + fclose(tensor_file); + checkCUDA(hipHostFree(host_ptr)); +} + template <> __host__ void save_tensor(int32_t const *ptr, size_t num_elements, @@ -489,6 +516,8 @@ miopenDataType_t ff_to_cudnn_datatype(DataType type) { switch (type) { case DT_HALF: return miopenHalf; + case DT_B16: + return miopenBFloat16; case DT_FLOAT: return miopenFloat; case DT_DOUBLE: @@ -510,6 +539,10 @@ hipblasDatatype_t ff_to_cuda_datatype(DataType type) { return HIPBLAS_R_64F; case DT_INT32: return HIPBLAS_R_32I; + case DT_B16: + return HIPBLAS_R_16B; + case DT_HALF: + return HIPBLAS_R_16F; default: assert(false && "Unspoorted cuda data type"); } @@ -520,6 +553,8 @@ ncclDataType_t ff_to_nccl_datatype(DataType type) { switch (type) { case DT_HALF: return ncclHalf; + case DT_B16: + return ncclBfloat16; case DT_FLOAT: return ncclFloat; case DT_DOUBLE: @@ -540,6 +575,9 @@ void handle_unimplemented_hip_kernel(OperatorType op_type) { template __global__ void assign_kernel(half *ptr, coord_t size, half value); +template __global__ void assign_kernel(hip_bfloat16 *ptr, + coord_t size, + hip_bfloat16 value); template __global__ void assign_kernel(float *ptr, coord_t size, float value); template __global__ void @@ -609,7 +647,9 @@ template __host__ void save_tensor(int64_t const *ptr, char const *file_name); template __host__ void save_tensor(half const *ptr, size_t rect, char const *file_name); - +template __host__ void save_tensor(hip_bfloat16 const *ptr, + size_t rect, + char const *file_name); template __host__ float *download_tensor(float const *ptr, size_t num_elements); template __host__ half *download_tensor(half const *ptr, diff --git a/src/runtime/initializer_kernel.cpp b/src/runtime/initializer_kernel.cpp index 1005d93cec..2e55591729 100644 --- a/src/runtime/initializer_kernel.cpp +++ b/src/runtime/initializer_kernel.cpp @@ -259,6 +259,17 @@ void ZeroInitializer::init_task(Task const *task, w, domain.get_volume(), 0.0f); + } else if (meta->data_types[i] == DT_B16) { + hip_bfloat16 *w = helperGetTensorPointerWO( + regions[i], task->regions[i], FID_DATA, ctx, runtime); + hipLaunchKernelGGL(HIP_KERNEL_NAME(assign_kernel), + GET_BLOCKS(domain.get_volume()), + CUDA_NUM_THREADS, + 0, + stream, + w, + domain.get_volume(), + 0.0f); } else if (meta->data_types[i] == DT_INT32) { int32_t *w = helperGetTensorPointerWO( regions[i], task->regions[i], FID_DATA, ctx, runtime); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index f31be1bcdb..f184a41934 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1948,7 +1948,7 @@ void FFModel::map_tensor_with_dim2(ParallelTensor tensor, allocator.allocate_field(sizeof(half), FID_DATA); break; case DT_B16: - allocator.allocate_field(sizeof(__nv_bfloat16), FID_DATA); + allocator.allocate_field(sizeof(__ff_bfloat16), FID_DATA); break; case DT_FLOAT: allocator.allocate_field(sizeof(float), FID_DATA); diff --git a/src/runtime/parallel_tensor.cc b/src/runtime/parallel_tensor.cc index e329593970..d3cc8a0cb1 100644 --- a/src/runtime/parallel_tensor.cc +++ b/src/runtime/parallel_tensor.cc @@ -847,10 +847,10 @@ template bool ParallelTensorBase::get_tensor(FFModel const *ff, half *data, bool get_gradients); -template bool ParallelTensorBase::set_tensor<__nv_bfloat16>( - FFModel const *ff, std::vector const &dims, __nv_bfloat16 const *data); -template bool ParallelTensorBase::get_tensor<__nv_bfloat16>(FFModel const *ff, - __nv_bfloat16 *data, +template bool ParallelTensorBase::set_tensor<__ff_bfloat16>( + FFModel const *ff, std::vector const &dims, __ff_bfloat16 const *data); +template bool ParallelTensorBase::get_tensor<__ff_bfloat16>(FFModel const *ff, + __ff_bfloat16 *data, bool get_gradients); template bool ParallelTensorBase::set_tensor(FFModel const *ff,