Skip to content

Commit

Permalink
Fix deepseeed crash with Sentence Transformer Trainer (#1328)
Browse files Browse the repository at this point in the history
Co-authored-by: ZhengHongming888 <hongming.zheng@intel.com>
Co-authored-by: Yaser Afshar <yaser.afshar@intel.com>
  • Loading branch information
3 people committed Sep 24, 2024
1 parent 1a52079 commit 091c8a5
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 33 deletions.
28 changes: 26 additions & 2 deletions examples/sentence-transformers-training/nli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Given two sentences (premise and hypothesis), the task of Natural Language Infer

The paper in [Conneau et al.](https://arxiv.org/abs/1705.02364) shows that NLI data can be quite useful when training Sentence Embedding methods. In [Sentence-BERT-Paper](https://arxiv.org/abs/1908.10084) NLI as a first fine-tuning step for sentence embedding methods has been used.

# General Models

## Single-card Training

To pre-train on the NLI task:
Expand Down Expand Up @@ -46,7 +48,29 @@ For multi-card training you can use the script of [gaudi_spawn.py](https://githu
HABANA_VISIBLE_MODULES="2,3" python ../../gaudi_spawn.py --use_deepspeed --world_size 2 training_nli.py bert-base-uncased
```

## Dataset

# Large Models (intfloat/e5-mistral-7b-instruct)

## Single-card Training with LoRA+gradient_checkpointing

Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can utilize LoRA and gradient checkpointing techniques to reduce the memory requirements, making it feasible to train the model on a single HPU.

```bash
python training_nli.py intfloat/e5-mistral-7b-instruct --peft --lora_target_module "q_proj" "k_proj" "v_proj" --learning_rate 1e-5
```

## Multi-card Training with Deepspeed Zero2/3

Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can use the Zero2/Zero3 stages of DeepSpeed (model parallelism) to reduce the memory requirements.

Our tests have shown that training this model requires at least four HPUs when using DeepSpeed Zero2.

```bash
python ../../gaudi_spawn.py --world_size 4 --use_deepspeed training_nli.py intfloat/e5-mistral-7b-instruct --deepspeed ds_config.json --bf16 --no-use_hpu_graphs_for_training --learning_rate 1e-7
```
In the above command, we need to enable lazy mode with a learning rate of `1e-7` and configure DeepSpeed using the `ds_config.json` file. To further reduce memory usage, change the stage to 3 (DeepSpeed Zero3) in the `ds_config.json` file.

# Dataset

We combine [SNLI](https://huggingface.co/datasets/stanfordnlp/snli) and [MultiNLI](https://huggingface.co/datasets/nyu-mll/multi_nli) into a dataset we call [AllNLI](https://huggingface.co/datasets/sentence-transformers/all-nli). These two datasets contain sentence pairs and one of three labels: entailment, neutral, contradiction:

Expand All @@ -58,7 +82,7 @@ We combine [SNLI](https://huggingface.co/datasets/stanfordnlp/snli) and [MultiNL

We format AllNLI in a few different subsets, compatible with different loss functions. See [triplet subset of AllNLI](https://huggingface.co/datasets/sentence-transformers/all-nli/viewer/triplet) as example.

## SoftmaxLoss
# SoftmaxLoss

<img src="https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/docs/img/SBERT_SoftmaxLoss.png" alt="SBERT SoftmaxLoss" width="250"/>

Expand Down
16 changes: 16 additions & 0 deletions examples/sentence-transformers-training/nli/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"steps_per_print": 1,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"bf16": {
"enabled": true
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"reduce_scatter": false,
"contiguous_gradients": false
}
}
58 changes: 46 additions & 12 deletions examples/sentence-transformers-training/nli/training_nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
STS benchmark dataset
"""

import argparse
import logging
import sys
from datetime import datetime

from datasets import load_dataset
Expand All @@ -28,16 +28,43 @@ def main():
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# You can specify any Hugging Face pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased"
train_batch_size = 16
parser = argparse.ArgumentParser()
parser.add_argument("model_name", help="model name or path", default="bert-base-uncased", nargs="?")
parser.add_argument("--peft", help="use LoRA", action="store_true", default=False)
parser.add_argument("--lora_target_modules", nargs="+", default=["query", "key", "value"])
parser.add_argument("--bf16", help="use bf16", action="store_true", default=False)
parser.add_argument(
"--use_hpu_graphs_for_training",
help="use hpu graphs for training",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument("--learning_rate", help="learning rate", type=float, default=5e-5)
parser.add_argument("--deepspeed", help="deepspeed config file", default=None)
parser.add_argument("--train_batch_size", help="train batch size", default=16, type=int)
args = parser.parse_args()

output_dir = (
"output/training_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
"output/training_nli_" + args.model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
model = SentenceTransformer(args.model_name)
if args.peft:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
r=16,
lora_alpha=64,
lora_dropout=0.05,
bias="none",
inference_mode=False,
target_modules=args.lora_target_modules,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
# We'll start with 10k training samples, but you can increase this to get a stronger model
Expand Down Expand Up @@ -66,16 +93,16 @@ def main():
dev_evaluator(model)

# 5. Define the training arguments
args = SentenceTransformerGaudiTrainingArguments(
stargs = SentenceTransformerGaudiTrainingArguments(
# Required parameter:
output_dir=output_dir,
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.train_batch_size,
warmup_ratio=0.1,
# fp16=True, # Set to False if you get an error that your GPU can't run on FP16
# bf16=False, # Set to True if you have a GPU that supports BF16
bf16=args.bf16, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
evaluation_strategy="steps",
eval_steps=100,
Expand All @@ -87,16 +114,18 @@ def main():
use_habana=True,
gaudi_config_name="Habana/bert-base-uncased",
use_lazy_mode=True,
use_hpu_graphs=True,
use_hpu_graphs=args.use_hpu_graphs_for_training,
use_hpu_graphs_for_inference=False,
use_hpu_graphs_for_training=True,
use_hpu_graphs_for_training=args.use_hpu_graphs_for_training,
dataloader_drop_last=True,
learning_rate=args.learning_rate,
deepspeed=args.deepspeed,
)

# 6. Create the trainer & start training
trainer = SentenceTransformerGaudiTrainer(
model=model,
args=args,
args=stargs,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
Expand All @@ -119,6 +148,11 @@ def main():
final_output_dir = f"{output_dir}/final"
model.save(final_output_dir)

if args.peft:
model.eval()
model = model.merge_and_unload()
model.save_pretrained(f"{output_dir}/merged")


if __name__ == "__main__":
main()
29 changes: 27 additions & 2 deletions examples/sentence-transformers-training/sts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Semantic Textual Similarity (STS) assigns a score on the similarity of two texts
- **[training_stsbenchmark.py](training_stsbenchmark.py)** - This example shows how to create a SentenceTransformer model from scratch by using a pre-trained transformer model (e.g. [`distilbert-base-uncased`](https://huggingface.co/distilbert/distilbert-base-uncased)) together with a pooling layer.
- **[training_stsbenchmark_continue_training.py](training_stsbenchmark_continue_training.py)** - This example shows how to continue training on STS data for a previously created & trained SentenceTransformer model (e.g. [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)).

# General Models

## Single-card Training

To fine tune on the STS task:
Expand Down Expand Up @@ -33,7 +35,30 @@ For multi-card training you can use the script of [gaudi_spawn.py](https://githu
HABANA_VISIBLE_MODULES="2,3" python ../../gaudi_spawn.py --use_deepspeed --world_size 2 training_stsbenchmark.py bert-base-uncased
```

## Training data

# Large Models (intfloat/e5-mistral-7b-instruct Model)

## Single-card Training with LoRA+gradient_checkpointing

Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can utilize LoRA and gradient checkpointing techniques to reduce the memory requirements, making it feasible to train the model on a single HPU.

```bash
python training_stsbenchmark.py intfloat/e5-mistral-7b-instruct --peft --lora_target_modules "q_proj" "k_proj" "v_proj"
```

## Multi-card Training with Deepspeed Zero2/3

Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can use the Zero2/Zero3 stages of DeepSpeed (model parallelism) to reduce the memory requirements.

Our tests have shown that training this model requires at least four HPUs when using DeepSpeed Zero2.

```bash
python ../../gaudi_spawn.py --world_size 4 --use_deepspeed training_stsbenchmark.py intfloat/e5-mistral-7b-instruct --deepspeed ds_config.json --bf16 --no-use_hpu_graphs_for_training --learning_rate 1e-7
```

In the above command, we need to enable lazy mode with a learning rate of `1e-7` and configure DeepSpeed using the `ds_config.json` file. To further reduce memory usage, change the stage to 3 (DeepSpeed Zero3) in the `ds_config.json` file.

# Training data

Here is a simplified version of our training data:

Expand Down Expand Up @@ -70,7 +95,7 @@ train_dataset = load_dataset("sentence-transformers/stsb", split="train")
# })
```

## Loss Function
# Loss Function

<img src="https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/docs/img/SBERT_Siamese_Network.png" alt="SBERT Siamese Network Architecture" width="250"/>

Expand Down
16 changes: 16 additions & 0 deletions examples/sentence-transformers-training/sts/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"steps_per_print": 1,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"bf16": {
"enabled": true
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"reduce_scatter": false,
"contiguous_gradients": false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""

import argparse
import logging
import sys
from datetime import datetime

from datasets import load_dataset
Expand All @@ -25,19 +25,48 @@ def main():
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# You can specify any Hugging Face pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = sys.argv[1] if len(sys.argv) > 1 else "distilbert-base-uncased"
parser = argparse.ArgumentParser()
parser.add_argument("model_name", help="model name or path", default="distilbert-base-uncased", nargs="?")
parser.add_argument("--peft", help="use LoRA", action="store_true", default=False)
parser.add_argument("--lora_target_modules", nargs="+", default=["q_lin", "k_lin", "v_lin"])
parser.add_argument("--bf16", help="use bf16", action="store_true", default=False)
parser.add_argument(
"--use_hpu_graphs_for_training",
help="use hpu graphs for training",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument("--learning_rate", help="learning rate", type=float, default=5e-5)
parser.add_argument("--deepspeed", help="deepspeed config file", default=None)
args = parser.parse_args()

train_batch_size = 16
num_epochs = 1
output_dir = (
"output/training_stsbenchmark_"
+ model_name.replace("/", "-")
+ args.model_name.replace("/", "-")
+ "-"
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
model = SentenceTransformer(args.model_name)

if args.peft:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
r=16,
lora_alpha=64,
lora_dropout=0.05,
bias="none",
inference_mode=False,
target_modules=args.lora_target_modules,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb
train_dataset = load_dataset("sentence-transformers/stsb", split="train")
Expand All @@ -61,7 +90,7 @@ def main():
)

# 5. Define the training arguments
args = SentenceTransformerGaudiTrainingArguments(
stargs = SentenceTransformerGaudiTrainingArguments(
# Required parameter:
output_dir=output_dir,
# Optional training parameters:
Expand All @@ -70,7 +99,7 @@ def main():
per_device_eval_batch_size=train_batch_size,
warmup_ratio=0.1,
# fp16=True, # Set to False if you get an error that your GPU can't run on FP16
# bf16=True, # Set to True if you have a GPU that supports BF16
bf16=args.bf16, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
evaluation_strategy="steps",
eval_steps=100,
Expand All @@ -82,16 +111,18 @@ def main():
use_habana=True,
gaudi_config_name="Habana/distilbert-base-uncased",
use_lazy_mode=True,
use_hpu_graphs=True,
use_hpu_graphs=args.use_hpu_graphs_for_training,
use_hpu_graphs_for_inference=False,
use_hpu_graphs_for_training=True,
use_hpu_graphs_for_training=args.use_hpu_graphs_for_training,
learning_rate=args.learning_rate,
deepspeed=args.deepspeed,
)

# 6. Create the trainer & start training
# trainer = SentenceTransformerTrainer(
trainer = SentenceTransformerGaudiTrainer(
model=model,
args=args,
args=stargs,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
Expand All @@ -113,6 +144,11 @@ def main():
final_output_dir = f"{output_dir}/final"
model.save(final_output_dir)

if args.peft:
model.eval()
model = model.merge_and_unload()
model.save_pretrained(f"{output_dir}/merged")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion optimum/habana/sentence_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
from .st_gaudi_trainer import SentenceTransformerGaudiTrainer
from .st_gaudi_training_args import SentenceTransformerGaudiTrainingArguments
from .st_gaudi_encoder import st_gaudi_encode
from .st_gaudi_transformer_tokenize import st_gaudi_transformer_tokenize
from .st_gaudi_transformer import st_gaudi_transformer_tokenize, st_gaudi_transformer_save
from .st_gaudi_data_collator import st_gaudi_data_collator_call
2 changes: 2 additions & 0 deletions optimum/habana/sentence_transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def adapt_sentence_transformers_to_gaudi():
from optimum.habana.sentence_transformers import (
st_gaudi_data_collator_call,
st_gaudi_encode,
st_gaudi_transformer_save,
st_gaudi_transformer_tokenize,
)

Expand All @@ -33,6 +34,7 @@ def adapt_sentence_transformers_to_gaudi():
from sentence_transformers.models import Transformer

Transformer.tokenize = st_gaudi_transformer_tokenize
Transformer.save = st_gaudi_transformer_save

from sentence_transformers.data_collator import SentenceTransformerDataCollator

Expand Down
Loading

0 comments on commit 091c8a5

Please sign in to comment.