Skip to content

Commit

Permalink
Fix issue with no of prompt less than FBS (quic#132)
Browse files Browse the repository at this point in the history
Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
  • Loading branch information
quic-rishinr authored Sep 26, 2024
1 parent fbc7b8f commit d98b6f9
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,29 @@ def get_input_prompts(prompt: str, prompts_txt_file_path: str) -> List[str]:


def fix_prompts(prompt: List[str], batch_size: int, full_batch_size: int = None):
if len(prompt) < batch_size or (full_batch_size is not None and len(prompt) < full_batch_size):
"""
Adjusts the list of prompts to match the required batch size.
``Mandatory`` Args:
prompt (List[str]): List of input prompts.
batch_size (int): The batch size to process at a time.
``Optional`` Args:
full_batch_size (Optional[int]): The full batch size if different from batch_size.
Returns:
List[str]: Adjusted list of prompts.
"""
exec_batch_size = full_batch_size if full_batch_size is not None else batch_size

if len(prompt) < exec_batch_size:
logger.warning("Number of prompts are less than batch size/full batch size, repeating to required batch size")
prompt = prompt * -(batch_size // -len(prompt)) # Repeat prompt to required size
prompt = prompt[:batch_size] # Truncate prompts to required size
elif not full_batch_size:
if (len(prompt) % batch_size) > 0:
logger.warning(
"Number of prompts are not multiple of batch size, dropping last incomplete batch from given input prompts"
)
prompt = prompt[: batch_size * (len(prompt) // batch_size)]
prompt = (prompt * (exec_batch_size // len(prompt) + 1))[:exec_batch_size]
elif full_batch_size is None and len(prompt) % batch_size != 0:
logger.warning(
"Number of prompts are not multiple of batch size, dropping last incomplete batch from given input prompts"
)
prompt = prompt[: batch_size * (len(prompt) // batch_size)]

return prompt

Expand Down

0 comments on commit d98b6f9

Please sign in to comment.