-
Notifications
You must be signed in to change notification settings - Fork 2
/
CosineEmbeddingCriterion.lua
149 lines (121 loc) · 3.76 KB
/
CosineEmbeddingCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
local CosineEmbeddingCriterion, parent = torch.class('nn.CosineEmbeddingCriterion', 'nn.Criterion')
function CosineEmbeddingCriterion:__init(margin)
parent.__init(self)
margin = margin or 0
self.margin = margin
self.gradInput = {torch.Tensor(), torch.Tensor()}
self.sizeAverage = true
end
function CosineEmbeddingCriterion:updateOutput(input,y)
local input1, input2 = input[1], input[2]
-- keep backward compatibility
if type(y) == 'number' then
self._y = self._y or input1.new(1)
self._y[1] = y
y = self._y
end
if input1:dim() == 1 then
input1 = input1:view(1,-1)
input2 = input2:view(1,-1)
end
if not self.buffer then
self.buffer = input1.new()
self.w1 = input1.new()
self.w22 = input1.new()
self.w = input1.new()
self.w32 = input1.new()
self._outputs = input1.new()
-- comparison operators behave differently from cuda/c implementations
if input1:type() == 'torch.CudaTensor' then
self._idx = input1.new()
else
self._idx = torch.ByteTensor()
end
end
self.buffer:cmul(input1,input2)
self.w1:sum(self.buffer,2)
local epsilon = 1e-12
self.buffer:cmul(input1,input1)
self.w22:sum(self.buffer,2):add(epsilon)
-- self._outputs is also used as a temporary buffer
self._outputs:resizeAs(self.w22):fill(1)
self.w22:cdiv(self._outputs, self.w22)
self.w:resizeAs(self.w22):copy(self.w22)
self.buffer:cmul(input2,input2)
self.w32:sum(self.buffer,2):add(epsilon)
self.w32:cdiv(self._outputs, self.w32)
self.w:cmul(self.w32)
self.w:sqrt()
self._outputs:cmul(self.w1,self.w)
self._outputs = self._outputs:select(2,1)
y.eq(self._idx,y,-1)
self._outputs[self._idx] = self._outputs[self._idx]:add(-self.margin):cmax(0)
y.eq(self._idx,y,1)
self._outputs[self._idx] = self._outputs[self._idx]:mul(-1):add(1)
self.output = self._outputs:sum()
if self.sizeAverage then
self.output = self.output/y:size(1)
end
return self.output
end
function CosineEmbeddingCriterion:updateGradInput(input, y)
local v1 = input[1]
local v2 = input[2]
local not_batch = false
-- keep backward compatibility
if type(y) == 'number' then
self._y = self._y or input1.new(1)
self._y[1] = y
y = self._y
end
if v1:dim() == 1 then
v1 = v1:view(1,-1)
v2 = v2:view(1,-1)
not_batch = true
end
local gw1 = self.gradInput[1]
local gw2 = self.gradInput[2]
gw1:resizeAs(v1):copy(v2)
gw2:resizeAs(v1):copy(v1)
self.w = self.w:expandAs(v1)
self.buffer:cmul(self.w1,self.w22)
self.buffer = self.buffer:expandAs(v1)
gw1:addcmul(-1,self.buffer,v1)
gw1:cmul(self.w)
self.buffer:cmul(self.w1,self.w32)
self.buffer = self.buffer:expandAs(v1)
gw2:addcmul(-1,self.buffer,v2)
gw2:cmul(self.w)
-- self._idx = self._outputs <= 0
y.le(self._idx,self._outputs,0)
self._idx = self._idx:view(-1,1):expand(gw1:size())
gw1[self._idx] = 0
gw2[self._idx] = 0
y.eq(self._idx,y,1)
self._idx = self._idx:view(-1,1):expand(gw2:size())
gw1[self._idx] = gw1[self._idx]:mul(-1)
gw2[self._idx] = gw2[self._idx]:mul(-1)
if self.sizeAverage then
gw1:div(y:size(1))
gw2:div(y:size(1))
end
if not_batch then
self.gradInput[1] = gw1:select(1,1)
self.gradInput[2] = gw2:select(1,1)
end
-- fix for torch bug
-- https://github.com/torch/torch7/issues/289
self.buffer:resize()
return self.gradInput
end
function CosineEmbeddingCriterion:type(type)
self._idx = nil
parent.type(self,type)
-- comparison operators behave differently from cuda/c implementations
if type == 'torch.CudaTensor' then
self._idx = torch.CudaTensor()
else
self._idx = torch.ByteTensor()
end
return self
end