Skip to content

Commit

Permalink
IMPORTANT: train on completions only
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed Jul 27, 2024
1 parent c2427ea commit d341d61
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion DiarizationLM/unsloth/1_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import dataprep
import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported

Expand Down Expand Up @@ -53,6 +53,10 @@ def run_training() -> None:
############################################################################
# Train the model
############################################################################
collator = DataCollatorForCompletionOnlyLM(
config.PROMPT_SUFFIX.rstrip(),
tokenizer=tokenizer)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -61,6 +65,7 @@ def run_training() -> None:
max_seq_length=config.MAX_SEQ_LENGTH,
dataset_num_proc=2,
packing=False,
data_collator=collator,
args=TrainingArguments(
per_device_train_batch_size=16,
gradient_accumulation_steps=1,
Expand Down

0 comments on commit d341d61

Please sign in to comment.