Skip to content

Commit

Permalink
Fixed Embeddings PyTorch Device
Browse files Browse the repository at this point in the history
  • Loading branch information
w11wo committed May 20, 2024
1 parent c4e7ae5 commit 25fcae2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
19 changes: 19 additions & 0 deletions unsupervised_learning/ConGen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ python train_con_gen.py \
--teacher-temp 0.5
```

## ConGen with Cohere Embeddings

### NusaBERT Base

```sh
python train_con_gen_cohere.py \
--model-name LazarusNLP/NusaBERT-base \
--train-dataset-name Cohere/wikipedia-2023-11-embed-multilingual-v3 \
--max-seq-length 128 \
--max-train-samples 1000000 \
--num-epochs 20 \
--train-batch-size 128 \
--early-stopping-patience 7 \
--learning-rate 1e-4 \
--queue-size 65536 \
--student-temp 0.5 \
--teacher-temp 0.5
```

## References

```bibtex
Expand Down
3 changes: 3 additions & 0 deletions unsupervised_learning/ConGen/train_con_gen_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def main(args: Args):
)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])

# move encoded texts to the same device as model
encoded_texts = encoded_texts.to(model._target_device)

# create instance queue
text_in_queue = np.random.RandomState(16349).choice(
train_ds[args.train_text_column], args.queue_size, replace=False
Expand Down

0 comments on commit 25fcae2

Please sign in to comment.