From 6b7e6f0ca158bb33685e6ed2fd77b9e867c2ab53 Mon Sep 17 00:00:00 2001 From: xinhaoc <99570243+xinhaoc@users.noreply.github.com> Date: Thu, 27 Jul 2023 14:36:09 -0400 Subject: [PATCH] Inference: add argmax operator (#888) * add argmax operator * support spec infer. * format * remove redundant * half precision * fix * fix * hip_rocm --- include/flexflow/ffconst.h | 1 + include/flexflow/model.h | 7 + include/flexflow/operator_params.h | 2 + include/flexflow/ops/argmax.h | 109 +++++++ include/flexflow/ops/argmax_params.h | 24 ++ inference/models/llama.cc | 6 +- inference/models/opt.cc | 6 +- src/ops/argmax.cc | 442 +++++++++++++++++++++++++++ src/ops/argmax.cpp | 69 +++++ src/ops/argmax.cu | 151 +++++++++ src/runtime/cuda_helper.cu | 1 + src/runtime/ffconst_utils.cc | 2 + src/runtime/graph.cc | 5 + src/runtime/model.cc | 59 +++- src/runtime/operator_params.cc | 3 + 15 files changed, 882 insertions(+), 5 deletions(-) create mode 100644 include/flexflow/ops/argmax.h create mode 100644 include/flexflow/ops/argmax_params.h create mode 100644 src/ops/argmax.cc create mode 100644 src/ops/argmax.cpp create mode 100644 src/ops/argmax.cu diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 65fa23569b..7521613477 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -163,6 +163,7 @@ enum OperatorType { OP_GATHER, // https://pytorch.org/docs/stable/generated/torch.gather.html OP_RMS_NORM, OP_BEAM_TOPK, + OP_ARGMAX, OP_INC_MULTIHEAD_SELF_ATTENTION, OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION, OP_TREE_INC_MULTIHEAD_SELF_ATTENTION, diff --git a/include/flexflow/model.h b/include/flexflow/model.h index a95c229a08..0e98b6e8ad 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -138,6 +138,9 @@ enum TaskIDs { ARG_TOPK_INF_TASK_ID, SAMPLING_INIT_TASK_ID, SAMPLING_INF_TASK_ID, + ARGMAX_INIT_TASK_ID, + ARGMAX_BEAM_INF_TASK_ID, + ARGMAX_NORM_INF_TASK_ID, TRANSPOSE_INIT_TASK_ID, TRANSPOSE_FWD_TASK_ID, TRANSPOSE_BWD_TASK_ID, @@ -315,6 +318,7 @@ class BeamTopK; class SpecIncMultiHeadSelfAttention; class IncMultiQuerySelfAttention; class Sampling; +class ArgMax; class Combine; class Repartition; class Reduction; @@ -615,6 +619,7 @@ class FFModel { int k, bool sorted, char const *name = NULL); + Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL); Tensor sampling(const Tensor input, float top_p, char const *name = NULL); Tensor multihead_attention(const Tensor query, const Tensor key, @@ -1067,6 +1072,8 @@ class FFModel { BeamTopK *>, std::unordered_map, Sampling *>, + std::unordered_map, + ArgMax *>, std::unordered_map< std::pair, SpecIncMultiHeadSelfAttention *>, diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 5c2101d190..982d5482a0 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -4,6 +4,7 @@ #include "flexflow/ops/aggregate_params.h" #include "flexflow/ops/aggregate_spec_params.h" #include "flexflow/ops/arg_topk_params.h" +#include "flexflow/ops/argmax_params.h" #include "flexflow/ops/attention_params.h" #include "flexflow/ops/batch_matmul_params.h" #include "flexflow/ops/beam_topk_params.h" @@ -73,6 +74,7 @@ using OperatorParameters = mp::variant const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void forward(FFModel const &) override; + void backward(FFModel const &) override; + Legion::FutureMap inference(FFModel const &, + BatchConfig const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void print_layer(FFModel const &model) override { + assert(0); + } + static Op * + create_operator_from_layer(FFModel &model, + Layer const *layer, + std::vector const &inputs); + + static OpMeta *init_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static BeamInferenceResult + inference_task_beam(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static InferenceResult + inference_task_norm(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + void serialize(Legion::Serializer &s) const override; + static PCG::Node deserialize(FFModel &ff, + Legion::Deserializer &d, + ParallelTensor inputs[], + int num_inputs); + Op *materialize(FFModel &ff, + ParallelTensor inputs[], + int num_inputs) const override; + bool measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const override; + template + static void forward_kernel(ArgMaxMeta const *m, + DT *input_ptr, + int *indices_ptr, + DT *prob_ptr, + int *parent_ptr, + int length, + int batch_size, + ffStream_t stream); + static void forward_kernel_wrapper(ArgMaxMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices, + GenericTensorAccessorW const &value, + GenericTensorAccessorW const &parent); + Params get_params() const; + +public: + bool beam_search; +}; + +}; // namespace FlexFlow + +#endif \ No newline at end of file diff --git a/include/flexflow/ops/argmax_params.h b/include/flexflow/ops/argmax_params.h new file mode 100644 index 0000000000..a8f629619f --- /dev/null +++ b/include/flexflow/ops/argmax_params.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_ARGMAX_PARAMS_H +#define _FLEXFLOW_ARGMAX_PARAMS_H + +#include "flexflow/ffconst.h" +#include "flexflow/parallel_tensor.h" + +namespace FlexFlow { + +struct ArgMaxParams { + bool beam_search; + bool is_valid(ParallelTensorShape const &) const; +}; +bool operator==(ArgMaxParams const &, ArgMaxParams const &); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ArgMaxParams const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_ARGMAX_PARAMS_H \ No newline at end of file diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 588d6d264c..e4cd54192d 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -176,7 +176,8 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor output; if (mode == BEAM_SEARCH_MODE) { Tensor softmax = ff.softmax(dense, -1); - output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); + // output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); + output = ff.argmax(softmax, /*beam_Search*/ true); } else { // Tensor softmax = ff.softmax(dense, -1); if (samplingConfig.do_sample) { @@ -184,7 +185,8 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor softmax = ff.softmax(dense, -1); output = ff.sampling(softmax, samplingConfig.topp); } else { - output = ff.arg_top_k(dense, /*k=*/1, false); + // output = ff.arg_top_k(dense, /*k=*/1, false); + output = ff.argmax(dense, /*beam_Search*/ false); } } diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 94aeb7f2bd..05cee2bf9d 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -215,9 +215,11 @@ void OPT::create_opt_model(FFModel &ff, Tensor output; if (mode == BEAM_SEARCH_MODE) { Tensor softmax = ff.softmax(lm_head, -1); - output = ff.beam_top_k(softmax, opt_config.max_beam_width, false); + // output = ff.beam_top_k(softmax, opt_config.max_beam_width, false); + output = ff.argmax(softmax, /*beam_Search*/ true); } else { - output = ff.arg_top_k(lm_head, /*k=*/1, false); + // output = ff.arg_top_k(lm_head, /*k=*/1, false); + output = ff.argmax(lm_head, /*beam_Search*/ false); } //------------------- compile the model -------------------------------- diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc new file mode 100644 index 0000000000..754337448e --- /dev/null +++ b/src/ops/argmax.cc @@ -0,0 +1,442 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/ops/argmax.h" +#include "flexflow/model.h" +#include "flexflow/utils/hash_utils.h" +#include "legion/legion_utilities.h" +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +#include "flexflow/utils/cuda_helper.h" +#else +#include "flexflow/utils/hip_helper.h" +#endif + +namespace FlexFlow { +// declare Legion names +using Legion::ArgumentMap; +using Legion::Context; +using Legion::coord_t; +using Legion::Domain; +using Legion::FutureMap; +using Legion::IndexLauncher; +using Legion::InlineLauncher; +using Legion::Machine; +using Legion::Memory; +using Legion::PhysicalRegion; +using Legion::Predicate; +using Legion::Rect; +using Legion::RegionRequirement; +using Legion::Runtime; +using Legion::Task; +using Legion::TaskArgument; +using Legion::TaskLauncher; +using PCG::Node; + +Tensor FFModel::argmax(const Tensor input, bool beam_search, char const *name) { + Layer *li = new Layer(this, + OP_ARGMAX, + input->data_type, + name, + 1 /*inputs*/, + 0 /*weights*/, + beam_search ? 3 : 2 /*outputs*/, + input); + { + int numdims = input->num_dims; + int dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdims; i++) { + dims[i] = input->dims[i]; + } + // now just support 1 output + dims[0] = 1; + // li->outputs[0] = create_tensor_legion_ordering( + // numdims, dims, input->data_type, li, 0, true /*create_grad*/); + li->outputs[0] = create_tensor_legion_ordering( + numdims, dims, DT_INT32, li, 0, false /*create_grad*/); + // logits + li->outputs[1] = create_tensor_legion_ordering( + numdims, dims, input->data_type, li, 1, false /*create_grad*/); + + if (beam_search) { + // parent id + li->outputs[2] = create_tensor_legion_ordering( + numdims, dims, DT_INT32, li, 1, false /*create_grad*/); + } + } + li->add_int_property("beam_search", beam_search); + layers.push_back(li); + // outputs[0] = li->outputs[0]; + // outputs[1] = li->outputs[1]; + return li->outputs[0]; +} + +Op *ArgMax::create_operator_from_layer( + FFModel &model, + Layer const *layer, + std::vector const &inputs) { + long long value; + layer->get_int_property("beam_search", value); + bool beam_search = (bool)value; + return new ArgMax(model, inputs[0], beam_search, layer->name); +} + +ArgMaxParams ArgMax::get_params() const { + ArgMaxParams params; + params.beam_search = this->beam_search; + return params; +} + +bool ArgMaxParams::is_valid(ParallelTensorShape const &) const { + return true; +} + +bool operator==(ArgMaxParams const &lhs, ArgMaxParams const &rhs) { + return lhs.beam_search == rhs.beam_search; +} + +ArgMax::ArgMax(FFModel &model, + const ParallelTensor _input, + bool _beam_search, + char const *name) + : Op(model, + OP_ARGMAX, + _input->data_type, + name, + 1 /*inputs*/, + 0 /*weights*/, + _beam_search ? 3 : 2 /*outputs*/, + _input), + beam_search(_beam_search) { + int numdim = inputs[0]->num_dims; + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdim; i++) { + dims[i] = inputs[0]->dims[i]; + } + dims[0].size = 1; + assert(inputs[0]->dims[0].degree == 1); + assert(inputs[0]->dims[0].parallel_idx == -1); + // outputs[0] = model.create_parallel_tensor_legion_ordering( + // numdim, dims, _input->data_type, this, 0 /*owner_idx*/); + outputs[0] = model.create_parallel_tensor_legion_ordering( + numdim, dims, DT_INT32, this, 0 /*owner_idx*/); + outputs[1] = model.create_parallel_tensor_legion_ordering( + numdim, dims, _input->data_type, this, 1 /*owner_idx*/); + if (_beam_search) { + outputs[2] = model.create_parallel_tensor_legion_ordering( + numdim, dims, DT_INT32, this, 2 /*owner_idx*/); + } +} + +ArgMax::ArgMax(FFModel &model, ArgMax const &other, const ParallelTensor input) + : ArgMax(model, input, other.beam_search, other.name) {} + +ArgMax::ArgMax(FFModel &model, + ArgMaxParams const ¶ms, + const ParallelTensor input, + char const *name) + : ArgMax(model, input, params.beam_search, name) {} + +void ArgMax::init_inference(FFModel const &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + assert(check_output_input_weight_same_parallel_is()); + parallel_is = batch_outputs[0]->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + size_t machine_view_hash = view->hash(); + set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]); + IndexLauncher launcher(ARGMAX_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(ArgMax)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[1]->region)); + launcher.add_field(2, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); +} + +void ArgMax::init(FFModel const &ff) { + assert(check_output_input_weight_same_parallel_is()); + parallel_is = outputs[0]->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + set_argumentmap_for_init(ff, argmap); + IndexLauncher launcher(ARGMAX_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(ArgMax)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap(ff, fm); +} + +OpMeta *ArgMax::init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + ArgMax *s = (ArgMax *)task->args; + FFHandler handle = *((FFHandler *)task->local_args); + GenericTensorAccessorW acc_input = + helperGetGenericTensorAccessorRW(s->inputs[0]->data_type, + regions[0], + task->regions[0], + FID_DATA, + ctx, + runtime); + Domain input_domain = runtime->get_index_space_domain( + ctx, task->regions[0].region.get_index_space()); + Domain output_domain = runtime->get_index_space_domain( + ctx, task->regions[2].region.get_index_space()); + + ArgMaxMeta *m = + new ArgMaxMeta(handle, s, input_domain, output_domain, acc_input); + m->profiling = s->profiling; + m->beam_search = s->beam_search; + return m; +} + +void ArgMax::forward(FFModel const &ff) { + // ArgMax does not support forward + assert(false); +} + +FutureMap ArgMax::inference(FFModel const &ff, + BatchConfig const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = batch_outputs[0]->parallel_is; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + size_t machine_view_hash = view->hash(); + /* std::cout << "ArgMax op machine_view: " << *(MachineView const *)mv + << std::endl; */ + if (beam_search) { + IndexLauncher launcher(ARGMAX_BEAM_INF_TASK_ID, + parallel_is, + TaskArgument(&bc, sizeof(BatchConfig)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[1]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[1]->region)); + launcher.add_field(2, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[2]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[2]->region)); + launcher.add_field(3, FID_DATA); + return runtime->execute_index_space(ctx, launcher); + } else { + IndexLauncher launcher(ARGMAX_NORM_INF_TASK_ID, + parallel_is, + TaskArgument(&bc, sizeof(BatchConfig)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[1]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[1]->region)); + launcher.add_field(2, FID_DATA); + return runtime->execute_index_space(ctx, launcher); + } +} + +BeamInferenceResult + ArgMax::inference_task_beam(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 4); + assert(task->regions.size() == 4); + BatchConfig const *bc = (BatchConfig *)task->args; + ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args); + + GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( + DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + int batch_size = bc->num_active_tokens(); + GenericTensorAccessorW value = helperGetGenericTensorAccessorWO( + m->input_type[0], regions[2], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO( + DT_INT32, regions[3], task->regions[1], FID_DATA, ctx, runtime); + ArgMax::forward_kernel_wrapper(m, input, indices, value, parent); + + BeamInferenceResult ir; + download_tensor( + indices.get_int32_ptr(), ir.token_ids, batch_size); + if (m->input_type[0] == DT_FLOAT) { + download_tensor(value.get_float_ptr(), ir.probs, batch_size); + } else if (m->input_type[0] == DT_HALF) { + download_tensor(m->probs, ir.probs, batch_size); + } + + download_tensor(parent.get_int32_ptr(), ir.parent_id, batch_size); + return ir; +} + +InferenceResult + ArgMax::inference_task_norm(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 3); + assert(task->regions.size() == 3); + BatchConfig const *bc = (BatchConfig *)task->args; + ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args); + + GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( + DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW value = helperGetGenericTensorAccessorWO( + m->input_type[0], regions[2], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW parent; + int batch_size = bc->num_active_tokens(); + ArgMax::forward_kernel_wrapper(m, input, indices, value, parent); + InferenceResult ir; + download_tensor( + indices.get_int32_ptr(), ir.token_ids, batch_size); + return ir; +} + +void ArgMax::backward(FFModel const &ff) { + // ArgMax does not support backward + assert(false); +} + +void ArgMax::serialize(Legion::Serializer &sez) const { + sez.serialize(this->beam_search); +} + +Node ArgMax::deserialize(FFModel &ff, + Legion::Deserializer &dez, + ParallelTensor inputs[], + int num_inputs) { + assert(num_inputs == 1); + bool beam_search; + dez.deserialize(beam_search); + ArgMaxParams params; + params.beam_search = beam_search; + return ff.get_or_create_node(inputs[0], params); +} + +Op *ArgMax::materialize(FFModel &ff, + ParallelTensor inputs[], + int num_inputs) const { + ArgMaxParams params = get_params(); + return new ArgMax(ff, params, inputs[0], this->name); +} + +bool ArgMax::measure_operator_cost(Simulator *sim, + MachineView const &mv, + CostMetrics &cost_metrics) const { + return false; +} + +}; // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ArgMaxParams const ¶ms) const { + size_t key = 0; + hash_combine(key, params.beam_search); + return key; +} +}; // namespace std \ No newline at end of file diff --git a/src/ops/argmax.cpp b/src/ops/argmax.cpp new file mode 100644 index 0000000000..1395a1cdeb --- /dev/null +++ b/src/ops/argmax.cpp @@ -0,0 +1,69 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/ops/argmax.h" +#include "flexflow/ffconst_utils.h" +#include "flexflow/utils/hip_helper.h" +#include + +namespace FlexFlow { + +/*static*/ +template +void ArgMax::forward_kernel(ArgMaxMeta const *m, + DT *input_ptr, + int *indices_ptr, + DT *prob_ptr, + int *parent_ptr, + int length, + int batch_size, + ffStream_t stream) {} + +/*static*/ +void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices, + GenericTensorAccessorW const &value, + GenericTensorAccessorW const &parent) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + hipEvent_t t_start, t_end; + if (m->profiling) { + hipEventCreate(&t_start); + hipEventCreate(&t_end); + hipEventRecord(t_start, stream); + } + + handle_unimplemented_hip_kernel(OP_RMS_NORM); + + if (m->profiling) { + hipEventRecord(t_end, stream); + checkCUDA(hipEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(hipEventElapsedTime(&elapsed, t_start, t_end)); + hipEventDestroy(t_start); + hipEventDestroy(t_end); + } +} + +ArgMaxMeta::ArgMaxMeta(FFHandler handler, + Op const *op, + Legion::Domain const &input_domain, + Legion::Domain const &output_domain, + GenericTensorAccessorW input) + : OpMeta(handler, op) {} + +}; // namespace FlexFlow \ No newline at end of file diff --git a/src/ops/argmax.cu b/src/ops/argmax.cu new file mode 100644 index 0000000000..99487ea380 --- /dev/null +++ b/src/ops/argmax.cu @@ -0,0 +1,151 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/ffconst_utils.h" +#include "flexflow/ops/argmax.h" +#include "flexflow/utils/cuda_helper.h" + +namespace FlexFlow { + +__global__ void + half_2_float_array(half *ptr, float *ptr_f, int num_of_elements) { + CUDA_KERNEL_LOOP(i, num_of_elements) { + ptr_f[i] = __half2float(ptr[i]); + } +} + +/*static*/ +template +void ArgMax::forward_kernel(ArgMaxMeta const *m, + DT *input_ptr, + int *indices_ptr, + DT *prob_ptr, + int *parent, + int const length, + int const batch_size, + cudaStream_t stream) { + + checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + DT alpha = 1.0f, beta = 0.0f; + if (m->beam_search) { + // set all parents id zero in arg top1 case. + checkCUDA(cudaMemset(parent, 0, batch_size * sizeof(int))); + } + checkCUDNN(cudnnReduceTensor(m->handle.dnn, + m->reduceMaxDesc, + indices_ptr /*indices*/, + batch_size * sizeof(int) /*indicesSizeInBytes*/, + m->handle.workSpace, + m->handle.workSpaceSize, + &alpha, + m->inputTensor, + input_ptr, + &beta, + m->outputTensor, + prob_ptr)); +} + +/*static*/ +void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, + GenericTensorAccessorW const &input, + GenericTensorAccessorW const &indices, + GenericTensorAccessorW const &value, + GenericTensorAccessorW const &parent) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + cudaEvent_t t_start, t_end; + if (m->profiling) { + cudaEventCreate(&t_start); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); + } + int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; + int batch_size = input.domain.get_volume() / length; + + if (input.data_type == DT_HALF) { + ArgMax::forward_kernel(m, + input.get_half_ptr(), + indices.get_int32_ptr(), + value.get_half_ptr(), + m->beam_search ? parent.get_int32_ptr() + : nullptr, + length, + batch_size, + stream); + if (m->beam_search) { + half_2_float_array<<>>( + value.get_half_ptr(), m->probs, batch_size); + } + + } else if (input.data_type == DT_FLOAT) { + ArgMax::forward_kernel(m, + input.get_float_ptr(), + indices.get_int32_ptr(), + value.get_float_ptr(), + m->beam_search ? parent.get_int32_ptr() + : nullptr, + length, + batch_size, + stream); + } else { + assert(false && "Unsupported data type"); + } + + if (m->profiling) { + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); + cudaEventDestroy(t_start); + cudaEventDestroy(t_end); + printf("[ArgMax] forward time = %.2lfms\n", elapsed); + } +} + +ArgMaxMeta::ArgMaxMeta(FFHandler handler, + Op const *op, + Legion::Domain const &input_domain, + Legion::Domain const &output_domain, + GenericTensorAccessorW input) + : OpMeta(handler, op) { + DataType data_type = op->data_type; + checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); + checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); + checkCUDNN(cudnnCreateReduceTensorDescriptor(&reduceMaxDesc)); + + // Float and Half use save type, according to + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnReduceTensor:~:text=not%20coordinate%20tuples.-,The%20data%20types%20of%20the%20tensors,.,-Note%3A + cudnnDataType_t cudnn_data_type = CUDNN_DATA_FLOAT; + + checkCUDNN( + cudnnSetReduceTensorDescriptor(reduceMaxDesc, + CUDNN_REDUCE_TENSOR_MAX, + cudnn_data_type, + CUDNN_PROPAGATE_NAN, + CUDNN_REDUCE_TENSOR_FLATTENED_INDICES, + CUDNN_32BIT_INDICES)); + checkCUDNN(cudnnSetTensorDescriptorFromDomain( + outputTensor, output_domain, data_type)); + checkCUDNN( + cudnnSetTensorDescriptorFromDomain(inputTensor, input_domain, data_type)); + + checkCUDA(cudaMalloc(&probs, sizeof(float) * BatchConfig::MAX_NUM_TOKENS)); +} + +}; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index dff5157a8a..da22a245f1 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -219,6 +219,7 @@ __host__ void cudaHostAllocPortable | cudaHostAllocMapped)); checkCUDA(cudaMemcpyAsync( host_ptr, ptr, sizeof(T) * num_elements, cudaMemcpyDeviceToHost, stream)); + cudaDeviceSynchronize(); int idx = 0; printf("%s", prefix); for (idx = 0; idx < num_elements; idx++) { diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index a777605daf..35ec59ce03 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -180,6 +180,8 @@ std::string get_operator_type_name(OperatorType type) { return "Identity"; case OP_SAMPLING: return "Sampling"; + case OP_ARGMAX: + return "ArgMax"; // Parallel Ops case OP_REPARTITION: return "Repartition"; diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 16bccc25df..a82add4b62 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -17,6 +17,7 @@ #include "flexflow/ffconst_utils.h" #include "flexflow/ops/aggregate.h" #include "flexflow/ops/arg_topk.h" +#include "flexflow/ops/argmax.h" #include "flexflow/ops/attention.h" #include "flexflow/ops/batch_matmul.h" #include "flexflow/ops/beam_topk.h" @@ -2924,6 +2925,10 @@ void FFModel::deserialize_graph_optimal_view( node = Sampling::deserialize(*this, dez, inputs, num_inputs); break; } + case OP_ARGMAX: { + node = ArgMax::deserialize(*this, dez, inputs, num_inputs); + break; + } case OP_GROUP_BY: { node = Group_by::deserialize(*this, dez, inputs, num_inputs); break; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 22515a2bb0..66cad1f248 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -25,6 +25,7 @@ #include "flexflow/ops/aggregate.h" #include "flexflow/ops/aggregate_spec.h" #include "flexflow/ops/arg_topk.h" +#include "flexflow/ops/argmax.h" #include "flexflow/ops/attention.h" #include "flexflow/ops/batch_matmul.h" #include "flexflow/ops/batch_norm.h" @@ -2943,6 +2944,11 @@ Op *FFModel::create_operator_from_layer( operators.push_back(op); return op; } + case OP_ARGMAX: { + Op *op = ArgMax::create_operator_from_layer(*this, layer, inputs); + operators.push_back(op); + return op; + } case OP_GROUP_BY: { Op *op = Group_by::create_operator_from_layer(*this, layer, inputs); operators.push_back(op); @@ -2984,7 +2990,8 @@ void FFModel::create_operators_from_layers() { // add a combine before arg_topk if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && - (l->op_type == OP_ARG_TOPK || l->op_type == OP_SOFTMAX)) { + (l->op_type == OP_ARG_TOPK || l->op_type == OP_SOFTMAX || + l->op_type == OP_ARGMAX)) { std::vector partitioned_inputs; assert(inputs.size() == 1); Combine *comb = new Combine(*this, @@ -5444,6 +5451,56 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar); } } + // ArgMax task + { + TaskVariantRegistrar registrar(ARGMAX_INIT_TASK_ID, "ArgMax Init"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "ArgMax Init Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } + { + TaskVariantRegistrar registrar(ARGMAX_BEAM_INF_TASK_ID, + "ArgMax Beam Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "ArgMax Inference Task Beam"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } + { + TaskVariantRegistrar registrar(ARGMAX_NORM_INF_TASK_ID, + "ArgMax Norm Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "ArgMax Inference Task Norm"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime + ->register_task_variant( + registrar); + } + } // Transpose task { TaskVariantRegistrar registrar(TRANSPOSE_INIT_TASK_ID, "Transpose Init"); diff --git a/src/runtime/operator_params.cc b/src/runtime/operator_params.cc index 8fb8c89b10..bf817f5351 100644 --- a/src/runtime/operator_params.cc +++ b/src/runtime/operator_params.cc @@ -2,6 +2,7 @@ #include "flexflow/ops/aggregate.h" #include "flexflow/ops/aggregate_spec.h" #include "flexflow/ops/arg_topk.h" +#include "flexflow/ops/argmax.h" #include "flexflow/ops/attention.h" #include "flexflow/ops/batch_matmul.h" #include "flexflow/ops/batch_norm.h" @@ -133,6 +134,8 @@ tl::optional get_op_parameters(Op const *op) { return ((BeamTopK *)op)->get_params(); case OP_SAMPLING: return ((Sampling *)op)->get_params(); + case OP_ARGMAX: + return ((ArgMax *)op)->get_params(); // TODO: implement the get_params() function for the operators below and // uncomment the lines below