-
Notifications
You must be signed in to change notification settings - Fork 2
/
Criterion.py
75 lines (55 loc) · 2.3 KB
/
Criterion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from __future__ import division
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
class Criterion(object):
def __init__(self, model, opt, weights=[0.05, 0.95], factor=0.0005):
super(Criterion, self).__init__()
weights = torch.FloatTensor(weights)
self.crit = nn.NLLLoss(weight=weights, size_average=False)
if opt.gpus:
self.crit.cuda()
self.l1_crit = nn.L1Loss(size_average=False)
self.model = model
self.factor = factor
self.records = []
def loss(self, scores, labels, generator, eval=False):
# compute generations one piece at a time
num_correct, loss = 0, 0
_scores = Variable(scores.data, requires_grad=(not eval), volatile=eval)
scores = generator(_scores)
self.records += [scores[:,1].data]
# print "scores", scores
# print "scores_for_one", scores_for_one
batch_size = scores.size(0)
labels = labels.view(-1)
# loss = torch.dot(torch.log(scores), labels) + torch.dot(torch.log(1-scores), (1-labels))
# print "labels: ", labels
loss = self.crit(scores, labels)
# reg_loss = 0
# for param in self.model.parameters():
# reg_loss += self.l1_crit(param)
# loss = loss + self.factor * reg_loss
pred = scores.max(1)[1]
num_correct = pred.data.eq(labels.data).sum()
# tp = pred.data.eq(labels.data).masked_select(labels.ne(0).data).sum()
# tn = pred.data.eq(labels.data).masked_select(labels.ne(1).data).sum()
# all_p = labels.data.eq(1).sum()
# all_n = labels.data.eq(0).sum()
# fn = all_p - tp
# fp = all_n - tn
# accuracy: (TP+TN)/(TP+FN+FP+TN)
# accuracy = num_correct * 1.0 / batch_size
# precision: TP/(TP+FP)
# if all_p == 0:
# precision = 0.0
# else:
# precision = tp * 1.0 / all_p
# recall: TP/(TP+FN)
# recall = tp * 1.0 / (tp + fn)
loss_data = loss.data[0]
if not eval:
loss.div(batch_size).backward()
grad_output = None if _scores.grad is None else _scores.grad.data
return loss_data, grad_output, num_correct