Skip to content

Commit

Permalink
Don't recompute scaling factor in activation checkpointing
Browse files Browse the repository at this point in the history
Summary:
Add the policy "layer_based_auto_wrap_policy_float8_training". It skips the recompute of float8 scaling factor (a scaler) to improve the latency.

To enable it, change the config file like: P1690229394

Reviewed By: yoyoyocmu

Differential Revision: D65360604
  • Loading branch information
y-sq authored and facebook-github-bot committed Dec 3, 2024
1 parent dc5bafd commit 8d28ecc
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

from dataclasses import asdict, dataclass
from functools import partial
from typing import Any, Callable, cast, Dict, Iterable, Optional, Union
from typing import (
Any,
Callable,
cast,
ContextManager,
Dict,
Iterable,
Optional,
Tuple,
Union,
)

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -165,6 +175,9 @@ class ActivationCheckpointParams:
checkpoint_impl: CheckpointImpl
check_fn: Callable[[torch.nn.Module], bool] = lambda _: True
auto_wrap_policy: Optional[Callable[[torch.nn.Module, bool, int], bool]] = None
context_fn: Optional[Callable[[], Tuple[ContextManager, ContextManager]]] = (
None # pypr-ignore
)


def prepare_ddp(
Expand Down Expand Up @@ -357,9 +370,14 @@ def prepare_module(
checkpoint_impl = activation_checkpoint_params.checkpoint_impl
check_fn = activation_checkpoint_params.check_fn
auto_wrap_policy = activation_checkpoint_params.auto_wrap_policy
context_fn = activation_checkpoint_params.context_fn
additional_params = {}
if context_fn:
additional_params["context_fn"] = context_fn
custom_checkpoint_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=checkpoint_impl,
**additional_params,
)
apply_activation_checkpointing(
module,
Expand Down

0 comments on commit 8d28ecc

Please sign in to comment.