-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
309 lines (245 loc) · 11.2 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import os
import time
import random
import numpy as np
import torch
from tqdm import tqdm
import torch.nn.functional as F
import wandb
# Default paths
from dataset import CLASSES
from functions import dice_coef
import matplotlib.pyplot as plt
from tools.streamlit.visualize import visualize_prediction
WRIST_CLASSES = [
'Trapezium',
'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
'Triquetrum', 'Pisiform'
]
def convert_seconds_to_hms(seconds):
"""초를 시, 분, 초로 변환하는 함수"""
hours = seconds // 3600
minutes = (seconds % 3600) // 60
seconds = seconds % 60
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
def train(model, data_loader, val_loader, criterion, optimizer, scheduler, num_epochs, val_interval, save_dir, use_roi = False):
print('Start training..')
best_dice = 0.
total_start_time = time.time() # 총 학습 시작 시간 추가
best_model_path = None # 최적 모델 파일 경로를 추적하기 위한 변수
for epoch in range(num_epochs):
epoch_start_time = time.time() # 에포크 시작 시간
model.train()
epoch_loss = 0
class_losses = torch.zeros(len(CLASSES)).cuda()
with tqdm(total=len(data_loader), desc=f"Epoch [{epoch+1}/{num_epochs}]") as pbar:
for step, (images, masks) in enumerate(data_loader):
images, masks = images.cuda(), masks.cuda()
model = model.cuda()
# outputs = model(images)['out']
outputs = model(images)
# roi를 사용할 시 손목 뼈 클래스만을 loss 계산에 사용
if use_roi:
loss = criterion(outputs[:, 19:27], masks[:, 19:27])
else:
loss = criterion(outputs, masks)
# 클래스별 손실 계산
for c in range(len(CLASSES)):
class_losses[c] += criterion(outputs[:, c:c+1], masks[:, c:c+1]).item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
pbar.set_postfix(loss=round(loss.item(), 4))
pbar.update(1)
# 에폭별 평균 계산
epoch_loss = epoch_loss / len(data_loader)
class_losses = class_losses / len(data_loader)
# 에포크 시간 계산
epoch_time = time.time() - epoch_start_time
remaining_epochs = num_epochs - (epoch + 1)
estimated_remaining_time = epoch_time * remaining_epochs
# 시, 분, 초로 변환
epoch_str = convert_seconds_to_hms(epoch_time)
remaining_str = convert_seconds_to_hms(estimated_remaining_time)
print(
f'Epoch {epoch+1} completed in {epoch_str}. '
f'Estimated remaining time: {remaining_str}.'
)
# Wandb 로깅 - train metrics
metrics = {
"total/train_loss": epoch_loss, # 총 train loss
**{f"train_loss_per_class/{c}": loss.item()
for c, loss in zip(CLASSES, class_losses)}, # 클래스별 train loss
"epoch": epoch + 1
}
# Validation 수행
if (epoch + 1) % val_interval == 0:
# roi를 사용할 시 손목 뼈 클래스만을 dice coef 계산에 사용
if use_roi:
dice, val_loss, class_val_losses, worst_samples, dices_per_class = validation_roi(
epoch + 1, model, val_loader, criterion)
cls = WRIST_CLASSES
else:
dice, val_loss, class_val_losses, worst_samples, dices_per_class = validation(
epoch + 1, model, val_loader, criterion)
cls = CLASSES
# scheduler
scheduler.step()
# Learning rate 로깅 추가
current_lr = optimizer.param_groups[0]['lr']
metrics["learning_rate"] = current_lr
# Wandb 로깅 - validation metrics
metrics.update({
"total/val_loss": val_loss, # 총 validation loss
"total/val_dice": dice, # 총 validation dice
**{f"val_loss_per_class/{c}": loss
for c, loss in zip(CLASSES, class_val_losses)}, # 클래스별 validation loss
**{f"val_dice_per_class/{c}": d.item()
for c, d in zip(cls, dices_per_class)} # 클래스별 validation dice
})
# Worst samples 시각화
for idx, (img, pred, true_mask, dice_score) in enumerate(worst_samples):
if use_roi:
fig = visualize_prediction(img, pred[19:27], true_mask[19:27])
else:
fig = visualize_prediction(img, pred, true_mask)
metrics[f"worst_sample_{idx+1}"] = wandb.Image(fig,
caption=f"Dice Score: {dice_score:.4f}")
plt.close(fig)
# Best model 저장_wandb의 Artifacts 탭에서 확인 가능
if best_dice < dice:
print(f"Best performance at epoch: {epoch + 1}, {best_dice:.4f} -> {dice:.4f}")
best_dice = dice
if best_model_path and os.path.exists(best_model_path):
os.remove(best_model_path)
best_model_path = os.path.join(save_dir, f"best_dice_{best_dice:.4f}.pt")
save_model(model, best_model_path)
wandb.run.summary.update({
"best_dice": best_dice,
"best_epoch": epoch + 1,
"best_model_path": best_model_path
})
# wandb 로깅
wandb.log(metrics)
total_time = time.time() - total_start_time
total_str = convert_seconds_to_hms(total_time)
print(f'Total training completed in {total_str}.')
def validation(epoch, model, data_loader, criterion, thr=0.5, num_worst_samples=4):
print(f'Start validation #{epoch:2d}')
model.eval()
dices = []
samples = []
total_loss = 0
class_losses = torch.zeros(len(CLASSES)).cuda()
with torch.no_grad():
for step, (images, masks) in tqdm(enumerate(data_loader), total=len(data_loader)):
images, masks = images.cuda(), masks.cuda()
outputs = model(images)
# outputs = model(images)
output_h, output_w = outputs.size(-2), outputs.size(-1)
mask_h, mask_w = masks.size(-2), masks.size(-1)
# gt와 prediction의 크기가 다른 경우 prediction을 gt에 맞춰 interpolation 합니다.
if output_h != mask_h or output_w != mask_w:
outputs = F.interpolate(outputs, size=(mask_h, mask_w), mode="bilinear")
# 전체 손실 계산
loss = criterion(outputs, masks)
total_loss += loss.item()
# 클래스별 손실 계산
for c in range(len(CLASSES)):
class_losses[c] += criterion(outputs[:, c:c+1], masks[:, c:c+1]).item()
outputs = torch.sigmoid(outputs)
outputs = (outputs > thr).detach().cpu()
masks = masks.detach().cpu()
# 배치 내 각 이미지에 대한 Dice score 계산
batch_dices = dice_coef(outputs, masks)
dices.append(batch_dices)
# worst samples 수집
for i in range(len(images)):
sample_dice = batch_dices[i].mean().item()
samples.append((
images[i].cpu().numpy().transpose(1,2,0),
outputs[i].numpy(),
masks[i].cpu().numpy(),
sample_dice
))
dices = torch.cat(dices, 0)
dices_per_class = torch.mean(dices, 0)
dice_str = [
f"{c:<12}: {d.item():.4f}"
for c, d in zip(CLASSES, dices_per_class)
]
dice_str = "\n".join(dice_str)
print(dice_str)
avg_loss = total_loss / len(data_loader)
class_losses = class_losses / len(data_loader)
avg_dice = torch.mean(dices_per_class).item()
# worst samples 정렬
worst_samples = sorted(samples, key=lambda x: x[3])[:num_worst_samples]
return avg_dice, avg_loss, class_losses.cpu().numpy(), worst_samples, dices_per_class
def validation_roi(epoch, model, data_loader, criterion, thr=0.5, num_worst_samples=4):
print(f'Start validation #{epoch:2d}')
model.eval()
dices = []
samples = []
total_loss = 0
class_losses = torch.zeros(len(CLASSES)).cuda()
with torch.no_grad():
for step, (images, masks) in tqdm(enumerate(data_loader), total=len(data_loader)):
images, masks = images.cuda(), masks.cuda()
# outputs = model(images)['out']
outputs = model(images)
output_h, output_w = outputs.size(-2), outputs.size(-1)
mask_h, mask_w = masks.size(-2), masks.size(-1)
# gt와 prediction의 크기가 다른 경우 prediction을 gt에 맞춰 interpolation 합니다.
if output_h != mask_h or output_w != mask_w:
outputs = F.interpolate(outputs, size=(mask_h, mask_w), mode="bilinear")
# 전체 손실 계산
loss = criterion(outputs, masks)
total_loss += loss.item()
# 클래스별 손실 계산
for c in range(len(CLASSES)):
class_losses[c] += criterion(outputs[:, c:c+1], masks[:, c:c+1]).item()
outputs = torch.sigmoid(outputs)
outputs = (outputs > thr).detach().cpu()
masks = masks.detach().cpu()
# 배치 내 각 이미지에 대한 Dice score 계산
batch_dices = dice_coef(outputs[:, 19:27], masks[:, 19:27])
dices.append(batch_dices)
# worst samples 수집
for i in range(len(images)):
sample_dice = batch_dices[i].mean().item()
samples.append((
images[i].cpu().numpy().transpose(1,2,0),
outputs[i].numpy(),
masks[i].cpu().numpy(),
sample_dice
))
dices = torch.cat(dices, 0)
dices_per_class = torch.mean(dices, 0)
# 손목 클래스에 대한 평균 Dice 계산
# target_classes = list(range(19, 27)) # 원하는 클래스 인덱스
# dices_target_classes = dices[:, target_classes] # 해당 클래스들만 선택
avg_dice_target = torch.mean(dices_per_class).item() # 선택된 클래스들의 평균 Dice
dice_str = [
f"{c:<12}: {d.item():.4f}"
for c, d in zip(WRIST_CLASSES, dices_per_class)
]
dice_str = "\n".join(dice_str)
print(dice_str)
avg_loss = total_loss / len(data_loader)
class_losses = class_losses / len(data_loader)
# avg_dice = torch.mean(dices_per_class).item()
# worst samples 정렬
worst_samples = sorted(samples, key=lambda x: x[3])[:num_worst_samples]
return avg_dice_target, avg_loss, class_losses.cpu().numpy(), worst_samples, dices_per_class
def save_model(model, model_path):
torch.save(model, model_path)
def set_seed(seed=123): #21
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)