Skip to content

Commit

Permalink
Merge pull request #623 from DeepRegNet/512-add_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
mathpluscode authored Jan 26, 2021
2 parents 035ba43 + 5ea1a29 commit a182f9d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ compatible with the updates.

### Added

- Added global NCC loss
- Added the docs on registry for backbone models.
- Added backward compatible config parser.
- Added tests so that test coverage is 100%.
Expand Down
55 changes: 55 additions & 0 deletions deepreg/loss/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,58 @@ class LocalNormalizedCrossCorrelationLoss(
NegativeLossMixin, LocalNormalizedCrossCorrelation
):
"""Revert the sign of LocalNormalizedCrossCorrelation."""


class GlobalNormalizedCrossCorrelation(tf.keras.losses.Loss):
"""
Global squared zero-normalized cross-correlation.
Compute the squared cross-correlation between the reference and moving images
y_true and y_pred have to be at least 4d tensor, including batch axis.
Reference:
- Zero-normalized cross-correlation (ZNCC):
https://en.wikipedia.org/wiki/Cross-correlation
"""

def __init__(
self,
reduction: str = tf.keras.losses.Reduction.AUTO,
name: str = "GlobalNormalizedCrossCorrelation",
):
"""
Init.
:param reduction: using AUTO reduction,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
super().__init__(reduction=reduction, name=name)

def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
"""
Return loss for a batch.
:param y_true: shape = (batch, ...)
:param y_pred: shape = (batch, ...)
:return: shape = (batch,)
"""

axis = [a for a in range(1, len(y_true.shape))]
mu_pred = tf.reduce_mean(y_pred, axis=axis, keepdims=True)
mu_true = tf.reduce_mean(y_true, axis=axis, keepdims=True)
var_pred = tf.math.reduce_variance(y_pred, axis=axis)
var_true = tf.math.reduce_variance(y_true, axis=axis)
numerator = tf.abs(
tf.reduce_mean((y_pred - mu_pred) * (y_true - mu_true), axis=axis)
)

return (numerator * numerator + EPS) / (var_pred * var_true + EPS)


@REGISTRY.register_loss(name="gncc")
class GlobalNormalizedCrossCorrelationLoss(
NegativeLossMixin, GlobalNormalizedCrossCorrelation
):
"""Revert the sign of GlobalNormalizedCrossCorrelation."""
21 changes: 11 additions & 10 deletions docs/source/docs/registered_classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ The category is `model_class`. Registered keys and values are as following.

The category is `loss_class`. Registered keys and values are as following.

| key | value |
| :-------------- | :------------------------------------------------------- |
| "bending" | `deepreg.loss.deform.BendingEnergy` |
| "cross-entropy" | `deepreg.loss.label.CrossEntropy` |
| "dice" | `deepreg.loss.label.DiceLoss` |
| "gmi" | `deepreg.loss.image.GlobalMutualInformationLoss` |
| "gradient" | `deepreg.loss.deform.GradientNorm` |
| "jaccard" | `deepreg.loss.label.JaccardLoss` |
| "lncc" | `deepreg.loss.image.LocalNormalizedCrossCorrelationLoss` |
| "ssd" | `deepreg.loss.image.SumSquaredDifference` |
| key | value |
| :-------------- | :-------------------------------------------------------- |
| "bending" | `deepreg.loss.deform.BendingEnergy` |
| "cross-entropy" | `deepreg.loss.label.CrossEntropy` |
| "dice" | `deepreg.loss.label.DiceLoss` |
| "gmi" | `deepreg.loss.image.GlobalMutualInformationLoss` |
| "gncc" | `deepreg.loss.image.GlobalNormalizedCrossCorrelationLoss` |
| "gradient" | `deepreg.loss.deform.GradientNorm` |
| "jaccard" | `deepreg.loss.label.JaccardLoss` |
| "lncc" | `deepreg.loss.image.LocalNormalizedCrossCorrelationLoss` |
| "ssd" | `deepreg.loss.image.SumSquaredDifference` |

## Data Augmentation

Expand Down
29 changes: 29 additions & 0 deletions test/unit/test_loss_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,32 @@ def test_get_config(self):
name="LocalNormalizedCrossCorrelation",
)
assert got == expected


class TestGlobalNormalizedCrossCorrelation:
@pytest.mark.parametrize(
"y_true,y_pred,shape,expected",
[
(0.6, 0.3, (3, 3), 1),
(0.6, 0.3, (3, 3, 3), 1),
(0.6, -0.3, (3, 3, 3), 1),
(0.6, 0.3, (3, 3, 3, 3), 1),
],
)
def test_output(self, y_true, y_pred, shape, expected):

y_true = y_true * tf.ones(shape=shape)
y_pred = y_pred * tf.ones(shape=shape)

pad_width = tuple([(0, 0)] + [(1, 1)] * (len(shape) - 1))
y_true = np.pad(y_true, pad_width=pad_width)
y_pred = np.pad(y_pred, pad_width=pad_width)

got = image.GlobalNormalizedCrossCorrelation().call(
y_true,
y_pred,
)

expected = expected * tf.ones(shape=(shape[0],))

assert is_equal_tf(got, expected)

0 comments on commit a182f9d

Please sign in to comment.