forked from Guim3/IcGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainEncoder.lua
346 lines (273 loc) · 12 KB
/
trainEncoder.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
-- This file reads the dataset generated by generateEncoderDataset.lua and
-- trains an encoder net that learns to map an image X to a noise vector Z (encoder Z, type Z)
-- or an encoded that maps an image X to an attribute vector Y (encoder Y, type Y).
require 'image'
require 'nn'
require 'optim'
torch.setdefaulttensortype('torch.FloatTensor')
local function getParameters()
opt = {}
-- Type of encoder must be passed as argument to decide what kind of
-- encoder will be trained (encoder Z [type Z] or encoder Y [type Y])
opt.type = os.getenv('type')
assert(opt.type, "Parameter 'type' not specified. It is necessary to set the encoder type: 'Z' or 'Y'.\nExample: type=Z th trainEncoder.lua")
assert(string.upper(opt.type)=='Z' or string.upper(opt.type)=='Y',"Parameter 'type' must be 'Z' (encoder Z) or 'Y' (encoder Y).")
-- Load parameters from config file
if string.upper(opt.type)=='Z' then
assert(loadfile("cfg/mainConfig.lua"))(1)
else
assert(loadfile("cfg/mainConfig.lua"))(2)
end
-- one-line argument parser. Parses environment variables to override the defaults
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)
if opt.display then display = require 'display' end
return opt
end
local function readDatasetZ(path)
-- There's expected to find in path a file named groundtruth.dmp
-- which contains the image paths / image tensors and Z and Y input vectors.
local X
local data = torch.load(path..'groundtruth.dmp')
local Z = data.Z
if data.storeAsTensor then
X = data.X
assert(Z:size(1)==X:size(1), "groundtruth.dmp is corrupted, number of images and Z vectors is not equal. Create the dataset again.")
else
assert(Z:size(1)==#data.imNames, "groundtruth.dmp is corrupted, number of images and Z vectors is not equal. Create the dataset again.")
-- Load images
local tmp = image.load(data.relativePath..data.imNames[1])
X = torch.Tensor(#data.imNames, data.imSize[1], data.imSize[2], data.imSize[3])
X[{{1}}] = tmp
for i=2,#data.imNames do
X[{{i}}] = image.load(data.relativePath..data.imNames[i])
end
end
return X, Z
end
local function readDatasetY(path, imSize)
-- For CelebA: there's expected to find in path a file named images.dmp and imLabels.dmp
-- which contains the images X and attribute vectors Y.
-- images.dmp is obtained running data/preprocess_celebA.lua
-- imLabels.dmp is obtained running trainGAN.lua via data/donkey_celebA.lua
-- For MNIST: It will use the mnist luarocks package
local X, Y
if string.lower(path) == 'mnist' or string.lower(path) == 'mnist/' then
local mnist = require 'mnist'
local trainSet = mnist.traindataset()
X = torch.Tensor(trainSet.data:size(1), 1, imSize, imSize)
local resize = trainSet.data:size(2) ~= imSize
local labels = trainSet.label:int()
Y = torch.IntTensor(trainSet.data:size(1), 10):fill(-1)
for i = 1,trainSet.data:size(1) do
-- Read MNIST images
local im = trainSet.data[{{i}}]
if resize then
im = image.scale(im, imSize, imSize):float()
end
im:div(255):mul(2):add(-1) -- change [0, 255] to [-1, 1]
X[{{i}}]:copy(im)
--Read MNIST labels
local class = trainSet.label[i] -- Convert 0-9 to one-hot vector
Y[{{i},{class+1}}] = 1
end
else
print('Loading images X from '..path..'images.dmp')
local data = torch.load(path..'images.dmp')
print(('Done. Loaded %.2f GB (%d images).'):format((4*data:size(1)*data:size(2)*data:size(3)*data:size(4))/2^30, data:size(1)))
if data:size(3) ~= imSize then
-- Resize images
X = torch.Tensor(data:size(1), imSize, imSize)
for i = 1,data:size(1) do
local im = image.scale(data[{{i}}], imSize, imSize)
im:mul(2):add(-1) -- change [0, 1] -to[-1, 1]
X[{{i}}]:copy(im)
end
else
X = data
X:mul(2):add(-1) -- change [0, 1] to [-1, 1]
end
print('Loading attributes Y from '..path..'imLabels.dmp')
Y = torch.load(path..'imLabels.dmp')
print(('Done. Loaded %d attributes'):format(Y:size(1)))
end
return X, Y
end
local function splitTrainTest(x, y, split)
local xTrain, yTrain, xTest, yTest
local nSamples = x:size(1)
local splitInd = torch.floor(split*nSamples)
xTrain = x[{{1,splitInd}}]
yTrain = y[{{1,splitInd}}]
xTest = x[{{splitInd+1,nSamples}}]
yTest = y[{{splitInd+1,nSamples}}]
return xTrain, yTrain, xTest, yTest
end
local function getEncoder(inputSize, nFiltersBase, outputSize, nConvLayers, FCsz)
-- Encoder architecture based on Autoencoding beyond pixels using a learned similarity metric (VAE/GAN hybrid)
local encoder = nn.Sequential()
-- Assuming nFiltersBase = 64, nConvLayers = 3
-- 1st Conv layer: 5×5 64 conv. ↓, BNorm, ReLU
-- Data: 32x32 -> 16x16
encoder:add(nn.SpatialConvolution(inputSize[1], nFiltersBase, 5, 5, 2, 2, 2, 2))
encoder:add(nn.SpatialBatchNormalization(nFiltersBase))
encoder:add(nn.ReLU(true))
-- 2nd Conv layer: 5×5 128 conv. ↓, BNorm, ReLU
-- Data: 16x16 -> 8x8
-- 3rd Conv layer: 5×5 256 conv. ↓, BNorm, ReLU
-- Data: 8x8 -> 4x4
local nFilters = nFiltersBase
for j=2,nConvLayers do
encoder:add(nn.SpatialConvolution(nFilters, nFilters*2, 5, 5, 2, 2, 2, 2))
encoder:add(nn.SpatialBatchNormalization(nFilters*2))
encoder:add(nn.ReLU(true))
nFilters = nFilters * 2
end
-- 4th FC layer: 2048 fully-connected
-- Data: 4x4 -> 16
encoder:add(nn.View(-1):setNumInputDims(3)) -- reshape data to 2d tensor (samples x the rest)
-- Assuming squared images and conv layers configuration (kernel, stride and padding) is not changed:
--nFilterFC = (imageSize/2^nConvLayers)²*nFiltersLastConvNet
local inputFilterFC = (inputSize[2]/2^nConvLayers)^2*nFilters
if FCsz == nil then FCsz = inputFilterFC end
encoder:add(nn.Linear(inputFilterFC, FCsz))
encoder:add(nn.BatchNormalization(FCsz))
encoder:add(nn.ReLU(true))
encoder:add(nn.Linear(FCsz, outputSize))
local criterion = nn.MSECriterion()
return encoder, criterion
end
local function assignBatches(batchX, batchY, x, y, batch, batchSize, shuffle)
data_tm:reset(); data_tm:resume()
batchX:copy(x:index(1, shuffle[{{batch,batch+batchSize-1}}]:long()))
batchY:copy(y:index(1, shuffle[{{batch,batch+batchSize-1}}]:long()))
data_tm:stop()
return batchX, batchY
end
local function displayConfig(disp, title)
-- initialize error display configuration
local errorData, errorDispConfig
if disp then
errorData = {}
errorDispConfig =
{
title = 'Encoder error - ' .. title,
win = 1,
labels = {'Epoch', 'Train error', 'Test error'},
ylabel = "Error",
legend='always'
}
end
return errorData, errorDispConfig
end
function main()
local opt = getParameters()
print(opt)
-- Set timers
local epoch_tm = torch.Timer()
local tm = torch.Timer()
data_tm = torch.Timer()
-- Read dataset
local X, Y
if string.upper(opt.type)=='Z' then
X, Y = readDatasetZ(opt.datasetPath)
else
X, Y = readDatasetY(opt.datasetPath, opt.loadSize)
end
-- Split train and test
local xTrain, yTrain, xTest, yTest
-- z --> contain Z vectors y --> contain Y vectors
xTrain, yTrain, xTest, yTest = splitTrainTest(X, Y, opt.split)
-- X: #samples x im3 x im2 x im1
-- Z: #samples x 100 x 1 x 1
-- Y: #samples x ny
-- Set network architecture
local encoder, criterion = getEncoder(xTrain[1]:size(), opt.nf, yTrain:size(2), opt.nConvLayers, opt.FCsz)
-- Initialize batches
local batchX = torch.Tensor(opt.batchSize, xTrain:size(2), xTrain:size(3), xTrain:size(4))
local batchY = torch.Tensor(opt.batchSize, yTrain:size(2))
-- Copy variables to GPU
if opt.gpu > 0 then
require 'cunn'
cutorch.setDevice(opt.gpu)
batchX = batchX:cuda(); batchY = batchY:cuda();
if pcall(require, 'cudnn') then
require 'cudnn'
cudnn.benchmark = true
cudnn.convert(encoder, cudnn)
end
encoder:cuda()
criterion:cuda()
end
local params, gradParams = encoder:getParameters() -- This has to be performed always after the cuda call
-- Define optim (general optimizer)
local errorTrain
local errorTest
local function optimFunction(params) -- This function needs to be declared here to avoid using global variables.
-- reset gradients (gradients are always accumulated, to accommodat batch methods)
gradParams:zero()
local outputs = encoder:forward(batchX)
errorTrain = criterion:forward(outputs, batchY)
local dloss_doutput = criterion:backward(outputs, batchY)
encoder:backward(batchX, dloss_doutput)
return errorTrain, gradParams
end
local optimState = {
learningRate = opt.lr,
beta1 = opt.beta1,
}
local nTrainSamples = xTrain:size(1)
local nTestSamples = xTest:size(1)
-- Initialize display configuration (if enabled)
local errorData, errorDispConfig = displayConfig(opt.display, opt.name)
paths.mkdir(opt.outputPath)
-- Train network
local batchIterations = 0 -- for display purposes only
for epoch = 1, opt.nEpochs do
epoch_tm:reset()
local shuffle = torch.randperm(nTrainSamples)
for batch = 1, nTrainSamples-opt.batchSize+1, opt.batchSize do
tm:reset()
batchX, batchY = assignBatches(batchX, batchY, xTrain, yTrain, batch, opt.batchSize, shuffle)
if opt.display == 2 and batchIterations % 20 == 0 then
display.image(image.toDisplayTensor(batchX,0,torch.round(math.sqrt(opt.batchSize))), {win=2, title='Train mini-batch'})
end
-- Update network
optim.adam(optimFunction, params, optimState)
-- Display train and test error
if opt.display and batchIterations % 20 == 0 then
-- Test error
batchX, batchY = assignBatches(batchX, batchY, xTest, yTest, torch.random(1,nTestSamples-opt.batchSize+1), opt.batchSize, torch.randperm(nTestSamples))
local outputs = encoder:forward(batchX)
errorTest = criterion:forward(outputs, batchY)
table.insert(errorData,
{
batchIterations/math.ceil(nTrainSamples / opt.batchSize), -- x-axis
errorTrain, -- y-axis for label1
errorTest -- y-axis for label2
})
display.plot(errorData, errorDispConfig)
if opt.display == 2 then
display.image(image.toDisplayTensor(batchX,0,torch.round(math.sqrt(opt.batchSize))), {win=3, title='Test mini-batch'})
end
end
-- Verbose
if ((batch-1) / opt.batchSize) % 1 == 0 then
print(('Epoch: [%d][%4d / %4d] Error (train): %.4f Error (test): %.4f '
.. ' Time: %.3f s Data time: %.3f s'):format(
epoch, ((batch-1) / opt.batchSize),
math.ceil(nTrainSamples / opt.batchSize),
errorTrain and errorTrain or -1,
errorTest and errorTest or -1,
tm:time().real, data_tm:time().real))
end
batchIterations = batchIterations + 1
end
print(('End of epoch %d / %d \t Time Taken: %.3f s'):format(
epoch, opt.nEpochs, epoch_tm:time().real))
-- Store network
torch.save(opt.outputPath .. opt.name .. '_' .. epoch .. 'epochs.t7', encoder:clearState())
torch.save('checkpoints/' .. opt.name .. '_error.t7', errorData)
end
end
main()