Skip to content

Commit

Permalink
feat: add jsd loss and asymmetric loss (mindspore-lab#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcrzlh authored Jul 10, 2023
1 parent c391b57 commit e32643a
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 2 deletions.
4 changes: 3 additions & 1 deletion mindcv/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
""" loss init """
from . import binary_cross_entropy_smooth, cross_entropy_smooth, loss_factory
from . import asymmetric, binary_cross_entropy_smooth, cross_entropy_smooth, jsd, loss_factory
from .asymmetric import AsymmetricLossMultilabel, AsymmetricLossSingleLabel
from .binary_cross_entropy_smooth import BinaryCrossEntropySmooth
from .cross_entropy_smooth import CrossEntropySmooth
from .jsd import JSDCrossEntropy
from .loss_factory import create_loss

__all__ = []
Expand Down
83 changes: 83 additions & 0 deletions mindcv/loss/asymmetric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np

import mindspore.nn as nn
from mindspore import Tensor, ops


class AsymmetricLossMultilabel(nn.LossBase):
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8):
super(AsymmetricLossMultilabel, self).__init__()
self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.eps = eps

def construct(self, logits, labels):
"""
logits: output from models
labels: multi-label binarized vector
"""
x_sigmoid = ops.Sigmoid()(logits)
xs_pos = x_sigmoid
xs_neg = 1 - x_sigmoid

if self.clip > 0:
xs_neg = ops.clip_by_value(xs_neg + self.clip, clip_value_max=Tensor(1.0))

los_pos = labels * ops.log(ops.clip_by_value(xs_pos, clip_value_min=Tensor(self.eps)))
los_neg = (1 - labels) * ops.log(ops.clip_by_value(xs_neg, clip_value_min=Tensor(self.eps)))

loss = los_pos + los_neg

if self.gamma_pos > 0 and self.gamma_neg > 0:
pt0 = xs_pos * labels
pt1 = xs_neg * (1 - labels)
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * labels + self.gamma_neg * (1 - labels)
one_sided_w = ops.pow(1 - pt, one_sided_gamma)

loss *= one_sided_w

return -loss.sum()


class AsymmetricLossSingleLabel(nn.LossBase):
def __init__(self, gamma_pos=1, gamma_neg=4, eps=0.1, reduction="mean", smoothing=0.1):
super(AsymmetricLossSingleLabel, self).__init__()

self.eps = eps
self.logsoftmax = nn.LogSoftmax(axis=-1)
self.targets_classes = []
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.reduction = reduction
self.smoothing = smoothing

def construct(self, logits, labels):
num_classes = logits.shape[-1]
log_preds = self.logsoftmax(logits)
labels_e = ops.ExpandDims()(labels, 1)
labels_e_shape = labels_e.shape
targets = ops.tensor_scatter_elements(
ops.ZerosLike()(logits), labels_e, Tensor(np.ones(labels_e_shape, dtype=np.float32)), 1
)

anti_targets = 1 - targets
xs_pos = ops.exp((log_preds))
xs_neg = 1 - xs_pos
xs_pos = xs_pos * targets
xs_neg = xs_neg * anti_targets

asymmetric_w = ops.pow(1 - xs_pos - xs_neg, self.gamma_pos * targets + self.gamma_neg * anti_targets)

log_preds = log_preds * asymmetric_w

targets = targets * (1 - self.smoothing) + self.smoothing / num_classes

loss = -targets * log_preds
loss = ops.ReduceSum()(loss, -1)

if self.reduction == "mean":
loss = loss.mean()

return loss
52 changes: 52 additions & 0 deletions mindcv/loss/jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from mindspore import nn, ops

from .cross_entropy_smooth import CrossEntropySmooth


class JSDCrossEntropy(nn.LossBase):
"""
JSD loss is implemented according to "AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty"
https://arxiv.org/abs/1912.02781
Please note that JSD loss should be used when "aug_splits = 3".
"""

def __init__(self, num_splits=3, alpha=12, smoothing=0.1, weight=None, reduction="mean", aux_factor=0.0):
super().__init__()
self.num_splits = num_splits
self.alpha = alpha
self.smoothing = smoothing
self.weight = weight
self.reduction = reduction
self.kldiv = ops.KLDivLoss(reduction="batchmean")
self.map = ops.Map()

self.softmax = ops.Softmax(axis=1)
self.aux_factor = aux_factor

def construct(self, logits, labels):
if self.training:
split_size = logits.shape[0] // self.num_splits
log_split = ops.split(logits, 0, self.num_splits)

loss = ops.cross_entropy(
log_split[0],
labels[:split_size],
weight=self.weight,
reduction=self.reduction,
label_smoothing=self.smoothing,
)

probs = self.map(self.softmax, log_split)
stack_probs = ops.stack(probs)
clip_probs = ops.clip_by_value(stack_probs.mean(axis=0), 1e-7, 1)
log_probs = ops.log(clip_probs)

for p_split in probs:
loss += self.alpha * self.kldiv(log_probs, p_split) / self.num_splits

return loss
else:
return CrossEntropySmooth(
smoothing=self.smoothing, aux_factor=self.aux_factor, reduction=self.reduction, weight=self.weight
)(logits, labels)
8 changes: 8 additions & 0 deletions mindcv/loss/loss_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from mindspore import Tensor

from .asymmetric import AsymmetricLossMultilabel, AsymmetricLossSingleLabel
from .binary_cross_entropy_smooth import BinaryCrossEntropySmooth
from .cross_entropy_smooth import CrossEntropySmooth
from .jsd import JSDCrossEntropy

__all__ = ["create_loss"]

Expand Down Expand Up @@ -50,6 +52,12 @@ def create_loss(
loss = BinaryCrossEntropySmooth(
smoothing=label_smoothing, aux_factor=aux_factor, reduction=reduction, weight=weight, pos_weight=None
)
elif name == "asl_single_label":
loss = AsymmetricLossSingleLabel(smoothing=label_smoothing)
elif name == "asl_multi_label":
loss = AsymmetricLossMultilabel()
elif name == "jsd":
loss = JSDCrossEntropy(smoothing=label_smoothing, aux_factor=aux_factor, reduction=reduction, weight=weight)
else:
raise NotImplementedError

Expand Down
173 changes: 172 additions & 1 deletion tests/modules/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import mindspore as ms
from mindspore import nn
from mindspore import Tensor, nn
from mindspore.common.initializer import Normal
from mindspore.nn import TrainOneStepCell, WithLossCell

Expand Down Expand Up @@ -120,5 +120,176 @@ def test_loss(mode, name, reduction, label_smoothing, aux_factor, weight, double
assert cur_loss < begin_loss, "Loss does NOT decrease"


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["asl_single_label"])
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_asl(mode, name, label_smoothing, reduction="mean", aux_factor=0.0, weight=None, double_aux=False):
weight = None
print(
f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\
label_smoothing={label_smoothing}; aux_factor={aux_factor}"
)
ms.set_context(mode=mode)

bs = 8
num_classes = c = 10
# create data
x = ms.Tensor(np.random.randn(bs, 1, 32, 32), ms.float32)
# logits = ms.Tensor(np.random.rand(bs, c), ms.float32)
y = np.random.randint(0, c, size=(bs))
y_onehot = np.eye(c)[y]
y = ms.Tensor(y, ms.int32) # C
y_onehot = ms.Tensor(y_onehot, ms.float32) # N, C
if name == "BCE":
label = y_onehot
else:
label = y

if weight is not None:
weight = np.random.randn(c)
weight = weight / weight.sum() # normalize
weight = ms.Tensor(weight, ms.float32)

# set network
aux_head = aux_factor > 0.0
aux_head2 = aux_head and double_aux
network = SimpleCNN(in_channels=1, num_classes=num_classes, aux_head=aux_head, aux_head2=aux_head2)

# set loss
net_loss = create_loss(
name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor
)

# optimize
net_with_loss = WithLossCell(network, net_loss)

net_opt = create_optimizer(network.trainable_params(), "adam", lr=0.001, weight_decay=1e-7)
train_network = TrainOneStepCell(net_with_loss, net_opt)

train_network.set_train()

begin_loss = train_network(x, label)
for _ in range(10):
cur_loss = train_network(x, label)

print("begin loss: {}, end loss: {}".format(begin_loss, cur_loss))

assert cur_loss < begin_loss, "Loss does NOT decrease"


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["asl_single_label"])
def test_asl_single_label_random(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None):
weight = None
print(
f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\
label_smoothing={label_smoothing}; aux_factor={aux_factor}"
)
ms.set_context(mode=mode)

net_loss = create_loss(
name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor
)

# logits and labels
logits = Tensor(
[
[0.38317937, 0.82873726, 0.8164871, 0.6443424],
[0.77863216, 0.17288171, 0.69345415, 0.26514006],
[0.14249292, 0.38524792, 0.97271717, 0.90531427],
],
ms.float32,
)
labels = Tensor([1, 1, 1], ms.int32)

output_expected = Tensor(1.1247127, ms.float32)

output = net_loss(logits, labels)

assert np.allclose(output_expected.asnumpy(), output.asnumpy())


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["asl_single_label"])
def test_asl_single_label_zero(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None):
weight = None
print(
f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\
label_smoothing={label_smoothing}; aux_factor={aux_factor}"
)
ms.set_context(mode=mode)

net_loss = create_loss(
name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor
)

# logits and labels
logits = Tensor(np.zeros((3, 5)), ms.float32)
labels = Tensor(np.zeros((3,)), ms.int32)

output_expected = Tensor(1.1847522, ms.float32)

output = net_loss(logits, labels)

assert np.allclose(output_expected.asnumpy(), output.asnumpy())


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["asl_multi_label"])
def test_asl_multi_label_random(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None):
weight = None
print(
f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\
label_smoothing={label_smoothing}; aux_factor={aux_factor}"
)
ms.set_context(mode=mode)

net_loss = create_loss(
name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor
)

# logits and labels
logits = Tensor(
[
[0.38317937, 0.82873726, 0.8164871, 0.6443424],
[0.77863216, 0.17288171, 0.69345415, 0.26514006],
[0.14249292, 0.38524792, 0.97271717, 0.90531427],
],
ms.float32,
)
labels = Tensor([[1, 1, 0, 0], [0, 0, 1, 1], [1, 0, 1, 0]], ms.int32)

output_expected = Tensor(1.8657642, ms.float32)

output = net_loss(logits, labels)

assert np.allclose(output_expected.asnumpy(), output.asnumpy())


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["asl_multi_label"])
def test_asl_multi_label_zero(mode, name, reduction="mean", label_smoothing=0.1, aux_factor=0.0, weight=None):
weight = None
print(
f"mode={mode}; loss_name={name}; has_weight=False; reduction={reduction};\
label_smoothing={label_smoothing}; aux_factor={aux_factor}"
)
ms.set_context(mode=mode)

net_loss = create_loss(
name=name, weight=weight, reduction=reduction, label_smoothing=label_smoothing, aux_factor=aux_factor
)

# logits and labels
logits = Tensor(np.zeros((3, 5)), ms.float32)
labels = Tensor(np.zeros((3, 5)), ms.int32)

output_expected = Tensor(0.3677258, ms.float32)

output = net_loss(logits, labels)

assert np.allclose(output_expected.asnumpy(), output.asnumpy())


if __name__ == "__main__":
test_loss(0, "BCE", "mean", 0.1, 0.1, None, True)

0 comments on commit e32643a

Please sign in to comment.