Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

precommit #232

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import os
from dataclasses import dataclass, fields
from pathlib import Path
from datasets.download.streaming_download_manager import xPath
from typing import List, Optional, Type, Union

import dacite
import torch
import yaml
from dacite import from_dict
from datasets.download.streaming_download_manager import xPath
from yaml.loader import SafeLoader

from nanotron.config.lighteval_config import LightEvalConfig
Expand Down Expand Up @@ -108,6 +108,7 @@ def __post_init__(self):
if isinstance(self.s5cmd_path, str):
self.s5cmd_path = xPath(self.s5cmd_path)


@dataclass
class NanosetDatasetsArgs:
dataset_folder: Union[str, List[str]]
Expand Down Expand Up @@ -151,7 +152,6 @@ class CheckpointsArgs:
checkpoints_path: where to save the checkpoints
checkpoint_interval: how often to save the checkpoints
resume_checkpoint_path: if you want to load from a specific checkpoint path

"""

checkpoints_path: Path
Expand Down Expand Up @@ -350,15 +350,15 @@ class Config:
data_stages: Optional[List[DatasetStageArgs]] = None
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None
s3_upload : Optional[S3UploadArgs] = None
s3_upload: Optional[S3UploadArgs] = None

@classmethod
def create_empty(cls):
cls_fields = fields(cls)
return cls(**{f.name: None for f in cls_fields})

def __post_init__(self):

if self.s3_upload is not None:
self.s3_upload.__post_init__()

Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def init_random_states(parallel_config: ParallelismArgs, tp_pg: ProcessGroup):
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=tp_pg)}
)
else:
# We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
# NOTE: We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
random_states = RandomStates({})
return random_states

Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/s3_checkpoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fsspec import check_path_is_local, fs_copy, fs_open
from .s3_mover import S3Mover

__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"]
__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"]
2 changes: 1 addition & 1 deletion src/nanotron/s3_checkpoints/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def fs_open(
file: Union[str, Path],
mode="r",
):
# TODO @thomasw21: pass storage options
# TODO @thomasw21: pass storage options.
fs, path = get_filesystem_and_path(file)
with fs.open(path, mode=mode) as f:
yield f
Expand Down
6 changes: 3 additions & 3 deletions src/nanotron/s3_checkpoints/s3_mover.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from datasets.download.streaming_download_manager import xPath
from filelock import FileLock, Timeout

from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
Expand All @@ -19,7 +20,7 @@


class S3Mover:
#TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading
# TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading
"""Take care of uploading a checkpoint to S3 in the background and remove it from the disk.

Args:
Expand Down Expand Up @@ -70,7 +71,6 @@ def __init__(
self,
local_path: xPath,
s3_path: xPath,
# duplicate_checkpoint_path: Optional[xPath] = None,
post_upload_callback: Optional[callable] = None,
remove_after_upload: Optional[bool] = True,
s5cmd_numworkers: Optional[int] = None,
Expand Down Expand Up @@ -219,7 +219,7 @@ def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None):
self._warning(
f"[S3] Waiting {self.state.value}: {all_saved} / {group.size()}. Stdout: {len(stdout_lines)} end: {stdout_lines[-1:]}",
)
# sync all our saves on NCCL we could do a dist barrier later but this helps us not loosing NCCL connections down the line
# sync all our saves on NCCL we could do a dist barrier later but this helps us not losing NCCL connections down the line
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
Expand Down
9 changes: 4 additions & 5 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from pathlib import Path
from typing import Optional, cast
from datasets.download.streaming_download_manager import xPath
import os

from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open
import torch
from datasets.download.streaming_download_manager import xPath
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR
Expand All @@ -13,11 +12,11 @@
from nanotron import logging
from nanotron import optim as optim
from nanotron.config import Config
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open
from nanotron.sanity_checks import (
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
Expand All @@ -43,7 +42,7 @@

Version 1:
- serialize -> dumps every process weights in individual files
- load -> assume topology is exactly the same
- load -> assume topology is exactly the same.
"""


Expand Down
32 changes: 16 additions & 16 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
cast,
)

from nanotron.s3_checkpoints import S3Mover, check_path_is_local
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -77,6 +76,7 @@
tie_parameters,
)
from nanotron.random import set_random_seed
from nanotron.s3_checkpoints import S3Mover, check_path_is_local
from nanotron.sanity_checks import (
after_optim_step_sanity_checks,
after_tbi_sanity_checks,
Expand Down Expand Up @@ -149,14 +149,12 @@ def __init__(
data_parallel_size=self.config.parallelism.dp,
expert_parallel_size=self.config.parallelism.expert_parallel_size,
)

self.pre_init()

# Set log levels
set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging)



# Log benchmark info
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1":
log_throughput(self.config, self.parallel_context)
Expand Down Expand Up @@ -263,12 +261,11 @@ def pre_init(self):
def post_init(self):
# S3 Mover and save initial state
if self.config.s3_upload is not None:
# Only local rank 0 should upload
# NOTE: Only local rank 0 should upload
dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0)
self.s3_mover = S3Mover(
local_path=self.config.checkpoints.checkpoints_path,
s3_path=self.config.s3_upload.upload_s3_path,
# duplicate_checkpoint_path=self.config.checkpoints.resume_checkpoint_path,
remove_after_upload=self.config.s3_upload.remove_after_upload,
s5cmd_numworkers=self.config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=self.config.s3_upload.s5cmd_concurrency,
Expand Down Expand Up @@ -307,7 +304,7 @@ def post_train_step(self):
def post_training(self):
if self.s3_mover is not None:
self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg)

def _print_training_plan(self):
if hasattr(self.config, "data_stages") and self.config.data_stages is not None:
stages_info = "".join(
Expand Down Expand Up @@ -464,10 +461,10 @@ def train(
self.save_checkpoint()

dist.barrier() # let's wait for everyone before leaving

if self.config.checkpoints.save_final_state:
self.save_checkpoint()

self.post_training()

def training_step(
Expand Down Expand Up @@ -711,17 +708,21 @@ def _init_model_instance(self) -> NanotronModel:
def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel:
unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model

# Load or initialize model weights
# Load or initialize model weights
reloaded_from_checkpoint = False
if self.init_checkpoint_path is not None:
# Load from a pre existing checkpoint
# Load from a pre existing checkpoint
if check_path_is_local(self.init_checkpoint_path):
# Reload from a training checkpoint
log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0)
# Reload from a training checkpoint
log_rank(
f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0
)
self.param_shard_metadata = load_weights(
model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
reloaded_from_checkpoint=True
reloaded_from_checkpoint = True
if not reloaded_from_checkpoint:
log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0)
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
Expand Down Expand Up @@ -865,7 +866,6 @@ def post_save_checkpoint(self):
if self.s3_mover is not None:
self.s3_mover.start_uploading()


def save_checkpoint(self) -> Path:
self.pre_save_checkpoint()

Expand Down
Loading