forked from huggingface/optimum-habana
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable Sentence Transformer Trainer with Gaudi (huggingface#1111)
Co-authored-by: Daniel Huang <daniel1.huang@intel.com> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
- Loading branch information
1 parent
c0606c8
commit ec90e05
Showing
21 changed files
with
2,581 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Natural Language Inference | ||
|
||
Given two sentences (premise and hypothesis), the task of Natural Language Inference (NLI) is to decide if the premise entails the hypothesis, if they are contradiction, or if they are neutral. Commonly the NLI dataset in [SNLI](https://huggingface.co/datasets/stanfordnlp/snli) and [MultiNLI](https://huggingface.co/datasets/nyu-mll/multi_nli) are used. | ||
|
||
The paper in [Conneau et al.](https://arxiv.org/abs/1705.02364) shows that NLI data can be quite useful when training Sentence Embedding methods. In [Sentence-BERT-Paper](https://arxiv.org/abs/1908.10084) NLI as a first fine-tuning step for sentence embedding methods has been used. | ||
|
||
## Single-card Training | ||
|
||
To pre-train on the NLI task: | ||
|
||
1. Choose a pre-trained model `<model_name>` (for example: [bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)). | ||
|
||
2. Load the training, validation, and test datasets. Below is an example of using the [AllNLI dataset](https://huggingface.co/datasets/sentence-transformers/all-nli) for training and validation, while the test set uses the STS Benchmark dataset. | ||
|
||
```python | ||
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train").select(range(10000)) | ||
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev").select(range(1000)) | ||
test_dataset = load_dataset("sentence-transformers/stsb", split="test") | ||
``` | ||
|
||
3. Choose one of the following scripts based on the loss model: | ||
|
||
a. **[training_nli.py](training_nli.py)**: | ||
|
||
> This example uses `sentence_transformers.losses.SoftmaxLoss` as described in the original [Sentence Transformers paper](https://arxiv.org/abs/1908.10084). | ||
|
||
b. **[training_nli_v2.py](training_nli_v2.py)**: | ||
|
||
> The `sentence_transformers.losses.SoftmaxLoss` as used in our original SBERT paper does not yield optimal performance. A better loss is `sentence_transformers.losses.MultipleNegativesRankingLoss`, where we provide pairs or triplets. In this script, we provide a triplet of the format: (anchor, entailment_sentence, contradiction_sentence). The NLI data provides such triplets. The `sentence_transformers.losses.MultipleNegativesRankingLoss` yields much higher performances and is more intuitive than `sentence_transformers.losses.SoftmaxLoss`. We have used this loss to train the paraphrase model in our [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813) paper. | ||
|
||
c) **[training_nli_v3.py](training_nli_v3.py)** | ||
|
||
> Following the [GISTEmbed](https://arxiv.org/abs/2402.16829) paper, we can modify the in-batch negative selection from `sentence_transformers.losses.MultipleNegativesRankingLoss` using a guiding model. Candidate negative pairs are ignored during training if the guiding model considers the pair to be too similar. In practice, the `sentence_transformers.losses.GISTEmbedLoss` tends to produce a stronger training signal than `sentence_transformers.losses.MultipleNegativesRankingLoss` at the cost of some training overhead for running inference on the guiding model. | ||
|
||
4. Execute the script: | ||
|
||
```bash | ||
python training_nli.py bert-base-uncased | ||
``` | ||
|
||
## Multi-card Training | ||
|
||
For multi-card training you can use the script of [gaudi_spawn.py](https://github.com/huggingface/optimum-habana/blob/main/examples/gaudi_spawn.py) to execute. There are two options to run the multi-card training by using '--use_deepspeed' or '--use_mpi'. We take the option of '--use_deepspeed' for our example of multi-card training. | ||
|
||
```bash | ||
HABANA_VISIBLE_MODULES="2,3" python ../../gaudi_spawn.py --use_deepspeed --world_size 2 training_nli.py bert-base-uncased | ||
``` | ||
|
||
## Dataset | ||
|
||
We combine [SNLI](https://huggingface.co/datasets/stanfordnlp/snli) and [MultiNLI](https://huggingface.co/datasets/nyu-mll/multi_nli) into a dataset we call [AllNLI](https://huggingface.co/datasets/sentence-transformers/all-nli). These two datasets contain sentence pairs and one of three labels: entailment, neutral, contradiction: | ||
|
||
| Sentence A (Premise) | Sentence B (Hypothesis) | Label | | ||
| ------------------------------------------------------------------ | ------------------------------------------------------------------ | ------------- | | ||
| A soccer game with multiple males playing. | Some men are playing a sport. | entailment | | ||
| An older and younger man smiling. | Two men are smiling and laughing at the cats playing on the floor. | neutral | | ||
| A man inspects the uniform of a figure in some East Asian country. | The man is sleeping. | contradiction | | ||
|
||
We format AllNLI in a few different subsets, compatible with different loss functions. See [triplet subset of AllNLI](https://huggingface.co/datasets/sentence-transformers/all-nli/viewer/triplet) as example. | ||
|
||
## SoftmaxLoss | ||
|
||
<img src="https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/docs/img/SBERT_SoftmaxLoss.png" alt="SBERT SoftmaxLoss" width="250"/> | ||
|
||
We pass the two sentences through our SentenceTransformer model and get the sentence embeddings _u_ and _v_. We then concatenate _u_, _v_ and _|u-v|_ to form one long vector. This vector is then passed to a softmax classifier, which predicts our three classes (entailment, neutral, contradiction). |
124 changes: 124 additions & 0 deletions
124
examples/sentence-transformers-training/nli/training_nli.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
""" | ||
The system trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) on the SNLI + MultiNLI (AllNLI) dataset | ||
with softmax loss function. At every 100 training steps, the model is evaluated on the | ||
STS benchmark dataset | ||
""" | ||
|
||
import logging | ||
import sys | ||
from datetime import datetime | ||
|
||
from datasets import load_dataset | ||
from sentence_transformers import SentenceTransformer, losses | ||
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator | ||
from sentence_transformers.similarity_functions import SimilarityFunction | ||
|
||
from optimum.habana import ( | ||
SentenceTransformerGaudiTrainer, | ||
SentenceTransformerGaudiTrainingArguments, | ||
) | ||
from optimum.habana.sentence_transformers.modeling_utils import adapt_sentence_transformers_to_gaudi | ||
|
||
|
||
adapt_sentence_transformers_to_gaudi() | ||
|
||
|
||
def main(): | ||
# Set the log level to INFO to get more information | ||
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | ||
|
||
# You can specify any Hugging Face pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base | ||
model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased" | ||
train_batch_size = 16 | ||
|
||
output_dir = ( | ||
"output/training_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | ||
) | ||
|
||
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically | ||
# create one with "mean" pooling. | ||
model = SentenceTransformer(model_name) | ||
|
||
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli | ||
# We'll start with 10k training samples, but you can increase this to get a stronger model | ||
logging.info("Read AllNLI train dataset") | ||
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train").select(range(10000)) | ||
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev").select(range(1000)) | ||
logging.info(train_dataset) | ||
|
||
# 3. Define our training loss: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#softmaxloss | ||
train_loss = losses.SoftmaxLoss( | ||
model=model, | ||
sentence_embedding_dimension=model.get_sentence_embedding_dimension(), | ||
num_labels=3, | ||
) | ||
|
||
# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. | ||
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") | ||
dev_evaluator = EmbeddingSimilarityEvaluator( | ||
sentences1=stsb_eval_dataset["sentence1"], | ||
sentences2=stsb_eval_dataset["sentence2"], | ||
scores=stsb_eval_dataset["score"], | ||
main_similarity=SimilarityFunction.COSINE, | ||
name="sts-dev", | ||
) | ||
logging.info("Evaluation before training:") | ||
dev_evaluator(model) | ||
|
||
# 5. Define the training arguments | ||
args = SentenceTransformerGaudiTrainingArguments( | ||
# Required parameter: | ||
output_dir=output_dir, | ||
# Optional training parameters: | ||
num_train_epochs=1, | ||
per_device_train_batch_size=train_batch_size, | ||
per_device_eval_batch_size=train_batch_size, | ||
warmup_ratio=0.1, | ||
# fp16=True, # Set to False if you get an error that your GPU can't run on FP16 | ||
# bf16=False, # Set to True if you have a GPU that supports BF16 | ||
# Optional tracking/debugging parameters: | ||
evaluation_strategy="steps", | ||
eval_steps=100, | ||
save_strategy="steps", | ||
save_steps=100, | ||
save_total_limit=2, | ||
logging_steps=100, | ||
run_name="nli-v1", # Will be used in W&B if `wandb` is installed | ||
use_habana=True, | ||
gaudi_config_name="Habana/bert-base-uncased", | ||
use_lazy_mode=True, | ||
use_hpu_graphs=True, | ||
use_hpu_graphs_for_inference=False, | ||
use_hpu_graphs_for_training=True, | ||
dataloader_drop_last=True, | ||
) | ||
|
||
# 6. Create the trainer & start training | ||
trainer = SentenceTransformerGaudiTrainer( | ||
model=model, | ||
args=args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
loss=train_loss, | ||
evaluator=dev_evaluator, | ||
) | ||
trainer.train() | ||
|
||
# 7. Evaluate the model performance on the STS Benchmark test dataset | ||
test_dataset = load_dataset("sentence-transformers/stsb", split="test") | ||
test_evaluator = EmbeddingSimilarityEvaluator( | ||
sentences1=test_dataset["sentence1"], | ||
sentences2=test_dataset["sentence2"], | ||
scores=test_dataset["score"], | ||
main_similarity=SimilarityFunction.COSINE, | ||
name="sts-test", | ||
) | ||
test_evaluator(model) | ||
|
||
# 8. Save the trained & evaluated model locally | ||
final_output_dir = f"{output_dir}/final" | ||
model.save(final_output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
130 changes: 130 additions & 0 deletions
130
examples/sentence-transformers-training/nli/training_nli_v2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
""" | ||
The system trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) on the SNLI + MultiNLI (AllNLI) dataset | ||
with MultipleNegativesRankingLoss. Entailments are positive pairs and the contradiction on AllNLI dataset is added as a hard negative. | ||
At every 10% training steps, the model is evaluated on the STS benchmark dataset | ||
Usage: | ||
python training_nli_v2.py | ||
OR | ||
python training_nli_v2.py pretrained_transformer_model_name | ||
""" | ||
|
||
import logging | ||
import sys | ||
from datetime import datetime | ||
|
||
from datasets import load_dataset | ||
from sentence_transformers import SentenceTransformer, losses | ||
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator | ||
from sentence_transformers.similarity_functions import SimilarityFunction | ||
from sentence_transformers.training_args import BatchSamplers | ||
|
||
from optimum.habana import ( | ||
SentenceTransformerGaudiTrainer, | ||
SentenceTransformerGaudiTrainingArguments, | ||
) | ||
from optimum.habana.sentence_transformers.modeling_utils import adapt_sentence_transformers_to_gaudi | ||
|
||
|
||
adapt_sentence_transformers_to_gaudi() | ||
|
||
|
||
def main(): | ||
# Set the log level to INFO to get more information | ||
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | ||
|
||
model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" | ||
train_batch_size = ( | ||
16 # The larger you select this, the better the results (usually). But it requires more GPU memory | ||
) | ||
|
||
# Save path of the model | ||
output_dir = ( | ||
"output/training_nli_v2_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | ||
) | ||
|
||
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically | ||
# create one with "mean" pooling. | ||
model = SentenceTransformer(model_name) | ||
|
||
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli | ||
# We'll start with 10k training samples, but you can increase this to get a stronger model | ||
logging.info("Read AllNLI train dataset") | ||
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train").select(range(10000)) | ||
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev").select(range(1000)) | ||
logging.info(train_dataset) | ||
|
||
# 3. Define our training loss: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#multiplenegativesrankingloss | ||
train_loss = losses.MultipleNegativesRankingLoss(model) | ||
|
||
# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. | ||
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") | ||
dev_evaluator = EmbeddingSimilarityEvaluator( | ||
sentences1=stsb_eval_dataset["sentence1"], | ||
sentences2=stsb_eval_dataset["sentence2"], | ||
scores=stsb_eval_dataset["score"], | ||
main_similarity=SimilarityFunction.COSINE, | ||
name="sts-dev", | ||
) | ||
logging.info("Evaluation before training:") | ||
dev_evaluator(model) | ||
|
||
# 5. Define the training arguments | ||
args = SentenceTransformerGaudiTrainingArguments( | ||
# Required parameter: | ||
output_dir=output_dir, | ||
# Optional training parameters: | ||
num_train_epochs=1, | ||
per_device_train_batch_size=train_batch_size, | ||
per_device_eval_batch_size=train_batch_size, | ||
warmup_ratio=0.1, | ||
# fp16=True, # Set to False if you get an error that your GPU can't run on FP16 | ||
# bf16=False, # Set to True if you have a GPU that supports BF16 | ||
batch_sampler=BatchSamplers.NO_DUPLICATES, | ||
# Optional tracking/debugging parameters: | ||
evaluation_strategy="steps", | ||
eval_steps=10, | ||
save_strategy="steps", | ||
save_steps=10, | ||
save_total_limit=2, | ||
logging_steps=100, | ||
run_name="nli-v2", # Will be used in W&B if `wandb` is installed | ||
use_habana=True, | ||
gaudi_config_name="Habana/bert-base-uncased", | ||
use_lazy_mode=True, | ||
use_hpu_graphs=True, | ||
use_hpu_graphs_for_inference=False, | ||
use_hpu_graphs_for_training=True, | ||
dataloader_drop_last=True, | ||
) | ||
|
||
# 6. Create the trainer & start training | ||
trainer = SentenceTransformerGaudiTrainer( | ||
model=model, | ||
args=args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
loss=train_loss, | ||
evaluator=dev_evaluator, | ||
) | ||
trainer.train() | ||
|
||
# 7. Evaluate the model performance on the STS Benchmark test dataset | ||
test_dataset = load_dataset("sentence-transformers/stsb", split="test") | ||
test_evaluator = EmbeddingSimilarityEvaluator( | ||
sentences1=test_dataset["sentence1"], | ||
sentences2=test_dataset["sentence2"], | ||
scores=test_dataset["score"], | ||
main_similarity=SimilarityFunction.COSINE, | ||
name="sts-test", | ||
) | ||
test_evaluator(model) | ||
|
||
# 8. Save the trained & evaluated model locally | ||
final_output_dir = f"{output_dir}/final" | ||
model.save(final_output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.