diff --git a/include/flexflow/ops/argmax.h b/include/flexflow/ops/argmax.h index d6d15f2a3c..709861f51c 100644 --- a/include/flexflow/ops/argmax.h +++ b/include/flexflow/ops/argmax.h @@ -47,7 +47,7 @@ class ArgMax : public Op { void forward(FFModel const &) override; void backward(FFModel const &) override; Legion::FutureMap inference(FFModel const &, - BatchConfig const &, + BatchConfigFuture const &, std::vector const &, std::vector const &, MachineView const *mv = nullptr) override; diff --git a/src/ops/arg_topk.cc b/src/ops/arg_topk.cc index b30114830d..b877a9f96d 100644 --- a/src/ops/arg_topk.cc +++ b/src/ops/arg_topk.cc @@ -302,7 +302,6 @@ InferenceResult assert(regions.size() == 2); assert(task->regions.size() == 2); // const ArgTopK* topk = (const ArgTopK*) task->args; - // BatchConfig const *bc = (BatchConfig *)task->args; BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { // Directly return for empty batch config diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index 754337448e..8598a71d50 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -254,7 +254,7 @@ void ArgMax::forward(FFModel const &ff) { } FutureMap ArgMax::inference(FFModel const &ff, - BatchConfig const &bc, + BatchConfigFuture const &bc, std::vector const &batch_inputs, std::vector const &batch_outputs, MachineView const *mv) { @@ -270,12 +270,13 @@ FutureMap ArgMax::inference(FFModel const &ff, if (beam_search) { IndexLauncher launcher(ARGMAX_BEAM_INF_TASK_ID, parallel_is, - TaskArgument(&bc, sizeof(BatchConfig)), + TaskArgument(nullptr, 0), argmap, Predicate::TRUE_PRED, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.add_future(bc); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_WRITE, @@ -307,12 +308,13 @@ FutureMap ArgMax::inference(FFModel const &ff, } else { IndexLauncher launcher(ARGMAX_NORM_INF_TASK_ID, parallel_is, - TaskArgument(&bc, sizeof(BatchConfig)), + TaskArgument(nullptr, 0), argmap, Predicate::TRUE_PRED, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.add_future(bc); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_WRITE, @@ -344,7 +346,12 @@ BeamInferenceResult Runtime *runtime) { assert(regions.size() == 4); assert(task->regions.size() == 4); - BatchConfig const *bc = (BatchConfig *)task->args; + BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); + if (bc->num_tokens == 0) { + // Directly return for empty batch config + BeamInferenceResult ir; + return ir; + } ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args); GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( @@ -378,8 +385,13 @@ InferenceResult Runtime *runtime) { assert(regions.size() == 3); assert(task->regions.size() == 3); - BatchConfig const *bc = (BatchConfig *)task->args; ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args); + BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); + if (bc->num_tokens == 0) { + // Directly return for empty batch config + InferenceResult ir; + return ir; + } GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);