From a65cfd1d1535014f701859141dc67ec1fb914cb9 Mon Sep 17 00:00:00 2001 From: AdrianKs Date: Mon, 31 Aug 2020 10:38:21 +0200 Subject: [PATCH] compute the backward pass only once for all three slots in negative sampling --- kge/config-default.yaml | 7 +++++++ kge/job/train.py | 32 +++++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 34975b1d2..9f8840cbe 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -312,6 +312,13 @@ negative_sampling: # 'triple' (e.g., for TransE or RotatE in the current implementation). implementation: triple + # Whether to compute the backward pass after scoring against each slot (S, P, O) + # or once after scoring against all three slots. + # Does not work with reciprocal relations. + # - True: slower, but needs less memory + # - False: faster, but needs more memory + backward_pass_per_slot: True + # Perform training in chunks of the specified size. When set, process each # batch in chunks of at most this size. This reduces memory consumption but # may increase runtime. Useful when there are many negative samples and/or diff --git a/kge/job/train.py b/kge/job/train.py index 48b6e207e..ba9d41c8c 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -779,6 +779,7 @@ def __init__(self, config, dataset, parent_job=None, model=None): "'{}' scoring function ...".format(self._implementation) ) self.type_str = "negative_sampling" + self.backward_pass_per_slot = self.config.get("negative_sampling.backward_pass_per_slot") if self.__class__ == TrainingJobNegativeSampling: for f in Job.job_created_hooks: @@ -851,6 +852,13 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: triples = batch_triples[chunk_indexes] # process the chunk + if not self.backward_pass_per_slot: + positive_scores = self.model.score_spo( + triples[:, S], + triples[:, P], + triples[:, O], + ) + loss_values_torch = [] for slot in [S, P, O]: num_samples = self._sampler.num_samples[slot] if num_samples <= 0: @@ -872,12 +880,15 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: # compute the scores forward_time -= time.time() scores = torch.empty((chunk_size, num_samples + 1), device=self.device) - scores[:, 0] = self.model.score_spo( - triples[:, S], - triples[:, P], - triples[:, O], - direction=SLOT_STR[slot], - ) + if not self.backward_pass_per_slot: + scores[:, 0] = positive_scores + else: + scores[:, 0] = self.model.score_spo( + triples[:, S], + triples[:, P], + triples[:, O], + direction=SLOT_STR[slot], + ) forward_time += time.time() scores[:, 1:] = batch_negative_samples[slot].score( self.model, indexes=chunk_indexes @@ -891,12 +902,19 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: self.loss(scores, labels[slot], num_negatives=num_samples) / batch_size ) + if not self.backward_pass_per_slot: + loss_values_torch.append(loss_value_torch) loss_value += loss_value_torch.item() forward_time += time.time() # backward pass for this chunk + if self.backward_pass_per_slot: + backward_time -= time.time() + loss_value_torch.backward() + backward_time += time.time() + if not self.backward_pass_per_slot: backward_time -= time.time() - loss_value_torch.backward() + torch.autograd.backward(loss_values_torch) backward_time += time.time() # all done