This repository has been archived by the owner on Jul 4, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
DataLoader.lua
133 lines (110 loc) · 4.25 KB
/
DataLoader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
require 'torch'
require 'image'
require 'torchx'
require 'xlua'
require 'dp'
require './Extensions.lua'
torch.setdefaulttensortype('torch.FloatTensor')
DataLoader = {}
local function todepth(img, depth)
if depth and depth == 1 then
if img:nDimension() == 2 then
-- all good, do nothing
elseif img:size(1) == 3 or img:size(1) == 4 then
img = image.rgb2y(img:narrow(1,1,3))[1]
elseif img:size(1) == 2 then
img = img:narrow(1,1,1)
elseif img:size(1) ~= 1 then
dok.error('image loaded has wrong #channels', 'image.todepth')
end
elseif depth and depth == 3 then
local chan = img:size(1)
if chan == 3 then
-- all good, do nothing
elseif img:nDimension() == 2 then
local imgrgb = img.new(3, img:size(1), img:size(2))
imgrgb:select(1, 1):copy(img)
imgrgb:select(1, 2):copy(img)
imgrgb:select(1, 3):copy(img)
img = imgrgb
elseif chan == 4 then
img = img:narrow(1,1,3)
elseif chan == 1 then
local imgrgb = img.new(3, img:size(2), img:size(3))
imgrgb:select(1, 1):copy(img)
imgrgb:select(1, 2):copy(img)
imgrgb:select(1, 3):copy(img)
img = imgrgb
else
dok.error('image loaded has wrong #channels', 'image.todepth')
end
end
return img
end
local function loadData(data_path, data_size, equal_representation, test_percentage, valid_percentage, verbose)
if verbose then print('Loading images') end
local c, h, w = data_size[1], data_size[2], data_size[3]
local normal = paths.indexdir(paths.concat(data_path, 'normal')) -- 1
local leuko = paths.indexdir(paths.concat(data_path, 'leuko')) -- 2
local num_normal, num_leuko
if equal_representation then
num_normal = math.min(normal:size(), leuko:size())
num_leuko = num_normal
else
num_normal, num_leuko = normal:size(), leuko:size()
end
local size = num_normal + num_leuko
local shuffle = torch.randperm(size)
local input = torch.FloatTensor(size, c, h, w)
local target = torch.IntTensor(size)
-- 1. load images into input and target tensors
for i = 1, num_normal do
if verbose then xlua.progress(i, size) end
local img = image.load(normal:filename(i))
img = image.scale(img, h, w)
img = todepth(img, 3)
local index = shuffle[i]
input[index]:copy(img)
target[index] = 1
collectgarbage()
end
for i = 1, num_leuko do
if verbose then xlua.progress(i + num_normal, size) end
local img = image.load(leuko:filename(i))
img = image.scale(img, h, w)
img = todepth(img, 3)
local index = shuffle[i + num_normal]
input[index]:copy(img)
target[index] = 2
collectgarbage()
end
-- 2. divide into train, test, and valid sets
local num_valid = math.floor(size * valid_percentage)
local num_test = math.floor(size * test_percentage)
local num_train = size - num_valid - num_test
-- 3. wrap into dp.View
local train_input = dp.ImageView('bchw', input:narrow(1, 1, num_train))
local train_target = dp.ClassView('b', target:narrow(1, 1, num_train))
local test_input = dp.ImageView('bchw', input:narrow(1, num_train + 1, num_test))
local test_target = dp.ClassView('b', target:narrow(1, num_train + 1, num_test))
local valid_input = dp.ImageView('bchw', input:narrow(1, num_train + num_test + 1, num_valid))
local valid_target = dp.ClassView('b', target:narrow(1, num_train + num_test + 1, num_valid))
train_target:setClasses({'normal', 'leuko'})
test_target:setClasses({'normal', 'leuko'})
valid_target:setClasses({'normal', 'leuko'})
-- 4. wrap dp.View into dp.DataSet
local train = dp.DataSet{inputs=train_input, targets=train_target, which_set='train'}
local test = dp.DataSet{inputs=test_input, targets=test_target, which_set='test'}
local valid = dp.DataSet{inputs=valid_input, targets=valid_target, which_set='valid'}
-- 4. wrap dp.DataSet into dp.DataSource
local ds = dp.DataSource{train_set=train, test_set=test, valid_set=valid}
ds:classes{'normal', 'leuko'}
return ds
end
-- expose desired functions as public
DataLoader.loadData = loadData
DataLoader.imagefromstring = imagefromstring
DataLoader.todepth = todepth
DataLoader.outercrop = outercrop
DataLoader.augment = augment
return DataLoader