Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse optimization #1248

Open
wants to merge 19 commits into
base: BertMLM_fixes
Choose a base branch
from
1 change: 1 addition & 0 deletions examples/python/pytorch/mt5/mt5_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def top_level_task():
input_names = ["input_ids", "attention_mask"]

print("Tracing the model...")
print(batch_size)
hf_model = PyTorchModel(
model, is_hf_model=True, input_names=input_names,
batch_size=batch_size, seq_length=seq_length,
Expand Down
3 changes: 3 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class FFConfig {
size_t workSpaceSize;
Legion::Context lg_ctx;
Legion::Runtime *lg_hlr;
Legion::IndexSpaceT<1> all_gpu_task_is;
Legion::FieldSpace field_space;
bool syntheticInput, profiling, perform_fusion;
size_t simulator_work_space_size;
Expand All @@ -137,6 +138,8 @@ class FFConfig {
bool enable_parameter_parallel;
bool enable_attribute_parallel;
bool enable_inplace_optimizations;
int data_parallelism_degree;
int tensor_parallelism_degree;
// Control Tensor Op Math Conversion
bool allow_tensor_op_math_conversion;
std::string dataset_path;
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ enum OperatorType {
OP_REPLICATE,
OP_REDUCTION,
OP_PIPELINE,
OP_ALLREDUCE,
OP_FUSED_PARALLEL,
OP_INVALID,
};
Expand Down Expand Up @@ -189,6 +190,7 @@ enum PMParameter {
PM_COMBINE_DEGREE, // Combine
PM_REDUCTION_DIM, // Reduction
PM_REDUCTION_DEGREE, // Reduction
PM_ALLREDUCE_DIM, // AllReduce
PM_SOFTMAX_DIM, // Softmax
PM_NUM_HEADS, // MultiHeadAttention
PM_INVALID,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ void flexflow_model_compute_metrics(flexflow_model_t handle);

void flexflow_model_update(flexflow_model_t handle);

void flexflow_model_unified_update(flexflow_model_t handle);

void flexflow_model_compile(flexflow_model_t handle,
enum LossType loss_type,
int *metrics,
Expand Down
12 changes: 12 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ enum TaskIDs {
// Optimizer with NCCL
SGD_UPD_NCCL_TASK_ID,
ADAM_UPD_NCCL_TASK_ID,
ADAM_UNIFY_UPD_NCCL_TASK_ID,
// Initializer
GLOROT_INIT_TASK_ID,
ZERO_INIT_TASK_ID,
Expand Down Expand Up @@ -190,6 +191,10 @@ enum TaskIDs {
PIPELINE_INIT_TASK_ID,
PIPELINE_FWD_TASK_ID,
PIPELINE_BWD_TASK_ID,
ALLREDUCE_INIT_TASK_ID,
ALLREDUCE_INF_TASK_ID,
ALLREDUCE_FWD_TASK_ID,
ALLREDUCE_BWD_TASK_ID,
FUSED_PARALLELOP_INIT_TASK_ID,
FUSED_PARALLELOP_FWD_TASK_ID,
FUSED_PARALLELOP_BWD_TASK_ID,
Expand Down Expand Up @@ -273,6 +278,7 @@ class Split;
class TopK;
class Transpose;
class Combine;
class AllReduce;
class Repartition;
class Reduction;
class Replicate;
Expand Down Expand Up @@ -777,6 +783,7 @@ class FFModel {
void get_metrics();
void backward(int seq_length = -1);
void update();
void unified_update();
bool apply_fusion(std::vector<Op *> const &operators,
std::vector<Op *> &new_operators);
Op *get_final_operator() const;
Expand Down Expand Up @@ -828,6 +835,8 @@ class FFModel {
Legion::IndexSpace get_task_is(Legion::Domain const &domain) const;
Legion::IndexSpace get_task_is(ParallelConfig const &pc) const;
Legion::IndexSpace get_task_is(MachineView const &view) const;
bool is_transformer_block(int layer_idx) const;
bool is_mlp_block(int layer_idx) const;
void create_operators_from_layers();
Op *create_operator_from_layer(Layer *layer,
std::vector<ParallelTensor> const &inputs);
Expand All @@ -854,6 +863,7 @@ class FFModel {
int metrics_input;
ParallelTensor parallel_label_tensor;
Tensor label_tensor;
int num_inputs = 0;

std::vector<Layer *> layers;
std::vector<Op *> operators;
Expand Down Expand Up @@ -923,6 +933,8 @@ class FFModel {
Replicate *>,
std::unordered_map<std::pair<ParallelTensorShape, ReductionParams>,
Reduction *>,
std::unordered_map<std::pair<ParallelTensorShape, AllReduceParams>,
AllReduce *>,
std::unordered_map<std::pair<ParallelTensorShape, CombineParams>,
Combine *>,
std::unordered_map<std::pair<ParallelTensorShape, FusedParallelOpParams>,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/operator_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "flexflow/ops/batch_matmul_params.h"
#include "flexflow/ops/cast_params.h"
#include "flexflow/ops/concat_params.h"
#include "flexflow/parallel_ops/allreduce_params.h"
#include "flexflow/ops/conv_2d_params.h"
#include "flexflow/ops/dropout_params.h"
#include "flexflow/ops/element_binary_params.h"
Expand Down Expand Up @@ -62,6 +63,7 @@ using OperatorParameters = mp::variant<AggregateParams,
ReplicateParams,
ReductionParams,
CombineParams,
AllReduceParams,
FusedParallelOpParams>;

tl::optional<OperatorParameters> get_op_parameters(Op const *op);
Expand Down
7 changes: 7 additions & 0 deletions include/flexflow/ops/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
#include "flexflow/node.h"
#include "flexflow/operator.h"
#include "flexflow/ops/dropout_params.h"
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#include <curand.h>
#include <curand_kernel.h>
#elif defined(FF_USE_HIP_ROCM)
#include <hiprand/hiprand.h>
#include <hiprand/hiprand_kernel.h>
#endif

namespace FlexFlow {

Expand Down
5 changes: 5 additions & 0 deletions include/flexflow/ops/element_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class ElementBinary : public Op {
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;
void serialize(Legion::Serializer &) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
ParallelTensor inputs[],
int num_inputs);
Params get_params() const;

public:
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/element_binary_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace FlexFlow {

struct ElementBinaryParams {
OperatorType type;
bool inplace_a;

bool is_valid(
std::pair<ParallelTensorShape, ParallelTensorShape> const &) const;
Expand Down
16 changes: 12 additions & 4 deletions include/flexflow/ops/kernels/dropout_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "flexflow/fftype.h"
#include "flexflow/op_meta.h"
#include "flexflow/ops/dropout.h"
#include "flexflow/accessor.h"

namespace FlexFlow {

Expand All @@ -17,33 +18,40 @@ class DropoutMeta : public OpMeta {
~DropoutMeta(void);
Realm::RegionInstance reserveInst;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
curandState *state;
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnDropoutDescriptor_t dropoutDesc;
#else
miopenTensorDescriptor_t inputTensor, outputTensor;
miopenDropoutDescriptor_t dropoutDesc;
hiprandState *state;
#endif
void *reserveSpace, *dropoutStates;
size_t reserveSpaceSize, dropoutStateSize;
size_t num_elements;
long long seed;
float rate;
};

namespace Kernels {
namespace Dropout {
void forward_kernel_wrapper(DropoutMeta *m,
float const *input_ptr,
float *output_ptr);
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
void backward_kernel_wrapper(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr);
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad);

namespace Internal {
void forward_kernel(DropoutMeta *m,
float const *input_ptr,
float *output_ptr,
size_t num_elements,
ffStream_t stream);
void backward_kernel(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr,
size_t num_elements,
ffStream_t stream);
} // namespace Internal
} // namespace Dropout
Expand Down
23 changes: 23 additions & 0 deletions include/flexflow/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "flexflow/parallel_tensor.h"
#include "legion.h"
#include "accessor.h"

namespace FlexFlow {

Expand All @@ -30,6 +31,7 @@ class Optimizer {
virtual void init(void) = 0;
virtual void next(void) = 0;
virtual void update(const ParallelTensor p) = 0;
virtual void unified_update(std::vector<ParallelTensor> const parameters) = 0;
FFModel const *model;
};

Expand All @@ -43,6 +45,7 @@ class SGDOptimizer : public Optimizer {
void init(void);
void next(void);
void update(const ParallelTensor p);
void unified_update(std::vector<ParallelTensor> const parameters);
void set_weight_decay(double _weight_decay);
static void ps_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Expand All @@ -60,6 +63,11 @@ class SGDOptimizer : public Optimizer {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
nccl_unified_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void nccl_update_task_gpu(SGDOptimizer const *op,
OpMeta const *meta,
float const *w_grad_ptr,
Expand All @@ -85,6 +93,7 @@ class AdamOptimizer : public Optimizer {
void init(void);
void next(void);
void update(const ParallelTensor p);
void unified_update(std::vector<ParallelTensor> const parameters);
void set_weight_decay(double _weight_decay);
static void ps_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Expand All @@ -103,17 +112,31 @@ class AdamOptimizer : public Optimizer {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
nccl_unified_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void nccl_update_task_gpu(AdamOptimizer const *op,
OpMeta const *meta,
float const *w_grad_ptr,
size_t size,
float *w_ptr,
float *v_ptr,
float *m_ptr);
static void nccl_unified_update_task_gpu(AdamOptimizer const *op,
OpMeta const *meta,
GenericTensorAccessorR *accWGrads,
size_t *size,
GenericTensorAccessorW *accWs,
GenericTensorAccessorW *accVs,
GenericTensorAccessorW *accMs);
#endif
double alpha, beta1, beta2, weight_decay, epsilon;
double alpha_t, beta1_t, beta2_t;
std::map<Legion::LogicalRegion, ParallelTensor> v_values, m_values;
size_t reservedWorkSpaceSize = 0;
int parameters_num = 0;
};

}; // namespace FlexFlow
Expand Down
57 changes: 57 additions & 0 deletions include/flexflow/parallel_ops/allreduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef _FLEXFLOW_ALLREDUCE_H
#define _FLEXFLOW_ALLREDUCE_H

#include "flexflow/layer.h"
#include "flexflow/node.h"
#include "flexflow/op_meta.h"
#include "flexflow/operator.h"
#include "flexflow/parallel_ops/allreduce_params.h"
#include "parallel_op.h"

namespace FlexFlow {

class AllReduce : public ParallelOp {
public:
using Params = AllReduceParams;
using Input = ParallelTensor;

AllReduce(FFModel &model,
const ParallelTensor input,
int allreduce_legion_dim,
char const *name = NULL);
AllReduce(FFModel &model,
Params const &params,
Input const input,
char const *name = nullptr);
void create_input_partition(FFModel &model) override;
void init(FFModel const &) override;
void forward(FFModel const &) override;
void backward(FFModel const &) override;
bool get_int_parameter(PMParameter, int *) const override;
bool append_parallel_op_info(
std::vector<ParallelOpInfo> &parallel_ops) const override;
static OpMeta *init_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void forward_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void backward_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;

Params get_params() const;

public:
int allreduce_dim;
};

}; // namespace FlexFlow

#endif // _FLEXFLOW_ALLREDUCE_H
22 changes: 22 additions & 0 deletions include/flexflow/parallel_ops/allreduce_params.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef _FLEXFLOW_ALLREDUCE_PARAMS_H
#define _FLEXFLOW_ALLREDUCE_PARAMS_H

namespace FlexFlow {

struct AllReduceParams {
int allreduce_legion_dim;
char name[MAX_OPNAME];
bool is_valid(ParallelTensorShape const &) const;
};
bool operator==(AllReduceParams const &, AllReduceParams const &);

} // namespace FlexFlow

namespace std {
template <>
struct hash<FlexFlow::AllReduceParams> {
size_t operator()(FlexFlow::AllReduceParams const &) const;
};
} // namespace std

#endif // _FLEXFLOW_ALLREDUCE_PARAMS_H
Loading
Loading