-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathtrain.py
114 lines (92 loc) · 3.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import numpy as np
import cv2
from glob import glob
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Recall, Precision
from model import deeplabv3_plus
from metrics import dice_loss, dice_coef, iou
""" Global parameters """
H = 512
W = 512
""" Creating a directory """
def create_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def shuffling(x, y):
x, y = shuffle(x, y, random_state=42)
return x, y
def load_data(path):
x = sorted(glob(os.path.join(path, "image", "*png")))
y = sorted(glob(os.path.join(path, "mask", "*png")))
return x, y
def read_image(path):
path = path.decode()
x = cv2.imread(path, cv2.IMREAD_COLOR)
x = x/255.0
x = x.astype(np.float32)
return x
def read_mask(path):
path = path.decode()
x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
x = x.astype(np.float32)
x = np.expand_dims(x, axis=-1)
return x
def tf_parse(x, y):
def _parse(x, y):
x = read_image(x)
y = read_mask(y)
return x, y
x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
x.set_shape([H, W, 3])
y.set_shape([H, W, 1])
return x, y
def tf_dataset(X, Y, batch=2):
dataset = tf.data.Dataset.from_tensor_slices((X, Y))
dataset = dataset.map(tf_parse)
dataset = dataset.batch(batch)
dataset = dataset.prefetch(10)
return dataset
if __name__ == "__main__":
""" Seeding """
np.random.seed(42)
tf.random.set_seed(42)
""" Directory for storing files """
create_dir("files")
""" Hyperparameters """
batch_size = 2
lr = 1e-4
num_epochs = 20
model_path = os.path.join("files", "model.h5")
csv_path = os.path.join("files", "data.csv")
""" Dataset """
dataset_path = "new_data"
train_path = os.path.join(dataset_path, "train")
valid_path = os.path.join(dataset_path, "test")
train_x, train_y = load_data(train_path)
train_x, train_y = shuffling(train_x, train_y)
valid_x, valid_y = load_data(valid_path)
print(f"Train: {len(train_x)} - {len(train_y)}")
print(f"Valid: {len(valid_x)} - {len(valid_y)}")
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)
""" Model """
model = deeplabv3_plus((H, W, 3))
model.compile(loss=dice_loss, optimizer=Adam(lr), metrics=[dice_coef, iou, Recall(), Precision()])
callbacks = [
ModelCheckpoint(model_path, verbose=1, save_best_only=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7, verbose=1),
CSVLogger(csv_path),
TensorBoard(),
EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=False),
]
model.fit(
train_dataset,
epochs=num_epochs,
validation_data=valid_dataset,
callbacks=callbacks
)