-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
68 lines (40 loc) · 1.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
@author: sourav
We define the loss functions and write the training loop in this module.
"""
import numpy as np
import tensorflow as tf
from tensorflow import image
from tensorflow import keras
from tensorflow import math
from keras import losses
import losses.MeanSquaredError as MSE
from keras import layers
from keras import utils
from keras import metrics
from keras import backend as K
from keras import initialiazers
import os
import random
import matplotlib.pyplot as plt
###########################################################################################################################
# defining the CB loss function
def norm_CB(z, l_cutoff = 0.495, u_cutoff = 0.505):
gate = math.logical_and(math.greater(z,l_cutoff), math.greater(u_cutoff,z))
z = tf.clip_by_value(z, clip_value_min = K.epsilon(), clip_value_max = 1 - K.epsilon())
norm_reg = (2*math.atanh(1 - 2*z_reg))/(1 - 2*z_reg)
norm_taylor = 2.0 + (8.0/3.0)*math.pow(z-0.5,2) + (32.0/5.0)*math.pow(z-0.5,4) + (128.0/7.0)*math.pow(z-0.5,6)
norm = tf.where(gate, norm_taylor, norm_reg)
return norm
@tf.function
def CB_logloss(true, pred):
true = layers.Flatten()(true)
pred = layers.Flatten()(pred)
bce = losses.binary_crossentropy(true,pred)
corrected_loss_tensor = bce - tf.reduce_mean(math.log(norm_CB(pred)), axis=-1 )
return tf.reduce_mean(corrected_loss_tensor)
##########################################################################################################################
# SSIM loss
def SSIMloss(true,pred):
return 1 - image.ssim(true,pred,1.0)
##########################################################################################################################