Skip to content

Commit

Permalink
[FSDP][Docs] Tidy up FSDP ctor/api docs (pytorch#105847)
Browse files Browse the repository at this point in the history
- This PR rewords the `BackwardPrefetch` docs to make the tradeoffs clear in the first sentence of each with more technical details after.
- The only supported `_FSDPPolicy` is `ModuleWrapPolicy` at the time of writing this PR. We may add others in the future such as in my other PR stack. This PR removes `_FSDPPolicy` from the public docs.
- This provides some more details around `MixedPrecision` such as explaining that layer norm and batch norm accumulate in fp32.

Follow-ups:
- Why do we force batch norm modules to have FSDP applied separately? (E.g. was this because before batch norm kernels did not support fp16/bf16?) Like layer norm, this just means that the affine parameters are in fp32. Both already accumulate in fp32 even with fp16/bf16 inputs.
- Check the `param_init_fn` + `sync_module_states=True` usage.
Pull Request resolved: pytorch#105847
Approved by: https://github.com/rohan-varma
  • Loading branch information
awgu authored and pytorchmergebot committed Jul 25, 2023
1 parent 65bce81 commit 6655b65
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 127 deletions.
168 changes: 88 additions & 80 deletions torch/distributed/fsdp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from enum import auto, Enum

from typing import Optional, Sequence
from typing import Optional, Sequence, Type

import torch
from torch.nn.modules.batchnorm import _BatchNorm
Expand Down Expand Up @@ -71,34 +71,39 @@ class ShardingStrategy(Enum):

class BackwardPrefetch(Enum):
"""
This configures explicit backward prefetching, which can improve throughput
but may slightly increase peak memory usage.
For a single process group using NCCL backend, any collectives, even if
issued in different streams, contend for the same per-device NCCL stream,
which is why the relative order in which the collectives are issued matters
for overlapping. The different backward prefetching settings correspond to
different orderings.
- ``BACKWARD_PRE``: This prefetches the next set of parameters before the
current set of parameter's gradient computation. This improves backward
pass throughput by overlapping communication (next all-gather) and
computation (current gradient computation).
- ``BACKWARD_POST``: This prefetches the next set of parameters after the
current set of parameter's gradient computation. This may improve
backward pass throughput by overlapping communication (current
reduce-scatter) and computation (next gradient computation).
Specifically, the next all-gather is reordered to be before the current
reduce-scatter.
.. note:: If the increase in peak memory usage from prefetching is an
issue, you may consider passing ``limit_all_gathers=True`` to the FSDP
constructor, which may help reduce peak memory usage in some cases.
This configures explicit backward prefetching, which improves throughput by
enabling communication and computation overlap in the backward pass at the
cost of slightly increased memory usage.
- ``BACKWARD_PRE``: This enables the most overlap but increases memory
usage the most. This prefetches the next set of parameters *before* the
current set of parameters' gradient computation. This overlaps the *next
all-gather* and the *current gradient computation*, and at the peak, it
holds the current set of parameters, next set of parameters, and current
set of gradients in memory.
- ``BACKWARD_POST``: This enables less overlap but requires less memory
usage. This prefetches the next set of parameters *after* the current
set of parameters' gradient computation. This overlaps the *current
reduce-scatter* and the *next gradient computation*, and it frees the
current set of parameters before allocating memory for the next set of
parameters, only holding the next set of parameters and current set of
gradients in memory at the peak.
- FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
the backward prefetching altogether. This has no overlap and does not
increase memory usage. In general, we do not recommend this setting since
it may degrade throughput significantly.
For more technical context: For a single process group using NCCL backend,
any collectives, even if issued from different streams, contend for the
same per-device NCCL stream, which implies that the relative order in which
the collectives are issued matters for overlapping. The two backward
prefetching values correspond to different issue orders.
"""

# NOTE: For both modes, the ordering that defines "current" and "next" is
# not always correct in the current implementation, so this may cause some
# performance regression for some models.
# not always exact in the current implementation. A mistargeted prefetch
# simply means that the parameter memory is allocated earlier than needed,
# possibly increasing peak memory usage, but does not affect correctness.
BACKWARD_PRE = auto()
BACKWARD_POST = auto()

Expand All @@ -109,37 +114,50 @@ class MixedPrecision:
This configures FSDP-native mixed precision training.
Attributes:
param_dtype (torch.dtype): This specifies the dtype for model
parameters, inputs (when ``cast_forward_inputs`` or
``cast_root_forward_inputs``is set to
``True``), and therefore the dtype for computation.
However, outside the forward and backward passes, parameters are in
full precision. Model checkpointing always happens in full
precision.
reduce_dtype (torch.dtype): This specifies the dtype for gradient
reduction, which is permitted to differ from ``param_dtype``.
buffer_dtype (torch.dtype): This specifies the dtype for buffers. FSDP
does not shard buffers, casts them to ``buffer_dtype`` in the first
forward pass, and keeps them in that dtype thereafter. Model
checkpointing always happens in full precision.
keep_low_precision_grads (bool): This specifies whether to upcast
gradients back to the full parameter precision after the backward
pass. This may be set to ``False`` to save memory if using custom
optimizers that can perform the optimizer step in ``reduce_dtype``.
param_dtype (Optional[torch.dtype]): This specifies the dtype for model
parameters during forward and backward and thus the dtype for
forward and backward computation. Outside forward and backward, the
*sharded* parameters are kept in full precision (e.g. for the
optimizer step), and for model checkpointing, the parameters are
always saved in full precision. (Default: ``None``)
reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
gradient reduction (i.e. reduce-scatter or all-reduce). If this is
``None`` but ``param_dtype`` is not ``None``, then this takes on
the ``param_dtype`` value, still running gradient reduction in low
precision. This is permitted to differ from ``param_dtype``, e.g.
to force gradient reduction to run in full precision. (Default:
``None``)
buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
buffers. FSDP does not shard buffers. Rather, FSDP casts them to
``buffer_dtype`` in the first forward pass and keeps them in that
dtype thereafter. For model checkpointing, the buffers are saved
in full precision except for ``LOCAL_STATE_DICT``. (Default:
``None``)
keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
gradients to full precision after the backward pass in preparation
for the optimizer step. If ``True``, then FSDP keeps the gradients
in the dtype used for gradient reduction, which can save memory if
using a custom optimizer that supports running in low precision.
(Default: ``False``)
cast_forward_inputs (bool): Cast floating point tensors in the forward
arguments and keyword arguments to ``param_dtype``.
cast_forward_inputs (bool): If ``True``, then this FSDP module casts
its forward args and kwargs to ``param_dtype``. This is to ensure
that parameter and input dtypes match for forward computation, as
required by many ops. This may need to be set to ``True`` when only
applying mixed precision to some but not all FSDP modules, in which
case a mixed-precision FSDP submodule needs to recast its inputs.
(Default: ``False``)
cast_root_forward_inputs (bool): Cast floating point tensors in the forward
arguments and keyword arguments to ``param_dtype`` for the root FSDP instance.
It takes precedence over ``cast_forward_inputs`` for the root FSDP instance.
(Default: ``True``)
_module_classes_to_ignore: (Sequence[type]): Module classes to ignore
for mixed precision. This will make the specified ``nn.Module`` types ignore mixed precision,
by wrapping them in their own FSDP unit and setting ``mixed_precision=None``. Note that
this setting is only relevant for auto wrapping with ``auto_wrap_policy``, and that this
implies the ultimate wrapping of your FSDP module will be different than what the policy
specifies. Note that this API is experimental and subject to change.
cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
casts its forward args and kwargs to ``param_dtype``, overriding
the value of ``cast_forward_inputs``. For non-root FSDP modules,
this does not do anything. (Default: ``True``)
_module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
module classes to ignore for mixed precision when using an
``auto_wrap_policy``: Modules of these classes will have FSDP
applied to them separately with mixed precision disabled (meaning
that the final FSDP construction would deviate from the specified
policy). If ``auto_wrap_policy`` is not specified, then this does
not do anything. This API is experimental and subject to change.
(Default: ``(_BatchNorm,)``)
.. note:: This API is experimental and subject to change.
Expand All @@ -148,27 +166,18 @@ class MixedPrecision:
.. note:: In ``summon_full_params``, parameters are forced to full
precision, but buffers are not.
.. note:: ``state_dict`` checkpoints parameters and buffers in full
precision. For buffers, this is only supported for
``StateDictType.FULL_STATE_DICT``.
.. note:: Layer norm and batch norm accumulate in ``float32`` even when
their inputs are in a low precision like ``float16`` or ``bfloat16``.
Disabling FSDP's mixed precision for those norm modules only means that
the affine parameters are kept in ``float32``. However, this incurs
separate all-gathers and reduce-scatters for those norm modules, which
may be inefficient, so if the workload permits, the user should prefer
to still apply mixed precision to those modules.
.. note:: Each low precision dtype must be specified explicitly. For
example, ``MixedPrecision(reduce_dtype=torch.float16)`` only specifies
the reduction dtype to be low precision, and FSDP will not cast
parameters or buffers.
.. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
happens in ``param_dtype`` if specified or the original parameter dtype
otherwise.
.. note:: If the user passes a model with ``BatchNorm`` modules and an
``auto_wrap_policy`` to the FSDP constructor, then FSDP will disable
mixed precision for ``BatchNorm`` modules by wrapping them separately
in their own FSDP instance with mixed precision disabled. This is due
to some missing low precision ``BatchNorm`` kernels. If the user does
not use an ``auto_wrap_policy``, then the user must take care to not
use mixed precision for FSDP instances containing ``BatchNorm``
modules.
.. note:: By default, if the user passes a model with any ``_BatchNorm``
modules and specifies an ``auto_wrap_policy``, then the batch norm
modules will have FSDP applied to them separately with mixed precision
disabled. See the ``_module_classes_to_ignore`` argument.
.. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
``cast_forward_inputs=False`` by default. For the root FSDP instance,
Expand Down Expand Up @@ -214,7 +223,7 @@ class MixedPrecision:
keep_low_precision_grads: bool = False
cast_forward_inputs: bool = False
cast_root_forward_inputs: bool = True
_module_classes_to_ignore: Optional[Sequence[type]] = (_BatchNorm,)
_module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)


@dataclass
Expand All @@ -224,10 +233,9 @@ class CPUOffload:
Attributes:
offload_params (bool): This specifies whether to offload parameters to
CPU when not involved in computation. If enabled, this implicitly
offloads gradients to CPU as well. This is to support the optimizer
step, which requires parameters and gradients to be on the same
device.
CPU when not involved in computation. If ``True``, then this
offloads gradients to CPU as well, meaning that the optimizer step
runs on CPU.
"""

offload_params: bool = False
Expand Down
Loading

0 comments on commit 6655b65

Please sign in to comment.