diff --git a/run_train.py b/run_train.py index 021d955d..a1a60940 100644 --- a/run_train.py +++ b/run_train.py @@ -7,10 +7,12 @@ torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml ``` """ +import time import argparse from typing import Dict, cast import numpy as np + from nanotron import logging from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs from nanotron.data.dataloader_builder import build_nanoset_dataloader @@ -27,12 +29,13 @@ from nanotron.logging import log_rank from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks from nanotron.trainer import DistributedTrainer +import nanotron.trainer from nanotron.utils import main_rank_first from torch.utils.data import DataLoader try: from huggingface_hub import __version__ as hf_hub_version - from transformers import AutoTokenizer + from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import __version__ as tf_version except ImportError: hf_hub_version = None @@ -60,6 +63,10 @@ def get_dataloader_from_data_stage( # First, we need to know which ranks to feed the dataloader to input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + print("--" * 40) + print(data.dataset) + print(type(data.dataset)) + print("--" * 40) # Case 1: Dummy data generator if data.dataset is None: log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0) @@ -142,6 +149,13 @@ def get_dataloader_from_data_stage( # Case 3: Nanosets elif isinstance(data.dataset, NanosetDatasetsArgs): # Get tokenizer cardinality + # sleep_seconds = 600 + # print(f"Sleeping for {sleep_seconds} seconds") + # time.sleep(sleep_seconds) + + print(trainer.config.tokenizer.tokenizer_name_or_path) + # model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-Nemo-Base-2407") + # del model tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 del tokenizer @@ -233,5 +247,6 @@ def get_args(): trainer = DistributedTrainer(config_file) dataloader = get_dataloader(trainer) - # Train - trainer.train(dataloader) + config = nanotron.trainer.get_config_from_file(config_file) + trainer.train(dataloader, validation_args=config.validation) + diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index adc1eafd..4295228b 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -45,6 +45,7 @@ class LoggingArgs: log_level: Optional[str] = None log_level_replica: Optional[str] = None iteration_step_info_interval: Optional[int] = 1 + iteration_step_validation_interval: Optional[int] = None def __post_init__(self): if self.log_level is None: @@ -230,6 +231,14 @@ class TokenizerArgs: tokenizer_max_length: Optional[int] = None +@dataclass +class ValidationArgs: + """Arguments related to the validation""" + # datasets: Optional[Union[DataArgs, List[DataArgs]]] = None + # datasets: Optional[Union[DataArgs, list[DataArgs]]] = None + datasets: Optional[Union[DataArgs, list]] = None + + @dataclass class TokensArgs: """Arguments related to the tokens, sequence, batch and steps of the training""" @@ -340,6 +349,8 @@ class Config: data_stages: Optional[List[DatasetStageArgs]] = None profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None + validation: Optional[ValidationArgs] = None + @classmethod def create_empty(cls): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 21251a32..e4dab6ff 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,10 +19,16 @@ cast, ) +import numpy as np import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader +from transformers import AutoTokenizer + + +from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks +from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs from nanotron import distributed as dist from nanotron import logging from nanotron.config import ( @@ -54,6 +60,7 @@ log_rank, set_ranks_logging_level, ) +from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding @@ -91,6 +98,7 @@ save, save_random_states, ) +from nanotron.utils import main_rank_first from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata from nanotron.serialize.optimizer import load_optimizer @@ -111,6 +119,47 @@ wandb = None + +def _quick_dataloader_builder(data, trainer): + print(data) + print(type(data)) + # assert isinstance(data.dataset, NanosetDatasetsArgs), f"only supporting nanosets for now data = {data}" + consumed_validation_samples = 0 + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + + from nanotron.data.nanoset import Nanoset + + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = Nanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + random_seed=data.seed, + ) + + train_dataloader = build_nanoset_dataloader( + train_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_validation_samples, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + return dataloader + + class DistributedTrainer: def __init__( self, @@ -397,13 +446,13 @@ def train( ], **kwargs, ) -> None: + self.pre_training(**kwargs) if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: self.save_checkpoint() self.pipeline_engine: PipelineEngine = self.config.parallelism.pp_engine - self.pipeline_engine.nb_microbatches = self.n_micro_batches_per_batch # TODO @nouamanetazi: refactor this @@ -439,6 +488,25 @@ def train( if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + if (self.config.logging.iteration_step_validation_interval is not None) and \ + ((self.iteration_step - 1) % self.config.logging.iteration_step_validation_interval == 0): + # assert "validation_args" in kwargs + validation_args = kwargs["validation_args"] + print(f"validation_args = {validation_args}") + dataset_list = validation_args.datasets + print(f"validation -> iteration_step - 1 = {self.iteration_step - 1}") + print(f"dataset_list = {dataset_list}") + print(f"len(dataset_list) = {len(dataset_list)}") + print(f"dataset_list[0] = {dataset_list[0]}") + for vds in dataset_list: + print("****" * 20) + vds = NanosetDatasetsArgs(vds) + print(vds) + print("****" * 20) + dataloader = _quick_dataloader_builder(vds, self) + validation_output = self.validation_step(dataloader) + print(f"validation_output = {validation_output}") + print(f"validation_output[0] = {validation_output[0]}") # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: self.save_checkpoint() @@ -550,6 +618,31 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten batch=(next(dataloader) for _ in range(self.limit_val_batches)), nb_microbatches=self.limit_val_batches, ) + """ + # Compute DP average loss and overlap with optimizer step + if isinstance(outputs[0]["loss"], torch.Tensor): + # This is an average on only one data rank. + loss_avg = torch.stack( + [output["loss"] for output in outputs] + ).sum() # already divided by n_micro_batches_per_batch + # sync loss across DP + handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + else: + loss_avg = None + handle = None + + # Apply gradient + self.optimizer.step() + self.optimizer.zero_grad() + + # Update the learning rate + self.lr_scheduler.step() + + after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + + if handle is not None: + handle.wait() + """ return outputs def train_step_logs( @@ -602,18 +695,13 @@ def train_step_logs( total, used, free = shutil.disk_usage("/") log_entries.extend( [ - LogItem( - "cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format" - ), # / 1024**2, ".2f"), - LogItem( - "cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format" - ), # / 1024**2, ".2f"), - LogItem("hd_total_memory_tb", total, "human_format"), # / (2**40), ".2f"), - LogItem("hd_used_memory_tb", used, "human_format"), # / (2**40), ".2f"), - LogItem("hd_free_memory_tb", free, "human_format"), # / (2**40), ".2f"), + LogItem("cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format"), + LogItem("cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format"), + LogItem("hd_total_memory_tb", total, "human_format"), + LogItem("hd_used_memory_tb", used, "human_format"), + LogItem("hd_free_memory_tb", free, "human_format"), ] ) - # NOTE: only one rank writes to wandb if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.log(