Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference: Sampling result #854

Merged
merged 19 commits into from
Jul 19, 2023
1 change: 1 addition & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ enum OperatorType {
OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION,
OP_TREE_INC_MULTIHEAD_SELF_ATTENTION,
OP_INC_MULTIQUERY_SELF_ATTENTION,
OP_SAMPLING,
// Parallel Ops
OP_REPARTITION,
OP_COMBINE,
Expand Down
12 changes: 12 additions & 0 deletions include/flexflow/inference.h
xinhaoc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ struct BeamTree {
treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1];
};

struct SamplingConfig {
bool do_sample = false;
float temperature = 0.8;
float topp = 0.6;
SamplingConfig(bool _do_sample, float _temperature, float _topp) {
temperature = _temperature > 0 ? _temperature : temperature;
topp = _topp > 0 ? _topp : topp;
do_sample = _do_sample;
}
SamplingConfig() {}
};

// struct BeamTree_v2 {
// std::vector<BatchConfig::TokenId> tokens;
// std::vector<int> parent_ids;
Expand Down
6 changes: 6 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ enum TaskIDs {
TOPK_BWD_TASK_ID,
ARG_TOPK_INIT_TASK_ID,
ARG_TOPK_INF_TASK_ID,
SAMPLING_INIT_TASK_ID,
SAMPLING_INF_TASK_ID,
TRANSPOSE_INIT_TASK_ID,
TRANSPOSE_FWD_TASK_ID,
TRANSPOSE_BWD_TASK_ID,
Expand Down Expand Up @@ -312,6 +314,7 @@ class RMSNorm;
class BeamTopK;
class SpecIncMultiHeadSelfAttention;
class IncMultiQuerySelfAttention;
class Sampling;
class Combine;
class Repartition;
class Reduction;
Expand Down Expand Up @@ -612,6 +615,7 @@ class FFModel {
int k,
bool sorted,
char const *name = NULL);
Tensor sampling(const Tensor input, float top_p, char const *name = NULL);
Tensor multihead_attention(const Tensor query,
const Tensor key,
const Tensor value,
Expand Down Expand Up @@ -1061,6 +1065,8 @@ class FFModel {
IncMultiQuerySelfAttention *>,
std::unordered_map<std::pair<ParallelTensorShape, BeamTopKParams>,
BeamTopK *>,
std::unordered_map<std::pair<ParallelTensorShape, SamplingParams>,
Sampling *>,
std::unordered_map<
std::pair<ParallelTensorShape, SpecIncMultiHeadSelfAttentionParams>,
SpecIncMultiHeadSelfAttention *>,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/operator_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "flexflow/ops/reduce_params.h"
#include "flexflow/ops/reshape_params.h"
#include "flexflow/ops/rms_norm_params.h"
#include "flexflow/ops/sampling_params.h"
#include "flexflow/ops/softmax_params.h"
#include "flexflow/ops/spec_inc_multihead_self_attention_params.h"
#include "flexflow/ops/split_params.h"
Expand Down Expand Up @@ -71,6 +72,7 @@ using OperatorParameters = mp::variant<AggregateParams,
SplitParams,
TopKParams,
ArgTopKParams,
SamplingParams,
SoftmaxParams,
TransposeParams,
RepartitionParams,
Expand Down
108 changes: 108 additions & 0 deletions include/flexflow/ops/sampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#ifndef _FLEXFLOW_SAMPLING_TOPK_H_
#define _FLEXFLOW_SAMPLING_TOPK_H_

#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/node.h"
#include "flexflow/ops/sampling_params.h"
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#include <curand.h>
#include <curand_kernel.h>
#endif

namespace FlexFlow {

class SamplingMeta : public OpMeta {
public:
float top_p;
void *sorted_logits;
int *sorted_idx;
int *begin_offset;
int *end_offset;
int *idx;
void *d_temp_storage;
size_t temp_storage_bytes;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
curandState *state;
#endif
SamplingMeta(FFHandler handle,
Op const *op,
int batch_size,
int total_ele,
GenericTensorAccessorW input);
};

class Sampling : public Op {
public:
using Params = SamplingParams;
using Input = ParallelTensor;
Sampling(FFModel &model,
const ParallelTensor input,
float top_p,
char const *name);
Sampling(FFModel &model, Sampling const &other, const ParallelTensor input);
Sampling(FFModel &model,
Params const &params,
Input const input,
char const *name = nullptr);
void init(FFModel const &) override;
void init_inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void forward(FFModel const &) override;
void backward(FFModel const &) override;
Legion::FutureMap inference(FFModel const &,
BatchConfig const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void print_layer(FFModel const &model) override {
assert(0);
}
static Op *
create_operator_from_layer(FFModel &model,
Layer const *layer,
std::vector<ParallelTensor> const &inputs);

static OpMeta *init_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static InferenceResult
inference_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
void serialize(Legion::Serializer &s) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
ParallelTensor inputs[],
int num_inputs);
Op *materialize(FFModel &ff,
ParallelTensor inputs[],
int num_inputs) const override;
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;
template <typename DT>
static void forward_kernel(SamplingMeta const *m,
DT *input_ptr,
int *indices_ptr,
float top_p,
int length,
int batch_size,
ffStream_t stream);
static void forward_kernel_wrapper(SamplingMeta const *m,
GenericTensorAccessorW const &input,
GenericTensorAccessorW const &indices,
int batch_size);
Params get_params() const;

public:
float top_p;
};

}; // namespace FlexFlow

#endif
24 changes: 24 additions & 0 deletions include/flexflow/ops/sampling_params.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef _FLEXFLOW_SAMPLING_PARAMS_H
#define _FLEXFLOW_SAMPLING_PARAMS_H

#include "flexflow/ffconst.h"
#include "flexflow/parallel_tensor.h"

namespace FlexFlow {

struct SamplingParams {
float top_p;
bool is_valid(ParallelTensorShape const &) const;
};
bool operator==(SamplingParams const &, SamplingParams const &);

} // namespace FlexFlow

namespace std {
template <>
struct hash<FlexFlow::SamplingParams> {
size_t operator()(FlexFlow::SamplingParams const &) const;
};
} // namespace std

#endif // _FLEXFLOW_SAMPLING_PARAMS_H
25 changes: 25 additions & 0 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ void parse_input_args(char **argv,
ModelType &llm_model_type,
bool &use_full_precision,
bool &verbose,
bool &do_sample,
float &temperature,
float &topp,
int &data_parallelism_degree,
int &tensor_parallelism_degree,
int &pipeline_parallelism_degree) {
Expand Down Expand Up @@ -109,6 +112,18 @@ void parse_input_args(char **argv,
verbose = true;
continue;
}
if (!strcmp(argv[i], "--do-sample")) {
do_sample = true;
continue;
}
if (!strcmp(argv[i], "--temperature")) {
temperature = std::stof(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--topp")) {
topp = std::stof(argv[++i]);
continue;
}
}
}

Expand All @@ -124,6 +139,9 @@ void FlexFlow::top_level_task(Task const *task,
ModelType model_type;
bool use_full_precision = false;
bool verbose = false;
bool do_sample = false;
float temperature = 0.0f;
float topp = 0.0f;
size_t num_devices = ffconfig.workersPerNode * ffconfig.numNodes;
int data_parallelism_degree = 1, tensor_parallelism_degree = 1,
pipeline_parallelism_degree = 1;
Expand All @@ -137,19 +155,24 @@ void FlexFlow::top_level_task(Task const *task,
model_type,
use_full_precision,
verbose,
do_sample,
temperature,
topp,
data_parallelism_degree,
tensor_parallelism_degree,
pipeline_parallelism_degree);
ffconfig.data_parallelism_degree = data_parallelism_degree;
ffconfig.tensor_parallelism_degree = tensor_parallelism_degree;
ffconfig.pipeline_parallelism_degree = pipeline_parallelism_degree;

assert(data_parallelism_degree * tensor_parallelism_degree *
pipeline_parallelism_degree ==
ffconfig.numNodes * ffconfig.workersPerNode);

assert(model_type != ModelType::UNKNOWN &&
"Invalid LLM model type passed (or no type was passed).");

SamplingConfig samplingConfig(do_sample, temperature, topp);
InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS);
RequestManager rm(model_type,
file_paths.tokenizer_file_path,
Expand All @@ -163,6 +186,7 @@ void FlexFlow::top_level_task(Task const *task,
file_paths.llm_config_file_path,
file_paths.llm_weight_file_path,
INC_DECODING_MODE,
samplingConfig,
use_full_precision);
} else if (model_type == ModelType::OPT) {
OPT::create_opt_model(model,
Expand Down Expand Up @@ -211,6 +235,7 @@ void FlexFlow::top_level_task(Task const *task,
assert(fm.get_future_map_domain().get_volume() == 1);
Future future = fm.get_future(0);
ir = future.get_result<InferenceResult>();
// assert(false);
}

// Execution fence
Expand Down
10 changes: 9 additions & 1 deletion inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void LLAMA::create_llama_model(FFModel &ff,
std::string const &model_config_file_path,
std::string const &weight_file_path,
InferenceMode mode,
SamplingConfig samplingConfig,
bool use_full_precision) {
// do not apply cpu offload in beam search model.
Config llama_config(model_config_file_path);
Expand Down Expand Up @@ -210,7 +211,14 @@ void LLAMA::create_llama_model(FFModel &ff,
Tensor softmax = ff.softmax(dense, -1);
output = ff.beam_top_k(softmax, llama_config.max_beam_width, false);
} else {
output = ff.arg_top_k(dense, /*k=*/1, false);
// Tensor softmax = ff.softmax(dense, -1);
if (samplingConfig.do_sample) {
dense = ff.scalar_truediv(dense, samplingConfig.temperature, false);
Tensor softmax = ff.softmax(dense, -1);
output = ff.sampling(softmax, samplingConfig.topp);
} else {
output = ff.arg_top_k(dense, /*k=*/1, false);
}
}

// Compile the model
Expand Down
1 change: 1 addition & 0 deletions inference/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class LLAMA {
std::string const &model_config_file_path,
std::string const &weight_file_path,
InferenceMode mode,
SamplingConfig samplingConfig,
bool use_full_precision = false);
};

Expand Down
3 changes: 3 additions & 0 deletions inference/spec_infer/spec_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void FlexFlow::top_level_task(Task const *task,
}

// Create SentencePiece tokenizer or OPT tokenizer
SamplingConfig samplingConfig;
InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS);
RequestManager rm(model_types.llm_model_type,
file_paths.tokenizer_file_path,
Expand All @@ -213,6 +214,7 @@ void FlexFlow::top_level_task(Task const *task,
file_paths.llm_config_file_path,
file_paths.llm_weight_file_path,
TREE_VERIFY_MODE,
samplingConfig,
use_full_precision);
} else if (model_types.llm_model_type == ModelType::OPT) {
OPT::create_opt_model(tree_model,
Expand Down Expand Up @@ -245,6 +247,7 @@ void FlexFlow::top_level_task(Task const *task,
file_paths.ssm_config_file_paths[ssm_id],
file_paths.ssm_weight_file_paths[ssm_id],
BEAM_SEARCH_MODE,
samplingConfig,
use_full_precision);
} else if (model_types.ssm_model_types[ssm_id] == ModelType::OPT) {
OPT::create_opt_model(beam_model,
Expand Down
3 changes: 2 additions & 1 deletion src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,8 @@ __host__ void
case OP_RELU:
case OP_SIGMOID:
case OP_TANH:
case OP_ELU: {
case OP_ELU:
case OP_SCALAR_TRUE_DIV: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
assert(fused->op_num_outputs[op] == 1);
Expand Down
Loading
Loading