From e32643a8b2675c5ee5c657747d1cd5108ce89a9b Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 10 Jul 2023 10:39:19 +0800 Subject: [PATCH] feat: add jsd loss and asymmetric loss (#682) --- mindcv/loss/__init__.py | 4 +- mindcv/loss/asymmetric.py | 83 +++++++++++++++++ mindcv/loss/jsd.py | 52 +++++++++++ mindcv/loss/loss_factory.py | 8 ++ tests/modules/test_loss.py | 173 +++++++++++++++++++++++++++++++++++- 5 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 mindcv/loss/asymmetric.py create mode 100644 mindcv/loss/jsd.py diff --git a/mindcv/loss/__init__.py b/mindcv/loss/__init__.py index a3b496aec..fcd521b6c 100644 --- a/mindcv/loss/__init__.py +++ b/mindcv/loss/__init__.py @@ -1,7 +1,9 @@ """ loss init """ -from . import binary_cross_entropy_smooth, cross_entropy_smooth, loss_factory +from . import asymmetric, binary_cross_entropy_smooth, cross_entropy_smooth, jsd, loss_factory +from .asymmetric import AsymmetricLossMultilabel, AsymmetricLossSingleLabel from .binary_cross_entropy_smooth import BinaryCrossEntropySmooth from .cross_entropy_smooth import CrossEntropySmooth +from .jsd import JSDCrossEntropy from .loss_factory import create_loss __all__ = [] diff --git a/mindcv/loss/asymmetric.py b/mindcv/loss/asymmetric.py new file mode 100644 index 000000000..493c90aab --- /dev/null +++ b/mindcv/loss/asymmetric.py @@ -0,0 +1,83 @@ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, ops + + +class AsymmetricLossMultilabel(nn.LossBase): + def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8): + super(AsymmetricLossMultilabel, self).__init__() + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.eps = eps + + def construct(self, logits, labels): + """ + logits: output from models + labels: multi-label binarized vector + """ + x_sigmoid = ops.Sigmoid()(logits) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + + if self.clip > 0: + xs_neg = ops.clip_by_value(xs_neg + self.clip, clip_value_max=Tensor(1.0)) + + los_pos = labels * ops.log(ops.clip_by_value(xs_pos, clip_value_min=Tensor(self.eps))) + los_neg = (1 - labels) * ops.log(ops.clip_by_value(xs_neg, clip_value_min=Tensor(self.eps))) + + loss = los_pos + los_neg + + if self.gamma_pos > 0 and self.gamma_neg > 0: + pt0 = xs_pos * labels + pt1 = xs_neg * (1 - labels) + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * labels + self.gamma_neg * (1 - labels) + one_sided_w = ops.pow(1 - pt, one_sided_gamma) + + loss *= one_sided_w + + return -loss.sum() + + +class AsymmetricLossSingleLabel(nn.LossBase): + def __init__(self, gamma_pos=1, gamma_neg=4, eps=0.1, reduction="mean", smoothing=0.1): + super(AsymmetricLossSingleLabel, self).__init__() + + self.eps = eps + self.logsoftmax = nn.LogSoftmax(axis=-1) + self.targets_classes = [] + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.reduction = reduction + self.smoothing = smoothing + + def construct(self, logits, labels): + num_classes = logits.shape[-1] + log_preds = self.logsoftmax(logits) + labels_e = ops.ExpandDims()(labels, 1) + labels_e_shape = labels_e.shape + targets = ops.tensor_scatter_elements( + ops.ZerosLike()(logits), labels_e, Tensor(np.ones(labels_e_shape, dtype=np.float32)), 1 + ) + + anti_targets = 1 - targets + xs_pos = ops.exp((log_preds)) + xs_neg = 1 - xs_pos + xs_pos = xs_pos * targets + xs_neg = xs_neg * anti_targets + + asymmetric_w = ops.pow(1 - xs_pos - xs_neg, self.gamma_pos * targets + self.gamma_neg * anti_targets) + + log_preds = log_preds * asymmetric_w + + targets = targets * (1 - self.smoothing) + self.smoothing / num_classes + + loss = -targets * log_preds + loss = ops.ReduceSum()(loss, -1) + + if self.reduction == "mean": + loss = loss.mean() + + return loss diff --git a/mindcv/loss/jsd.py b/mindcv/loss/jsd.py new file mode 100644 index 000000000..46f96950e --- /dev/null +++ b/mindcv/loss/jsd.py @@ -0,0 +1,52 @@ +from mindspore import nn, ops + +from .cross_entropy_smooth import CrossEntropySmooth + + +class JSDCrossEntropy(nn.LossBase): + """ + JSD loss is implemented according to "AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" + https://arxiv.org/abs/1912.02781 + + Please note that JSD loss should be used when "aug_splits = 3". + """ + + def __init__(self, num_splits=3, alpha=12, smoothing=0.1, weight=None, reduction="mean", aux_factor=0.0): + super().__init__() + self.num_splits = num_splits + self.alpha = alpha + self.smoothing = smoothing + self.weight = weight + self.reduction = reduction + self.kldiv = ops.KLDivLoss(reduction="batchmean") + self.map = ops.Map() + + self.softmax = ops.Softmax(axis=1) + self.aux_factor = aux_factor + + def construct(self, logits, labels): + if self.training: + split_size = logits.shape[0] // self.num_splits + log_split = ops.split(logits, 0, self.num_splits) + + loss = ops.cross_entropy( + log_split[0], + labels[:split_size], + weight=self.weight, + reduction=self.reduction, + label_smoothing=self.smoothing, + ) + + probs = self.map(self.softmax, log_split) + stack_probs = ops.stack(probs) + clip_probs = ops.clip_by_value(stack_probs.mean(axis=0), 1e-7, 1) + log_probs = ops.log(clip_probs) + + for p_split in probs: + loss += self.alpha * self.kldiv(log_probs, p_split) / self.num_splits + + return loss + else: + return CrossEntropySmooth( + smoothing=self.smoothing, aux_factor=self.aux_factor, reduction=self.reduction, weight=self.weight + )(logits, labels) diff --git a/mindcv/loss/loss_factory.py b/mindcv/loss/loss_factory.py index c1582c4ad..54659ec95 100644 --- a/mindcv/loss/loss_factory.py +++ b/mindcv/loss/loss_factory.py @@ -3,8 +3,10 @@ from mindspore import Tensor +from .asymmetric import AsymmetricLossMultilabel, AsymmetricLossSingleLabel from .binary_cross_entropy_smooth import BinaryCrossEntropySmooth from .cross_entropy_smooth import CrossEntropySmooth +from .jsd import JSDCrossEntropy __all__ = ["create_loss"] @@ -50,6 +52,12 @@ def create_loss( loss = BinaryCrossEntropySmooth( smoothing=label_smoothing, aux_factor=aux_factor, reduction=reduction, weight=weight, pos_weight=None ) + elif name == "asl_single_label": + loss = AsymmetricLossSingleLabel(smoothing=label_smoothing) + elif name == "asl_multi_label": + loss = AsymmetricLossMultilabel() + elif name == "jsd": + loss = JSDCrossEntropy(smoothing=label_smoothing, aux_factor=aux_factor, reduction=reduction, weight=weight) else: raise NotImplementedError diff --git a/tests/modules/test_loss.py b/tests/modules/test_loss.py index db756bfbf..6b031631a 100644 --- a/tests/modules/test_loss.py +++ b/tests/modules/test_loss.py @@ -6,7 +6,7 @@ import pytest import mindspore as ms -from mindspore import nn +from mindspore import Tensor, nn from mindspore.common.initializer import Normal from mindspore.nn import TrainOneStepCell, WithLossCell @@ -120,5 +120,176 @@ def test_loss(mode, name, reduction, label_smoothing, aux_factor, weight, double assert cur_loss < begin_loss, "Loss does NOT decrease" +@pytest.mark.parametrize("mode", [0, 1]) +@pytest.mark.parametrize("name", ["asl_single_label"]) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_asl(mode, name, label_smoothing, reduction="mean", aux_factor=0.0, weight=None, double_aux=False): + weight = None + print( + f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\ + label_smoothing={label_smoothing}; aux_factor={aux_factor}" + ) + ms.set_context(mode=mode) + + bs = 8 + num_classes = c = 10 + # create data + x = ms.Tensor(np.random.randn(bs, 1, 32, 32), ms.float32) + # logits = ms.Tensor(np.random.rand(bs, c), ms.float32) + y = np.random.randint(0, c, size=(bs)) + y_onehot = np.eye(c)[y] + y = ms.Tensor(y, ms.int32) # C + y_onehot = ms.Tensor(y_onehot, ms.float32) # N, C + if name == "BCE": + label = y_onehot + else: + label = y + + if weight is not None: + weight = np.random.randn(c) + weight = weight / weight.sum() # normalize + weight = ms.Tensor(weight, ms.float32) + + # set network + aux_head = aux_factor > 0.0 + aux_head2 = aux_head and double_aux + network = SimpleCNN(in_channels=1, num_classes=num_classes, aux_head=aux_head, aux_head2=aux_head2) + + # set loss + net_loss = create_loss( + name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor + ) + + # optimize + net_with_loss = WithLossCell(network, net_loss) + + net_opt = create_optimizer(network.trainable_params(), "adam", lr=0.001, weight_decay=1e-7) + train_network = TrainOneStepCell(net_with_loss, net_opt) + + train_network.set_train() + + begin_loss = train_network(x, label) + for _ in range(10): + cur_loss = train_network(x, label) + + print("begin loss: {}, end loss: {}".format(begin_loss, cur_loss)) + + assert cur_loss < begin_loss, "Loss does NOT decrease" + + +@pytest.mark.parametrize("mode", [0, 1]) +@pytest.mark.parametrize("name", ["asl_single_label"]) +def test_asl_single_label_random(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None): + weight = None + print( + f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\ + label_smoothing={label_smoothing}; aux_factor={aux_factor}" + ) + ms.set_context(mode=mode) + + net_loss = create_loss( + name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor + ) + + # logits and labels + logits = Tensor( + [ + [0.38317937, 0.82873726, 0.8164871, 0.6443424], + [0.77863216, 0.17288171, 0.69345415, 0.26514006], + [0.14249292, 0.38524792, 0.97271717, 0.90531427], + ], + ms.float32, + ) + labels = Tensor([1, 1, 1], ms.int32) + + output_expected = Tensor(1.1247127, ms.float32) + + output = net_loss(logits, labels) + + assert np.allclose(output_expected.asnumpy(), output.asnumpy()) + + +@pytest.mark.parametrize("mode", [0, 1]) +@pytest.mark.parametrize("name", ["asl_single_label"]) +def test_asl_single_label_zero(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None): + weight = None + print( + f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\ + label_smoothing={label_smoothing}; aux_factor={aux_factor}" + ) + ms.set_context(mode=mode) + + net_loss = create_loss( + name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor + ) + + # logits and labels + logits = Tensor(np.zeros((3, 5)), ms.float32) + labels = Tensor(np.zeros((3,)), ms.int32) + + output_expected = Tensor(1.1847522, ms.float32) + + output = net_loss(logits, labels) + + assert np.allclose(output_expected.asnumpy(), output.asnumpy()) + + +@pytest.mark.parametrize("mode", [0, 1]) +@pytest.mark.parametrize("name", ["asl_multi_label"]) +def test_asl_multi_label_random(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None): + weight = None + print( + f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\ + label_smoothing={label_smoothing}; aux_factor={aux_factor}" + ) + ms.set_context(mode=mode) + + net_loss = create_loss( + name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor + ) + + # logits and labels + logits = Tensor( + [ + [0.38317937, 0.82873726, 0.8164871, 0.6443424], + [0.77863216, 0.17288171, 0.69345415, 0.26514006], + [0.14249292, 0.38524792, 0.97271717, 0.90531427], + ], + ms.float32, + ) + labels = Tensor([[1, 1, 0, 0], [0, 0, 1, 1], [1, 0, 1, 0]], ms.int32) + + output_expected = Tensor(1.8657642, ms.float32) + + output = net_loss(logits, labels) + + assert np.allclose(output_expected.asnumpy(), output.asnumpy()) + + +@pytest.mark.parametrize("mode", [0, 1]) +@pytest.mark.parametrize("name", ["asl_multi_label"]) +def test_asl_multi_label_zero(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None): + weight = None + print( + f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\ + label_smoothing={label_smoothing}; aux_factor={aux_factor}" + ) + ms.set_context(mode=mode) + + net_loss = create_loss( + name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor + ) + + # logits and labels + logits = Tensor(np.zeros((3, 5)), ms.float32) + labels = Tensor(np.zeros((3, 5)), ms.int32) + + output_expected = Tensor(0.3677258, ms.float32) + + output = net_loss(logits, labels) + + assert np.allclose(output_expected.asnumpy(), output.asnumpy()) + + if __name__ == "__main__": test_loss(0, "BCE", "mean", 0.1, 0.1, None, True)