From a8294e8c8641cc7b6f25a0f6eb1629051864babe Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Tue, 3 Sep 2024 14:15:52 +0000 Subject: [PATCH] fix hip build --- config/config.inc | 2 +- src/ops/add_bias_residual_layer_norm.cc | 2 +- src/ops/aggregate.cc | 2 +- src/ops/aggregate_spec.cc | 2 +- src/ops/arg_topk.cc | 2 +- src/ops/argmax.cc | 2 +- src/ops/argmax.cpp | 491 ++++++++++++++---- src/ops/argmax.cu | 4 +- src/ops/attention.cc | 2 +- src/ops/fused.cpp | 4 + src/ops/group_by.cc | 2 +- src/ops/inc_multihead_self_attention.cc | 2 +- src/ops/layer_norm.cc | 2 +- src/ops/linear.cc | 2 +- src/ops/lora_linear.cc | 2 +- src/ops/reduce.cc | 2 +- src/ops/reshape.cc | 2 +- src/ops/residual_layer_norm.cc | 2 +- src/ops/residual_rms_norm.cc | 2 +- src/ops/rms_norm.cc | 2 +- src/ops/sampling.cc | 2 +- src/ops/sigmoid_silu_multi.cc | 2 +- src/ops/softmax.cc | 2 +- src/ops/spec_inc_multihead_self_attention.cc | 2 +- src/ops/split.cc | 2 +- src/ops/topk.cc | 2 +- src/ops/transpose.cc | 2 +- src/ops/tree_inc_multihead_self_attention.cc | 2 +- src/parallel_ops/allreduce.cc | 6 +- src/parallel_ops/combine.cc | 6 +- src/parallel_ops/fused_parallel_op.cc | 2 +- .../kernels/parallel_identity_kernels.cpp | 97 ++++ src/parallel_ops/parallel_identity.cc | 6 +- src/parallel_ops/partition.cc | 6 +- src/parallel_ops/reduction.cc | 6 +- src/parallel_ops/replicate.cc | 6 +- src/runtime/request_manager.cc | 2 +- 37 files changed, 524 insertions(+), 162 deletions(-) create mode 100644 src/parallel_ops/kernels/parallel_identity_kernels.cpp diff --git a/config/config.inc b/config/config.inc index 7d7b2db9cf..6431eaf136 100644 --- a/config/config.inc +++ b/config/config.inc @@ -197,7 +197,7 @@ fi # set ROCM path if [ -n "$ROCM_PATH" ]; then - SET_ROCM_PATH="-DROCM_PATH=${ROCM_PATH}" + SET_ROCM_PATH="-DROCM_PATH=${ROCM_PATH} -DHIP_ROOT_DIR=${ROCM_PATH}" fi ADD_ROCM_TO_PATH="" diff --git a/src/ops/add_bias_residual_layer_norm.cc b/src/ops/add_bias_residual_layer_norm.cc index da72f0e0fb..7a1da2e974 100644 --- a/src/ops/add_bias_residual_layer_norm.cc +++ b/src/ops/add_bias_residual_layer_norm.cc @@ -60,7 +60,7 @@ AddBiasResidualLayerNormParams AddBiasResidualLayerNorm::get_params() const { params.eps = this->eps; params.use_bias = this->use_bias; params.inplace_residual = this->inplace_residual; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/aggregate.cc b/src/ops/aggregate.cc index 96cbcb8bb5..c83b738a0e 100644 --- a/src/ops/aggregate.cc +++ b/src/ops/aggregate.cc @@ -85,7 +85,7 @@ AggregateParams Aggregate::get_params() const { AggregateParams params; params.n = this->n; params.lambda_bal = this->lambda_bal; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/aggregate_spec.cc b/src/ops/aggregate_spec.cc index 413b27e94a..6ea3ff3747 100644 --- a/src/ops/aggregate_spec.cc +++ b/src/ops/aggregate_spec.cc @@ -84,7 +84,7 @@ AggregateSpecParams AggregateSpec::get_params() const { AggregateSpecParams params; params.n = this->n; params.lambda_bal = this->lambda_bal; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/arg_topk.cc b/src/ops/arg_topk.cc index 53b259a703..534bac2419 100644 --- a/src/ops/arg_topk.cc +++ b/src/ops/arg_topk.cc @@ -112,7 +112,7 @@ ArgTopKParams ArgTopK::get_params() const { params.k = this->k; params.sorted = this->sorted; params.speculative_decoding = this->speculative_decoding; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index 81ffbc6a88..4123e50e7e 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -91,7 +91,7 @@ Op *ArgMax::create_operator_from_layer( ArgMaxParams ArgMax::get_params() const { ArgMaxParams params; params.beam_search = this->beam_search; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/argmax.cpp b/src/ops/argmax.cpp index 233f36304a..60d44cdf2b 100644 --- a/src/ops/argmax.cpp +++ b/src/ops/argmax.cpp @@ -21,28 +21,316 @@ namespace FlexFlow { -__global__ void init_offset(int batch_size, - int vocab_size, - int total_eles, - int *d_offsets) { - CUDA_KERNEL_LOOP(i, total_eles) { - if (i % vocab_size == 0) { - d_offsets[i / vocab_size] = i; +using Legion::coord_t; + +enum class HeapType { kMinHeap, kMaxHeap }; +enum class PreferIndices { kLower, kHigher }; + +template +struct Entry { + int index; + T value; +}; + +template +struct LinearData { + typedef Entry Entry; + + __device__ Entry &operator[](std::size_t index) const { + return data[index]; + } + + __device__ int get_index(int i) const { + return data[i].index; + } + __device__ T get_value(int i) const { + return data[i].value; + } + + Entry *const data; +}; + +template +struct IndirectLinearData { + typedef Entry Entry; + + __device__ Entry &operator[](std::size_t index) const { + return data[index]; + } + + __device__ int get_index(int i) const { + return backing_data[data[i].index].index; + } + __device__ T get_value(int i) const { + return data[i].value; + } + + Entry *const data; + Entry *const backing_data; +}; + +template +struct StridedData { + typedef Entry Entry; + + __device__ Entry &operator[](std::size_t index) const { + return data[index * blockDim.x + threadIdx.x]; + } + + __device__ int get_index(int i) const { + return (*this)[i].index; + } + __device__ T get_value(int i) const { + return (*this)[i].value; + } + + Entry *const data; +}; + +// A heap of Entry that can either work as a min-heap or as a max-heap. +template + class Data, + typename T> +struct IndexedHeap { + typedef typename Data::Entry Entry; + Data const data; + __device__ IndexedHeap(Data const &d) : data(d) {} + + __device__ bool is_above(int left, int right) { + T left_value = data.get_value(left); + T right_value = data.get_value(right); + if (left_value == right_value) { + if (preferIndices == PreferIndices::kLower) { + return data.get_index(left) < data.get_index(right); + } else { + return data.get_index(left) > data.get_index(right); + } } + if (heapType == HeapType::kMinHeap) { + return left_value < right_value; + } else { + return left_value > right_value; + } + } + + __device__ void assign(int i, Entry const &entry) { + data[i] = entry; } + + __device__ void push_up(int i) { + int child = i; + int parent; + for (; child > 0; child = parent) { + parent = (child - 1) / 2; + if (!is_above(child, parent)) { + // Heap property satisfied. + break; + } + swap(child, parent); + } + } + + __device__ void swap(int a, int b) { + auto tmp = data[b]; + data[b] = data[a]; + data[a] = tmp; + } + + __device__ void push_root_down(int k) { + push_down(0, k); + } + + // MAX-HEAPIFY in Cormen + __device__ void push_down(int node, int k) { + while (true) { + int const left = 2 * node + 1; + int const right = left + 1; + int smallest = node; + if (left < k && is_above(left, smallest)) { + smallest = left; + } + if (right < k && is_above(right, smallest)) { + smallest = right; + } + if (smallest == node) { + break; + } + swap(smallest, node); + node = smallest; + } + } + + // BUILD-MAX-HEAPIFY in Cormen + __device__ void build(int k) { + for (int node = (k - 1) / 2; node >= 0; node--) { + push_down(node, k); + } + } + + // HEAP-EXTRACT-MAX in Cormen + __device__ void remove_root(int k) { + data[0] = data[k - 1]; + push_root_down(k - 1); + } + + // in-place HEAPSORT in Cormen + // This method destroys the heap property. + __device__ void sort(int k) { + for (int slot = k - 1; slot > 0; slot--) { + // This is like remove_root but we insert the element at the end. + swap(slot, 0); + // Heap is now an element smaller. + push_root_down(/*k=*/slot); + } + } + + __device__ void replace_root(Entry const &entry, int k) { + data[0] = entry; + push_root_down(k); + } + + __device__ Entry const &root() { + return data[0]; + } +}; + +template + class Data, + typename T> +__device__ IndexedHeap + make_indexed_heap(typename Data::Entry *data) { + return IndexedHeap{Data{data}}; } -template -__global__ void copy_result(hipcub::KeyValuePair *d_out, - int *indices, - float *prob_ptr, - int batch_size, - bool beam_search) { - CUDA_KERNEL_LOOP(i, batch_size) { - indices[i] = d_out[i].key; - if (beam_search) { - prob_ptr[i] = static_cast(d_out[i].value); +// heapArgTopK walks over [input, input+length) with `step_size` stride starting +// at `start_index`. It builds a top-`k` heap that is stored in `heap_entries` +// using `Accessor` to access elements in `heap_entries`. If sorted=true, the +// elements will be sorted at the end. +template class Data = LinearData> +__device__ void heapArgTopK(T const *__restrict__ input, + int length, + int k, + Entry *__restrict__ heap_entries, + bool sorted = false, + int start_index = 0, + int step_size = 1) { + assert(k <= length); + + auto heap = + make_indexed_heap( + heap_entries); + + int heap_end_index = start_index + k * step_size; + if (heap_end_index > length) { + heap_end_index = length; + } + // Initialize the min-heap. + for (int index = start_index, slot = 0; index < heap_end_index; + index += step_size, slot++) { + heap.assign(slot, {index, input[index]}); + } + + heap.build(k); + + // Now iterate over the remaining items. + // If an item is smaller than the min element, it is not amongst the top k. + // Otherwise, replace the min element with it and push upwards. + for (int index = heap_end_index; index < length; index += step_size) { + // We prefer elements with lower indices. This is given here. + // Later elements automatically have higher indices, so can be discarded. + if (input[index] > heap.root().value) { + // This element should replace the min. + heap.replace_root({index, input[index]}, k); + } + } + + // Sort if wanted. + if (sorted) { + heap.sort(k); + } +} + +// mergeShards performs a top-k merge on `num_shards` many sorted streams that +// are sorted and stored in `entries` in a strided way: +// |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|... +// The overall top k elements are written to `top_k_values` and their indices +// to top_k_indices. +// `top_k_heap` is used as temporary storage for the merge heap. +template +__device__ void mergeShards(int num_shards, + int k, + Entry *__restrict__ entries, + Entry *__restrict__ top_k_heap, + float *top_k_values, + int *top_k_indices) { + // If k < num_shards, we can use a min-heap with k elements to get the top k + // of the sorted blocks. + // If k > num_shards, we can initialize a min-heap with the top element from + // each sorted block. + int const heap_size = k < num_shards ? k : num_shards; + + // Min-heap part. + { + auto min_heap = IndexedHeap{IndirectLinearData{top_k_heap, entries}}; + // Initialize the heap as a min-heap. + for (int slot = 0; slot < heap_size; slot++) { + min_heap.assign(slot, {slot, entries[slot].value}); } + min_heap.build(heap_size); + + // Now perform top k with the remaining shards (if num_shards > heap_size). + for (int shard = heap_size; shard < num_shards; shard++) { + auto const entry = entries[shard]; + auto const root = min_heap.root(); + if (entry.value < root.value) { + continue; + } + if (entry.value == root.value && + entry.index > entries[root.index].index) { + continue; + } + // This element should replace the min. + min_heap.replace_root({shard, entry.value}, heap_size); + } + } + + // Max-part. + { + // Turn the min-heap into a max-heap in-place. + auto max_heap = IndexedHeap{IndirectLinearData{top_k_heap, entries}}; + // Heapify into a max heap. + max_heap.build(heap_size); + + // Now extract the minimum k-1 times. + // k is treated specially. + int const last_k = k - 1; + for (int rank = 0; rank < last_k; rank++) { + Entry const &max_element = max_heap.root(); + top_k_values[rank] = __half2float(max_element.value); + int shard_index = max_element.index; + top_k_indices[rank] = entries[shard_index].index; + int next_shard_index = shard_index + num_shards; + // For rank < k-1, each top k heap still contains at least 1 element, + // so we can draw a replacement. + max_heap.replace_root({next_shard_index, entries[next_shard_index].value}, + heap_size); + } + + // rank == last_k. + Entry const &max_element = max_heap.root(); + top_k_values[last_k] = __half2float(max_element.value); + int shard_index = max_element.index; + top_k_indices[last_k] = entries[shard_index].index; } } @@ -61,6 +349,50 @@ __global__ void compute_sparse_categorical_crossentropy_loss( } } +template +__global__ void argmax_forward_kernel(T const *__restrict__ input, + size_t shared_memory_size, + int length, + int k, + float *__restrict__ output, + int *__restrict__ indices) { + __shared__ char shared_memory[48 << 10]; + int const batch_index = blockIdx.x; + T const *batch_input = input + batch_index * length; + int const thread_index = threadIdx.x; + int const thread_count = blockDim.x; + Entry *shared_entries = (Entry *)shared_memory; + heapArgTopK( + batch_input, length, k, shared_entries, true, thread_index, thread_count); + __syncthreads(); + if (thread_index == 0) { + int const offset = batch_index * k; + auto batch_output = output + offset; + auto batch_indices = indices + offset; + Entry *top_k_heap = shared_entries + thread_count * k; + mergeShards(thread_count, + k, + shared_entries, + top_k_heap, + batch_output, + batch_indices); + } +} + +template +__global__ void copy_result(hipcub::KeyValuePair *d_out, + int *indices, + float *prob_ptr, + int batch_size, + bool beam_search) { + CUDA_KERNEL_LOOP(i, batch_size) { + indices[i] = d_out[i].key; + if (beam_search) { + prob_ptr[i] = static_cast(d_out[i].value); + } + } +} + /*static*/ template void ArgMax::forward_kernel(ArgMaxMeta const *m, @@ -75,40 +407,41 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, hipStream_t stream) { checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - DT alpha = 1.0f, beta = 0.0f; + if (m->beam_search) { // set all parents id zero in arg top1 case. checkCUDA(hipMemsetAsync(parent, 0, batch_size * sizeof(int), stream)); } - if (m->beam_search) { - // set all parents id zero in arg top1 case. - checkCUDA(hipMemsetAsync(parent, 0, batch_size * sizeof(int), stream)); + int num_shards = 0; + int k = 1; + { + constexpr auto shared_memory_size = 48 << 10; + auto const heap_size = k * sizeof(Entry
); + // shared_memory_size = (num_shards + 1) * heap_size <=> + num_shards = shared_memory_size / heap_size - 1; + assert(num_shards > 0); + if (num_shards > CUDA_NUM_THREADS) { + num_shards = CUDA_NUM_THREADS; + } } - size_t temp_storage_bytes = m->temp_storage_bytes; - // use cub - checkCUDA(hipcub::DeviceSegmentedReduce::ArgMax( - m->d_temp_storage, - temp_storage_bytes, - input_ptr, - static_cast *>(m->d_out), - batch_size, - m->d_offsets, - m->d_offsets + 1, - stream)); - - // copy dout to incides - int parallelism = batch_size; - hipLaunchKernelGGL(HIP_KERNEL_NAME(copy_result), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), + // We are limited by the amount of shared memory we have per block. + size_t shared_memory_size = (num_shards + 1) * k * sizeof(Entry
); + // size_t num_blocks = (batch_size + num_shards - 1) / num_shards; + size_t num_blocks = batch_size; + assert(num_shards >= (size_t)k); + num_shards = k; + + hipLaunchKernelGGL(argmax_forward_kernel, + num_blocks, + num_shards, 0, stream, - static_cast *>(m->d_out), - indices_ptr, + input_ptr, + shared_memory_size, + length, + k, prob_ptr, - batch_size, - m->beam_search); - // print_tensor(indices_ptr, 32, "argmax op"); + indices_ptr); // compute cross-entropy loss if there is a finetuning request assert(loss != nullptr); @@ -142,12 +475,11 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, stream)); // copy loss to d_loss checkCUDA(hipMemsetAsync(m->d_loss, 0, sizeof(float), stream)); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(compute_sparse_categorical_crossentropy_loss), - GET_BLOCKS(num_bwd_tokens), - min(CUDA_NUM_THREADS, num_bwd_tokens), - 0, - stream, + compute_sparse_categorical_crossentropy_loss<<>>( input_ptr, static_cast(m->handle.workSpace), m->d_loss, @@ -214,7 +546,6 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, checkCUDA(hipEventElapsedTime(&elapsed, t_start, t_end)); checkCUDA(hipEventDestroy(t_start)); checkCUDA(hipEventDestroy(t_end)); - printf("[ArgMax] forward time = %.2lfms\n", elapsed); } } @@ -228,70 +559,12 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, MemoryAllocator &gpu_mem_allocator) : OpMeta(handler, op) { DataType data_type = op->data_type; - hipStream_t stream; - checkCUDA(get_legion_stream(&stream)); - - size_t d_offsets_size = batch_size; size_t prob_size = batch_size; assert(data_type == DT_FLOAT || data_type == DT_HALF); - size_t total_size = - d_offsets_size * sizeof(int) + - (data_type == DT_FLOAT - ? sizeof(hipcub::KeyValuePair) * batch_size - : sizeof(hipcub::KeyValuePair) * batch_size) + - prob_size * sizeof(float); + size_t total_size = prob_size * sizeof(float); gpu_mem_allocator.create_legion_instance(reserveInst, total_size); - d_offsets = gpu_mem_allocator.allocate_instance(d_offsets_size); - d_out = data_type == DT_FLOAT - ? gpu_mem_allocator.allocate_instance_untyped( - batch_size * sizeof(hipcub::KeyValuePair)) - : gpu_mem_allocator.allocate_instance_untyped( - batch_size * sizeof(hipcub::KeyValuePair)); probs = gpu_mem_allocator.allocate_instance(prob_size); - // init offset - int parallelism = total_ele; - hipLaunchKernelGGL(HIP_KERNEL_NAME(init_offset), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - batch_size, - total_ele / batch_size, - total_ele, - d_offsets); - - if (data_type == DT_FLOAT) { - checkCUDA(hipcub::DeviceSegmentedReduce::ArgMax( - d_temp_storage, - temp_storage_bytes, - input.get_float_ptr(), - static_cast *>(d_out), - batch_size, - d_offsets, - d_offsets + 1, - stream)); - - } else if (data_type == DT_HALF) { - checkCUDA(hipcub::DeviceSegmentedReduce::ArgMax( - d_temp_storage, - temp_storage_bytes, - input.get_half_ptr(), - static_cast *>(d_out), - batch_size, - d_offsets, - d_offsets + 1, - stream)); - } - - gpu_mem_allocator.create_legion_instance(reserveInst, temp_storage_bytes); - d_temp_storage = - gpu_mem_allocator.allocate_instance_untyped(temp_storage_bytes); - - // allocate space for loss on device - gpu_mem_allocator.create_legion_instance(reserveInst, sizeof(float)); - d_loss = gpu_mem_allocator.allocate_instance(1); } - ArgMaxMeta::~ArgMaxMeta(void) { if (reserveInst != Realm::RegionInstance::NO_INST) { reserveInst.destroy(); diff --git a/src/ops/argmax.cu b/src/ops/argmax.cu index 78f432acee..8a2e2da2d0 100644 --- a/src/ops/argmax.cu +++ b/src/ops/argmax.cu @@ -72,7 +72,7 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, float *loss, cudaStream_t stream) { checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - DT alpha = 1.0f, beta = 0.0f; + if (m->beam_search) { // set all parents id zero in arg top1 case. checkCUDA(cudaMemsetAsync(parent, 0, batch_size * sizeof(int), stream)); @@ -89,7 +89,7 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, m->d_offsets + 1, stream)); - // copy dout to incides + // copy dout to indices int parallelism = batch_size; copy_result<<bias; params.add_bias_kv = this->add_bias_kv; params.add_zero_attn = this->add_zero_attn; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index ccabab1cc8..9f826cd611 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -15,6 +15,7 @@ #include "flexflow/ops/fused.h" #include "flexflow/accessor.h" +#include "flexflow/ffconst_utils.h" #include "flexflow/model.h" #include "flexflow/ops/add_bias_residual_layer_norm.h" #include "flexflow/ops/batch_norm.h" @@ -43,6 +44,7 @@ #include "flexflow/ops/spec_inc_multihead_self_attention.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" #include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/parallel_ops/kernels/parallel_identity_kernels.h" #include "flexflow/utils/hip_helper.h" #include @@ -92,6 +94,7 @@ __host__ void if (bc->num_tokens == 0) { return; } + assert(metas->numOperators == fused->numOperators); assert(regions.size() == task->regions.size()); bool softmax_grad_additional_region = @@ -1763,6 +1766,7 @@ __host__ void FusedOp::backward_task(Task const *task, if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { my_input_accessor[i] = input_accessor[my_off]; my_input_grad_accessor[i] = input_grad_accessor[my_off]; + assert(my_input_grad_accessor[i].domain == my_input_accessor[i].domain); } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { my_input_accessor[i] = output_accessor[my_off]; my_input_grad_accessor[i] = output_grad_accessor[my_off]; diff --git a/src/ops/group_by.cc b/src/ops/group_by.cc index 7ebfb8af66..03b9a5199b 100644 --- a/src/ops/group_by.cc +++ b/src/ops/group_by.cc @@ -99,7 +99,7 @@ Group_byParams Group_by::get_params() const { Group_byParams params; params.n = this->n; params.alpha = this->alpha; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index a13a82135f..8219cf9e1f 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -1061,7 +1061,7 @@ IncMultiHeadSelfAttentionParams IncMultiHeadSelfAttention::get_params() const { params.quantization_type = this->quantization_type; params.offload = this->offload; params.num_kv_heads = this->num_kv_heads; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } diff --git a/src/ops/layer_norm.cc b/src/ops/layer_norm.cc index eea10a5cae..3161987d60 100644 --- a/src/ops/layer_norm.cc +++ b/src/ops/layer_norm.cc @@ -57,7 +57,7 @@ LayerNormParams LayerNorm::get_params() const { params.elementwise_affine = this->elementwise_affine; params.eps = this->eps; params.use_bias = this->use_bias; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/linear.cc b/src/ops/linear.cc index a7e6e671fc..20ad762b62 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -1433,7 +1433,7 @@ LinearParams Linear::get_params() const { params.kernel_reg_lambda = this->kernel_reg_lambda; params.quantization_type = this->quantization_type; params.offload = this->offload; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index 681df8b1b8..fde6bc2b28 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -1278,7 +1278,7 @@ LoraLinearParams LoraLinear::get_params() const { LoraLinearParams params; params.layer_guid = this->layer_guid; params.type = this->op_type; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } params.peft_configs = this->peft_configs; diff --git a/src/ops/reduce.cc b/src/ops/reduce.cc index 454a35caf4..1c0566e9ca 100644 --- a/src/ops/reduce.cc +++ b/src/ops/reduce.cc @@ -41,7 +41,7 @@ ReduceParams Reduce::get_params() const { } params.keepdims = keepdims; params.layer_guid = this->layer_guid; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/reshape.cc b/src/ops/reshape.cc index 7ebe29a6f6..4e7fd2eb96 100644 --- a/src/ops/reshape.cc +++ b/src/ops/reshape.cc @@ -296,7 +296,7 @@ ReshapeParams Reshape::get_params() const { ReshapeParams params; params.shape = shape_vec; params.layer_guid = this->layer_guid; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/residual_layer_norm.cc b/src/ops/residual_layer_norm.cc index eae3996358..2a30d12d6d 100644 --- a/src/ops/residual_layer_norm.cc +++ b/src/ops/residual_layer_norm.cc @@ -65,7 +65,7 @@ ResidualLayerNormParams ResidualLayerNorm::get_params() const { params.use_bias = this->use_bias; params.use_two_residuals = this->use_two_residuals; params.inplace_residual = this->inplace_residual; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/residual_rms_norm.cc b/src/ops/residual_rms_norm.cc index 6d83bf8e87..744902f908 100644 --- a/src/ops/residual_rms_norm.cc +++ b/src/ops/residual_rms_norm.cc @@ -57,7 +57,7 @@ ResidualRMSNormParams ResidualRMSNorm::get_params() const { params.eps = this->eps; params.dim = this->dim; params.inplace_residual = this->inplace_residual; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/rms_norm.cc b/src/ops/rms_norm.cc index 0519ccfb55..8dadd7dcc3 100644 --- a/src/ops/rms_norm.cc +++ b/src/ops/rms_norm.cc @@ -53,7 +53,7 @@ RMSNormParams RMSNorm::get_params() const { params.layer_guid = this->layer_guid; params.eps = this->eps; params.dim = this->dim; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc index 44094b70d0..0358a2cd31 100644 --- a/src/ops/sampling.cc +++ b/src/ops/sampling.cc @@ -88,7 +88,7 @@ Op *Sampling::create_operator_from_layer( SamplingParams Sampling::get_params() const { SamplingParams params; params.top_p = this->top_p; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/sigmoid_silu_multi.cc b/src/ops/sigmoid_silu_multi.cc index 66306b2d1c..e7c2fea19c 100644 --- a/src/ops/sigmoid_silu_multi.cc +++ b/src/ops/sigmoid_silu_multi.cc @@ -52,7 +52,7 @@ bool SigmoidSiluMultiParams::is_valid( SigmoidSiluMultiParams SigmoidSiluMulti::get_params() const { SigmoidSiluMultiParams params; params.layer_guid = this->layer_guid; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index 6a80c06c24..a02d88b98b 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -86,7 +86,7 @@ SoftmaxParams Softmax::get_params() const { SoftmaxParams params; params.layer_guid = this->layer_guid; params.dim = this->dim; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index 68d3a4c205..52da51fb26 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -850,7 +850,7 @@ SpecIncMultiHeadSelfAttentionParams params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; params.position_bias = this->position_bias; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } diff --git a/src/ops/split.cc b/src/ops/split.cc index 7c6b631b20..92cfbd49e9 100644 --- a/src/ops/split.cc +++ b/src/ops/split.cc @@ -50,7 +50,7 @@ SplitParams Split::get_params() const { SplitParams params; params.splits = this->splits; params.legion_axis = this->legion_axis; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/topk.cc b/src/ops/topk.cc index c49e722d22..0e88befa68 100644 --- a/src/ops/topk.cc +++ b/src/ops/topk.cc @@ -87,7 +87,7 @@ TopKParams TopK::get_params() const { TopKParams params; params.k = this->k; params.sorted = this->sorted; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/transpose.cc b/src/ops/transpose.cc index 30f7243157..bffde477de 100644 --- a/src/ops/transpose.cc +++ b/src/ops/transpose.cc @@ -51,7 +51,7 @@ TransposeParams Transpose::get_params() const { for (int i = 0; i < outputs[0]->num_dims; i++) { params.perm.push_back(this->perm[i]); } - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index df722a3d51..132a48be40 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -924,7 +924,7 @@ TreeIncMultiHeadSelfAttentionParams params.qk_prod_scaling = this->qk_prod_scaling; params.position_bias = this->position_bias; params.tensor_parallelism_degree = this->tensor_parallelism_degree; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc index 6cd3e3d482..52c4ec2e28 100644 --- a/src/parallel_ops/allreduce.cc +++ b/src/parallel_ops/allreduce.cc @@ -46,9 +46,7 @@ using namespace FlexFlow::Kernels::AllReduce; /* Params */ bool operator==(AllReduceParams const &lhs, AllReduceParams const &rhs) { return lhs.allreduce_legion_dim == rhs.allreduce_legion_dim && - ((lhs.name == NULL && rhs.name == NULL) || - (lhs.name != NULL && rhs.name != NULL && - std::strcmp(lhs.name, rhs.name) == 0)); + std::strcmp(lhs.name, rhs.name) == 0; } bool AllReduceParams::is_valid(ParallelTensorShape const &input) const { @@ -58,7 +56,7 @@ bool AllReduceParams::is_valid(ParallelTensorShape const &input) const { AllReduceParams AllReduce::get_params() const { AllReduceParams params; params.allreduce_legion_dim = this->allreduce_dim; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index ead9af836c..ce9c032350 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -45,9 +45,7 @@ using namespace FlexFlow::Kernels::Combine; bool operator==(CombineParams const &lhs, CombineParams const &rhs) { return lhs.combine_legion_dim == rhs.combine_legion_dim && lhs.combine_degree == rhs.combine_degree && - ((lhs.name == NULL && rhs.name == NULL) || - (lhs.name != NULL && rhs.name != NULL && - std::strcmp(lhs.name, rhs.name) == 0)); + std::strcmp(lhs.name, rhs.name) == 0; } bool CombineParams::is_valid(ParallelTensorShape const &input) const { @@ -61,7 +59,7 @@ CombineParams Combine::get_params() const { CombineParams params; params.combine_legion_dim = this->combine_dim; params.combine_degree = this->combine_degree; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/parallel_ops/fused_parallel_op.cc b/src/parallel_ops/fused_parallel_op.cc index 1a76cbfc40..dec7b20fb2 100644 --- a/src/parallel_ops/fused_parallel_op.cc +++ b/src/parallel_ops/fused_parallel_op.cc @@ -59,7 +59,7 @@ FusedParallelOpParams FusedParallelOp::get_params() const { std::vector ops(std::begin(this->parallel_ops), std::end(this->parallel_ops)); params.parallel_ops = ops; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/parallel_ops/kernels/parallel_identity_kernels.cpp b/src/parallel_ops/kernels/parallel_identity_kernels.cpp new file mode 100644 index 0000000000..8378231fb2 --- /dev/null +++ b/src/parallel_ops/kernels/parallel_identity_kernels.cpp @@ -0,0 +1,97 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/kernels/parallel_identity_kernels.h" +#include "flexflow/ffconst_utils.h" +#include "flexflow/utils/hip_helper.h" +#include + +namespace FlexFlow { + +ParallelIdentityMeta::ParallelIdentityMeta(FFHandler handle, + ParallelIdentity const *reduct) + : OpMeta(handle, reduct) {} + +namespace Kernels { +namespace ParallelIdentity { + +void forward_kernel_wrapper(ParallelIdentityMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); + size_t data_size = data_type_size(input.data_type); + // copy input to output + checkCUDA(hipMemcpyAsync(output.ptr, + input.ptr, + input.domain.get_volume() * data_size, + hipMemcpyDeviceToDevice, + stream)); +} + +void backward_kernel_wrapper(ParallelIdentityMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + assert(false && "To be implemented"); +} + +void inference_kernel_wrapper(ParallelIdentityMeta const *m, + BatchConfig const *bc, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); + size_t hidden_dim_size = input.domain.hi()[0] - input.domain.lo()[0] + 1; + size_t num_elements = bc->num_active_tokens(); + size_t data_size = data_type_size(input.data_type); + checkCUDA(hipMemcpyAsync(output.ptr, + input.ptr, + hidden_dim_size * num_elements * data_size, + hipMemcpyDeviceToDevice, + stream)); +} + +void peft_bwd_kernel_wrapper(ParallelIdentityMeta const *m, + BatchConfig const *bc, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input_grad.data_type == output_grad.data_type); + assert(input_grad.domain == output_grad.domain); + size_t hidden_dim_size = + input_grad.domain.hi()[0] - input_grad.domain.lo()[0] + 1; + size_t num_elements = bc->num_active_tokens() * hidden_dim_size; +#ifdef FF_USE_NCCL + ncclDataType_t nccl_data_type = ff_to_nccl_datatype(input_grad.data_type); + checkNCCL(ncclAllReduce(output_grad.ptr, + input_grad.ptr, + num_elements, + nccl_data_type, + ncclSum, + m->handle.ncclComm, + stream)); +#else + assert(false && "Must enable FF_USE_NCCL to use ParallelIdentity operators"); +#endif +} + +} // namespace ParallelIdentity +} // namespace Kernels +} // namespace FlexFlow diff --git a/src/parallel_ops/parallel_identity.cc b/src/parallel_ops/parallel_identity.cc index 56f9b1beac..883910ae09 100644 --- a/src/parallel_ops/parallel_identity.cc +++ b/src/parallel_ops/parallel_identity.cc @@ -47,9 +47,7 @@ using namespace FlexFlow::Kernels::ParallelIdentity; bool operator==(ParallelIdentityParams const &lhs, ParallelIdentityParams const &rhs) { return lhs.parallel_identity_legion_dim == rhs.parallel_identity_legion_dim && - ((lhs.name == NULL && rhs.name == NULL) || - (lhs.name != NULL && rhs.name != NULL && - std::strcmp(lhs.name, rhs.name) == 0)); + std::strcmp(lhs.name, rhs.name) == 0; } bool ParallelIdentityParams::is_valid(ParallelTensorShape const &input) const { @@ -59,7 +57,7 @@ bool ParallelIdentityParams::is_valid(ParallelTensorShape const &input) const { ParallelIdentityParams ParallelIdentity::get_params() const { ParallelIdentityParams params; params.parallel_identity_legion_dim = this->parallel_identity_dim; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/parallel_ops/partition.cc b/src/parallel_ops/partition.cc index 303e45c4ad..fddf739599 100644 --- a/src/parallel_ops/partition.cc +++ b/src/parallel_ops/partition.cc @@ -45,9 +45,7 @@ using namespace FlexFlow::Kernels::Repartition; bool operator==(RepartitionParams const &lhs, RepartitionParams const &rhs) { return lhs.repartition_legion_dim == rhs.repartition_legion_dim && lhs.repartition_degree == rhs.repartition_degree && - ((lhs.name == NULL && rhs.name == NULL) || - (lhs.name != NULL && rhs.name != NULL && - std::strcmp(lhs.name, rhs.name) == 0)); + std::strcmp(lhs.name, rhs.name) == 0; } bool RepartitionParams::is_valid(ParallelTensorShape const &input) const { @@ -63,7 +61,7 @@ RepartitionParams Repartition::get_params() const { RepartitionParams params; params.repartition_legion_dim = this->repartition_dim; params.repartition_degree = this->repartition_degree; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/parallel_ops/reduction.cc b/src/parallel_ops/reduction.cc index 26c0dfeeab..7306e04334 100644 --- a/src/parallel_ops/reduction.cc +++ b/src/parallel_ops/reduction.cc @@ -46,9 +46,7 @@ using namespace FlexFlow::Kernels::Reduction; bool operator==(ReductionParams const &lhs, ReductionParams const &rhs) { return lhs.reduction_legion_dim == rhs.reduction_legion_dim && lhs.reduction_degree == rhs.reduction_degree && - ((lhs.name == NULL && rhs.name == NULL) || - (lhs.name != NULL && rhs.name != NULL && - std::strcmp(lhs.name, rhs.name) == 0)); + std::strcmp(lhs.name, rhs.name) == 0; } bool ReductionParams::is_valid(ParallelTensorShape const &input) const { @@ -59,7 +57,7 @@ ReductionParams Reduction::get_params() const { ReductionParams params; params.reduction_legion_dim = this->reduction_dim; params.reduction_degree = this->reduction_degree; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/parallel_ops/replicate.cc b/src/parallel_ops/replicate.cc index edee8875de..38215fc903 100644 --- a/src/parallel_ops/replicate.cc +++ b/src/parallel_ops/replicate.cc @@ -45,9 +45,7 @@ using namespace FlexFlow::Kernels::Replicate; bool operator==(ReplicateParams const &lhs, ReplicateParams const &rhs) { return lhs.replicate_legion_dim == rhs.replicate_legion_dim && lhs.replicate_degree == rhs.replicate_degree && - ((lhs.name == NULL && rhs.name == NULL) || - (lhs.name != NULL && rhs.name != NULL && - std::strcmp(lhs.name, rhs.name) == 0)); + std::strcmp(lhs.name, rhs.name) == 0; } bool ReplicateParams::is_valid(ParallelTensorShape const &input) const { @@ -58,7 +56,7 @@ ReplicateParams Replicate::get_params() const { ReplicateParams params; params.replicate_legion_dim = this->replicate_dim; params.replicate_degree = this->replicate_degree; - if (this->name != nullptr) { + if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } return params; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 8a2a2496e1..31a32dd3c8 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -2515,7 +2515,7 @@ std::vector> // must in this branch. int layer_slot = i - processed_whole_layer_tokens; int layer_slot_total = treeLayers[layer_num]; - if ((first_layer_slot == layer_slot)) { + if (first_layer_slot == layer_slot) { verifiedTree.push_back(output); new_committed_tokens.push_back(std::make_pair( input.second, committed_tokens.at(guid).at(i).second));