-
Notifications
You must be signed in to change notification settings - Fork 0
/
testing-generatore.py
68 lines (51 loc) · 2.13 KB
/
testing-generatore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from model.generator import Generator
from model.discriminator import Discriminator
from utils.function import prepare_data, ritagliare_centro, create_dir, create_graphic_testing
from utils.parameters import *
import torchvision.utils as vutils
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
create_dir(TEST_RESULT)
criterion = nn.BCELoss()
generator = Generator()
generator.load_state_dict(torch.load("./log/generator.pt"))
generator.eval()
discriminator = Discriminator()
discriminator.load_state_dict(torch.load("./log/discriminator.pt"))
discriminator.eval()
dataloader = prepare_data("./dataset/testing/")
num_img = 0
perdita = 0.0
total = 0.0
for data in dataloader:
real_cpu, _ = data
real_center_cpu = real_cpu[:, :, int(img_size / 4):int(img_size / 4) + int(img_size / 2),
int(img_size / 4):int(img_size / 4) + int(img_size / 2)]
batch_size = real_cpu.size(0)
real_cpu = real_cpu.cuda()
real_center_cpu = real_center_cpu.cuda()
real_cpu.to(device)
real_center_cpu.to(device)
# Individuiamo e ritagliamo il centro dell'immagine reale
input_real, input_cropped, real_center = ritagliare_centro(input_real, input_cropped, real_cpu, real_center,
real_center_cpu)
with torch.no_grad():
label.resize_((batch_size, 1, 1, 1)).fill_(real_label)
fake = generator(input_cropped)
label.data.fill_(fake_label)
output = discriminator(fake.detach())
errD_fake = criterion(output, label)
total += 1
perdita += errD_fake.mean().item()
ricostruzione.append(100 * perdita / total)
recon_image = input_cropped.clone()
recon_image.data[:, :, int(img_size / 4):int(img_size / 4 + img_size / 2),
int(img_size / 4):int(img_size / 4 + img_size / 2)] = fake.data
for i in range(0, batch_size):
num_img += 1
vutils.save_image([input_real[i], input_cropped[i], recon_image[i]],
TEST_RESULT + f"ricostruite_{num_img}.png")
print(f"Accuratezza: {100 * perdita / total}")
create_graphic_testing(ricostruzione)