From 5fd1c0788a76e5668f975ae5ea80e69263bb4046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 22 Aug 2024 04:19:57 +0000 Subject: [PATCH 1/9] add support for s3 checkpoint: need a fix on check_path_is_local --- src/nanotron/config/config.py | 22 ++ src/nanotron/helpers.py | 3 +- src/nanotron/s3_checkpoints/__init__.py | 4 + src/nanotron/s3_checkpoints/fsspec.py | 38 ++ src/nanotron/s3_checkpoints/s3_mover.py | 439 ++++++++++++++++++++++++ src/nanotron/serialize/main.py | 108 ++++-- src/nanotron/trainer.py | 71 +++- src/nanotron/utils.py | 6 + 8 files changed, 642 insertions(+), 49 deletions(-) create mode 100644 src/nanotron/s3_checkpoints/__init__.py create mode 100644 src/nanotron/s3_checkpoints/fsspec.py create mode 100644 src/nanotron/s3_checkpoints/s3_mover.py diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index de0fa3c0..8cdc8715 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass, fields from pathlib import Path +from datasets.download.streaming_download_manager import xPath from typing import List, Optional, Type, Union import dacite @@ -91,6 +92,22 @@ def __post_init__(self): self.hf_dataset_splits = "train" +@dataclass +class S3UploadArgs: + """Arguments related to uploading checkpoints on s3""" + + upload_s3_path: xPath + remove_after_upload: bool + s5cmd_numworkers: Optional[int] + s5cmd_concurrency: Optional[int] + s5cmd_path: Optional[xPath] + + def __post_init__(self): + if isinstance(self.upload_s3_path, str): + self.upload_s3_path = xPath(self.upload_s3_path) + if isinstance(self.s5cmd_path, str): + self.s5cmd_path = xPath(self.s5cmd_path) + @dataclass class NanosetDatasetsArgs: dataset_folder: Union[str, dict, List[str]] @@ -338,6 +355,7 @@ class Config: data_stages: Optional[List[DatasetStageArgs]] = None profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None + s3_upload : Optional[S3UploadArgs] = None @classmethod def create_empty(cls): @@ -345,6 +363,10 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): + + if self.s3_upload is not None: + self.s3_upload.__post_init__() + # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: assert self.tokens.train_steps < 10 diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index a82f0294..892ac03c 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -52,8 +52,9 @@ def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size multiple = make_vocab_size_divisible_by * pg_size after = int(ceil(orig_vocab_size / multiple) * multiple) - + print("hello") if after != orig_vocab_size: + print("i'm in") log_rank( f"[Vocab Size Padding] Padded vocab (size: {orig_vocab_size}) with {after - orig_vocab_size} dummy tokens (new size: {after})", logger=logger, diff --git a/src/nanotron/s3_checkpoints/__init__.py b/src/nanotron/s3_checkpoints/__init__.py new file mode 100644 index 00000000..0b32a02a --- /dev/null +++ b/src/nanotron/s3_checkpoints/__init__.py @@ -0,0 +1,4 @@ +from .fsspec import check_path_is_local, fs_copy, fs_open +from .s3_mover import S3Mover + +__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"] \ No newline at end of file diff --git a/src/nanotron/s3_checkpoints/fsspec.py b/src/nanotron/s3_checkpoints/fsspec.py new file mode 100644 index 00000000..01786489 --- /dev/null +++ b/src/nanotron/s3_checkpoints/fsspec.py @@ -0,0 +1,38 @@ +import contextlib +from pathlib import Path +from typing import Tuple, Union + +import fsspec +from fsspec.implementations import local + + +def get_filesystem_and_path(path: Path, storage_options=None) -> Tuple[fsspec.AbstractFileSystem, str]: + # Use supported filesystems in `fsspec`. If you need another one, please use `fsspec.registry.register_implementation` + # DO NOT USE `mode` argument as it adds a suffix `0.part` when using `mode="w"`. + fs, _, paths = fsspec.core.get_fs_token_paths(str(path), storage_options=storage_options) + assert len(paths) == 1 + return fs, paths[0] + + +@contextlib.contextmanager +def fs_open( + file: Union[str, Path], + mode="r", +): + # TODO @thomasw21: pass storage options + fs, path = get_filesystem_and_path(file) + with fs.open(path, mode=mode) as f: + yield f + + +def fs_copy( + input_file: Union[str, Path], + output_file: Union[str, Path], +): + """Copy file from input to output (possibly on s3/other fs)""" + with fs_open(input_file, mode="rb") as fi, fs_open(output_file, mode="wb") as fo: + fo.write(fi.read()) + + +def check_path_is_local(path: Path, storage_options=None) -> bool: + return isinstance(get_filesystem_and_path(path=path, storage_options=storage_options)[0], local.LocalFileSystem) diff --git a/src/nanotron/s3_checkpoints/s3_mover.py b/src/nanotron/s3_checkpoints/s3_mover.py new file mode 100644 index 00000000..73c9b793 --- /dev/null +++ b/src/nanotron/s3_checkpoints/s3_mover.py @@ -0,0 +1,439 @@ +import glob +import json +import os +import subprocess +import time +from datetime import datetime +from enum import Enum +from typing import Optional, Union + +import torch +from datasets.download.streaming_download_manager import xPath +from filelock import FileLock, Timeout +from nanotron import distributed as dist +from nanotron import logging +from nanotron.distributed import ProcessGroup +from nanotron.logging import human_format + +logger = logging.get_logger(__name__) + + +class S3Mover: + #TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading + """Take care of uploading a checkpoint to S3 in the background and remove it from the disk. + + Args: + local_path: Path to the checkpoints on the local disk + s3_path: Path to the checkpoints on S3 + remove_after_upload: If True, remove the checkpoint from the disk after uploading it to S3 + s5cmd_numworkers: Number of workers to use for the s5cmd command + s5cmd_concurrency: Concurrency to use for the s5cmd command + s5cmd_path: Path to the s5cmd command + dummy: If True, don't actually upload/remove/etc anything. Useful for simpler multi-processing node and only uploading from one process. + + Usage: + # Create a mover - use dummy=True for all the process that shouldn't do anything (e.g. all but one per node) + mover = S3Mover(local_path=/scratch/my-checkpoints, + s3_path=s3://my-bucket/my-checkpoints, + remove_after_upload=True, + s5cmd_numworkers=96, + s5cmd_concurrency=10, + s5cmd_path=/admin/user/my/bin/s5cmd, + dummy=False) + + while training: + # from times to times update the state + mover_status = mover.update() + ... + + # When saving a checkpoint, check if the previous checkpoint has been uploaded and removed + # in a distributed setting + """ + + class S3MoverState(Enum): + IDLE = "IDLE" + UPLOADING = "UPLOADING" + DOWNLOADING = "DOWNLOADING" + REMOVING_CHECKPOINT = "REMOVING_CHECKPOINT" + + class DummyPopen: + def __init__(self, *args, **kwargs): + pass + + def poll(self): + return 0 + + def communicate(self): + return ("", "") + + def __init__( + self, + local_path: xPath, + s3_path: xPath, + # duplicate_checkpoint_path: Optional[xPath] = None, + post_upload_callback: Optional[callable] = None, + remove_after_upload: Optional[bool] = True, + s5cmd_numworkers: Optional[int] = None, + s5cmd_concurrency: Optional[int] = None, + s5cmd_path: Optional[str] = None, + s5cmd_credentials: Optional[str] = None, + clean_up_local_on_start: bool = False, + dummy: bool = False, + s3_region: str = "us-east-1", + ): + self.process: Optional[Union[subprocess.Popen, S3Mover.DummyPopen]] = None + self.remove_after_upload = remove_after_upload + self.s5cmd_numworkers = s5cmd_numworkers + self.s5cmd_concurrency = s5cmd_concurrency + self.s5cmd_path = s5cmd_path if s5cmd_path is not None else "s5cmd" + self.s5cmd_credentials = s5cmd_credentials + self.lock_file = None + self.dummy = dummy + self.s3_region = s3_region + self.post_upload_callback = post_upload_callback + self.post_upload_callback_outputs = None + + local_path = str(local_path) + if not local_path.startswith("/scratch/"): + self._warning(f"The local path is not on the scratch drive: {local_path}") + if not local_path.endswith("/"): + local_path += "/" + + s3_path = str(s3_path) + if not s3_path.endswith("/"): + s3_path += "/" + + self.local_path = local_path + self.s3_path = s3_path + + s3_bucket, s3_prefix = s3_path.replace("s3://", "").split("/", maxsplit=1) + self.s3_path_direct_link = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?region={self.s3_region}&prefix={s3_prefix}&showversions=false" + + self._reset_state() + if clean_up_local_on_start: + self._start_removing() + + def _warning(self, message): + if self.dummy: + return + logger.warning(message) + + def _info(self, message): + if self.dummy: + return + logger.info(message) + + def _reset_state(self): + self.state = self.S3MoverState.IDLE + self.num_uploaded_files = 0 + if self.lock_file is not None: + self._release_lock() + self.lock_file = None + self.stdout = "" + self.start_time: datetime = None + self.cmd = "" + + def _popen(self, cmd: list): + self.stdout = "" + self.start_time = datetime.now() + self.cmd = cmd + if self.dummy: + return self.DummyPopen(cmd) + else: + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + os.set_blocking(process.stdout.fileno(), False) + return process + + def _acquire_lock(self, file_path: str) -> bool: + if self.dummy: + return True + if file_path.endswith("/"): + lock_file_path = file_path[:-1] + ".lock" + else: + lock_file_path = file_path + ".lock" + self.lock_file = FileLock(lock_file_path) + try: + self.lock_file.acquire(timeout=1) + except Timeout: + message = f"[S3] The checkpoint files {lock_file_path} are currently locked by another process. " + self._warning(message) + return False + return True + + def get_state_as_int(self) -> int: + """Return the state as an int""" + if self.state == self.S3MoverState.IDLE: + return 0 + elif self.state == self.S3MoverState.UPLOADING: + return 1 + elif self.state == self.S3MoverState.DOWNLOADING: + return 2 + elif self.state == self.S3MoverState.REMOVING_CHECKPOINT: + return 3 + else: + return -1 + + def _release_lock(self): + if self.dummy: + return + if self.lock_file is not None and self.lock_file.is_locked: + self.lock_file.release() + + def get_current_stdout(self) -> str: + """Return the current stdout of the process if any""" + if self.process is None or isinstance(self.process, self.DummyPopen): + return "" + try: + stdout = self.process.stdout.read() + except ValueError: + stdout = "" # The buffer is already closed: "ValueError: read of closed file" + if stdout: + self.stdout += stdout.decode() + return self.stdout + + def wait_for_completion(self): + while self.state != self.S3MoverState.IDLE: + _ = self.update() + time.sleep(0.5) + + def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None): + """Wait for the previous checkpoint to be fully uploaded and removed in a distributed setting. + Will wait for all process to be ready + """ + if group is None: + group = dist.torch_dist.distributed_c10d._get_default_group() + + test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda")) + test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())] + dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False) + dist.barrier() + all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list) + if all_saved != group.size() and self.state != self.S3MoverState.IDLE: + self._warning( + f"Waiting previous checkpoint saving is finished - S3Mover {dist.get_rank(group)} still in {self.state} state.", + ) + while all_saved != group.size(): + stdout = self.get_current_stdout() + stdout_lines = [lst for lst in stdout.split("\n") if lst] + if self.state != self.S3MoverState.IDLE: + self._warning( + f"[S3] Waiting {self.state.value}: {all_saved} / {group.size()}. Stdout: {len(stdout_lines)} end: {stdout_lines[-1:]}", + ) + # sync all our saves on NCCL we could do a dist barrier later but this helps us not loosing NCCL connections down the line + test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda")) + test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())] + dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False) + dist.barrier() + all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list) + time.sleep(1) + + def is_previous_save_finished(self) -> bool: + """Return True if a potential previous checkpoint has been fully uploaded to S3 + and removed from the drive + """ + self.update() + return self.state == self.S3MoverState.IDLE + + def _start_downloading(self, sub_folder: Optional[str] = None) -> (bool, str): + self._warning( + f"[S3] Downloading checkpoint in background from {self.s3_path} to {self.local_path} (direct link: {self.s3_path_direct_link})" + ) + cmd = [self.s5cmd_path, "--json"] + if self.s5cmd_credentials is not None: + cmd += ["--credentials-file", self.s5cmd_credentials] + if self.s5cmd_numworkers is not None: + cmd += ["--numworkers", str(self.s5cmd_numworkers)] + cmd += ["cp"] + if self.s5cmd_concurrency is not None: + cmd += ["--concurrency", str(self.s5cmd_concurrency)] + cmd += [self.s3_path + "*", self.local_path] + + self.process = self._popen(cmd) + self.state = self.S3MoverState.DOWNLOADING + + return True + + def _post_downloading(self) -> bool: + self.get_current_stdout() + s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i] + total_files = len([i for i in s5cmd_results if i["success"]]) + total_not_downloaded_files = len([i for i in s5cmd_results if not i["success"]]) + if total_not_downloaded_files == 0: + all_upload = "all files" + success = True + else: + all_upload = "not all files" + success = False + total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"]) + total_time = (datetime.now() - self.start_time).total_seconds() + self._warning( + f"[S3] Successfully downloaded {total_files} files for a total of {human_format(total_size)}B in {total_time}" + f"sec ({all_upload}) from S3 at {self.s3_path} to {self.local_path}" + f"(direct link: {self.s3_path_direct_link})" + ) + return success + + def _start_uploading( + self, + ) -> (bool, str): + # Get a file lock on the first file + local_files = glob.glob(self.full_local_path + "/**/*.*", recursive=True) + + locked = self._acquire_lock(local_files[0]) + if not locked: + return False + + if not os.path.exists(self.full_local_path): + message = f"[S3] Checkpoint {self.full_local_path} does not exist, cannot upload to S3" + self._warning(message) + return False + + self._warning( + f"[S3] Uploading checkpoint in background from {self.full_local_path} to {self.full_s3_path} (direct link: {self.s3_path_direct_link})" + ) + cmd = [self.s5cmd_path, "--json"] + if self.s5cmd_credentials is not None: + cmd += ["--credentials-file", self.s5cmd_credentials] + if self.s5cmd_numworkers is not None: + cmd += ["--numworkers", str(self.s5cmd_numworkers)] + cmd += ["cp", "--exclude", "*.lock", "--exclude", "*.lock.*"] + if self.s5cmd_concurrency is not None: + cmd += ["--concurrency", str(self.s5cmd_concurrency)] + cmd += [self.full_local_path, self.full_s3_path] + + self.process = self._popen(cmd) + self.state = self.S3MoverState.UPLOADING + + return True + + def _post_uploading(self) -> bool: + self.get_current_stdout() + s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i] + local_files = glob.glob(self.full_local_path + "/**/*.?*", recursive=True) + total_files = len([i for i in s5cmd_results if i["success"]]) + self.num_uploaded_files = total_files + if len(local_files) == total_files: + all_upload = "all files" + success = True + else: + all_upload = f"not all files: {len(local_files)} out of {total_files}" + success = False + total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"]) + total_time = (datetime.now() - self.start_time).total_seconds() + self._warning( + f"[S3] Successfully uploaded {total_files} files for a total of {human_format(total_size)}B in {total_time} sec" + f"({all_upload}) from {self.full_local_path} to S3 at {self.full_s3_path} " + f"(direct link: {self.s3_path_direct_link})" + ) + if self.post_upload_callback: + self.post_upload_callback_outputs = self.post_upload_callback(uploaded_files=s5cmd_results) + self._release_lock() + return success + + def _start_removing(self) -> (bool, str): + top_dir_in_local_checkpoint = [dir for dir in glob.glob(self.local_path + "/*") if os.path.isdir(dir)] + names_dir = [os.path.basename(dir) for dir in top_dir_in_local_checkpoint] + if len(names_dir) == 0: + # If the local is already empty or if we have already started duplicating in another process we skip with a noop + self._warning("[S3] Local checkpoint empty. skipping removal") + cmd = ["echo", "'skipping'"] + self.process = self._popen(cmd) + self.state = self.S3MoverState.REMOVING_CHECKPOINT + return True + + self._warning(f"[S3] Removing checkpoint in background: {names_dir}") + locked = self._acquire_lock(top_dir_in_local_checkpoint[0]) + if not locked: + return False + cmd = ["rm", "-rfv"] + top_dir_in_local_checkpoint + self.process = self._popen(cmd) + self.state = self.S3MoverState.REMOVING_CHECKPOINT + return True + + def _post_removing(self) -> bool: + self.get_current_stdout() + local_files = [ + loc_f + for loc_f in self.stdout.split("\n") + if "directory" not in loc_f.lower() and loc_f and ".lock" not in loc_f + ] + if len(local_files) == self.num_uploaded_files: + all_removed = "all files" + success = True + else: + all_removed = "not all files" + success = False + self._release_lock() + total_time = (datetime.now() - self.start_time).total_seconds() + self._warning( + f"[S3] Successfully removed {len(local_files)} local files ({all_removed}) from {self.local_path} (uploaded to {self.s3_path_direct_link}) in {total_time}" + ) + return success + + def update(self) -> (str, str): + """Update the state of the mover: UPLOADING => REMOVING_DUPLICATED => DUPLICATING => REMOVING_CHECKPOINT => IDLE + + Returns: + (str, str): The state and the stdout of the process if any + """ + if self.process is None: + self._reset_state() + return self.state, self.stdout + + return_code = self.process.poll() + if return_code is None: + # Still running + return self.state, self.stdout + if return_code != 0: + self.get_current_stdout() + self._warning( + f"[S3] Error running command {self.cmd} during process {self.state.value}, " + f"return code {return_code}, return message {self.stdout}" + ) + return self.state, self.stdout + if self.state == self.S3MoverState.DOWNLOADING: + self._post_downloading() + self._reset_state() + elif self.state == self.S3MoverState.UPLOADING: + self._post_uploading() + if self.remove_after_upload: + self._start_removing() + else: + self._reset_state() + elif self.state == self.S3MoverState.REMOVING_CHECKPOINT: + self._post_removing() + self._reset_state() + + return self.state.value, self.stdout + + def start_uploading(self, sub_folder=None): + """Start uploading last saved checkpoint to S3 in the background. + + After running this method, you should call regularly `update` to update the + state to duplicating and then removing. + + For a blocking upload, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method. + """ + self.update() + if self.state != self.S3MoverState.IDLE: + message = "[S3] Cannot move to S3 as the previous checkpoint has not been uploaded and removed" + self._warning(message) + return False + self.full_local_path = self.local_path + (f"/{sub_folder}" if sub_folder else "") + self.full_s3_path = self.s3_path + (f"/{sub_folder}" if sub_folder else "") + return self._start_uploading() + + def start_downloading(self): + """Start downloading a checkpoint from S3 in the background. + + After running this method, you should call regularly `update` to update the + state. + + For a blocking download, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method. + """ + self.update() + if self.state != self.S3MoverState.IDLE: + message = f"[S3] Cannot download from S3 as the state is not IDLE but {self.state.value}" + self._warning(message) + return False + return self._start_downloading() diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 286008ac..da192c94 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -1,6 +1,9 @@ from pathlib import Path from typing import Optional, cast +from datasets.download.streaming_download_manager import xPath +import os +from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open import torch from torch import nn from torch.nn.parallel import DistributedDataParallel @@ -251,33 +254,80 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: Path to checkpoint or None if no checkpoint. """ load_from_candidate = config.checkpoints.resume_checkpoint_path - if load_from_candidate is None: - return None - - latest_meta_path: Path = config.checkpoints.resume_checkpoint_path / "latest.txt" - if latest_meta_path.exists(): - with open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: - # TODO @thomasw21: make a better structure system so that we get typing correct - load_from_candidate = int(fi.read()) - checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate) - - elif (config.checkpoints.resume_checkpoint_path / MODEL_CONFIG_FILE_NAME).exists(): - # we assume that the checkpoint path is a path to a checkpoint - checkpoint_path = config.checkpoints.resume_checkpoint_path - - else: - log_rank( - f"No previous checkpoint found in: {latest_meta_path}", - logger=logger, - level=logging.INFO, - rank=0, - ) - return None + print(load_from_candidate) + if load_from_candidate is not None: + if check_path_is_local(load_from_candidate): + print("I'M LOCAL") + latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt" + if latest_meta_path.exists(): + with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: + # TODO @thomasw21: make a better structure system so that we get typing correct + load_from_candidate = int(fi.read()) + checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate) + + elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists(): + # we assume that the checkpoint path is a path to a checkpoint + checkpoint_path = config.checkpoints.resume_checkpoint_path + + else: + log_rank( + f"No previous checkpoint found in: {latest_meta_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + return None - log_rank( - f"Loading checkpoint from {checkpoint_path}", - logger=logger, - level=logging.INFO, - rank=0, - ) - return checkpoint_path + log_rank( + f"Loading checkpoint from {checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + else: + # elif check_path_is_s3(str(load_from_candidate)): + print("I'M S3") + latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt" + if latest_meta_path.exists(): + # if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint + with fs_open(latest_meta_path, mode="r") as fi: + latest_iteration = int(fi.read()) + s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path + checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path + elif config.checkpoints.resume_checkpoint_path.exists(): + print("INNNN") + # we assume that the checkpoint path is a path to a checkpoint + s3_path = config.checkpoints.resume_checkpoint_path # load_path + checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path + else: + log_rank( + f"No previous checkpoint found in: {config.checkpoints.resume_checkpoint_path}\n Initializing from scratch.", + logger=logger, + level=logging.WARNING, + rank=0, + ) + return None + log_rank( + f"Downloading checkpoint from S3 in {checkpoint_path} ", + logger=logger, + level=logging.WARNING, + rank=0, + ) + # Download checkpoint from S3 + s3_mover = S3Mover( + local_path=os.path.join(checkpoint_path), + s3_path=os.path.join(s3_path), + s5cmd_numworkers=config.s3_upload.s5cmd_numworkers, + s5cmd_concurrency=config.s3_upload.s5cmd_concurrency, + s5cmd_path=config.s3_upload.s5cmd_path, + dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0), + ) + s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg) + s3_mover.start_downloading() + s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg) + + # Replace S3 path with local path in config + return checkpoint_path + # else: + # raise Exception(f"{load_from_candidate} should be either a local link or a s3 link.") + # return None \ No newline at end of file diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 70d023fb..6c54068d 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,6 +19,8 @@ cast, ) +from nanotron.s3_checkpoints import S3Mover, check_path_is_local +from nanotron.utils import check_path_is_s3 import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader @@ -149,11 +151,12 @@ def __init__( expert_parallel_size=self.config.parallelism.expert_parallel_size, ) - self.pre_init() - # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) + self.pre_init() + + # Log benchmark info if os.environ.get("NANOTRON_BENCHMARK", "0") == "1": log_throughput(self.config, self.parallel_context) @@ -255,10 +258,27 @@ def __init__( self.post_init() def pre_init(self): - pass + print("in pre init") + self.init_checkpoint_path = parse_ckpt_path(config=self.config) + print(self.init_checkpoint_path) def post_init(self): - pass + # S3 Mover and save initial state + if self.config.s3_upload is not None: + # Only local rank 0 should upload + dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0) + self.s3_mover = S3Mover( + local_path=self.config.checkpoints.checkpoints_path, + s3_path=self.config.s3_upload.upload_s3_path, + # duplicate_checkpoint_path=self.config.checkpoints.resume_checkpoint_path, + remove_after_upload=self.config.s3_upload.remove_after_upload, + s5cmd_numworkers=self.config.s3_upload.s5cmd_numworkers, + s5cmd_concurrency=self.config.s3_upload.s5cmd_concurrency, + s5cmd_path=self.config.s3_upload.s5cmd_path, + dummy=dummy, + ) + else: + self.s3_mover = None def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -281,11 +301,15 @@ def pre_training(self, *args, **kwargs): ) def post_train_step(self): - pass - def post_training(self): - pass + # Update our background upload/removal of checkpoints + if self.s3_mover is not None: + self.s3_mover.update() + def post_training(self): + if self.s3_mover is not None: + self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) + def _print_training_plan(self): if hasattr(self.config, "data_stages") and self.config.data_stages is not None: stages_info = "".join( @@ -689,20 +713,22 @@ def _init_model_instance(self) -> NanotronModel: def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model - # Load or initialize model weights - self.init_checkpoint_path = parse_ckpt_path(config=self.config) + # Load or initialize model weights reloaded_from_checkpoint = False if self.init_checkpoint_path is not None: - # Reload from a training checkpoint - log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) - self.param_shard_metadata = load_weights( - model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path - ) - reloaded_from_checkpoint = True + # Load from a pre existing checkpoint + if check_path_is_local(self.init_checkpoint_path): + # Reload from a training checkpoint + log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + self.param_shard_metadata = load_weights( + model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + ) + reloaded_from_checkpoint=True if not reloaded_from_checkpoint: + # TODO @eliebak add s3 support also here log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO) if isinstance(self.config.model.init_method, ExistingCheckpointInit): - # Initialize model from an pretrained model checkpoint + # Initialize model from an pretrained model checkpoint (without optimizer, lr_scheduler...) self.param_shard_metadata = load_weights( model=unwrapped_model, parallel_context=self.parallel_context, @@ -830,11 +856,18 @@ def setup_log_writers( return loggerwriter - def pre_save_checkpoint(self): - pass + def pre_save_checkpoint(self) -> Path: + if self.s3_mover is not None: + self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) + if self.s3_mover.post_upload_callback_outputs is not None: + slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs + self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") def post_save_checkpoint(self): - pass + # Upload to S3 + if self.s3_mover is not None: + self.s3_mover.start_uploading() + def save_checkpoint(self) -> Path: self.pre_save_checkpoint() diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index b3831801..a99289da 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -3,6 +3,7 @@ import os import random import socket +import re from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional @@ -160,3 +161,8 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: return port except OSError: continue + +def check_path_is_s3(path:str) -> bool: + #TODO maybe replace by a better method ? + s3_pattern = r'^s3://|^https?://[\w\-\.]+\.s3\.amazonaws\.com/|^https?://s3\.amazonaws\.com/[\w\-\.]+' + return bool(re.match(s3_pattern, path)) From e24cce15f736887a9f9a16663b1442256ba6b1dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 22 Aug 2024 23:53:31 +0000 Subject: [PATCH 2/9] fix path to xpath, delete debugging print, working on 1 gpu --- src/nanotron/config/config.py | 6 +++--- src/nanotron/serialize/main.py | 12 ++++-------- src/nanotron/trainer.py | 7 +++---- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 8cdc8715..0744dd69 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -163,14 +163,14 @@ class CheckpointsArgs: checkpoint_interval: int save_initial_state: Optional[bool] = False save_final_state: Optional[bool] = False - resume_checkpoint_path: Optional[Path] = None + resume_checkpoint_path: Optional[xPath] = None checkpoints_path_is_shared_file_system: Optional[bool] = False def __post_init__(self): if isinstance(self.checkpoints_path, str): - self.checkpoints_path = Path(self.checkpoints_path) + self.checkpoints_path = xPath(self.checkpoints_path) if isinstance(self.resume_checkpoint_path, str): - self.resume_checkpoint_path = Path(self.resume_checkpoint_path) + self.resume_checkpoint_path = xPath(self.resume_checkpoint_path) @dataclass diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index da192c94..3c93c69a 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -244,7 +244,7 @@ def load( return checkpoint_metadata -def parse_ckpt_path(config: Config) -> Optional[Path]: +def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]: """Parse checkpoint path from config and download checkpoint from S3 if needed. Args: @@ -254,10 +254,8 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: Path to checkpoint or None if no checkpoint. """ load_from_candidate = config.checkpoints.resume_checkpoint_path - print(load_from_candidate) if load_from_candidate is not None: if check_path_is_local(load_from_candidate): - print("I'M LOCAL") latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: @@ -286,7 +284,6 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: ) else: # elif check_path_is_s3(str(load_from_candidate)): - print("I'M S3") latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): # if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint @@ -295,7 +292,6 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path elif config.checkpoints.resume_checkpoint_path.exists(): - print("INNNN") # we assume that the checkpoint path is a path to a checkpoint s3_path = config.checkpoints.resume_checkpoint_path # load_path checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path @@ -322,12 +318,12 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: s5cmd_path=config.s3_upload.s5cmd_path, dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0), ) - s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg) + s3_mover.distributed_wait_for_completion(parallel_context.world_pg) s3_mover.start_downloading() - s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg) + s3_mover.distributed_wait_for_completion(parallel_context.world_pg) # Replace S3 path with local path in config - return checkpoint_path + return checkpoint_path # else: # raise Exception(f"{load_from_candidate} should be either a local link or a s3 link.") # return None \ No newline at end of file diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 6c54068d..0a7f57f1 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -150,11 +150,12 @@ def __init__( data_parallel_size=self.config.parallelism.dp, expert_parallel_size=self.config.parallelism.expert_parallel_size, ) + + self.pre_init() # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) - self.pre_init() # Log benchmark info @@ -258,9 +259,7 @@ def __init__( self.post_init() def pre_init(self): - print("in pre init") - self.init_checkpoint_path = parse_ckpt_path(config=self.config) - print(self.init_checkpoint_path) + self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) def post_init(self): # S3 Mover and save initial state From d5a36507c888589ee0525c0831f503e298a1b519 Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Fri, 23 Aug 2024 05:45:31 +0200 Subject: [PATCH 3/9] fixing comment --- src/nanotron/serialize/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 3c93c69a..4365504b 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -283,7 +283,6 @@ def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Option rank=0, ) else: - # elif check_path_is_s3(str(load_from_candidate)): latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): # if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint @@ -326,4 +325,4 @@ def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Option return checkpoint_path # else: # raise Exception(f"{load_from_candidate} should be either a local link or a s3 link.") - # return None \ No newline at end of file + # return None From a62daad39bbfbbee67a0e71e2a89eaa046263b10 Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Fri, 23 Aug 2024 05:46:12 +0200 Subject: [PATCH 4/9] fixing comment --- src/nanotron/serialize/main.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 4365504b..e9ed2572 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -321,8 +321,4 @@ def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Option s3_mover.start_downloading() s3_mover.distributed_wait_for_completion(parallel_context.world_pg) - # Replace S3 path with local path in config return checkpoint_path - # else: - # raise Exception(f"{load_from_candidate} should be either a local link or a s3 link.") - # return None From 17468b22c21b133308ae0c0111c09fc28b9b12ed Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Fri, 23 Aug 2024 05:47:23 +0200 Subject: [PATCH 5/9] no need for check_path_is_s3 anymore --- src/nanotron/utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index a99289da..cb187f77 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -161,8 +161,3 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: return port except OSError: continue - -def check_path_is_s3(path:str) -> bool: - #TODO maybe replace by a better method ? - s3_pattern = r'^s3://|^https?://[\w\-\.]+\.s3\.amazonaws\.com/|^https?://s3\.amazonaws\.com/[\w\-\.]+' - return bool(re.match(s3_pattern, path)) From 5f95a4f08f9dfd7299c689b76c2eb4c213a8cc21 Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Fri, 23 Aug 2024 17:03:26 +0200 Subject: [PATCH 6/9] no need for check_path_is_s3 anymore --- src/nanotron/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0a7f57f1..5626de1e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -20,7 +20,6 @@ ) from nanotron.s3_checkpoints import S3Mover, check_path_is_local -from nanotron.utils import check_path_is_s3 import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader From 2651a178a4e1433c5227e66c2832d18d70259743 Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Sat, 24 Aug 2024 19:06:32 +0200 Subject: [PATCH 7/9] remove debugging print --- src/nanotron/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 892ac03c..761fffc2 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -52,7 +52,6 @@ def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size multiple = make_vocab_size_divisible_by * pg_size after = int(ceil(orig_vocab_size / multiple) * multiple) - print("hello") if after != orig_vocab_size: print("i'm in") log_rank( From ab04877fa19d8f1565338bf09569c3b1fe970285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Wed, 4 Sep 2024 07:28:34 +0000 Subject: [PATCH 8/9] add datasets by default + s3 flavor --- pyproject.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6a0cfb83..9794ab78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "safetensors", "dacite", "tqdm", + "datasets", ] [tool.setuptools.packages.find] @@ -53,6 +54,12 @@ nanosets = [ "numba", ] +s3 = [ + "boto3", + "s3fs", + "s5cmd", +] + [build-system] requires = [ "setuptools", From b17daf3b24c4ad47e85fd0c4239a80980b7c5051 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 9 Sep 2024 16:32:42 +0000 Subject: [PATCH 9/9] minor changes and add an example config for s3 --- README.md | 1 + .../config_tiny_llama_with_s3_upload.yaml | 115 ++++++++++++++++++ src/nanotron/helpers.py | 9 +- src/nanotron/utils.py | 1 - 4 files changed, 119 insertions(+), 7 deletions(-) create mode 100644 examples/config_tiny_llama_with_s3_upload.yaml diff --git a/README.md b/README.md index 7a22f12a..6bdc917d 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ You can find more examples in the [`/examples`](/examples) directory: | `mamba` | Train an example Mamba model | | `moe` | Train an example Mixture-of-Experts (MoE) model | | `mup` | Use spectral µTransfer to scale up your model | +| `examples/config_tiny_llama_with_s3_upload.yaml` | For automatically uploading checkpoints to S3 | We're working on adding more examples soon! Feel free to add a PR to add your own example. 🚀 diff --git a/examples/config_tiny_llama_with_s3_upload.yaml b/examples/config_tiny_llama_with_s3_upload.yaml new file mode 100644 index 00000000..d71cba73 --- /dev/null +++ b/examples/config_tiny_llama_with_s3_upload.yaml @@ -0,0 +1,115 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: s3://phuc-experiments/temp/config_tiny_llama_with_s3_upload + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 30 + val_check_interval: -1 +s3_upload: + remove_after_upload: true + s5cmd_concurrency: 5 + s5cmd_numworkers: 16 + s5cmd_path: /fsx/nouamane/miniconda/envs/2-1-cu121/bin/s5cmd + upload_s3_path: s3://phuc-experiments/temp/config_tiny_llama_with_s3_upload diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 761fffc2..fc41bfb3 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -53,7 +53,6 @@ def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size multiple = make_vocab_size_divisible_by * pg_size after = int(ceil(orig_vocab_size / multiple) * multiple) if after != orig_vocab_size: - print("i'm in") log_rank( f"[Vocab Size Padding] Padded vocab (size: {orig_vocab_size}) with {after - orig_vocab_size} dummy tokens (new size: {after})", logger=logger, @@ -147,10 +146,8 @@ def lr_lambda(current_step: int, initial_lr: float): / lr_decay_steps ) elif lr_scheduler_args.lr_decay_style == "1-sqrt": - lmbda = ( - lr_scheduler_args.min_decay_lr - + (initial_lr - lr_scheduler_args.min_decay_lr) - * (1 - math.sqrt((current_step - lr_decay_starting_step) / lr_decay_steps)) + lmbda = lr_scheduler_args.min_decay_lr + (initial_lr - lr_scheduler_args.min_decay_lr) * ( + 1 - math.sqrt((current_step - lr_decay_starting_step) / lr_decay_steps) ) else: raise ValueError(f"Unknown decay style {lr_scheduler_args.lr_decay_style}") @@ -693,7 +690,7 @@ def is_resume_from_training(): else: next_stage = next((s for s in config.data_stages if s.start_training_step > stage.start_training_step), None) total_train_steps = next_stage.start_training_step - + if metadata.last_train_step > stage.start_training_step: # NOTE: if the last_train_step is larger than the start_training_step of the current stage, # it means that the training has already passed this stage diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index cb187f77..b3831801 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -3,7 +3,6 @@ import os import random import socket -import re from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional