Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jul 27, 2023
1 parent 06b1e8b commit 02ae7ce
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ArgMax : public Op {
void forward(FFModel const &) override;
void backward(FFModel const &) override;
Legion::FutureMap inference(FFModel const &,
BatchConfig const &,
BatchConfigFuture const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
Expand Down
1 change: 0 additions & 1 deletion src/ops/arg_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ InferenceResult
assert(regions.size() == 2);
assert(task->regions.size() == 2);
// const ArgTopK* topk = (const ArgTopK*) task->args;
// BatchConfig const *bc = (BatchConfig *)task->args;
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_tokens == 0) {
// Directly return for empty batch config
Expand Down
22 changes: 17 additions & 5 deletions src/ops/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void ArgMax::forward(FFModel const &ff) {
}

FutureMap ArgMax::inference(FFModel const &ff,
BatchConfig const &bc,
BatchConfigFuture const &bc,
std::vector<ParallelTensor> const &batch_inputs,
std::vector<ParallelTensor> const &batch_outputs,
MachineView const *mv) {
Expand All @@ -270,12 +270,13 @@ FutureMap ArgMax::inference(FFModel const &ff,
if (beam_search) {
IndexLauncher launcher(ARGMAX_BEAM_INF_TASK_ID,
parallel_is,
TaskArgument(&bc, sizeof(BatchConfig)),
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.add_future(bc);
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
Expand Down Expand Up @@ -307,12 +308,13 @@ FutureMap ArgMax::inference(FFModel const &ff,
} else {
IndexLauncher launcher(ARGMAX_NORM_INF_TASK_ID,
parallel_is,
TaskArgument(&bc, sizeof(BatchConfig)),
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.add_future(bc);
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
Expand Down Expand Up @@ -344,7 +346,12 @@ BeamInferenceResult
Runtime *runtime) {
assert(regions.size() == 4);
assert(task->regions.size() == 4);
BatchConfig const *bc = (BatchConfig *)task->args;
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_tokens == 0) {
// Directly return for empty batch config
BeamInferenceResult ir;
return ir;
}
ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args);

GenericTensorAccessorW input = helperGetGenericTensorAccessorRW(
Expand Down Expand Up @@ -378,8 +385,13 @@ InferenceResult
Runtime *runtime) {
assert(regions.size() == 3);
assert(task->regions.size() == 3);
BatchConfig const *bc = (BatchConfig *)task->args;
ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args);
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_tokens == 0) {
// Directly return for empty batch config
InferenceResult ir;
return ir;
}

GenericTensorAccessorW input = helperGetGenericTensorAccessorRW(
m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
Expand Down

0 comments on commit 02ae7ce

Please sign in to comment.