-
Notifications
You must be signed in to change notification settings - Fork 15
/
test.py
69 lines (52 loc) · 2.26 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import print_function
import time
import os
import sys
import logging
import js
import tesorflow as tf
import numpy as np
import cv2
from data.data_loader import *
from model.cgan_model import cgan
def build_model(args):
sess = tf.Session()
model = cgan()
def test(args):
sess = tf.Session()
model = cgan(sess, args)
model.build_model()
model.sess.run(tf.global_variables_initializer())
model.load_weights(args.checkpoint_dir)
dataset = read_data_path_custom(args.data_path_test, image_type=args.imge_type)
image_size = (args.img_h, args.img_w)
if not os.path.exists(args.result_dir):
os.mkdir(args.result_dir)
for i, data in enumerate(dataset):
logging.info("%s image deblur starts", data)
blur_img = read_image(data, resize_or_crop=args.resize_or_crop, image_size=image_size)
logging.debug("%s image was loaded", data)
feed_dict_G = {model.input['blur_img']: blur_img}
G_out = model.G_output(feed_dict=feed_dict_G)
logging.debug("The image was converted")
cv2.imwrite(os.path.join(args.result_dir, 'sharp_'+data.split('/')[-1]), (G_out[0]+1.0)/2.0*255.0)
logging.info("%s Image save was completed", data)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('--data_path_test', type=str, default=None)
parser.add_argument('--result_dir', type=str, default='./result_dir')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/')
parser.add_argument('--model_name', type=str, default='DeblurGAN.model')
parser.add_argument('--img_type', type=str, default='png')
parser.add_argument('--img_h', type=int, default=256)
parser.add_argument('--img_w', type=int, default=256)
parser.add_argument('--img_c', type=int, default=3)
parser.add_argument('--is_test', action='store_true')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
log_format = '[%(asctime)s %(levelname)s] %(message)s'
level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=level, format=log_format, stream=sys.stderr)
logging.getLogger("DeblurGAN_TEST.*").setLevel(level)
test(args)