diff --git a/.github/workflows/doc-build.yml b/.github/workflows/doc-build.yml
index 481ef2010..29b67559b 100644
--- a/.github/workflows/doc-build.yml
+++ b/.github/workflows/doc-build.yml
@@ -51,7 +51,14 @@ jobs:
- name: Make documentation
shell: bash
run: |
- doc-builder build optimum.neuron docs/source/ --repo_name optimum-neuron --build_dir neuron-doc-build/ --version ${{ env.VERSION }} --version_tag_suffix "" --html --clean
+ doc-builder build optimum.neuron docs/source/ \
+ --repo_name optimum-neuron \
+ --build_dir neuron-doc-build/ \
+ --version ${{ env.VERSION }} \
+ --version_tag_suffix "" \
+ --html \
+ --clean \
+ --notebook_dir docs/notebooks/
cd neuron-doc-build/
mv optimum.neuron optimum-neuron
- doc-builder push optimum-neuron --doc_build_repo_id "hf-doc-build/doc-build" --token "${{ secrets.HF_DOC_BUILD_PUSH }}" --commit_msg "Updated with commit $COMMIT_SHA See: https://github.com/huggingface/optimum-neuron/commit/$COMMIT_SHA" --n_retries 5
\ No newline at end of file
+ doc-builder push optimum-neuron --doc_build_repo_id "hf-doc-build/doc-build" --token "${{ secrets.HF_DOC_BUILD_PUSH }}" --commit_msg "Updated with commit $COMMIT_SHA See: https://github.com/huggingface/optimum-neuron/commit/$COMMIT_SHA" --n_retries 5
diff --git a/.github/workflows/doc-pr-build.yml b/.github/workflows/doc-pr-build.yml
index 450d0f182..a206771b5 100644
--- a/.github/workflows/doc-pr-build.yml
+++ b/.github/workflows/doc-pr-build.yml
@@ -36,7 +36,14 @@ jobs:
- name: Make documentation
shell: bash
run: |
- doc-builder build optimum.neuron docs/source/ --repo_name optimum-neuron --build_dir neuron-doc-build/ --version pr_${{ env.PR_NUMBER }} --version_tag_suffix "" --html --clean
+ doc-builder build optimum.neuron docs/source/ \
+ --repo_name optimum-neuron \
+ --build_dir neuron-doc-build/ \
+ --version pr_${{ env.PR_NUMBER }} \
+ --version_tag_suffix "" \
+ --html \
+ --clean \
+ --notebook_dir docs/notebooks/
- name: Save commit_sha & pr_number
run: |
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 400626399..0573fafdb 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -14,6 +14,8 @@
title: Fine-tune BERT for Text Classification on AWS Trainium
- local: training_tutorials/finetune_llm
title: Fine-tune Llama 3 8B on AWS Trainium
+ - local: training_tutorials/sft_lora_finetune_llm
+ title: Fine-tune Llama 3 8B on with LoRA and the SFTTrainer
title: Training Tutorials
- sections:
- local: inference_tutorials/notebooks
diff --git a/docs/source/training_tutorials/finetune_llm.mdx b/docs/source/training_tutorials/finetune_llm.mdx
index 4928ca37c..65eae65b9 100644
--- a/docs/source/training_tutorials/finetune_llm.mdx
+++ b/docs/source/training_tutorials/finetune_llm.mdx
@@ -45,7 +45,7 @@ And many others!
Before starting this tutorial, you will need to setup your environment:
-1. Create an AWS Trainium instance. You can follow this [guide](https://huggingface.co/docs/optimum-neuron/guides/setup_aws_instance) to create one.
+1. Create an AWS Trainium instance. **You will need a `trn1.32xlarge`, which contains 16 Neuron Devices.** You can follow this [guide](https://huggingface.co/docs/optimum-neuron/guides/setup_aws_instance) to create one.
2. Make sure you are logged in on the Hugging Face Hub:
```bash
huggingface-cli login --token YOUR_TOKEN
@@ -53,7 +53,7 @@ huggingface-cli login --token YOUR_TOKEN
3. Check that you have access to the model. Some open source models are gated, meaning that users need to apply to the model owner to be able to use the model weights. Here we will be training Llama-3 8B, for which there are two possibilities:
* The official gated repo: [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
* The non-official un-gated repo: [`NousResearch/Meta-Llama-3-8B`](https://huggingface.co/NousResearch/Meta-Llama-3-8B)
-4. Clone the Optimum Neuron repository, **which contains the [complete script](https://github.com/huggingface/optimum-neuron/docs/source/training_tutorials/finetune_llm.py) described in this tutorial:**
+4. Clone the Optimum Neuron repository, **which contains the [complete script](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/finetune_llm.py) described in this tutorial:**
```bash
git clone https://github.com/huggingface/optimum-neuron.git
```
@@ -68,7 +68,10 @@ Example:
{
"instruction": "What is world of warcraft",
"context": "",
- "response": "World of warcraft is a massive online multi player role playing game. It was released in 2004 by bizarre entertainment"
+ "response": (
+ "World of warcraft is a massive online multi player role playing game. "
+ "It was released in 2004 by blizarre entertainment"
+ )
}
```
@@ -98,7 +101,7 @@ def format_dolly(sample):
return prompt
```
-In addition to formatting our samples, we also want to pack multiple samples to one sequence to have a more efficient training. In other words, we are stacking multiple samples to one sequence and split them with an EOS Token. Packing/stacking samples can be done during training or before. Here, we will do it before training to save time.
+In addition to formatting our samples, we also want to pack multiple samples to one sequence to have a more efficient training. In other words, we are stacking multiple samples to one sequence and split them with an EOS Token. Packing/stacking samples can be done during training or before.
The following function `pack_dataset` takes a `dataset` and a `chunk_length` and returns a packed dataset:
@@ -181,16 +184,6 @@ dataset = dataset.map(
lm_dataset = pack_dataset(dataset, chunk_length=2048) # We use 2048 as the maximum length for packing
```
-After we processed the datasets we are going save it to disk. You could also save it to S3 or the Hugging Face Hub for later use.
-
-_Note: Packing and preprocessing your dataset can be run outside of the Trainium instance._
-
-```python
-# save train_dataset to disk
-dataset_path = "tokenized_dolly"
-lm_dataset.save_to_disk(dataset_path)
-```
-
## 3. Fine-tune Llama on AWS Trainium using the `NeuronTrainer`
Normally you would use the **[Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer)** and **[TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments)** classes to fine-tune PyTorch-based transformer models.
@@ -244,16 +237,18 @@ The key points here are:
## 4. Launch Training
-We prepared a script called [finetune_llm.py](https://github.com/huggingface/optimum-neuron/docs/source/training_tutorials/finetune_llm.py) summing up everything mentioned in this tutorial.
+We prepared a script called [finetune_llm.py](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/finetune_llm.py) summing up everything mentioned in this tutorial.
-This script is a minimalistic version of our official example training script to run causal language modeling fine-tuning, called [run_clm.py](https://github.com/huggingface/optimum-neuron/blob/main/examples/language-modeling/run_clm.py). For the sake of this tutorial, we tried to get rid of anything that is not necessary, but if you want to do more custom things, maybe the solution is already implemented in `run_clm.py`!
+This script is a minimalistic version of our official example training script to run causal language modeling fine-tuning, called [run_clm.py](https://github.com/huggingface/optimum-neuron/blob/main/examples/language-modeling/run_clm.py). For the sake of this tutorial, we tried to get rid of anything that is not necessary, and added the formatting step necessary for fine-tuning, but if you want to do more custom things, maybe the solution is already implemented in `run_clm.py`!
Also, these scripts are more designed as templates than final scripts. Feel free to take `finetune_llm.py` or `run_clm.py` and adapt them to your own needs!
+PyTorch Neuron uses `torch_xla`. It evaluates operations lazily during execution of the training loops, which means it builds a symbolic graph in the background and the graph is executed on the hardware only when the tensor is printed, transfered to CPU, or `xm.mark_step()` is called. During execution, multiple graphs can be build depending on control-flow and it can take time to compile each graph sequentially. To alleviate that, the Neuron SDK provides `neuron_parallel_compile`, a tool which performs a fast trial run that builds all the graphs and compile them in parallel. This step is usually called precompilation.
+
### Precompilation
When training models on AWS Trainium we first need to compile our model with our training arguments.
@@ -266,8 +261,7 @@ The compilation command simply consists in calling your script as an input to th
```bash
MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node=32 finetune_llm.py \
- --model_id {model_id} \
- --dataset_path {dataset_path} \
+ --model_id meta-llama/Meta-Llama-3-8B \
--bf16 True \
--learning_rate 5e-5 \
--output_dir dolly_llama \
@@ -305,8 +299,7 @@ Launch the training, with the following command.
```bash
MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 torchrun --nproc_per_node=32 finetune_llm.py \
- --model_id {model_id} \
- --dataset_path {dataset_path} \
+ --model_id meta-llama/Meta-Llama-3-8B \
--bf16 True \
--learning_rate 5e-5 \
--output_dir dolly_llama \
diff --git a/docs/source/training_tutorials/finetune_llm.py b/docs/source/training_tutorials/finetune_llm.py
index d3fd2bfd0..f291779a2 100644
--- a/docs/source/training_tutorials/finetune_llm.py
+++ b/docs/source/training_tutorials/finetune_llm.py
@@ -1,9 +1,8 @@
from dataclasses import dataclass, field
from functools import partial
from itertools import chain
-from typing import Optional
-from datasets import load_dataset, load_from_disk
+from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -17,10 +16,6 @@
from optimum.neuron.distributed import lazy_load_for_parallelism
-# Load dataset from the hub
-dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
-
-
def format_dolly(sample):
instruction = f"### Instruction\n{sample['instruction']}"
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
@@ -70,9 +65,7 @@ def chunk(sample, chunk_length=chunk_length):
return lm_dataset
-def create_and_save_dataset(model_id: str, dataset_path: str):
- tokenizer = AutoTokenizer.from_pretrained(model_id)
-
+def prepare_dataset(tokenizer, dataset):
# template dataset to add prompt to each sample
def template_dataset(sample):
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
@@ -89,15 +82,16 @@ def template_dataset(sample):
# chunk dataset
lm_dataset = pack_dataset(dataset, chunk_length=2048) # We use 2048 as the maximum length for packing
- # save train_dataset to disk
- lm_dataset.save_to_disk(dataset_path)
+ return lm_dataset
def training_function(script_args, training_args):
- # load dataset
- dataset = load_from_disk(script_args.dataset_path)
-
tokenizer = AutoTokenizer.from_pretrained(script_args.model_id)
+
+ # Load dataset from the hub and prepare it for training.
+ dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
+ dataset = prepare_dataset(tokenizer, dataset)
+
with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
model = AutoModelForCausalLM.from_pretrained(script_args.model_id)
@@ -122,20 +116,12 @@ class ScriptArguments:
default="meta-llama/Meta-Llama-3-8B",
metadata={"help": "The model that you want to train from the Hugging Face hub."},
)
- dataset_path: Optional[str] = field(
- metadata={"help": "Path to the preprocessed and tokenized dataset."},
- default=None,
- )
def main():
parser = HfArgumentParser([ScriptArguments, TrainingArguments])
script_args, training_args = parser.parse_args_into_dataclasses()
- if script_args.dataset_path is None:
- create_and_save_dataset(script_args.model_id, "tokenized_dolly")
- script_args.dataset_path = "tokenized_dolly"
-
# set seed
set_seed(training_args.seed)
diff --git a/docs/source/training_tutorials/sft_lora_finetune_llm.mdx b/docs/source/training_tutorials/sft_lora_finetune_llm.mdx
new file mode 100644
index 000000000..dc3df1ace
--- /dev/null
+++ b/docs/source/training_tutorials/sft_lora_finetune_llm.mdx
@@ -0,0 +1,430 @@
+
+
+# Supervised Fine-Tuning of Llama 3 8B on one AWS Trainium instance
+
+_Note: The complete script for this tutorial can be downloaded [here](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/sft_lora_finetune_llm.py)._
+
+This tutorial will teach you how to fine-tune open source LLMs like [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) on AWS Trainium. In our example, we are going to leverage the [Optimum Neuron](https://huggingface.co/docs/optimum-neuron/index), [Transformers](https://huggingface.co/docs/transformers/index) and [Datasets](https://huggingface.co/docs/datasets/index) libraries.
+
+You will learn how to:
+
+1. [Setup AWS Environment](#1-setup-aws-environment)
+2. [Load and process the dataset](#2-load-and-prepare-the-dataset)
+3. [Supervised Fine-Tuning of Llama on AWS Trainium with the `NeuronSFTTrainer`](#3-supervised-fined-tuning-of-llama-on-aws-trainium-with-the-neuronsfttrainer)
+4. [Launch Training](#4-launch-training)
+5. [Evaluate and test fine-tuned Llama model](#5-evaluate-and-test-fine-tuned-llama-model)
+
+
+
+While we will use `Llama-3 8B` in this tutorial, it is completely possible to use other models, simply by swtiching the `model_id`.
+
+
+
+## 1. Setup AWS Environment
+
+Before starting this tutorial, you will need to setup your environment:
+
+1. Create an AWS Trainium instance. **You will need a `trn1.32xlarge`, which contains 16 Neuron Devices.** You can follow this [guide](https://huggingface.co/docs/optimum-neuron/guides/setup_aws_instance) to create one.
+2. Make sure you are logged in on the Hugging Face Hub:
+```bash
+huggingface-cli login --token YOUR_TOKEN
+```
+3. Check that you have access to the model. Some open source models are gated, meaning that users need to apply to the model owner to be able to use the model weights. Here we will be training Llama-3 8B, for which there are two possibilities:
+ * The official gated repo: [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
+ * The non-official un-gated repo: [`NousResearch/Meta-Llama-3-8B`](https://huggingface.co/NousResearch/Meta-Llama-3-8B)
+4. Clone the Optimum Neuron repository, **which contains the [complete script](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/sft_lora_finetune_llm.py) described in this tutorial:**
+```bash
+git clone https://github.com/huggingface/optimum-neuron.git
+```
+
+## 2. Load and prepare the dataset
+
+For this tutorial, we will use [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k), an open source dataset of instruction-following records on categories outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155), including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.
+
+Example:
+
+```python
+{
+ "instruction": "What is world of warcraft",
+ "context": "",
+ "response": (
+ "World of warcraft is a massive online multi player role playing game. "
+ "It was released in 2004 by blizarre entertainment"
+ )
+}
+```
+
+We can use the `load_dataset()` method from the 🤗 Datasets library to load the `dolly` dataset very easily.
+
+```python
+from datasets import load_dataset
+from random import randrange
+
+# Load dataset from the hub
+dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
+
+print(f"dataset size: {len(dataset)}")
+print(dataset[randrange(len(dataset))])
+# dataset size: 15011
+```
+
+To instruct fine-tune our model we need to:
+
+ 1. Convert our structured examples into collection of tasks described via instructions
+
+ 2. (Optional) Pack multiple examples to one sequence for more efficient training. In other words, we are stacking multiple examples into one example,
+ and split them with the EOS token.
+
+We could do this manually, but we will use the `NeuronSFTTrainer` instead.
+
+## 3. Supervised Fine-Tuning of Llama on AWS Trainium with the `NeuronSFTTrainer`
+
+Normally you would use the **[SFTConfig](https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTConfig)** and **[SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer)** classes to perform supervised fine-tuning of PyTorch-based transformer models.
+
+Instead, here we will be using the [`~optimum.neuron.NeuronSFTConfig`] and [`~optimum.neuron.NeuronSFTTrainer`]. These classes replicate the ones from the `trl` library while making sure they work properly on Neuron cores.
+
+### Formatting our dataset
+
+There are multiple ways to give a dataset to the `NeuronSFTTrainer`, and one of them consists in providing a formatting function.
+For `dolly` without packing the examples it looks as follows:
+
+```python
+def format_dolly(examples):
+ output_text = []
+ for i in range(len(examples["instruction"])):
+ instruction = f"### Instruction\n{examples['instruction'][i]}"
+ context = f"### Context\n{examples['context'][i]}" if len(examples["context"][i]) > 0 else None
+ response = f"### Answer\n{examples['response'][i]}"
+ prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
+ output_text.append(prompt)
+ return output_text
+```
+
+### Preparing the model
+
+Since Llama-3 8B is a big model it will not fit on a single `trn1.32xlarge` instance, even with distributed training. To actually fine-tune a 8B model using only one Trainium instance we need to use both LoRA and distributed training.
+
+
+
+If you want to know more about distributed training you can take a look at the [documentation](https://huggingface.co/docs/optimum-neuron/guides/distributed_training).
+
+
+
+Here, we will use tensor parallelism in conjuction with LoRA.
+Our training code will look as follows:
+
+```python
+from peft import LoraConfig
+from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer
+from optimum.neuron.distributed import lazy_load_for_parallelism
+
+# Define the tensor_parallel_size
+tensor_parallel_size = 2
+
+dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
+
+model_id = "meta-llama/Meta-Llama-3-8B"
+
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+tokenizer.pad_token = tokenizer.eos_token
+
+with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
+ model = AutoModelForCausalLM.from_pretrained(model_id)
+
+config = LoraConfig(
+ r=16,
+ lora_alpha=16,
+ lora_dropout=0.05,
+ target_modules=[
+ "q_proj",
+ "gate_proj",
+ "v_proj",
+ "o_proj",
+ "k_proj",
+ "up_proj",
+ "down_proj"
+ ],
+ bias="none",
+ task_type="CAUSAL_LM",
+)
+
+# training_args is an instance of NeuronTrainingArguments
+args = training_args.to_dict()
+sft_config = NeuronSFTConfig(
+ max_seq_length=1024,
+ packing=False,
+ **args,
+)
+
+trainer = NeuronSFTTrainer(
+ args=sft_config,
+ model=model,
+ peft_config=config,
+ tokenizer=tokenizer,
+ train_dataset=dataset,
+ formatting_func=format_dolly,
+)
+
+# Start training
+trainer.train()
+
+trainer.save_model() # Saves the tokenizer too for easy upload
+```
+
+The key points here are:
+
+- We use the `lazy_load_for_parallelism` context manager to lazily load the model. This will not load the full model weights on each worker, but instead only load the required weights (sharded or full). **This is much more memory efficient, and often mandatory to use.**
+- We define a `LoraConfig` that specifies which layers should have adapters, and the hyperparameters for theses adapters.
+- We create a [`~optimum.neuron.NeuronSFTConfig`] from regular `NeuronTrainingArguments`. Here we specify that we do not want to pack our examples, and that the max sequence length should be `1024`, meaning that every example will be either padded or truncated to a length of `1024`.
+- We use the [`~optimum.neuron.NeuronSFTTrainer`] to perform training. It will take the lazily loaded model, along with `lora_config`, `sft_config` and `format_dolly` and prepare the dataset and model for supervised fine-tuning.
+
+## 4. Launch Training
+
+We prepared a script called [sft_lora_finetune_llm.py](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/lora_finetune_llm.py) summing up everything mentioned in this tutorial.
+
+PyTorch Neuron uses `torch_xla`. It evaluates operations lazily during the execution of the training loops, which means it builds a symbolic graph in the background, and the graph is executed on the hardware only when the tensor is printed, transferred to CPU, or when `xm.mark_step()` is called. During execution, multiple graphs can be built depending on control-flow, and it can take time to compile each graph sequentially. To alleviate that, the Neuron SDK provides `neuron_parallel_compile`, a tool which performs a fast trial run that builds all the graphs and compile them in parallel. This step is usually called precompilation.
+
+### Precompilation
+
+When training models on AWS Trainium we first need to compile our model with our training arguments.
+
+To ease this step, we added a [model cache repository](https://huggingface.co/aws-neuron/optimum-neuron-cache), which allows us to use precompiled models from the Hugging Face Hub to skip the compilation step. But be careful: every change in the model configuration might lead to a new compilation, which could result in some cache misses.
+
+
+
+To learn more about the caching system, and how you can create your own private cache repository, check this [guide](https://huggingface.co/docs/optimum-neuron/guides/cache_system).
+
+
+
+The compilation command simply consists in calling your script as an input to the `neuron_parallel_compile` utility:
+
+```bash
+#!/bin/bash
+set -ex
+
+export NEURON_FUSE_SOFTMAX=1
+export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
+export MALLOC_ARENA_MAX=64
+export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"
+
+PROCESSES_PER_NODE=8
+
+NUM_EPOCHS=1
+TP_DEGREE=2
+PP_DEGREE=1
+BS=1
+GRADIENT_ACCUMULATION_STEPS=8
+LOGGING_STEPS=1
+MODEL_NAME="meta-llama/Meta-Llama-3-8B"
+OUTPUT_DIR=output-$SLURM_JOB_ID
+
+if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
+ MAX_STEPS=$((LOGGING_STEPS + 5))
+else
+ MAX_STEPS=-1
+fi
+
+
+XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
+ --model_id $MODEL_NAME \
+ --num_train_epochs $NUM_EPOCHS \
+ --do_train \
+ --learning_rate 5e-5 \
+ --warmup_ratio 0.03 \
+ --max_steps $MAX_STEPS \
+ --per_device_train_batch_size $BS \
+ --per_device_eval_batch_size $BS \
+ --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
+ --gradient_checkpointing true \
+ --bf16 \
+ --zero_1 false \
+ --tensor_parallel_size $TP_DEGREE \
+ --pipeline_parallel_size $PP_DEGREE \
+ --logging_steps $LOGGING_STEPS \
+ --save_total_limit 1 \
+ --output_dir $OUTPUT_DIR \
+ --lr_scheduler_type "constant" \
+ --overwrite_output_dir
+```
+
+
+
+Make sure to run this precompilation phase for around 10 training steps. It is usually enough to accumulate and compile all the graphs that will be needed during the actual training.
+
+
+
+_Note: Compiling without a cache can take a while. It will also create dummy files in the `dolly_llama_sharded` during compilation you will have to remove them afterwards. We also need to add `MALLOC_ARENA_MAX=64` to limit the CPU allocation to avoid potential crashes, don't remove it for now._
+
+```bash
+# remove dummy artifacts which are created by the precompilation command
+rm -rf dolly_llama
+```
+
+### Actual Training
+
+After compilation is done we can start our actual training with a similar command, we just need to remove the use of `neuron_parallel_compile`.
+
+We will use `torchrun` to launch our training script. `torchrun` is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as `nproc_per_node` arguments alongside our hyperparameters.
+
+The difference to the compilation command is that we changed from `max_steps=10` to `num_train_epochs=3`.
+
+Launch the training, with the following command.
+
+```bash
+#!/bin/bash
+set -ex
+
+export NEURON_FUSE_SOFTMAX=1
+export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
+export MALLOC_ARENA_MAX=64
+export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"
+
+PROCESSES_PER_NODE=8
+
+NUM_EPOCHS=1
+TP_DEGREE=2
+PP_DEGREE=1
+BS=1
+GRADIENT_ACCUMULATION_STEPS=8
+LOGGING_STEPS=1
+MODEL_NAME="meta-llama/Meta-Llama-3-8B"
+OUTPUT_DIR=output-$SLURM_JOB_ID
+
+if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
+ MAX_STEPS=$((LOGGING_STEPS + 5))
+else
+ MAX_STEPS=-1
+fi
+
+
+XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
+ --model_id $MODEL_NAME \
+ --num_train_epochs $NUM_EPOCHS \
+ --do_train \
+ --learning_rate 5e-5 \
+ --warmup_ratio 0.03 \
+ --max_steps $MAX_STEPS \
+ --per_device_train_batch_size $BS \
+ --per_device_eval_batch_size $BS \
+ --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
+ --gradient_checkpointing true \
+ --bf16 \
+ --zero_1 false \
+ --tensor_parallel_size $TP_DEGREE \
+ --pipeline_parallel_size $PP_DEGREE \
+ --logging_steps $LOGGING_STEPS \
+ --save_total_limit 1 \
+ --output_dir $OUTPUT_DIR \
+ --lr_scheduler_type "constant" \
+ --overwrite_output_dir
+```
+
+That's it, we successfully trained Llama-3 8B on AWS Trainium!
+
+But before we can share and test our model we need to consolidate our model. Since we used tensor parallelism during training, we saved sharded versions of the checkpoints. We need to consolidate them now.
+
+### Consolidate the Checkpoint
+
+The Optimum CLI provides a way of doing that very easily via the `optimum neuron consolidate [sharded_checkpoint] [output_dir]` command:
+
+```bash
+optimum-cli neuron consolidate dolly_llama dolly_llama
+```
+
+## 5. Evaluate and test fine-tuned Llama model
+
+As for training, to be able to run inference on AWS Trainium or AWS Inferentia2 we need to compile our model. In this case, we will use our Trainium instance for the inference test, but we recommend customer to switch to Inferentia2 (`inf2.24xlarge`) for inference.
+
+Optimum Neuron implements similar to Transformers AutoModel classes for easy inference use. We will use the `NeuronModelForCausalLM` class to load our vanilla transformers checkpoint and convert it to neuron.
+
+```python
+from optimum.neuron import NeuronModelForCausalLM
+from transformers import AutoTokenizer
+
+compiler_args = {"num_cores": 2, "auto_cast_type": 'fp16'}
+input_shapes = {"batch_size": 1, "sequence_length": 2048}
+
+tokenizer = AutoTokenizer.from_pretrained("dolly_llama")
+model = NeuronModelForCausalLM.from_pretrained(
+ "dolly_llama",
+ export=True,
+ **compiler_args,
+ **input_shapes)
+```
+
+_Note: Inference compilation can take ~25minutes. Luckily, you need to only run this onces. Since you can save the model afterwards. If you are going to run on Inferentia2 you need to recompile again. The compilation is parameter and hardware specific._
+
+```python
+# COMMENT IN if you want to save the compiled model
+# model.save_pretrained("compiled_dolly_llama")
+```
+
+We can now test inference, but have to make sure we format our input to our prompt format we used for fine-tuning. Therefore we created a helper method, which accepts a `dict` with our `instruction` and optionally a `context`.
+
+```python
+def format_dolly_inference(sample):
+ instruction = f"### Instruction\n{sample['instruction']}"
+ context = f"### Context\n{sample['context']}" if "context" in sample else None
+ response = f"### Answer\n"
+ prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
+ return prompt
+
+
+def generate(sample):
+ prompt = format_dolly_inference(sample)
+ inputs = tokenizer(prompt, return_tensors="pt")
+ outputs = model.generate(
+ **inputs,
+ max_new_tokens=512,
+ do_sample=True,
+ temperature=0.9,
+ top_k=50,
+ top_p=0.9
+ )
+ return tokenizer.decode(outputs[0], skip_special_tokens=False)[len(prompt):]
+```
+
+Let's test inference. First we test without a context.
+
+_Note: Inference is not expected to be super fast on AWS Trainium using 2 cores. For Inference we recommend using Inferentia2._
+
+```python
+prompt = {
+ "instruction": "Can you tell me something about AWS?"
+}
+res = generate(prompt)
+
+print(res)
+```
+
+> AWS stands for Amazon Web Services. AWS is a suite of remote computing services offered by Amazon. The most widely used of these include Amazon Elastic Compute Cloud (Amazon EC2), which provides resizable compute capacity in the cloud; Amazon Simple Storage Service (Amazon S3), which is an object storage service; and Amazon Elastic Block Store (Amazon EBS), which is designed to provide high performance, durable block storage volumes for use with AWS instances. AWS also provides other services, such as AWS Identity and Access Management (IAM), a service that enables organizations to control access to their AWS resources, and AWS Key Management Service (AWS KMS), which helps customers create and control the use of encryption keys.
+
+That looks correct. Now, lets add some context, e.g. as you would do for RAG applications:
+
+```python
+prompt = {
+ "instruction": "How can I train models on AWS Trainium?",
+ "context": "🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/?nc1=h_ls) and [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/?nc1=h_ls). It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks."
+}
+res = generate(prompt)
+
+print(res)
+```
+
+> You can use the Optimum Neuron interface to train models on AWS Trainium.
+
+Awesome, our model also correctly uses the provided context. We are done. Congrats on fine-tuning Llama on AWS Trainium.
diff --git a/docs/source/training_tutorials/sft_lora_finetune_llm.py b/docs/source/training_tutorials/sft_lora_finetune_llm.py
new file mode 100644
index 000000000..9c383ff85
--- /dev/null
+++ b/docs/source/training_tutorials/sft_lora_finetune_llm.py
@@ -0,0 +1,87 @@
+from dataclasses import dataclass, field
+
+from datasets import load_dataset
+from peft import LoraConfig
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ set_seed,
+)
+
+from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser
+from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArguments
+from optimum.neuron.distributed import lazy_load_for_parallelism
+
+
+def format_dolly(examples):
+ output_text = []
+ for i in range(len(examples["instruction"])):
+ instruction = f"### Instruction\n{examples['instruction'][i]}"
+ context = f"### Context\n{examples['context'][i]}" if len(examples["context"][i]) > 0 else None
+ response = f"### Answer\n{examples['response'][i]}"
+ prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
+ output_text.append(prompt)
+ return output_text
+
+
+def training_function(script_args, training_args):
+ dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
+
+ tokenizer = AutoTokenizer.from_pretrained(script_args.model_id)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
+ model = AutoModelForCausalLM.from_pretrained(script_args.model_id)
+
+ config = LoraConfig(
+ r=16,
+ lora_alpha=16,
+ lora_dropout=0.05,
+ target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
+
+ args = training_args.to_dict()
+ sft_config = NeuronSFTConfig(
+ max_seq_length=1024,
+ packing=False,
+ **args,
+ )
+
+ trainer = NeuronSFTTrainer(
+ args=sft_config,
+ model=model,
+ peft_config=config,
+ tokenizer=tokenizer,
+ train_dataset=dataset,
+ formatting_func=format_dolly,
+ )
+
+ # Start training
+ trainer.train()
+
+ trainer.save_model() # Saves the tokenizer too for easy upload
+
+
+@dataclass
+class ScriptArguments:
+ model_id: str = field(
+ default="meta-llama/Meta-Llama-3-8B",
+ metadata={"help": "The model that you want to train from the Hugging Face hub."},
+ )
+
+
+def main():
+ parser = HfArgumentParser([ScriptArguments, NeuronTrainingArguments])
+ script_args, training_args = parser.parse_args_into_dataclasses()
+
+ # set seed
+ set_seed(training_args.seed)
+
+ # run training function
+ training_function(script_args, training_args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/optimum/neuron/utils/peft_utils.py b/optimum/neuron/utils/peft_utils.py
index 355c83846..4855f2ef2 100644
--- a/optimum/neuron/utils/peft_utils.py
+++ b/optimum/neuron/utils/peft_utils.py
@@ -174,6 +174,10 @@ class DummyModule(torch.nn.Module):
def state_dict(self):
return output_state_dict
+ adapter_shards_dir_model = os.path.join(output_dir, "adapter_shards", "model")
+ if not os.path.isdir(adapter_shards_dir_model):
+ os.makedirs(adapter_shards_dir_model)
+
dummy_mod = DummyModule()
neuronx_distributed.trainer.save_checkpoint(
output_dir,