Skip to content

Commit

Permalink
Allow exporting decoder models using optimum-cli (#422)
Browse files Browse the repository at this point in the history
* refactor(export): add root config class

* feat(decoder): accept batch_size = None

* feat(decoder): accept num_cores = None

* feat(cli): support exporting decoder models

* tests: use NeuronDefaultConfig

* doc: update export cli section

* test: add export decoder cli test

* Apply suggestions from code review

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* feat(decoder): check that the host has neuron devices

* ci(inf2): move generation tests up

* test(decoder): extend fixture scope to session

---------

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
  • Loading branch information
dacorvo and michaelbenayoun authored Jan 25, 2024
1 parent 2709183 commit 7d0dbb5
Show file tree
Hide file tree
Showing 17 changed files with 219 additions and 94 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/test_inf2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ jobs:
run: |
source aws_neuron_venv_pytorch/bin/activate
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/cli
- name: Run generation tests
run: |
source aws_neuron_venv_pytorch/bin/activate
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation
- name: Run exporters tests
run: |
source aws_neuron_venv_pytorch/bin/activate
Expand All @@ -51,10 +55,6 @@ jobs:
run: |
source aws_neuron_venv_pytorch/bin/activate
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/inference
- name: Run generation tests
run: |
source aws_neuron_venv_pytorch/bin/activate
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation
- name: Run pipelines tests
run: |
source aws_neuron_venv_pytorch/bin/activate
Expand Down
20 changes: 11 additions & 9 deletions docs/source/guides/export_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ optimum-cli export neuron \
--model bert-base-uncased \
--sequence_length 128 \
--batch_size 1 \
bert_neuron/
bert_neuron/
```

Check out the help for more options:
Expand All @@ -36,7 +36,7 @@ optimum-cli export neuron --help

## Why compile to Neuron model?

AWS provides two generations of the Inferentia accelerator built for machine learning inference with higher throughput, lower latency but lower cost: [inf2 (NeuronCore-v2)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html) and [inf1 (NeuronCore-v1)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf1-arch.html#aws-inf1-arch).
AWS provides two generations of the Inferentia accelerator built for machine learning inference with higher throughput, lower latency but lower cost: [inf2 (NeuronCore-v2)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html) and [inf1 (NeuronCore-v1)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf1-arch.html#aws-inf1-arch).

In production environments, to deploy 🤗 [Transformers](https://huggingface.co/docs/transformers/index) models on Neuron devices, you need to compile your models and export them to a serialized format before inference. Through Ahead-Of-Time (AOT) compilation with Neuron Compiler( [neuronx-cc](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/compiler/neuronx-cc/index.html) or [neuron-cc](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/compiler/neuron-cc/neuron-cc.html) ), your models will be converted to serialized and optimized [TorchScript modules](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html).

Expand All @@ -49,18 +49,18 @@ To understand a little bit more about the compilation, here are general steps ex
</Tip>

Although pre-compilation avoids overhead during the inference, traced Neuron module has some limitations:
* Traced Neuron module will be static, which requires fixed input shapes and data types used passed during the compilation. As the model won't be dynamically recompiled, the inference will fail if any of the above conditions change.
* Traced Neuron module will be static, which requires fixed input shapes and data types used during the compilation. As the model won't be dynamically recompiled, the inference will fail if any of the above conditions change.
(*But these limitations could be bypass with [dynamic batching](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html#dynamic-batching) and [bucketing](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/appnotes/torch-neuron/bucketing-app-note.html#bucketing-app-note)*).
* Neuron models are hardware-specialized, which means:
* Models traced with Neuron can no longer be executed in non-Neuron environment.
* Models traced with Neuron can no longer be executed in non-Neuron environment.
* Models compiled for inf1 (NeuronCore-v1) are not compatible with inf2 (NeuronCore-v2), and vice versa.

In this guide, we'll show you how to export your models to serialized models optimized for Neuron devices.

<Tip>

🤗 Optimum provides support for the Neuron export by leveraging configuration objects.
These configuration objects come ready made for a number of model architectures, and are designed to be easily extendable to other architectures.
These configuration objects come ready made for a number of model architectures, and are designed to be easily extendable to other architectures.

**To check the supported architectures, go to the [configuration reference page](../package_reference/configuration).**

Expand Down Expand Up @@ -89,7 +89,7 @@ optimum-cli export neuron --help

usage: optimum-cli export neuron [-h] -m MODEL [--task TASK] [--atol ATOL] [--cache_dir CACHE_DIR] [--trust-remote-code]
[--compiler_workdir COMPILER_WORKDIR] [--disable-validation] [--auto_cast {none,matmul,all}]
[--auto_cast_type {bf16,fp16,tf32}] [--dynamic-batch-size] [--unet UNET]
[--auto_cast_type {bf16,fp16,tf32}] [--dynamic-batch-size] [--num_cores NUM_CORES] [--unet UNET]
[--output_hidden_states] [--output_attentions] [--batch_size BATCH_SIZE]
[--sequence_length SEQUENCE_LENGTH] [--num_beams NUM_BEAMS] [--num_choices NUM_CHOICES]
[--num_channels NUM_CHANNELS] [--width WIDTH] [--height HEIGHT]
Expand Down Expand Up @@ -137,6 +137,8 @@ Optional arguments:
--dynamic-batch-size Enable dynamic batch size for neuron compiled model. If this option is enabled, the input batch size can
be a multiple of the batch size during the compilation, but it comes with a potential tradeoff in terms
of latency.
--num_cores NUM_CORES
The number of cores on which the model should be deployed (text-generation only).
--unet UNET UNet model ID on huggingface.co or path on disk to load model from. This will replace the unet in the
original Stable Diffusion pipeline.
--output_hidden_states
Expand Down Expand Up @@ -173,7 +175,7 @@ Exporting a checkpoint can be done as follows:
optimum-cli export neuron --model distilbert-base-uncased-distilled-squad --batch_size 1 --sequence_length 16 distilbert_base_uncased_squad_neuron/
```
You should see the following logs which validate the model on Neuron deivces by comparing with PyTorch model on CPU:
You should see the following logs which validate the model on Neuron devices by comparing with PyTorch model on CPU:
```bash
Validating Neuron model...
Expand All @@ -192,7 +194,7 @@ As you can see, the task was automatically detected. This was possible because t
optimum-cli export neuron --model local_path --task question-answering --batch_size 1 --sequence_length 16 --dynamic-batch-size distilbert_base_uncased_squad_neuron/
```
Note that providing the `--task` argument for a model on the Hub will disable the automatic task detection. The resulting `model.neuron` file, can then be loaded and run on Neuron devices.
Note that providing the `--task` argument for a model on the Hub will disable the automatic task detection. The resulting `model.neuron` file, can then be loaded and run on Neuron devices.
## Exporting a model to Neuron via NeuronModel
Expand All @@ -204,7 +206,7 @@ You will also be able to export your models to Neuron format with `optimum.neuro
>>> input_shapes = {"batch_size": 1, "sequence_length": 64} # mandatory shapes
>>> model = NeuronModelForSequenceClassification.from_pretrained(
... "distilbert-base-uncased-finetuned-sst-2-english", export=True, **input_shapes
... )
... )

# Save the model
>>> model.save_pretrained("./distilbert-base-uncased-finetuned-sst-2-english_neuron/")
Expand Down
6 changes: 6 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def parse_args_neuronx(parser: "ArgumentParser"):
action="store_true",
help="Enable dynamic batch size for neuron compiled model. If this option is enabled, the input batch size can be a multiple of the batch size during the compilation, but it comes with a potential tradeoff in terms of latency.",
)
optional_group.add_argument(
"--num_cores",
type=int,
default=None,
help="The number of cores on which the model should be deployed (text-generation only).",
)
optional_group.add_argument(
"--unet",
default=None,
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"normalize_input_shapes",
"normalize_stable_diffusion_input_shapes",
],
"base": ["NeuronConfig"],
"base": ["NeuronDefaultConfig"],
"convert": ["export", "export_models", "validate_model_outputs", "validate_models_outputs"],
"utils": [
"DiffusersPretrainedConfig",
Expand All @@ -40,7 +40,7 @@
normalize_input_shapes,
normalize_stable_diffusion_input_shapes,
)
from .base import NeuronConfig
from .base import NeuronDefaultConfig
from .convert import export, export_models, validate_model_outputs, validate_models_outputs
from .utils import (
DiffusersPretrainedConfig,
Expand Down
43 changes: 37 additions & 6 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, PretrainedConfig

from ...neuron import NeuronModelForCausalLM
from ...neuron.utils import (
DECODER_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
Expand All @@ -41,6 +42,7 @@
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from .base import NeuronDecoderConfig
from .convert import export_models, validate_models_outputs
from .model_configs import * # noqa: F403
from .utils import (
Expand Down Expand Up @@ -106,7 +108,7 @@ def infer_task(task: str, model_name_or_path: str) -> str:
return task


def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int]:
def get_input_shapes_and_config_class(task: str, args: argparse.Namespace) -> Dict[str, int]:
config = AutoConfig.from_pretrained(args.model)

model_type = config.model_type.replace("_", "-")
Expand All @@ -116,9 +118,9 @@ def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model_type=model_type, exporter="neuron", task=task
)
mandatory_axes = neuron_config_constructor.func.get_mandatory_axes_for_task(task)
input_shapes = {name: getattr(args, name) for name in mandatory_axes}
return input_shapes
input_args = neuron_config_constructor.func.get_input_args_for_task(task)
input_shapes = {name: getattr(args, name) for name in input_args}
return input_shapes, neuron_config_constructor.func


def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]:
Expand Down Expand Up @@ -457,6 +459,19 @@ def main_export(
)


def decoder_export(
model_name_or_path: str,
output: Union[str, Path],
**kwargs,
):
output = Path(output)
if not output.parent.exists():
output.parent.mkdir(parents=True)

model = NeuronModelForCausalLM.from_pretrained(model_name_or_path, export=True, **kwargs)
model.save_pretrained(output)


def main():
parser = ArgumentParser(f"Hugging Face Optimum {NEURON_COMPILER} exporter")

Expand All @@ -468,7 +483,6 @@ def main():
task = infer_task(args.task, args.model)
is_stable_diffusion = "stable-diffusion" in task
is_sentence_transformers = args.library_name == "sentence_transformers"
compiler_kwargs = infer_compiler_kwargs(args)

if is_stable_diffusion:
input_shapes = normalize_stable_diffusion_input_shapes(args)
Expand All @@ -477,9 +491,26 @@ def main():
input_shapes = normalize_sentence_transformers_input_shapes(args)
submodels = None
else:
input_shapes = normalize_input_shapes(task, args)
input_shapes, neuron_config_class = get_input_shapes_and_config_class(task, args)
if NeuronDecoderConfig in inspect.getmro(neuron_config_class):
# TODO: warn about ignored args:
# dynamic_batch_size, compiler_workdir, optlevel,
# atol, disable_validation, library_name
decoder_export(
model_name_or_path=args.model,
output=args.output,
task=task,
cache_dir=args.cache_dir,
trust_remote_code=args.trust_remote_code,
subfolder=args.subfolder,
auto_cast_type=args.auto_cast_type,
num_cores=args.num_cores,
**input_shapes,
)
return
submodels = None

compiler_kwargs = infer_compiler_kwargs(args)
optional_outputs = customize_optional_outputs(args)
optlevel = parse_optlevel(args)

Expand Down
77 changes: 53 additions & 24 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,46 @@ class MissingMandatoryAxisDimension(ValueError):
pass


class NeuronConfig(ExportConfig, ABC):
class NeuronConfig(ExportConfig):
"""Base class for Neuron exportable models
Class attributes:
- INPUT_ARGS (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either:
- An argument name, for instance "batch_size" or "sequence_length", that indicates that the argument can
be passed to export the model,
- Or a tuple containing two elements:
- The first one is either a string or a tuple of strings and specifies for which task(s) the argument is relevant
- The second one is the argument name.
Input arguments can be mandatory for some export types, as specified in child classes.
Args:
task (`str`):
The task the model should be exported for.
"""

INPUT_ARGS = ()

@classmethod
def get_input_args_for_task(cls, task: str) -> Tuple[str]:
axes = []
for axis in cls.INPUT_ARGS:
if isinstance(axis, tuple):
tasks, name = axis
if not isinstance(tasks, tuple):
tasks = (tasks,)
if task not in tasks:
continue
else:
name = axis
axes.append(name)
return tuple(axes)


class NeuronDefaultConfig(NeuronConfig, ABC):
"""
Base class for Neuron exportable model describing metadata on how to export the model through the TorchScript format.
Base class for configuring the export of Neuron TorchScript models.
Class attributes:
Expand All @@ -50,14 +87,14 @@ class NeuronConfig(ExportConfig, ABC):
[`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs.
- ATOL_FOR_VALIDATION (`Union[float, Dict[str, float]]`) -- A float or a dictionary mapping task names to float,
where the float values represent the absolute tolerance value to use during model conversion validation.
- MANDATORY_AXES (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either:
- An axis name, for instance "batch_size" or "sequence_length", that indicates that the axis dimension is
needed to export the model,
- INPUT_ARGS (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either:
- An argument name, for instance "batch_size" or "sequence_length", that indicates that the argument MUST
be passed to export the model,
- Or a tuple containing two elements:
- The first one is either a string or a tuple of strings and specifies for which task(s) the axis is needed
- The second one is the axis name.
- The first one is either a string or a tuple of strings and specifies for which task(s) the argument must be passed
- The second one is the argument name.
For example: `MANDATORY_AXES = ("batch_size", "sequence_length", ("multiple-choice", "num_choices"))` means that
For example: `INPUT_ARGS = ("batch_size", "sequence_length", ("multiple-choice", "num_choices"))` means that
to export the model, the batch size and sequence length values always need to be specified, and that a value
for the number of possible choices is needed when the task is multiple-choice.
Expand All @@ -74,13 +111,12 @@ class NeuronConfig(ExportConfig, ABC):
The data type of float tensors, could be ["fp32", "fp16", "bf16"], default to "fp32".
The rest of the arguments are used to specify the shape of the inputs the model can take.
They are required or not depending on the model the `NeuronConfig` is designed for.
They are required or not depending on the model the `NeuronDefaultConfig` is designed for.
"""

NORMALIZED_CONFIG_CLASS = None
DUMMY_INPUT_GENERATOR_CLASSES = ()
ATOL_FOR_VALIDATION: Union[float, Dict[str, float]] = 1e-5
MANDATORY_AXES = ()
MODEL_TYPE = None

_TASK_TO_COMMON_OUTPUTS = {
Expand Down Expand Up @@ -165,18 +201,7 @@ def __init__(

@classmethod
def get_mandatory_axes_for_task(cls, task: str) -> Tuple[str]:
axes = []
for axis in cls.MANDATORY_AXES:
if isinstance(axis, tuple):
tasks, name = axis
if not isinstance(tasks, tuple):
tasks = (tasks,)
if task not in tasks:
continue
else:
name = axis
axes.append(name)
return tuple(axes)
return cls.get_input_args_for_task(task)

@property
def task(self) -> str:
Expand Down Expand Up @@ -343,12 +368,15 @@ def forward(self, *input):
return ModelWrapper(model, list(dummy_inputs.keys()))


class NeuronDecoderConfig(ExportConfig):
class NeuronDecoderConfig(NeuronConfig):
"""
Base class for configuring the export of Neuron Decoder models
Class attributes:
- INPUT_ARGS (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either:
- An argument name, for instance "batch_size" or "sequence_length", that indicates that the argument can
be passed to export the model,
- NEURONX_CLASS (`str`) -- the name of the transformers-neuronx class to instantiate for the model.
It is a full class name defined relatively to the transformers-neuronx module, e.g. `gpt2.model.GPT2ForSampling`
[`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs.
Expand All @@ -359,9 +387,10 @@ class NeuronDecoderConfig(ExportConfig):
task (`str`): The task the model should be exported for.
"""

INPUT_ARGS = ("batch_size", "sequence_length")
NEURONX_CLASS = None

def __init__(self, task):
def __init__(self, task: str):
if not is_transformers_neuronx_available():
raise ModuleNotFoundError(
"The mandatory transformers-neuronx package is missing. Please install optimum[neuronx]."
Expand Down
Loading

0 comments on commit 7d0dbb5

Please sign in to comment.