forked from face-analysis/emonet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
67 lines (52 loc) · 3.05 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
from pathlib import Path
import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import transforms
from emonet.models import EmoNet
from emonet.data import AffectNet
from emonet.data_augmentation import DataAugmentor
from emonet.metrics import CCC, PCC, RMSE, SAGR, ACC
from emonet.evaluation import evaluate, evaluate_flip
torch.backends.cudnn.benchmark = True
#Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--nclasses', type=int, default=8, choices=[5,8], help='Number of emotional classes to test the model on. Please use 5 or 8.')
args = parser.parse_args()
# Parameters of the experiments
n_expression = args.nclasses
batch_size = 32
n_workers = 16
device = 'cuda:0'
image_size = 256
subset = 'test'
metrics_valence_arousal = {'CCC':CCC, 'PCC':PCC, 'RMSE':RMSE, 'SAGR':SAGR}
metrics_expression = {'ACC':ACC}
# Create the data loaders
transform_image = transforms.Compose([transforms.ToTensor()])
transform_image_shape_no_flip = DataAugmentor(image_size, image_size)
flipping_indices = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 26, 25, 24, 23, 22,21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33, 32, 31, 45,44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52, 51,50, 49, 48, 59, 58,57, 56, 55, 64, 63,62, 61, 60, 67, 66,65]
transform_image_shape_flip = DataAugmentor(image_size, image_size, mirror=True, shape_mirror_indx=flipping_indices, flipping_probability=1.0)
print(f'Testing the model on {n_expression} emotional classes')
print('Loading the data')
test_dataset_no_flip = AffectNet(root_path='~/datasets/new_affectnet/', subset=subset, n_expression=n_expression,
transform_image_shape=transform_image_shape_no_flip, transform_image=transform_image)
test_dataset_flip = AffectNet(root_path='~/datasets/new_affectnet/', subset=subset, n_expression=n_expression,
transform_image_shape=transform_image_shape_flip, transform_image=transform_image)
test_dataloader_no_flip = DataLoader(test_dataset_no_flip, batch_size=batch_size, shuffle=False, num_workers=n_workers)
test_dataloader_flip = DataLoader(test_dataset_flip, batch_size=batch_size, shuffle=False, num_workers=n_workers)
# Loading the model
state_dict_path = Path(__file__).parent.joinpath('pretrained', f'emonet_{n_expression}.pth')
print(f'Loading the model from {state_dict_path}.')
state_dict = torch.load(str(state_dict_path), map_location='cpu')
state_dict = {k.replace('module.',''):v for k,v in state_dict.items()}
net = EmoNet(n_expression=n_expression).to(device)
net.load_state_dict(state_dict, strict=False)
net.eval()
print(f'Testing on {subset}-set')
print(f'------------------------')
evaluate_flip(net, test_dataloader_no_flip, test_dataloader_flip, device=device, metrics_valence_arousal=metrics_valence_arousal, metrics_expression=metrics_expression)
#evaluate(net, test_dataloader_no_flip, device=device, metrics_valence_arousal=metrics_valence_arousal, metrics_expression=metrics_expression)