Skip to content

Commit

Permalink
Merge branch 'inference' into fix_batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc authored Jul 19, 2023
2 parents 5c2e9ae + d3cd370 commit 3d8af13
Show file tree
Hide file tree
Showing 18 changed files with 919 additions and 3 deletions.
1 change: 1 addition & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ enum OperatorType {
OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION,
OP_TREE_INC_MULTIHEAD_SELF_ATTENTION,
OP_INC_MULTIQUERY_SELF_ATTENTION,
OP_SAMPLING,
// Parallel Ops
OP_REPARTITION,
OP_COMBINE,
Expand Down
12 changes: 12 additions & 0 deletions include/flexflow/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ struct BeamTree {
treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1];
};

struct SamplingConfig {
bool do_sample = false;
float temperature = 0.8;
float topp = 0.6;
SamplingConfig(bool _do_sample, float _temperature, float _topp) {
temperature = _temperature > 0 ? _temperature : temperature;
topp = _topp > 0 ? _topp : topp;
do_sample = _do_sample;
}
SamplingConfig() {}
};

// struct BeamTree_v2 {
// std::vector<BatchConfig::TokenId> tokens;
// std::vector<int> parent_ids;
Expand Down
6 changes: 6 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ enum TaskIDs {
TOPK_BWD_TASK_ID,
ARG_TOPK_INIT_TASK_ID,
ARG_TOPK_INF_TASK_ID,
SAMPLING_INIT_TASK_ID,
SAMPLING_INF_TASK_ID,
TRANSPOSE_INIT_TASK_ID,
TRANSPOSE_FWD_TASK_ID,
TRANSPOSE_BWD_TASK_ID,
Expand Down Expand Up @@ -312,6 +314,7 @@ class RMSNorm;
class BeamTopK;
class SpecIncMultiHeadSelfAttention;
class IncMultiQuerySelfAttention;
class Sampling;
class Combine;
class Repartition;
class Reduction;
Expand Down Expand Up @@ -612,6 +615,7 @@ class FFModel {
int k,
bool sorted,
char const *name = NULL);
Tensor sampling(const Tensor input, float top_p, char const *name = NULL);
Tensor multihead_attention(const Tensor query,
const Tensor key,
const Tensor value,
Expand Down Expand Up @@ -1061,6 +1065,8 @@ class FFModel {
IncMultiQuerySelfAttention *>,
std::unordered_map<std::pair<ParallelTensorShape, BeamTopKParams>,
BeamTopK *>,
std::unordered_map<std::pair<ParallelTensorShape, SamplingParams>,
Sampling *>,
std::unordered_map<
std::pair<ParallelTensorShape, SpecIncMultiHeadSelfAttentionParams>,
SpecIncMultiHeadSelfAttention *>,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/operator_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "flexflow/ops/reduce_params.h"
#include "flexflow/ops/reshape_params.h"
#include "flexflow/ops/rms_norm_params.h"
#include "flexflow/ops/sampling_params.h"
#include "flexflow/ops/softmax_params.h"
#include "flexflow/ops/spec_inc_multihead_self_attention_params.h"
#include "flexflow/ops/split_params.h"
Expand Down Expand Up @@ -71,6 +72,7 @@ using OperatorParameters = mp::variant<AggregateParams,
SplitParams,
TopKParams,
ArgTopKParams,
SamplingParams,
SoftmaxParams,
TransposeParams,
RepartitionParams,
Expand Down
108 changes: 108 additions & 0 deletions include/flexflow/ops/sampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#ifndef _FLEXFLOW_SAMPLING_TOPK_H_
#define _FLEXFLOW_SAMPLING_TOPK_H_

#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/node.h"
#include "flexflow/ops/sampling_params.h"
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#include <curand.h>
#include <curand_kernel.h>
#endif

namespace FlexFlow {

class SamplingMeta : public OpMeta {
public:
float top_p;
void *sorted_logits;
int *sorted_idx;
int *begin_offset;
int *end_offset;
int *idx;
void *d_temp_storage;
size_t temp_storage_bytes;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
curandState *state;
#endif
SamplingMeta(FFHandler handle,
Op const *op,
int batch_size,
int total_ele,
GenericTensorAccessorW input);
};

class Sampling : public Op {
public:
using Params = SamplingParams;
using Input = ParallelTensor;
Sampling(FFModel &model,
const ParallelTensor input,
float top_p,
char const *name);
Sampling(FFModel &model, Sampling const &other, const ParallelTensor input);
Sampling(FFModel &model,
Params const &params,
Input const input,
char const *name = nullptr);
void init(FFModel const &) override;
void init_inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void forward(FFModel const &) override;
void backward(FFModel const &) override;
Legion::FutureMap inference(FFModel const &,
BatchConfig const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void print_layer(FFModel const &model) override {
assert(0);
}
static Op *
create_operator_from_layer(FFModel &model,
Layer const *layer,
std::vector<ParallelTensor> const &inputs);

static OpMeta *init_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static InferenceResult
inference_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
void serialize(Legion::Serializer &s) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
ParallelTensor inputs[],
int num_inputs);
Op *materialize(FFModel &ff,
ParallelTensor inputs[],
int num_inputs) const override;
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;
template <typename DT>
static void forward_kernel(SamplingMeta const *m,
DT *input_ptr,
int *indices_ptr,
float top_p,
int length,
int batch_size,
ffStream_t stream);
static void forward_kernel_wrapper(SamplingMeta const *m,
GenericTensorAccessorW const &input,
GenericTensorAccessorW const &indices,
int batch_size);
Params get_params() const;

public:
float top_p;
};

}; // namespace FlexFlow

#endif
24 changes: 24 additions & 0 deletions include/flexflow/ops/sampling_params.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef _FLEXFLOW_SAMPLING_PARAMS_H
#define _FLEXFLOW_SAMPLING_PARAMS_H

#include "flexflow/ffconst.h"
#include "flexflow/parallel_tensor.h"

namespace FlexFlow {

struct SamplingParams {
float top_p;
bool is_valid(ParallelTensorShape const &) const;
};
bool operator==(SamplingParams const &, SamplingParams const &);

} // namespace FlexFlow

namespace std {
template <>
struct hash<FlexFlow::SamplingParams> {
size_t operator()(FlexFlow::SamplingParams const &) const;
};
} // namespace std

#endif // _FLEXFLOW_SAMPLING_PARAMS_H
25 changes: 25 additions & 0 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ void parse_input_args(char **argv,
ModelType &llm_model_type,
bool &use_full_precision,
bool &verbose,
bool &do_sample,
float &temperature,
float &topp,
int &data_parallelism_degree,
int &tensor_parallelism_degree,
int &pipeline_parallelism_degree) {
Expand Down Expand Up @@ -109,6 +112,18 @@ void parse_input_args(char **argv,
verbose = true;
continue;
}
if (!strcmp(argv[i], "--do-sample")) {
do_sample = true;
continue;
}
if (!strcmp(argv[i], "--temperature")) {
temperature = std::stof(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--topp")) {
topp = std::stof(argv[++i]);
continue;
}
}
}

Expand All @@ -124,6 +139,9 @@ void FlexFlow::top_level_task(Task const *task,
ModelType model_type;
bool use_full_precision = false;
bool verbose = false;
bool do_sample = false;
float temperature = 0.0f;
float topp = 0.0f;
size_t num_devices = ffconfig.workersPerNode * ffconfig.numNodes;
int data_parallelism_degree = 1, tensor_parallelism_degree = 1,
pipeline_parallelism_degree = 1;
Expand All @@ -137,19 +155,24 @@ void FlexFlow::top_level_task(Task const *task,
model_type,
use_full_precision,
verbose,
do_sample,
temperature,
topp,
data_parallelism_degree,
tensor_parallelism_degree,
pipeline_parallelism_degree);
ffconfig.data_parallelism_degree = data_parallelism_degree;
ffconfig.tensor_parallelism_degree = tensor_parallelism_degree;
ffconfig.pipeline_parallelism_degree = pipeline_parallelism_degree;

assert(data_parallelism_degree * tensor_parallelism_degree *
pipeline_parallelism_degree ==
ffconfig.numNodes * ffconfig.workersPerNode);

assert(model_type != ModelType::UNKNOWN &&
"Invalid LLM model type passed (or no type was passed).");

SamplingConfig samplingConfig(do_sample, temperature, topp);
InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS);
RequestManager rm(model_type,
file_paths.tokenizer_file_path,
Expand All @@ -163,6 +186,7 @@ void FlexFlow::top_level_task(Task const *task,
file_paths.llm_config_file_path,
file_paths.llm_weight_file_path,
INC_DECODING_MODE,
samplingConfig,
use_full_precision);
} else if (model_type == ModelType::OPT) {
OPT::create_opt_model(model,
Expand Down Expand Up @@ -211,6 +235,7 @@ void FlexFlow::top_level_task(Task const *task,
assert(fm.get_future_map_domain().get_volume() == 1);
Future future = fm.get_future(0);
ir = future.get_result<InferenceResult>();
// assert(false);
}

// Execution fence
Expand Down
10 changes: 9 additions & 1 deletion inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void LLAMA::create_llama_model(FFModel &ff,
std::string const &model_config_file_path,
std::string const &weight_file_path,
InferenceMode mode,
SamplingConfig samplingConfig,
bool use_full_precision) {
// do not apply cpu offload in beam search model.
Config llama_config(model_config_file_path);
Expand Down Expand Up @@ -210,7 +211,14 @@ void LLAMA::create_llama_model(FFModel &ff,
Tensor softmax = ff.softmax(dense, -1);
output = ff.beam_top_k(softmax, llama_config.max_beam_width, false);
} else {
output = ff.arg_top_k(dense, /*k=*/1, false);
// Tensor softmax = ff.softmax(dense, -1);
if (samplingConfig.do_sample) {
dense = ff.scalar_truediv(dense, samplingConfig.temperature, false);
Tensor softmax = ff.softmax(dense, -1);
output = ff.sampling(softmax, samplingConfig.topp);
} else {
output = ff.arg_top_k(dense, /*k=*/1, false);
}
}

// Compile the model
Expand Down
1 change: 1 addition & 0 deletions inference/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class LLAMA {
std::string const &model_config_file_path,
std::string const &weight_file_path,
InferenceMode mode,
SamplingConfig samplingConfig,
bool use_full_precision = false);
};

Expand Down
3 changes: 3 additions & 0 deletions inference/spec_infer/spec_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void FlexFlow::top_level_task(Task const *task,
}

// Create SentencePiece tokenizer or OPT tokenizer
SamplingConfig samplingConfig;
InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS);
RequestManager rm(model_types.llm_model_type,
file_paths.tokenizer_file_path,
Expand All @@ -213,6 +214,7 @@ void FlexFlow::top_level_task(Task const *task,
file_paths.llm_config_file_path,
file_paths.llm_weight_file_path,
TREE_VERIFY_MODE,
samplingConfig,
use_full_precision);
} else if (model_types.llm_model_type == ModelType::OPT) {
OPT::create_opt_model(tree_model,
Expand Down Expand Up @@ -245,6 +247,7 @@ void FlexFlow::top_level_task(Task const *task,
file_paths.ssm_config_file_paths[ssm_id],
file_paths.ssm_weight_file_paths[ssm_id],
BEAM_SEARCH_MODE,
samplingConfig,
use_full_precision);
} else if (model_types.ssm_model_types[ssm_id] == ModelType::OPT) {
OPT::create_opt_model(beam_model,
Expand Down
3 changes: 2 additions & 1 deletion src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,8 @@ __host__ void
case OP_RELU:
case OP_SIGMOID:
case OP_TANH:
case OP_ELU: {
case OP_ELU:
case OP_SCALAR_TRUE_DIV: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
assert(fused->op_num_outputs[op] == 1);
Expand Down
Loading

0 comments on commit 3d8af13

Please sign in to comment.