-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
123 lines (95 loc) · 3.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import time
from glob import glob
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from preprocess import DriveDataset
from model import build_unet
from loss import DiceLoss, DiceBCELoss
from utils import seeding, create_dir, epoch_time
def train(model, loader, optimizer, loss_fn, device):
epoch_loss = 0.0
model.train()
for x, y in loader:
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss = epoch_loss/len(loader)
return epoch_loss
def evaluate(model, loader, loss_fn, device):
epoch_loss = 0.0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
y_pred = model(x)
loss = loss_fn(y_pred, y)
epoch_loss += loss.item()
epoch_loss = epoch_loss/len(loader)
return epoch_loss
if __name__ == "__main__":
""" Seeding """
seeding(42)
""" Directories """
create_dir("files")
""" Load dataset """
train_x = sorted(glob("data/train/image/*"))
train_y = sorted(glob("data/train/mask/*"))
valid_x = sorted(glob("data/test/image/*"))
valid_y = sorted(glob("data/test/mask/*"))
data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
print(data_str)
""" Hyperparameters """
H = 512
W = 512
size = (H, W)
batch_size = 2
num_epochs = 50
lr = 1e-4
checkpoint_path = "files/checkpoint.pth"
""" Dataset and loader """
train_dataset = DriveDataset(train_x, train_y)
valid_dataset = DriveDataset(valid_x, valid_y)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2
)
device = torch.device('cuda') ## GTX 1060 6GB
model = build_unet()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()
""" Training the model """
best_valid_loss = float("inf")
for epoch in range(num_epochs):
start_time = time.time()
train_loss = train(model, train_loader, optimizer, loss_fn, device)
valid_loss = evaluate(model, valid_loader, loss_fn, device)
""" Saving the model """
if valid_loss < best_valid_loss:
data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
print(data_str)
best_valid_loss = valid_loss
torch.save(model.state_dict(), checkpoint_path)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
data_str += f'\tTrain Loss: {train_loss:.3f}\n'
data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
print(data_str)