diff --git a/docs/source/checkpointing/local/api.rst b/docs/source/checkpointing/local/api.rst index 6d8a361..d80f1c0 100644 --- a/docs/source/checkpointing/local/api.rst +++ b/docs/source/checkpointing/local/api.rst @@ -9,6 +9,5 @@ API documentation api/base_ckpt_manager api/local_ckpt_manager api/replication - api/group_utils api/base_state_dict api/basic_state_dict diff --git a/docs/source/checkpointing/local/api/basic_state_dict.rst b/docs/source/checkpointing/local/api/basic_state_dict.rst index 0af4cd7..debc6f3 100644 --- a/docs/source/checkpointing/local/api/basic_state_dict.rst +++ b/docs/source/checkpointing/local/api/basic_state_dict.rst @@ -2,6 +2,6 @@ BasicTensorAwareStateDict ========================= .. automodule:: nvidia_resiliency_ext.checkpointing.local.basic_state_dict - :members: + :members: BasicTensorAwareStateDict :undoc-members: :show-inheritance: diff --git a/docs/source/checkpointing/local/api/group_utils.rst b/docs/source/checkpointing/local/api/group_utils.rst deleted file mode 100644 index 1fd98f6..0000000 --- a/docs/source/checkpointing/local/api/group_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -GroupWrapper -============ - -.. automodule:: nvidia_resiliency_ext.checkpointing.local.replication.group_utils - :members: GroupWrapper, parse_group_sequence - :undoc-members: - :show-inheritance: diff --git a/docs/source/checkpointing/local/usage_guide.rst b/docs/source/checkpointing/local/usage_guide.rst index 34b2459..94a3588 100644 --- a/docs/source/checkpointing/local/usage_guide.rst +++ b/docs/source/checkpointing/local/usage_guide.rst @@ -17,8 +17,15 @@ Requirements Requirements for `LocalCheckpointManager` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - The directory specified by the `root_local_ckpt_dir` parameter must have enough storage capacity to hold - at least two checkpoint parts (clean and dirty) per rank. + at least two checkpoint parts (clean and dirty) per rank multiplied by the replication factor defined by the replication strategy (`repl_strategy`). +- If a local checkpoint had been created with replication being enabled, it's recommended to enable replication also + when loading that checkpoint, in which case the replication parameters + (i.e. `world_size`, `--replication-jump` and `--replication-factor`) must be the same as during save. + If replication is disabled during load, the replicas are ignored even if available which might lead to + inability to recover from an otherwise complete checkpoint. +- All training ranks must call `LocalCheckpointManager` methods (`save`, `load`, `find_latest`) at once, + otherwise the training ends up in a corrupted state (a NCCL collective hang or tensor allocation OOM). Requirements for `BasicTensorAwareStateDict` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -54,11 +61,11 @@ This is a basic reference and may omit specific implementation details: if iteration != -1: ta_state_dict, ckpt_part_id = ckpt_manager.load() # Use the loaded state_dict to resume training - state_dict = ta_state_dict.state_dict + model.load_state_dict(ta_state_dict.state_dict) else: # An iteration value of -1 indicates that no local checkpoint was found. # In this case, either return an error or initialize the model from scratch. - return -1 + print('Starting training from scratch') # Training loop while True: @@ -95,7 +102,6 @@ The retrieval mechanism is seamlessly integrated into the LocalCheckpointManager Asynchronous Checkpoint Saving ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The `LocalCheckpointManager` supports both synchronous and asynchronous saving, ->>>>>>> skierat/documentation controlled by the `is_async` parameter in the `save(...)` method. - Synchronous Save: When `is_async` is set to `False`, the `save(...)` method diff --git a/docs/source/inprocess/usage_guide.rst b/docs/source/inprocess/usage_guide.rst index 9c4a20b..0f3fb42 100644 --- a/docs/source/inprocess/usage_guide.rst +++ b/docs/source/inprocess/usage_guide.rst @@ -497,7 +497,7 @@ Known issues participating in a Gloo collective stops making forward progress, the remaining ranks would wait till :py:class:`ProcessGroupGloo` timeout is exceeded; a workaround is to specify a short timeout for the ``gloo`` - backend to enable faster restarts + backend to enable faster restarts. 2. NCCL collective kernel termination is implemented by periodically checking a flag residing in mapped memory, and exiting from the kernel if the flag is set. Interval of checking for this flag is controlled by @@ -521,3 +521,14 @@ Known issues 5. NCCL net plugins must be disabled by setting ``NCCL_NET_PLUGIN`` environment variable to ``"none"``. This issue will be addressed in future NCCL versions. +6. :py:class:`nvidia_resiliency_ext.inprocess.Wrapper` is not fully compatible with + :py:func:`torch.distributed.run`. :py:func:`torch.distributed.run` + automatically terminates all worker processes if any one of them fails, in + this case :py:class:`nvidia_resiliency_ext.inprocess.Wrapper` can only recover from transient + faults that don't cause termination of worker processes. +7. By default, PyTorch NCCL Watchdog forcefully terminates the process if NCCL + call returns an error, or if CUDA context was corrupted. Forceful + termination of the worker process prevents :py:class:`nvidia_resiliency_ext.inprocess.Wrapper` + from restarting the wrapper function. A workaround is to set + ``TORCH_NCCL_RETHROW_CUDA_ERRORS`` environment variable to ``0``, to avoid + rethrowing CUDA and NCCL errors in PyTorch NCCL Watchdog. diff --git a/examples/checkpointing/local_ckpt.py b/examples/checkpointing/local_ckpt.py index 1b12718..0398457 100644 --- a/examples/checkpointing/local_ckpt.py +++ b/examples/checkpointing/local_ckpt.py @@ -10,14 +10,17 @@ from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict -from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \ - LocalCheckpointManager -from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import parse_group_sequence, GroupWrapper -from nvidia_resiliency_ext.checkpointing.local.replication.strategies import CliqueReplicationStrategy +from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import ( + LocalCheckpointManager, +) +from nvidia_resiliency_ext.checkpointing.local.replication.strategies import ( + CliqueReplicationStrategy, +) # Set up basic logging configuration logging.basicConfig(level=logging.INFO) + def parse_args(): parser = argparse.ArgumentParser( description='Local Checkpointing Basic Example', @@ -44,9 +47,13 @@ def parse_args(): '--replication_jump', default=4, type=int, - help="Specifies `k` such that replica of rank `n` is stored on ranks" - "`n+k`, `n+2k`, ..., `n+rk`. `r` is the --replication-factor" - "Needs to be specified if using --replication and have the same value on all ranks", + help=( + "Specifies `J`, the spacing between ranks storing replicas of a given rank's data. " + "Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. " + "This flag has an effect only if --replication is used. " + "and must be consistent across all ranks. " + "The default value of 4 is for demonstration purposes and can be adjusted as needed." + ), ) parser.add_argument( '--replication_factor', @@ -89,24 +96,19 @@ def init_distributed_backend(backend="nccl"): logging.error(f"Error initializing the distributed backend: {e}") raise + def create_checkpoint_manager(args): if args.replication: logging.info("Creating CliqueReplicationStrategy.") - repl_process_groups_ranks : List[List[int]] = parse_group_sequence( - replication_jump=args.replication_jump, - replication_factor=args.replication_factor, - world_size=dist.get_world_size() + repl_strategy = CliqueReplicationStrategy.from_replication_params( + args.replication_jump, args.replication_factor ) - repl_process_groups: List[torch.distributed.ProcessGroup] = [ - torch.distributed.new_group(g) for g in repl_process_groups_ranks - ] - my_process_group = GroupWrapper.from_list_of_groups(repl_process_groups) - repl_strategy = CliqueReplicationStrategy(my_process_group) else: repl_strategy = None return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy) + def save(args, ckpt_manager, async_queue, model, iteration): # Create Tensor-Aware State Dict ta_state_dict = BasicTensorAwareStateDict(model.state_dict()) @@ -133,6 +135,7 @@ def load(args, ckpt_manager): logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})") return ta_state_dict.state_dict + def main(): args = parse_args() logging.info(f'{args}') @@ -147,7 +150,7 @@ def main(): ckpt_manager = create_checkpoint_manager(args) async_queue = AsyncCallsQueue() if args.async_save else None - iteration = 123 # training iteration (used as training state id) + iteration = 123 # training iteration (used as training state id) # Local checkpointing save save(args, ckpt_manager, async_queue, model, iteration) diff --git a/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py b/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py index 72492dc..ead69e2 100644 --- a/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py +++ b/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py @@ -63,7 +63,6 @@ def __init__(self, ckpt_id): super().__init__(message) - class BaseCheckpointManager(ABC): """ The Base Checkpoint Manager provides an interface for integrating different checkpoint managers, @@ -159,6 +158,8 @@ def find_latest(self): If no complete checkpoints are found, the method returns -1. + All training ranks have to call this method at once. + Returns: int: The iteration number of the most recent complete checkpoint, or -1 if no checkpoints are available. @@ -197,6 +198,8 @@ def load(self) -> Tuple[TensorAwareStateDict, str]: Ensure that `find_latest()` has been called first to identify the latest checkpoint. + All training ranks have to call this method at once. + Returns: Tuple[TensorAwareStateDict, str] - `state_dict`: The state dictionary loaded from the most recent complete checkpoint. @@ -231,6 +234,8 @@ def save( and the function returns an `AsyncRequest` object. Otherwise, the save operation is completed synchronously. + All training ranks have to call this method at once. + Args: state_dict (dict): The state dictionary to be saved. iteration (int): The iteration number for identifying the checkpoint. diff --git a/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/local_manager.py b/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/local_manager.py index b021b99..58158b8 100644 --- a/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/local_manager.py +++ b/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/local_manager.py @@ -38,11 +38,24 @@ class LocalCheckpointManager(BaseCheckpointManager): """Local Checkpoint Manager designed for handling checkpoints on local storage devices - like SSDs or RAM disks.""" + like SSDs or RAM disks. + + Args: + root_local_ckpt_dir (str, Path): root checkpoint directory on local storage. + Checkpoints from different iterations can be saved within the same root directory, + as each will have a unique name + session_id (str, optional): adds additional identification opportunity for local + checkpoints used in different training workloads. An example use case + is the `root_local_ckpt_dir` being configured by the cluster administrator + (e.g. /tmp/...) and `session_id` configured by the end user for + differentiating different local checkpoints. + repl_strategy (ReplicationStrategy, optional): strategy used to perform local checkpoint + shards replication. + """ def __init__( self, - root_local_ckpt_dir, + root_local_ckpt_dir: Union[str, Path], session_id: str = '', repl_strategy: Optional[ReplicationStrategy] = None, ): diff --git a/src/nvidia_resiliency_ext/checkpointing/local/replication/strategies.py b/src/nvidia_resiliency_ext/checkpointing/local/replication/strategies.py index ea894fc..f870ba2 100644 --- a/src/nvidia_resiliency_ext/checkpointing/local/replication/strategies.py +++ b/src/nvidia_resiliency_ext/checkpointing/local/replication/strategies.py @@ -191,6 +191,39 @@ def retrieve_execute(self, *args, **kwargs): def from_replication_params( cls, replication_jump: int = torch.cuda.device_count(), replication_factor: int = 2 ) -> 'CliqueReplicationStrategy': + """Instantiates process groups necessary for checkpoint replication. + + Training ranks are divided into `W // F` distinct groups of size `F`, where + `W` is the world size + and `F` is the `replication_factor`. + Each group consists of ranks: + + `n`, `n + J`, `n + 2J`, ..., `n + (F - 1)J`, + + where `J` is the `replication_jump` and `n = aJF + b`, with: + - `a = 0, 1, ..., (W / (JF)) - 1` + - `b = 0, 1, ..., J - 1`. + + Checkpoint shards are exchanged and fully replicated within each group. + + **Important:** The world size (`W`) must be divisible by `J * F`. + + This grouping enables replication across different failure domains by specifying + `J` equal to the failure blast radius. + + **Example:** + For a world size of 32, `replication_jump = 8`, and `replication_factor = 2`, + the replication groups (cliques) are: + + 0-8, 1-9, 2-10, 3-11, 4-12, 5-13, 6-14, 7-15, + 16-24, 17-25, 18-26, 19-27, 20-28, 21-29, 22-30, 23-31 + + Args: + replication_jump (int, optional): `J` in the formula above. Represents the gap between + successive ranks storing replicas of a given rank's data. + replication_factor (int, optional): `F` in the formula above. Denotes the number of + ranks storing replicas of a given rank's data. + """ logger.debug(f'Initializing {cls.__name__}') repl_process_groups_ranks: List[List[int]] = parse_group_sequence( replication_jump=replication_jump, @@ -252,8 +285,39 @@ def _eager_build(self) -> EagerT: class LazyCliqueReplicationStrategy(LazyReplicationStrategyBuilder[CliqueReplicationStrategy]): - """Lazy version of CliqueReplicationStrategy allowing to delay process group formation.""" + """Lazy version of CliqueReplicationStrategy allowing to delay process group formation. + + Training ranks are divided into `W // F` distinct groups of size `F`, where + `W` is the world size + and `F` is the `replication_factor`. + Each group consists of ranks: + + `n`, `n + J`, `n + 2J`, ..., `n + (F - 1)J`, + + where `J` is the `replication_jump` and `n = aJF + b`, with: + - `a = 0, 1, ..., (W / (JF)) - 1` + - `b = 0, 1, ..., J - 1`. + + Checkpoint shards are exchanged and fully replicated within each group. + **Important:** The world size (`W`) must be divisible by `J * F`. + + This grouping enables replication across different failure domains by specifying + `J` equal to the failure blast radius. + + **Example:** + For a world size of 32, `replication_jump = 8`, and `replication_factor = 2`, + the replication groups (cliques) are: + + 0-8, 1-9, 2-10, 3-11, 4-12, 5-13, 6-14, 7-15, + 16-24, 17-25, 18-26, 19-27, 20-28, 21-29, 22-30, 23-31 + + Args: + replication_jump (int, optional): `J` in the formula above. Represents the gap between + successive ranks storing replicas of a given rank's data. + replication_factor (int, optional): `F` in the formula above. Denotes the number of + ranks storing replicas of a given rank's data. + """ def __init__( self, replication_jump: int = torch.cuda.device_count(), replication_factor: int = 2 ): diff --git a/src/nvidia_resiliency_ext/ptl_resiliency/local_checkpoint_callback.py b/src/nvidia_resiliency_ext/ptl_resiliency/local_checkpoint_callback.py index 0405b2a..e3b5fe3 100644 --- a/src/nvidia_resiliency_ext/ptl_resiliency/local_checkpoint_callback.py +++ b/src/nvidia_resiliency_ext/ptl_resiliency/local_checkpoint_callback.py @@ -48,7 +48,22 @@ class LocalCheckpointCallback(pl.callbacks.ModelCheckpoint): - """ModelCheckpoint without most of the functionality. Only train_batch_end simple save.""" + """ModelCheckpoint with basic functionality. Only train_batch_end simple save. + + Simple callback for initiating local checkpoint save in `on_train_batch_end` method. + Since local checkpoints are ephemeral, they shouldn't be used for "major" checkpoint + types like `on_train_epoch_end`. + + This callback must be used in conjunction with the HierarchicalCheckpointIO, + since the only thing this callback really does is passing some options + to `trainer.save_checkpoint` which can be captured with HierarchicalCheckpointIO. + + Args: + every_n_train_steps (int, optional): controls local checkpointing interval in terms + of train iterations. Same semantic as in PTL ModelCheckpoint. + train_time_interval (int, optional): controls local checkpointing interval in terms + of wall time. Same semantics as in PTL ModelCheckpoint. + """ def __init__( self, @@ -89,9 +104,13 @@ def _save_last_checkpoint( class HierarchicalCheckpointIO(_WrappingCheckpointIO): """Wrapper for a global CheckpointIO enabling local checkpointing. - Must be used in conjunction with LocalCheckpointCallback. + Based on the presence of local checkpointing options in saving `storage_options`, + routes the save to the original global CheckpointIO or the local checkpoint manager. + + Must be used in conjunction with LocalCheckpointCallback which *initiates* + local checkpoint saving during training. - Arguments: + Args: wrapped_checkpoint_io (CheckpointIO): global CheckpointIO to wrap local_ckpt_manager (BaseCheckpointManager): local manager to use for local checkpoints get_global_ckpt_iteration_fn (Callable[[_PATH], int]): a function that @@ -175,7 +194,6 @@ def get_partial_wrapper_constructor( @abstractmethod def to_tensor_aware_state_dict(self, checkpoint: Dict[str, Any]) -> TensorAwareStateDict: - # TODO: consider providing a default implementation for a simple state dict raise NotImplementedError def from_tensor_aware_state_dict(self, tensor_aware_checkpoint: TensorAwareStateDict, **kwargs): diff --git a/tests/fault_tolerance/sim-multinode/.env b/tests/fault_tolerance/sim-multinode/.env deleted file mode 100644 index 2430667..0000000 --- a/tests/fault_tolerance/sim-multinode/.env +++ /dev/null @@ -1,2 +0,0 @@ -FT_DIR=/workspace/fault_tolerance -TEST_BASE_DIR=/workspace/fault_tolerance/tests/sim-multinode \ No newline at end of file diff --git a/tests/fault_tolerance/sim-multinode/README.txt b/tests/fault_tolerance/sim-multinode/README.txt deleted file mode 100644 index ae873eb..0000000 --- a/tests/fault_tolerance/sim-multinode/README.txt +++ /dev/null @@ -1,10 +0,0 @@ -This is test of fault handing in simulated (dockerized) multi node setup. -You need docker and docker-compose installed + NVIDIA extensions that allow to use GPU. -There should be GPU available, one is enough. - -Run command (current dir should be repo root dir) -/fault_tolerance$ docker-compose -f tests/sim-multinode/compose-c10d.yaml up --force-recreate --build -OR -/fault_tolerance$ docker-compose -f tests/sim-multinode/compose-etcd.yaml up --force-recreate --build - -Expected result is "SUCCEEDED (...)" line printed at the end. \ No newline at end of file diff --git a/tests/fault_tolerance/sim-multinode/compose-c10d.yaml b/tests/fault_tolerance/sim-multinode/compose-c10d.yaml deleted file mode 100644 index 17277f4..0000000 --- a/tests/fault_tolerance/sim-multinode/compose-c10d.yaml +++ /dev/null @@ -1,39 +0,0 @@ -version: '3' - -networks: - app-tier: - driver: bridge - -services: - - worker00: &worker - hostname: worker00 - build: ../../ - image: 'fault_tol_img' - command: 'ft_launcher --rdzv_backend=c10d --rdzv_endpoint=worker00:2323 --fault-tol-cfg-path=${TEST_BASE_DIR}/ft.yaml --max-restarts=100 --nnodes=4 --nproc-per-node=4 ${TEST_BASE_DIR}/test_worker.py' - networks: - - app-tier - ports: - - 2323:2323 - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: all - capabilities: [gpu] - - worker01: - <<: *worker - hostname: worker01 - ports: [] - - worker02: - <<: *worker - hostname: worker02 - ports: [] - - worker03: - <<: *worker - hostname: worker03 - ports: [] \ No newline at end of file diff --git a/tests/fault_tolerance/sim-multinode/compose-etcd.yaml b/tests/fault_tolerance/sim-multinode/compose-etcd.yaml deleted file mode 100644 index 4f7ab4e..0000000 --- a/tests/fault_tolerance/sim-multinode/compose-etcd.yaml +++ /dev/null @@ -1,47 +0,0 @@ -version: '3' - -networks: - app-tier: - driver: bridge - -services: - - etcd-server: - image: 'bitnami/etcd:3.4.30' - environment: - - ALLOW_NONE_AUTHENTICATION=yes - - ETCD_ADVERTISE_CLIENT_URLS=http://etcd-server:2379 - - ETCD_ENABLE_V2=1 - ports: - - 2379:2379 - - 2380:2380 - networks: - - app-tier - - worker00: &worker - hostname: worker00 - build: ../../ - image: 'fault_tol_img' - command: 'ft_launcher --rdzv_backend=etcd --rdzv_endpoint=etcd-server:2379 --fault-tol-cfg-path=${TEST_BASE_DIR}/ft.yaml --max-restarts=100 --nnodes=4 --nproc-per-node=4 ${TEST_BASE_DIR}/test_worker.py' - networks: - - app-tier - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: all - capabilities: [gpu] - - worker01: - <<: *worker - hostname: worker01 - - worker02: - <<: *worker - hostname: worker02 - - worker03: - <<: *worker - hostname: worker03 - diff --git a/tests/fault_tolerance/sim-multinode/ft.yaml b/tests/fault_tolerance/sim-multinode/ft.yaml deleted file mode 100644 index 02948cc..0000000 --- a/tests/fault_tolerance/sim-multinode/ft.yaml +++ /dev/null @@ -1,3 +0,0 @@ -fault_tolerance: - initial_rank_heartbeat_timeout: 3 - rank_heartbeat_timeout: 3 \ No newline at end of file diff --git a/tests/fault_tolerance/sim-multinode/test_worker.py b/tests/fault_tolerance/sim-multinode/test_worker.py deleted file mode 100644 index 93ddab3..0000000 --- a/tests/fault_tolerance/sim-multinode/test_worker.py +++ /dev/null @@ -1,152 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import random -import signal -import socket -import sys -import threading -import time - -import torch -import torch.distributed as dist - -from nvidia_resiliency_ext import fault_tolerance - -# Tester script for multi-node FT test -# Initializes fault tolerance, periodically sends heartbeats -# Crashes or hangs random rank after some delay -# Checks if workload is restarted after a failure and the same ranks land on the same nodes - -SIM_FAULT_BASE_DELAY = 2 -SIM_FAULT_MAX_RAND_DELAY = 3 -MAX_RUN_INTERVAL = 60 -STATE_FILE_PATH_PATT = "/workspace/_rank{rank}-assignment-test-state.json" -NUM_RUNS = 16 - - -def _print_on_rank0(msg): - if dist.get_rank() == 0: - print(msg) - - -def _setup_simulated_fault(): - rng = random.Random() - - fault_type = rng.choice(['rank_killed', 'rank_hanged']) - - rank_to_fail = rng.randint(0, dist.get_world_size() - 1) - rank_to_fail = torch.tensor([rank_to_fail]) - dist.broadcast(rank_to_fail, 0) - rank_to_fail = int(rank_to_fail.item()) - - rank = torch.distributed.get_rank() - if rank != rank_to_fail: - return - - if fault_type == 'rank_killed': - target_pid = os.getpid() - target_sig = signal.SIGKILL - elif fault_type == 'rank_hanged': - target_pid = os.getpid() - target_sig = signal.SIGSTOP - else: - raise Exception(f"Unknown fault type {fault_type}") - - delay = SIM_FAULT_BASE_DELAY + SIM_FAULT_MAX_RAND_DELAY * random.random() - - def __fault_thread(): - time.sleep(delay) - print( - f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n", - file=sys.stderr, - ) - os.kill(target_pid, target_sig) - - fault_sim_thread = threading.Thread(target=__fault_thread) - fault_sim_thread.daemon = True - fault_sim_thread.start() - - -def _load_state(rank): - fn = STATE_FILE_PATH_PATT.format(rank=rank) - if not os.path.exists(fn): - return {} - with open(fn, "r") as f: - return json.load(f) - - -def _get_new_state(prev_state): - curr_state = { - 'is_failed': False, - 'is_finished': False, - 'node': socket.gethostname(), - 'run_idx': prev_state['run_idx'] + 1 if prev_state else 0, - 'start_time': time.monotonic(), - } - if prev_state: - if prev_state['node'] != curr_state['node']: - curr_state['is_finished'] = True - curr_state['is_failed'] = True - curr_state['failure_reason'] = "Ranks assignment changed." - elif (curr_state['start_time'] - prev_state['start_time']) > MAX_RUN_INTERVAL: - curr_state['is_finished'] = True - curr_state['is_failed'] = True - curr_state['failure_reason'] = "Interval between runs exceeded limit." - - if curr_state['run_idx'] == NUM_RUNS: - curr_state['is_finished'] = True - - return curr_state - - -def _save_state(rank, state): - fn = STATE_FILE_PATH_PATT.format(rank=rank) - with open(fn, "w") as f: - json.dump(state, f) - - -if __name__ == '__main__': - dist.init_process_group(backend="gloo") - rank = dist.get_rank() - - print(f"Running rank {rank} on node {socket.gethostname()}.") - - loaded_state = _load_state(rank) - curr_state = _get_new_state(loaded_state) - _save_state(rank, curr_state) - - dist.barrier() - _print_on_rank0(f"### RUN {curr_state['run_idx']}/{NUM_RUNS} ###") - - if curr_state['is_finished']: - dist.barrier() - if curr_state['is_failed']: - print(f"TEST FAILED rank={rank} final state={curr_state}") - else: - assert curr_state['run_idx'] == NUM_RUNS - _print_on_rank0("TEST SUCCEEDED") - sys.exit(0) # return 0 so launcher wont respawn the workload - - ft_client = fault_tolerance.RankMonitorClient() - ft_client.init_workload_monitoring() - - _setup_simulated_fault() - - for i in range(1000000): - ft_client.send_heartbeat() - time.sleep(0.5) diff --git a/tests/ptl_resiliency/func/nemo20/local_ckpt_test.sh b/tests/ptl_resiliency/func/nemo20/local_ckpt_test.sh new file mode 100755 index 0000000..0c13323 --- /dev/null +++ b/tests/ptl_resiliency/func/nemo20/local_ckpt_test.sh @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +# +# Local checkpointing functional test, using NeMo 2.0 and FT callback. +# Based on `ft_test.sh`. +# This script should be run on a node with 8 GPUs and with docker available. +# It runs 4 containers that "simulate" 4 nodes, each with 2 GPUs. +# +# Should be run from the root of the repository: +# ./tests/ptl_resiliency/func/nemo20/local_ckpt_test.sh +# +# Expected result is "All tests passed." printed at the end. +# + +# TODO: consider merging with `ft_test.sh` + +set -x # show commands as they are executed +set -o pipefail # pipelined commands exit code is 1 if any command in a pipe fails + +REPO_ROOT=$(git rev-parse --show-toplevel) +FT_CONT_OUT_DIR="/mnt/ft_test_storage" +TOKENIZER_PATH="/mnt/nvdl/datasets/ft/models/llama/tokenizer.model" +echo EXTRA_CONTAINER_MOUNTS=${EXTRA_CONTAINER_MOUNTS} +CONTAINER_MOUNTS="-v ./ft_test_storage:${FT_CONT_OUT_DIR} -v ${TOKENIZER_PATH}:${TOKENIZER_PATH}:ro ${EXTRA_CONTAINER_MOUNTS}" + +FT_TEST_BASE_IMG="${FT_TEST_BASE_IMG:-''}" +TEST_IMG="ft_test_nemo_img" +CONTAINER_COMMON_ARGS="--rm --net testnet --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 ${CONTAINER_MOUNTS}" + +RDZV_PORT=2323 +COMMON_LAUNCHER_ARGS="--rdzv_backend=c10d --rdzv_endpoint=node0:${RDZV_PORT} --nnodes=4 --nproc-per-node=2" + +LAUNCHERS_WAIT_EXIT_CODE=0 + +function run_on_simulated_nodes { + # Run the same command on 4 simulated nodes, each with 2 GPUs + # NOTE: node0 is the rendezvous host node, hence it needs to expose the port. + cmd=$1 + LAUNCHERS_WAIT_EXIT_CODE=0 + rm -f ./node*log.txt + { docker run -h node0 -p ${RDZV_PORT}:${RDZV_PORT} --gpus='"device=0,1"' ${CONTAINER_COMMON_ARGS} ${TEST_IMG} bash -c "${cmd}" 2>&1 | tee "./node0log.txt" ; } & + { docker run -h node1 --gpus='"device=2,3"' ${CONTAINER_COMMON_ARGS} ${TEST_IMG} bash -c "${cmd}" 2>&1 | tee "./node1log.txt" ; } & + { docker run -h node2 --gpus='"device=4,5"' ${CONTAINER_COMMON_ARGS} ${TEST_IMG} bash -c "${cmd}" 2>&1 | tee "./node2log.txt" ; } & + { docker run -h node3 --gpus='"device=6,7"' ${CONTAINER_COMMON_ARGS} ${TEST_IMG} bash -c "${cmd}" 2>&1 | tee "./node3log.txt" ; } & + wait `jobs -p` + LAUNCHERS_WAIT_EXIT_CODE=$? +} + +function assert_log_contains { + expected_str="$1" + if ! grep -q "${expected_str}" ./node*log.txt ; then + echo "Expected string not found in logs from nodes: ${expected_str}" + exit 1 + fi +} + +####### PREPARE TEST ENVIRONMENT ##### + +set -e # exit on error during initialization + +# Build the test container with current sources +docker build --build-arg BASE_IMG="${FT_TEST_BASE_IMG}" -f ${REPO_ROOT}/tests/ptl_resiliency/func/nemo20/Dockerfile.ft_test -t ${TEST_IMG} ${REPO_ROOT} + +# Network for the containers +docker network create testnet || echo "Network 'testnet' already exists" + +set +e # some errors are expected in the tests + +######## TEST STAGE 1: LOCAL CKPT SAVE ######### + +mkdir -p ft_test_storage +docker run ${CONTAINER_COMMON_ARGS} ${TEST_IMG} bash -c "rm -rf ${FT_CONT_OUT_DIR}/default ${FT_CONT_OUT_DIR}/lightning_logs ${FT_CONT_OUT_DIR}/local_ckpt" + +run_on_simulated_nodes \ + "MEGATRON_LOGGING_LEVEL=10 ft_launcher --ft-param-initial_rank_heartbeat_timeout=600 --ft-param-rank_heartbeat_timeout=600 ${COMMON_LAUNCHER_ARGS} \ + ./tests/ptl_resiliency/func/nemo20/test_local_ckpt_llama3.py \ + --tokenizer-path=${TOKENIZER_PATH} \ + --log-dir=${FT_CONT_OUT_DIR} \ + --num-nodes=4 \ + --num-gpus=2 \ + --max-steps=100 \ + --local-checkpoint-interval 20" + +rm -rf ${FT_CONT_OUT_DIR}/local_ckpt/node1 + +######## TEST STAGE 2: LOCAL CKPT LOAD ######### + +run_on_simulated_nodes \ + "MEGATRON_LOGGING_LEVEL=10 ft_launcher --ft-param-initial_rank_heartbeat_timeout=600 --ft-param-rank_heartbeat_timeout=600 ${COMMON_LAUNCHER_ARGS} \ + ./tests/ptl_resiliency/func/nemo20/test_local_ckpt_llama3.py \ + --tokenizer-path=${TOKENIZER_PATH} \ + --log-dir=${FT_CONT_OUT_DIR} \ + --num-nodes=4 \ + --num-gpus=2 \ + --max-steps=200 \ + --local-checkpoint-interval 20" + +echo "LOADING DONE" +assert_log_contains "Resuming from a local checkpoint" +echo "All tests passed" \ No newline at end of file