-
Notifications
You must be signed in to change notification settings - Fork 2
/
model_testing.py
66 lines (41 loc) · 2.02 KB
/
model_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from ALL_model import *
from torch.autograd import Variable
from dataloader import read_bci_data
import pandas as pd
import torch
import numpy
import os
def testing(x_test,y_test,device,model):
model.eval()
with torch.no_grad():
model.to(device)
n = x_test.shape[0]
x_test = x_test.astype("float32")
y_test = y_test.astype("float32").reshape(y_test.shape[0],)
x_test, y_test = Variable(torch.from_numpy(x_test)),Variable(torch.from_numpy(y_test))
x_test,y_test = x_test.to(device),y_test.to(device)
y_pred_test = model(x_test)
correct = (torch.max(y_pred_test,1)[1]==y_test).sum().item()
# print("testing accuracy:",correct/n)
return correct/n
if __name__ == "__main__":
model_list=[EEGNet_ReLU, EEGNet_LeakyReLU, EEGNet_ELU, DeepConvNet_ReLU, DeepConvNet_LeakyReLU, DeepConvNet_ELU]
model_file_path=["EEGNet_checkpoint_ReLU.rar","EEGNet_checkpoint_LeakyReLU.rar","EEGNet_checkpoint_ELU.rar","DeepConvNet_checkpoint_ReLU.rar","DeepConvNet_checkpoint_LeakyReLU.rar","DeepConvNet_checkpoint_ELU.rar"]
ReLU_accuracy=[]
LeakyReLU_accuracy=[]
ELU_accuracy=[]
for i in range(len(model_list)):
filepath=os.path.abspath(os.path.dirname(__file__))+"\\checkpoint\\"+model_file_path[i]
device = torch.device("cuda:0")
model = model_list[i](2)
model.load_state_dict(torch.load(filepath))
train_data, train_label, test_data, test_label = read_bci_data()
testing_accuracy = testing(test_data,test_label,device,model)
if "LeakyReLU" in model_file_path[i]:
LeakyReLU_accuracy.append(testing_accuracy)
elif "ReLU" in model_file_path[i]:
ReLU_accuracy.append(testing_accuracy)
else:
ELU_accuracy.append(testing_accuracy)
df = pd.DataFrame({"ReLU":ReLU_accuracy,"LeakyReLU":LeakyReLU_accuracy,"ELU":ELU_accuracy},index=["EEGNet","DeepConvNet"])
print(df)