Skip to content

Commit

Permalink
better version
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 12, 2023
1 parent f419ab4 commit 8e44f04
Showing 1 changed file with 92 additions and 78 deletions.
170 changes: 92 additions & 78 deletions vision/conv_mnist/conv_mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,81 +4,47 @@ using MLDatasets, Flux, CUDA, BSON # this will install everything if necc.

#===== DATA =====#

tmp = MLDatasets.MNIST()
# Calling MLDatasets.MNIST() will dowload the dataset if necessary,
# and return a struct containing it.
# It takes a few seconds to read from disk each time, so do this once:

# This will dowload the dataset if necessary, and return a struct containing it:
# tmp.features is a 28×28×60000 Array{Float32, 3} of the images.
# Flux needs images to be 4D arrays, with the 3rd dim for channels -- here trivial, grayscale.
train_data = MLDatasets.MNIST() # i.e. split=:train
test_data = MLDatasets.MNIST(split=:test)

function get_data(; split=:train, batchsize=64) # allows also split=:test
x, y = MLDatasets.MNIST(; split)[:]
# train_data.features is a 28×28×60000 Array{Float32, 3} of the images.
# Flux needs a 4D array, with the 3rd dim for channels -- here trivial, grayscale.
# Combine the reshape needed other pre-processing:

function loader(data::MNIST=train_data; batchsize::Int=64)
x, y = data[:] # this is a NamedTuple of (features, targets)
x4dim = reshape(x, 28,28,1,:) # insert channel dim
yhot = Flux.onehotbatch(y, 0:9)
isinf(batchsize) && return [(x4dim, yhot)] # |> gpu
Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) # |> gpu
Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) |> gpu
end

get_data(split=:test) # returns a DataLoader, with first element a tuple like this:

x1, y1 = first(get_data()); # (28×28×1×64 Array{Float32, 3}, 10×64 OneHotMatrix(::Vector{UInt32}))
loader() # returns a DataLoader, with first element a tuple like this:

#===== MODEL =====#
x1, y1 = first(loader()); # (28×28×1×64 Array{Float32, 3}, 10×64 OneHotMatrix(::Vector{UInt32}))

# A layer like Conv((5, 5), 1=>6) takes 5x5 patches of an image, and matches them to
# each of 6 different 5x5 filters, placed at every possible position.
# If you are using a GPU, these should be CuArray{Float32, 3} etc.

Conv((5, 5), 1=>6).weights |> summary # 5×5×1×6 Array{Float32, 4}
#===== MODEL =====#

# LeNet has two convolutional layers, and our modern version has relu nonlinearities.
# After each such layer, there's a pooling step, which keeps 1 result in each 2x2 window:
# After each conv layer there's a pooling step. Finally, there are some fully connected layers:

conv_layers = Chain(
lenet = Chain(
Conv((5, 5), 1=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6=>16, relu),
MaxPool((2, 2)),
)

# Now conv_layers[1] is just the first Conv layer, conv_layers[1:2] includes the pooling layer.
# These can accept any size of image; let's trace the sizes with the actual input:

#=
julia> x1 |> size
(28, 28, 1, 64)
julia> conv_layers[1](x1) |> size
(24, 24, 6, 64)
julia> conv_layers[1:2](x1) |> size
(12, 12, 6, 64)
julia> conv_layers[1:3](x1) |> size
(8, 8, 16, 64)
julia> conv_layers(x1) |> size
(4, 4, 16, 64)
julia> conv_layers(x1) |> Flux.flatten |> size
(256, 64)
=#

# Flux.flatten is just reshape, preserving the batch dimesion (64) while combining others (4*4*16).
# These layers are going to be followed by some Dense layers, which need to know what size to expect.
# (See Flux.outputsize for ways to automate this.)

dense_layers = Chain(
Flux.flatten,
Dense(256 => 120, relu),
Dense(120 => 84, relu),
Dense(84 => 10),
)

# Now assemble the whole network, and try it out:
) |> gpu

lenet = Chain(conv_layers, Flux.flatten, dense_layers) # |> gpu

y1hat = lenet(x1)
y1hat = lenet(x1) # try it out

softmax(y1hat)

Expand All @@ -96,72 +62,72 @@ hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))

using Statistics: mean # standard library

function loss_and_accuracy(model; split=:train)
(x,y) = first(get_data(; split, batchsize=Inf))
function loss_and_accuracy(model, data::MNIST=test_data)
(x,y) = only(loader(data; batchsize=0)) # batchsize=0 means one big batch
= model(x)
loss = Flux.logitcrossentropy(ŷ, y) # did not include softmax in the model
acc = round(100 * mean(Flux.onecold(ŷ) .== Flux.onecold(y)); digits=2)
(; loss, acc, split) # make a NamedTuple
(; loss, acc, split=data.split) # return a NamedTuple
end

loss_and_accuracy(lenet, split=:test) # accuracy about 10%
loss_and_accuracy(lenet) # accuracy about 10%

#===== TRAINING =====#

# Let's collect some hyper-parameters in a NamedTuple, just to write them in one place.
# Global variables are fine -- we won't access this from inside any fast loops.

TRAIN = (;
settings = (;
eta = 3e-4, # learning rate
lambda = 1e-2, # for weight decay
batchsize = 128,
epochs = 10,
)
LOG = []

train_loader = get_data(batchsize=TRAIN.batchsize)
train_log = []

# Initialise the storage needed for the optimiser:

opt_rule = OptimiserChain(WeightDecay(TRAIN.lambda), Adam(TRAIN.eta))
opt_rule = OptimiserChain(WeightDecay(settings.lambda), Adam(settings.eta))
opt_state = Flux.setup(opt_rule, lenet);

for epoch in 1:TRAIN.epochs
@time for (x,y) in train_loader
for epoch in 1:settings.epochs
@time for (x,y) in loader(batchsize=settings.batchsize)
grads = Flux.gradient(m -> Flux.logitcrossentropy(m(x), y), lenet)
Flux.update!(opt_state, lenet, grads[1])
end

# Logging & saving, not every epoch
if epoch % 2 == 0
loss, acc, _ = loss_and_accuracy(lenet)
test_loss, test_acc, _ = loss_and_accuracy(lenet, split=:test)
test_loss, test_acc, _ = loss_and_accuracy(lenet, test_data)
@info "logging:" epoch acc test_acc
nt = (; epoch, loss, acc, test_loss, test_acc)
push!(LOG, nt)
push!(train_log, nt)
end
if epoch % 5 == 0
name = joinpath("runs", "lenet.bson")
# BSON.@save name lenet epoch
BSON.@save name lenet epoch
end
end

LOG
train_log

loss_and_accuracy(lenet, split=:test) # already logged
# We can re-run the quick sanity-check of predictions:
y1hat = lenet(x1)
hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))

#===== INSPECTION =====#

using ImageInTerminal, ImageCore

xtest, ytest = first(get_data(; split=:test, batchsize=Inf))
xtest, ytest = only(loader(test_data, batchsize=0))

# Many ways to look at images.
# ImageCore.Gray is a special type, whick interprets numbers between 0.0 and 1.0 as gray...
# There are many ways to look at images, you won't need ImageInTerminal if working in a notebook
# ImageCore.Gray is a special type, whick interprets numbers between 0.0 and 1.0 as shades:

xtest[:,:,1,5] .|> Gray |> transpose # should display a 4

Flux.onecold(ytest, 0:9)[5] # it's a 4
Flux.onecold(ytest, 0:9)[5] # it's coded as being a 4

# Let's look for the image whose classification is least certain.
# First, in each column of probabilities, ask for the largest one.
Expand All @@ -174,9 +140,57 @@ _, i = findmin(vec(max_p))
xtest[:,:,1,i] .|> Gray |> transpose

Flux.onecold(ytest, 0:9)[i] # true classification
# Flux.onecold(ptest[:,:,:,i:i], 0:9) # uncertain prediction
# Maybe broken?
Flux.onecold(ptest[:,i], 0:9) # uncertain prediction

# Next, let's look for the most confident, yet wrong, prediction.
# Often this will look quite ambiguous to you too.

iwrong = findall(Flux.onecold(lenet(xtest)) .!= Flux.onecold(ytest))

xtest[:,:,1,itest[1]]
max_p = maximum(ptest[:,iwrong]; dims=1)
_, k = findmax(vec(max_p)) # now max not min
i = iwrong[k]

xtest[:,:,1,i] .|> Gray |> transpose

Flux.onecold(ytest, 0:9)[i] # true classification
Flux.onecold(ptest[:,i], 0:9) # prediction

#===== SIZES =====#

# Maybe... at first I had this above, but it makes things long.

# A layer like Conv((5, 5), 1=>6) takes 5x5 patches of an image, and matches them to each
# of 6 different 5x5 filters, placed at every possible position. These filters are here:

Conv((5, 5), 1=>6).weights |> summary # 5×5×1×6 Array{Float32, 4}

# This layer can accept any size of image; let's trace the sizes with the actual input:

#=
julia> x1 |> size
(28, 28, 1, 64)
julia> conv_layers[1](x1) |> size
(24, 24, 6, 64)
julia> conv_layers[1:2](x1) |> size
(12, 12, 6, 64)
julia> conv_layers[1:3](x1) |> size
(8, 8, 16, 64)
julia> conv_layers(x1) |> size
(4, 4, 16, 64)
julia> conv_layers(x1) |> Flux.flatten |> size
(256, 64)
=#

# Flux.flatten is just reshape, preserving the batch dimesion (64) while combining others (4*4*16).
# This 256 must match the Dense(256 => 120). (See Flux.outputsize for ways to automate this.)

#===== THE END =====#

0 comments on commit 8e44f04

Please sign in to comment.