Skip to content

Commit

Permalink
[Mihai Neagu] Write predict.py script
Browse files Browse the repository at this point in the history
  • Loading branch information
MihaiNeagu committed Nov 16, 2023
1 parent 3d308f9 commit d770a45
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 4 deletions.
76 changes: 72 additions & 4 deletions model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from torchvision import models
import json

import numpy as np

from PIL import Image, ImageFile

import data_loaders

import pathlib
Expand Down Expand Up @@ -144,11 +148,11 @@ def save_model(model, save_dir):
torch.save(checkpoint, f'{save_dir}/checkpoint.pth')


def load_model(save_dir):
if not pathlib.Path(save_dir).exists():
raise FileNotFoundError(f'Save directory {save_dir} not found.')
def load_model(checkpoint_path):
if not pathlib.Path(checkpoint_path).exists():
raise FileNotFoundError(f'Save directory {checkpoint_path} not found.')

checkpoint = torch.load(f'{save_dir}/checkpoint.pth')
checkpoint = torch.load(checkpoint_path)

device = get_device(checkpoint['gpu'])

Expand All @@ -168,3 +172,67 @@ def load_model(save_dir):
model.eval()

return model


def predict(model, path_to_image, gpu, top_k):
''' Predict the class (or classes) of an image using a trained deep learning model.
'''
if not pathlib.Path(path_to_image).exists():
raise FileNotFoundError(f'Image {path_to_image} not found')

cat_to_name = get_cat_to_name()

device = get_device(gpu)

model.eval()

image = process_image(path_to_image).to(device=device).float().unsqueeze(0)

with torch.no_grad():
ps = torch.exp(model.forward(image))

probs, classes = ps.topk(top_k, dim=1)

idx_to_class = {v: k for k, v in model.class_to_idx.items()}

top_labels = [cat_to_name[idx_to_class[idx]] for idx in classes.cpu().numpy().tolist()[0]]

probabilities = probs.cpu().numpy()[0] * 100

print(f'Flower name: {top_labels[0]} with a probability of {probabilities[0]}')
print(f'Top {top_k} most likely classes: ', top_labels)



def crop_center(img, size=(224, 224)):
width, height = img.size
left = (width - size[0]) / 2
top = (height - size[1]) / 2
right = (width + size[0]) / 2
bottom = (height + size[1]) / 2

cropped_img = img.crop((left, top, right, bottom))

return cropped_img


def process_image(image):
''' Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
'''
with Image.open(image) as pil_image:
pil_image.thumbnail((256, 256))

pil_image = crop_center(pil_image)

np_image = np.array(pil_image)
np_image = np_image / 255

means = np.array([0.485, 0.456, 0.406])
deviations = np.array([0.229, 0.224, 0.225])

np_image = (np_image - means) / deviations

np_image = np_image.transpose((2, 0, 1))

return torch.tensor(np_image)
59 changes: 59 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import argparse
import model_utils


def main():
parser = argparse.ArgumentParser(
description='Use this script to predict an image class using a saved model',
add_help=True
)

parser.add_argument(
'--top_k',
default=3,
action='store',
type=int,
help="Top K most likely classes"
)

parser.add_argument(
'--category_names',
default='cat_to_name.json',
action='store',
type=str,
help="Path to a JSON file containing the mapping between numeric category and class names"
)

parser.add_argument(
'--gpu',
default=True,
action='store_true',
help="If passed, will use the GPU"
)

parser.add_argument(
'path_to_image',
action='store',
type=str,
help="Path to an image for which to predict class"
)

parser.add_argument(
'checkpoint',
action='store',
type=str,
help="Path to a checkpoint file containing a trained model"
)

results = parser.parse_args()

model_utils.predict(
model=model_utils.load_model(results.checkpoint),
path_to_image=results.path_to_image,
gpu=results.gpu,
top_k=results.top_k
)


if __name__ == '__main__':
main()

0 comments on commit d770a45

Please sign in to comment.