diff --git a/include/flexflow/ops/beam_topk.h b/include/flexflow/ops/beam_topk.h index 76404bfb6d..57ab5c1074 100644 --- a/include/flexflow/ops/beam_topk.h +++ b/include/flexflow/ops/beam_topk.h @@ -82,7 +82,7 @@ class BeamTopK : public Op { float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted, ffStream_t stream); @@ -92,7 +92,7 @@ class BeamTopK : public Op { float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted); Params get_params() const; diff --git a/src/ops/arg_topk.cc b/src/ops/arg_topk.cc index a604c016d2..c1bbb65f1e 100644 --- a/src/ops/arg_topk.cc +++ b/src/ops/arg_topk.cc @@ -311,9 +311,6 @@ InferenceResult int batch_size = bc->num_active_tokens(); ArgTopK::forward_kernel_wrapper(m, input, indices, batch_size); - int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; - batch_size = input.domain.get_volume() / length; - InferenceResult ir; download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index db507c1729..0920105acc 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -379,12 +379,9 @@ BeamInferenceResult // total token nums size_t tokens_per_request = in1_domain.hi()[1] - in1_domain.lo()[1] + 1; - size_t batch_size = in1_domain.get_volume() / length; - + int batch_size = bc->num_active_tokens(); // std::cout << "beam search topk params: " << length << ", " << k << ", " // << batch_size << "\n"; - assert(out2_domain.get_volume() / k == batch_size); - // std::vector beam_width; // std::unordered_map sub_requests = bc->sub_requests; // for (int i = 0; i < bc->MAX_NUM_REQUESTS; i++) { diff --git a/src/ops/beam_topk.cpp b/src/ops/beam_topk.cpp index 1817eae4da..248ab188da 100644 --- a/src/ops/beam_topk.cpp +++ b/src/ops/beam_topk.cpp @@ -479,7 +479,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted, hipStream_t stream) { @@ -630,7 +630,7 @@ void BeamTopK::forward_kernel_wrapper(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted) { hipStream_t stream; diff --git a/src/ops/beam_topk.cu b/src/ops/beam_topk.cu index 9a5cd86486..ceddb55f2d 100644 --- a/src/ops/beam_topk.cu +++ b/src/ops/beam_topk.cu @@ -511,7 +511,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted, cudaStream_t stream) { @@ -662,7 +662,7 @@ void BeamTopK::forward_kernel_wrapper(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted) { cudaStream_t stream;