-
Notifications
You must be signed in to change notification settings - Fork 0
/
engiine.py
153 lines (117 loc) · 4.81 KB
/
engiine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import torch
from tqdm.autonotebook import tqdm
from typing import Dict, List, Tuple
import wandb
import logger
def train_step(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
Accumulation_steps:int,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
batch_cnt,
exp_cnt,
epoch)-> Tuple[float, float]:
model.train()
# Setup train loss and train accuracy values
train_loss, train_acc = 0, 0
# Loop through data loader data batches
for batch, (X, y) in tqdm(enumerate(dataloader), total =len(dataloader)):
# Send data to target device
X, y = X.to(device), y.to(device)
# 1. Forward pass
y_pred = model(X)
# 2. Calculate and accumulate loss
loss = loss_fn(y_pred, y)
exp_cnt += len(X)
batch_cnt += 1
train_loss += loss.item()
if ((batch_cnt + 1) % 25) == 0:
logger.train_log(loss, exp_cnt, epoch)
loss = loss/Accumulation_steps
loss.backward()
if ((batch + 1) % Accumulation_steps == 0) or (batch + 1 == len(dataloader)):
# Update Optimizer
optimizer.step()
# 3. Optimizer zero grad
optimizer.zero_grad()
# 4. Loss backward
# 5. Optimizer step
# Calculate and accumulate accuracy metric across all batches
y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
train_acc += (y_pred_class == y).sum().item()/len(y_pred)
# Adjust metrics to get average loss and accuracy per batch
train_loss = train_loss / len(dataloader)
train_acc = train_acc / len(dataloader)
return train_loss, train_acc
###########################
def val_step(model: torch.nn.Module,
val_dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
device: torch.device) -> Tuple[float, float]:
# Put model in eval mode
model.eval()
# Setup val loss and val accuracy values
val_loss, val_acc = 0, 0
# Turn on inference context manager
# Loop through DataLoader batches
for batch, (X, y) in enumerate(val_dataloader):
# Send data to target device
X, y = X.to(device), y.to(device)
# 1. Forward pass
val_pred_logits = model(X)
# 2. Calculate and accumulate loss
loss = loss_fn(val_pred_logits, y)
val_loss += loss.item()
# Calculate and accumulate accuracy
val_pred_labels = val_pred_logits.argmax(dim=1)
val_acc += ((val_pred_labels == y).sum().item()/len(val_pred_labels))
# Adjust metrics to get average loss and accuracy per batch
val_loss = val_loss / len(val_dataloader)
val_acc = val_acc / len(val_dataloader)
return val_loss, val_acc
########Main func for Training#############
def train(model: torch.nn.Module,
train_dataloader: torch.utils.data.DataLoader,
val_dataloader: torch.utils.data.DataLoader,
Accumulation_steps,
optimizer: torch.optim.Optimizer,
loss_fn: torch.nn.Module,
epochs: int,
device: torch.device) -> Dict[str, List]:
results = {"train_loss": [],
"train_acc": [],
"val_loss": [],
"val_acc": []}
batch_cnt=0
exp_cnt=0
for epoch in range(epochs):
wandb.watch(model, loss_fn, log="all", log_freq=100)
train_loss, train_acc = train_step(model=model,
dataloader=train_dataloader,
Accumulation_steps=Accumulation_steps,
loss_fn=loss_fn,
optimizer=optimizer,
device=device,
batch_cnt=batch_cnt,
exp_cnt=exp_cnt,
epoch=epoch)
val_loss, val_acc = val_step(model=model,
val_dataloader=val_dataloader,
loss_fn=loss_fn,
device=device)
# Print out what's happening
print(
f"Epoch: {epoch+1} | "
f"train_loss: {train_loss:.4f} | "
f"train_acc: {train_acc:.4f} | "
f"val_loss: {val_loss:.4f} | "
f"val_acc: {val_acc:.4f} | "
)
# Update results dictionary
results["train_loss"].append(train_loss)
results["train_acc"].append(train_acc)
results["train_loss"].append(val_loss)
results["train_acc"].append(val_acc)
return results