Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sasarkar/qwen finetuning bucketing #1130

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions examples/trl/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ScriptArguments:
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
num_buckets: Optional[int] = field(default=-1, metadata={"help": "whether to use bucketing for SFTTrainer"})
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
Expand Down Expand Up @@ -188,6 +189,7 @@ def create_datasets(tokenizer, args, seed=None):
tokenizer=tokenizer,
args=training_args,
formatting_func=formatting_func,
num_buckets=script_args.num_buckets,
)
train_result = trainer.train()
trainer.save_model(training_args.output_dir)
Expand Down
67 changes: 66 additions & 1 deletion optimum/habana/trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import dataclasses
import inspect
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from accelerate import PartialState
Expand All @@ -28,6 +30,7 @@
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from trl import SFTTrainer
Expand All @@ -46,6 +49,46 @@
from .sft_config import GaudiSFTConfig


class BucketedDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
def _get_bucketed_len(self, examples):
max_sentence_len = max([len(k["input_ids"]) for k in examples])
if max_sentence_len > self.buckets[-1]:
self.buckets = np.append(self.buckets, max_sentence_len)
curr_bucket = max_sentence_len
else:
curr_bucket = self.buckets[np.argmin(np.where(max_sentence_len <= self.buckets))]
return curr_bucket

# copied from https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/data/data_collator.py#L758
# change is pad_to_multiple_of=self.pad_to_multiple_of -> pad_to_multiple_of=bucketed_len
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], Mapping):
bucketed_len = self._get_bucketed_len(examples)
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
examples,
return_tensors="pt",
pad_to_multiple_of=bucketed_len, # self.pad_to_multiple_of
)
else:
assert False, "This path has not been implemented/tested yet"
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/data/data_collator.py#L765

# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch["input_ids"].clone()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch


class GaudiSFTTrainer(SFTTrainer, GaudiTrainer):
def __init__(
self,
Expand Down Expand Up @@ -75,14 +118,20 @@ def __init__(
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
eval_packing: Optional[bool] = None,
num_buckets: Optional[int] = -1,
):
"""
Copied from SFTTrainer.__init__: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L116
The only differences are:
- add new args gaudi_config
- use GaudiTrainer instead of Trainer
- cast peft model to bf16.
- num_buckets: Number of buckets. > 0 means apply bucketing, <= 0 means no bucketing
"""
if num_buckets > 0:
assert (
data_collator is None
), "For bucketing (num_buckets > 0), we only support data_collator=None (later it becomes DataCollatorForLanguageModeling)"
if args is None:
output_dir = "tmp_trainer"
warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.")
Expand Down Expand Up @@ -380,8 +429,24 @@ def make_inputs_require_grad(module, input, output):
elif self.args.max_steps == -1 and args.packing:
self.train_dataset.infinite = False

if num_buckets > 0:
train_dataloader = self.get_train_dataloader()
batched_sentence_lengths = [batch["input_ids"].shape[1] for batch in train_dataloader]
buckets = self._get_buckets(batched_sentence_lengths, num_buckets=num_buckets)
self.data_collator = BucketedDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
self.data_collator.buckets = buckets

if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks):
for callback in self.callback_handler.callbacks:
# Remove the PrinterCallback to avoid duplicated prints in case we passed a `RichProgressCallback`
if callback.__class__.__name__ == "PrinterCallback":
self.callback_handler.pop_callback(callback)

def _get_buckets(self, sentence_lengths, num_buckets):
return np.unique(
np.percentile(
sentence_lengths,
np.linspace(0, 100, num_buckets + 1),
interpolation="lower",
)[1:]
)
131 changes: 131 additions & 0 deletions tests/test_sft_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import subprocess
import tempfile
from pathlib import Path

import pytest


@pytest.mark.parametrize("model_name,expected", [("Qwen/Qwen2-7B", (30.12, 4.8347)), ("Qwen/Qwen2-72B", (6.969, 3.6))])
def test_sft_train(model_name, expected):
ds_config = """{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu" :"auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"stage3_gather_16bit_weights_on_model_save": true
},
"flops_profiler": {
"enabled": false,
"profile_step": 1,
"module_depth": -1,
"top_modules": 1,
"detailed": true,
"output_file": null
}
}
"""
env_variables = os.environ.copy()
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
filename = f"{path_to_example_dir / 'trl' / 'sft.py'}"
gaudispawn_filename = f"{path_to_example_dir / 'gaudi_spawn.py'}"

command = [
"python3",
gaudispawn_filename,
"--world_size",
"8",
"--use_deepspeed",
filename,
"--model_name_or_path",
model_name,
"--dataset_name",
"philschmid/dolly-15k-oai-style",
"--streaming",
"False",
"--bf16",
"True",
"--output_dir",
"./model_qwen",
"--num_train_epochs",
"1",
"--per_device_train_batch_size",
"8",
"--evaluation_strategy",
"no",
"--save_strategy",
"no",
"--learning_rate",
"3e-4",
"--warmup_ratio",
"0.03",
"--lr_scheduler_type",
"cosine",
"--max_grad_norm",
"0.3",
"--logging_steps",
"1",
"--do_train",
"--do_eval",
"--use_habana",
"--use_lazy_mode",
"--throughput_warmup_steps",
"3",
"--lora_r",
"4",
"--lora_alpha=16",
"--lora_dropout=0.05",
"--lora_target_modules",
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"--max_seq_length",
"512",
"--adam_epsilon",
"1e-08",
"--packing",
"False",
"--num_bucket",
"8",
"--subset",
"''",
]
if "72" in model_name:
command += [
"--max_steps",
"50",
"--gradient_checkpointing",
"True",
"--pipelining_fwd_bwd",
"True",
]
else:
command += ["--max_steps", "100"]
env_variables["DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED"] = "1"
with tempfile.NamedTemporaryFile() as fp:
fp.write(str.encode(ds_config))
fp.flush()
if "72" in model_name:
command += ["--deepspeed", fp.name]
proc = subprocess.run(
command, env=env_variables, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
)

assert proc.returncode == 0, f"Got these from process: stderr={proc.stderr}, stdout={proc.stdout}"
alllines = proc.stdout.split("\n")
train_samples_per_second = float(
[line for line in alllines if "train_samples_per_second" in line][-1].split("=")[-1]
)
perplexity = float([line for line in alllines if "perplexity" in line][-1].split("=")[-1])
train_samples_per_second_expected, perplexity_expected = expected
assert (
train_samples_per_second > 0.9 * train_samples_per_second_expected
), f"Got {train_samples_per_second}, expected 0.9*{train_samples_per_second_expected}"
assert perplexity < 1.05 * perplexity_expected, f"Got {perplexity}, expected 1.05*{perplexity_expected}"
Loading