Skip to content

Commit

Permalink
change batch_size to num_active_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Jul 16, 2023
1 parent 53c5617 commit 5ba8c85
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 13 deletions.
4 changes: 2 additions & 2 deletions include/flexflow/ops/beam_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class BeamTopK : public Op {
float *output_ptr,
int *indices_ptr,
int *parent_ptr,
size_t batch_size,
int batch_size,
int length,
bool sorted,
ffStream_t stream);
Expand All @@ -92,7 +92,7 @@ class BeamTopK : public Op {
float *output_ptr,
int *indices_ptr,
int *parent_ptr,
size_t batch_size,
int batch_size,
int length,
bool sorted);
Params get_params() const;
Expand Down
3 changes: 0 additions & 3 deletions src/ops/arg_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,6 @@ InferenceResult
int batch_size = bc->num_active_tokens();
ArgTopK::forward_kernel_wrapper(m, input, indices, batch_size);

int length = input.domain.hi()[0] - input.domain.lo()[0] + 1;
batch_size = input.domain.get_volume() / length;

InferenceResult ir;
download_tensor<BatchConfig::TokenId>(
indices.get_int32_ptr(), ir.token_ids, batch_size);
Expand Down
5 changes: 1 addition & 4 deletions src/ops/beam_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,9 @@ BeamInferenceResult

// total token nums
size_t tokens_per_request = in1_domain.hi()[1] - in1_domain.lo()[1] + 1;
size_t batch_size = in1_domain.get_volume() / length;

int batch_size = bc->num_active_tokens();
// std::cout << "beam search topk params: " << length << ", " << k << ", "
// << batch_size << "\n";
assert(out2_domain.get_volume() / k == batch_size);

// std::vector<int> beam_width;
// std::unordered_map<size_t, int> sub_requests = bc->sub_requests;
// for (int i = 0; i < bc->MAX_NUM_REQUESTS; i++) {
Expand Down
4 changes: 2 additions & 2 deletions src/ops/beam_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m,
float *output_ptr,
int *indices_ptr,
int *parent_ptr,
size_t batch_size,
int batch_size,
int length,
bool sorted,
hipStream_t stream) {
Expand Down Expand Up @@ -630,7 +630,7 @@ void BeamTopK::forward_kernel_wrapper(BeamTopKMeta const *m,
float *output_ptr,
int *indices_ptr,
int *parent_ptr,
size_t batch_size,
int batch_size,
int length,
bool sorted) {
hipStream_t stream;
Expand Down
4 changes: 2 additions & 2 deletions src/ops/beam_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m,
float *output_ptr,
int *indices_ptr,
int *parent_ptr,
size_t batch_size,
int batch_size,
int length,
bool sorted,
cudaStream_t stream) {
Expand Down Expand Up @@ -662,7 +662,7 @@ void BeamTopK::forward_kernel_wrapper(BeamTopKMeta const *m,
float *output_ptr,
int *indices_ptr,
int *parent_ptr,
size_t batch_size,
int batch_size,
int length,
bool sorted) {
cudaStream_t stream;
Expand Down

0 comments on commit 5ba8c85

Please sign in to comment.