From 08ab93944297cb41df60754932df29642ef8e3f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Ballester?= Date: Wed, 17 Jul 2024 19:28:05 +0200 Subject: [PATCH] Adding black + some minor changes in wect and the README.md --- README.md | 2 +- dect/ect.py | 18 ++++------ dect/wect.py | 98 ++++++++++++++++++++++++++-------------------------- 3 files changed, 56 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 6f2bdec..8c9ac64 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ from dect.directions import generate_uniform_2d_directions v = generate_uniform_2d_directions(num_thetas=64) -layer = ECTLayer(ECTConfig(), V=v) +layer = ECTLayer(ECTConfig(), v=v) points_coordinates = torch.tensor( [[0.5, 0.0], [-0.5, 0.0], [0.5, 0.5]], requires_grad=True diff --git a/dect/ect.py b/dect/ect.py index 363506d..81edeb2 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -108,9 +108,7 @@ def compute_ecc( return segment_add_coo(ecc, index) -def compute_ect_points( - batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor -): +def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): """Computes the Euler Characteristic Transform of a batch of point clouds. Parameters @@ -127,9 +125,7 @@ def compute_ect_points( return compute_ecc(nh, batch.batch, lin) -def compute_ect_edges( - batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor -): +def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): """Computes the Euler Characteristic Transform of a batch of graphs. Parameters @@ -162,9 +158,7 @@ def compute_ect_edges( ) -def compute_ect_faces( - batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor -): +def compute_ect_faces(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): """Computes the Euler Characteristic Transform of a batch of meshes. Parameters @@ -225,9 +219,9 @@ def __init__(self, config: ECTConfig, v=None): super().__init__() self.config = config self.lin = nn.Parameter( - torch.linspace( - -config.radius, config.radius, config.bump_steps - ).view(-1, 1, 1, 1), + torch.linspace(-config.radius, config.radius, config.bump_steps).view( + -1, 1, 1, 1 + ), requires_grad=False, ) diff --git a/dect/wect.py b/dect/wect.py index 7498e12..b2a188c 100644 --- a/dect/wect.py +++ b/dect/wect.py @@ -5,7 +5,7 @@ from torch import nn from torch_scatter import segment_add_coo -from ect import ECTConfig, Batch, normalize +from dect.ect import ECTConfig, Batch, normalize def compute_wecc( @@ -17,26 +17,26 @@ def compute_wecc( ): """Computes the weighted Euler Characteristic curve. - Parameters - ---------- - nh : torch.FloatTensor - The node heights, computed as the inner product of the node coordinates - x and the direction vector v. - index: torch.LongTensor - The index that indicates to which pointcloud a node height belongs. For - the node heights it is the same as the batch index, for the higher order - simplices it will have to be recomputed. - lin: torch.FloatTensor - The discretization of the interval [-1,1] each node height falls in this - range due to rescaling in normalizing the data. - weight: torch.FloatTensor - The weight of the node, edge or face. It is the maximum of the node - weights for the edges and faces. - scale: torch.FloatTensor - A single number that scales the sigmoid function by multiplying the - sigmoid with the scale. With high (100>) values, the ect will resemble a - discrete ECT and with lower values it will smooth the ECT. - """ + Parameters + ---------- + nh : torch.FloatTensor + The node heights, computed as the inner product of the node coordinates + x and the direction vector v. + index: torch.LongTensor + The index that indicates to which pointcloud a node height belongs. For + the node heights it is the same as the batch index, for the higher order + simplices it will have to be recomputed. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + weight: torch.FloatTensor + The weight of the node, edge or face. It is the maximum of the node + weights for the edges and faces. + scale: torch.FloatTensor + A single number that scales the sigmoid function by multiplying the + sigmoid with the scale. With high (100>) values, the ect will resemble a + discrete ECT and with lower values it will smooth the ECT. + """ ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh)) * weight.view( 1, -1, 1 ) @@ -52,19 +52,19 @@ def compute_wect( ): """Computes the Weighted Euler Characteristic Transform of a batch of point clouds. - Parameters - ---------- - batch : Batch - A batch of data containing the node coordinates, batch index, edge_index, face, and - node weights. - v: torch.FloatTensor - The direction vector that contains the directions. - lin: torch.FloatTensor - The discretization of the interval [-1,1] each node height falls in this - range due to rescaling in normalizing the data. - wect_type: str - The type of WECT to compute. Can be "points", "edges", or "faces". - """ + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, batch index, edge_index, face, and + node weights. + v: torch.FloatTensor + The direction vector that contains the directions. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + wect_type: str + The type of WECT to compute. Can be "points", "edges", or "faces". + """ nh = batch.x @ v if wect_type in ["edges", "faces"]: edge_weights, _ = batch.node_weights[batch.edge_index].max(axis=0) @@ -132,21 +132,21 @@ def forward(self, batch: Batch): """Forward method for the ECT Layer. - Parameters - ---------- - batch : Batch - A batch of data containing the node coordinates, edges, faces, - batch index, and node_weights. It should follow the pytorch geometric conventions. - - Returns - ---------- - wect: torch.FloatTensor - Returns the WECT of each data object in the batch. If the layer is - initialized with v of the shape [ndims,num_thetas], the returned WECT - has shape [batch,num_thetas,bump_steps]. In case the layer is - initialized with v of the form [n_channels, ndims, num_thetas] the - returned WECT has the shape [batch,n_channels,num_thetas,bump_steps] - """ + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, edges, faces, + batch index, and node_weights. It should follow the pytorch geometric conventions. + + Returns + ---------- + wect: torch.FloatTensor + Returns the WECT of each data object in the batch. If the layer is + initialized with v of the shape [ndims,num_thetas], the returned WECT + has shape [batch,num_thetas,bump_steps]. In case the layer is + initialized with v of the form [n_channels, ndims, num_thetas] the + returned WECT has the shape [batch,n_channels,num_thetas,bump_steps] + """ # Movedim for geotorch wect = compute_wect( batch, self.v.movedim(-1, -2), self.lin, self.config.ect_type