diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fbe010fa..886db76ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ -#### Februar 2021 +#### March 2021 +- PR [#191](https://github.com/uma-pi1/kge/pull/191): Fix loading of pretrained embeddings with reciprocal relation models + +#### February 2021 - [27e8a32](https://github.com/uma-pi1/kge/commit/27e8a323d208106d7b75f4e003ea4b73c1c5d58d): improve validation time by allowing bulk KvsAll index lookup and improved history computation - PR [#154](https://github.com/uma-pi1/kge/pull/154): store checkpoint containing the initialized model for reproducibility - [9e88117](https://github.com/uma-pi1/kge/commit/9e88117b3bf3f91b1c22f17d88eae2f77b5e3d3e): Add Transformer model and learning rate warmup (thanks nluedema) diff --git a/kge/model/kge_model.py b/kge/model/kge_model.py index cc8679880..96b27a725 100644 --- a/kge/model/kge_model.py +++ b/kge/model/kge_model.py @@ -254,7 +254,6 @@ def __init__( self.dim: int = self.get_option("dim") - @staticmethod def create( config: Config, @@ -283,7 +282,9 @@ def create( ) return embedder except: - config.log(f"Failed to create embedder {embedder_type} (class {class_name}).") + config.log( + f"Failed to create embedder {embedder_type} (class {class_name})." + ) raise def _intersect_ids_with_pretrained_embedder( @@ -616,8 +617,7 @@ def penalty(self, **kwargs) -> List[Tensor]: (triples[:, S].view(-1, 1), triples[:, O].view(-1, 1)), dim=1 ) entity_penalty_result = self.get_s_embedder().penalty( - indexes=entity_indexes, - **kwargs, + indexes=entity_indexes, **kwargs, ) if not weighted: # backwards compatibility diff --git a/kge/model/reciprocal_relations_model.py b/kge/model/reciprocal_relations_model.py index d151f8aea..4f47fd400 100644 --- a/kge/model/reciprocal_relations_model.py +++ b/kge/model/reciprocal_relations_model.py @@ -26,6 +26,10 @@ def __init__( # Using a dataset with twice the number of relations to initialize base model alt_dataset = dataset.shallow_copy() alt_dataset._num_relations = dataset.num_relations() * 2 + reciprocal_relation_ids = [ + rel_id + "_reciprocal" for rel_id in alt_dataset.relation_ids() + ] + alt_dataset._meta["relation_ids"].extend(reciprocal_relation_ids) base_model = KgeModel.create( config=config, dataset=alt_dataset,