Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
kylematoba committed Oct 9, 2024
1 parent 7b7ead9 commit 6216fab
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 14 deletions.
21 changes: 18 additions & 3 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml
```
"""
import time
import argparse
from typing import Dict, cast

import numpy as np

from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.dataloader_builder import build_nanoset_dataloader
Expand All @@ -27,12 +29,13 @@
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.trainer import DistributedTrainer
import nanotron.trainer
from nanotron.utils import main_rank_first
from torch.utils.data import DataLoader

try:
from huggingface_hub import __version__ as hf_hub_version
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import __version__ as tf_version
except ImportError:
hf_hub_version = None
Expand Down Expand Up @@ -60,6 +63,10 @@ def get_dataloader_from_data_stage(
# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)

print("--" * 40)
print(data.dataset)
print(type(data.dataset))
print("--" * 40)
# Case 1: Dummy data generator
if data.dataset is None:
log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
Expand Down Expand Up @@ -142,6 +149,13 @@ def get_dataloader_from_data_stage(
# Case 3: Nanosets
elif isinstance(data.dataset, NanosetDatasetsArgs):
# Get tokenizer cardinality
# sleep_seconds = 600
# print(f"Sleeping for {sleep_seconds} seconds")
# time.sleep(sleep_seconds)

print(trainer.config.tokenizer.tokenizer_name_or_path)
# model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-Nemo-Base-2407")
# del model
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
Expand Down Expand Up @@ -233,5 +247,6 @@ def get_args():
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)

# Train
trainer.train(dataloader)
config = nanotron.trainer.get_config_from_file(config_file)
trainer.train(dataloader, validation_args=config.validation)

11 changes: 11 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class LoggingArgs:
log_level: Optional[str] = None
log_level_replica: Optional[str] = None
iteration_step_info_interval: Optional[int] = 1
iteration_step_validation_interval: Optional[int] = None

def __post_init__(self):
if self.log_level is None:
Expand Down Expand Up @@ -230,6 +231,14 @@ class TokenizerArgs:
tokenizer_max_length: Optional[int] = None


@dataclass
class ValidationArgs:
"""Arguments related to the validation"""
# datasets: Optional[Union[DataArgs, List[DataArgs]]] = None
# datasets: Optional[Union[DataArgs, list[DataArgs]]] = None
datasets: Optional[Union[DataArgs, list]] = None


@dataclass
class TokensArgs:
"""Arguments related to the tokens, sequence, batch and steps of the training"""
Expand Down Expand Up @@ -340,6 +349,8 @@ class Config:
data_stages: Optional[List[DatasetStageArgs]] = None
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None
validation: Optional[ValidationArgs] = None


@classmethod
def create_empty(cls):
Expand Down
110 changes: 99 additions & 11 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
cast,
)

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader

from transformers import AutoTokenizer


from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import (
Expand Down Expand Up @@ -54,6 +60,7 @@
log_rank,
set_ranks_logging_level,
)
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.models import NanotronModel, build_model
from nanotron.models.base import check_model_has_grad
from nanotron.models.llama import LlamaForTraining, RotaryEmbedding
Expand Down Expand Up @@ -91,6 +98,7 @@
save,
save_random_states,
)
from nanotron.utils import main_rank_first
from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata
from nanotron.serialize.optimizer import load_optimizer

Expand All @@ -111,6 +119,47 @@
wandb = None



def _quick_dataloader_builder(data, trainer):
print(data)
print(type(data))
# assert isinstance(data.dataset, NanosetDatasetsArgs), f"only supporting nanosets for now data = {data}"
consumed_validation_samples = 0

# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)

tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)

token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer

from nanotron.data.nanoset import Nanoset

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = Nanoset(
dataset_folders=data.dataset.dataset_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
random_seed=data.seed,
)

train_dataloader = build_nanoset_dataloader(
train_dataset,
trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=consumed_validation_samples,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
)
return dataloader


class DistributedTrainer:
def __init__(
self,
Expand Down Expand Up @@ -397,13 +446,13 @@ def train(
],
**kwargs,
) -> None:

self.pre_training(**kwargs)

if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None:
self.save_checkpoint()

self.pipeline_engine: PipelineEngine = self.config.parallelism.pp_engine

self.pipeline_engine.nb_microbatches = self.n_micro_batches_per_batch

# TODO @nouamanetazi: refactor this
Expand Down Expand Up @@ -439,6 +488,25 @@ def train(
if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0:
self.train_step_logs(outputs=outputs, loss_avg=loss_avg)

if (self.config.logging.iteration_step_validation_interval is not None) and \
((self.iteration_step - 1) % self.config.logging.iteration_step_validation_interval == 0):
# assert "validation_args" in kwargs
validation_args = kwargs["validation_args"]
print(f"validation_args = {validation_args}")
dataset_list = validation_args.datasets
print(f"validation -> iteration_step - 1 = {self.iteration_step - 1}")
print(f"dataset_list = {dataset_list}")
print(f"len(dataset_list) = {len(dataset_list)}")
print(f"dataset_list[0] = {dataset_list[0]}")
for vds in dataset_list:
print("****" * 20)
vds = NanosetDatasetsArgs(vds)
print(vds)
print("****" * 20)
dataloader = _quick_dataloader_builder(vds, self)
validation_output = self.validation_step(dataloader)
print(f"validation_output = {validation_output}")
print(f"validation_output[0] = {validation_output[0]}")
# Checkpoint
if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0:
self.save_checkpoint()
Expand Down Expand Up @@ -550,6 +618,31 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
batch=(next(dataloader) for _ in range(self.limit_val_batches)),
nb_microbatches=self.limit_val_batches,
)
"""
# Compute DP average loss and overlap with optimizer step
if isinstance(outputs[0]["loss"], torch.Tensor):
# This is an average on only one data rank.
loss_avg = torch.stack(
[output["loss"] for output in outputs]
).sum() # already divided by n_micro_batches_per_batch
# sync loss across DP
handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG)
else:
loss_avg = None
handle = None
# Apply gradient
self.optimizer.step()
self.optimizer.zero_grad()
# Update the learning rate
self.lr_scheduler.step()
after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
if handle is not None:
handle.wait()
"""
return outputs

def train_step_logs(
Expand Down Expand Up @@ -602,18 +695,13 @@ def train_step_logs(
total, used, free = shutil.disk_usage("/")
log_entries.extend(
[
LogItem(
"cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format"
), # / 1024**2, ".2f"),
LogItem(
"cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format"
), # / 1024**2, ".2f"),
LogItem("hd_total_memory_tb", total, "human_format"), # / (2**40), ".2f"),
LogItem("hd_used_memory_tb", used, "human_format"), # / (2**40), ".2f"),
LogItem("hd_free_memory_tb", free, "human_format"), # / (2**40), ".2f"),
LogItem("cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format"),
LogItem("cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format"),
LogItem("hd_total_memory_tb", total, "human_format"),
LogItem("hd_used_memory_tb", used, "human_format"),
LogItem("hd_free_memory_tb", free, "human_format"),
]
)

# NOTE: only one rank writes to wandb
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.log(
Expand Down

0 comments on commit 6216fab

Please sign in to comment.