Skip to content

Commit

Permalink
Added All v4 Dataset Results and CachedMNRL Loss Training
Browse files Browse the repository at this point in the history
  • Loading branch information
w11wo committed May 15, 2024
1 parent dab15bb commit 66b3e22
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 21 deletions.
20 changes: 19 additions & 1 deletion README.md

Large diffs are not rendered by default.

35 changes: 33 additions & 2 deletions docs/index.md

Large diffs are not rendered by default.

49 changes: 34 additions & 15 deletions docs/training/all.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@ Inspired by [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-

## Training Data

| Dataset | Task | Data Instance | Number of Training Tuples |
| ------------------------------------------------------------------------------------------------------------------ | :----------------------------: | :-------------------------------------------: | :-----------------------: |
| [indonli](https://huggingface.co/datasets/indonli) | Natural Language Inference | `(premise, entailment, contradiction)` | 3,914 |
| [indolem/indo_story_cloze](https://huggingface.co/datasets/indolem/indo_story_cloze) | Commonsense Reasoning | `(context, correct ending, incorrect ending)` | 1,000 |
| [unicamp-dl/mmarco](https://huggingface.co/datasets/unicamp-dl/mmarco) | Passage Retrieval | `(query, positive passage, negative passage)` | 100,000 |
| [miracl/miracl](https://huggingface.co/datasets/miracl/miracl) | Passage Retrieval | `(query, positive passage, negative passage)` | 8,086 |
| [SEACrowd/wrete](https://huggingface.co/datasets/SEACrowd/wrete) | Textual Entailment | `(sentenceA, sentenceB)` | 183 |
| [SEACrowd/indolem_ntp](https://huggingface.co/datasets/SEACrowd/indolem_ntp) | Textual Entailment | `(tweet, next tweet)` | 5,681 |
| [khalidalt/tydiqa-goldp](https://huggingface.co/datasets/khalidalt/tydiqa-goldp) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 11,404 |
| [SEACrowd/facqa](https://huggingface.co/datasets/SEACrowd/facqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 4,990 |
| *included in v2* |
| [indonesian-nlp/lfqa_id](https://huggingface.co/datasets/indonesian-nlp/lfqa_id) | Open-domain Question-Answering | `(question, answer)` | 226,147 |
| [jakartaresearch/indoqa](https://huggingface.co/datasets/jakartaresearch/indoqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 6,498 |
| [jakartaresearch/id-paraphrase-detection](https://huggingface.co/datasets/jakartaresearch/id-paraphrase-detection) | Paraphrase | `(sentence, rephrased sentence)` | 4,076 |
| **Total** | | | **371,979** |
| Dataset | Task | Data Instance | Number of Training Tuples |
| -------------------------------------------------------------------------------------------------------------------------- | :----------------------------: | :-------------------------------------------: | :-----------------------: |
| [indonli](https://huggingface.co/datasets/indonli) | Natural Language Inference | `(premise, entailment, contradiction)` | 3,914 |
| [indolem/indo_story_cloze](https://huggingface.co/datasets/indolem/indo_story_cloze) | Commonsense Reasoning | `(context, correct ending, incorrect ending)` | 1,000 |
| [unicamp-dl/mmarco](https://huggingface.co/datasets/unicamp-dl/mmarco) | Passage Retrieval | `(query, positive passage, negative passage)` | 100,000 |
| [miracl/miracl](https://huggingface.co/datasets/miracl/miracl) | Passage Retrieval | `(query, positive passage, negative passage)` | 8,086 |
| [SEACrowd/wrete](https://huggingface.co/datasets/SEACrowd/wrete) | Textual Entailment | `(sentenceA, sentenceB)` | 183 |
| [SEACrowd/indolem_ntp](https://huggingface.co/datasets/SEACrowd/indolem_ntp) | Textual Entailment | `(tweet, next tweet)` | 5,681 |
| [khalidalt/tydiqa-goldp](https://huggingface.co/datasets/khalidalt/tydiqa-goldp) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 11,404 |
| [SEACrowd/facqa](https://huggingface.co/datasets/SEACrowd/facqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 4,990 |
| *included in v2* |
| [indonesian-nlp/lfqa_id](https://huggingface.co/datasets/indonesian-nlp/lfqa_id) | Open-domain Question-Answering | `(question, answer)` | 226,147 |
| [jakartaresearch/indoqa](https://huggingface.co/datasets/jakartaresearch/indoqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 6,498 |
| [jakartaresearch/id-paraphrase-detection](https://huggingface.co/datasets/jakartaresearch/id-paraphrase-detection) | Paraphrase | `(sentence, rephrased sentence)` | 4,076 |
| *included in v3* |
| [LazarusNLP/multilingual-NLI-26lang-2mil7-id](https://huggingface.co/datasets/LazarusNLP/multilingual-NLI-26lang-2mil7-id) | Natural Language Inference | `(premise, entailment, hypothesis)` | 41,924 |
| *included in v4* |
| [nthakur/swim-ir-monolingual](https://huggingface.co/datasets/nthakur/swim-ir-monolingual) | Passage Retrieval | `(query, positive passage, negative passage)` | 227,145 |
| **Total** | | | **641,048** |

## All Supervised Datasets with MultipleNegativesRankingLoss

Expand Down Expand Up @@ -46,6 +50,21 @@ python train_all_mnrl.py \
--learning-rate 2e-5
```

## All Supervised Datasets with CachedMultipleNegativesRankingLoss

### IndoBERT Base

```sh
python train_all_mnrl.py \
--model-name indobenchmark/indobert-base-p1 \
--max-seq-length 128 \
--num-epochs 5 \
--train-batch-size-pairs 384 \
--train-batch-size-triplets 256 \
--mini-batch-size 320 \
--learning-rate 2e-5
```

## References

```bibtex
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ git+https://github.com/w11wo/SCT.git
git+https://github.com/embeddings-benchmark/mteb.git
datasets
scikit-learn
datargs
datargs
nusacrowd
21 changes: 19 additions & 2 deletions training/all/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ Inspired by [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-
| [jakartaresearch/indoqa](https://huggingface.co/datasets/jakartaresearch/indoqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 6,498 |
| [jakartaresearch/id-paraphrase-detection](https://huggingface.co/datasets/jakartaresearch/id-paraphrase-detection) | Paraphrase | `(sentence, rephrased sentence)` | 4,076 |
| *included in v3* |
| [LazarusNLP/multilingual-NLI-26lang-2mil7-id](https://huggingface.co/datasets/LazarusNLP/multilingual-NLI-26lang-2mil7-id) | Natural Language Inference | `(premise, entailement hypothesis)` | 41,924 |
| **Total** | | | **413,903** |
| [LazarusNLP/multilingual-NLI-26lang-2mil7-id](https://huggingface.co/datasets/LazarusNLP/multilingual-NLI-26lang-2mil7-id) | Natural Language Inference | `(premise, entailment, hypothesis)` | 41,924 |
| *included in v4* |
| [nthakur/swim-ir-monolingual](https://huggingface.co/datasets/nthakur/swim-ir-monolingual) | Passage Retrieval | `(query, positive passage, negative passage)` | 227,145 |
| **Total** | | | **641,048** |

## All Supervised Datasets with MultipleNegativesRankingLoss

Expand Down Expand Up @@ -48,6 +50,21 @@ python train_all_mnrl.py \
--learning-rate 2e-5
```

## All Supervised Datasets with CachedMultipleNegativesRankingLoss

### IndoBERT Base

```sh
python train_all_mnrl.py \
--model-name indobenchmark/indobert-base-p1 \
--max-seq-length 128 \
--num-epochs 5 \
--train-batch-size-pairs 384 \
--train-batch-size-triplets 256 \
--mini-batch-size 320 \
--learning-rate 2e-5
```

## References

```bibtex
Expand Down
36 changes: 36 additions & 0 deletions training/all/all_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,42 @@ def train_samples() -> List[InputExample]:
return train_samples


@dataclass
class SwimIR:
dataset = load_dataset("nthakur/swim-ir-monolingual", "id", split="train")

@staticmethod
def train_samples() -> List[InputExample]:
train_data = {}
train_samples = []

for datum in SwimIR.dataset:
query = datum["query"].strip()
answer = datum["text"].strip()
title = datum["title"].strip()

if title not in train_data:
train_data[title] = {query: [answer]}
elif title in train_data and query not in train_data[title]:
train_data[title][query] = [answer]
else:
train_data[title][query].append(answer)

for title, queries in train_data.items():
passage_queries = list(queries.keys())
# cannot get a negative sample if only 1 query in that passage
if len(passage_queries) > 1:
for query, answers in queries.items():
positive = random.choice(answers)
# get random negative sample, from different query
random_query = random.choice([q for q in passage_queries if q != query])
negative = random.choice(queries[random_query])

train_samples.append(InputExample(texts=[query, positive, negative]))

return train_samples


@dataclass
class IndoStoryCloze:
dataset = load_dataset("indolem/indo_story_cloze", split="train", trust_remote_code=True)
Expand Down
127 changes: 127 additions & 0 deletions training/all/train_all_cached_mnrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from dataclasses import dataclass
import math

from datargs import parse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, InputExample, models, losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

from all_datasets import (
IndoNLI,
IndoStoryCloze,
mMARCO,
MIRACL,
SwimIR,
MultilingualNLI,
WReTE,
IndoLEMNTP,
TyDiQA,
FacQA,
LFQAID,
IndoQA,
ParaphraseDetection,
)
from MultiDatasetDataLoader import MultiDatasetDataLoader


@dataclass
class Args:
# data args
model_name: str = "indobenchmark/indobert-base-p1"
# train
max_seq_length: int = 128
# test
test_dataset_name: str = "LazarusNLP/stsb_mt_id"
test_dataset_split: str = "validation"
test_text_column_1: str = "text_1"
test_text_column_2: str = "text_2"
test_label_column: str = "correlation"
# training args
num_epochs: int = 5
train_batch_size_pairs: int = 384
train_batch_size_triplets: int = 256
test_batch_size: int = 32
mini_batch_size: int = 128
learning_rate: float = 2e-5
warmup_ratio: float = 0.1
output_path: str = "exp/all-indobert-base"
use_amp: bool = True
# huggingface hub args
hub_model_id: str = "LazarusNLP/all-indobert-base"
hub_private_repo: bool = True


def main(args: Args):
# Load datasets
raw_datasets = {
"indonli": IndoNLI,
"indolem/indo_story_cloze": IndoStoryCloze,
"unicamp-dl/mmarco": mMARCO,
"miracl/miracl": MIRACL,
"nthakur/swim-ir-monolingual": SwimIR,
"LazarusNLP/multilingual-NLI-26lang-2mil7-id": MultilingualNLI,
"SEACrowd/wrete": WReTE,
"SEACrowd/indolem_ntp": IndoLEMNTP,
"khalidalt/tydiqa-goldp": TyDiQA,
"SEACrowd/facqa": FacQA,
"indonesian-nlp/lfqa_id": LFQAID,
"jakartaresearch/indoqa": IndoQA,
"jakartaresearch/id-paraphrase-detection": ParaphraseDetection,
}

train_ds = [ds.train_samples() for ds in raw_datasets.values()]
test_ds = load_dataset(args.test_dataset_name, split=args.test_dataset_split)

# Intialize model with mean pool
word_embedding_model = models.Transformer(args.model_name, max_seq_length=args.max_seq_length)
dimension = word_embedding_model.get_word_embedding_dimension()
pooling_model = models.Pooling(dimension, pooling_mode="mean")
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# DataLoader to batch your data
train_dataloader = MultiDatasetDataLoader(
train_ds, batch_size_pairs=args.train_batch_size_pairs, batch_size_triplets=args.train_batch_size_triplets
)

warmup_steps = math.ceil(
len(train_dataloader) * args.num_epochs * args.warmup_ratio
) # 10% of train data for warm-up

# Setup test data for evaluation
test_data = [
InputExample(
texts=[data[args.test_text_column_1], data[args.test_text_column_2]],
label=float(data[args.test_label_column]) / 5.0,
)
for data in test_ds
]

evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_data, batch_size=args.test_batch_size)

# Use the denoising auto-encoder loss
train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=args.mini_batch_size)

# Call the fit method
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=args.num_epochs,
warmup_steps=warmup_steps,
show_progress_bar=True,
optimizer_params={"lr": args.learning_rate, "eps": 1e-6},
output_path=args.output_path,
save_best_model=True,
use_amp=args.use_amp,
)

# Save model to HuggingFace Hub
model.save_to_hub(
args.hub_model_id,
private=args.hub_private_repo,
train_datasets=list(raw_datasets.keys()),
)


if __name__ == "__main__":
args = parse(Args)
main(args)
2 changes: 2 additions & 0 deletions training/all/train_all_mnrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
IndoStoryCloze,
mMARCO,
MIRACL,
SwimIR,
MultilingualNLI,
WReTE,
IndoLEMNTP,
Expand Down Expand Up @@ -56,6 +57,7 @@ def main(args: Args):
"indolem/indo_story_cloze": IndoStoryCloze,
"unicamp-dl/mmarco": mMARCO,
"miracl/miracl": MIRACL,
"nthakur/swim-ir-monolingual": SwimIR,
"LazarusNLP/multilingual-NLI-26lang-2mil7-id": MultilingualNLI,
"SEACrowd/wrete": WReTE,
"SEACrowd/indolem_ntp": IndoLEMNTP,
Expand Down

0 comments on commit 66b3e22

Please sign in to comment.