Skip to content

Commit

Permalink
fixed compilation bug and others
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jan 30, 2023
1 parent 7216f22 commit 6cd6b67
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 58 deletions.
100 changes: 50 additions & 50 deletions examples/cpp/inference/mixture_of_experts/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ void FlexFlow::top_level_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
/* // Inference parameters
// Inference parameters
int total_requests =
256; // total number of requests processed as part of the simulation
int request_tensor_size = 4; // request tensor dimensions
bool poisson_distribution = true;
double lambda = 25; // average number of request arrivals per second
int num_requests_per_batch = 5;
int num_inflight_batches = 10; */
int num_inflight_batches = 10;

//-----------------------------------------------------------------

Expand Down Expand Up @@ -130,20 +130,20 @@ void FlexFlow::top_level_task(Task const *task,
Tensor t = create_moe(&ff, &moeConfig, input);
t = ff.dense(t, OUT_DIM, AC_MODE_RELU);

/* InferenceManager im(&ff, num_requests_per_batch, num_inflight_batches);
im.compile_model_and_allocate_buffer(); */
InferenceManager im(&ff, num_requests_per_batch, num_inflight_batches);
im.compile_model_and_allocate_buffer();

Optimizer *optimizer = new SGDOptimizer(&ff, 0.001f);
/* Optimizer *optimizer = new SGDOptimizer(&ff, 0.001f);
std::vector<MetricsType> metrics;
metrics.push_back(METRICS_ACCURACY);
metrics.push_back(METRICS_SPARSE_CATEGORICAL_CROSSENTROPY);
ff.compile(optimizer, LOSS_SPARSE_CATEGORICAL_CROSSENTROPY, metrics);
ff.compile(optimizer, LOSS_SPARSE_CATEGORICAL_CROSSENTROPY, metrics); */

// Data Loader
ParallelTensor input_pt, label_pt;
/* ParallelTensor input_pt, label_pt;
ff.get_parallel_tensor_from_tensor(input, input_pt);
ff.get_parallel_tensor_from_tensor(ff.label_tensor, label_pt);
DataLoader data_loader(ff, moeConfig, input_pt, label_pt);
DataLoader data_loader(ff, moeConfig, input_pt, label_pt); */

ff.init_operators();

Expand All @@ -160,52 +160,52 @@ void FlexFlow::top_level_task(Task const *task,

///////////////////////////////////////////////////////////////////////////////////

// int index = 0;
// int processed_requests = 0;
// Generator data_generator(
// total_requests, request_tensor_size, poisson_distribution, lambda);
// while (processed_requests < total_requests) {
// vector<vector<double>> req = data_generator.get_requests();
// int iterations = req.size();
// for (int iter = 0; iter < iterations; iter++) {
// // data_loader.next_batch(ff);
// runtime->begin_trace(ctx, 111 /*trace_id*/);
// im.inference((index++) % num_inflight_batches);
// runtime->end_trace(ctx, 111 /*trace_id*/);
// }
// processed_requests += iterations;
// }

for (int epoch = 0; epoch < ffConfig.epochs; epoch++) {
data_loader.reset();
ff.reset_metrics();
int iterations = TRAIN_SAMPLES / ffConfig.batchSize;

int index = 0;
int processed_requests = 0;
Generator data_generator(
total_requests, request_tensor_size, poisson_distribution, lambda);
while (processed_requests < total_requests) {
vector<vector<double>> req = data_generator.get_requests();
int iterations = req.size();
for (int iter = 0; iter < iterations; iter++) {
data_loader.next_batch(ff);
if (epoch > 0) {
runtime->begin_trace(ctx, 111 /*trace_id*/);
}
ff.forward();
ff.zero_gradients();
// ff.backward();
ff.update();
// ff.recompile_on_condition(r);
if (epoch > 0) {
runtime->end_trace(ctx, 111 /*trace_id*/);
}
// data_loader.next_batch(ff);
runtime->begin_trace(ctx, 111 /*trace_id*/);
im.inference((index++) % num_inflight_batches);
runtime->end_trace(ctx, 111 /*trace_id*/);
}

// TODO: Do properly
ff.reset_metrics();
// iterations = TEST_SAMPLES / ffConfig.batchSize;
// for (int iter = 0; iter < iterations; iter++) {
// data_loader.next_batch(ff);
// ff.forward();
// ff.backward();
// }
processed_requests += iterations;
}

// for (int epoch = 0; epoch < ffConfig.epochs; epoch++) {
// data_loader.reset();
// ff.reset_metrics();
// int iterations = TRAIN_SAMPLES / ffConfig.batchSize;

// for (int iter = 0; iter < iterations; iter++) {
// data_loader.next_batch(ff);
// if (epoch > 0) {
// runtime->begin_trace(ctx, 111 /*trace_id*/);
// }
// ff.forward();
// ff.zero_gradients();
// // ff.backward();
// ff.update();
// // ff.recompile_on_condition(r);
// if (epoch > 0) {
// runtime->end_trace(ctx, 111 /*trace_id*/);
// }
// }

// // TODO: Do properly
// ff.reset_metrics();
// // iterations = TEST_SAMPLES / ffConfig.batchSize;
// // for (int iter = 0; iter < iterations; iter++) {
// // data_loader.next_batch(ff);
// // ff.forward();
// // ff.backward();
// // }
// }

///////////////////////////////////////////////////////////////////////////////////

// End timer
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/ops/noop.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class NoOp : public Op {
char const *name = NULL);
void init(FFModel const &) override;
void forward(FFModel const &) override;
void inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void backward(FFModel const &) override;
void print_layer(FFModel const &model) override {
assert(0);
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class Softmax : public Op {
char const *name = nullptr);
void init(FFModel const &) override;
void forward(FFModel const &) override;
void inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void backward(FFModel const &) override;
bool get_int_parameter(PMParameter, int *) const override;
void print_layer(FFModel const &model) override {
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/parallel_ops/partition.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class Repartition : public ParallelOp {
void create_input_partition(FFModel &model) override;
void init(FFModel const &) override;
void forward(FFModel const &) override;
void inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void backward(FFModel const &) override;
bool get_int_parameter(PMParameter, int *) const override;
bool append_parallel_op_info(
Expand Down
5 changes: 5 additions & 0 deletions src/ops/noop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ void NoOp::init(FFModel const &ff) {

void NoOp::forward(FFModel const &ff) {}

void NoOp::inference(FFModel const &ff,
std::vector<ParallelTensor> const &batch_inputs,
std::vector<ParallelTensor> const &batch_outputs,
MachineView const *mv) {}

void NoOp::backward(FFModel const &ff) {}

bool NoOp::measure_operator_cost(Simulator *sim,
Expand Down
32 changes: 32 additions & 0 deletions src/ops/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,38 @@ OpMeta *Softmax::init_task(Task const *task,
return m;
}

void Softmax::inference(FFModel const &ff,
std::vector<ParallelTensor> const &batch_inputs,
std::vector<ParallelTensor> const &batch_outputs,
MachineView const *mv) {
ArgumentMap argmap;
Context ctx = ff.config.lg_ctx;
Runtime *runtime = ff.config.lg_hlr;
set_argumentmap_for_forward(ff, argmap);
size_t machine_view_hash = mv ? mv->hash() : outputs[0]->machine_view.hash();
IndexLauncher launcher(SOFTMAX_FWD_TASK_ID,
parallel_is,
TaskArgument(NULL, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.add_region_requirement(RegionRequirement(inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
inputs[0]->region));
launcher.add_field(0, FID_DATA);
launcher.add_region_requirement(RegionRequirement(outputs[0]->part,
0 /*projection id*/,
WRITE_ONLY,
EXCLUSIVE,
outputs[0]->region));
launcher.add_field(1, FID_DATA);
runtime->execute_index_space(ctx, launcher);
}

void Softmax::forward(FFModel const &ff) {
ArgumentMap argmap;
Context ctx = ff.config.lg_ctx;
Expand Down
44 changes: 39 additions & 5 deletions src/parallel_ops/partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,45 @@ void Repartition::create_input_partition(FFModel &ff) {
outputs[0]->parallel_is,
inputs[0]->region,
input_lp);
ff.create_disjoint_partition(inputs[0]->num_dims,
inputs[0]->dims,
inputs[0]->parallel_is,
outputs[0]->region_grad,
output_grad_lp);
if (ff.config.computationMode == COMP_MODE_TRAINING) {
ff.create_disjoint_partition(inputs[0]->num_dims,
inputs[0]->dims,
inputs[0]->parallel_is,
outputs[0]->region_grad,
output_grad_lp);
}
}

void Repartition::inference(FFModel const &ff,
std::vector<ParallelTensor> const &batch_inputs,
std::vector<ParallelTensor> const &batch_outputs,
MachineView const *mv) {
ArgumentMap argmap;
Context ctx = ff.config.lg_ctx;
Runtime *runtime = ff.config.lg_hlr;
assert(numOutputs == 1);
assert(numInputs == 1);
assert(inputs[0]->data_type == outputs[0]->data_type);
DataType data_type = inputs[0]->data_type;
size_t machine_view_hash = mv ? mv->hash() : outputs[0]->machine_view.hash();
IndexLauncher launcher(REPARTITION_FWD_TASK_ID,
outputs[0]->parallel_is,
TaskArgument(&data_type, sizeof(DataType)),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.add_region_requirement(RegionRequirement(
input_lp, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, inputs[0]->region));
launcher.add_field(0, FID_DATA);
launcher.add_region_requirement(RegionRequirement(outputs[0]->part,
0 /*projection id*/,
WRITE_ONLY,
EXCLUSIVE,
outputs[0]->region));
launcher.add_field(1, FID_DATA);
runtime->execute_index_space(ctx, launcher);
}

void Repartition::forward(FFModel const &ff) {
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/inference_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ void InferenceManager::inference(int index) {
assert(index < max_num_inflight_batches);
for (size_t o = 0; o < model->operators.size(); o++) {
Op *op = model->operators[o];
if (op->op_type == OP_WEIGHT) {
continue;
}
std::vector<ParallelTensor> inputs(op->numInputs);
std::vector<ParallelTensor> outputs(op->numOutputs);
for (int i = 0; i < op->numInputs; i++) {
assert(op->inputs[i] != nullptr);
assert(tensor_buffer[op->inputs[i]].size() > index);
inputs[i] = tensor_buffer[op->inputs[i]][index];
}
for (int i = 0; i < op->numOutputs; i++) {
assert(op->outputs[i] != nullptr);
assert(tensor_buffer[op->outputs[i]].size() > index);
outputs[i] = tensor_buffer[op->outputs[i]][index];
}
op->inference(*model, inputs, outputs);
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3114,9 +3114,11 @@ void FFModel::compile(LossType loss_type,
assert(false && "Unsupported dim");
}
}
// init optimizer
assert(optimizer != NULL);
optimizer->init();
if (config.computationMode == COMP_MODE_TRAINING) {
// init optimizer
assert(optimizer != NULL);
optimizer->init();
}

#ifdef FF_USE_NCCL
if (config.computationMode == COMP_MODE_TRAINING) {
Expand Down

0 comments on commit 6cd6b67

Please sign in to comment.