Skip to content

Commit

Permalink
feat: support custom amp cast list (#685)
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh authored Jul 11, 2023
1 parent 2f24f90 commit d627dc2
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 8 deletions.
8 changes: 8 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def create_parser():
'Choice: O0 - all FP32, O1 - only cast ops in white-list to FP16, '
'O2 - cast all ops except for blacklist to FP16, '
'O3 - cast all ops to FP16. (default="O0").')
group.add_argument('--amp_cast_list', type=str, default=None,
help='At the cell level, customize the black-list or white-list to cast cells to '
'FP16 based on the value of argument "amp_level". If None, use the built-in '
'black-list and white-list. (default=None) '
'If amp_level="O0" or "O3", this argument has no effect. '
'If amp_level="O1", cast all cells in the white-list to FP16. '
'If amp_level="O2", cast all cells except for the black-list to FP16. '
'Example: "[nn.Conv1d, nn.Conv2d]" or "[nn.BatchNorm1d, nn.BatchNorm2d]".')
group.add_argument('--loss_scale_type', type=str, default='fixed',
choices=['fixed', 'dynamic', 'auto'],
help='The type of loss scale (default="fixed")')
Expand Down
115 changes: 115 additions & 0 deletions mindcv/utils/amp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
""" auto mixed precision related functions """
from mindspore import dtype as mstype
from mindspore import nn
from mindspore.ops import functional as F

# from mindspore.amp import LossScaler # this line of code leads to “get rank id error” in modelarts
try:
Expand Down Expand Up @@ -61,3 +64,115 @@ def unscale(self, inputs):

def adjust(self, grads_finite):
return True


AMP_WHITE_LIST = (
nn.Dense,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Conv1dTranspose,
nn.Conv2dTranspose,
nn.Conv3dTranspose,
)

AMP_BLACK_LIST = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
)


class _OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16"

def __init__(self, op):
super(_OutputTo16, self).__init__(auto_prefix=False)
self._op = op

def construct(self, x):
return F.cast(self._op(x), mstype.float16)


class _OutputTo32(nn.Cell):
"Wrap cell for amp. Cast network output back to float32"

def __init__(self, op):
super(_OutputTo32, self).__init__(auto_prefix=False)
self._op = op

def construct(self, x):
return F.cast(self._op(x), mstype.float32)


def _auto_white_list(network, white_list=None):
"""process the white list of network."""
if white_list is None:
white_list = AMP_WHITE_LIST
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, white_list):
network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16))
change = True
else:
_auto_white_list(subcell, white_list)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())


def _auto_black_list(network, black_list=None):
"""process the black list of network."""
if black_list is None:
black_list = AMP_BLACK_LIST
network.to_float(mstype.float16)
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, black_list):
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
change = True
else:
_auto_black_list(subcell, black_list)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())


def auto_mixed_precision(network, amp_level="O0", amp_cast_list=None):
"""
auto mixed precision function.
Args:
network (Cell): Definition of the network.
amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: "O0".
- "O0": Do not change.
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
- "O2": Cast network to float16, keep operators in black_list run in float32,
- "O3": Cast network to float16.
amp_cast_list: At the cell level, customize the list to cast ops to FP16.
Raises:
ValueError: If amp level is not supported.
"""
if amp_cast_list is not None:
amp_cast_list = eval(amp_cast_list)
if isinstance(amp_cast_list, list):
amp_cast_list = tuple(amp_cast_list)

if amp_level == "O0":
pass
elif amp_level == "O1":
_auto_white_list(network, amp_cast_list)
elif amp_level == "O2":
_auto_black_list(network, amp_cast_list)
elif amp_level == "O3":
network.to_float(mstype.float16)
else:
raise ValueError("The amp level {} is not supported".format(amp_level))
50 changes: 43 additions & 7 deletions mindcv/utils/trainer_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
from typing import Union
from typing import Optional, Union

import mindspore as ms
from mindspore import Tensor, context, nn
from mindspore import Tensor, context
from mindspore import dtype as mstype
from mindspore import nn
from mindspore.ops import functional as F
from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model

from .amp import auto_mixed_precision
from .train_step import TrainStep

__all__ = [
Expand All @@ -29,22 +33,53 @@ def get_metrics(num_classes):
return metrics


def require_customized_train_step(ema: bool = False, clip_grad: bool = False, gradient_accumulation_steps: int = 1):
def require_customized_train_step(
ema: bool = False,
clip_grad: bool = False,
gradient_accumulation_steps: int = 1,
amp_cast_list: Optional[str] = None,
):
if ema:
return True
if clip_grad:
return True
if gradient_accumulation_steps > 1:
return True
if amp_cast_list:
return True
return False


def add_loss_network(network, loss_fn, amp_level):
"""Add loss network."""

class WithLossCell(nn.Cell):
"Wrap loss for amp. Cast network output back to float32"

def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn

def construct(self, data, label):
out = self._backbone(data)
label = F.mixed_precision_cast(mstype.float32, label)
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)

if amp_level == "O2" or amp_level == "O3":
network = WithLossCell(network, loss_fn)
else:
network = nn.WithLossCell(network, loss_fn)
return network


def create_trainer(
network: nn.Cell,
loss: nn.Cell,
optimizer: nn.Cell,
metrics: Union[dict, set],
amp_level: str,
amp_cast_list: str,
loss_scale_type: str,
loss_scale: float = 1.0,
drop_overflow_update: bool = False,
Expand All @@ -62,6 +97,7 @@ def create_trainer(
optimizer: The optimizer for training.
metrics: The metrics for model evaluation.
amp_level: The level of auto mixing precision training.
amp_cast_list: At the cell level, custom casting the cell to FP16.
loss_scale_type: The type of loss scale.
loss_scale: The value of loss scale.
drop_overflow_update: Whether to execute optimizer if there is an overflow.
Expand All @@ -84,7 +120,7 @@ def create_trainer(
if gradient_accumulation_steps < 1:
raise ValueError("`gradient_accumulation_steps` must be >= 1!")

if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps):
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list):
mindspore_kwargs = dict(
network=network,
loss_fn=loss,
Expand All @@ -111,8 +147,9 @@ def create_trainer(
raise ValueError(f"Loss scale type only support ['fixed', 'dynamic', 'auto'], but got{loss_scale_type}.")
model = Model(**mindspore_kwargs)
else: # require customized train step
net_with_loss = nn.WithLossCell(network, loss)
ms.amp.auto_mixed_precision(net_with_loss, amp_level=amp_level)
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
auto_mixed_precision(network, amp_level, amp_cast_list)
net_with_loss = add_loss_network(network, loss, amp_level)
train_step_kwargs = dict(
network=net_with_loss,
optimizer=optimizer,
Expand Down Expand Up @@ -145,7 +182,6 @@ def create_trainer(
)
train_step_kwargs["scale_sense"] = update_cell
train_step_cell = TrainStep(**train_step_kwargs).set_train()
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2])
# todo: do we need to set model._loss_scale_manager
return model
8 changes: 7 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ def train(args):
if (
args.loss_scale_type == "fixed"
and args.drop_overflow_update is False
and not require_customized_train_step(args.ema, args.clip_grad, args.gradient_accumulation_steps)
and not require_customized_train_step(
args.ema,
args.clip_grad,
args.gradient_accumulation_steps,
args.amp_cast_list,
)
):
optimizer_loss_scale = args.loss_scale
else:
Expand Down Expand Up @@ -236,6 +241,7 @@ def train(args):
optimizer,
metrics,
amp_level=args.amp_level,
amp_cast_list=args.amp_cast_list,
loss_scale_type=args.loss_scale_type,
loss_scale=args.loss_scale,
drop_overflow_update=args.drop_overflow_update,
Expand Down

0 comments on commit d627dc2

Please sign in to comment.