-
Notifications
You must be signed in to change notification settings - Fork 16
/
test_model.py
168 lines (132 loc) · 6.57 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import pandas as pd
import numpy as np
import sys
import torch
import torch.nn as nn
from sklearn.metrics import confusion_matrix
import time
""" LOCAL IMPORTS """
from src.data_preprocessing import remove_misc
from supervised_product_matching.model_preprocessing import remove_stop_words, character_bert_preprocess_batch, bert_preprocess_batch
from src.common import Common
using_model = "characterbert"
# Get the folder name in models
FOLDER = sys.argv[1]
# Get the model name from the terminal
MODEL_NAME = sys.argv[2]
def split_test_data(df):
'''
Split test data into the data and the labels
'''
df = remove_misc(df).to_numpy()
df_labels = df[:, 2].astype('float32')
df_data = df[:, 0:2]
return df_data, df_labels
test_laptop_data, test_laptop_labels = split_test_data(pd.read_csv('data/test/final_laptop_test_data.csv')) # General laptop test data
test_gb_space_data, test_gb_space_labels = split_test_data(pd.read_csv('data/test/final_gb_space_laptop_test.csv')) # Same titles; Substituted storage attributes
test_gb_no_space_data, test_gb_no_space_labels = split_test_data(pd.read_csv('data/test/final_gb_no_space_laptop_test.csv')) # Same titles; Substituted storage attributes
test_retailer_gb_space_data, test_retailer_gb_space_labels = split_test_data(pd.read_csv('data/test/final_retailer_gb_space_test.csv')) # Different titles; Substituted storage attributes
test_retailer_gb_no_space_data, test_retailer_gb_no_space_labels = split_test_data(pd.read_csv('data/test/final_retailer_gb_no_space_test.csv')) # Different titles; Substituted storage attributes
print('Loaded all test files')
# Initialize the model
net = None
if using_model == "characterbert":
from supervised_product_matching.model_architectures.characterbert_classifier import SiameseNetwork, forward_prop
net = SiameseNetwork().to(Common.device)
elif using_model == "bert":
from supervised_product_matching.model_architectures.bert_classifier import SiameseNetwork, forward_prop
net = SiameseNetwork().to(Common.device)
elif using_model == "scaled characterbert concat":
from supervised_product_matching.model_architectures.characterbert_transformer_concat import SiameseNetwork, forward_prop
net = SiameseNetwork()
elif using_model == "scaled characterbert add":
from supervised_product_matching.model_architectures.characterbert_transformer_add import SiameseNetwork, forward_prop
net = SiameseNetwork().to(Common.device)
if (torch.cuda.is_available()):
net.load_state_dict(torch.load('./models/{}/{}.pt'.format(FOLDER, MODEL_NAME)))
else:
net.load_state_dict(torch.load('./models/{}/{}.pt'.format(FOLDER, MODEL_NAME), map_location=torch.device('cpu')))
# Using cross-entropy because we are making a classifier
criterion = nn.CrossEntropyLoss()
print("************* Validating *************")
# The size of each mini-batch
BATCH_SIZE = 32
# The size of the validation mini-batch
VAL_BATCH_SIZE = 16
# How long we should accumulate for running loss and accuracy
PERIOD = 50
def validation(data, labels, name):
'''
Validate the model
'''
running_loss = 0.0
running_accuracy = 0.0
current_batch = 0
running_tn = 0
running_fp = 0
running_fn = 0
running_tp = 0
for i, position in enumerate(range(0, len(data), VAL_BATCH_SIZE)):
current_batch += 1
if (position + VAL_BATCH_SIZE > len(data)):
batch_data = data[position:]
batch_labels = labels[position:]
else:
batch_data = data[position:position + VAL_BATCH_SIZE]
batch_labels = labels[position:position + VAL_BATCH_SIZE]
# Forward propagation
loss, forward = forward_prop(batch_data, batch_labels, net, criterion)
# Get the predictions from the net
y_pred = torch.argmax(forward, dim=1).cpu()
# Calculate accuracy
accuracy = np.sum(y_pred.detach().numpy() == batch_labels) / float(batch_labels.shape[0])
# Get the confusion matrix and calculate precision, recall and F1 score
confusion = confusion_matrix(batch_labels, y_pred.detach().numpy(), labels=[0, 1])
tn, fp, fn, tp = confusion.ravel()
running_tn += tn
running_fp += fp
running_fn += fn
running_tp += tp
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1_score = 2 * ((precision * recall) / (precision + recall))
# Add to running loss and accuracy (every 10 batches)
running_loss += loss.item()
running_accuracy += accuracy
# Print statistics every batch
print('%s Batch: %5d, Loss: %.6f, Accuracy: %.6f, Running Loss: %.6f, Running Accuracy: %.6f, Precision: %.3f, Recall: %.3f, F1 Score: %.3f' %
(name, i + 1, loss, accuracy, running_loss / current_batch, running_accuracy / current_batch, precision, recall, f1_score))
# Clear our running variables every 10 batches
if (current_batch == PERIOD):
current_batch = 0
running_loss = 0
running_accuracy = 0
# Get the statistics for the whole data
final_precision = running_tp / (running_tp + running_fp)
final_recall = running_tp / (running_tp + running_fn)
final_f1_score = 2 * ((final_precision * final_recall) / (final_precision + final_recall))
print('%s: Precision: %.3f, Recall: %.3f, F1 Score: %.3f' % (name, final_precision, final_recall, final_f1_score))
def inference():
'''
Test model using your own titles
'''
title1 = input('First title: ')
title2 = input('Second title: ')
title1 = remove_stop_words(title1)
title2 = remove_stop_words(title2)
data = np.array([title1, title2]).reshape(1, 2)
forward = net(*character_bert_preprocess_batch(data))
np_forward = forward.detach().numpy()[0]
print('Output: {}'.format(torch.argmax(forward)))
print('Softmax: Negative {:.4f}%, Positive {:.4f}%'.format(np_forward[0], np_forward[1]))
user_input = input('Would you like to validate, or manually test the model? (validate/test) ')
net.eval()
if user_input.lower() == 'validate':
validation(test_laptop_data, test_laptop_labels, 'Test Laptop (General)')
validation(test_gb_space_data, test_gb_space_labels, 'Test Laptop (Same Title) (Space)')
validation(test_gb_no_space_data, test_gb_no_space_labels, 'Test Laptop (Same Title) (No Space')
validation(test_retailer_gb_space_data, test_retailer_gb_space_labels, 'Test Laptop (Different Title) (Space)')
validation(test_retailer_gb_no_space_data, test_retailer_gb_no_space_labels, 'Test Laptop (Different Title) (No Space)')
else:
while True:
inference()