From d98b6f9c541061d448a8acb8e636a11a46240e72 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 26 Sep 2024 23:17:06 +0530 Subject: [PATCH] Fix issue with no of prompt less than FBS (#132) Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6a10444a..75281800 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -170,16 +170,29 @@ def get_input_prompts(prompt: str, prompts_txt_file_path: str) -> List[str]: def fix_prompts(prompt: List[str], batch_size: int, full_batch_size: int = None): - if len(prompt) < batch_size or (full_batch_size is not None and len(prompt) < full_batch_size): + """ + Adjusts the list of prompts to match the required batch size. + + ``Mandatory`` Args: + prompt (List[str]): List of input prompts. + batch_size (int): The batch size to process at a time. + + ``Optional`` Args: + full_batch_size (Optional[int]): The full batch size if different from batch_size. + + Returns: + List[str]: Adjusted list of prompts. + """ + exec_batch_size = full_batch_size if full_batch_size is not None else batch_size + + if len(prompt) < exec_batch_size: logger.warning("Number of prompts are less than batch size/full batch size, repeating to required batch size") - prompt = prompt * -(batch_size // -len(prompt)) # Repeat prompt to required size - prompt = prompt[:batch_size] # Truncate prompts to required size - elif not full_batch_size: - if (len(prompt) % batch_size) > 0: - logger.warning( - "Number of prompts are not multiple of batch size, dropping last incomplete batch from given input prompts" - ) - prompt = prompt[: batch_size * (len(prompt) // batch_size)] + prompt = (prompt * (exec_batch_size // len(prompt) + 1))[:exec_batch_size] + elif full_batch_size is None and len(prompt) % batch_size != 0: + logger.warning( + "Number of prompts are not multiple of batch size, dropping last incomplete batch from given input prompts" + ) + prompt = prompt[: batch_size * (len(prompt) // batch_size)] return prompt