Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Aug 9, 2023
1 parent c52dc88 commit 1a950d7
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 17 deletions.
51 changes: 43 additions & 8 deletions tardis_pytorch/dist_pytorch/datasets/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ class BuildGraph:
K (int): Number of maximum connections per node.
"""

def __init__(self, K=2):
def __init__(self, K=2, mesh=False):
self.K = K
self.mesh = mesh

def __call__(self, coord: np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -201,15 +202,49 @@ def __call__(self, coord: np.ndarray) -> np.ndarray:
# build the connectivity matrix
N = coord.shape[0]
graph = np.zeros((N, N))
for i in range(N):
for j in indices[i]:
if class_id[i] == class_id[j]: # check class ID before adding edges
graph[i, j] = 1
# graph[j, i] = 1
if self.mesh:
for i in range(N):
for j in indices[i]:
if class_id[i] == class_id[j]: # check class ID before adding edges
graph[i, j] = 1
# graph[j, i] = 1
else:
all_idx = np.unique(coord[:, 0])
for i in all_idx:
points_in_contour = np.where(coord[:, 0] == i)[0].tolist()

for j in points_in_contour:
# Self-connection
graph[j, j] = 1

# First point in contour
if j == points_in_contour[0]: # First point
if (j + 1) <= (len(coord) - 1):
graph[j, j + 1] = 1
graph[j + 1, j] = 1
# Last point
elif j == points_in_contour[len(points_in_contour) - 1]:
graph[j, j - 1] = 1
graph[j - 1, j] = 1
else: # Point in the middle
graph[j, j + 1] = 1
graph[j + 1, j] = 1
graph[j, j - 1] = 1
graph[j - 1, j] = 1

# Check euclidean distance between fist and last point
ends_distance = np.linalg.norm(
coord[points_in_contour[0]][1:] - coord[points_in_contour[-1]][1:]
)

# If < 2 nm pixel size, connect
if ends_distance < 2:
graph[points_in_contour[0], points_in_contour[-1]] = 1
graph[points_in_contour[-1], points_in_contour[0]] = 1

# Ensure self-connection
range_ = list(range(len(graph)))
graph[range_, range_] = 1
np.fill_diagonal(graph, 1)

return graph


Expand Down
6 changes: 3 additions & 3 deletions tardis_pytorch/dist_pytorch/datasets/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# MIT License 2021 - 2023 #
#######################################################################

from typing import Tuple, Union
from typing import Tuple, Union, List

import numpy as np
import torch
Expand Down Expand Up @@ -162,7 +162,7 @@ def points_in_patch(self, coord: np.ndarray, patch_center: np.ndarray) -> bool:

return coord_idx

def optimal_patches(self, coord: np.ndarray, random=False) -> list[bool]:
def optimal_patches(self, coord: np.ndarray, random=False) -> List[bool]:
"""
Main class function to compute optimal patch size.
Expand Down Expand Up @@ -389,7 +389,7 @@ def patched_dataset(
mesh=6,
random=False,
voxel_size=None,
) -> Union[Tuple[list, list, list, list, list], Tuple[list, list, list, list]]:
) -> Union[Tuple[List, List, List, List, List], Tuple[List, List, List, List]]:
coord_patch = []
graph_patch = []
output_idx = []
Expand Down
2 changes: 1 addition & 1 deletion tardis_pytorch/dist_pytorch/sparse_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def embed_input(self, coords: torch.tensor) -> torch.tensor:

return x, idx

def forward(self, coords: torch.tensor, idx=None) -> torch.tensor:
def forward(self, coords: torch.tensor) -> torch.tensor:
"""
Forward pass for the SparseDIST.
Expand Down
2 changes: 1 addition & 1 deletion tardis_pytorch/dist_pytorch/sparse_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def __init__(self, n_out: int, sigma: list, knn: int, _device):
self.sigma = sigma
self._device = _device

def forward(self, input_coord: np.ndarray) -> Union[torch.tensor, list]:
def forward(self, input_coord: torch.tensor) -> Union[torch.tensor, list]:
with torch.no_grad():
# Get all ij element from row and col
input_coord = input_coord.cpu().detach().numpy()
Expand Down
2 changes: 1 addition & 1 deletion tardis_pytorch/dist_pytorch/sparse_model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(self, x: torch.tensor, indices: list) -> torch.tensor:
else:
k = torch.einsum('ik,ijk->ik', a, b[indices[1]])

return torch.sigmoid(self.gate_o(x)) * self.linear_o(self.norm_o(a))
return torch.sigmoid(self.gate_o(x)) * self.linear_o(self.norm_o(k))


def sparse_to_dense(x: list, numpy=False) -> np.ndarray:
Expand Down
13 changes: 13 additions & 0 deletions tardis_pytorch/predict_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
help="Directory with images for prediction with CNN model.",
show_default=True,
)
@click.option(
"-ch",
"--checkpoint",
default=None,
type=str,
help="Optional list of pre-trained weights",
show_default=True,
)
@click.option(
"-out",
"--output_format",
Expand Down Expand Up @@ -133,6 +141,7 @@
@click.version_option(version=version)
def main(
dir: str,
checkpoint: str,
output_format: str,
patch_size: int,
rotate: bool,
Expand All @@ -151,9 +160,13 @@ def main(
else:
instances = True

if checkpoint is not None:
checkpoint = [checkpoint, None]

predictor = DataSetPredictor(
predict="Membrane",
dir_=dir,
checkpoint=checkpoint,
output_format=output_format,
patch_size=patch_size,
cnn_threshold=cnn_threshold,
Expand Down
18 changes: 15 additions & 3 deletions tardis_pytorch/utils/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,32 @@ def __init__(self, src_am: str, src_img: Optional[str] = None):
if frame == 10000:
break

if not any(
binary = False
spatial_graph = ""
if any(
[
True
for i in ["AmiraMesh 3D ASCII", "# ASCII Spatial Graph"]
if i not in am
]
):
self.spatial_graph = None
else:
if "AmiraMesh BINARY-LITTLE-ENDIAN 3.0" not in am:
spatial_graph = None
else:
binary = True
if spatial_graph is not None:
self.spatial_graph = (
open(src_am, "r", encoding="iso-8859-1").read().split("\n")
)
self.spatial_graph = [x for x in self.spatial_graph if x != ""]

if binary:
return None
# self.spatial_graph = self.am_decode(self.spatial_graph)

def __am_decode(self, am: str) -> str:
pass

def __get_segments(self) -> Union[np.ndarray, None]:
"""
Helper class function to read segment data from amira file.
Expand Down

0 comments on commit 1a950d7

Please sign in to comment.