Skip to content

Commit

Permalink
Continuous batching (quic#73)
Browse files Browse the repository at this point in the history
* - Added Continuous batching feature
- Refactored text generation module

Signed-off-by: quic-rishinr <quic_rishinr@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Added cherrypicked continous batching changes.

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Updated the assert condition for bs > 1 and full batch size >1
Updated issue with qpc path creation for non cb execution.
Added condition to check CB is enabled for supported architectures
Added formatting changes

Signed-off-by: quic-rishinr <quic_rishinr@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* update full-batch-size args

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* update unique cache dir to include arg naming

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* include env variable in the QEff_MODELS_DIR to override

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Small bug fix

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Fixed issue with output issue with FBS > 1.
Cherry picked the support for Mixtral
Added CB suport for Starcoder

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Lint & Format (quic#53)

* Lint & Format

- Added linting and formatting github actions
- Formatted entire codebase
- Fixed linter errors
- Removed `# noqa` with fix

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Split test config into multiple-lines

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Fix external repo for workflow

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Format newly added files

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

---------

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* enabling `export_and_compile` for `QEFFAutoModelForCausalLM` (quic#48)

* enabling export_and_compile

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* * cleaned API usage, *Integrated export into compile *Addressed comments

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* removed src, simplified automodelclass

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added base directory in place of src

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* replaced src/auto with transformers/models/modeling_auto

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* ran linter and formatter

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* removed commented code

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* fixed typos

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* fixed testing script

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* removed unitTest dependency using pytest only in all tests

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added test report for showing on jenkins view

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* updated jenkinsfile to capture test data in xml for jenkins view

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* fixed HL tests

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* fixed cloud tests

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* made ctx_len default argument in exec_kv function, fixed tests/cloud/test_infer

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* run ruff formatter

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* fixed type hint

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* moved use_cache assignment to init so that models initialized via init will also have the flag True

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* ran ruff formatter

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

---------

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* - Added Continuous batching feature
- Refactored text generation module

Signed-off-by: quic-rishinr <quic_rishinr@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Updated the assert condition for bs > 1 and full batch size >1
Updated issue with qpc path creation for non cb execution.
Added condition to check CB is enabled for supported architectures
Added formatting changes

Signed-off-by: quic-rishinr <quic_rishinr@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* update unique cache dir to include arg naming

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Small bug fix

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* rebased the code against mainline

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Created a separate file for scatter and gather CB ops adhering to PR 55

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Formatted the code using linter

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* removed runtime_args

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Added FBS flag in execute module

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* making HL test alighnment with Continuous Batching

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* Removed cache path from infer and export module, Updated default cache path in constants

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Rebased against main

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Lint and format

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Removed base_dir_name from export.py
Removed TODO from infer
Removed CB-specific scatter and gather op from cts_scatter_gather.py
Updated CB model architecture change to export_for_cloud module and changed it to NotImplementedError
Commented out custom_opsets usage in export_onnx_model
Lint fix on conftest.py
Removed print statement from text_generation_inference.py

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Lint and format

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Adding test configs

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Added support for pytorch input handler, Added support for fetching FBS from QPC

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Added back check_and_assign_cache_dir in infer and export, reverted custom_opsets in export_utils,Minor fix in text_generation_inference

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Linter formatting and minor bug fixes

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Addressed review comments and fixed the issue with total decode token calculation

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Linter and formating

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* rebased against mainline

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Lint formaatting

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Making CI Running I

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* Added Transformed models and QPC storage section in readme
Removed Constants.CACHE_DIR.
Added FBS and BS check in compiler helper.
Renamed “perfill time” print statement to “Average prefill time”.
Added CB transform class.
Updated Modeling file to adhere to CBTransform changes.
Renamed Qeff cache folder from qeff_models to qeff_cache.
Other review changes.

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Linter and added some missing changes

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Added CI changes and some missing changes

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Updated logic for initializing transform classes for PyTorch transforms

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Rebased and updated doc string

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Resloved Testing bugs

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* adding some models in json file

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* Added streamer for CB

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Updated generated ID len

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* removed streamer for CB

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* Changed Tests Configs

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* Lint format fix

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* updated generated output print logic

Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>

* added extra line between full batch size prints

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* removed commented code

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* removed commented lines

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added infer docstring back

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

---------

Signed-off-by: quic-rishinr <quic_rishinr@quicinc.com>
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>
Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>
Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>
Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>
Co-authored-by: Rishin Raj <rishinr@qti.qualcomm.com>
Co-authored-by: Vinayak Baddi <quic_vbaddi@quicinc.com>
Co-authored-by: Ilango Rajagopal <quic_irajagop@quicinc.com>
Co-authored-by: Onkar Chougule <168134249+ochougul@users.noreply.github.com>
Co-authored-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>
Co-authored-by: Onkar Chougule <quic_ochougul@quicinc.com>
  • Loading branch information
7 people authored Aug 30, 2024
1 parent 643bb2c commit 614c4cd
Show file tree
Hide file tree
Showing 33 changed files with 1,914 additions and 832 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,4 @@ cython_debug/
# Local Files
cache_dir
qeff_models
.vscode/*
8 changes: 7 additions & 1 deletion QEfficient/cloud/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@
default=-1,
help=" Effort level to reduce the on-chip memory",
)

parser.add_argument(
"--full_batch_size",
"--full-batch-size",
type=int,
default=None,
help="Set full batch size to enable continuous batching mode, default is None",
)
# FIXME(ochougul): Allow extra compilation arguments
args = parser.parse_args()
QEfficient.compile(**vars(args))
211 changes: 110 additions & 101 deletions QEfficient/cloud/execute.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,110 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import argparse
from typing import List, Optional

from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.utils import load_hf_tokenizer
from QEfficient.utils.constants import Constants


def main(
model_name: str,
qpc_path: str,
device_group: Optional[List[int]] = None,
local_model_dir: Optional[str] = None,
prompt: Optional[str] = None, # type: ignore
prompts_txt_file_path: Optional[str] = None,
generation_len: Optional[int] = None,
cache_dir: Optional[str] = Constants.CACHE_DIR,
hf_token: Optional[str] = None,
) -> None:
"""
Helper function used by execute CLI app to run the Model on ``Cloud AI 100`` Platform.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``.
:qpc_path (str): Path to the generated binary after compilation.
``Optional`` Args:
:device_group (List[int]): Device Ids to be used for compilation. if len(device_group) > 1. Multiple Card setup is enabled. ``Defaults to None.``
:local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.``
:prompt (str): Sample prompt for the model text generation. ``Defaults to None.``
:prompts_txt_file_path (str): Path to txt file for multiple input prompts. ``Defaults to None.``
:generation_len (int): Number of tokens to be generated. ``Defaults to None.``
:cache_dir (str): Cache dir where downloaded HuggingFace files are stored. ``Defaults to Constants.CACHE_DIR.``
:hf_token (str): HuggingFace login token to access private repos. ``Defaults to None.``
.. code-block:: bash
python -m QEfficient.cloud.execute OPTIONS
"""
tokenizer = load_hf_tokenizer(
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
cache_dir=cache_dir,
hf_token=hf_token,
)

# Execute
cloud_ai_100_exec_kv(
tokenizer=tokenizer,
qpc_path=qpc_path,
device_id=device_group,
prompt=prompt,
prompts_txt_file_path=prompts_txt_file_path,
generation_len=generation_len,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Execution script.")
parser.add_argument(
"--model_name", "--model-name", required=False, type=str, help="HF model card name for tokenizing the inputs"
)
parser.add_argument("--qpc_path", "--qpc-path", required=True, help="Path to generated QPC")
parser.add_argument(
"--device_group",
"--device-group",
type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")],
help="Cloud AI 100 device ids (comma-separated) e.g. [0]",
)
parser.add_argument(
"--prompt",
type=lambda prompt: prompt.split("|"),
help="Input prompt, if executing for batch size>1, pass input prompts in single string but separate with pipe (|) symbol",
)
parser.add_argument(
"--prompts_txt_file_path",
"--prompts-txt-file-path",
type=str,
help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder",
)
parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate")
parser.add_argument(
"--local-model-dir", "--local_model_dir", required=False, help="Path to custom model weights and config files"
)
parser.add_argument(
"--cache-dir",
"--cache_dir",
default=Constants.CACHE_DIR,
required=False,
help="Cache dir to store HF Downloads",
)
parser.add_argument(
"--hf-token", "--hf_token", default=None, type=str, required=False, help="HF token id for private HF models"
)
args = parser.parse_args()
main(**args.__dict__)
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import argparse
from typing import List, Optional

from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.utils import load_hf_tokenizer


def main(
model_name: str,
qpc_path: str,
device_group: List[int] = None,
local_model_dir: Optional[str] = None,
prompt: Optional[str] = None, # type: ignore
prompts_txt_file_path: Optional[str] = None,
generation_len: Optional[int] = None,
cache_dir: Optional[str] = None,
hf_token: Optional[str] = None,
full_batch_size: Optional[int] = None,
):
"""
Helper function used by execute CLI app to run the Model on ``Cloud AI 100`` Platform.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``.
:qpc_path (str): Path to the generated binary after compilation.
``Optional`` Args:
:device_group (List[int]): Device Ids to be used for compilation. if len(device_group) > 1. Multiple Card setup is enabled.``Defaults to None.``
:local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.``
:prompt (str): Sample prompt for the model text generation. ``Defaults to None.``
:prompts_txt_file_path (str): Path to txt file for multiple input prompts. ``Defaults to None.``
:generation_len (int): Number of tokens to be generated. ``Defaults to None.``
:cache_dir (str): Cache dir where downloaded HuggingFace files are stored. ``Defaults to Constants.CACHE_DIR.``
:hf_token (str): HuggingFace login token to access private repos. ``Defaults to None.``
:full_batch_size (int): Set full batch size to enable continuous batching mode. ``Defaults to None.``
.. code-block:: bash
python -m QEfficient.cloud.execute OPTIONS
"""
tokenizer = load_hf_tokenizer(
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
cache_dir=cache_dir,
hf_token=hf_token,
)

# Execute
cloud_ai_100_exec_kv(
tokenizer=tokenizer,
qpc_path=qpc_path,
device_id=device_group,
prompt=prompt,
prompts_txt_file_path=prompts_txt_file_path,
generation_len=generation_len,
full_batch_size=full_batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Execution script.")
parser.add_argument(
"--model_name", "--model-name", required=False, type=str, help="HF model card name for tokenizing the inputs"
)
parser.add_argument("--qpc_path", "--qpc-path", required=True, help="Path to generated QPC")
parser.add_argument(
"--device_group",
"--device-group",
type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")],
help="Cloud AI 100 device ids (comma-separated) e.g. [0]",
)
parser.add_argument(
"--prompt",
type=lambda prompt: prompt.split("|"),
help="Input prompt, if executing for batch size>1, pass input prompts in single string but separate with pipe (|) symbol",
)
parser.add_argument(
"--prompts_txt_file_path",
"--prompts-txt-file-path",
type=str,
help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder",
)
parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate")
parser.add_argument(
"--local-model-dir", "--local_model_dir", required=False, help="Path to custom model weights and config files"
)
parser.add_argument(
"--cache-dir",
"--cache_dir",
default=None,
required=False,
help="Cache dir to store HF Downloads",
)
parser.add_argument(
"--full_batch_size",
"--full-batch-size",
type=int,
default=None,
help="Set full batch size to enable continuous batching mode, default is None",
)
parser.add_argument(
"--hf-token", "--hf_token", default=None, type=str, required=False, help="HF token id for private HF models"
)
args = parser.parse_args()
main(**args.__dict__)
22 changes: 20 additions & 2 deletions QEfficient/cloud/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_onnx_model_path(
tokenizer: Optional[Union[PreTrainedTokenizerFast, PreTrainedTokenizer]] = None,
hf_token: Optional[str] = None,
local_model_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
):
"""
exports the model to onnx if pre-exported file is not found and returns onnx_model_path
Expand All @@ -36,8 +37,9 @@ def get_onnx_model_path(
:tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Pass model tokenizer. ``Defaults to None.``
:hf_token (str): HuggingFace login token to access private repos. ``Defaults to None.``
:local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.``
:full_batch_size (int): Set full batch size to enable continuous batching mode. ``Defaults to None.``
"""
onnx_path_exists, onnx_dir_path, onnx_model_path = onnx_exists(model_name)
onnx_path_exists, onnx_dir_path, onnx_model_path = onnx_exists(model_name, full_batch_size)
if onnx_path_exists:
logger.info(f"Pre-exported ONNX files found at {onnx_dir_path}! Jumping to Compilation")
else:
Expand All @@ -55,6 +57,7 @@ def get_onnx_model_path(
form_factor="cloud",
hf_token=hf_token,
cache_dir=cache_dir,
full_batch_size=full_batch_size,
) # type: ignore
logger.info(f"Generated onnx_path: {onnx_model_path}, onnx_dir_path: {onnx_dir_path}")
return onnx_model_path
Expand All @@ -65,6 +68,7 @@ def main(
cache_dir: Optional[str] = None,
hf_token: Optional[str] = None,
local_model_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
) -> None:
"""
Helper function used by export CLI app for exporting to ONNX Model.
Expand All @@ -76,14 +80,21 @@ def main(
:cache_dir (str): Cache dir where downloaded HuggingFace files are stored. ``Defaults to None.``
:hf_token (str): HuggingFace login token to access private repos. ``Defaults to None.``
:local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.``
:full_batch_size (int): Set full batch size to enable continuous batching mode. ``Defaults to None.``
.. code-block:: bash
python -m QEfficient.cloud.export OPTIONS
"""
cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir)
get_onnx_model_path(model_name=model_name, cache_dir=cache_dir, hf_token=hf_token, local_model_dir=local_model_dir)
get_onnx_model_path(
model_name=model_name,
cache_dir=cache_dir,
hf_token=hf_token,
local_model_dir=local_model_dir,
full_batch_size=full_batch_size,
)


if __name__ == "__main__":
Expand All @@ -101,5 +112,12 @@ def main(
parser.add_argument(
"--hf-token", "--hf_token", default=None, type=str, required=False, help="HF token id for private HF models"
)
parser.add_argument(
"--full_batch_size",
"--full-batch-size",
type=int,
default=None,
help="Set full batch size to enable continuous batching mode, default is None",
)
args = parser.parse_args()
main(**args.__dict__)
21 changes: 19 additions & 2 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def main(
aic_enable_depth_first: bool = False,
mos: int = -1,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
prompt_len: int = 32,
ctx_len: int = 128,
generation_len: Optional[int] = None,
Expand All @@ -36,6 +37,10 @@ def main(
hf_token: Optional[str] = None,
) -> None:
"""
1. Check if compiled qpc for given config already exists, if it does jump to execute, else
2. Check if exported ONNX file already exists, if true, jump to compilation -> execution, else
3. Check if HF model exists in cache, if true, start transform -> export -> compilation -> execution, else,
4. Download HF model -> transform -> export -> compile -> execute
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
:num_cores (int): Number of cores to compile model on.
Expand All @@ -46,6 +51,7 @@ def main(
:aic_enable_depth_first (bool): Enables ``DFS`` with default memory size. ``Defaults to False.``
:mos (int): Effort level to reduce the on-chip memory. ``Defaults to -1.``
:batch_size (int): Batch size to compile the model for. ``Defaults to 1.``
:full_batch_size (int): Set full batch size to enable continuous batching mode. ``Default to None``
:prompt_len (int): Prompt length for the model to compile. ``Defaults to 32.``
:ctx_len (int): Maximum context length to compile the model. ``Defaults to 128.``
:generation_len (int): Number of tokens to be generated. ``Defaults to False.``
Expand All @@ -68,15 +74,17 @@ def main(
)

qpc_dir_path = get_qpc_dir_path(
model_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group
model_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size
)

# Handle qpc generation
if qpc_exists(qpc_dir_path):
logger.info(f"Pre-compiled qpc found at {qpc_dir_path}! Executing with given prompt")
else:
# Handle onnx model generation
onnx_model_path = get_onnx_model_path(model_name, cache_dir, tokenizer, hf_token, local_model_dir)
onnx_model_path = get_onnx_model_path(
model_name, cache_dir, tokenizer, hf_token, local_model_dir, full_batch_size
) # , base_dir_name)

#########
# Compile
Expand All @@ -95,6 +103,7 @@ def main(
aic_enable_depth_first=aic_enable_depth_first,
mos=mos,
device_group=device_group,
full_batch_size=full_batch_size,
)

#########
Expand All @@ -107,6 +116,7 @@ def main(
prompt=prompt,
prompts_txt_file_path=prompts_txt_file_path,
generation_len=generation_len,
full_batch_size=full_batch_size,
)


Expand Down Expand Up @@ -181,6 +191,13 @@ def main(
action="store_true",
help="pass to print info logs",
)
parser.add_argument(
"--full_batch_size",
"--full_batch_size",
type=int,
default=None,
help="Set full batch size to enable continuous batching mode, default is None",
)

args = parser.parse_args()
if args.verbose:
Expand Down
Loading

0 comments on commit 614c4cd

Please sign in to comment.