diff --git a/config.py b/config.py index e0208b07..b78e5a9b 100644 --- a/config.py +++ b/config.py @@ -240,6 +240,14 @@ def create_parser(): 'Choice: O0 - all FP32, O1 - only cast ops in white-list to FP16, ' 'O2 - cast all ops except for blacklist to FP16, ' 'O3 - cast all ops to FP16. (default="O0").') + group.add_argument('--amp_cast_list', type=str, default=None, + help='At the cell level, customize the black-list or white-list to cast cells to ' + 'FP16 based on the value of argument "amp_level". If None, use the built-in ' + 'black-list and white-list. (default=None) ' + 'If amp_level="O0" or "O3", this argument has no effect. ' + 'If amp_level="O1", cast all cells in the white-list to FP16. ' + 'If amp_level="O2", cast all cells except for the black-list to FP16. ' + 'Example: "[nn.Conv1d, nn.Conv2d]" or "[nn.BatchNorm1d, nn.BatchNorm2d]".') group.add_argument('--loss_scale_type', type=str, default='fixed', choices=['fixed', 'dynamic', 'auto'], help='The type of loss scale (default="fixed")') diff --git a/mindcv/utils/amp.py b/mindcv/utils/amp.py index 92ee5177..391b718d 100644 --- a/mindcv/utils/amp.py +++ b/mindcv/utils/amp.py @@ -1,4 +1,7 @@ """ auto mixed precision related functions """ +from mindspore import dtype as mstype +from mindspore import nn +from mindspore.ops import functional as F # from mindspore.amp import LossScaler # this line of code leads to “get rank id error” in modelarts try: @@ -61,3 +64,115 @@ def unscale(self, inputs): def adjust(self, grads_finite): return True + + +AMP_WHITE_LIST = ( + nn.Dense, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.Conv1dTranspose, + nn.Conv2dTranspose, + nn.Conv3dTranspose, +) + +AMP_BLACK_LIST = ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LayerNorm, +) + + +class _OutputTo16(nn.Cell): + "Wrap cell for amp. Cast network output back to float16" + + def __init__(self, op): + super(_OutputTo16, self).__init__(auto_prefix=False) + self._op = op + + def construct(self, x): + return F.cast(self._op(x), mstype.float16) + + +class _OutputTo32(nn.Cell): + "Wrap cell for amp. Cast network output back to float32" + + def __init__(self, op): + super(_OutputTo32, self).__init__(auto_prefix=False) + self._op = op + + def construct(self, x): + return F.cast(self._op(x), mstype.float32) + + +def _auto_white_list(network, white_list=None): + """process the white list of network.""" + if white_list is None: + white_list = AMP_WHITE_LIST + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, white_list): + network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16)) + change = True + else: + _auto_white_list(subcell, white_list) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + + +def _auto_black_list(network, black_list=None): + """process the black list of network.""" + if black_list is None: + black_list = AMP_BLACK_LIST + network.to_float(mstype.float16) + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, black_list): + network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32)) + change = True + else: + _auto_black_list(subcell, black_list) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + + +def auto_mixed_precision(network, amp_level="O0", amp_cast_list=None): + """ + auto mixed precision function. + Args: + network (Cell): Definition of the network. + amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: "O0". + + - "O0": Do not change. + - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32. + - "O2": Cast network to float16, keep operators in black_list run in float32, + - "O3": Cast network to float16. + amp_cast_list: At the cell level, customize the list to cast ops to FP16. + + Raises: + ValueError: If amp level is not supported. + """ + if amp_cast_list is not None: + amp_cast_list = eval(amp_cast_list) + if isinstance(amp_cast_list, list): + amp_cast_list = tuple(amp_cast_list) + + if amp_level == "O0": + pass + elif amp_level == "O1": + _auto_white_list(network, amp_cast_list) + elif amp_level == "O2": + _auto_black_list(network, amp_cast_list) + elif amp_level == "O3": + network.to_float(mstype.float16) + else: + raise ValueError("The amp level {} is not supported".format(amp_level)) diff --git a/mindcv/utils/trainer_factory.py b/mindcv/utils/trainer_factory.py index 5679a92d..db47a48e 100644 --- a/mindcv/utils/trainer_factory.py +++ b/mindcv/utils/trainer_factory.py @@ -1,10 +1,14 @@ import logging -from typing import Union +from typing import Optional, Union import mindspore as ms -from mindspore import Tensor, context, nn +from mindspore import Tensor, context +from mindspore import dtype as mstype +from mindspore import nn +from mindspore.ops import functional as F from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model +from .amp import auto_mixed_precision from .train_step import TrainStep __all__ = [ @@ -29,22 +33,53 @@ def get_metrics(num_classes): return metrics -def require_customized_train_step(ema: bool = False, clip_grad: bool = False, gradient_accumulation_steps: int = 1): +def require_customized_train_step( + ema: bool = False, + clip_grad: bool = False, + gradient_accumulation_steps: int = 1, + amp_cast_list: Optional[str] = None, +): if ema: return True if clip_grad: return True if gradient_accumulation_steps > 1: return True + if amp_cast_list: + return True return False +def add_loss_network(network, loss_fn, amp_level): + """Add loss network.""" + + class WithLossCell(nn.Cell): + "Wrap loss for amp. Cast network output back to float32" + + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, data, label): + out = self._backbone(data) + label = F.mixed_precision_cast(mstype.float32, label) + return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) + + if amp_level == "O2" or amp_level == "O3": + network = WithLossCell(network, loss_fn) + else: + network = nn.WithLossCell(network, loss_fn) + return network + + def create_trainer( network: nn.Cell, loss: nn.Cell, optimizer: nn.Cell, metrics: Union[dict, set], amp_level: str, + amp_cast_list: str, loss_scale_type: str, loss_scale: float = 1.0, drop_overflow_update: bool = False, @@ -62,6 +97,7 @@ def create_trainer( optimizer: The optimizer for training. metrics: The metrics for model evaluation. amp_level: The level of auto mixing precision training. + amp_cast_list: At the cell level, custom casting the cell to FP16. loss_scale_type: The type of loss scale. loss_scale: The value of loss scale. drop_overflow_update: Whether to execute optimizer if there is an overflow. @@ -84,7 +120,7 @@ def create_trainer( if gradient_accumulation_steps < 1: raise ValueError("`gradient_accumulation_steps` must be >= 1!") - if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps): + if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list): mindspore_kwargs = dict( network=network, loss_fn=loss, @@ -111,8 +147,9 @@ def create_trainer( raise ValueError(f"Loss scale type only support ['fixed', 'dynamic', 'auto'], but got{loss_scale_type}.") model = Model(**mindspore_kwargs) else: # require customized train step - net_with_loss = nn.WithLossCell(network, loss) - ms.amp.auto_mixed_precision(net_with_loss, amp_level=amp_level) + eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"]) + auto_mixed_precision(network, amp_level, amp_cast_list) + net_with_loss = add_loss_network(network, loss, amp_level) train_step_kwargs = dict( network=net_with_loss, optimizer=optimizer, @@ -145,7 +182,6 @@ def create_trainer( ) train_step_kwargs["scale_sense"] = update_cell train_step_cell = TrainStep(**train_step_kwargs).set_train() - eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"]) model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2]) # todo: do we need to set model._loss_scale_manager return model diff --git a/train.py b/train.py index ed22748d..644948a5 100644 --- a/train.py +++ b/train.py @@ -208,7 +208,12 @@ def train(args): if ( args.loss_scale_type == "fixed" and args.drop_overflow_update is False - and not require_customized_train_step(args.ema, args.clip_grad, args.gradient_accumulation_steps) + and not require_customized_train_step( + args.ema, + args.clip_grad, + args.gradient_accumulation_steps, + args.amp_cast_list, + ) ): optimizer_loss_scale = args.loss_scale else: @@ -236,6 +241,7 @@ def train(args): optimizer, metrics, amp_level=args.amp_level, + amp_cast_list=args.amp_cast_list, loss_scale_type=args.loss_scale_type, loss_scale=args.loss_scale, drop_overflow_update=args.drop_overflow_update,