-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
109 lines (78 loc) · 2.74 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from constants import ImagesTransforms
def extract_patient(path: str) -> str:
return path.split("/")[-2]
def calculate_iou(pred, target, threshold=0.5):
"""Calculates the IoU between two images."""
pred = torch.sigmoid(pred)
pred = pred > threshold
target = target > threshold
intersection = (pred & target).float().sum((1, 2, 3))
union = (pred | target).float().sum((1, 2, 3))
iou = (intersection + 1e-6) / (union + 1e-6)
return iou.mean()
def diagnosis(mask_path: str) -> bool:
"""Checks wether a mask is all black,
and sets False if it is black (no tumor),
and True (tumor) otherwise.
"""
if np.max(cv2.imread(mask_path)) > 0:
return True
else:
return False
def preprocess_image(image_path):
"""Preprocess an image so it is in
the desired format for doing a prediction.
"""
transform = ImagesTransforms.IMAGE_TRANSFORM
image = Image.open(image_path)
image = transform(image)
image = image.unsqueeze(0)
return image
def predict_mask(model, image_path, device):
model.eval()
image = preprocess_image(image_path)
image = image.to(device)
with torch.no_grad():
output = model(image)
probs = torch.sigmoid(output)
predicted_mask = probs > 0.5
return predicted_mask.squeeze(0).squeeze(0).cpu().numpy()
# remove batch and channel dimensions
def print_res(image, mask, predict_mask):
"""Displays the result along with original, grayscale, and seeded region."""
_, ax = plt.subplots(1, 3, figsize=(20, 6))
ax[0].imshow(image)
ax[0].set_title("Original Image")
ax[0].axis("off")
ax[1].imshow(mask, cmap="gray")
ax[1].set_title("Mask")
ax[1].axis("off")
ax[2].imshow(predict_mask, cmap="gray")
ax[2].set_title("Predicted Mask")
ax[2].axis("off")
plt.show()
class CustomDataset(Dataset):
"""Creates a PyTorch dataset from Pandas dataframe
with columns image_path and mask_path."""
def __init__(self, dataframe, image_transform=None, mask_transform=None):
self.dataframe = dataframe
self.image_transform = image_transform
self.mask_transform = mask_transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
image_path = self.dataframe.iloc[idx, 0]
mask_path = self.dataframe.iloc[idx, 1]
image = Image.open(image_path)
mask = Image.open(mask_path)
if self.image_transform:
image = self.image_transform(image)
if self.mask_transform:
mask = self.mask_transform(mask)
return image, mask