-
Notifications
You must be signed in to change notification settings - Fork 0
/
10_gradio_app.py
57 lines (40 loc) · 1.54 KB
/
10_gradio_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import time
from urllib.request import urlopen
import gradio as gr
import numpy as np
import onnxruntime as ort
import torch
from PIL import Image
from imagenet_classes import IMAGENET2012_CLASSES
def read_image(image: Image.Image):
image = image.convert("RGB")
img_numpy = np.array(image).astype(np.float32)
img_numpy = img_numpy.transpose(2, 0, 1)
img_numpy = np.expand_dims(img_numpy, axis=0)
return img_numpy
providers = ["CPUExecutionProvider"]
session = ort.InferenceSession("merged_model_compose.onnx", providers=providers)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
def predict(img):
output = session.run([output_name], {input_name: read_image(img)})
output = torch.from_numpy(output[0])
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
im_classes = list(IMAGENET2012_CLASSES.values())
class_names = [im_classes[i] for i in top5_class_indices[0]]
results = {
name: float(prob) for name, prob in zip(class_names, top5_probabilities[0])
}
return results
# Add an example image
example_image = "beignets-task-guide.png"
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="Image Classification with ONNX using EVA02 model",
description="Blog post: https://dicksonneoh.com/portfolio/supercharge_your_pytorch_image_models/",
examples=[example_image], # Add the example image to the interface
)
if __name__ == "__main__":
iface.launch()