-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTest_main.py
79 lines (58 loc) · 2.72 KB
/
Test_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import torch
import cv2
from utils.feature_extractor import featureExtractor
from utils.data_loader import TestDataset
from torch.utils.data import Dataset, DataLoader
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def run_testing_on_dataset(trained_model, dataset_dir, GT_blurry):
correct_prediction_count = 0
img_list = os.listdir(dataset_dir)
for ind, image_name in enumerate(img_list):
print("Blurry Image Prediction: %d / %d images processed.." % (ind, len(img_list)))
# Read the image
img = cv2.imread(os.path.join(dataset_dir, image_name), 0)
prediction = is_image_blurry(trained_model, img, threshold=0.5)
if(prediction == GT_blurry):
correct_prediction_count += 1
accuracy = correct_prediction_count / len(img_list)
return(accuracy)
def is_image_blurry(trained_model, img, threshold=0.5):
feature_extractor = featureExtractor()
accumulator = []
# Resize the image by the downsampling factor
feature_extractor.resize_image(img, np.shape(img)[0], np.shape(img)[1])
# compute the image ROI using local entropy filter
feature_extractor.compute_roi()
# extract the blur features using DCT transform coefficients
extracted_features = feature_extractor.extract_feature()
extracted_features = np.array(extracted_features)
if(len(extracted_features) == 0):
return True
test_data_loader = DataLoader(TestDataset(extracted_features), batch_size=1, shuffle=False)
# trained_model.test()
for batch_num, input_data in enumerate(test_data_loader):
x = input_data
x = x.to(device).float()
output = trained_model(x)
_, predicted_label = torch.max(output, 1)
accumulator.append(predicted_label.item())
prediction= np.mean(accumulator) < threshold
return(prediction)
if __name__ == '__main__':
trained_model = torch.load('./trained_model/trained_model')
trained_model = trained_model['model_state']
dataset_dir = './dataset/defocused_blurred/'
accuracy_blurry_images = run_testing_on_dataset(trained_model, dataset_dir, GT_blurry = True)
dataset_dir = './dataset/sharp/'
accuracy_sharp_images = run_testing_on_dataset(trained_model, dataset_dir, GT_blurry = False)
dataset_dir = './dataset/motion_blurred/'
accuracy_motion_blur_images = run_testing_on_dataset(trained_model, dataset_dir, GT_blurry=True)
print("========================================")
print('Test accuracy on blurry forlder = ')
print(accuracy_blurry_images)
print('Test accuracy on sharp forlder = ')
print(accuracy_sharp_images)
print('Test accuracy on motion blur forlder = ')
print(accuracy_motion_blur_images)