Skip to content

Commit

Permalink
update-docs (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
srogawski-nvidia authored Dec 19, 2024
1 parent b1da821 commit 7da4d2f
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 292 deletions.
1 change: 0 additions & 1 deletion docs/source/checkpointing/local/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ API documentation
api/base_ckpt_manager
api/local_ckpt_manager
api/replication
api/group_utils
api/base_state_dict
api/basic_state_dict
2 changes: 1 addition & 1 deletion docs/source/checkpointing/local/api/basic_state_dict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ BasicTensorAwareStateDict
=========================

.. automodule:: nvidia_resiliency_ext.checkpointing.local.basic_state_dict
:members:
:members: BasicTensorAwareStateDict
:undoc-members:
:show-inheritance:
7 changes: 0 additions & 7 deletions docs/source/checkpointing/local/api/group_utils.rst

This file was deleted.

14 changes: 10 additions & 4 deletions docs/source/checkpointing/local/usage_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ Requirements
Requirements for `LocalCheckpointManager`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- The directory specified by the `root_local_ckpt_dir` parameter must have enough storage capacity to hold
at least two checkpoint parts (clean and dirty) per rank.
at least two checkpoint parts (clean and dirty) per rank
multiplied by the replication factor defined by the replication strategy (`repl_strategy`).
- If a local checkpoint had been created with replication being enabled, it's recommended to enable replication also
when loading that checkpoint, in which case the replication parameters
(i.e. `world_size`, `--replication-jump` and `--replication-factor`) must be the same as during save.
If replication is disabled during load, the replicas are ignored even if available which might lead to
inability to recover from an otherwise complete checkpoint.
- All training ranks must call `LocalCheckpointManager` methods (`save`, `load`, `find_latest`) at once,
otherwise the training ends up in a corrupted state (a NCCL collective hang or tensor allocation OOM).

Requirements for `BasicTensorAwareStateDict`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -54,11 +61,11 @@ This is a basic reference and may omit specific implementation details:
if iteration != -1:
ta_state_dict, ckpt_part_id = ckpt_manager.load()
# Use the loaded state_dict to resume training
state_dict = ta_state_dict.state_dict
model.load_state_dict(ta_state_dict.state_dict)
else:
# An iteration value of -1 indicates that no local checkpoint was found.
# In this case, either return an error or initialize the model from scratch.
return -1
print('Starting training from scratch')
# Training loop
while True:
Expand Down Expand Up @@ -95,7 +102,6 @@ The retrieval mechanism is seamlessly integrated into the LocalCheckpointManager
Asynchronous Checkpoint Saving
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The `LocalCheckpointManager` supports both synchronous and asynchronous saving,
>>>>>>> skierat/documentation
controlled by the `is_async` parameter in the `save(...)` method.

- Synchronous Save: When `is_async` is set to `False`, the `save(...)` method
Expand Down
13 changes: 12 additions & 1 deletion docs/source/inprocess/usage_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ Known issues
participating in a Gloo collective stops making forward progress, the
remaining ranks would wait till :py:class:`ProcessGroupGloo` timeout is
exceeded; a workaround is to specify a short timeout for the ``gloo``
backend to enable faster restarts
backend to enable faster restarts.
2. NCCL collective kernel termination is implemented by periodically checking a
flag residing in mapped memory, and exiting from the kernel if the flag is
set. Interval of checking for this flag is controlled by
Expand All @@ -521,3 +521,14 @@ Known issues
5. NCCL net plugins must be disabled by setting ``NCCL_NET_PLUGIN`` environment
variable to ``"none"``. This issue will be addressed in future NCCL
versions.
6. :py:class:`nvidia_resiliency_ext.inprocess.Wrapper` is not fully compatible with
:py:func:`torch.distributed.run`. :py:func:`torch.distributed.run`
automatically terminates all worker processes if any one of them fails, in
this case :py:class:`nvidia_resiliency_ext.inprocess.Wrapper` can only recover from transient
faults that don't cause termination of worker processes.
7. By default, PyTorch NCCL Watchdog forcefully terminates the process if NCCL
call returns an error, or if CUDA context was corrupted. Forceful
termination of the worker process prevents :py:class:`nvidia_resiliency_ext.inprocess.Wrapper`
from restarting the wrapper function. A workaround is to set
``TORCH_NCCL_RETHROW_CUDA_ERRORS`` environment variable to ``0``, to avoid
rethrowing CUDA and NCCL errors in PyTorch NCCL Watchdog.
37 changes: 20 additions & 17 deletions examples/checkpointing/local_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@

from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue
from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \
LocalCheckpointManager
from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import parse_group_sequence, GroupWrapper
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import CliqueReplicationStrategy
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
LocalCheckpointManager,
)
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
CliqueReplicationStrategy,
)

# Set up basic logging configuration
logging.basicConfig(level=logging.INFO)


def parse_args():
parser = argparse.ArgumentParser(
description='Local Checkpointing Basic Example',
Expand All @@ -44,9 +47,13 @@ def parse_args():
'--replication_jump',
default=4,
type=int,
help="Specifies `k` such that replica of rank `n` is stored on ranks"
"`n+k`, `n+2k`, ..., `n+rk`. `r` is the --replication-factor"
"Needs to be specified if using --replication and have the same value on all ranks",
help=(
"Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
"Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
"This flag has an effect only if --replication is used. "
"and must be consistent across all ranks. "
"The default value of 4 is for demonstration purposes and can be adjusted as needed."
),
)
parser.add_argument(
'--replication_factor',
Expand Down Expand Up @@ -89,24 +96,19 @@ def init_distributed_backend(backend="nccl"):
logging.error(f"Error initializing the distributed backend: {e}")
raise


def create_checkpoint_manager(args):
if args.replication:
logging.info("Creating CliqueReplicationStrategy.")
repl_process_groups_ranks : List[List[int]] = parse_group_sequence(
replication_jump=args.replication_jump,
replication_factor=args.replication_factor,
world_size=dist.get_world_size()
repl_strategy = CliqueReplicationStrategy.from_replication_params(
args.replication_jump, args.replication_factor
)
repl_process_groups: List[torch.distributed.ProcessGroup] = [
torch.distributed.new_group(g) for g in repl_process_groups_ranks
]
my_process_group = GroupWrapper.from_list_of_groups(repl_process_groups)
repl_strategy = CliqueReplicationStrategy(my_process_group)
else:
repl_strategy = None

return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy)


def save(args, ckpt_manager, async_queue, model, iteration):
# Create Tensor-Aware State Dict
ta_state_dict = BasicTensorAwareStateDict(model.state_dict())
Expand All @@ -133,6 +135,7 @@ def load(args, ckpt_manager):
logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})")
return ta_state_dict.state_dict


def main():
args = parse_args()
logging.info(f'{args}')
Expand All @@ -147,7 +150,7 @@ def main():
ckpt_manager = create_checkpoint_manager(args)
async_queue = AsyncCallsQueue() if args.async_save else None

iteration = 123 # training iteration (used as training state id)
iteration = 123 # training iteration (used as training state id)

# Local checkpointing save
save(args, ckpt_manager, async_queue, model, iteration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(self, ckpt_id):
super().__init__(message)



class BaseCheckpointManager(ABC):
"""
The Base Checkpoint Manager provides an interface for integrating different checkpoint managers,
Expand Down Expand Up @@ -159,6 +158,8 @@ def find_latest(self):
If no complete checkpoints are found, the method returns -1.
All training ranks have to call this method at once.
Returns:
int: The iteration number of the most recent complete checkpoint,
or -1 if no checkpoints are available.
Expand Down Expand Up @@ -197,6 +198,8 @@ def load(self) -> Tuple[TensorAwareStateDict, str]:
Ensure that `find_latest()` has been called first to identify the latest checkpoint.
All training ranks have to call this method at once.
Returns:
Tuple[TensorAwareStateDict, str]
- `state_dict`: The state dictionary loaded from the most recent complete checkpoint.
Expand Down Expand Up @@ -231,6 +234,8 @@ def save(
and the function returns an `AsyncRequest` object. Otherwise, the save operation
is completed synchronously.
All training ranks have to call this method at once.
Args:
state_dict (dict): The state dictionary to be saved.
iteration (int): The iteration number for identifying the checkpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,24 @@

class LocalCheckpointManager(BaseCheckpointManager):
"""Local Checkpoint Manager designed for handling checkpoints on local storage devices
like SSDs or RAM disks."""
like SSDs or RAM disks.
Args:
root_local_ckpt_dir (str, Path): root checkpoint directory on local storage.
Checkpoints from different iterations can be saved within the same root directory,
as each will have a unique name
session_id (str, optional): adds additional identification opportunity for local
checkpoints used in different training workloads. An example use case
is the `root_local_ckpt_dir` being configured by the cluster administrator
(e.g. /tmp/...) and `session_id` configured by the end user for
differentiating different local checkpoints.
repl_strategy (ReplicationStrategy, optional): strategy used to perform local checkpoint
shards replication.
"""

def __init__(
self,
root_local_ckpt_dir,
root_local_ckpt_dir: Union[str, Path],
session_id: str = '',
repl_strategy: Optional[ReplicationStrategy] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,39 @@ def retrieve_execute(self, *args, **kwargs):
def from_replication_params(
cls, replication_jump: int = torch.cuda.device_count(), replication_factor: int = 2
) -> 'CliqueReplicationStrategy':
"""Instantiates process groups necessary for checkpoint replication.
Training ranks are divided into `W // F` distinct groups of size `F`, where
`W` is the world size
and `F` is the `replication_factor`.
Each group consists of ranks:
`n`, `n + J`, `n + 2J`, ..., `n + (F - 1)J`,
where `J` is the `replication_jump` and `n = aJF + b`, with:
- `a = 0, 1, ..., (W / (JF)) - 1`
- `b = 0, 1, ..., J - 1`.
Checkpoint shards are exchanged and fully replicated within each group.
**Important:** The world size (`W`) must be divisible by `J * F`.
This grouping enables replication across different failure domains by specifying
`J` equal to the failure blast radius.
**Example:**
For a world size of 32, `replication_jump = 8`, and `replication_factor = 2`,
the replication groups (cliques) are:
0-8, 1-9, 2-10, 3-11, 4-12, 5-13, 6-14, 7-15,
16-24, 17-25, 18-26, 19-27, 20-28, 21-29, 22-30, 23-31
Args:
replication_jump (int, optional): `J` in the formula above. Represents the gap between
successive ranks storing replicas of a given rank's data.
replication_factor (int, optional): `F` in the formula above. Denotes the number of
ranks storing replicas of a given rank's data.
"""
logger.debug(f'Initializing {cls.__name__}')
repl_process_groups_ranks: List[List[int]] = parse_group_sequence(
replication_jump=replication_jump,
Expand Down Expand Up @@ -252,8 +285,39 @@ def _eager_build(self) -> EagerT:


class LazyCliqueReplicationStrategy(LazyReplicationStrategyBuilder[CliqueReplicationStrategy]):
"""Lazy version of CliqueReplicationStrategy allowing to delay process group formation."""
"""Lazy version of CliqueReplicationStrategy allowing to delay process group formation.
Training ranks are divided into `W // F` distinct groups of size `F`, where
`W` is the world size
and `F` is the `replication_factor`.
Each group consists of ranks:
`n`, `n + J`, `n + 2J`, ..., `n + (F - 1)J`,
where `J` is the `replication_jump` and `n = aJF + b`, with:
- `a = 0, 1, ..., (W / (JF)) - 1`
- `b = 0, 1, ..., J - 1`.
Checkpoint shards are exchanged and fully replicated within each group.
**Important:** The world size (`W`) must be divisible by `J * F`.
This grouping enables replication across different failure domains by specifying
`J` equal to the failure blast radius.
**Example:**
For a world size of 32, `replication_jump = 8`, and `replication_factor = 2`,
the replication groups (cliques) are:
0-8, 1-9, 2-10, 3-11, 4-12, 5-13, 6-14, 7-15,
16-24, 17-25, 18-26, 19-27, 20-28, 21-29, 22-30, 23-31
Args:
replication_jump (int, optional): `J` in the formula above. Represents the gap between
successive ranks storing replicas of a given rank's data.
replication_factor (int, optional): `F` in the formula above. Denotes the number of
ranks storing replicas of a given rank's data.
"""
def __init__(
self, replication_jump: int = torch.cuda.device_count(), replication_factor: int = 2
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,22 @@


class LocalCheckpointCallback(pl.callbacks.ModelCheckpoint):
"""ModelCheckpoint without most of the functionality. Only train_batch_end simple save."""
"""ModelCheckpoint with basic functionality. Only train_batch_end simple save.
Simple callback for initiating local checkpoint save in `on_train_batch_end` method.
Since local checkpoints are ephemeral, they shouldn't be used for "major" checkpoint
types like `on_train_epoch_end`.
This callback must be used in conjunction with the HierarchicalCheckpointIO,
since the only thing this callback really does is passing some options
to `trainer.save_checkpoint` which can be captured with HierarchicalCheckpointIO.
Args:
every_n_train_steps (int, optional): controls local checkpointing interval in terms
of train iterations. Same semantic as in PTL ModelCheckpoint.
train_time_interval (int, optional): controls local checkpointing interval in terms
of wall time. Same semantics as in PTL ModelCheckpoint.
"""

def __init__(
self,
Expand Down Expand Up @@ -89,9 +104,13 @@ def _save_last_checkpoint(
class HierarchicalCheckpointIO(_WrappingCheckpointIO):
"""Wrapper for a global CheckpointIO enabling local checkpointing.
Must be used in conjunction with LocalCheckpointCallback.
Based on the presence of local checkpointing options in saving `storage_options`,
routes the save to the original global CheckpointIO or the local checkpoint manager.
Must be used in conjunction with LocalCheckpointCallback which *initiates*
local checkpoint saving during training.
Arguments:
Args:
wrapped_checkpoint_io (CheckpointIO): global CheckpointIO to wrap
local_ckpt_manager (BaseCheckpointManager): local manager to use for local checkpoints
get_global_ckpt_iteration_fn (Callable[[_PATH], int]): a function that
Expand Down Expand Up @@ -175,7 +194,6 @@ def get_partial_wrapper_constructor(

@abstractmethod
def to_tensor_aware_state_dict(self, checkpoint: Dict[str, Any]) -> TensorAwareStateDict:
# TODO: consider providing a default implementation for a simple state dict
raise NotImplementedError

def from_tensor_aware_state_dict(self, tensor_aware_checkpoint: TensorAwareStateDict, **kwargs):
Expand Down
2 changes: 0 additions & 2 deletions tests/fault_tolerance/sim-multinode/.env

This file was deleted.

10 changes: 0 additions & 10 deletions tests/fault_tolerance/sim-multinode/README.txt

This file was deleted.

Loading

0 comments on commit 7da4d2f

Please sign in to comment.