From dfd0aace1a77d6b7f04f414bdc8ea748a9d0d2f2 Mon Sep 17 00:00:00 2001 From: nzteb Date: Thu, 14 May 2020 15:04:58 +0200 Subject: [PATCH] Add se loss --- kge/config-default.yaml | 4 ++++ kge/util/loss.py | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 7d03f42d7..7ace2f788 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -133,6 +133,10 @@ train: # true label. Computed once for each positive and once for each negative # triple in the batch. See loss_arg for parameters. # + # Squared error (se, all training types): Calculate squared error between + # the score of a triple and its true value (0, 1). Computed once for each + # positive and once for each negative triple in the batch. + # # Generally, the loss values are averaged over the batch elements (e.g., # positive triple for 1vsAll and negative_sampling, sp- or po-pair for # KvsAll). If multiple loss values arise for each batch element (e.g., when diff --git a/kge/util/loss.py b/kge/util/loss.py index eb57b86f6..4a0cdb7ca 100644 --- a/kge/util/loss.py +++ b/kge/util/loss.py @@ -41,6 +41,7 @@ def create(config: Config): "ce", "kl", "soft_margin", + "se", ], ) if config.get("train.loss") == "bce": @@ -81,6 +82,8 @@ def create(config: Config): return MarginRankingKgeLoss(config, margin=margin) elif config.get("train.loss") == "soft_margin": return SoftMarginKgeLoss(config) + elif config.get("train.loss") == "se": + return SEKgeLoss(config) else: raise ValueError( "invalid value train.loss={}".format(config.get("train.loss")) @@ -259,3 +262,13 @@ def __call__(self, scores, labels, **kwargs): ) else: raise ValueError("train.type for margin ranking.") + + +class SEKgeLoss(KgeLoss): + def __init__(self, config, reduction="sum", **kwargs): + super().__init__(config) + self._loss = torch.nn.MSELoss(reduction=reduction, **kwargs) + + def __call__(self, scores, labels, **kwargs): + labels = self._labels_as_matrix(scores, labels) + return self._loss(scores, labels)