Skip to content

Commit

Permalink
Bug fix and added notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Jul 3, 2024
1 parent eb4ec58 commit 9f6d014
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 8 deletions.
16 changes: 8 additions & 8 deletions dect/ect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ class ECTConfig:
Configuration of the ECT Layer.
"""

num_thetas: int = 32
bump_steps: int = 32
radius: float = 1.1
ect_type: str = "points"
num_features: int = 3
normalized: bool = False
fixed: bool = False
fixed: bool = True


@dataclass()
Expand Down Expand Up @@ -214,14 +212,15 @@ def __init__(self, config: ECTConfig, v=None):
# The set of directions is added
# TODO: Requires testing.
if config.fixed:
self.v = nn.Parameter(v, requires_grad=False)
self.v = nn.Parameter(v.movedim(-1, -2), requires_grad=False)
else:
self.v = nn.Parameter(torch.zeros_like(v))
geotorch.constraints.sphere(self, "v")
# Movedim to make geotorch happy, me not happy.
self.v = nn.Parameter(torch.zeros_like(v.movedim(-1, -2)))
geotorch.constraints.sphere(self, "v", radius=config.radius)
# Since geotorch randomizes the vector during initialization, we
# assign the values after registering it with spherical constraints.
# See Geotorch documentation for examples.
self.v = nn.Parameter(v, requires_grad=True)
self.v = v.movedim(-1, -2)

if config.ect_type == "points":
self.compute_ect = compute_ect_points
Expand All @@ -232,7 +231,8 @@ def __init__(self, config: ECTConfig, v=None):

def forward(self, batch: Batch):
"""Forward method for the ECT Layer."""
ect = self.compute_ect(batch, self.v, self.lin)
# Movedim for geotorch.
ect = self.compute_ect(batch, self.v.movedim(-1, -2), self.lin)
if self.config.normalized:
return normalize(ect)
return ect.squeeze()
Loading

0 comments on commit 9f6d014

Please sign in to comment.