Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed Jun 4, 2024
2 parents fa10c10 + 4dd7103 commit b8d5df7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 6 deletions.
4 changes: 4 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def run_tests(args):
f"Skipping test {test_flavor.test_name} that requires {test_flavor.ngpu} gpus,"
f" because --ngpu arg is {args.ngpu}"
)
elif args.ngpu == 8 and test_flavor.ngpu != 8:
logger.info(
f"Skipping non-8gpu test {test_flavor.test_name} on 8-gpu runner"
)
else:
run_test(test_flavor, full_path, args.output_dir)

Expand Down
18 changes: 18 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import enum
import os
import re
import shutil
import time
from multiprocessing import get_context
from typing import Any, Dict
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.keep_latest_k = ckpt_config.keep_latest_k

if not self.enable_checkpoint:
return
Expand Down Expand Up @@ -313,6 +315,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
else:
dcp.save(self.states, checkpoint_id=checkpoint_id)
self.reset()
self._purge_stale_checkpoints()

logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
Expand Down Expand Up @@ -364,3 +367,18 @@ def load(self, step: int = -1) -> bool:
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
return True

def _purge_stale_checkpoints(self):
if self.keep_latest_k > 0:
discovered_checkpoints = []
for filename in os.listdir(self.folder):
match = re.search(r"step-(\d+)", filename)
path = os.path.join(self.folder, filename)
discovered_checkpoints.append((int(match.group(1)), path))

discovered_checkpoints.sort()
to_delete = discovered_checkpoints[: -1 * self.keep_latest_k]

for _, path in to_delete:
logger.info(f"Deleting old checkpoint {path}")
shutil.rmtree(path, ignore_errors=True)
10 changes: 9 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,15 @@ def __init__(self):
"disabled" is the default mode.
""",
)

self.parser.add_argument(
"--checkpoint.keep_latest_k",
type=int,
default=0,
help="""
Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
0 is the default value.
""",
)
# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand Down
7 changes: 3 additions & 4 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,15 @@ def pipeline_llama_manual(
splits = job_config.experimental.pipeline_parallel_split_points
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < pp_size - 1 else None

if pp_rank > 0:
model.tok_embeddings = None

drop_layers = True
drop_layers = start_layer is not None
for name in list(model.layers.keys()):
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
if start_layer is None or f"layers.{name}" == start_layer:
if f"layers.{name}" == start_layer:
drop_layers = False
if stop_layer is not None and f"layers.{name}" == stop_layer:
if f"layers.{name}" == stop_layer:
drop_layers = True
if drop_layers:
del model.layers[name]
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def loss_fn(pred, labels):
model, world_mesh, parallel_dims, job_config
)

model.to_empty(device="cuda")
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)

if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
Expand Down

0 comments on commit b8d5df7

Please sign in to comment.