Skip to content

Commit

Permalink
Merge pull request #191 from uma-pi1/fix_pretrain_reciprocal
Browse files Browse the repository at this point in the history
fix loading of pretrained reciprocal relations model
  • Loading branch information
AdrianKs authored Mar 18, 2021
2 parents e3c11e3 + f6c8167 commit d063477
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#### Februar 2021
#### March 2021
- PR [#191](https://github.com/uma-pi1/kge/pull/191): Fix loading of pretrained embeddings with reciprocal relation models

#### February 2021
- [27e8a32](https://github.com/uma-pi1/kge/commit/27e8a323d208106d7b75f4e003ea4b73c1c5d58d): improve validation time by allowing bulk KvsAll index lookup and improved history computation
- PR [#154](https://github.com/uma-pi1/kge/pull/154): store checkpoint containing the initialized model for reproducibility
- [9e88117](https://github.com/uma-pi1/kge/commit/9e88117b3bf3f91b1c22f17d88eae2f77b5e3d3e): Add Transformer model and learning rate warmup (thanks nluedema)
Expand Down
8 changes: 4 additions & 4 deletions kge/model/kge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def __init__(

self.dim: int = self.get_option("dim")


@staticmethod
def create(
config: Config,
Expand Down Expand Up @@ -283,7 +282,9 @@ def create(
)
return embedder
except:
config.log(f"Failed to create embedder {embedder_type} (class {class_name}).")
config.log(
f"Failed to create embedder {embedder_type} (class {class_name})."
)
raise

def _intersect_ids_with_pretrained_embedder(
Expand Down Expand Up @@ -616,8 +617,7 @@ def penalty(self, **kwargs) -> List[Tensor]:
(triples[:, S].view(-1, 1), triples[:, O].view(-1, 1)), dim=1
)
entity_penalty_result = self.get_s_embedder().penalty(
indexes=entity_indexes,
**kwargs,
indexes=entity_indexes, **kwargs,
)
if not weighted:
# backwards compatibility
Expand Down
4 changes: 4 additions & 0 deletions kge/model/reciprocal_relations_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(
# Using a dataset with twice the number of relations to initialize base model
alt_dataset = dataset.shallow_copy()
alt_dataset._num_relations = dataset.num_relations() * 2
reciprocal_relation_ids = [
rel_id + "_reciprocal" for rel_id in alt_dataset.relation_ids()
]
alt_dataset._meta["relation_ids"].extend(reciprocal_relation_ids)
base_model = KgeModel.create(
config=config,
dataset=alt_dataset,
Expand Down

0 comments on commit d063477

Please sign in to comment.