From e267b529f39da4552ca7ec18159ade6dcec2da0f Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sun, 9 Jul 2023 14:02:28 +0000 Subject: [PATCH 01/16] init --- include/flexflow/ffconst.h | 1 + include/flexflow/model.h | 8 + include/flexflow/operator_params.h | 1 + include/flexflow/ops/sampling.h | 91 +++++++ include/flexflow/ops/sampling_params.h | 24 ++ src/ops/sampling.cc | 351 +++++++++++++++++++++++++ src/ops/sampling.cu | 315 ++++++++++++++++++++++ src/runtime/ffconst_utils.cc | 2 + src/runtime/graph.cc | 5 + src/runtime/model.cc | 21 ++ 10 files changed, 819 insertions(+) create mode 100644 include/flexflow/ops/sampling.h create mode 100644 include/flexflow/ops/sampling_params.h create mode 100644 src/ops/sampling.cc create mode 100644 src/ops/sampling.cu diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 0b572a9674..5b30867169 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -167,6 +167,7 @@ enum OperatorType { OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION, OP_TREE_INC_MULTIHEAD_SELF_ATTENTION, OP_INC_MULTIQUERY_SELF_ATTENTION, + OP_SAMPLING, // Parallel Ops OP_REPARTITION, OP_COMBINE, diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 1277b29b3d..2d56f87713 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -135,6 +135,8 @@ enum TaskIDs { TOPK_BWD_TASK_ID, ARG_TOPK_INIT_TASK_ID, ARG_TOPK_INF_TASK_ID, + SAMPLING_INIT_TASK_ID, + SAMPLING_INF_TASK_ID, TRANSPOSE_INIT_TASK_ID, TRANSPOSE_FWD_TASK_ID, TRANSPOSE_BWD_TASK_ID, @@ -307,6 +309,7 @@ class RMSNorm; class BeamTopK; class SpecIncMultiHeadSelfAttention; class IncMultiQuerySelfAttention; +class Sampling; class Combine; class Repartition; class Reduction; @@ -606,6 +609,9 @@ class FFModel { int k, bool sorted, char const *name = NULL); + Tensor sampling(const Tensor input, + float top_p, + char const *name = NULL); Tensor multihead_attention(const Tensor query, const Tensor key, const Tensor value, @@ -1050,6 +1056,8 @@ class FFModel { IncMultiQuerySelfAttention *>, std::unordered_map, BeamTopK *>, + std::unordered_map, + Sampling *>, std::unordered_map< std::pair, SpecIncMultiHeadSelfAttention *>, diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 8c52dfb584..3d3edcbc01 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -30,6 +30,7 @@ #include "flexflow/ops/spec_inc_multihead_self_attention_params.h" #include "flexflow/ops/split_params.h" #include "flexflow/ops/topk_params.h" +#include "flexflow/ops/sampling_params.h" #include "flexflow/ops/transpose_params.h" #include "flexflow/ops/tree_inc_multihead_self_attention_params.h" #include "flexflow/parallel_ops/combine_params.h" diff --git a/include/flexflow/ops/sampling.h b/include/flexflow/ops/sampling.h new file mode 100644 index 0000000000..e57ed0b870 --- /dev/null +++ b/include/flexflow/ops/sampling.h @@ -0,0 +1,91 @@ +#ifndef _FLEXFLOW_SAMPLING_TOPK_H_ +#define _FLEXFLOW_SAMPLING_TOPK_H_ + +#include "flexflow/inference.h" +#include "flexflow/model.h" +#include "flexflow/node.h" +#include "flexflow/ops/sampling_params.h" + +namespace FlexFlow { + +class SamplingMeta : public OpMeta { +public: + float top_p; + void *cumsum_ptr; + void *sampled; + SamplingMeta(FFHandler handle, Op const *op); +}; + +class Sampling : public Op { +public: + using Params = SamplingParams; + using Input = ParallelTensor; + Sampling(FFModel &model, + const ParallelTensor input, + float top_p, + char const *name); + Sampling(FFModel &model, Sampling const &other, const ParallelTensor input); + Sampling(FFModel &model, + Params const ¶ms, + Input const input, + char const *name = nullptr); + void init(FFModel const &) override; + void init_inference(FFModel const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void forward(FFModel const &) override; + void backward(FFModel const &) override; + Legion::FutureMap inference(FFModel const &, + BatchConfig const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void print_layer(FFModel const &model) override { + assert(0); + } + static Op * + create_operator_from_layer(FFModel &model, + Layer const *layer, + std::vector const &inputs); + + static OpMeta *init_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static InferenceResult + inference_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + void serialize(Legion::Serializer &s) const override; + static PCG::Node deserialize(FFModel &ff, + Legion::Deserializer &d, + ParallelTensor inputs[], + int num_inputs); + Op *materialize(FFModel &ff, + ParallelTensor inputs[], + int num_inputs) const override; + bool measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const override; + template + static void forward_kernel(SamplingMeta const *m, + DT *input_ptr, + int *indices_ptr, + float top_p, + int length, + int batch_size, + ffStream_t stream); + static void forward_kernel_wrapper(SamplingMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices); + Params get_params() const; + +public: + float top_p; +}; + +}; // namespace FlexFlow + +#endif \ No newline at end of file diff --git a/include/flexflow/ops/sampling_params.h b/include/flexflow/ops/sampling_params.h new file mode 100644 index 0000000000..1449ddbf54 --- /dev/null +++ b/include/flexflow/ops/sampling_params.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_SAMPLING_PARAMS_H +#define _FLEXFLOW_SAMPLING_PARAMS_H + +#include "flexflow/ffconst.h" +#include "flexflow/parallel_tensor.h" + +namespace FlexFlow { + +struct SamplingParams { + float top_p; + bool is_valid(ParallelTensorShape const &) const; +}; +bool operator==(SamplingParams const &, SamplingParams const &); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SamplingParams const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_SAMPLING_PARAMS_H \ No newline at end of file diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc new file mode 100644 index 0000000000..2b544580da --- /dev/null +++ b/src/ops/sampling.cc @@ -0,0 +1,351 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/ops/sampling.h" +#include "flexflow/model.h" +#include "flexflow/utils/hash_utils.h" +#include "legion/legion_utilities.h" +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +#include "flexflow/utils/cuda_helper.h" +#else +#include "flexflow/utils/hip_helper.h" +#endif + +namespace FlexFlow { +// declare Legion names +using Legion::ArgumentMap; +using Legion::Context; +using Legion::coord_t; +using Legion::Domain; +using Legion::FutureMap; +using Legion::IndexLauncher; +using Legion::InlineLauncher; +using Legion::Machine; +using Legion::Memory; +using Legion::PhysicalRegion; +using Legion::Predicate; +using Legion::Rect; +using Legion::RegionRequirement; +using Legion::Runtime; +using Legion::Task; +using Legion::TaskArgument; +using Legion::TaskLauncher; +using PCG::Node; + +// For an input tensor, computes the top k entries in each row +// (resp. vector along the last dimension). Thus, +// values.shape = indices.shape = input.shape[:-1] + [k] +Tensor FFModel::sampling(const Tensor input, float top_p, char const *name) { + Layer *li = new Layer(this, + OP_SAMPLING, + input->data_type, + name, + 1 /*inputs*/, + 0 /*weights*/, + 1 /*outputs*/, + input); + { + int numdims = input->num_dims; + int dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdims; i++) { + dims[i] = input->dims[i]; + } + // now just support 1 output + dims[0] = 1; + // li->outputs[0] = create_tensor_legion_ordering( + // numdims, dims, input->data_type, li, 0, true /*create_grad*/); + li->outputs[0] = create_tensor_legion_ordering( + numdims, dims, DT_INT32, li, 0, false /*create_grad*/); + } + layers.push_back(li); + li->add_float_property("top_p", top_p); + // outputs[0] = li->outputs[0]; + // outputs[1] = li->outputs[1]; + return li->outputs[0]; +} + +Op *Sampling::create_operator_from_layer( + FFModel &model, + Layer const *layer, + std::vector const &inputs) { + float top_p; + layer->get_float_property("top_p", top_p); + return new Sampling(model, inputs[0], top_p, layer->name); +} + +SamplingParams Sampling::get_params() const { + SamplingParams params; + params.top_p = this->top_p; + return params; +} + +bool SamplingParams::is_valid(ParallelTensorShape const &) const { + // topk is always valid + return true; +} + +bool operator==(SamplingParams const &lhs, SamplingParams const &rhs) { + return lhs.top_p == rhs.top_p; +} + +Sampling::Sampling(FFModel &model, + const ParallelTensor _input, + float _top_p, + char const *name) + : Op(model, + OP_SAMPLING, + _input->data_type, + name, + 1 /*inputs*/, + 0 /*weights*/, + 1 /*outputs*/, + _input), + top_p(_top_p) { + int numdim = inputs[0]->num_dims; + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdim; i++) { + dims[i] = inputs[0]->dims[i]; + } + dims[0].size = 1; + assert(inputs[0]->dims[0].degree == 1); + assert(inputs[0]->dims[0].parallel_idx == -1); + // outputs[0] = model.create_parallel_tensor_legion_ordering( + // numdim, dims, _input->data_type, this, 0 /*owner_idx*/); + outputs[0] = model.create_parallel_tensor_legion_ordering( + numdim, dims, DT_INT32, this, 0 /*owner_idx*/); +} + +Sampling::Sampling(FFModel &model, + Sampling const &other, + const ParallelTensor input) + : Sampling(model, input, other.top_p, other.name) {} + +Sampling::Sampling(FFModel &model, + SamplingParams const ¶ms, + const ParallelTensor input, + char const *name) + : Sampling(model, input, params.top_p, name) {} + +void Sampling::init_inference(FFModel const &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + assert(check_output_input_weight_same_parallel_is()); + parallel_is = batch_outputs[0]->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + size_t machine_view_hash = view->hash(); + set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]); + IndexLauncher launcher(ARG_TOPK_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(Sampling)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + // launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, + // 0 /*projection id*/, + // WRITE_ONLY, + // EXCLUSIVE, + // batch_outputs[1]->region)); + // launcher.add_field(2, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); +} + +void Sampling::init(FFModel const &ff) { + assert(check_output_input_weight_same_parallel_is()); + parallel_is = outputs[0]->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + set_argumentmap_for_init(ff, argmap); + IndexLauncher launcher(ARG_TOPK_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(Sampling)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + // launcher.add_region_requirement(RegionRequirement(outputs[1]->part, + // 0 /*projection id*/, + // WRITE_ONLY, + // EXCLUSIVE, + // outputs[1]->region)); + // launcher.add_field(2, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap(ff, fm); +} + +OpMeta *Sampling::init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + Sampling *s = (Sampling *)task->args; + FFHandler handle = *((FFHandler *)task->local_args); + SamplingMeta *m = new SamplingMeta(handle, s); + m->profiling = s->profiling; + m->top_p = s->top_p; + return m; +} + +void Sampling::forward(FFModel const &ff) { + // Sampling does not support forward + assert(false); +} + +FutureMap Sampling::inference(FFModel const &ff, + BatchConfig const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = batch_outputs[0]->parallel_is; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + size_t machine_view_hash = view->hash(); + /* std::cout << "Sampling op machine_view: " << *(MachineView const *)mv + << std::endl; */ + IndexLauncher launcher(ARG_TOPK_INF_TASK_ID, + parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + // launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, + // 0 /*projection id*/, + // WRITE_ONLY, + // EXCLUSIVE, + // batch_outputs[1]->region)); + // launcher.add_field(2, FID_DATA); + return runtime->execute_index_space(ctx, launcher); +} + +InferenceResult + Sampling::inference_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + // const Sampling* topk = (const Sampling*) task->args; + SamplingMeta const *m = *((SamplingMeta **)task->local_args); + + GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( + DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + + Sampling::forward_kernel_wrapper(m, input, indices); + + int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; + int batch_size = input.domain.get_volume() / length; + + InferenceResult ir; + download_tensor( + indices.get_int32_ptr(), ir.token_ids, batch_size); + return ir; +} + +void Sampling::backward(FFModel const &ff) { + // Sampling does not support backward + assert(false); +} + +void Sampling::serialize(Legion::Serializer &sez) const { + sez.serialize(this->top_p); +} + +Node Sampling::deserialize(FFModel &ff, + Legion::Deserializer &dez, + ParallelTensor inputs[], + int num_inputs) { + assert(num_inputs == 1); + float top_p; + dez.deserialize(top_p); + SamplingParams params; + params.top_p = top_p; + return ff.get_or_create_node(inputs[0], params); +} + +Op *Sampling::materialize(FFModel &ff, + ParallelTensor inputs[], + int num_inputs) const { + SamplingParams params = get_params(); + return new Sampling(ff, params, inputs[0], this->name); +} + +bool Sampling::measure_operator_cost(Simulator *sim, + MachineView const &mv, + CostMetrics &cost_metrics) const { + return false; +} + +}; // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SamplingParams const ¶ms) const { + size_t key = 0; + hash_combine(key, params.top_p); + return key; +} +}; // namespace std \ No newline at end of file diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu new file mode 100644 index 0000000000..904b1a0cf6 --- /dev/null +++ b/src/ops/sampling.cu @@ -0,0 +1,315 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/ops/sampling.h" +#include "flexflow/utils/cuda_helper.h" +#include +#include + +namespace FlexFlow { + +template +__global__ void mask_value_above_top_p(DT *input_ptr, + DT *cumsum_ptr, + float top_p, + int total_eles) { + CUDA_KERNEL_LOOP(i, total_eles) { + if ((cumsum_ptr[i] - input_ptr[i]) > static_cast
(top_p)) { + input_ptr[i] = 0.0; + } + } +} + +template +__global__ void re_normalized(DT *input_ptr, DT div, int length) { + CUDA_KERNEL_LOOP(i, length) { + input_ptr[i] /= div; + } +} + +template +__global__ void sampleMultinomialOnce(long long N, DT *input_ptr) { + extern __shared__ unsigned char my_smem[]; + __shared__ bool found; + __shared__ unsigned foundPos; + + float *smem = reinterpret_cast(my_smem); + + float accZero = static_cast(0); + DT zero = static_cast
(0); + + for (int64_t curDist = blockIdx.x; curDist < distributions; + curDist += gridDim.x) { + + float sum = accZero; + DT val; + + for (int cat = threadIdx.x; cat < N; cat += blockDim.x) { + val = dist[curDist * stride_dist + cat * stride_categories]; + CUDA_KERNEL_ASSERT(!at::_isnan(val)); + CUDA_KERNEL_ASSERT(!_isinf(val)); + CUDA_KERNEL_ASSERT(!(val < zero)); + sum = sum + static_cast(val); + } + + + //sum + sum = BlockReduceSum(sum, smem); + + if (threadIdx.x == 0) { + foundPos = 0; + smem[0] = sum; + smem[1] = sampled[curDist]; + } + + __syncthreads(); + sum = smem[0]; + + DT sample = static_cast
(smem[1]); + __syncthreads(); + + if (sum == accZero) { + // Choose the first element + if (threadIdx.x == 0) { + dest[curDist] = 0; + } + + continue; + } + + //ELSE + int chunks = (categories + (int)blockDim.x - 1) / blockDim.x; + float prevHighProb = accZero; + + found = false; + for (int chunk = 0; chunk < chunks && !found; ++chunk) { + + int cat = chunk * blockDim.x + threadIdx.x; + float dist_val = cat < categories ? + static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : + accZero; + + smem[threadIdx.x] = dist_val; + __syncthreads(); + + // Perform an inclusive prefix sum of the shared memory contents + for (int offset = 1; offset < blockDim.x; offset *= 2) { + float val = accZero; + + if (threadIdx.x >= offset) { + val = smem[threadIdx.x - offset] + smem[threadIdx.x]; + } + + __syncthreads(); + if (threadIdx.x >= offset) { + smem[threadIdx.x] = val; + } + __syncthreads(); + } + + // Each thread will check to see if the sample falls in its + // bucket + DT curBucket = + static_cast
(smem[threadIdx.x] + prevHighProb); + DT prevBucket = static_cast
( + threadIdx.x == 0 ? prevHighProb + : smem[threadIdx.x - 1] + prevHighProb); + bool inBucket = + (cat < categories) && + (!(sample >= curBucket) && + (sample >= prevBucket) && + (dist_val > zero)); + + if (inBucket) { + // We're done; we have the sample + // Torch indices are 1-based + atomicMax(&foundPos, cat); + found = true; + } + + // Store the previous scan's high value for future use + prevHighProb = prevHighProb + smem[blockDim.x - 1]; + __syncthreads(); + } + + if (threadIdx.x == 0) { + if (found) { + dest[curDist] = foundPos; + } else { + // This should address a rare bug where we don't select a valid index. This likely occurs when + // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but + // and our uniform sample is greater than this value. In this case we likely have unitialized memory + // in dest[curDist]. So basically we will loop through the distribution and pick the largest index + // where the distribution is non-zero. This is obviously terribly inefficient, but due to the + // rarity in which this occurs, this should not be an issue. + for (int cat = categories - 1; cat >= 0; --cat) { + if (dist[curDist * stride_dist + cat * stride_categories] > zero) { + dest[curDist] = cat; + break; + } + } + } + } + + + } +} + + +/*static*/ +template +void Sampling::forward_kernel(SamplingMeta const *m, + DT *input_ptr, + int *indices_ptr, + float top_p, + int length, + int batch_size, + cudaStream_t stream) { + // 1. sort + // 2. cumsum + // how to do it in parallel? + + checkCUDA(cudaMemcpy(static_cast
(m->origin_ptr), + input_ptr, + sizeof(DT) * 15 * length, + cudaMemcpyDeviceToDevice)); + + std::cout << "asdqs: " << length << "\n"; + + for (int i = 0; i < 15; i++) { + thrust::sort(thrust::device, + input_ptr + i * length, + input_ptr + (i + 1) * length, + thrust::greater
()); + thrust::sort(thrust::device, + static_cast
(m->origin_ptr) + i * length, + static_cast
(m->origin_ptr) + (i + 1) * length, + thrust::greater
()); + thrust::inclusive_scan(thrust::device, + input_ptr + i * length, + input_ptr + (i + 1) * length, + static_cast
(m->cumsum_ptr) + i * length); + } + std::cout << "sdsd" + << "\n"; + + // 3. mask + int parallelism = 15 * length; + mask_value_above_top_p
<<>>( + input_ptr, static_cast
(m->cumsum_ptr), top_p, parallelism); + + // 4. sum/div + std::cout << "sadsd2www" + << "\n"; + for (int i = 0; i < 15; i++) { + DT sum = thrust::reduce( + thrust::device, input_ptr + i * length, input_ptr + (i + 1) * length); + parallelism = length; + + re_normalized
<<>>(input_ptr + i * length, sum, length); + } + std::cout << "sdds332" + << "\n"; + + // 5.multinominal + for (int i = 0; i < 15; i++) { + parallelism = length; + DT random = static_cast
(((float)std::rand()) / RAND_MAX); + thrust::inclusive_scan(thrust::device, + input_ptr + i * length, + input_ptr + (i + 1) * length, + static_cast
(m->cumsum_ptr) + i * length); + + // find_idx
<<>>(static_cast
(m->cumsum_ptr) + i * length, + // static_cast
(m->origin_ptr) + i * length, + // random, + // length, + // indices_ptr, + // i); + for (int j = 0; j < length; j++) { + if ((static_cast
(m->cumsum_ptr) + i * length)[j] >= random) { + indices_ptr[i] = (static_cast
(m->origin_ptr) + i * length)[i]; + printf("k value is:%d. %f\n", i, indices_ptr[i]); + break; + } + } + } + // print_tensor((int *)indices_ptr, 15, "sdsdasd"); + assert(false); +} + +/*static*/ +void Sampling::forward_kernel_wrapper(SamplingMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + cudaEvent_t t_start, t_end; + if (m->profiling) { + cudaEventCreate(&t_start); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); + } + int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; + int batch_size = input.domain.get_volume() / length; + + if (input.data_type == DT_HALF) { + Sampling::forward_kernel(m, + input.get_half_ptr(), + indices.get_int32_ptr(), + m->top_p, + length, + batch_size, + stream); + } else if (input.data_type == DT_FLOAT) { + Sampling::forward_kernel(m, + input.get_float_ptr(), + indices.get_int32_ptr(), + m->top_p, + length, + batch_size, + stream); + } else { + assert(false && "Unsupported data type"); + } + + if (m->profiling) { + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); + cudaEventDestroy(t_start); + cudaEventDestroy(t_end); + printf("[Sampling] forward time = %.2lfms\n", elapsed); + } +} + +SamplingMeta::SamplingMeta(FFHandler handler, Op const *op) + : OpMeta(handler, op) { + checkCUDA(cudaMalloc(&cumsum_ptr, 15 * 32000 * sizeof(float))); + checkCUDA(cudaMalloc(&sampled, 15 * 32000 * sizeof(float))); +} + +}; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index d2b68595bd..bb50a4d969 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -178,6 +178,8 @@ std::string get_operator_type_name(OperatorType type) { return "GELU"; case OP_IDENTITY: return "Identity"; + case OP_SAMPLING: + return "Sampling"; // Parallel Ops case OP_REPARTITION: return "Repartition"; diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index e8a1b6f9f1..067a846c58 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -44,6 +44,7 @@ #include "flexflow/ops/spec_inc_multihead_self_attention.h" #include "flexflow/ops/split.h" #include "flexflow/ops/topk.h" +#include "flexflow/ops/sampling.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" #include "flexflow/parallel_ops/combine.h" @@ -2866,6 +2867,10 @@ void FFModel::deserialize_graph_optimal_view( node = BeamTopK::deserialize(*this, dez, inputs, num_inputs); break; } + case OP_SAMPLING: { + node = Sampling::deserialize(*this, dez, inputs, num_inputs); + break; + } case OP_GROUP_BY: { node = Group_by::deserialize(*this, dez, inputs, num_inputs); break; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 64c3a2eb61..96fedae9d4 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -56,6 +56,7 @@ #include "flexflow/ops/spec_inc_multihead_self_attention.h" #include "flexflow/ops/split.h" #include "flexflow/ops/topk.h" +#include "flexflow/ops/sampling.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" #include "flexflow/parallel_ops/combine.h" @@ -2929,6 +2930,11 @@ Op *FFModel::create_operator_from_layer( operators.push_back(op); return op; } + case OP_SAMPLING: { + Op *op = Sampling::create_operator_from_layer(*this, layer, inputs); + operators.push_back(op); + return op; + } case OP_GROUP_BY: { Op *op = Group_by::create_operator_from_layer(*this, layer, inputs); operators.push_back(op); @@ -4687,6 +4693,21 @@ void register_flexflow_internal_tasks() { BeamTopK::inference_task>( registrar, "BeamTopK Inference Task"); } + // Sampling task + { + TaskVariantRegistrar registrar(SAMPLING_INIT_TASK_ID, "Sampling Init"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + Runtime::preregister_task_variant( + registrar, "Sampling Init Task"); + } + { + TaskVariantRegistrar registrar(SAMPLING_INF_TASK_ID, "Sampling Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + Runtime::preregister_task_variant( + registrar, "Sampling Inference Task"); + } // Transpose task { TaskVariantRegistrar registrar(TRANSPOSE_INIT_TASK_ID, "Transpose Init"); From 65c070e1bc7d47c4f1a927f1fd94eba78c3aef2e Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 10 Jul 2023 04:52:18 +0000 Subject: [PATCH 02/16] sort --- include/flexflow/ops/sampling.h | 10 +- inference/models/llama.cc | 4 +- src/ops/sampling.cc | 14 +- src/ops/sampling.cu | 294 +++++++++-------------------- src/ops/sampling_ref.cu | 316 ++++++++++++++++++++++++++++++++ 5 files changed, 418 insertions(+), 220 deletions(-) create mode 100644 src/ops/sampling_ref.cu diff --git a/include/flexflow/ops/sampling.h b/include/flexflow/ops/sampling.h index e57ed0b870..50830ce4a6 100644 --- a/include/flexflow/ops/sampling.h +++ b/include/flexflow/ops/sampling.h @@ -11,9 +11,13 @@ namespace FlexFlow { class SamplingMeta : public OpMeta { public: float top_p; - void *cumsum_ptr; - void *sampled; - SamplingMeta(FFHandler handle, Op const *op); + void *sorted_logits; + int *sorted_idx; + int *begin_offset; + int *end_offset; + int *idx; + void *d_temp_storage; + SamplingMeta(FFHandler handle, Op const *op, int batch_size, int total_ele); }; class Sampling : public Op { diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 1e61f43a98..3b4ddbd6d3 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -241,7 +241,9 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor softmax = ff.softmax(dense, -1); output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); } else { - output = ff.arg_top_k(dense, /*k=*/1, false); + // output = ff.arg_top_k(dense, /*k=*/1, false); + Tensor softmax = ff.softmax(dense, -1); + output = ff.sampling(softmax, 0.95); } // Compile the model diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc index 2b544580da..bb63c6ff90 100644 --- a/src/ops/sampling.cc +++ b/src/ops/sampling.cc @@ -150,7 +150,7 @@ void Sampling::init_inference(FFModel const &ff, MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; size_t machine_view_hash = view->hash(); set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]); - IndexLauncher launcher(ARG_TOPK_INIT_TASK_ID, + IndexLauncher launcher(SAMPLING_INIT_TASK_ID, parallel_is, TaskArgument(this, sizeof(Sampling)), argmap, @@ -188,7 +188,7 @@ void Sampling::init(FFModel const &ff) { Context ctx = ff.config.lg_ctx; Runtime *runtime = ff.config.lg_hlr; set_argumentmap_for_init(ff, argmap); - IndexLauncher launcher(ARG_TOPK_INIT_TASK_ID, + IndexLauncher launcher(SAMPLING_INIT_TASK_ID, parallel_is, TaskArgument(this, sizeof(Sampling)), argmap, @@ -225,7 +225,13 @@ OpMeta *Sampling::init_task(Task const *task, Runtime *runtime) { Sampling *s = (Sampling *)task->args; FFHandler handle = *((FFHandler *)task->local_args); - SamplingMeta *m = new SamplingMeta(handle, s); + GenericTensorAccessorW acc_input = helperGetGenericTensorAccessorRW( + s->inputs[0]->data_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); + + int length = acc_input.domain.hi()[0] - acc_input.domain.lo()[0] + 1; + int batch_size = acc_input.domain.get_volume() / length; + + SamplingMeta *m = new SamplingMeta(handle, s, batch_size, length * batch_size); m->profiling = s->profiling; m->top_p = s->top_p; return m; @@ -250,7 +256,7 @@ FutureMap Sampling::inference(FFModel const &ff, size_t machine_view_hash = view->hash(); /* std::cout << "Sampling op machine_view: " << *(MachineView const *)mv << std::endl; */ - IndexLauncher launcher(ARG_TOPK_INF_TASK_ID, + IndexLauncher launcher(SAMPLING_INF_TASK_ID, parallel_is, TaskArgument(NULL, 0), argmap, diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index 904b1a0cf6..88192f4677 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -13,10 +13,10 @@ * limitations under the License. */ +#include "cub/cub.cuh" #include "flexflow/ops/sampling.h" #include "flexflow/utils/cuda_helper.h" -#include -#include +#include "flexflow/ffconst_utils.h" namespace FlexFlow { @@ -32,142 +32,22 @@ __global__ void mask_value_above_top_p(DT *input_ptr, } } -template -__global__ void re_normalized(DT *input_ptr, DT div, int length) { - CUDA_KERNEL_LOOP(i, length) { - input_ptr[i] /= div; - } -} - -template -__global__ void sampleMultinomialOnce(long long N, DT *input_ptr) { - extern __shared__ unsigned char my_smem[]; - __shared__ bool found; - __shared__ unsigned foundPos; - - float *smem = reinterpret_cast(my_smem); - - float accZero = static_cast(0); - DT zero = static_cast
(0); - - for (int64_t curDist = blockIdx.x; curDist < distributions; - curDist += gridDim.x) { - - float sum = accZero; - DT val; - - for (int cat = threadIdx.x; cat < N; cat += blockDim.x) { - val = dist[curDist * stride_dist + cat * stride_categories]; - CUDA_KERNEL_ASSERT(!at::_isnan(val)); - CUDA_KERNEL_ASSERT(!_isinf(val)); - CUDA_KERNEL_ASSERT(!(val < zero)); - sum = sum + static_cast(val); - } - - - //sum - sum = BlockReduceSum(sum, smem); - - if (threadIdx.x == 0) { - foundPos = 0; - smem[0] = sum; - smem[1] = sampled[curDist]; - } - - __syncthreads(); - sum = smem[0]; - - DT sample = static_cast
(smem[1]); - __syncthreads(); - - if (sum == accZero) { - // Choose the first element - if (threadIdx.x == 0) { - dest[curDist] = 0; - } - - continue; - } - - //ELSE - int chunks = (categories + (int)blockDim.x - 1) / blockDim.x; - float prevHighProb = accZero; - - found = false; - for (int chunk = 0; chunk < chunks && !found; ++chunk) { - - int cat = chunk * blockDim.x + threadIdx.x; - float dist_val = cat < categories ? - static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : - accZero; - - smem[threadIdx.x] = dist_val; - __syncthreads(); - - // Perform an inclusive prefix sum of the shared memory contents - for (int offset = 1; offset < blockDim.x; offset *= 2) { - float val = accZero; - - if (threadIdx.x >= offset) { - val = smem[threadIdx.x - offset] + smem[threadIdx.x]; - } - - __syncthreads(); - if (threadIdx.x >= offset) { - smem[threadIdx.x] = val; - } - __syncthreads(); - } - - // Each thread will check to see if the sample falls in its - // bucket - DT curBucket = - static_cast
(smem[threadIdx.x] + prevHighProb); - DT prevBucket = static_cast
( - threadIdx.x == 0 ? prevHighProb - : smem[threadIdx.x - 1] + prevHighProb); - bool inBucket = - (cat < categories) && - (!(sample >= curBucket) && - (sample >= prevBucket) && - (dist_val > zero)); - - if (inBucket) { - // We're done; we have the sample - // Torch indices are 1-based - atomicMax(&foundPos, cat); - found = true; - } - - // Store the previous scan's high value for future use - prevHighProb = prevHighProb + smem[blockDim.x - 1]; - __syncthreads(); - } - - if (threadIdx.x == 0) { - if (found) { - dest[curDist] = foundPos; - } else { - // This should address a rare bug where we don't select a valid index. This likely occurs when - // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but - // and our uniform sample is greater than this value. In this case we likely have unitialized memory - // in dest[curDist]. So basically we will loop through the distribution and pick the largest index - // where the distribution is non-zero. This is obviously terribly inefficient, but due to the - // rarity in which this occurs, this should not be an issue. - for (int cat = categories - 1; cat >= 0; --cat) { - if (dist[curDist * stride_dist + cat * stride_categories] > zero) { - dest[curDist] = cat; - break; - } - } - } +__global__ void init_idxs(int batch_size, + int vocab_size, + int total_eles, + int *idx, + int *begin_offset, + int *end_offset) { + CUDA_KERNEL_LOOP(i, total_eles) { + idx[i] = i % vocab_size; + if (i % vocab_size == 0) { + // printf("adfadf :%d\n", i); + begin_offset[i / vocab_size] = i; + end_offset[i / vocab_size] = i; } - - } } - /*static*/ template void Sampling::forward_kernel(SamplingMeta const *m, @@ -180,81 +60,62 @@ void Sampling::forward_kernel(SamplingMeta const *m, // 1. sort // 2. cumsum // how to do it in parallel? - - checkCUDA(cudaMemcpy(static_cast
(m->origin_ptr), - input_ptr, - sizeof(DT) * 15 * length, - cudaMemcpyDeviceToDevice)); - - std::cout << "asdqs: " << length << "\n"; - - for (int i = 0; i < 15; i++) { - thrust::sort(thrust::device, - input_ptr + i * length, - input_ptr + (i + 1) * length, - thrust::greater
()); - thrust::sort(thrust::device, - static_cast
(m->origin_ptr) + i * length, - static_cast
(m->origin_ptr) + (i + 1) * length, - thrust::greater
()); - thrust::inclusive_scan(thrust::device, - input_ptr + i * length, - input_ptr + (i + 1) * length, - static_cast
(m->cumsum_ptr) + i * length); - } - std::cout << "sdsd" - << "\n"; - - // 3. mask - int parallelism = 15 * length; - mask_value_above_top_p
<<>>( - input_ptr, static_cast
(m->cumsum_ptr), top_p, parallelism); - - // 4. sum/div - std::cout << "sadsd2www" - << "\n"; - for (int i = 0; i < 15; i++) { - DT sum = thrust::reduce( - thrust::device, input_ptr + i * length, input_ptr + (i + 1) * length); - parallelism = length; - - re_normalized
<<>>(input_ptr + i * length, sum, length); - } - std::cout << "sdds332" - << "\n"; - - // 5.multinominal - for (int i = 0; i < 15; i++) { - parallelism = length; - DT random = static_cast
(((float)std::rand()) / RAND_MAX); - thrust::inclusive_scan(thrust::device, - input_ptr + i * length, - input_ptr + (i + 1) * length, - static_cast
(m->cumsum_ptr) + i * length); - - // find_idx
<<>>(static_cast
(m->cumsum_ptr) + i * length, - // static_cast
(m->origin_ptr) + i * length, - // random, - // length, - // indices_ptr, - // i); - for (int j = 0; j < length; j++) { - if ((static_cast
(m->cumsum_ptr) + i * length)[j] >= random) { - indices_ptr[i] = (static_cast
(m->origin_ptr) + i * length)[i]; - printf("k value is:%d. %f\n", i, indices_ptr[i]); - break; - } - } - } + // init + print_tensor((float *)input_ptr+ 32000, 32, "inputttt"); + std::cout<< "meta " << length << ", " << batch_size << "\n"; + int parallelism = length * batch_size; + init_idxs<<>>(batch_size, + length, + length * batch_size, + m->idx, + m->begin_offset, + m->end_offset); + + checkCUDA(cudaDeviceSynchronize()); + // print_tensor(m->begin_offset, 64, "ofsset"); + // print_tensor(m->end_offset, 64, "ofsset"); + + std::cout<<"-------------------------sampling kernel _--------------------" << "\n"; + // sort + size_t temp_storage_bytes = 0; + void *d_temp_storage = nullptr; + cub::DeviceSegmentedRadixSort::SortPairsDescending( + m->d_temp_storage, + temp_storage_bytes, + input_ptr, + static_cast
(m->sorted_logits), + m->idx, + m->sorted_idx, + length * batch_size, + batch_size, + m->begin_offset, + m->end_offset + 1, + 0, // begin_bit + sizeof(DT) * 8, // end_bit = sizeof(KeyT) * 8 + stream); + + + checkCUDA(cudaDeviceSynchronize()); + cudaMalloc(&d_temp_storage, temp_storage_bytes); + cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, + temp_storage_bytes, + input_ptr, + static_cast
(m->sorted_logits), + m->idx, + m->sorted_idx, + length * batch_size, + batch_size, + m->begin_offset, + m->end_offset + 1, + 0, // begin_bit + sizeof(DT) * 8, // end_bit = sizeof(KeyT) * 8 + stream); + print_tensor((float *)m->sorted_logits + 32000, 32, "after sort"); + print_tensor(m->sorted_idx+ 32000, 32, "after sort"); // print_tensor((int *)indices_ptr, 15, "sdsdasd"); assert(false); } @@ -306,10 +167,19 @@ void Sampling::forward_kernel_wrapper(SamplingMeta const *m, } } -SamplingMeta::SamplingMeta(FFHandler handler, Op const *op) +SamplingMeta::SamplingMeta(FFHandler handler, + Op const *op, + int batch_size, + int total_ele) : OpMeta(handler, op) { - checkCUDA(cudaMalloc(&cumsum_ptr, 15 * 32000 * sizeof(float))); - checkCUDA(cudaMalloc(&sampled, 15 * 32000 * sizeof(float))); + DataType data_type = op->data_type; + checkCUDA(cudaMalloc(&begin_offset, (batch_size + 1) * sizeof(int))); + checkCUDA(cudaMalloc(&end_offset, (batch_size + 1) * sizeof(int))); + checkCUDA(cudaMalloc(&idx, total_ele * sizeof(int))); + + checkCUDA(cudaMalloc(&sorted_idx, total_ele * sizeof(int))); + checkCUDA(cudaMalloc(&sorted_logits, total_ele * data_type_size(data_type))); + } }; // namespace FlexFlow \ No newline at end of file diff --git a/src/ops/sampling_ref.cu b/src/ops/sampling_ref.cu new file mode 100644 index 0000000000..7918d4e76a --- /dev/null +++ b/src/ops/sampling_ref.cu @@ -0,0 +1,316 @@ +// /* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ + +// #include "flexflow/ops/sampling.h" +// #include "flexflow/utils/cuda_helper.h" +// #include +// #include +// #include "cub/cub.cuh" + +// namespace FlexFlow { + +// template +// __global__ void mask_value_above_top_p(DT *input_ptr, +// DT *cumsum_ptr, +// float top_p, +// int total_eles) { +// CUDA_KERNEL_LOOP(i, total_eles) { +// if ((cumsum_ptr[i] - input_ptr[i]) > static_cast
(top_p)) { +// input_ptr[i] = 0.0; +// } +// } +// } + +// template +// __global__ void re_normalized(DT *input_ptr, DT div, int length) { +// CUDA_KERNEL_LOOP(i, length) { +// input_ptr[i] /= div; +// } +// } + +// template +// __global__ void sampleMultinomialOnce(long long N, DT *input_ptr) { +// extern __shared__ unsigned char my_smem[]; +// __shared__ bool found; +// __shared__ unsigned foundPos; + +// float *smem = reinterpret_cast(my_smem); + +// float accZero = static_cast(0); +// DT zero = static_cast
(0); + +// for (int64_t curDist = blockIdx.x; curDist < distributions; +// curDist += gridDim.x) { + +// float sum = accZero; +// DT val; + +// for (int cat = threadIdx.x; cat < N; cat += blockDim.x) { +// val = dist[curDist * stride_dist + cat * stride_categories]; +// CUDA_KERNEL_ASSERT(!at::_isnan(val)); +// CUDA_KERNEL_ASSERT(!_isinf(val)); +// CUDA_KERNEL_ASSERT(!(val < zero)); +// sum = sum + static_cast(val); +// } + + +// //sum +// sum = BlockReduceSum(sum, smem); + +// if (threadIdx.x == 0) { +// foundPos = 0; +// smem[0] = sum; +// smem[1] = sampled[curDist]; +// } + +// __syncthreads(); +// sum = smem[0]; + +// DT sample = static_cast
(smem[1]); +// __syncthreads(); + +// if (sum == accZero) { +// // Choose the first element +// if (threadIdx.x == 0) { +// dest[curDist] = 0; +// } + +// continue; +// } + +// //ELSE +// int chunks = (categories + (int)blockDim.x - 1) / blockDim.x; +// float prevHighProb = accZero; + +// found = false; +// for (int chunk = 0; chunk < chunks && !found; ++chunk) { + +// int cat = chunk * blockDim.x + threadIdx.x; +// float dist_val = cat < categories ? +// static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : +// accZero; + +// smem[threadIdx.x] = dist_val; +// __syncthreads(); + +// // Perform an inclusive prefix sum of the shared memory contents +// for (int offset = 1; offset < blockDim.x; offset *= 2) { +// float val = accZero; + +// if (threadIdx.x >= offset) { +// val = smem[threadIdx.x - offset] + smem[threadIdx.x]; +// } + +// __syncthreads(); +// if (threadIdx.x >= offset) { +// smem[threadIdx.x] = val; +// } +// __syncthreads(); +// } + +// // Each thread will check to see if the sample falls in its +// // bucket +// DT curBucket = +// static_cast
(smem[threadIdx.x] + prevHighProb); +// DT prevBucket = static_cast
( +// threadIdx.x == 0 ? prevHighProb +// : smem[threadIdx.x - 1] + prevHighProb); +// bool inBucket = +// (cat < categories) && +// (!(sample >= curBucket) && +// (sample >= prevBucket) && +// (dist_val > zero)); + +// if (inBucket) { +// // We're done; we have the sample +// // Torch indices are 1-based +// atomicMax(&foundPos, cat); +// found = true; +// } + +// // Store the previous scan's high value for future use +// prevHighProb = prevHighProb + smem[blockDim.x - 1]; +// __syncthreads(); +// } + +// if (threadIdx.x == 0) { +// if (found) { +// dest[curDist] = foundPos; +// } else { +// // This should address a rare bug where we don't select a valid index. This likely occurs when +// // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but +// // and our uniform sample is greater than this value. In this case we likely have unitialized memory +// // in dest[curDist]. So basically we will loop through the distribution and pick the largest index +// // where the distribution is non-zero. This is obviously terribly inefficient, but due to the +// // rarity in which this occurs, this should not be an issue. +// for (int cat = categories - 1; cat >= 0; --cat) { +// if (dist[curDist * stride_dist + cat * stride_categories] > zero) { +// dest[curDist] = cat; +// break; +// } +// } +// } +// } + + +// } +// } + + +// /*static*/ +// template +// void Sampling::forward_kernel(SamplingMeta const *m, +// DT *input_ptr, +// int *indices_ptr, +// float top_p, +// int length, +// int batch_size, +// cudaStream_t stream) { +// // 1. sort +// // 2. cumsum +// // how to do it in parallel? + +// checkCUDA(cudaMemcpy(static_cast
(m->origin_ptr), +// input_ptr, +// sizeof(DT) * 15 * length, +// cudaMemcpyDeviceToDevice)); + +// std::cout << "asdqs: " << length << "\n"; + +// for (int i = 0; i < 15; i++) { +// thrust::sort(thrust::device, +// input_ptr + i * length, +// input_ptr + (i + 1) * length, +// thrust::greater
()); +// thrust::sort(thrust::device, +// static_cast
(m->origin_ptr) + i * length, +// static_cast
(m->origin_ptr) + (i + 1) * length, +// thrust::greater
()); +// thrust::inclusive_scan(thrust::device, +// input_ptr + i * length, +// input_ptr + (i + 1) * length, +// static_cast
(m->cumsum_ptr) + i * length); +// } +// std::cout << "sdsd" +// << "\n"; + +// // 3. mask +// int parallelism = 15 * length; +// mask_value_above_top_p
<<>>( +// input_ptr, static_cast
(m->cumsum_ptr), top_p, parallelism); + +// // 4. sum/div +// std::cout << "sadsd2www" +// << "\n"; +// for (int i = 0; i < 15; i++) { +// DT sum = thrust::reduce( +// thrust::device, input_ptr + i * length, input_ptr + (i + 1) * length); +// parallelism = length; + +// re_normalized
<<>>(input_ptr + i * length, sum, length); +// } +// std::cout << "sdds332" +// << "\n"; + +// // 5.multinominal +// for (int i = 0; i < 15; i++) { +// parallelism = length; +// DT random = static_cast
(((float)std::rand()) / RAND_MAX); +// thrust::inclusive_scan(thrust::device, +// input_ptr + i * length, +// input_ptr + (i + 1) * length, +// static_cast
(m->cumsum_ptr) + i * length); + +// // find_idx
<<>>(static_cast
(m->cumsum_ptr) + i * length, +// // static_cast
(m->origin_ptr) + i * length, +// // random, +// // length, +// // indices_ptr, +// // i); +// for (int j = 0; j < length; j++) { +// if ((static_cast
(m->cumsum_ptr) + i * length)[j] >= random) { +// indices_ptr[i] = (static_cast
(m->origin_ptr) + i * length)[i]; +// printf("k value is:%d. %f\n", i, indices_ptr[i]); +// break; +// } +// } +// } +// // print_tensor((int *)indices_ptr, 15, "sdsdasd"); +// assert(false); +// } + +// /*static*/ +// void Sampling::forward_kernel_wrapper(SamplingMeta const *m, +// GenericTensorAccessorW const &input, +// GenericTensorAccessorW const &indices) { +// cudaStream_t stream; +// checkCUDA(get_legion_stream(&stream)); + +// cudaEvent_t t_start, t_end; +// if (m->profiling) { +// cudaEventCreate(&t_start); +// cudaEventCreate(&t_end); +// cudaEventRecord(t_start, stream); +// } +// int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; +// int batch_size = input.domain.get_volume() / length; + +// if (input.data_type == DT_HALF) { +// Sampling::forward_kernel(m, +// input.get_half_ptr(), +// indices.get_int32_ptr(), +// m->top_p, +// length, +// batch_size, +// stream); +// } else if (input.data_type == DT_FLOAT) { +// Sampling::forward_kernel(m, +// input.get_float_ptr(), +// indices.get_int32_ptr(), +// m->top_p, +// length, +// batch_size, +// stream); +// } else { +// assert(false && "Unsupported data type"); +// } + +// if (m->profiling) { +// cudaEventRecord(t_end, stream); +// checkCUDA(cudaEventSynchronize(t_end)); +// float elapsed = 0; +// checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); +// cudaEventDestroy(t_start); +// cudaEventDestroy(t_end); +// printf("[Sampling] forward time = %.2lfms\n", elapsed); +// } +// } + +// SamplingMeta::SamplingMeta(FFHandler handler, Op const *op) +// : OpMeta(handler, op) { +// checkCUDA(cudaMalloc(&cumsum_ptr, 15 * 32000 * sizeof(float))); +// checkCUDA(cudaMalloc(&sampled, 15 * 32000 * sizeof(float))); +// } + +// }; // namespace FlexFlow \ No newline at end of file From 2cdcac68e9428d68069bb3a448363f5381f1934d Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Tue, 11 Jul 2023 05:51:17 +0000 Subject: [PATCH 03/16] . --- include/flexflow/ops/sampling.h | 3 + inference/models/llama.cc | 7 +- src/ops/kernels/rms_norm_kernels.cu | 1 + src/ops/sampling.cu | 217 ++++++++++++++++++++++++++-- 4 files changed, 209 insertions(+), 19 deletions(-) diff --git a/include/flexflow/ops/sampling.h b/include/flexflow/ops/sampling.h index 50830ce4a6..fbc4e4bf37 100644 --- a/include/flexflow/ops/sampling.h +++ b/include/flexflow/ops/sampling.h @@ -5,6 +5,8 @@ #include "flexflow/model.h" #include "flexflow/node.h" #include "flexflow/ops/sampling_params.h" +#include +#include namespace FlexFlow { @@ -17,6 +19,7 @@ class SamplingMeta : public OpMeta { int *end_offset; int *idx; void *d_temp_storage; + curandState *state; SamplingMeta(FFHandler handle, Op const *op, int batch_size, int total_ele); }; diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 3b4ddbd6d3..9e7fe2d892 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -241,9 +241,10 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor softmax = ff.softmax(dense, -1); output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); } else { - // output = ff.arg_top_k(dense, /*k=*/1, false); - Tensor softmax = ff.softmax(dense, -1); - output = ff.sampling(softmax, 0.95); + output = ff.arg_top_k(dense, /*k=*/1, false); + // dense = ff.scalar_truediv(dense, 0.8, false); + // Tensor softmax = ff.softmax(dense, -1); + // output = ff.sampling(softmax, 0.95); } // Compile the model diff --git a/src/ops/kernels/rms_norm_kernels.cu b/src/ops/kernels/rms_norm_kernels.cu index 44e6288529..042bf83c59 100644 --- a/src/ops/kernels/rms_norm_kernels.cu +++ b/src/ops/kernels/rms_norm_kernels.cu @@ -123,6 +123,7 @@ __global__ void elewise_apply_weights(int64_t batch_size, } template + void forward_kernel(RMSNormMeta const *m, T const *input_ptr, T const *weight_ptr, diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index 88192f4677..091b987d7b 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -14,12 +14,30 @@ */ #include "cub/cub.cuh" +#include "flexflow/ffconst_utils.h" #include "flexflow/ops/sampling.h" #include "flexflow/utils/cuda_helper.h" -#include "flexflow/ffconst_utils.h" +#include +#include namespace FlexFlow { +struct BlockPrefixCallbackOp { + // Running prefix + float running_total; + // Constructor + __device__ BlockPrefixCallbackOp(float running_total) + : running_total(running_total) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. + __device__ float operator()(float block_aggregate) { + float old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + template __global__ void mask_value_above_top_p(DT *input_ptr, DT *cumsum_ptr, @@ -48,21 +66,166 @@ __global__ void init_idxs(int batch_size, } } +__global__ void + init_random_kernel(curandState *state, int batch_size, long rand) { + CUDA_KERNEL_LOOP(i, batch_size) { + curand_init(rand, i, 0, &state[i]); + } +} + +// multinominal and gather +template +__global__ void sampling_topp_kernel(int batch_size, + int const vocab_size, + curandState *state, + DT *sorted_logits, + int *sorted_idx, + int *indices_ptr, + float topp) { + int const vocab_id = threadIdx.x; + int const batch_idx = blockIdx.x; + __shared__ float random_n; + __shared__ float renormalized_sum; + __shared__ long long result_idx; + + // random num + if (threadIdx.x == 0) { + // number must < topp + random_n = curand_uniform(state + batch_idx) * topp; + printf("batch idx: %d, %f\n", batch_idx, random_n); + } + + __syncthreads(); + + // cumsum; + typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScanMultiNominal; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage reduce_temp_storage; + __shared__ typename BlockScan::TempStorage multinominal_temp_storage; + + int offset = batch_idx * vocab_size; + float prefix_sum = 0.0f; + int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + BlockPrefixCallbackOp prefix_op(0); + float sum; + result_idx = vocab_size; + + for (long long j = threadIdx.x; j < vocab_size; j += blockDim.x) { + float logit = (float)sorted_logits[offset + j]; + // float logit = (j < vocab_size) ? (float)sorted_logits[offset + j] : 0.f; + BlockScan(temp_storage).InclusiveSum(logit, prefix_sum, prefix_op); + + prefix_sum /= topp; + + if (prefix_sum >= random_n) { + atomicMin(&result_idx, j); + } + + // if (blockIdx.x == 0 && j == 276){ + // printf("batch idx afterward aaaaaa: %f, %.10f, %.10f, %.10f, %.10f\n", + // topp, prefix_sum, logit, (float)sorted_logits[offset + j], random_n); + // } + // if (blockIdx.x == 1 && j == 39){ + // printf("batch idx afterward aaaaaa11111: %f, %.10f, %.10f, %.10f, + // %.10f\n", topp, prefix_sum, logit, (float)sorted_logits[offset + j], + // random_n); + // } + + // // mask + // sorted_logits[offset + j] = + // (prefix_sum - (float)sorted_logits[offset + j] > topp) + // ? (DT)0 + // : sorted_logits[offset + j]; + + // //get sum and divide + // sum += (float)sorted_logits[offset + j]; + // __syncthreads(); + // if (blockIdx.x == 0 && j > 31990) { + // printf( + // "batch idx afterward after:%d, %.20f, %.20f\n", j, prefix_sum, + // logit); + // } + // if (blockIdx.x == 0 && j > 1022 && j < 1028) { + // printf( + // "batch idx afterward before:%d, %,20f, %.20f\n", j, prefix_sum, + // logit); + // } + } + + indices_ptr[batch_idx] = sorted_idx[offset + result_idx]; + // if meet latency issue, this part can also be removed because the sum is + // very close to topp. + // float temp_sum = BlockReduce(reduce_temp_storage).Sum(sum); + // __syncthreads(); + // if(threadIdx.x == 0){ + // renormalized_sum = temp_sum; + // } + // __syncthreads(); + + // renormalized and multinominal + // result_idx = vocab_size; + // BlockPrefixCallbackOp prefix_op_2(0); + // prefix_sum = 0.0f; + // for (long long j = threadIdx.x; j < vocab_size; j += blockDim.x) { + // float logit = (float)sorted_logits[offset + j] / topp; + // BlockScanMultiNominal(multinominal_temp_storage).InclusiveSum(logit, + // prefix_sum, prefix_op_2); + + // if(prefix_sum >= random_n){ + // atomicMin(&result_idx, j); + // } + + // if (blockIdx.x == 0 && j == 1023){ + // printf("batch idx afterward aaaaaa: %f, %.10f, %.10f, %.10f, %.10f\n", + // topp, prefix_sum, logit, (float)sorted_logits[offset + j], random_n); + // } + // if (blockIdx.x == 1 && j == 39){ + // printf("batch idx afterward aaaaaa11111: %f, %.10f, %.10f, %.10f, + // %.10f\n", topp, prefix_sum, logit, (float)sorted_logits[offset + j], + // random_n); + // } + // } + // indices_ptr[batch_idx] = (int)result_idx; + + // __syncthreads(); + + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("batch idx afterward aaaaaa: %d\n", result_idx); + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("batch idx afterward aaaaaa0000: %d\n", result_idx); + } + if (blockIdx.x == 1 && threadIdx.x == 1) { + printf("batch idx afterward aaaaaa11111: %d\n", result_idx); + } + + // if (threadIdx.x == 1) { + // printf("batch idx afterward: %d, %f, %f, %d\n", batch_idx, prefix_sum, + // (float)sorted_logits[offset], offset); + // } + + // mask, div + + // select +} + /*static*/ template void Sampling::forward_kernel(SamplingMeta const *m, DT *input_ptr, int *indices_ptr, - float top_p, - int length, - int batch_size, + float const top_p, + int const length, + int const batch_size, cudaStream_t stream) { // 1. sort // 2. cumsum // how to do it in parallel? // init - print_tensor((float *)input_ptr+ 32000, 32, "inputttt"); - std::cout<< "meta " << length << ", " << batch_size << "\n"; + print_tensor((float *)input_ptr + 32000, 32, "inputttt"); + std::cout << "meta " << length << ", " << batch_size << "\n"; int parallelism = length * batch_size; init_idxs<<begin_offset, m->end_offset); - checkCUDA(cudaDeviceSynchronize()); + checkCUDA(cudaDeviceSynchronize()); // print_tensor(m->begin_offset, 64, "ofsset"); - // print_tensor(m->end_offset, 64, "ofsset"); + // print_tensor(m->end_offset, 64, "ofsset"); - std::cout<<"-------------------------sampling kernel _--------------------" << "\n"; + std::cout << "-------------------------sampling kernel _--------------------" + << "\n"; // sort size_t temp_storage_bytes = 0; void *d_temp_storage = nullptr; @@ -97,9 +261,9 @@ void Sampling::forward_kernel(SamplingMeta const *m, sizeof(DT) * 8, // end_bit = sizeof(KeyT) * 8 stream); - - checkCUDA(cudaDeviceSynchronize()); - cudaMalloc(&d_temp_storage, temp_storage_bytes); + checkCUDA(cudaDeviceSynchronize()); + cudaMalloc(&d_temp_storage, temp_storage_bytes); + // sort cub::DeviceSegmentedRadixSort::SortPairsDescending( d_temp_storage, temp_storage_bytes, @@ -114,10 +278,31 @@ void Sampling::forward_kernel(SamplingMeta const *m, 0, // begin_bit sizeof(DT) * 8, // end_bit = sizeof(KeyT) * 8 stream); - print_tensor((float *)m->sorted_logits + 32000, 32, "after sort"); - print_tensor(m->sorted_idx+ 32000, 32, "after sort"); + print_tensor((float *)m->sorted_logits, 32, "after sort 0"); + // print_tensor((float *)m->sorted_logits + 31990, 32, "after sort 1"); // print_tensor((int *)indices_ptr, 15, "sdsdasd"); - assert(false); + + // random + + parallelism = batch_size; + init_random_kernel<<>>(m->state, batch_size, rand()); + sampling_topp_kernel + <<>>(batch_size, + length, + m->state, + static_cast
(m->sorted_logits), + m->sorted_idx, + indices_ptr, + 0.95f); + + checkCUDA(cudaDeviceSynchronize()); + // print_tensor((float *)m->sorted_logits + 32000, 32, "after sort"); + // topk / topp mask some value and renormalize + + // sampling } /*static*/ @@ -179,7 +364,7 @@ SamplingMeta::SamplingMeta(FFHandler handler, checkCUDA(cudaMalloc(&sorted_idx, total_ele * sizeof(int))); checkCUDA(cudaMalloc(&sorted_logits, total_ele * data_type_size(data_type))); - + cudaMalloc(&state, sizeof(curandState) * batch_size); } }; // namespace FlexFlow \ No newline at end of file From b00ec992cff2b5e7629b5ec99d2602f4c1be3447 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Tue, 11 Jul 2023 05:56:24 +0000 Subject: [PATCH 04/16] del --- src/ops/sampling_ref.cu | 316 ---------------------------------------- 1 file changed, 316 deletions(-) delete mode 100644 src/ops/sampling_ref.cu diff --git a/src/ops/sampling_ref.cu b/src/ops/sampling_ref.cu deleted file mode 100644 index 7918d4e76a..0000000000 --- a/src/ops/sampling_ref.cu +++ /dev/null @@ -1,316 +0,0 @@ -// /* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) -// * -// * Licensed under the Apache License, Version 2.0 (the "License"); -// * you may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ - -// #include "flexflow/ops/sampling.h" -// #include "flexflow/utils/cuda_helper.h" -// #include -// #include -// #include "cub/cub.cuh" - -// namespace FlexFlow { - -// template -// __global__ void mask_value_above_top_p(DT *input_ptr, -// DT *cumsum_ptr, -// float top_p, -// int total_eles) { -// CUDA_KERNEL_LOOP(i, total_eles) { -// if ((cumsum_ptr[i] - input_ptr[i]) > static_cast
(top_p)) { -// input_ptr[i] = 0.0; -// } -// } -// } - -// template -// __global__ void re_normalized(DT *input_ptr, DT div, int length) { -// CUDA_KERNEL_LOOP(i, length) { -// input_ptr[i] /= div; -// } -// } - -// template -// __global__ void sampleMultinomialOnce(long long N, DT *input_ptr) { -// extern __shared__ unsigned char my_smem[]; -// __shared__ bool found; -// __shared__ unsigned foundPos; - -// float *smem = reinterpret_cast(my_smem); - -// float accZero = static_cast(0); -// DT zero = static_cast
(0); - -// for (int64_t curDist = blockIdx.x; curDist < distributions; -// curDist += gridDim.x) { - -// float sum = accZero; -// DT val; - -// for (int cat = threadIdx.x; cat < N; cat += blockDim.x) { -// val = dist[curDist * stride_dist + cat * stride_categories]; -// CUDA_KERNEL_ASSERT(!at::_isnan(val)); -// CUDA_KERNEL_ASSERT(!_isinf(val)); -// CUDA_KERNEL_ASSERT(!(val < zero)); -// sum = sum + static_cast(val); -// } - - -// //sum -// sum = BlockReduceSum(sum, smem); - -// if (threadIdx.x == 0) { -// foundPos = 0; -// smem[0] = sum; -// smem[1] = sampled[curDist]; -// } - -// __syncthreads(); -// sum = smem[0]; - -// DT sample = static_cast
(smem[1]); -// __syncthreads(); - -// if (sum == accZero) { -// // Choose the first element -// if (threadIdx.x == 0) { -// dest[curDist] = 0; -// } - -// continue; -// } - -// //ELSE -// int chunks = (categories + (int)blockDim.x - 1) / blockDim.x; -// float prevHighProb = accZero; - -// found = false; -// for (int chunk = 0; chunk < chunks && !found; ++chunk) { - -// int cat = chunk * blockDim.x + threadIdx.x; -// float dist_val = cat < categories ? -// static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : -// accZero; - -// smem[threadIdx.x] = dist_val; -// __syncthreads(); - -// // Perform an inclusive prefix sum of the shared memory contents -// for (int offset = 1; offset < blockDim.x; offset *= 2) { -// float val = accZero; - -// if (threadIdx.x >= offset) { -// val = smem[threadIdx.x - offset] + smem[threadIdx.x]; -// } - -// __syncthreads(); -// if (threadIdx.x >= offset) { -// smem[threadIdx.x] = val; -// } -// __syncthreads(); -// } - -// // Each thread will check to see if the sample falls in its -// // bucket -// DT curBucket = -// static_cast
(smem[threadIdx.x] + prevHighProb); -// DT prevBucket = static_cast
( -// threadIdx.x == 0 ? prevHighProb -// : smem[threadIdx.x - 1] + prevHighProb); -// bool inBucket = -// (cat < categories) && -// (!(sample >= curBucket) && -// (sample >= prevBucket) && -// (dist_val > zero)); - -// if (inBucket) { -// // We're done; we have the sample -// // Torch indices are 1-based -// atomicMax(&foundPos, cat); -// found = true; -// } - -// // Store the previous scan's high value for future use -// prevHighProb = prevHighProb + smem[blockDim.x - 1]; -// __syncthreads(); -// } - -// if (threadIdx.x == 0) { -// if (found) { -// dest[curDist] = foundPos; -// } else { -// // This should address a rare bug where we don't select a valid index. This likely occurs when -// // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but -// // and our uniform sample is greater than this value. In this case we likely have unitialized memory -// // in dest[curDist]. So basically we will loop through the distribution and pick the largest index -// // where the distribution is non-zero. This is obviously terribly inefficient, but due to the -// // rarity in which this occurs, this should not be an issue. -// for (int cat = categories - 1; cat >= 0; --cat) { -// if (dist[curDist * stride_dist + cat * stride_categories] > zero) { -// dest[curDist] = cat; -// break; -// } -// } -// } -// } - - -// } -// } - - -// /*static*/ -// template -// void Sampling::forward_kernel(SamplingMeta const *m, -// DT *input_ptr, -// int *indices_ptr, -// float top_p, -// int length, -// int batch_size, -// cudaStream_t stream) { -// // 1. sort -// // 2. cumsum -// // how to do it in parallel? - -// checkCUDA(cudaMemcpy(static_cast
(m->origin_ptr), -// input_ptr, -// sizeof(DT) * 15 * length, -// cudaMemcpyDeviceToDevice)); - -// std::cout << "asdqs: " << length << "\n"; - -// for (int i = 0; i < 15; i++) { -// thrust::sort(thrust::device, -// input_ptr + i * length, -// input_ptr + (i + 1) * length, -// thrust::greater
()); -// thrust::sort(thrust::device, -// static_cast
(m->origin_ptr) + i * length, -// static_cast
(m->origin_ptr) + (i + 1) * length, -// thrust::greater
()); -// thrust::inclusive_scan(thrust::device, -// input_ptr + i * length, -// input_ptr + (i + 1) * length, -// static_cast
(m->cumsum_ptr) + i * length); -// } -// std::cout << "sdsd" -// << "\n"; - -// // 3. mask -// int parallelism = 15 * length; -// mask_value_above_top_p
<<>>( -// input_ptr, static_cast
(m->cumsum_ptr), top_p, parallelism); - -// // 4. sum/div -// std::cout << "sadsd2www" -// << "\n"; -// for (int i = 0; i < 15; i++) { -// DT sum = thrust::reduce( -// thrust::device, input_ptr + i * length, input_ptr + (i + 1) * length); -// parallelism = length; - -// re_normalized
<<>>(input_ptr + i * length, sum, length); -// } -// std::cout << "sdds332" -// << "\n"; - -// // 5.multinominal -// for (int i = 0; i < 15; i++) { -// parallelism = length; -// DT random = static_cast
(((float)std::rand()) / RAND_MAX); -// thrust::inclusive_scan(thrust::device, -// input_ptr + i * length, -// input_ptr + (i + 1) * length, -// static_cast
(m->cumsum_ptr) + i * length); - -// // find_idx
<<>>(static_cast
(m->cumsum_ptr) + i * length, -// // static_cast
(m->origin_ptr) + i * length, -// // random, -// // length, -// // indices_ptr, -// // i); -// for (int j = 0; j < length; j++) { -// if ((static_cast
(m->cumsum_ptr) + i * length)[j] >= random) { -// indices_ptr[i] = (static_cast
(m->origin_ptr) + i * length)[i]; -// printf("k value is:%d. %f\n", i, indices_ptr[i]); -// break; -// } -// } -// } -// // print_tensor((int *)indices_ptr, 15, "sdsdasd"); -// assert(false); -// } - -// /*static*/ -// void Sampling::forward_kernel_wrapper(SamplingMeta const *m, -// GenericTensorAccessorW const &input, -// GenericTensorAccessorW const &indices) { -// cudaStream_t stream; -// checkCUDA(get_legion_stream(&stream)); - -// cudaEvent_t t_start, t_end; -// if (m->profiling) { -// cudaEventCreate(&t_start); -// cudaEventCreate(&t_end); -// cudaEventRecord(t_start, stream); -// } -// int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; -// int batch_size = input.domain.get_volume() / length; - -// if (input.data_type == DT_HALF) { -// Sampling::forward_kernel(m, -// input.get_half_ptr(), -// indices.get_int32_ptr(), -// m->top_p, -// length, -// batch_size, -// stream); -// } else if (input.data_type == DT_FLOAT) { -// Sampling::forward_kernel(m, -// input.get_float_ptr(), -// indices.get_int32_ptr(), -// m->top_p, -// length, -// batch_size, -// stream); -// } else { -// assert(false && "Unsupported data type"); -// } - -// if (m->profiling) { -// cudaEventRecord(t_end, stream); -// checkCUDA(cudaEventSynchronize(t_end)); -// float elapsed = 0; -// checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); -// cudaEventDestroy(t_start); -// cudaEventDestroy(t_end); -// printf("[Sampling] forward time = %.2lfms\n", elapsed); -// } -// } - -// SamplingMeta::SamplingMeta(FFHandler handler, Op const *op) -// : OpMeta(handler, op) { -// checkCUDA(cudaMalloc(&cumsum_ptr, 15 * 32000 * sizeof(float))); -// checkCUDA(cudaMalloc(&sampled, 15 * 32000 * sizeof(float))); -// } - -// }; // namespace FlexFlow \ No newline at end of file From b9b57f7343fe9d1b5ee537c1b5e024ee24b38db8 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Wed, 12 Jul 2023 02:44:04 +0000 Subject: [PATCH 05/16] . --- src/runtime/model.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 7f7d65e7ae..99fbe2ede3 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -52,11 +52,11 @@ #include "flexflow/ops/reshape.h" #include "flexflow/ops/reverse.h" #include "flexflow/ops/rms_norm.h" +#include "flexflow/ops/sampling.h" #include "flexflow/ops/softmax.h" #include "flexflow/ops/spec_inc_multihead_self_attention.h" #include "flexflow/ops/split.h" #include "flexflow/ops/topk.h" -#include "flexflow/ops/sampling.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" #include "flexflow/parallel_ops/allreduce.h" @@ -4764,7 +4764,8 @@ void register_flexflow_internal_tasks() { TaskVariantRegistrar registrar(SAMPLING_INF_TASK_ID, "Sampling Inference"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - Runtime::preregister_task_variant( + Runtime::preregister_task_variant( registrar, "Sampling Inference Task"); } // Transpose task From 800b1cd5d63a69f7ff3a8a35089d3e53568bf0b6 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 13 Jul 2023 07:08:04 +0000 Subject: [PATCH 06/16] finish impl. --- include/flexflow/inference.h | 12 ++ include/flexflow/operator_params.h | 1 + include/flexflow/ops/sampling.h | 3 +- inference/incr_decoding/incr_decoding.cc | 25 +++ inference/models/llama.cc | 12 +- inference/models/llama.h | 3 +- inference/spec_infer/spec_infer.cc | 3 + src/ops/fused.cu | 5 +- src/ops/sampling.cc | 6 +- src/ops/sampling.cu | 234 ++++++++--------------- src/runtime/model.cc | 2 +- src/runtime/operator_params.cc | 3 + 12 files changed, 139 insertions(+), 170 deletions(-) diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index a1846c96dc..192573fffa 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -65,6 +65,18 @@ struct BeamTree { treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1]; }; +struct GenerationConfig { + bool do_sample = false; + float temperature = 0.8; + float topp = 0.6; + GenerationConfig(bool _do_sample, float _temperature, float _topp) { + temperature = _temperature > 0 ? _temperature : temperature; + topp = _topp > 0 ? _topp : topp; + do_sample = _do_sample; + } + GenerationConfig() {} +}; + // struct BeamTree_v2 { // std::vector tokens; // std::vector parent_ids; diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 96a62e0bd4..4386a09d20 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -72,6 +72,7 @@ using OperatorParameters = mp::variantop_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); assert(fused->op_num_outputs[op] == 1); @@ -1268,7 +1269,7 @@ __host__ void FusedOp::backward_task(Task const *task, case OP_RELU: case OP_SIGMOID: case OP_TANH: - case OP_ELU: { + case OP_ELU:{ assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); assert(fused->op_num_outputs[op] == 1); diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc index bb63c6ff90..02c34aa719 100644 --- a/src/ops/sampling.cc +++ b/src/ops/sampling.cc @@ -92,7 +92,6 @@ SamplingParams Sampling::get_params() const { } bool SamplingParams::is_valid(ParallelTensorShape const &) const { - // topk is always valid return true; } @@ -119,6 +118,7 @@ Sampling::Sampling(FFModel &model, dims[i] = inputs[0]->dims[i]; } dims[0].size = 1; + std::cout << "degree: " << inputs[0]->dims[0].degree << "\n"; assert(inputs[0]->dims[0].degree == 1); assert(inputs[0]->dims[0].parallel_idx == -1); // outputs[0] = model.create_parallel_tensor_legion_ordering( @@ -231,7 +231,7 @@ OpMeta *Sampling::init_task(Task const *task, int length = acc_input.domain.hi()[0] - acc_input.domain.lo()[0] + 1; int batch_size = acc_input.domain.get_volume() / length; - SamplingMeta *m = new SamplingMeta(handle, s, batch_size, length * batch_size); + SamplingMeta *m = new SamplingMeta(handle, s, batch_size, length * batch_size, acc_input); m->profiling = s->profiling; m->top_p = s->top_p; return m; @@ -292,7 +292,7 @@ InferenceResult Runtime *runtime) { assert(regions.size() == 2); assert(task->regions.size() == 2); - // const Sampling* topk = (const Sampling*) task->args; + const Sampling* sampling = (const Sampling*) task->args; SamplingMeta const *m = *((SamplingMeta **)task->local_args); GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index 091b987d7b..2637d71857 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -22,6 +22,7 @@ namespace FlexFlow { +constexpr int SamplingNumThreads = 1024; struct BlockPrefixCallbackOp { // Running prefix float running_total; @@ -38,18 +39,6 @@ struct BlockPrefixCallbackOp { } }; -template -__global__ void mask_value_above_top_p(DT *input_ptr, - DT *cumsum_ptr, - float top_p, - int total_eles) { - CUDA_KERNEL_LOOP(i, total_eles) { - if ((cumsum_ptr[i] - input_ptr[i]) > static_cast
(top_p)) { - input_ptr[i] = 0.0; - } - } -} - __global__ void init_idxs(int batch_size, int vocab_size, int total_eles, @@ -82,133 +71,48 @@ __global__ void sampling_topp_kernel(int batch_size, int *sorted_idx, int *indices_ptr, float topp) { - int const vocab_id = threadIdx.x; + // int const vocab_id = threadIdx.x; int const batch_idx = blockIdx.x; __shared__ float random_n; - __shared__ float renormalized_sum; + // __shared__ float renormalized_sum; __shared__ long long result_idx; // random num if (threadIdx.x == 0) { // number must < topp random_n = curand_uniform(state + batch_idx) * topp; - printf("batch idx: %d, %f\n", batch_idx, random_n); + if(blockIdx.x == 0){ + printf("batch idx: %d, %f\n", batch_idx, random_n); + } + } __syncthreads(); // cumsum; typedef cub::BlockScan BlockScan; - typedef cub::BlockScan BlockScanMultiNominal; - typedef cub::BlockReduce BlockReduce; __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ typename BlockReduce::TempStorage reduce_temp_storage; - __shared__ typename BlockScan::TempStorage multinominal_temp_storage; int offset = batch_idx * vocab_size; float prefix_sum = 0.0f; - int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; BlockPrefixCallbackOp prefix_op(0); - float sum; - result_idx = vocab_size; + // float sum; + result_idx = vocab_size - 1; for (long long j = threadIdx.x; j < vocab_size; j += blockDim.x) { float logit = (float)sorted_logits[offset + j]; - // float logit = (j < vocab_size) ? (float)sorted_logits[offset + j] : 0.f; BlockScan(temp_storage).InclusiveSum(logit, prefix_sum, prefix_op); - prefix_sum /= topp; - if (prefix_sum >= random_n) { atomicMin(&result_idx, j); } - - // if (blockIdx.x == 0 && j == 276){ - // printf("batch idx afterward aaaaaa: %f, %.10f, %.10f, %.10f, %.10f\n", - // topp, prefix_sum, logit, (float)sorted_logits[offset + j], random_n); - // } - // if (blockIdx.x == 1 && j == 39){ - // printf("batch idx afterward aaaaaa11111: %f, %.10f, %.10f, %.10f, - // %.10f\n", topp, prefix_sum, logit, (float)sorted_logits[offset + j], - // random_n); - // } - - // // mask - // sorted_logits[offset + j] = - // (prefix_sum - (float)sorted_logits[offset + j] > topp) - // ? (DT)0 - // : sorted_logits[offset + j]; - - // //get sum and divide - // sum += (float)sorted_logits[offset + j]; - // __syncthreads(); - // if (blockIdx.x == 0 && j > 31990) { - // printf( - // "batch idx afterward after:%d, %.20f, %.20f\n", j, prefix_sum, - // logit); - // } - // if (blockIdx.x == 0 && j > 1022 && j < 1028) { - // printf( - // "batch idx afterward before:%d, %,20f, %.20f\n", j, prefix_sum, - // logit); - // } } - indices_ptr[batch_idx] = sorted_idx[offset + result_idx]; - // if meet latency issue, this part can also be removed because the sum is - // very close to topp. - // float temp_sum = BlockReduce(reduce_temp_storage).Sum(sum); - // __syncthreads(); - // if(threadIdx.x == 0){ - // renormalized_sum = temp_sum; - // } - // __syncthreads(); - - // renormalized and multinominal - // result_idx = vocab_size; - // BlockPrefixCallbackOp prefix_op_2(0); - // prefix_sum = 0.0f; - // for (long long j = threadIdx.x; j < vocab_size; j += blockDim.x) { - // float logit = (float)sorted_logits[offset + j] / topp; - // BlockScanMultiNominal(multinominal_temp_storage).InclusiveSum(logit, - // prefix_sum, prefix_op_2); - // if(prefix_sum >= random_n){ - // atomicMin(&result_idx, j); - // } - - // if (blockIdx.x == 0 && j == 1023){ - // printf("batch idx afterward aaaaaa: %f, %.10f, %.10f, %.10f, %.10f\n", - // topp, prefix_sum, logit, (float)sorted_logits[offset + j], random_n); - // } - // if (blockIdx.x == 1 && j == 39){ - // printf("batch idx afterward aaaaaa11111: %f, %.10f, %.10f, %.10f, - // %.10f\n", topp, prefix_sum, logit, (float)sorted_logits[offset + j], - // random_n); - // } + // if (blockIdx.x == 0 && threadIdx.x == 0) { + // printf("batch idx afterward aaaaaa: %d\n", result_idx); // } - // indices_ptr[batch_idx] = (int)result_idx; - - // __syncthreads(); - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("batch idx afterward aaaaaa: %d\n", result_idx); - } - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("batch idx afterward aaaaaa0000: %d\n", result_idx); - } - if (blockIdx.x == 1 && threadIdx.x == 1) { - printf("batch idx afterward aaaaaa11111: %d\n", result_idx); - } - - // if (threadIdx.x == 1) { - // printf("batch idx afterward: %d, %f, %f, %d\n", batch_idx, prefix_sum, - // (float)sorted_logits[offset], offset); - // } - - // mask, div - - // select } /*static*/ @@ -222,53 +126,15 @@ void Sampling::forward_kernel(SamplingMeta const *m, cudaStream_t stream) { // 1. sort // 2. cumsum - // how to do it in parallel? - // init - print_tensor((float *)input_ptr + 32000, 32, "inputttt"); - std::cout << "meta " << length << ", " << batch_size << "\n"; - int parallelism = length * batch_size; - init_idxs<<>>(batch_size, - length, - length * batch_size, - m->idx, - m->begin_offset, - m->end_offset); - - checkCUDA(cudaDeviceSynchronize()); - // print_tensor(m->begin_offset, 64, "ofsset"); - // print_tensor(m->end_offset, 64, "ofsset"); - std::cout << "-------------------------sampling kernel _--------------------" << "\n"; - // sort - size_t temp_storage_bytes = 0; - void *d_temp_storage = nullptr; + + size_t temp_storage_bytes = m->temp_storage_bytes; cub::DeviceSegmentedRadixSort::SortPairsDescending( m->d_temp_storage, temp_storage_bytes, input_ptr, - static_cast
(m->sorted_logits), - m->idx, - m->sorted_idx, - length * batch_size, - batch_size, - m->begin_offset, - m->end_offset + 1, - 0, // begin_bit - sizeof(DT) * 8, // end_bit = sizeof(KeyT) * 8 - stream); - - checkCUDA(cudaDeviceSynchronize()); - cudaMalloc(&d_temp_storage, temp_storage_bytes); - // sort - cub::DeviceSegmentedRadixSort::SortPairsDescending( - d_temp_storage, - temp_storage_bytes, input_ptr, - static_cast
(m->sorted_logits), m->idx, m->sorted_idx, length * batch_size, @@ -278,25 +144,19 @@ void Sampling::forward_kernel(SamplingMeta const *m, 0, // begin_bit sizeof(DT) * 8, // end_bit = sizeof(KeyT) * 8 stream); - print_tensor((float *)m->sorted_logits, 32, "after sort 0"); - // print_tensor((float *)m->sorted_logits + 31990, 32, "after sort 1"); - // print_tensor((int *)indices_ptr, 15, "sdsdasd"); - - // random - - parallelism = batch_size; + int parallelism = batch_size; init_random_kernel<<>>(m->state, batch_size, rand()); - sampling_topp_kernel - <<>>(batch_size, + sampling_topp_kernel + <<>>(batch_size, length, m->state, - static_cast
(m->sorted_logits), + input_ptr, m->sorted_idx, indices_ptr, - 0.95f); + top_p); checkCUDA(cudaDeviceSynchronize()); // print_tensor((float *)m->sorted_logits + 32000, 32, "after sort"); @@ -355,7 +215,8 @@ void Sampling::forward_kernel_wrapper(SamplingMeta const *m, SamplingMeta::SamplingMeta(FFHandler handler, Op const *op, int batch_size, - int total_ele) + int total_ele, + GenericTensorAccessorW input) : OpMeta(handler, op) { DataType data_type = op->data_type; checkCUDA(cudaMalloc(&begin_offset, (batch_size + 1) * sizeof(int))); @@ -365,6 +226,61 @@ SamplingMeta::SamplingMeta(FFHandler handler, checkCUDA(cudaMalloc(&sorted_idx, total_ele * sizeof(int))); checkCUDA(cudaMalloc(&sorted_logits, total_ele * data_type_size(data_type))); cudaMalloc(&state, sizeof(curandState) * batch_size); + + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + + //init offset + int parallelism = total_ele; + init_idxs<<>>(batch_size, + total_ele / batch_size, + total_ele, + idx, + begin_offset, + end_offset); + + + + //init sort function + if(data_type == DT_FLOAT){ + cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, + temp_storage_bytes, + input.get_float_ptr(), + input.get_float_ptr(), + idx, + idx, + total_ele, + batch_size, + begin_offset, + end_offset + 1, + 0, // begin_bit + data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 + stream); + }else if(data_type == DT_HALF){ + cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, + temp_storage_bytes, + input.get_half_ptr(), + input.get_half_ptr(), + idx, + idx, + total_ele, + batch_size, + begin_offset, + end_offset + 1, + 0, // begin_bit + data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 + stream); + }else{ + assert(false && "input type in float and half"); + } + + checkCUDA(cudaMalloc(&d_temp_storage, temp_storage_bytes)); } }; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 99fbe2ede3..9428a3fd89 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -2983,7 +2983,7 @@ void FFModel::create_operators_from_layers() { Op *op = nullptr; // add a combine before arg_topk if (config.computationMode == COMP_MODE_INFERENCE && - config.tensor_parallelism_degree > 1 && l->op_type == OP_ARG_TOPK) { + config.tensor_parallelism_degree > 1 && (l->op_type == OP_ARG_TOPK || l->op_type == OP_SOFTMAX)) { std::vector partitioned_inputs; assert(inputs.size() == 1); Combine *comb = new Combine(*this, diff --git a/src/runtime/operator_params.cc b/src/runtime/operator_params.cc index 6b61d5ac7a..0366fba4e8 100644 --- a/src/runtime/operator_params.cc +++ b/src/runtime/operator_params.cc @@ -2,6 +2,7 @@ #include "flexflow/ops/aggregate.h" #include "flexflow/ops/aggregate_spec.h" #include "flexflow/ops/arg_topk.h" +#include "flexflow/ops/sampling.h" #include "flexflow/ops/attention.h" #include "flexflow/ops/batch_matmul.h" #include "flexflow/ops/batch_norm.h" @@ -130,6 +131,8 @@ tl::optional get_op_parameters(Op const *op) { return ((ArgTopK *)op)->get_params(); case OP_BEAM_TOPK: return ((BeamTopK *)op)->get_params(); + case OP_SAMPLING: + return ((Sampling *)op)->get_params(); // TODO: implement the get_params() function for the operators below and // uncomment the lines below From 0874a49e557fde97198189a4a2c892f242a6dc9e Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 13 Jul 2023 07:20:54 +0000 Subject: [PATCH 07/16] clean up, format, hip_rocm --- include/flexflow/model.h | 8 +-- include/flexflow/ops/sampling.h | 6 +- inference/models/llama.cc | 6 +- inference/models/llama.h | 2 +- src/ops/fused.cu | 4 +- src/ops/kernels/rms_norm_kernels.cu | 1 - src/ops/sampling.cc | 16 +++-- src/ops/sampling.cpp | 66 ++++++++++++++++++++ src/ops/sampling.cu | 93 +++++++++++++---------------- src/runtime/ffconst_utils.cc | 2 +- src/runtime/graph.cc | 2 +- src/runtime/model.cc | 3 +- src/runtime/operator_params.cc | 2 +- 13 files changed, 138 insertions(+), 73 deletions(-) create mode 100644 src/ops/sampling.cpp diff --git a/include/flexflow/model.h b/include/flexflow/model.h index cab268b91d..afea5e6c96 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -615,9 +615,7 @@ class FFModel { int k, bool sorted, char const *name = NULL); - Tensor sampling(const Tensor input, - float top_p, - char const *name = NULL); + Tensor sampling(const Tensor input, float top_p, char const *name = NULL); Tensor multihead_attention(const Tensor query, const Tensor key, const Tensor value, @@ -1067,8 +1065,8 @@ class FFModel { IncMultiQuerySelfAttention *>, std::unordered_map, BeamTopK *>, - std::unordered_map, - Sampling *>, + std::unordered_map, + Sampling *>, std::unordered_map< std::pair, SpecIncMultiHeadSelfAttention *>, diff --git a/include/flexflow/ops/sampling.h b/include/flexflow/ops/sampling.h index 1ad193082b..544b728575 100644 --- a/include/flexflow/ops/sampling.h +++ b/include/flexflow/ops/sampling.h @@ -21,7 +21,11 @@ class SamplingMeta : public OpMeta { void *d_temp_storage; size_t temp_storage_bytes; curandState *state; - SamplingMeta(FFHandler handle, Op const *op, int batch_size, int total_ele, GenericTensorAccessorW input); + SamplingMeta(FFHandler handle, + Op const *op, + int batch_size, + int total_ele, + GenericTensorAccessorW input); }; class Sampling : public Op { diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 1975bb861d..8b7f5b2dc6 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -212,12 +212,12 @@ void LLAMA::create_llama_model(FFModel &ff, output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); } else { // Tensor softmax = ff.softmax(dense, -1); - - if(generationConfig.do_sample){ + + if (generationConfig.do_sample) { dense = ff.scalar_truediv(dense, generationConfig.temperature, false); Tensor softmax = ff.softmax(dense, -1); output = ff.sampling(softmax, generationConfig.topp); - }else{ + } else { output = ff.arg_top_k(dense, /*k=*/1, false); } } diff --git a/inference/models/llama.h b/inference/models/llama.h index 043b80c1ad..31959e4938 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -108,7 +108,7 @@ class LLAMA { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generationConfig, - bool use_full_precision = false ); + bool use_full_precision = false); }; }; // namespace FlexFlow diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 896d47018e..ef6c856871 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -749,7 +749,7 @@ __host__ void case OP_SIGMOID: case OP_TANH: case OP_ELU: - case OP_SCALAR_TRUE_DIV: { + case OP_SCALAR_TRUE_DIV: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); assert(fused->op_num_outputs[op] == 1); @@ -1269,7 +1269,7 @@ __host__ void FusedOp::backward_task(Task const *task, case OP_RELU: case OP_SIGMOID: case OP_TANH: - case OP_ELU:{ + case OP_ELU: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); assert(fused->op_num_outputs[op] == 1); diff --git a/src/ops/kernels/rms_norm_kernels.cu b/src/ops/kernels/rms_norm_kernels.cu index 042bf83c59..44e6288529 100644 --- a/src/ops/kernels/rms_norm_kernels.cu +++ b/src/ops/kernels/rms_norm_kernels.cu @@ -123,7 +123,6 @@ __global__ void elewise_apply_weights(int64_t batch_size, } template - void forward_kernel(RMSNormMeta const *m, T const *input_ptr, T const *weight_ptr, diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc index 02c34aa719..dcab50ffb2 100644 --- a/src/ops/sampling.cc +++ b/src/ops/sampling.cc @@ -225,13 +225,19 @@ OpMeta *Sampling::init_task(Task const *task, Runtime *runtime) { Sampling *s = (Sampling *)task->args; FFHandler handle = *((FFHandler *)task->local_args); - GenericTensorAccessorW acc_input = helperGetGenericTensorAccessorRW( - s->inputs[0]->data_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW acc_input = + helperGetGenericTensorAccessorRW(s->inputs[0]->data_type, + regions[0], + task->regions[0], + FID_DATA, + ctx, + runtime); int length = acc_input.domain.hi()[0] - acc_input.domain.lo()[0] + 1; - int batch_size = acc_input.domain.get_volume() / length; + int batch_size = acc_input.domain.get_volume() / length; - SamplingMeta *m = new SamplingMeta(handle, s, batch_size, length * batch_size, acc_input); + SamplingMeta *m = + new SamplingMeta(handle, s, batch_size, length * batch_size, acc_input); m->profiling = s->profiling; m->top_p = s->top_p; return m; @@ -292,7 +298,7 @@ InferenceResult Runtime *runtime) { assert(regions.size() == 2); assert(task->regions.size() == 2); - const Sampling* sampling = (const Sampling*) task->args; + Sampling const *sampling = (Sampling const *)task->args; SamplingMeta const *m = *((SamplingMeta **)task->local_args); GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( diff --git a/src/ops/sampling.cpp b/src/ops/sampling.cpp new file mode 100644 index 0000000000..a652e46fe1 --- /dev/null +++ b/src/ops/sampling.cpp @@ -0,0 +1,66 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/ops/sampling.h" +#include "flexflow/ffconst_utils.h" +#include "flexflow/utils/hip_helper.h" +#include + +namespace FlexFlow { + +/*static*/ +template +void Sampling::forward_kernel(SamplingMeta const *m, + DT *input_ptr, + int *indices_ptr, + float const top_p, + int const length, + int const batch_size, + hipStream_t stream) {} + +/*static*/ +void Sampling::forward_kernel_wrapper(SamplingMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + hipEvent_t t_start, t_end; + if (m->profiling) { + hipEventCreate(&t_start); + hipEventCreate(&t_end); + hipEventRecord(t_start, stream); + } + + handle_unimplemented_hip_kernel(OP_RMS_NORM); + + if (m->profiling) { + hipEventRecord(t_end, stream); + checkCUDA(hipEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(hipEventElapsedTime(&elapsed, t_start, t_end)); + hipEventDestroy(t_start); + hipEventDestroy(t_end); + } +} + +SamplingMeta::SamplingMeta(FFHandler handler, + Op const *op, + int batch_size, + int total_ele, + GenericTensorAccessorW input) + : OpMeta(handler, op) {} + +}; // namespace FlexFlow \ No newline at end of file diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index 2637d71857..185f7d02b3 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -48,7 +48,6 @@ __global__ void init_idxs(int batch_size, CUDA_KERNEL_LOOP(i, total_eles) { idx[i] = i % vocab_size; if (i % vocab_size == 0) { - // printf("adfadf :%d\n", i); begin_offset[i / vocab_size] = i; end_offset[i / vocab_size] = i; } @@ -81,10 +80,9 @@ __global__ void sampling_topp_kernel(int batch_size, if (threadIdx.x == 0) { // number must < topp random_n = curand_uniform(state + batch_idx) * topp; - if(blockIdx.x == 0){ - printf("batch idx: %d, %f\n", batch_idx, random_n); - } - + // if (blockIdx.x == 0) { + // printf("batch idx: %d, random num%f\n", batch_idx, random_n); + // } } __syncthreads(); @@ -110,9 +108,8 @@ __global__ void sampling_topp_kernel(int batch_size, indices_ptr[batch_idx] = sorted_idx[offset + result_idx]; // if (blockIdx.x == 0 && threadIdx.x == 0) { - // printf("batch idx afterward aaaaaa: %d\n", result_idx); + // printf("selected idx: %d\n", result_idx); // } - } /*static*/ @@ -126,9 +123,7 @@ void Sampling::forward_kernel(SamplingMeta const *m, cudaStream_t stream) { // 1. sort // 2. cumsum - std::cout << "-------------------------sampling kernel _--------------------" - << "\n"; - + size_t temp_storage_bytes = m->temp_storage_bytes; cub::DeviceSegmentedRadixSort::SortPairsDescending( m->d_temp_storage, @@ -151,15 +146,14 @@ void Sampling::forward_kernel(SamplingMeta const *m, stream>>>(m->state, batch_size, rand()); sampling_topp_kernel <<>>(batch_size, - length, - m->state, - input_ptr, - m->sorted_idx, - indices_ptr, - top_p); + length, + m->state, + input_ptr, + m->sorted_idx, + indices_ptr, + top_p); checkCUDA(cudaDeviceSynchronize()); - // print_tensor((float *)m->sorted_logits + 32000, 32, "after sort"); // topk / topp mask some value and renormalize // sampling @@ -230,8 +224,7 @@ SamplingMeta::SamplingMeta(FFHandler handler, cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - - //init offset + // init offset int parallelism = total_ele; init_idxs<< 1 && (l->op_type == OP_ARG_TOPK || l->op_type == OP_SOFTMAX)) { + config.tensor_parallelism_degree > 1 && + (l->op_type == OP_ARG_TOPK || l->op_type == OP_SOFTMAX)) { std::vector partitioned_inputs; assert(inputs.size() == 1); Combine *comb = new Combine(*this, diff --git a/src/runtime/operator_params.cc b/src/runtime/operator_params.cc index 0366fba4e8..8fb8c89b10 100644 --- a/src/runtime/operator_params.cc +++ b/src/runtime/operator_params.cc @@ -2,7 +2,6 @@ #include "flexflow/ops/aggregate.h" #include "flexflow/ops/aggregate_spec.h" #include "flexflow/ops/arg_topk.h" -#include "flexflow/ops/sampling.h" #include "flexflow/ops/attention.h" #include "flexflow/ops/batch_matmul.h" #include "flexflow/ops/batch_norm.h" @@ -29,6 +28,7 @@ #include "flexflow/ops/reshape.h" #include "flexflow/ops/reverse.h" #include "flexflow/ops/rms_norm.h" +#include "flexflow/ops/sampling.h" #include "flexflow/ops/softmax.h" #include "flexflow/ops/spec_inc_multihead_self_attention.h" #include "flexflow/ops/split.h" From 33a082768a970752147120600f6271a969a67e43 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 13 Jul 2023 07:23:37 +0000 Subject: [PATCH 08/16] format --- include/flexflow/operator_params.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 4386a09d20..2a368ef381 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -26,11 +26,11 @@ #include "flexflow/ops/reduce_params.h" #include "flexflow/ops/reshape_params.h" #include "flexflow/ops/rms_norm_params.h" +#include "flexflow/ops/sampling_params.h" #include "flexflow/ops/softmax_params.h" #include "flexflow/ops/spec_inc_multihead_self_attention_params.h" #include "flexflow/ops/split_params.h" #include "flexflow/ops/topk_params.h" -#include "flexflow/ops/sampling_params.h" #include "flexflow/ops/transpose_params.h" #include "flexflow/ops/tree_inc_multihead_self_attention_params.h" #include "flexflow/parallel_ops/allreduce_params.h" @@ -84,6 +84,6 @@ using OperatorParameters = mp::variant get_op_parameters(Op const *op); -}; // namespace FlexFlow +}; // namespace FlexFlow #endif // _OPERATOR_PARAMS_H From 68c0911003a34b9f67d2c6c0240a2848281e36d5 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 13 Jul 2023 07:26:04 +0000 Subject: [PATCH 09/16] . --- include/flexflow/operator_params.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 2a368ef381..5c2101d190 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -84,6 +84,6 @@ using OperatorParameters = mp::variant get_op_parameters(Op const *op); -}; // namespace FlexFlow +}; // namespace FlexFlow #endif // _OPERATOR_PARAMS_H From 469da8598031fc542baed31a3892876573a09671 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 13 Jul 2023 18:58:28 +0000 Subject: [PATCH 10/16] fix half precision. --- include/flexflow/ops/sampling.h | 4 ++++ src/ops/sampling.cu | 42 +++++++++++++++------------------ 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/include/flexflow/ops/sampling.h b/include/flexflow/ops/sampling.h index 544b728575..2f0c152cfc 100644 --- a/include/flexflow/ops/sampling.h +++ b/include/flexflow/ops/sampling.h @@ -5,8 +5,10 @@ #include "flexflow/model.h" #include "flexflow/node.h" #include "flexflow/ops/sampling_params.h" +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) #include #include +#endif namespace FlexFlow { @@ -20,7 +22,9 @@ class SamplingMeta : public OpMeta { int *idx; void *d_temp_storage; size_t temp_storage_bytes; +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) curandState *state; +#endif SamplingMeta(FFHandler handle, Op const *op, int batch_size, diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index 185f7d02b3..a631ab8919 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -80,9 +80,7 @@ __global__ void sampling_topp_kernel(int batch_size, if (threadIdx.x == 0) { // number must < topp random_n = curand_uniform(state + batch_idx) * topp; - // if (blockIdx.x == 0) { - // printf("batch idx: %d, random num%f\n", batch_idx, random_n); - // } + // printf("batch idx: %d, random num%f\n", batch_idx, random_n); } __syncthreads(); @@ -98,7 +96,7 @@ __global__ void sampling_topp_kernel(int batch_size, result_idx = vocab_size - 1; for (long long j = threadIdx.x; j < vocab_size; j += blockDim.x) { - float logit = (float)sorted_logits[offset + j]; + float logit = (float)(sorted_logits[offset + j]); BlockScan(temp_storage).InclusiveSum(logit, prefix_sum, prefix_op); prefix_sum /= topp; if (prefix_sum >= random_n) { @@ -107,8 +105,8 @@ __global__ void sampling_topp_kernel(int batch_size, } indices_ptr[batch_idx] = sorted_idx[offset + result_idx]; - // if (blockIdx.x == 0 && threadIdx.x == 0) { - // printf("selected idx: %d\n", result_idx); + // if (threadIdx.x == 0) { + // printf("selected idx: %d, %d\n", blockIdx.x, result_idx); // } } @@ -125,11 +123,11 @@ void Sampling::forward_kernel(SamplingMeta const *m, // 2. cumsum size_t temp_storage_bytes = m->temp_storage_bytes; - cub::DeviceSegmentedRadixSort::SortPairsDescending( + checkCUDA(cub::DeviceSegmentedRadixSort::SortPairsDescending( m->d_temp_storage, temp_storage_bytes, input_ptr, - input_ptr, + static_cast
(m->sorted_logits), m->idx, m->sorted_idx, length * batch_size, @@ -138,24 +136,23 @@ void Sampling::forward_kernel(SamplingMeta const *m, m->end_offset + 1, 0, // begin_bit sizeof(DT) * 8, // end_bit = sizeof(KeyT) * 8 - stream); + stream)); int parallelism = batch_size; init_random_kernel<<>>(m->state, batch_size, rand()); sampling_topp_kernel - <<>>(batch_size, - length, - m->state, - input_ptr, - m->sorted_idx, - indices_ptr, - top_p); + <<>>( + batch_size, + length, + m->state, + static_cast
(m->sorted_logits), + m->sorted_idx, + indices_ptr, + top_p); - checkCUDA(cudaDeviceSynchronize()); // topk / topp mask some value and renormalize - // sampling } @@ -238,7 +235,7 @@ SamplingMeta::SamplingMeta(FFHandler handler, // init sort function if (data_type == DT_FLOAT) { - cub::DeviceSegmentedRadixSort::SortPairsDescending( + checkCUDA(cub::DeviceSegmentedRadixSort::SortPairsDescending( d_temp_storage, temp_storage_bytes, input.get_float_ptr(), @@ -251,9 +248,9 @@ SamplingMeta::SamplingMeta(FFHandler handler, end_offset + 1, 0, // begin_bit data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 - stream); + stream)); } else if (data_type == DT_HALF) { - cub::DeviceSegmentedRadixSort::SortPairsDescending( + checkCUDA(cub::DeviceSegmentedRadixSort::SortPairsDescending( d_temp_storage, temp_storage_bytes, input.get_half_ptr(), @@ -266,11 +263,10 @@ SamplingMeta::SamplingMeta(FFHandler handler, end_offset + 1, 0, // begin_bit data_type_size(data_type) * 8, // end_bit = sizeof(KeyT) * 8 - stream); + stream)); } else { assert(false && "input type in float and half"); } - checkCUDA(cudaMalloc(&d_temp_storage, temp_storage_bytes)); } From 147b48ba7cdb83c695bca2db7f13514896e003f8 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Fri, 14 Jul 2023 00:56:35 +0000 Subject: [PATCH 11/16] try torch1. --- conda/pytorch-gpu.yml | 2 +- inference/incr_decoding/incr_decoding.cc | 2 +- inference/models/llama.cc | 1 - src/ops/sampling.cc | 18 ------------------ src/ops/sampling.cu | 8 +------- 5 files changed, 3 insertions(+), 28 deletions(-) diff --git a/conda/pytorch-gpu.yml b/conda/pytorch-gpu.yml index e6702a6572..6ce93d328f 100644 --- a/conda/pytorch-gpu.yml +++ b/conda/pytorch-gpu.yml @@ -7,4 +7,4 @@ dependencies: - numpy>=1.16.0 - pip - pip: - - torch>=1.13.1 + - torch==1.13.1 diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index efa13f9e74..9e4728eac0 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -165,7 +165,6 @@ void FlexFlow::top_level_task(Task const *task, ffconfig.tensor_parallelism_degree = tensor_parallelism_degree; ffconfig.pipeline_parallelism_degree = pipeline_parallelism_degree; - std::cout << "workers: " << num_devices << "\n"; assert(data_parallelism_degree * tensor_parallelism_degree * pipeline_parallelism_degree == ffconfig.numNodes * ffconfig.workersPerNode); @@ -236,6 +235,7 @@ void FlexFlow::top_level_task(Task const *task, assert(fm.get_future_map_domain().get_volume() == 1); Future future = fm.get_future(0); ir = future.get_result(); + // assert(false); } // Execution fence diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 8b7f5b2dc6..f6942537f4 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -212,7 +212,6 @@ void LLAMA::create_llama_model(FFModel &ff, output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); } else { // Tensor softmax = ff.softmax(dense, -1); - if (generationConfig.do_sample) { dense = ff.scalar_truediv(dense, generationConfig.temperature, false); Tensor softmax = ff.softmax(dense, -1); diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc index dcab50ffb2..5695149d8e 100644 --- a/src/ops/sampling.cc +++ b/src/ops/sampling.cc @@ -170,12 +170,6 @@ void Sampling::init_inference(FFModel const &ff, EXCLUSIVE, batch_outputs[0]->region)); launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // batch_outputs[1]->region)); - // launcher.add_field(2, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); @@ -208,12 +202,6 @@ void Sampling::init(FFModel const &ff) { EXCLUSIVE, outputs[0]->region)); launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(outputs[1]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // outputs[1]->region)); - // launcher.add_field(2, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap(ff, fm); @@ -282,12 +270,6 @@ FutureMap Sampling::inference(FFModel const &ff, EXCLUSIVE, batch_outputs[0]->region)); launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // batch_outputs[1]->region)); - // launcher.add_field(2, FID_DATA); return runtime->execute_index_space(ctx, launcher); } diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index a631ab8919..e97df1b8a3 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -73,7 +73,6 @@ __global__ void sampling_topp_kernel(int batch_size, // int const vocab_id = threadIdx.x; int const batch_idx = blockIdx.x; __shared__ float random_n; - // __shared__ float renormalized_sum; __shared__ long long result_idx; // random num @@ -92,7 +91,6 @@ __global__ void sampling_topp_kernel(int batch_size, int offset = batch_idx * vocab_size; float prefix_sum = 0.0f; BlockPrefixCallbackOp prefix_op(0); - // float sum; result_idx = vocab_size - 1; for (long long j = threadIdx.x; j < vocab_size; j += blockDim.x) { @@ -120,8 +118,6 @@ void Sampling::forward_kernel(SamplingMeta const *m, int const batch_size, cudaStream_t stream) { // 1. sort - // 2. cumsum - size_t temp_storage_bytes = m->temp_storage_bytes; checkCUDA(cub::DeviceSegmentedRadixSort::SortPairsDescending( m->d_temp_storage, @@ -142,6 +138,7 @@ void Sampling::forward_kernel(SamplingMeta const *m, min(CUDA_NUM_THREADS, parallelism), 0, stream>>>(m->state, batch_size, rand()); + // sampling sampling_topp_kernel <<>>( batch_size, @@ -151,9 +148,6 @@ void Sampling::forward_kernel(SamplingMeta const *m, m->sorted_idx, indices_ptr, top_p); - - // topk / topp mask some value and renormalize - // sampling } /*static*/ From 9d55ca45e835ed343d21e26c5243762747169a97 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Fri, 14 Jul 2023 02:31:28 +0000 Subject: [PATCH 12/16] . --- conda/pytorch-gpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conda/pytorch-gpu.yml b/conda/pytorch-gpu.yml index 6ce93d328f..e6702a6572 100644 --- a/conda/pytorch-gpu.yml +++ b/conda/pytorch-gpu.yml @@ -7,4 +7,4 @@ dependencies: - numpy>=1.16.0 - pip - pip: - - torch==1.13.1 + - torch>=1.13.1 From 95f7076eec4513aff02fd8bfd25571ac22a3055a Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sun, 16 Jul 2023 15:55:12 -0400 Subject: [PATCH 13/16] batch size --- include/flexflow/ops/sampling.h | 3 ++- src/ops/sampling.cc | 10 ++++------ src/ops/sampling.cpp | 3 ++- src/ops/sampling.cu | 9 ++++----- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/include/flexflow/ops/sampling.h b/include/flexflow/ops/sampling.h index 2f0c152cfc..8ffa6a290a 100644 --- a/include/flexflow/ops/sampling.h +++ b/include/flexflow/ops/sampling.h @@ -95,7 +95,8 @@ class Sampling : public Op { ffStream_t stream); static void forward_kernel_wrapper(SamplingMeta const *m, GenericTensorAccessorW const &input, - GenericTensorAccessorW const &indices); + GenericTensorAccessorW const &indices, + int batch_size); Params get_params() const; public: diff --git a/src/ops/sampling.cc b/src/ops/sampling.cc index 5695149d8e..8c01464042 100644 --- a/src/ops/sampling.cc +++ b/src/ops/sampling.cc @@ -252,7 +252,7 @@ FutureMap Sampling::inference(FFModel const &ff, << std::endl; */ IndexLauncher launcher(SAMPLING_INF_TASK_ID, parallel_is, - TaskArgument(NULL, 0), + TaskArgument(&bc, sizeof(BatchConfig)), argmap, Predicate::TRUE_PRED, false /*must*/, @@ -280,7 +280,7 @@ InferenceResult Runtime *runtime) { assert(regions.size() == 2); assert(task->regions.size() == 2); - Sampling const *sampling = (Sampling const *)task->args; + BatchConfig const *bc = (BatchConfig *)task->args; SamplingMeta const *m = *((SamplingMeta **)task->local_args); GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( @@ -288,10 +288,8 @@ InferenceResult GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); - Sampling::forward_kernel_wrapper(m, input, indices); - - int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; - int batch_size = input.domain.get_volume() / length; + int batch_size = bc->num_active_tokens(); + Sampling::forward_kernel_wrapper(m, input, indices, batch_size); InferenceResult ir; download_tensor( diff --git a/src/ops/sampling.cpp b/src/ops/sampling.cpp index a652e46fe1..4901fe400c 100644 --- a/src/ops/sampling.cpp +++ b/src/ops/sampling.cpp @@ -33,7 +33,8 @@ void Sampling::forward_kernel(SamplingMeta const *m, /*static*/ void Sampling::forward_kernel_wrapper(SamplingMeta const *m, GenericTensorAccessorW const &input, - GenericTensorAccessorW const &indices) { + GenericTensorAccessorW const &indices, + int batch_size) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index e97df1b8a3..5b3a674794 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -151,9 +151,10 @@ void Sampling::forward_kernel(SamplingMeta const *m, } /*static*/ -void Sampling::forward_kernel_wrapper(SamplingMeta const *m, - GenericTensorAccessorW const &input, - GenericTensorAccessorW const &indices) { +void Sampling::forward_kernel_wrapper( + SamplingMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices int batch_size) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -163,8 +164,6 @@ void Sampling::forward_kernel_wrapper(SamplingMeta const *m, cudaEventCreate(&t_end); cudaEventRecord(t_start, stream); } - int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; - int batch_size = input.domain.get_volume() / length; if (input.data_type == DT_HALF) { Sampling::forward_kernel(m, From 10aa0fae40396f583a80cd6893beca0dcedbedbe Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 17 Jul 2023 02:56:39 +0000 Subject: [PATCH 14/16] fix --- src/ops/sampling.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ops/sampling.cu b/src/ops/sampling.cu index 5b3a674794..a91263a621 100644 --- a/src/ops/sampling.cu +++ b/src/ops/sampling.cu @@ -151,10 +151,10 @@ void Sampling::forward_kernel(SamplingMeta const *m, } /*static*/ -void Sampling::forward_kernel_wrapper( - SamplingMeta const *m, - GenericTensorAccessorW const &input, - GenericTensorAccessorW const &indices int batch_size) { +void Sampling::forward_kernel_wrapper(SamplingMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices, + int batch_size) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -164,6 +164,7 @@ void Sampling::forward_kernel_wrapper( cudaEventCreate(&t_end); cudaEventRecord(t_start, stream); } + int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; if (input.data_type == DT_HALF) { Sampling::forward_kernel(m, From 2d4ca810803603d9c08e45d7914c32b4f0d77520 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 17 Jul 2023 12:06:10 -0400 Subject: [PATCH 15/16] rename GenerationConfig SamplingConfig --- include/flexflow/inference.h | 6 +++--- inference/incr_decoding/incr_decoding.cc | 4 ++-- inference/models/llama.cc | 8 ++++---- inference/models/llama.h | 2 +- inference/spec_infer/spec_infer.cc | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index 192573fffa..0c5274e15b 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -65,16 +65,16 @@ struct BeamTree { treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1]; }; -struct GenerationConfig { +struct SamplingConfig { bool do_sample = false; float temperature = 0.8; float topp = 0.6; - GenerationConfig(bool _do_sample, float _temperature, float _topp) { + SamplingConfig(bool _do_sample, float _temperature, float _topp) { temperature = _temperature > 0 ? _temperature : temperature; topp = _topp > 0 ? _topp : topp; do_sample = _do_sample; } - GenerationConfig() {} + SamplingConfig() {} }; // struct BeamTree_v2 { diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index 9e4728eac0..17fc58c53a 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -172,7 +172,7 @@ void FlexFlow::top_level_task(Task const *task, assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); - GenerationConfig generationConfig(do_sample, temperature, topp); + SamplingConfig samplingConfig(do_sample, temperature, topp); InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS); RequestManager rm(model_type, file_paths.tokenizer_file_path, @@ -186,7 +186,7 @@ void FlexFlow::top_level_task(Task const *task, file_paths.llm_config_file_path, file_paths.llm_weight_file_path, INC_DECODING_MODE, - generationConfig, + samplingConfig, use_full_precision); } else if (model_type == ModelType::OPT) { OPT::create_opt_model(model, diff --git a/inference/models/llama.cc b/inference/models/llama.cc index f6942537f4..06dfaebcb1 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -24,7 +24,7 @@ void LLAMA::create_llama_model(FFModel &ff, std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - GenerationConfig generationConfig, + SamplingConfig samplingConfig, bool use_full_precision) { // do not apply cpu offload in beam search model. Config llama_config(model_config_file_path); @@ -212,10 +212,10 @@ void LLAMA::create_llama_model(FFModel &ff, output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); } else { // Tensor softmax = ff.softmax(dense, -1); - if (generationConfig.do_sample) { - dense = ff.scalar_truediv(dense, generationConfig.temperature, false); + if (samplingConfig.do_sample) { + dense = ff.scalar_truediv(dense, samplingConfig.temperature, false); Tensor softmax = ff.softmax(dense, -1); - output = ff.sampling(softmax, generationConfig.topp); + output = ff.sampling(softmax, samplingConfig.topp); } else { output = ff.arg_top_k(dense, /*k=*/1, false); } diff --git a/inference/models/llama.h b/inference/models/llama.h index 31959e4938..6f80194d72 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -107,7 +107,7 @@ class LLAMA { std::string const &model_config_file_path, std::string const &weight_file_path, InferenceMode mode, - GenerationConfig generationConfig, + SamplingConfig samplingConfig, bool use_full_precision = false); }; diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index de7605234a..a4c3dc64f9 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -199,7 +199,7 @@ void FlexFlow::top_level_task(Task const *task, } // Create SentencePiece tokenizer or OPT tokenizer - GenerationConfig generationConfig; + SamplingConfig samplingConfig; InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS); RequestManager rm(model_types.llm_model_type, file_paths.tokenizer_file_path, @@ -214,7 +214,7 @@ void FlexFlow::top_level_task(Task const *task, file_paths.llm_config_file_path, file_paths.llm_weight_file_path, TREE_VERIFY_MODE, - generationConfig, + samplingConfig, use_full_precision); } else if (model_types.llm_model_type == ModelType::OPT) { OPT::create_opt_model(tree_model, @@ -247,7 +247,7 @@ void FlexFlow::top_level_task(Task const *task, file_paths.ssm_config_file_paths[ssm_id], file_paths.ssm_weight_file_paths[ssm_id], BEAM_SEARCH_MODE, - generationConfig, + samplingConfig, use_full_precision); } else if (model_types.ssm_model_types[ssm_id] == ModelType::OPT) { OPT::create_opt_model(beam_model, From 6a4a2679cf68dc0adf3420790fb6831b7730b778 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Tue, 18 Jul 2023 18:10:00 -0400 Subject: [PATCH 16/16] bug fix --- src/runtime/model.cc | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/runtime/model.cc b/src/runtime/model.cc index c914ec0f7d..22515a2bb0 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -5418,16 +5418,31 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(SAMPLING_INIT_TASK_ID, "Sampling Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - Runtime::preregister_task_variant( - registrar, "Sampling Init Task"); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "Sampling Init Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } } { TaskVariantRegistrar registrar(SAMPLING_INF_TASK_ID, "Sampling Inference"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - Runtime::preregister_task_variant( - registrar, "Sampling Inference Task"); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "Sampling Inference Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant( + registrar); + } } // Transpose task {