-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
128 lines (101 loc) · 3.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import os
import torch
from tqdm.auto import tqdm
import wandb
from utils.utils import get_device, save_checkpoint
from validate import early_stop_validation, validate_fn
def train_fn(loader, model, optimizer, loss_fn, scaler, config):
device = get_device(config)
logging.info("Training model...")
loop = tqdm(
loader,
position=1,
leave=False,
postfix={"loss": 0.0},
desc="Training One Epoch: ",
)
closs = 0.0
model.train()
for data, targets in loop:
data, targets = data.to(device), targets.long().to(device)
# forward
optimizer.zero_grad()
with torch.cuda.amp.autocast():
predictions = model(data)
loss = loss_fn(predictions, targets)
# backward
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update tqdm loop
loop.set_postfix(loss=loss.item())
# wandb logging
wandb.log({"batch loss": loss.item()})
closs += loss.item()
wandb.log({"loss": closs / config.hyperparameters.batch_size})
loop.close()
return loss.item()
def train_loop(
train_loader,
val_loader,
model,
optimizer,
scheduler,
loss_fn,
scaler,
stopping,
global_metrics,
label_metrics,
config,
):
checkpoint_dir = "data/checkpoints/"
os.makedirs(checkpoint_dir, exist_ok=True)
name = f"{config.project.time}_best_checkpoint.pth.tar"
best_checkpoint_path = os.path.join(checkpoint_dir, name)
outer_loop = tqdm(
range(config.project.epoch, config.project.num_epochs),
position=0,
postfix={"loss": 0.0, "val_loss": 0.0},
desc="Starting Training: ",
)
for idx, epoch in enumerate(outer_loop):
outer_loop.set_description_str(f"EPOCH {idx}|{config.project.num_epochs}: ")
logging.info(f"Starting epoch {epoch}...")
config.project.epoch = epoch
wandb.log({"epoch": epoch})
train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler, config)
# check accuracy
val_loss = validate_fn(
val_loader,
model,
loss_fn,
scheduler,
global_metrics,
label_metrics,
config,
)
outer_loop.set_postfix(loss=train_loss, val_loss=val_loss)
# save model
logging.info("Saving trained weights...")
checkpoint = {
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}
stopping(
val_loss,
checkpoint,
checkpoint_path=best_checkpoint_path,
epoch=epoch,
)
if not stopping.early_stop:
save_checkpoint(checkpoint)
continue
early_stop_validation(val_loader, model, global_metrics, label_metrics, config)
break
wandb.finish()
logging.info("Training finished...")