-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
43 lines (30 loc) · 1.13 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import pickle
import spacy
from model import Sentiment
nlp = spacy.load('en')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def predict_class(model, TEXT,sentence, min_len = 4):
model.eval()
tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
if len(tokenized) < min_len:
tokenized += ['<pad>'] * (min_len - len(tokenized))
indexed = [TEXT.vocab.stoi[t] for t in tokenized]
tensor = torch.LongTensor(indexed).to(device)
tensor = tensor.unsqueeze(1)
preds = model(tensor)
print(preds)
max_preds = preds.argmax(dim = 0)
return max_preds.item()
EMBEDDING_DIM = 400
HIDDEN_DIM = 400
EPOCH = 20
OUTPUT_DIM = 3
TEXT = pickle.load(open("text.pkl", "rb"))
LABEL = pickle.load(open("label.pkl", "rb"))
ix_to_label = {0:'negative', 1:'neutral', 2:'positive'}
model = Sentiment(len(TEXT.vocab), EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, 2, 0.5)
model.to(device)
model.load_state_dict(torch.load('best_model.pt'))
pred_class = predict_class(model,TEXT, "I love you")
print(f'Predicted class is: {ix_to_label[pred_class]}')