diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 8cdc8715..0744dd69 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -163,14 +163,14 @@ class CheckpointsArgs: checkpoint_interval: int save_initial_state: Optional[bool] = False save_final_state: Optional[bool] = False - resume_checkpoint_path: Optional[Path] = None + resume_checkpoint_path: Optional[xPath] = None checkpoints_path_is_shared_file_system: Optional[bool] = False def __post_init__(self): if isinstance(self.checkpoints_path, str): - self.checkpoints_path = Path(self.checkpoints_path) + self.checkpoints_path = xPath(self.checkpoints_path) if isinstance(self.resume_checkpoint_path, str): - self.resume_checkpoint_path = Path(self.resume_checkpoint_path) + self.resume_checkpoint_path = xPath(self.resume_checkpoint_path) @dataclass diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index da192c94..3c93c69a 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -244,7 +244,7 @@ def load( return checkpoint_metadata -def parse_ckpt_path(config: Config) -> Optional[Path]: +def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]: """Parse checkpoint path from config and download checkpoint from S3 if needed. Args: @@ -254,10 +254,8 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: Path to checkpoint or None if no checkpoint. """ load_from_candidate = config.checkpoints.resume_checkpoint_path - print(load_from_candidate) if load_from_candidate is not None: if check_path_is_local(load_from_candidate): - print("I'M LOCAL") latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: @@ -286,7 +284,6 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: ) else: # elif check_path_is_s3(str(load_from_candidate)): - print("I'M S3") latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): # if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint @@ -295,7 +292,6 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path elif config.checkpoints.resume_checkpoint_path.exists(): - print("INNNN") # we assume that the checkpoint path is a path to a checkpoint s3_path = config.checkpoints.resume_checkpoint_path # load_path checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path @@ -322,12 +318,12 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: s5cmd_path=config.s3_upload.s5cmd_path, dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0), ) - s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg) + s3_mover.distributed_wait_for_completion(parallel_context.world_pg) s3_mover.start_downloading() - s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg) + s3_mover.distributed_wait_for_completion(parallel_context.world_pg) # Replace S3 path with local path in config - return checkpoint_path + return checkpoint_path # else: # raise Exception(f"{load_from_candidate} should be either a local link or a s3 link.") # return None \ No newline at end of file diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 6c54068d..0a7f57f1 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -150,11 +150,12 @@ def __init__( data_parallel_size=self.config.parallelism.dp, expert_parallel_size=self.config.parallelism.expert_parallel_size, ) + + self.pre_init() # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) - self.pre_init() # Log benchmark info @@ -258,9 +259,7 @@ def __init__( self.post_init() def pre_init(self): - print("in pre init") - self.init_checkpoint_path = parse_ckpt_path(config=self.config) - print(self.init_checkpoint_path) + self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) def post_init(self): # S3 Mover and save initial state