-
Notifications
You must be signed in to change notification settings - Fork 32
/
crossDatasetTrainTest.m
62 lines (55 loc) · 2.04 KB
/
crossDatasetTrainTest.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
% clc;
clear;
warning off
addpath('support_methods/');
addpath(genpath('tools'));
vl_setupnn;
Databases = {'ChallengeDB_release','LIVE','CSIQ','TID2013'};
trainDatabase = 'LIVE';
testDatabase = setdiff(Databases,trainDatabase);
ModelType = 'AlexNet'; % ResNet AlexNet S_CNN
trainingPropertion = 1.0;
testingPropertion = 1.0;
patchNum = [50];
%% parameters for PQR
quantizationMethod = 'uniform'; % LloydMax
beta = 64;
bins = 5; % set bins=1 for scalar quality score regression
switch ModelType
case 'AlexNet'
patchSize = 227;
patchStep = 16;
epoch = 20;
case 'ResNet'
patchSize = 224;
patchStep = 16;
epoch = 10;
case 'S_CNN'
patchSize = 64;
patchStep = 8;
epoch = 40;
end
seed = 1; % no use in cross dataset eveluation
TrainModel('Model',ModelType,'Database',trainDatabase,'PatchNum', patchNum,...
'seed',seed, 'trainingpropertion', trainingPropertion,...
'quantizationMethod',quantizationMethod,...
'bins',bins, 'beta', beta,'patchSize',patchSize,...
'patchStep',patchStep,'epoch',epoch);
for d = 1:numel(testDatabase)
for i = 1:epoch
netStruct = load(fullfile('data',[ModelType '_' trainDatabase '_TrainPropertion' num2str(trainingPropertion)...
'_PatchSize' num2str(patchSize) '_PatchNum' num2str(patchNum) '_Seed' num2str(seed) '_bins' num2str(bins)...
'_beta' num2str(beta)],['net-epoch-' num2str(i) '.mat']));
net = dagnn.DagNN.loadobj(netStruct.net) ;
move(net, 'gpu')
net.mode = 'test';
[SRCC(i),PLCC(i)] = testModel(ModelType, net, testDatabase{d},seed,1-testingPropertion);
end
best_SRCC = max(abs(SRCC));
best_PLCC = max((PLCC));
file = fopen(fullfile('result','crossDatabase.txt'),'a');
fprintf(file,'TrainDatabase: %s; TestDatabase: %s; Model: %s; bins = %d; patches = %d;\n',...
trainDatabase,testDatabase{d},ModelType,bins,patchNum);
fprintf(file,'SRCC: %.4f; PLCC: %.4f \n',best_SRCC,best_PLCC);
fclose(file);
end