Skip to content

Commit

Permalink
hip
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Jan 15, 2024
1 parent badbf53 commit fbdaba2
Show file tree
Hide file tree
Showing 38 changed files with 511 additions and 63 deletions.
12 changes: 9 additions & 3 deletions include/flexflow/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <cuda_bf16.h>
#elif defined(FF_USE_HIP_CUDA)
#include <cuda_fp16.h>
#include <hip_bfloat16.h>
#include <cuda_bf16.h>
#elif defined(FF_USE_HIP_ROCM)
#include <hip/hip_fp16.h>
#include <hip_bfloat16.h>
Expand All @@ -18,6 +18,12 @@

namespace FlexFlow {

#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
typedef __nv_bfloat16 __ff_bfloat16;
#elif defined(FF_USE_HIP_ROCM)
typedef hip_bfloat16 __ff_bfloat16;
#endif

template <typename FT, int N, typename T = Legion::coord_t>
using AccessorRO =
Legion::FieldAccessor<READ_ONLY, FT, N, T, Realm::AffineAccessor<FT, N, T>>;
Expand Down Expand Up @@ -64,7 +70,7 @@ class GenericTensorAccessorW {
float *get_float_ptr() const;
double *get_double_ptr() const;
half *get_half_ptr() const;
__nv_bfloat16 *get_bfloat16_ptr() const;
__ff_bfloat16 *get_bfloat16_ptr() const;
char *get_byte_ptr() const;
DataType data_type;
Legion::Domain domain;
Expand All @@ -84,7 +90,7 @@ class GenericTensorAccessorR {
float const *get_float_ptr() const;
double const *get_double_ptr() const;
half const *get_half_ptr() const;
__nv_bfloat16 const *get_bfloat16_ptr() const;
__ff_bfloat16 const *get_bfloat16_ptr() const;
char const *get_byte_ptr() const;
DataType data_type;
Legion::Domain domain;
Expand Down
14 changes: 14 additions & 0 deletions include/flexflow/utils/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ T *download_tensor(T const *ptr, size_t num_elements);
template <typename T>
bool download_tensor(T const *ptr, T *dst, size_t num_elements);

// data type for cublasgemm
template <typename T>
struct cublasAlphaBetaType {
using type = float; // default
};
template <>
struct cublasAlphaBetaType<half> {
using type = half;
};
template <>
struct cublasAlphaBetaType<__nv_bfloat16> {
using type = float;
};

cudnnStatus_t cudnnSetTensorDescriptorFromDomain(cudnnTensorDescriptor_t tensor,
Legion::Domain domain,
DataType data_type = DT_FLOAT);
Expand Down
12 changes: 6 additions & 6 deletions inference/file_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void load_attention_weights_multi_query(DT *ptr,

///////////////////////bfloat16 function///////////////////////

void load_from_file_b16(__nv_bfloat16 *ptr, size_t size, std::string filepath) {
void load_from_file_b16(__ff_bfloat16 *ptr, size_t size, std::string filepath) {
std::ifstream in(filepath, std::ios::in | std::ios::binary);
if (!in.good()) {
std::cout << "Could not open file: " << filepath << std::endl;
Expand Down Expand Up @@ -156,7 +156,7 @@ void load_from_file_b16(__nv_bfloat16 *ptr, size_t size, std::string filepath) {
in.close();
}

void load_attention_weights_v2_b16(__nv_bfloat16 *ptr,
void load_attention_weights_v2_b16(__ff_bfloat16 *ptr,
int num_heads,
int num_kv_heads,
size_t hidden_dim,
Expand Down Expand Up @@ -290,7 +290,7 @@ void load_attention_weights_v2_b16(__nv_bfloat16 *ptr,
}
}

void load_attention_bias_v2_b16(__nv_bfloat16 *ptr,
void load_attention_bias_v2_b16(__ff_bfloat16 *ptr,
int num_heads,
int num_kv_heads,
size_t hidden_dim,
Expand Down Expand Up @@ -382,8 +382,8 @@ void FileDataLoader::load_single_weight_tensor_b16(FFModel *ff,
dims_vec.push_back(weight->dims[i]);
volume *= weight->dims[i];
}
assert(data_type_size(weight->data_type) == sizeof(__nv_bfloat16));
__nv_bfloat16 *data = (__nv_bfloat16 *)malloc(sizeof(__nv_bfloat16) * volume);
assert(data_type_size(weight->data_type) == sizeof(__ff_bfloat16));
__ff_bfloat16 *data = (__ff_bfloat16 *)malloc(sizeof(__ff_bfloat16) * volume);

std::string weight_filename = removeGuidOperatorName(std::string(l->name));

Expand Down Expand Up @@ -446,7 +446,7 @@ void FileDataLoader::load_single_weight_tensor_b16(FFModel *ff,
// Copy the weight data from the buffer to the weight's ParallelTensor
ParallelTensor weight_pt;
ff->get_parallel_tensor_from_tensor(weight, weight_pt);
weight_pt->set_tensor<__nv_bfloat16>(ff, dims_vec, data);
weight_pt->set_tensor<__ff_bfloat16>(ff, dims_vec, data);

// Free buffer memory
delete data;
Expand Down
12 changes: 12 additions & 0 deletions src/ops/arg_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,18 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m,
m->sorted,
m->speculative_decoding ? bc : nullptr,
stream);
} else if (input.data_type == DT_B16) {
ArgTopK::forward_kernel(m,
input.get_bfloat16_ptr(),
m->speculative_decoding ? probs.get_float_ptr()
: nullptr,
indices.get_int32_ptr(),
batch_size,
length,
k,
m->sorted,
m->speculative_decoding ? bc : nullptr,
stream);
} else {
assert(false && "Unsupported data type");
}
Expand Down
12 changes: 11 additions & 1 deletion src/ops/argmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,16 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m,
length,
batch_size,
stream);
} else if (input.data_type == DT_B16) {
ArgMax::forward_kernel<hip_bfloat16>(m,
input.get_bfloat16_ptr(),
indices.get_int32_ptr(),
m->probs,
m->beam_search ? parent.get_int32_ptr()
: nullptr,
length,
batch_size,
stream);
} else {
assert(false && "Unsupported data type");
}
Expand All @@ -491,7 +501,7 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler,
: OpMeta(handler, op) {
DataType data_type = op->data_type;
size_t prob_size = batch_size;
assert(data_type == DT_FLOAT || data_type == DT_HALF);
assert(data_type == DT_FLOAT || data_type == DT_HALF || data_type == DT_B16);
size_t total_size = prob_size * sizeof(float);
gpu_mem_allocator.create_legion_instance(reserveInst, total_size);
probs = gpu_mem_allocator.allocate_instance<float>(prob_size);
Expand Down
6 changes: 6 additions & 0 deletions src/ops/element_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ template void
int64_t *output_ptr,
size_t num_elements);

template void ElementUnary::forward_kernel_wrapper<hip_bfloat16>(
ElementUnaryMeta const *m,
hip_bfloat16 const *input_ptr,
hip_bfloat16 *output_ptr,
size_t num_elements);

template void
ElementUnary::backward_kernel_wrapper<float>(ElementUnaryMeta const *m,
float const *input_ptr,
Expand Down
16 changes: 16 additions & 0 deletions src/ops/fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ __host__ void FusedOp::forward_task(Task const *task,
m,
my_input_accessor[0].get_float_ptr(),
my_output_accessor[0].get_float_ptr());
} else if (m->input_type == DT_B16) {
Kernels::Softmax::forward_kernel_wrapper(
m,
my_input_accessor[0].get_bfloat16_ptr(),
my_output_accessor[0].get_bfloat16_ptr());
}
break;
}
Expand Down Expand Up @@ -815,6 +820,12 @@ __host__ void
my_input_accessor[0].get_float_ptr(),
my_output_accessor[0].get_float_ptr(),
my_input_accessor[0].domain.get_volume());
} else if (m->data_type == DT_B16) {
ElementUnary::forward_kernel_wrapper(
m,
my_input_accessor[0].get_bfloat16_ptr(),
my_output_accessor[0].get_bfloat16_ptr(),
my_input_accessor[0].domain.get_volume());
} else {
assert(false && "Unsupported data type in ElementUnary forward");
}
Expand Down Expand Up @@ -1039,6 +1050,11 @@ __host__ void
m,
my_input_accessor[0].get_float_ptr(),
my_output_accessor[0].get_float_ptr());
} else if (m->input_type == DT_B16) {
Kernels::Softmax::forward_kernel_wrapper(
m,
my_input_accessor[0].get_bfloat16_ptr(),
my_output_accessor[0].get_bfloat16_ptr());
}
break;
}
Expand Down
24 changes: 24 additions & 0 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,23 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
output.get_float_ptr(),
bias_ptr,
stream);
} else if (input.data_type == DT_B16) {
if (m->offload) {
pre_build_weight_kernel<hip_bfloat16>(m, weight, input.data_type, stream);
}
hip_bfloat16 const *bias_ptr =
use_bias ? bias.get_bfloat16_ptr()
: static_cast<hip_bfloat16 const *>(nullptr);
Kernels::IncMultiHeadAttention::inference_kernel(
m,
bc,
shard_id,
input.get_bfloat16_ptr(),
m->offload ? static_cast<hip_bfloat16 *>(m->weight_ptr)
: weight.get_bfloat16_ptr(),
output.get_bfloat16_ptr(),
bias_ptr,
stream);
} else {
assert(false && "Unspported data type");
}
Expand Down Expand Up @@ -1098,4 +1115,11 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel<half>(
DataType data_type,
hipStream_t stream);

template void
Kernels::IncMultiHeadAttention::pre_build_weight_kernel<hip_bfloat16>(
IncMultiHeadSelfAttentionMeta const *m,
GenericTensorAccessorR const weight,
DataType data_type,
cudaStream_t stream);

}; // namespace FlexFlow
14 changes: 9 additions & 5 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,

// Step 1: Compute QKV projections
{
float alpha = 1.0f, beta = 0.0f;
typename cublasAlphaBetaType<DT>::type alpha = 1.0;
typename cublasAlphaBetaType<DT>::type beta = 0.0;
// after transpositions
int m_q = m->qProjSize * m->num_q_heads;
int m_k = m->kProjSize * m->num_q_heads;
Expand Down Expand Up @@ -645,7 +646,8 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m,
#endif
// Project to output, save result directly on output tensor
{
float alpha = 1.0f, beta = 0.0f;
typename cublasAlphaBetaType<DT>::type alpha = 1.0;
typename cublasAlphaBetaType<DT>::type beta = 0.0;
// after transpositions
int m_ = m->oProjSize;
int k = m->vProjSize * m->num_q_heads;
Expand Down Expand Up @@ -925,7 +927,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}else if (m->output_type[0] == DT_B16) {
} else if (m->output_type[0] == DT_B16) {
compute_type = CUBLAS_COMPUTE_32F;
}
#endif
Expand All @@ -951,7 +953,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
// Step 1: compute query-key product QK.T/sqrt(d_k)
{
// Scale by sqrt(d_k) as per the original attention paper
float alpha = 1.0f, beta = 0.0f;
typename cublasAlphaBetaType<DT>::type alpha = 1.0;
typename cublasAlphaBetaType<DT>::type beta = 0.0;
if (*m->qk_prod_scaling) {
alpha = static_cast<DT>(1.0f / sqrt(m->kProjSize));
}
Expand Down Expand Up @@ -1082,7 +1085,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
// Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @
// softmax(QK.T/sqrt(d_k)).T
{
float alpha = 1.0f, beta = 0.0f;
typename cublasAlphaBetaType<DT>::type alpha = 1.0;
typename cublasAlphaBetaType<DT>::type beta = 0.0;
// after transpositions
int m_ = m->vProjSize;
int n = num_new_tokens;
Expand Down
23 changes: 22 additions & 1 deletion src/ops/kernels/decompress_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,20 @@ template __global__ void decompress_int4_general_weights<float>(
char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize);
template __global__ void decompress_int4_general_weights<half>(
char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize);
template __global__ void
decompress_int4_general_weights<hip_bfloat16>(char const *input_weight_ptr,
hip_bfloat16 *weight_ptr,
int in_dim,
int valueSize);
template __global__ void decompress_int8_general_weights<float>(
char const *input_weight_ptr, float *weight_ptr, int in_dim, int valueSize);
template __global__ void decompress_int8_general_weights<half>(
char const *input_weight_ptr, half *weight_ptr, int in_dim, int valueSize);
template __global__ void
decompress_int8_general_weights<hip_bfloat16>(char const *input_weight_ptr,
hip_bfloat16 *weight_ptr,
int in_dim,
int valueSize);
template __global__ void
decompress_int4_attention_weights<float>(char *input_weight_ptr,
float *weight_ptr,
Expand All @@ -71,7 +81,12 @@ template __global__ void
int qProjSize,
int qSize,
int num_heads);

template __global__ void
decompress_int4_attention_weights<hip_bfloat16>(char *input_weight_ptr,
hip_bfloat16 *weight_ptr,
int qProjSize,
int qSize,
int num_heads);
template __global__ void
decompress_int8_attention_weights<float>(char *input_weight_ptr,
float *weight_ptr,
Expand All @@ -86,5 +101,11 @@ template __global__ void
int qSize,
int num_heads);

template __global__ void
decompress_int8_attention_weights<hip_bfloat16>(char *input_weight_ptr,
hip_bfloat16 *weight_ptr,
int qProjSize,
int qSize,
int num_heads);
} // namespace Kernels
}; // namespace FlexFlow
20 changes: 18 additions & 2 deletions src/ops/kernels/element_binary_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,24 @@ void forward_kernel_wrapper(ElementBinaryMeta const *m,
}
// print_tensor<float>(in1_ptr, in1_domain.get_volume(), "input1:");
// print_tensor<float>(in2_ptr, in2_domain.get_volume(), "input2:");
Internal::forward_kernel(
m, in1.get_float_ptr(), in2.get_float_ptr(), out.get_float_ptr(), stream);
if (out.data_type == DT_HALF) {
Internal::forward_kernel(
m, in1.get_half_ptr(), in2.get_half_ptr(), out.get_half_ptr(), stream);
} else if (out.data_type == DT_FLOAT) {
Internal::forward_kernel(m,
in1.get_float_ptr(),
in2.get_float_ptr(),
out.get_float_ptr(),
stream);
} else if (out.data_type == DT_B16) {
Internal::forward_kernel(m,
in1.get_bfloat16_ptr(),
in2.get_bfloat16_ptr(),
out.get_bfloat16_ptr(),
stream);
} else {
assert(false && "Unsupported data type");
}
// print_tensor<float>(out_ptr, in1_domain.get_volume(), "output:");
if (m->profiling) {
checkCUDA(hipEventRecord(t_end, stream));
Expand Down
Loading

0 comments on commit fbdaba2

Please sign in to comment.