-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
32 lines (19 loc) · 864 Bytes
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pandas as pd
def train(args, model, dataloader, criterion, optimizer, device):
summary = pd.DataFrame(columns=['Epoch', 'Loss'])
for epoch in range(args.epochs):
print(f'Epoch {epoch}')
model.train()
train_loss = 0.0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(dataloader)
print(f'Epoch {epoch} | Loss: {train_loss}')
summary = pd.concat([summary, pd.DataFrame([[epoch, train_loss]], columns=['Epoch', 'Loss'])])
summary.to_csv('summary.csv', index=False)