-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunctions.py
39 lines (32 loc) · 1 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import numpy as np
## Dice Coefficient
def dice_coef(y_true, y_pred):
y_true_f = y_true.flatten(2)
y_pred_f = y_pred.flatten(2)
intersection = torch.sum(y_true_f * y_pred_f, -1)
eps = 0.0001
return (2. * intersection + eps) / (torch.sum(y_true_f, -1) + torch.sum(y_pred_f, -1) + eps)
## Mask -> RLE
def encode_mask_to_rle(mask):
'''
mask: numpy array binary mask
1 - mask
0 - background
Returns encoded run length
'''
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
# RLE -> Mask
def decode_rle_to_mask(rle, height, width):
s = np.array(rle.split(), dtype=int)
starts, lengths = s[0::2] - 1, s[1::2]
ends = starts + lengths
mask = np.zeros(height * width, dtype=np.int32)
mask[starts] += 1
mask[ends] -= 1
mask = np.cumsum(mask)
return mask.reshape(height, width).astype(np.uint8)