Skip to content

Commit

Permalink
Inference: add argmax operator (#888)
Browse files Browse the repository at this point in the history
* add argmax operator

* support spec infer.

* format

* remove redundant

* half precision

* fix

* fix

* hip_rocm
  • Loading branch information
xinhaoc authored Jul 27, 2023
1 parent aef158a commit 6b7e6f0
Show file tree
Hide file tree
Showing 15 changed files with 882 additions and 5 deletions.
1 change: 1 addition & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ enum OperatorType {
OP_GATHER, // https://pytorch.org/docs/stable/generated/torch.gather.html
OP_RMS_NORM,
OP_BEAM_TOPK,
OP_ARGMAX,
OP_INC_MULTIHEAD_SELF_ATTENTION,
OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION,
OP_TREE_INC_MULTIHEAD_SELF_ATTENTION,
Expand Down
7 changes: 7 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ enum TaskIDs {
ARG_TOPK_INF_TASK_ID,
SAMPLING_INIT_TASK_ID,
SAMPLING_INF_TASK_ID,
ARGMAX_INIT_TASK_ID,
ARGMAX_BEAM_INF_TASK_ID,
ARGMAX_NORM_INF_TASK_ID,
TRANSPOSE_INIT_TASK_ID,
TRANSPOSE_FWD_TASK_ID,
TRANSPOSE_BWD_TASK_ID,
Expand Down Expand Up @@ -315,6 +318,7 @@ class BeamTopK;
class SpecIncMultiHeadSelfAttention;
class IncMultiQuerySelfAttention;
class Sampling;
class ArgMax;
class Combine;
class Repartition;
class Reduction;
Expand Down Expand Up @@ -615,6 +619,7 @@ class FFModel {
int k,
bool sorted,
char const *name = NULL);
Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL);
Tensor sampling(const Tensor input, float top_p, char const *name = NULL);
Tensor multihead_attention(const Tensor query,
const Tensor key,
Expand Down Expand Up @@ -1067,6 +1072,8 @@ class FFModel {
BeamTopK *>,
std::unordered_map<std::pair<ParallelTensorShape, SamplingParams>,
Sampling *>,
std::unordered_map<std::pair<ParallelTensorShape, ArgMaxParams>,
ArgMax *>,
std::unordered_map<
std::pair<ParallelTensorShape, SpecIncMultiHeadSelfAttentionParams>,
SpecIncMultiHeadSelfAttention *>,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/operator_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "flexflow/ops/aggregate_params.h"
#include "flexflow/ops/aggregate_spec_params.h"
#include "flexflow/ops/arg_topk_params.h"
#include "flexflow/ops/argmax_params.h"
#include "flexflow/ops/attention_params.h"
#include "flexflow/ops/batch_matmul_params.h"
#include "flexflow/ops/beam_topk_params.h"
Expand Down Expand Up @@ -73,6 +74,7 @@ using OperatorParameters = mp::variant<AggregateParams,
TopKParams,
ArgTopKParams,
SamplingParams,
ArgMaxParams,
SoftmaxParams,
TransposeParams,
RepartitionParams,
Expand Down
109 changes: 109 additions & 0 deletions include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#ifndef _FLEXFLOW_ARG_MAX_H_
#define _FLEXFLOW_ARG_MAX_H_

#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/node.h"
#include "flexflow/ops/argmax_params.h"

namespace FlexFlow {

class ArgMaxMeta : public OpMeta {
public:
bool beam_search;
float *probs;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnReduceTensorDescriptor_t reduceMaxDesc;
#else
miopenTensorDescriptor_t inputTensor, outputTensor;
miopenReduceTensorDescriptor_t reduceMaxDesc;
#endif
ArgMaxMeta(FFHandler handler,
Op const *op,
Legion::Domain const &input_domain,
Legion::Domain const &output_domain,
GenericTensorAccessorW input);
};

class ArgMax : public Op {
public:
using Params = ArgMaxParams;
using Input = ParallelTensor;
ArgMax(FFModel &model,
const ParallelTensor input,
bool beam_search,
char const *name);
ArgMax(FFModel &model, ArgMax const &other, const ParallelTensor input);
ArgMax(FFModel &model,
Params const &params,
Input const input,
char const *name = nullptr);
void init(FFModel const &) override;
void init_inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void forward(FFModel const &) override;
void backward(FFModel const &) override;
Legion::FutureMap inference(FFModel const &,
BatchConfig const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void print_layer(FFModel const &model) override {
assert(0);
}
static Op *
create_operator_from_layer(FFModel &model,
Layer const *layer,
std::vector<ParallelTensor> const &inputs);

static OpMeta *init_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static BeamInferenceResult
inference_task_beam(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static InferenceResult
inference_task_norm(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
void serialize(Legion::Serializer &s) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
ParallelTensor inputs[],
int num_inputs);
Op *materialize(FFModel &ff,
ParallelTensor inputs[],
int num_inputs) const override;
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;
template <typename DT>
static void forward_kernel(ArgMaxMeta const *m,
DT *input_ptr,
int *indices_ptr,
DT *prob_ptr,
int *parent_ptr,
int length,
int batch_size,
ffStream_t stream);
static void forward_kernel_wrapper(ArgMaxMeta const *m,
GenericTensorAccessorW const &input,
GenericTensorAccessorW const &indices,
GenericTensorAccessorW const &value,
GenericTensorAccessorW const &parent);
Params get_params() const;

public:
bool beam_search;
};

}; // namespace FlexFlow

#endif
24 changes: 24 additions & 0 deletions include/flexflow/ops/argmax_params.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef _FLEXFLOW_ARGMAX_PARAMS_H
#define _FLEXFLOW_ARGMAX_PARAMS_H

#include "flexflow/ffconst.h"
#include "flexflow/parallel_tensor.h"

namespace FlexFlow {

struct ArgMaxParams {
bool beam_search;
bool is_valid(ParallelTensorShape const &) const;
};
bool operator==(ArgMaxParams const &, ArgMaxParams const &);

} // namespace FlexFlow

namespace std {
template <>
struct hash<FlexFlow::ArgMaxParams> {
size_t operator()(FlexFlow::ArgMaxParams const &) const;
};
} // namespace std

#endif // _FLEXFLOW_ARGMAX_PARAMS_H
6 changes: 4 additions & 2 deletions inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,17 @@ void LLAMA::create_llama_model(FFModel &ff,
Tensor output;
if (mode == BEAM_SEARCH_MODE) {
Tensor softmax = ff.softmax(dense, -1);
output = ff.beam_top_k(softmax, llama_config.max_beam_width, false);
// output = ff.beam_top_k(softmax, llama_config.max_beam_width, false);
output = ff.argmax(softmax, /*beam_Search*/ true);
} else {
// Tensor softmax = ff.softmax(dense, -1);
if (samplingConfig.do_sample) {
dense = ff.scalar_truediv(dense, samplingConfig.temperature, false);
Tensor softmax = ff.softmax(dense, -1);
output = ff.sampling(softmax, samplingConfig.topp);
} else {
output = ff.arg_top_k(dense, /*k=*/1, false);
// output = ff.arg_top_k(dense, /*k=*/1, false);
output = ff.argmax(dense, /*beam_Search*/ false);
}
}

Expand Down
6 changes: 4 additions & 2 deletions inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,11 @@ void OPT::create_opt_model(FFModel &ff,
Tensor output;
if (mode == BEAM_SEARCH_MODE) {
Tensor softmax = ff.softmax(lm_head, -1);
output = ff.beam_top_k(softmax, opt_config.max_beam_width, false);
// output = ff.beam_top_k(softmax, opt_config.max_beam_width, false);
output = ff.argmax(softmax, /*beam_Search*/ true);
} else {
output = ff.arg_top_k(lm_head, /*k=*/1, false);
// output = ff.arg_top_k(lm_head, /*k=*/1, false);
output = ff.argmax(lm_head, /*beam_Search*/ false);
}

//------------------- compile the model --------------------------------
Expand Down
Loading

0 comments on commit 6b7e6f0

Please sign in to comment.