Skip to content

Commit

Permalink
fix path to xpath, delete debugging print, working on 1 gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
eliebak committed Aug 22, 2024
1 parent 5fd1c07 commit e24cce1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ class CheckpointsArgs:
checkpoint_interval: int
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[Path] = None
resume_checkpoint_path: Optional[xPath] = None
checkpoints_path_is_shared_file_system: Optional[bool] = False

def __post_init__(self):
if isinstance(self.checkpoints_path, str):
self.checkpoints_path = Path(self.checkpoints_path)
self.checkpoints_path = xPath(self.checkpoints_path)
if isinstance(self.resume_checkpoint_path, str):
self.resume_checkpoint_path = Path(self.resume_checkpoint_path)
self.resume_checkpoint_path = xPath(self.resume_checkpoint_path)


@dataclass
Expand Down
12 changes: 4 additions & 8 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def load(
return checkpoint_metadata


def parse_ckpt_path(config: Config) -> Optional[Path]:
def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
"""Parse checkpoint path from config and download checkpoint from S3 if needed.
Args:
Expand All @@ -254,10 +254,8 @@ def parse_ckpt_path(config: Config) -> Optional[Path]:
Path to checkpoint or None if no checkpoint.
"""
load_from_candidate = config.checkpoints.resume_checkpoint_path
print(load_from_candidate)
if load_from_candidate is not None:
if check_path_is_local(load_from_candidate):
print("I'M LOCAL")
latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi:
Expand Down Expand Up @@ -286,7 +284,6 @@ def parse_ckpt_path(config: Config) -> Optional[Path]:
)
else:
# elif check_path_is_s3(str(load_from_candidate)):
print("I'M S3")
latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
# if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint
Expand All @@ -295,7 +292,6 @@ def parse_ckpt_path(config: Config) -> Optional[Path]:
s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path
checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path
elif config.checkpoints.resume_checkpoint_path.exists():
print("INNNN")
# we assume that the checkpoint path is a path to a checkpoint
s3_path = config.checkpoints.resume_checkpoint_path # load_path
checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path
Expand All @@ -322,12 +318,12 @@ def parse_ckpt_path(config: Config) -> Optional[Path]:
s5cmd_path=config.s3_upload.s5cmd_path,
dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0),
)
s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg)
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
s3_mover.start_downloading()
s3_mover.distributed_wait_for_completion(config.parallel_context.world_pg)
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)

# Replace S3 path with local path in config
return checkpoint_path
return checkpoint_path
# else:
# raise Exception(f"{load_from_candidate} should be either a local link or a s3 link.")
# return None
7 changes: 3 additions & 4 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,12 @@ def __init__(
data_parallel_size=self.config.parallelism.dp,
expert_parallel_size=self.config.parallelism.expert_parallel_size,
)

self.pre_init()

# Set log levels
set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging)

self.pre_init()


# Log benchmark info
Expand Down Expand Up @@ -258,9 +259,7 @@ def __init__(
self.post_init()

def pre_init(self):
print("in pre init")
self.init_checkpoint_path = parse_ckpt_path(config=self.config)
print(self.init_checkpoint_path)
self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context)

def post_init(self):
# S3 Mover and save initial state
Expand Down

0 comments on commit e24cce1

Please sign in to comment.