-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
55 lines (49 loc) · 1.89 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
def L1_loss(pred, target):
loss = nn.L1Loss()(pred, target)
return loss
def MSE(pred, target):
loss = nn.MSELoss()(pred, target)
return loss
def RMSE(pred, target, metrics=None):
loss = (((pred-target)**2).mean())**0.5
return loss
'''
Building segments loss function: Takes in two np arrays:
When the input is 0, there if the pred is not 0 - then it leads to increase in a count.
This count is averaged acrossed samples.
Need to be between (input,pred) and black pixels in input need to be matched with the black pixels in the pred.
'''
def building_segments(input,pred,target):
building_prediction = torch.where((input!=0)&(pred!=0),1,0)
building_input = torch.where((input!=0),1,0)
building_prediction_sum = building_prediction.sum()
building_input_sum = building_input.sum()
avg_across_one_batch = building_prediction_sum/building_input_sum
'''
Uncomment the following lines to debug these outputs :
print("Building_prediction_sum ",building_prediction_sum)
print("Building input sum ",building_input_sum)
print("Avg across one batch ",avg_across_one_batch)
'''
return avg_across_one_batch
'''
It computes the rmse in the ROI region alone and avoids checking the
building segmentation loss as well that is unrelated to the problem statement.
'''
def roi_rmse_loss(input,pred,target):
input = input[:,0,:,:].unsqueeze(1)
#print("Input dimension ",input.shape)
error_tensor = torch.where(input==0,(pred-target)**2,0)
sum_torch = error_tensor.sum()
#print("Sum torch ",sum_torch)
count_non_zero = error_tensor.count_nonzero()
error_tensor_mean = sum_torch/count_non_zero
#print("Mean : ",error_tensor_mean)
#print("Max value: ",error_tensor.max())
error_float = (error_tensor_mean)**0.5
#building_sum = torch.sum(building_tensor)
#print("count_non_zero:",count_non_zero)
#print("Error float per batch ",error_float)
return error_float