forked from krasserm/super-resolution
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_WDSR.py
69 lines (48 loc) · 1.89 KB
/
run_WDSR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import matplotlib.pyplot as plt
from data import DIV2K
from model.wdsr import wdsr_b
from train import WdsrTrainer
# Number of residual blocks
depth = 32
# Super-resolution factor
scale = 4
# Downgrade operator
downgrade = 'bicubic'
# Location of model weights (needed for demo)
weights_dir = f'weights/wdsr-b-{depth}-x{scale}'
weights_file = os.path.join(weights_dir, 'weights.h5')
os.makedirs(weights_dir, exist_ok=True)
div2k_train = DIV2K(scale=scale, subset='train', downgrade=downgrade)
div2k_valid = DIV2K(scale=scale, subset='valid', downgrade=downgrade)
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=False, repeat_count=1)
attention = True
trainer = WdsrTrainer(model=wdsr_b(scale=scale, num_res_blocks=depth, attention=attention),
checkpoint_dir=f'.ckpt/wdsr-b-{depth}-x{scale}')
# Train WDSR B model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
valid_ds.take(10),
steps=300000,
evaluate_every=1000,
save_best_only=True,
model_name='wdsr_attention')
# Restore from checkpoint with highest PSNR
trainer.restore()
# Evaluate model on full validation set
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')
# Save weights to separate location (needed for demo)
trainer.model.save_weights(weights_file)
model = wdsr_b(scale=scale, num_res_blocks=depth, attention=attention)
model.load_weights(weights_file)
from model import resolve_single
from utils import load_image, plot_sample
def resolve_and_plot(lr_image_path):
lr = load_image(lr_image_path)
sr = resolve_single(model, lr)
plot_sample(lr, sr)
resolve_and_plot('demo/0869x4-crop.png')