-
Notifications
You must be signed in to change notification settings - Fork 9
/
test.py
47 lines (31 loc) · 1.34 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from dehaze_net import gman_net
import random
import glob
import tensorflow as tf
import matplotlib.pyplot as plt
net = gman_net()
# function to evaluate the model on test images.
def evaluate(net, test_img_path):
test_img = glob.glob(test_img_path + '/*.jpg')
random.shuffle(test_img)
for img in test_img:
img = tf.io.read_file(img)
img = tf.io.decode_jpeg(img, channels = 3)
if img.shape[1] > img.shape[0]:
img = tf.image.resize(img, size = (1080, 1920), antialias = True)
if img.shape[1] < img.shape[0]:
img = tf.image.resize(img, size = (1920, 1080), antialias = True)
img = img / 255.0
img = tf.expand_dims(img, axis = 0) #transform input image from 3D to 4D
dehaze = net(img, training = False)
plt.figure(figsize = (80, 80))
display_list = [img[0], dehaze[0]] #make the first dimension zero
title = ['Hazy Image', 'Dehazed Image']
for i in range(2):
plt.subplot(1, 2, i+1)
plt.title(title[i], fontsize = 65, y = 1.045)
plt.imshow(display_list[i])
plt.axis('off')
plt.show()
test_net = tf.keras.models.load_model('trained_model', compile = False)
evaluate(test_net, '../input/hazy-test-images')