-
Notifications
You must be signed in to change notification settings - Fork 11
/
main.lua
90 lines (71 loc) · 3.2 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
require 'paths'
require 'optim'
require 'nn'
------------[[Specify your options here]]--------------
local cmd = torch.CmdLine()
cmd:text()
cmd:text('Torch-7 PlantVillage Challenge Training script')
cmd:text()
cmd:option('-learningRate', 0.01, 'initial learning rate for sgd')
cmd:option('-momentum', 0.9, 'momentum term of sgd')
cmd:option('-maxEpochs', 120, 'Max # Epochs')
cmd:option('-batchSize', 32, 'batch size')
cmd:option('-nbClasses', 38, '# of classes' )
cmd:option('-nbChannels', 3, '# of channels' )
cmd:option('-backend', 'cudnn', 'Options: cudnn | nn')
cmd:option('-model', 'alexnet', 'Options: alexnet | vgg | resnet')
cmd:option('-depth', 'A', 'For vgg depth: A | B | C | D, For resnet depth: 18 | 34 | 50 | 101 | ... Not applicable for alexnet')
cmd:option('-retrain', 'none', 'Path to model to finetune')
cmd:option('-save', '.', 'Path to save models')
cmd:option('-data', 'datasets/crowdai/' ,'Path to folder with train and val directories')
local opt = cmd:parse(arg or {}) -- Table containing all these options
-----------[[Model and criterion here]]---------------
local net
if opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
net = torch.load(opt.retrain)
-- remove final linear layer
local orig = net:get(#net.modules)
assert(torch.type(orig) == 'nn.Linear',
'expected last layer to be fully connected')
net:remove(#net.modules) --remove original layer
net:add(nn.Linear(orig.weight:size(2), opt.nbClasses))
net:add(nn.LogSoftMax())
else
createModel = require('models/'..opt.model)
print('Creating new '..opt.model..' model')
net = createModel(opt)
end
local criterion = nn.ClassNLLCriterion()
if opt.backend ~= 'nn' then
require 'cunn'; require 'cudnn'
cudnn.fastest = true; cudnn.benchmark = true
net = net:cuda()
cudnn.convert(net, cudnn) --Convert the net to cudnn
criterion = criterion:cuda()
end
require 'datasets/plantvillage.lua'
dgen = DataGen(opt.data)
----[[Get your trainer object and start training]]-----
require 'train.lua'
local trainer = Trainer(net, criterion, dgen, opt)
local bestValLoss = math.huge
for n_epoch = 1,opt.maxEpochs do
local trainLoss = trainer:train() --Train on training set
local valLoss = trainer:validate() --Valiate on valiadation set
-- Checkpoint model every 10 epochs
if n_epoch%10 == 0 then
local save_path = paths.concat( opt.save, opt.model..'_'..n_epoch..'.t7')
torch.save(save_path, net:clearState())
print("Checkpointing Model")
end
-- Early stopping
if valLoss < bestValLoss then
bestValLoss = valLoss
print(('Current Best Validation Loss %.5f. Saving the model.'):format(bestValLoss))
local save_path = paths.concat( opt.save, opt.model..'_best.t7')
torch.save(save_path, net:clearState())
end
print("Epoch "..n_epoch.." complete")
end