-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
47 lines (35 loc) · 1.23 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import io
import keras
import numpy as np
import streamlit as st
from PIL import Image
from tensorflow.keras.models import load_model
st.title("plant seedling classification")
img_size = [224, 224]
model_path = 'model.h5'
#
# model = load_model(model_path)
file = st.file_uploader('please upload an image', type=['jpg', 'png'])
class_label = ['Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat',
'Fat Hen', 'Loose Silky-bent', 'Maize', 'Scentless Mayweed', 'Shepherds Purse',
'Small-flowered Cranesbill',
'Sugar beet']
def get_model(path):
return load_model(path)
def get_image(img_path, img_size):
img = keras.preprocessing.image.load_img(img_path, target_size=img_size)
array = keras.preprocessing.image.img_to_array(img)
array = np.expand_dims(array, axis=0)
array = array / 255.0
return array
def test():
model = get_model(model_path)
image = get_image(file, img_size)
pred = model.predict(image)
pred_index = np.argmax(pred, axis=1)
st.write('the predicted class of image is : {}'.format(class_label[int(pred_index)]))
print('\n')
st.image(image, use_column_width=True)
gen_pred = st.button('Predict')
if gen_pred:
test()