Skip to content

Commit

Permalink
support spec infer.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Jul 26, 2023
1 parent d0d61b5 commit 29776bb
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 92 deletions.
5 changes: 3 additions & 2 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ enum TaskIDs {
SAMPLING_INIT_TASK_ID,
SAMPLING_INF_TASK_ID,
ARGMAX_INIT_TASK_ID,
ARGMAX_INF_TASK_ID,
ARGMAX_BEAM_INF_TASK_ID,
ARGMAX_NORM_INF_TASK_ID,
TRANSPOSE_INIT_TASK_ID,
TRANSPOSE_FWD_TASK_ID,
TRANSPOSE_BWD_TASK_ID,
Expand Down Expand Up @@ -618,7 +619,7 @@ class FFModel {
int k,
bool sorted,
char const *name = NULL);
Tensor argmax(const Tensor input, char const *name = NULL);
Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL);
Tensor sampling(const Tensor input, float top_p, char const *name = NULL);
Tensor multihead_attention(const Tensor query,
const Tensor key,
Expand Down
25 changes: 19 additions & 6 deletions include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace FlexFlow {

class ArgMaxMeta : public OpMeta {
public:
bool beam_search;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnReduceTensorDescriptor_t reduceMaxDesc;
Expand All @@ -28,7 +29,10 @@ class ArgMax : public Op {
public:
using Params = ArgMaxParams;
using Input = ParallelTensor;
ArgMax(FFModel &model, const ParallelTensor input, char const *name);
ArgMax(FFModel &model,
const ParallelTensor input,
bool beam_search,
char const *name);
ArgMax(FFModel &model, ArgMax const &other, const ParallelTensor input);
ArgMax(FFModel &model,
Params const &params,
Expand Down Expand Up @@ -58,11 +62,16 @@ class ArgMax : public Op {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static BeamInferenceResult
inference_task_beam(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static InferenceResult
inference_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
inference_task_norm(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
void serialize(Legion::Serializer &s) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
Expand All @@ -78,16 +87,20 @@ class ArgMax : public Op {
static void forward_kernel(ArgMaxMeta const *m,
DT *input_ptr,
int *indices_ptr,
float *prob_ptr,
int *parent_ptr,
int length,
int batch_size,
ffStream_t stream);
static void forward_kernel_wrapper(ArgMaxMeta const *m,
GenericTensorAccessorW const &input,
GenericTensorAccessorW const &indices,
int batch_size);
GenericTensorAccessorW const &value,
GenericTensorAccessorW const &parent);
Params get_params() const;

public:
bool beam_search;
};

}; // namespace FlexFlow
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/ops/argmax_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace FlexFlow {

struct ArgMaxParams {
OperatorType op_type;
bool beam_search;
bool is_valid(ParallelTensorShape const &) const;
};
bool operator==(ArgMaxParams const &, ArgMaxParams const &);
Expand Down
6 changes: 3 additions & 3 deletions inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ void LLAMA::create_llama_model(FFModel &ff,

Tensor output;
if (mode == BEAM_SEARCH_MODE) {
// Tensor softmax = ff.softmax(dense, -1);
Tensor softmax = ff.softmax(dense, -1);
// output = ff.beam_top_k(softmax, llama_config.max_beam_width, false);
output = ff.argmax(dense);
output = ff.argmax(softmax, /*beam_Search*/true);
} else {
// Tensor softmax = ff.softmax(dense, -1);
if (samplingConfig.do_sample) {
Expand All @@ -186,7 +186,7 @@ void LLAMA::create_llama_model(FFModel &ff,
output = ff.sampling(softmax, samplingConfig.topp);
} else {
// output = ff.arg_top_k(dense, /*k=*/1, false);
output = ff.argmax(dense);
output = ff.argmax(dense, /*beam_Search*/false);
}
}

Expand Down
6 changes: 3 additions & 3 deletions inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,12 @@ void OPT::create_opt_model(FFModel &ff,

Tensor output;
if (mode == BEAM_SEARCH_MODE) {
// Tensor softmax = ff.softmax(lm_head, -1);
Tensor softmax = ff.softmax(lm_head, -1);
// output = ff.beam_top_k(softmax, opt_config.max_beam_width, false);
output = ff.argmax(lm_head);
output = ff.argmax(softmax, /*beam_Search*/ true);
} else {
// output = ff.arg_top_k(lm_head, /*k=*/1, false);
output = ff.argmax(lm_head);
output = ff.argmax(lm_head, /*beam_Search*/ false);
}

//------------------- compile the model --------------------------------
Expand Down
Loading

0 comments on commit 29776bb

Please sign in to comment.