Skip to content

Commit

Permalink
[Mihai Neagu] Write train.py script
Browse files Browse the repository at this point in the history
  • Loading branch information
MihaiNeagu committed Nov 16, 2023
1 parent f6f86f1 commit 3d308f9
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
/.idea
/playground.py

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
Binary file removed checkpoint.pth
Binary file not shown.
51 changes: 51 additions & 0 deletions data_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from torchvision import datasets, transforms
import torch
import pathlib


def get_data_loaders(data_dir):
if not pathlib.Path(data_dir).exists():
raise FileNotFoundError(f'Data directory not found {data_dir}')

train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'

# Define transforms for the training, validation, and testing sets
train_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])

validation_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])

# Load the datasets with ImageFolder
image_datasets = {
'train': datasets.ImageFolder(train_dir, transform=train_transforms),
'test': datasets.ImageFolder(test_dir, transform=test_transforms),
'validation': datasets.ImageFolder(valid_dir, transform=validation_transforms)
}

# Using the image datasets and the trainforms, define the dataloaders
data_loaders = {
'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=64, shuffle=True),
'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=64),
'validation': torch.utils.data.DataLoader(image_datasets['validation'], batch_size=64, shuffle=True)
}

return {
'datasets': image_datasets,
'data_loaders': data_loaders
}
2 changes: 2 additions & 0 deletions model_checkpoints/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
170 changes: 170 additions & 0 deletions model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import torch
from torch import nn, optim
from torchvision import models
import json

import data_loaders

import pathlib


def get_cat_to_name():
with open('cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)

return cat_to_name


def train_model(arch, learning_rate, data_dir, hidden_units, epochs, save_dir=None, gpu=True):
'''
Trains a model given the provided hyperparameters
:param arch:
:return: trained model
'''

device = get_device(gpu)

if save_dir is not None and not pathlib.Path(save_dir).exists():
raise FileNotFoundError(f'Save directory {save_dir} not found.')

if arch not in models.list_models():
raise RuntimeError(f'No such model {arch}. Available models: {models.list_models()}')

data = data_loaders.get_data_loaders(data_dir)

dataloaders = data['data_loaders']
datasets = data['datasets']


model = models.get_model(arch, weights=models.get_model_weights(arch))

for param in model.parameters():
param.requires_grad = False

classifier = get_classifier(hidden_units)

model.classifier = classifier
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=learning_rate)
model.to(device)

steps = 0
running_loss = 0
print_every = 5

for epoch in range(epochs):
for inputs, labels in dataloaders['train']:
steps += 1

inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()

logps = model.forward(inputs)
loss = criterion(logps, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()

if steps % print_every == 0:
test_loss = 0
accuracy = 0
model.eval()
with torch.no_grad():
for inputs, labels in dataloaders['test']:
inputs, labels = inputs.to(device), labels.to(device)
logps = model.forward(inputs)
batch_loss = criterion(logps, labels)

test_loss += batch_loss.item()

# Calculate accuracy
ps = torch.exp(logps)
top_p, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

print(f"Epoch {epoch + 1}/{epochs}.. "
f"Train loss: {running_loss / print_every:.3f}.. "
f"Test loss: {test_loss / len(dataloaders['test']):.3f}.. "
f"Test accuracy: {accuracy / len(dataloaders['test']):.3f}")
running_loss = 0
model.train()

model.arch = arch
model.hidden_units = hidden_units
model.optimizer = optimizer
model.epochs = epochs
model.gpu = gpu
model.class_to_idx = datasets['train'].class_to_idx

if save_dir is not None:
save_model(model, save_dir)

return model


def get_classifier(hidden_units):
cat_to_name = get_cat_to_name()
classifier = nn.Sequential(nn.Linear(1024, hidden_units),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_units, len(cat_to_name)),
nn.LogSoftmax(dim=1))
return classifier


def get_device(gpu):
if gpu and torch.cuda.is_available():
device = 'cuda'
else:
if gpu:
print(f'Using cpu since cuda is not available')

device = 'cpu'

return device


def save_model(model, save_dir):
if not pathlib.Path(save_dir).exists():
raise FileNotFoundError(f'Save directory {save_dir} not found.')

checkpoint = {
'arch': model.arch,
'hidden_units': model.hidden_units,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': model.optimizer.state_dict(),
'epochs': model.epochs,
'gpu': model.gpu,
'class_to_idx': model.class_to_idx
}

torch.save(checkpoint, f'{save_dir}/checkpoint.pth')


def load_model(save_dir):
if not pathlib.Path(save_dir).exists():
raise FileNotFoundError(f'Save directory {save_dir} not found.')

checkpoint = torch.load(f'{save_dir}/checkpoint.pth')

device = get_device(checkpoint['gpu'])

model = models.get_model(checkpoint['arch'], weights=models.get_model_weights(checkpoint['arch']))
model.classifier = get_classifier(checkpoint['hidden_units'])
model.load_state_dict(checkpoint['model_state_dict'])
model.optimizer = optim.Adam(model.classifier.parameters(), lr=0.003)
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

model.arch = checkpoint['arch']
model.hidden_units = checkpoint['hidden_units']
model.epochs = checkpoint['epochs']
model.gpu = checkpoint['gpu']
model.class_to_idx = checkpoint['class_to_idx']

model.to(device)
model.eval()

return model
79 changes: 79 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import model_utils


def main():
parser = argparse.ArgumentParser(
description='Use this script to train a model used for classifying flowers',
add_help=True
)

parser.add_argument(
'--save_dir',
default='./model_checkpoints',
action='store',
type=str,
help="Directory path for saving checkpoints"
)

parser.add_argument(
'--arch',
action='store',
default='densenet121',
type=str,
help="The architecture for the pretrained network on top of which to train classifier"
)

parser.add_argument(
'--learning_rate',
action='store',
default=0.003,
type=float,
help="The learning rate for backpropagation"
)

parser.add_argument(
'--hidden_units',
default=256,
action='store',
type=int,
help="The hidden units of the classifier"
)

parser.add_argument(
'--epochs',
default=2,
action='store',
type=int,
help="How many epochs should the classifier iterate"
)

parser.add_argument(
'--gpu',
default=True,
action='store_true',
help="If passed, will use the GPU"
)

parser.add_argument(
'data_dir',
action='store',
type=str,
help="Data directory path"
)

results = parser.parse_args()

model_utils.train_model(
arch=results.arch,
data_dir=results.data_dir,
save_dir=results.save_dir,
learning_rate=results.learning_rate,
hidden_units=results.hidden_units,
epochs=results.epochs,
gpu=results.gpu
)


if __name__ == '__main__':
main()

0 comments on commit 3d308f9

Please sign in to comment.