forked from torch/nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MixtureTable.lua
170 lines (154 loc) · 5.59 KB
/
MixtureTable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
local MixtureTable, parent = torch.class('nn.MixtureTable', 'nn.Module')
function MixtureTable:__init(dim)
parent.__init(self)
self.dim = dim
self.size = torch.LongStorage()
self.batchSize = 0
self.size2 = torch.LongStorage()
self.backwardSetup = false
self.gradInput = {}
end
function MixtureTable:updateOutput(input)
local gaterInput, expertInputs = table.unpack(input)
-- buffers
self._gaterView = self._gaterView or input[1].new()
self._expert = self._expert or input[1].new()
self._expertView = self._expertView or input[1].new()
self.dimG = 2
local batchSize = gaterInput:size(1)
if gaterInput:dim() < 2 then
self.dimG = 1
self.dim = self.dim or 1
batchSize = 1
end
self.dim = self.dim or 2
if self.table or torch.type(expertInputs) == 'table' then
-- expertInputs is a Table :
self.table = true
if gaterInput:size(self.dimG) ~= #expertInputs then
error"Should be one gater output per expert"
end
local expertInput = expertInputs[1]
if self.batchSize ~= batchSize then
self.size:resize(expertInput:dim()+1):fill(1)
if self.dimG > 1 then
self.size[1] = gaterInput:size(1)
end
self.size[self.dim] = gaterInput:size(self.dimG)
self.output:resizeAs(expertInput)
self.backwardSetup = false
self.batchSize = batchSize
end
self._gaterView:view(gaterInput, self.size)
self.output:zero()
-- multiply accumulate gater outputs by their commensurate expert
for i,expertInput in ipairs(expertInputs) do
local gate = self._gaterView:select(self.dim,i):expandAs(expertInput)
self.output:addcmul(expertInput, gate)
end
else
-- expertInputs is a Tensor :
if self.batchSize ~= batchSize then
self.size:resize(expertInputs:dim()):fill(1)
if self.dimG > 1 then
self.size[1] = gaterInput:size(1)
end
self.size[self.dim] = gaterInput:size(self.dimG)
self.output:resizeAs(expertInputs:select(self.dim, 1))
self.batchSize = batchSize
self.backwardSetup = false
end
self._gaterView:view(gaterInput, self.size)
self._expert:cmul(self._gaterView:expandAs(expertInputs), expertInputs)
self.output:sum(self._expert, self.dim)
self.output:resizeAs(expertInputs:select(self.dim, 1))
end
return self.output
end
function MixtureTable:updateGradInput(input, gradOutput)
local gaterInput, expertInputs = table.unpack(input)
nn.utils.recursiveResizeAs(self.gradInput, input)
local gaterGradInput, expertGradInputs = table.unpack(self.gradInput)
-- buffers
self._sum = self._sum or input[1].new()
self._expertView2 = self._expertView2 or input[1].new()
self._expert2 = self._expert2 or input[1].new()
if self.table then
if not self.backwardSetup then
for i,expertInput in ipairs(expertInputs) do
local expertGradInput = expertGradInputs[i] or expertInput:clone()
expertGradInput:resizeAs(expertInput)
expertGradInputs[i] = expertGradInput
end
gaterGradInput:resizeAs(gaterInput)
self.backwardSetup = true
end
-- like CMulTable, but with broadcasting
for i,expertGradInput in ipairs(expertGradInputs) do
-- gater updateGradInput
self._expert:cmul(gradOutput, expertInputs[i])
if self.dimG == 1 then
self._expertView:view(self._expert, -1)
else
self._expertView:view(self._expert, gradOutput:size(1), -1)
end
self._sum:sum(self._expertView, self.dimG)
if self.dimG == 1 then
gaterGradInput[i] = self._sum:select(self.dimG,1)
else
gaterGradInput:select(self.dimG,i):copy(self._sum:select(self.dimG,1))
end
-- expert updateGradInput
local gate = self._gaterView:select(self.dim,i):expandAs(expertGradInput)
expertGradInput:cmul(gate, gradOutput)
end
else
if not self.backwardSetup then
self.size2:resize(expertInputs:dim())
self.size2:copy(expertInputs:size())
self.size2[self.dim] = 1
gaterGradInput:resizeAs(gaterInput)
self.backwardSetup = true
end
-- gater updateGradInput
self._expertView:view(gradOutput, self.size2)
local gradOutput = self._expertView:expandAs(expertInputs)
self._expert:cmul(gradOutput, expertInputs)
local expert = self._expert:transpose(self.dim, self.dimG)
if not expert:isContiguous() then
self._expert2:resizeAs(expert)
self._expert2:copy(expert)
expert = self._expert2
end
if self.dimG == 1 then
self._expertView2:view(expert, gaterInput:size(1), -1)
else
self._expertView2:view(expert, gaterInput:size(1), gaterInput:size(2), -1)
end
gaterGradInput:sum(self._expertView2, self.dimG+1)
gaterGradInput:resizeAs(gaterInput)
-- expert updateGradInput
expertGradInputs:cmul(self._gaterView:expandAs(expertInputs), gradOutput)
end
return self.gradInput
end
function MixtureTable:type(type, tensorCache)
self._gaterView = nil
self._expert = nil
self._expertView = nil
self._sum = nil
self._expert2 = nil
self._expertView2 = nil
return parent.type(self, type, tensorCache)
end
function MixtureTable:clearState()
nn.utils.clear(self, {
'_gaterView',
'_expert',
'_expertView',
'_sum',
'_expert2',
'_expertView2',
})
return parent.clearState(self)
end