-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
117 lines (100 loc) · 4.67 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import timm
from torchvision import datasets, transforms
from datetime import datetime
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description="Train a model on ImageNet using timm")
parser.add_argument('--data-dir', type=str, default='/oscar/data/tserre/data/ImageNet/ILSVRC/Data/CLS-LOC',
help='Path to ImageNet dataset')
parser.add_argument('--epochs', type=int, default=90, help='Number of training epochs')
parser.add_argument('--batch-size', type=int, default=256, help='Batch size for training')
parser.add_argument('--learning-rate', type=float, default=0.1, help='Initial learning rate')
parser.add_argument('--workers', type=int, default=4, help='Number of data loading workers')
parser.add_argument('--model-name', type=str, default='resnet34', help='Model name from timm')
args = parser.parse_args()
return args
def main():
args = parse_args()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join('checkpoints', f"{args.model_name}_{timestamp}")
os.makedirs(checkpoint_dir, exist_ok=True)
print(checkpoint_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Standard ImageNet transforms
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ImageNet dataset loaders
train_dataset = datasets.ImageFolder(os.path.join(args.data_dir, 'train'), transform=train_transforms)
val_dataset = datasets.ImageFolder(os.path.join(args.data_dir, 'val'), transform=val_transforms)
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
# Model, criterion, and optimizer
model = timm.create_model(args.model_name, pretrained=False, num_classes=1000)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
best_acc = 0.0
for epoch in range(args.epochs):
# Training
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{args.epochs}"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss = running_loss / total
train_acc = correct / total
# Validation
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(val_loader, desc="Validating"):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_loss /= total
val_acc = correct / total
scheduler.step()
print(f"Epoch [{epoch+1}/{args.epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
# Save checkpoint every 5 epochs and if best accuracy is achieved
if epoch % 5 == 0 or val_acc > best_acc:
checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
torch.save(model.state_dict(), checkpoint_path)
if val_acc > best_acc:
best_acc = val_acc
if __name__ == '__main__':
main()