Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Oct 22, 2024
1 parent 3679d53 commit ae767c1
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/package_reference/modeling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ The following Neuron model classes are available for natural language processing
[[autodoc]] modeling.NeuronModelForCausalLM
- forward

### NeuronModelForSeq2SeqLM

[[autodoc]] modeling_seq2seq.NeuronModelForSeq2SeqLM
- forward

## Computer Vision

The following Neuron model classes are available for computer vision tasks.
Expand Down
94 changes: 94 additions & 0 deletions optimum/neuron/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModelForSeq2SeqLM, GenerationConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.utils import ModelOutput
Expand Down Expand Up @@ -58,6 +59,35 @@

logger = logging.getLogger(__name__)

_TOKENIZER_FOR_DOC = "AutoTokenizer"

NEURON_SEQ2SEQ_MODEL_START_DOCSTRING = r"""
This model inherits from [`~neuron.modeling.NeuronTracedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving)
Args:
encoder (`torch.jit._script.ScriptModule`): [torch.jit._script.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html) is the TorchScript module of the encoder with embedded NEFF(Neuron Executable File Format) compiled by neuron(x) compiler.
decoder (`torch.jit._script.ScriptModule`): [torch.jit._script.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html) is the TorchScript module of the decoder with embedded NEFF(Neuron Executable File Format) compiled by neuron(x) compiler.
config (`transformers.PretrainedConfig`): [PretrainedConfig](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig) is the Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`optimum.neuron.modeling.NeuronTracedModel.from_pretrained`] method to load the model weights.
"""

NEURON_SEQ2SEQ_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.Tensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`](https://huggingface.co/docs/transformers/autoclass_tutorial#autotokenizer).
See [`PreTrainedTokenizer.encode`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.encode) and
[`PreTrainedTokenizer.__call__`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.__call__) for details.
[What are input IDs?](https://huggingface.co/docs/transformers/glossary#input-ids)
attention_mask (`Union[torch.Tensor, None]` of shape `({0})`, defaults to `None`):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](https://huggingface.co/docs/transformers/glossary#attention-mask)
"""


class NeuronModelForConditionalGeneration(NeuronTracedModel, ABC):
base_model_prefix = "neuron_model"
Expand Down Expand Up @@ -371,10 +401,74 @@ def _combine_encoder_decoder_config(self, encoder_config: "PretrainedConfig", de
return combined_config


TRANSLATION_EXAMPLE = r"""
*(Following models are compiled with neuronx compiler and can only be run on INF2.)*
Example of text-to-text generation with small T5 model:
```python
from transformers import {processor_class}
from optimum.neuron import {model_class}
neuron_model = {model_class}.from_pretrained({checkpoint_regular}, export=True, dynamic_batch_size=False, batch_size=1, sequence_length=64, num_beams=4)
neuron_model.save_pretrained("t5_small_neuronx")
del neuron_model
neuron_model = {model_class}.from_pretrained("t5_small_neuronx")
tokenizer = {processor_class}.from_pretrained("t5_small_neuronx")
inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")
output = neuron_model.generate(
**inputs,
num_return_sequences=1,
)
results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]
```
Example of text-to-text generation with tensor parallelism:
(For large models, in order to fit into Neuron cores, we need to applly tensor parallelism. Hers below is an example ran on `inf2.24xlarge`.)
```python
from transformers import {processor_class}
from optimum.neuron import {model_class}
# 1. compile
if __name__ == "__main__": # `if __name__ == "__main__"` is compulsory for parallel tracing since the API will spawn multiple processes
neuron_model = {model_class}.from_pretrained(
{checkpoint_tp}, export=True, tensor_parallel_size=8, dynamic_batch_size=False, batch_size=1, sequence_length=128, num_beams=4,
)
neuron_model.save_pretrained("flan_t5_xl_neuronx_tp8/")
del neuron_model
# 2. inference
neuron_model = {model_class}.from_pretrained("flan_t5_xl_neuronx_tp8")
tokenizer = {processor_class}.from_pretrained("flan_t5_xl_neuronx_tp8")
inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")
output = neuron_model.generate(
**inputs,
num_return_sequences=1,
)
results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]
```
"""


@add_start_docstrings(
"""
Neuron Sequence-to-sequence model with a language modeling head for text2text-generation tasks.
""",
NEURON_SEQ2SEQ_MODEL_START_DOCSTRING,
)
class NeuronModelForSeq2SeqLM(NeuronModelForConditionalGeneration, NeuronGenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
main_input_name = "input_ids"

@add_start_docstrings_to_model_forward(
NEURON_SEQ2SEQ_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ TRANSLATION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForSeq2SeqLM",
checkpoint_regular="google-t5/t5-small",
checkpoint_tp="google/flan-t5-xl",
)
)
def forward(
self,
attention_mask: Optional[torch.FloatTensor] = None,
Expand Down

0 comments on commit ae767c1

Please sign in to comment.