From d770a45141662a7768da972f78db32d84391fd20 Mon Sep 17 00:00:00 2001 From: Mihai Neagu Date: Fri, 17 Nov 2023 01:14:20 +0200 Subject: [PATCH] [Mihai Neagu] Write predict.py script --- model_utils.py | 76 +++++++++++++++++++++++++++++++++++++++++++++++--- predict.py | 59 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 4 deletions(-) diff --git a/model_utils.py b/model_utils.py index f52c7c8f..124312bd 100644 --- a/model_utils.py +++ b/model_utils.py @@ -3,6 +3,10 @@ from torchvision import models import json +import numpy as np + +from PIL import Image, ImageFile + import data_loaders import pathlib @@ -144,11 +148,11 @@ def save_model(model, save_dir): torch.save(checkpoint, f'{save_dir}/checkpoint.pth') -def load_model(save_dir): - if not pathlib.Path(save_dir).exists(): - raise FileNotFoundError(f'Save directory {save_dir} not found.') +def load_model(checkpoint_path): + if not pathlib.Path(checkpoint_path).exists(): + raise FileNotFoundError(f'Save directory {checkpoint_path} not found.') - checkpoint = torch.load(f'{save_dir}/checkpoint.pth') + checkpoint = torch.load(checkpoint_path) device = get_device(checkpoint['gpu']) @@ -168,3 +172,67 @@ def load_model(save_dir): model.eval() return model + + +def predict(model, path_to_image, gpu, top_k): + ''' Predict the class (or classes) of an image using a trained deep learning model. + ''' + if not pathlib.Path(path_to_image).exists(): + raise FileNotFoundError(f'Image {path_to_image} not found') + + cat_to_name = get_cat_to_name() + + device = get_device(gpu) + + model.eval() + + image = process_image(path_to_image).to(device=device).float().unsqueeze(0) + + with torch.no_grad(): + ps = torch.exp(model.forward(image)) + + probs, classes = ps.topk(top_k, dim=1) + + idx_to_class = {v: k for k, v in model.class_to_idx.items()} + + top_labels = [cat_to_name[idx_to_class[idx]] for idx in classes.cpu().numpy().tolist()[0]] + + probabilities = probs.cpu().numpy()[0] * 100 + + print(f'Flower name: {top_labels[0]} with a probability of {probabilities[0]}') + print(f'Top {top_k} most likely classes: ', top_labels) + + + +def crop_center(img, size=(224, 224)): + width, height = img.size + left = (width - size[0]) / 2 + top = (height - size[1]) / 2 + right = (width + size[0]) / 2 + bottom = (height + size[1]) / 2 + + cropped_img = img.crop((left, top, right, bottom)) + + return cropped_img + + +def process_image(image): + ''' Scales, crops, and normalizes a PIL image for a PyTorch model, + returns an Numpy array + ''' + with Image.open(image) as pil_image: + pil_image.thumbnail((256, 256)) + + pil_image = crop_center(pil_image) + + np_image = np.array(pil_image) + np_image = np_image / 255 + + means = np.array([0.485, 0.456, 0.406]) + deviations = np.array([0.229, 0.224, 0.225]) + + np_image = (np_image - means) / deviations + + np_image = np_image.transpose((2, 0, 1)) + + return torch.tensor(np_image) diff --git a/predict.py b/predict.py index e69de29b..9d5f23f3 100644 --- a/predict.py +++ b/predict.py @@ -0,0 +1,59 @@ +import argparse +import model_utils + + +def main(): + parser = argparse.ArgumentParser( + description='Use this script to predict an image class using a saved model', + add_help=True + ) + + parser.add_argument( + '--top_k', + default=3, + action='store', + type=int, + help="Top K most likely classes" + ) + + parser.add_argument( + '--category_names', + default='cat_to_name.json', + action='store', + type=str, + help="Path to a JSON file containing the mapping between numeric category and class names" + ) + + parser.add_argument( + '--gpu', + default=True, + action='store_true', + help="If passed, will use the GPU" + ) + + parser.add_argument( + 'path_to_image', + action='store', + type=str, + help="Path to an image for which to predict class" + ) + + parser.add_argument( + 'checkpoint', + action='store', + type=str, + help="Path to a checkpoint file containing a trained model" + ) + + results = parser.parse_args() + + model_utils.predict( + model=model_utils.load_model(results.checkpoint), + path_to_image=results.path_to_image, + gpu=results.gpu, + top_k=results.top_k + ) + + +if __name__ == '__main__': + main() \ No newline at end of file