Skip to content

Commit

Permalink
draw random crop from each patch (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzLamm authored Sep 18, 2024
1 parent 800f6ce commit 9e79ca1
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 3 deletions.
115 changes: 114 additions & 1 deletion src/membrain_seg/segmentation/dataloading/memseg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import imageio as io
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm

from membrain_seg.segmentation.dataloading.data_utils import read_nifti
from membrain_seg.segmentation.dataloading.memseg_augmentation import (
Expand All @@ -30,6 +31,9 @@ class CryoETMemSegDataset(Dataset):
aug_prob_to_one : bool, default False
A flag indicating whether the probability of augmentation should be
set to one or not.
patch_size : int, default 160
The size of the patches to be extracted from the images.
Methods
-------
Expand All @@ -54,6 +58,7 @@ def __init__(
label_folder: str,
train: bool = False,
aug_prob_to_one: bool = False,
patch_size: int = 160,
) -> None:
"""
Constructs all the necessary attributes for the CryoETMemSegDataset object.
Expand All @@ -69,9 +74,12 @@ def __init__(
aug_prob_to_one : bool, default False
A flag indicating whether the probability of augmentation should be set
to one or not.
patch_size : int, default 160
The size of the patches to be extracted from the images.
"""
self.train = train
self.img_folder, self.label_folder = img_folder, label_folder
self.patch_size = patch_size
self.initialize_imgs_paths()
self.load_data()
self.transforms = (
Expand Down Expand Up @@ -100,6 +108,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"image": np.expand_dims(self.imgs[idx], 0),
"label": np.expand_dims(self.labels[idx], 0),
}
idx_dict = self.get_random_crop(idx_dict)
idx_dict = self.transforms(idx_dict)
idx_dict["dataset"] = self.dataset_labels[idx]
return idx_dict
Expand All @@ -115,6 +124,110 @@ def __len__(self) -> int:
"""
return len(self.data_paths)

def get_random_crop(self, idx_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Returns a random crop from the image-label pair.
Parameters
----------
idx_dict : Dict[str, np.ndarray]
A dictionary containing an image and its corresponding label.
Returns
-------
Dict[str, np.ndarray]
A dictionary containing a random crop from the image and its corresponding
label.
"""
img, label = idx_dict["image"], idx_dict["label"]
x, y, z = img.shape[1:]

if x <= self.patch_size or y <= self.patch_size or z <= self.patch_size:
# pad with 2s on both sides
pad_x = max(self.patch_size - x, 0)
pad_y = max(self.patch_size - y, 0)
pad_z = max(self.patch_size - z, 0)
img = np.pad(
img,
(
(0, 0),
(pad_x // 2, pad_x // 2),
(pad_y // 2, pad_y // 2),
(pad_z // 2, pad_z // 2),
),
mode="constant",
constant_values=2,
)
label = np.pad(
label,
(
(0, 0),
(pad_x // 2, pad_x // 2),
(pad_y // 2, pad_y // 2),
(pad_z // 2, pad_z // 2),
),
mode="constant",
constant_values=0,
)
# make sure there was no rounding issue
if (
img.shape[1] < self.patch_size
or img.shape[2] < self.patch_size
or img.shape[3] < self.patch_size
):
img = np.pad(
img,
(
(0, 0),
(0, max(self.patch_size - img.shape[1], 0)),
(0, max(self.patch_size - img.shape[2], 0)),
(0, max(self.patch_size - img.shape[3], 0)),
),
mode="constant",
constant_values=2,
)
label = np.pad(
label,
(
(0, 0),
(0, max(self.patch_size - label.shape[1], 0)),
(0, max(self.patch_size - label.shape[2], 0)),
(0, max(self.patch_size - label.shape[3], 0)),
),
mode="constant",
constant_values=0,
)
assert (
img.shape[1] == self.patch_size
and img.shape[2] == self.patch_size
and img.shape[3] == self.patch_size
), f"Image shape is {img.shape} instead of {self.patch_size}"
return {"image": img, "label": label}

x_crop, y_crop, z_crop = self.patch_size, self.patch_size, self.patch_size
x_start = np.random.randint(0, x - x_crop)
y_start = np.random.randint(0, y - y_crop)
z_start = np.random.randint(0, z - z_crop)
img = img[
:,
x_start : x_start + x_crop,
y_start : y_start + y_crop,
z_start : z_start + z_crop,
]
label = label[
:,
x_start : x_start + x_crop,
y_start : y_start + y_crop,
z_start : z_start + z_crop,
]

assert (
img.shape[1] == self.patch_size
and img.shape[2] == self.patch_size
and img.shape[3] == self.patch_size
), f"Image shape is {img.shape} instead of {self.patch_size}"
return {"image": img, "label": label}

def load_data(self) -> None:
"""
Loads image-label pairs into memory from the specified directories.
Expand All @@ -127,7 +240,7 @@ def load_data(self) -> None:
self.imgs = []
self.labels = []
self.dataset_labels = []
for entry in self.data_paths:
for entry in tqdm(self.data_paths):
label = read_nifti(
entry[1]
) # TODO: Change this to be applicable to .mrc images
Expand Down
4 changes: 2 additions & 2 deletions src/membrain_seg/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def train(
)

checkpoint_callback_regular = ModelCheckpoint(
save_top_k=-1, # Save all checkpoints
every_n_epochs=100,
save_top_k=1, # Save all checkpoints
every_n_epochs=10,
dirpath="checkpoints/",
filename=checkpointing_name + "-{epoch}-{val_loss:.2f}",
verbose=True, # Print a message when a checkpoint is saved
Expand Down

0 comments on commit 9e79ca1

Please sign in to comment.