From 896b7cabe11a150148020fadd694f6333e9a8ce6 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 17 Dec 2024 15:52:27 -0800 Subject: [PATCH] [BE] add integration test for the generation script ghstack-source-id: 42013cd4bc32a60e5c79c4b3ed976584eb17efad Pull Request resolved: https://github.com/pytorch/torchtitan/pull/741 --- scripts/generate/test_generate.py | 21 ++++++++++-------- tests/integration_tests.py | 37 +++++++++++++++++++++++++------ torchtitan/utils.py | 18 ++++++++++----- train.py | 4 +++- 4 files changed, 57 insertions(+), 23 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index f46c0967..210bd673 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -52,7 +52,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): }, ) - for layer_id, transformer_block in model.layers.items(): + for _, transformer_block in model.layers.items(): layer_plan = { "attention.wq": ColwiseParallel(), "attention.wk": ColwiseParallel(), @@ -81,6 +81,7 @@ def test_generate( batch_size: int = 1, top_k: Optional[int] = None, seed: Optional[int] = None, + deterministic: bool = False, ): init_logger() color = utils.Color @@ -95,13 +96,6 @@ def test_generate( "The input prompt is empty, model will respond from a empty sequence." ) - utils.set_determinism(seed) - - if seed is None: - logger.info("Deterministic sampling off") - else: - logger.info(f"Deterministic sampling on. Using seed: {seed}") - world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) device = torch.device(f"{device_type}:{local_rank}") @@ -128,6 +122,7 @@ def test_generate( logger.info(f"Init model on init_device: {init_device}") model = model_cls.from_model_args(model_config) + world_mesh = None # Init distributed env if world_size > 1: utils.init_distributed(config) @@ -147,6 +142,8 @@ def test_generate( # sequences would require https://github.com/pytorch/torchtitan/pull/686 apply_tp_minus_sp(model, world_mesh["tp"]) + utils.set_determinism(world_mesh, device, seed, deterministic) + # materalize model model.to_empty(device=device_type) model.eval() @@ -276,8 +273,13 @@ def test_generate( "--top_k", type=int, help="Prune to select from top_k probabilities. Optional" ) parser.add_argument("--seed", type=int, help="Random seed for reproducibility") + parser.add_argument( + "--deterministic", + action="store_true", + help="Use deterministic algorithms wherever possible, may be slower", + ) - parser.add_argument("--prompt", type=str, help="Input prompt") + parser.add_argument("--prompt", type=str, default="", help="Input prompt") parser.add_argument( "--out", @@ -297,6 +299,7 @@ def test_generate( batch_size=args.batch_size, top_k=args.top_k, seed=args.seed, + deterministic=args.deterministic, ) if torch.distributed.is_initialized(): diff --git a/tests/integration_tests.py b/tests/integration_tests.py index a81e2a68..ca58e4d2 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -362,6 +362,19 @@ def build_test_list(): "fsdp+tp+cp", ngpu=8, ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--experimental.pipeline_parallel_degree 2", + "--training.enable_cpu_offload True", + "--optimizer.early_step_in_backward", + ], + ], + "Enable CPU Offload with PP", + "enable_cpu_offload+PP", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -376,14 +389,14 @@ def build_test_list(): [ [ "--checkpoint.enable_checkpoint", - "--experimental.pipeline_parallel_degree 2", - "--training.enable_cpu_offload True", - "--optimizer.early_step_in_backward", + ], + [ + # placeholder for the generation script's generate step ], ], - "Enable CPU Offload with PP", - "enable_cpu_offload+PP", - ngpu=4, + "Generation script test", + "test_generate", + ngpu=2, ), ] return integration_tests_flavors @@ -406,7 +419,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}" all_ranks = ",".join(map(str, range(test_flavor.ngpu))) - for override_arg in test_flavor.override_args: + for idx, override_arg in enumerate(test_flavor.override_args): cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh" # dump compile trace for debugging purpose cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd @@ -422,6 +435,16 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): logger.info( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) + + # save checkpoint (idx == 0) and load it for generation (idx == 1) + if test_name == "test_generate" and idx == 1: + cmd = ( + f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " + f"CHECKPOINT_DIR={output_dir}/{test_name}/checkpoint/step-10 " + "PROMPT='What is the meaning of life?' " + f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json" + ) + result = _run_cmd(cmd) logger.info(result.stdout) if result.returncode != 0: diff --git a/torchtitan/utils.py b/torchtitan/utils.py index c377b7f5..8a065153 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -20,7 +20,6 @@ from torch._utils import _get_available_device_type, _get_device_module from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor -from torchtitan.config_manager import JobConfig from torchtitan.logging import logger @@ -54,9 +53,10 @@ def _warn_overwrite_env(env, val): def set_determinism( - world_mesh: DeviceMesh, + world_mesh: Optional[DeviceMesh], device: torch.device, - job_config: JobConfig, + seed: Optional[int] = None, + deterministic: bool = False, ) -> None: """ Set the same DTensor manual seed for all ranks within the same DTensor SPMD group, but different @@ -67,8 +67,8 @@ def set_determinism( Set Determinism flags for increased reproducibility with loss of performance. """ - if job_config.training.deterministic: - logger.info("Deterministic training enabled (expect perf degradation).") + if deterministic: + logger.info("Deterministic algorithm enabled (expect perf degradation).") torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -76,9 +76,15 @@ def set_determinism( # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + if not world_mesh: + if seed is not None: + torch.manual_seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed % 2**32) + logger.debug(f"Single-process job using seed: {seed}") + return + # to ensure we can control which ranks have same or different seeds, all ranks agree on a starting seed. # if user provides one, we use this. Otherwise rank 0 rolls the dice and everyone else uses that. - seed = job_config.training.seed if seed is None: # Extract the seed for torch's main generator on rank 0 and standardizes on using that to build # seeds for unique SPMD groups diff --git a/train.py b/train.py index 13b3290e..3b157ad1 100644 --- a/train.py +++ b/train.py @@ -73,7 +73,9 @@ def main(job_config: JobConfig): pp_mesh = world_mesh["pp"] # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss) - utils.set_determinism(world_mesh, device, job_config) + utils.set_determinism( + world_mesh, device, job_config.training.seed, job_config.training.deterministic + ) model_name = job_config.model.name # build tokenizer