diff --git a/DiarizationLM/unsloth/1_finetune.py b/DiarizationLM/unsloth/1_finetune.py index 03edff7..74a3531 100644 --- a/DiarizationLM/unsloth/1_finetune.py +++ b/DiarizationLM/unsloth/1_finetune.py @@ -8,7 +8,7 @@ import dataprep import torch from transformers import TrainingArguments -from trl import SFTTrainer +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM from unsloth import FastLanguageModel from unsloth import is_bfloat16_supported @@ -53,6 +53,10 @@ def run_training() -> None: ############################################################################ # Train the model ############################################################################ + collator = DataCollatorForCompletionOnlyLM( + config.PROMPT_SUFFIX.rstrip(), + tokenizer=tokenizer) + trainer = SFTTrainer( model=model, tokenizer=tokenizer, @@ -61,6 +65,7 @@ def run_training() -> None: max_seq_length=config.MAX_SEQ_LENGTH, dataset_num_proc=2, packing=False, + data_collator=collator, args=TrainingArguments( per_device_train_batch_size=16, gradient_accumulation_steps=1,