From 887cba1dea1730281f256e507c4b765239a8406f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 22 May 2023 11:45:43 -0400 Subject: [PATCH] Simplify LeNet (#394) * simpler lenet * better version * update & simplify * update & simplify * project+manifest * use JLD2 and Flux.state * make directory, tidy * manifest * add show statements --- README.md | 2 +- vision/conv_mnist/Manifest.toml | 1279 ++++++++++++++++++++++--------- vision/conv_mnist/Project.toml | 13 +- vision/conv_mnist/conv_mnist.jl | 366 +++++---- 4 files changed, 1109 insertions(+), 551 deletions(-) diff --git a/README.md b/README.md index 3239652a..a543a021 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ Flux v0.13 is the latest right now, marked with ☀️; models upgraded to use **Vision** * MNIST * [Simple multi-layer perceptron](vision/mlp_mnist) ☀️ v0.13 + - * [Simple ConvNet (LeNet)](vision/conv_mnist) ☀️ v0.13 + * [Simple ConvNet (LeNet)](vision/conv_mnist) ☀️ v0.13 + * [Variational Auto-Encoder](vision/vae_mnist) ☀️ v0.13 + * [Deep Convolutional Generative Adversarial Networks](vision/dcgan_mnist) ☀️ v0.13 + * [Conditional Deep Convolutional Generative Adversarial Networks](vision/cdcgan_mnist) ☀️ v0.13 diff --git a/vision/conv_mnist/Manifest.toml b/vision/conv_mnist/Manifest.toml index 06f118a4..daad7b48 100644 --- a/vision/conv_mnist/Manifest.toml +++ b/vision/conv_mnist/Manifest.toml @@ -1,622 +1,1199 @@ # This file is machine-generated - editing it directly is not advised -[[AbstractFFTs]] +julia_version = "1.9.0-rc3" +manifest_format = "2.0" +project_hash = "d6992a1d25242425878729ae27146921261f25a9" + +[[deps.AbstractFFTs]] deps = ["LinearAlgebra"] -git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" +git-tree-sha1 = "16b6dbc4cf7caee4e1e75c49485ec67b667098a0" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.0.1" +version = "1.3.1" +weakdeps = ["ChainRulesCore"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + +[[deps.Accessors]] +deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "a4f8669e46c8cdf68661fe6bb0f7b89f51dd23cf" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.30" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.6.2" +weakdeps = ["StaticArrays"] -[[AbstractTrees]] -git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.3.4" + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" -[[Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.2.0" +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" -[[ArgTools]] +[[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" -[[Artifacts]] +[[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -[[BFloat16s]] -deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" -[[BSON]] -git-tree-sha1 = "db18b5ea04686f73d269e10bdb241947c40d7d6f" -uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -version = "0.3.2" - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.4.2" -[[BinDeps]] -deps = ["Libdl", "Pkg", "SHA", "URIParser", "Unicode"] -git-tree-sha1 = "1289b57e8cf019aede076edab0587eb9644175bd" -uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" -version = "1.0.2" +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] +git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.37" -[[BinaryProvider]] -deps = ["Libdl", "Logging", "SHA"] -git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058" -uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.10" +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[Blosc]] -deps = ["Blosc_jll"] -git-tree-sha1 = "84cf7d0f8fd46ca6f1b3e0305b4b4a37afe50fd6" -uuid = "a74b3585-a348-5f62-a45c-50e91977d574" -version = "0.7.0" +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" -[[Blosc_jll]] -deps = ["Libdl", "Lz4_jll", "Pkg", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "aa9ef39b54a168c3df1b2911e7797e4feee50fbe" -uuid = "0b7ba130-8d10-5ba8-a3d6-c5182647fed9" -version = "1.14.3+1" +[[deps.BitFlags]] +git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.7" -[[BufferedStreams]] -deps = ["Compat", "Test"] -git-tree-sha1 = "5d55b9486590fdda5905c275bb21ce1f0754020f" +[[deps.BufferedStreams]] +git-tree-sha1 = "bb065b14d7f941b8617bc323063dbe79f55d16ea" uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.0.0" +version = "1.1.0" -[[CEnum]] -git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +[[deps.CEnum]] +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.1" +version = "0.4.2" -[[CRC32c]] -uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "ed28c86cbde3dc3f53cf76643c2e9bc11d56acc7" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.10" -[[CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "NNlib", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"] -git-tree-sha1 = "2d90e6c29706856928f02e11ae15e71889905e34" +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "280893f920654ebfaaaa1999fbd975689051f890" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "2.6.1" +version = "4.2.0" -[[ChainRules]] -deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "e01f521443e3700f40ad3c7c1c6aa3a6940aaea1" +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.5.0+1" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.2.2" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "5248d9c45712e51e27ba9b30eebec65658c6ce29" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.6.0+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "2918fbffb50e3b7a0b9127617587afa76d4276e8" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "8.8.1+0" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] +git-tree-sha1 = "8bae903893aeeb429cf732cf1888490b93ecf265" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.54" +version = "1.49.0" -[[ChainRulesCore]] +[[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3" +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.29" +version = "1.16.0" + +[[deps.Chemfiles]] +deps = ["Chemfiles_jll", "DocStringExtensions"] +git-tree-sha1 = "9126d0271c337ca5ed02ba92f2dec087c4260d4a" +uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" +version = "0.10.31" -[[CodecZlib]] +[[deps.Chemfiles_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "d4e54b053fc584e7a0f37e9d3a5c4500927b343a" +uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" +version = "0.10.3+0" + +[[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" +git-tree-sha1 = "9c209fb7536406834aa938fb149964b985de6c83" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.0" +version = "0.7.1" -[[ColorTypes]] +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "be6ab11021cd29f0344d5c4357b163af05a48cba" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.21.0" + +[[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "5e9769a17f17b587c951d57ba4319782b40c3513" +git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.10.10" +version = "0.11.4" -[[Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"] -git-tree-sha1 = "ac5f2213e56ed8a34a3dd2f681f4df1166b34929" +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] +git-tree-sha1 = "600cc5508d66b78aae350f7accdb58763ac18589" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.9.10" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.6" +version = "0.12.10" -[[CommonSubexpressions]] +[[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" +[[deps.Compat]] +deps = ["UUIDs"] +git-tree-sha1 = "7a60c856b9fa189eb34f5f8a6f6b5529b7942957" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.25.0" +version = "4.6.1" +weakdeps = ["Dates", "LinearAlgebra"] -[[CompilerSupportLibraries_jll]] + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.2+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" -[[DataAPI]] -git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d" +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "96d823b94ba8d187a6d8f0826e731195a74b90e9" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.2.0" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "738fec4d684a9a6ee9598a8bfee305b26831f28c" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.2" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.6.0" +version = "1.15.0" -[[DataDeps]] -deps = ["BinaryProvider", "HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"] -git-tree-sha1 = "4f0e41ff461d42cfc62ff0de4f1cd44c6e6b3771" +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"] +git-tree-sha1 = "bc0a264d3e7b3eeb0b6fc9f6481f970697f29805" uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.7" +version = "0.7.10" -[[DataStructures]] +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SnoopPrecompile", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "aa51303df86f8626a962fccb878430cdb0a97eee" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.5.0" + +[[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" +git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.9" +version = "0.18.13" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" -[[Dates]] +[[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -[[DelimitedFiles]] +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" -[[DiffResults]] -deps = ["StaticArrays"] -git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.3" +version = "1.1.0" -[[DiffRules]] -deps = ["NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "a4ad7ef19d2cdc2eff57abbbe68032b1cd0bd8f8" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.0.2" +version = "1.13.0" -[[Distributed]] +[[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -[[Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" -[[ExprTools]] -git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e" +[[deps.ExprTools]] +git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.3" +version = "0.1.9" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.1" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" -[[FileIO]] +[[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "8800ec70aee7292931a3d3c10a3be3445b9c6141" +git-tree-sha1 = "299dc33549f68299137e51e6d49a13b5b1da9673" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.6.2" +version = "1.16.1" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.20" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" -[[FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "4705cc4e212c3c978c60b1b18118ec49b4d731fd" +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "fc86b4fd3eff76c3ce4f5e96e2fdfa6282722885" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.5" +version = "1.0.0" -[[FixedPointNumbers]] +[[deps.FixedPointNumbers]] deps = ["Statistics"] git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.4" -[[Flux]] -deps = ["AbstractTrees", "Adapt", "CUDA", "CodecZlib", "Colors", "DelimitedFiles", "Functors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "SHA", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] -git-tree-sha1 = "6be915c29e53d41311272540b094df89c53bf350" -repo-rev = "master" -repo-url = "https://github.com/FluxML/Flux.jl.git" +[[deps.Flux]] +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"] +git-tree-sha1 = "64005071944bae14fc145661f617eb68b339189c" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.12.0-dev" +version = "0.13.16" + + [deps.Flux.extensions] + AMDGPUExt = "AMDGPU" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -[[ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77" +[[deps.FoldsThreads]] +deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] +git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" +uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" +version = "0.1.1" + +[[deps.Formatting]] +deps = ["Printf"] +git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" +uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" +version = "0.4.2" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.16" +version = "0.10.35" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" -[[Functors]] -deps = ["MacroTools"] -git-tree-sha1 = "a7bb2af991c43dcf5c3455d276dd83976799634f" +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.2.1" +version = "0.4.4" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" -[[GPUArrays]] -deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "f99a25fe0313121f2f9627002734c7d63b4dd3bd" +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "9ade6983c3dbbd492cf5729f865fe030d1541463" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "6.2.0" +version = "8.6.6" -[[GPUCompiler]] -deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "ef2839b063e158672583b9c09d2cf4876a8d3d55" +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "1cd7f0af1aa58abc02ea1d872953a97359cb87fa" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.4" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.10.0" +version = "0.19.3" -[[GZip]] +[[deps.GZip]] deps = ["Libdl"] git-tree-sha1 = "039be665faf0b8ae36e089cd694233f5dee3f7d6" uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" version = "0.5.1" -[[Graphics]] +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.Graphics]] deps = ["Colors", "LinearAlgebra", "NaNMath"] -git-tree-sha1 = "2c1cf4df419938ece72de17f368a021ee162762e" +git-tree-sha1 = "d61890399bc535850c4bf08e4e0d3a7ad0f21cbd" uuid = "a2bd30eb-e257-5431-a919-1863eab51364" -version = "1.1.0" +version = "1.1.2" -[[HDF5]] -deps = ["Blosc", "Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires"] -git-tree-sha1 = "8a21f34a34491833bcda29a3ec2188b4ec6e558f" +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "3dab31542b3da9f25a6a1d11159d4af8fdce7d67" uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.15.4" +version = "0.16.14" -[[HDF5_jll]] +[[deps.HDF5_jll]] deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "fd83fa0bde42e01952757f01149dd968c06c4dba" +git-tree-sha1 = "4cc2bb72df6ff40b055295fdef6d92955f9dede8" uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.12.0+1" +version = "1.12.2+2" -[[HTTP]] -deps = ["Base64", "Dates", "IniFile", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] -git-tree-sha1 = "c9f380c76d8aaa1fa7ea9cf97bddbc0d5b15adc2" +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "877b7bc42729aa2c90bbbf5cb0d4294bd6d42e5a" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "0.9.5" +version = "1.9.1" -[[IRTools]] +[[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510" +git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.2" +version = "0.4.10" + +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "b51bb8cae22c66d0f6357e3bcb6363145ef20835" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.5" -[[ImageCore]] -deps = ["AbstractFFTs", "Colors", "FixedPointNumbers", "Graphics", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "Reexport"] -git-tree-sha1 = "79badd979fbee9b8980cd995cd5a86a9e93b8ad7" +[[deps.ImageCore]] +deps = ["AbstractFFTs", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Graphics", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "Reexport"] +git-tree-sha1 = "acf614720ef026d38400b3817614c45882d75500" uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.8.20" +version = "0.9.4" -[[IniFile]] -deps = ["Test"] -git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8" -uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" -version = "0.5.0" +[[deps.ImageShow]] +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "ce28c68c900eed3cdbfa418be66ed053e54d4f56" +uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +version = "0.3.7" -[[InteractiveUtils]] +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +deps = ["Parsers"] +git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.0" + +[[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[JLLWrappers]] -git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0" +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "6667aadd1cdee2c6cd068128b3d226ebc4fb0c67" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.9" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"] +git-tree-sha1 = "42c17b18ced77ff0be65957a591d34f4ed57c631" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.31" + +[[deps.JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.2.0" +version = "1.4.1" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "SnoopPrecompile", "StructTypes", "UUIDs"] +git-tree-sha1 = "84b10656a41ef564c39d2d477d7236966d2b5683" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.12.0" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" -[[Juno]] -deps = ["Base64", "Logging", "Media", "Profile"] -git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" -uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.8.4" +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "47be64f040a7ece575c2b5f53ca6da7b548d69f4" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.4" -[[LLVM]] -deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194" +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "26a31cdd9f1f4ea74f649a7bf249703c687a953d" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.6.0" +version = "5.1.0" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.21+0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.0" -[[LazyArtifacts]] +[[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" -[[LibCURL]] +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + +[[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" -[[LibCURL_jll]] +[[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" -[[LibGit2]] +[[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" -[[LibSSH2_jll]] +[[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" -[[Libdl]] +[[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -[[LinearAlgebra]] -deps = ["Libdl"] +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "c7cb1f5d892775ba13767a87c7ada0b980ea0a71" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.16.1+2" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -[[Logging]] +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.23" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[Lz4_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "51b1db0732bbdcfabb60e36095cc3ed9c0016932" -uuid = "5ced341a-0733-55b8-9ab6-a4889d929147" -version = "1.9.2+2" +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.0" -[[MAT]] +[[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "5c62992f3d46b8dce69bdd234279bb5a369db7d5" +git-tree-sha1 = "6eff5740c8ab02c90065719579c7aa0eb40c9f69" uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.1" +version = "0.10.4" -[[MLDatasets]] -deps = ["BinDeps", "ColorTypes", "DataDeps", "DelimitedFiles", "FixedPointNumbers", "GZip", "MAT", "Requires"] -git-tree-sha1 = "7b1a2d0ccd45e1474d2f9f6e4582e69f910e4175" +[[deps.MLDatasets]] +deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Tables"] +git-tree-sha1 = "498b37aa3ebb4407adea36df1b244fa4e397de5e" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.5.5" +version = "0.7.9" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "FoldsThreads", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "ca31739905ddb08c59758726e22b9e25d0d1521b" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.2" -[[MacroTools]] +[[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" +git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.6" +version = "0.5.10" -[[MappedArrays]] -deps = ["FixedPointNumbers"] -git-tree-sha1 = "b92bd220c95a8bbe89af28f11201fd080e0e3fe7" +[[deps.MappedArrays]] +git-tree-sha1 = "e8b359ef06ec72e8c030463fe02efe5527ee5142" uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.3.0" +version = "0.4.1" -[[Markdown]] +[[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"] -git-tree-sha1 = "1c38e51c3d08ef2278062ebceade0e46cefc96fe" +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] +git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.0.3" +version = "1.1.7" -[[MbedTLS_jll]] +[[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+0" -[[Media]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" -uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" -version = "0.5.0" - -[[Memoize]] -deps = ["MacroTools"] -git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa" -uuid = "c03570c3-d221-55d1-a50c-7939bbd78826" -version = "0.4.4" +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.4" -[[Missings]] +[[deps.Missings]] deps = ["DataAPI"] -git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c" +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.5" +version = "1.1.0" -[[Mmap]] +[[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" -[[MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews"] -git-tree-sha1 = "614e8d77264d20c1db83661daadfab38e8e4b77e" +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.2.4" +version = "0.3.4" -[[MozillaCACerts_jll]] +[[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.10.11" -[[NNlib]] -deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "5ce2e4b2bfe3811811e7db4b6a148439806fd2f8" +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.16" +version = "0.8.20" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + +[[deps.NNlibCUDA]] +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] +git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.7" -[[NaNMath]] -git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +[[deps.NPZ]] +deps = ["FileIO", "ZipFile"] +git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" +uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" +version = "0.4.3" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.5" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" -[[NetworkOptions]] +[[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" -[[OffsetArrays]] +[[deps.OffsetArrays]] deps = ["Adapt"] -git-tree-sha1 = "b3dfef5f2be7d7eb0e782ba9146a5271ee426e90" +git-tree-sha1 = "82d7c9e310fe55aa54996e6f7f94674e2a38fcb4" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.6.2" +version = "1.12.9" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "f511fca956ed9e70b80cd3417bb8c2dde4b68644" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.3" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.21+4" -[[OpenSSL_jll]] +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "51901a49222b09e3743c65b8847687ae5fc78eb2" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.1" + +[[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "71bbbc616a1d710879f5a1021bcba65ffba6ce58" +git-tree-sha1 = "9ff31d101d987eb9d66bd8b176ac7c277beccd09" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "1.1.1+6" +version = "1.1.20+0" -[[OpenSpecFun_jll]] +[[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+4" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.18" -[[OrderedCollections]] -git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf" +[[deps.OrderedCollections]] +git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.0" +version = "1.6.0" -[[PaddedViews]] +[[deps.PaddedViews]] deps = ["OffsetArrays"] -git-tree-sha1 = "0fa5e78929aebc3f6b56e1a88cf505bb00a354c4" +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.8" +version = "0.5.12" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "7302075e5e06da7d000d9bfa055013e3e85578ca" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.5.9" + +[[deps.Pickle]] +deps = ["DataStructures", "InternedStrings", "Serialization", "SparseArrays", "Strided", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e6a34eb1dc0c498f0774bbfbbbeff2de101f4235" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.2" -[[Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs"] +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.9.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "a6062fe4063cdafe78f4a0a81cfffb89721b30e7" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.2" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "259e206946c293698122f63e2b513a7c99a244e8" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.1.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.0" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "Formatting", "LaTeXStrings", "Markdown", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "213579618ec1f42dea7dd637a42785a608b1ea9c" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.2.4" -[[Printf]] +[[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[Profile]] -deps = ["Printf"] -uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" -[[ProgressMeter]] +[[deps.ProgressMeter]] deps = ["Distributed", "Printf"] -git-tree-sha1 = "6e9c89cba09f6ef134b00e10625590746ba1e036" +git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.5.0" - -[[ProtoBuf]] -deps = ["Compat", "Logging"] -git-tree-sha1 = "9ecf92287404ebe5666a1c0488c3aaf90bbb5ff4" -uuid = "3349acd9-ac6a-5e09-bcdb-63829b23a429" -version = "0.10.0" +version = "1.7.2" -[[REPL]] +[[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" -[[Random]] -deps = ["Serialization"] +[[deps.Random]] +deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[Reexport]] -git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5" +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.6.1" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.0.0" +version = "1.2.2" -[[Requires]] +[[deps.Requires]] deps = ["UUIDs"] -git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.3" +version = "1.3.0" -[[SHA]] +[[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" -[[Scratch]] +[[deps.Scratch]] deps = ["Dates"] -git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6" +git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.0.3" +version = "1.2.0" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "77d3c4726515dca71f6d80fbb5e251088defe305" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.3.18" -[[Serialization]] +[[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" -[[Sockets]] +[[deps.SnoopPrecompile]] +deps = ["Preferences"] +git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" +uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" +version = "1.0.3" + +[[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -[[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" +version = "1.1.0" -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[[SpecialFunctions]] -deps = ["ChainRulesCore", "OpenSpecFun_jll"] -git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902" +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.3.0" +version = "2.2.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" -[[StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49" +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "c262c8e978048c2b095be1672c9bee55b4619521" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.0.1" +version = "1.5.24" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.0" -[[Statistics]] +[[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" -[[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "400aa43f7de43aeccc5b2e39a76a79d262202b76" +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.6.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.3" +version = "0.34.0" + +[[deps.Strided]] +deps = ["LinearAlgebra", "TupleTools"] +git-tree-sha1 = "a7a664c91104329c88222aa20264e1a05b6ad138" +uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" +version = "1.2.3" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "33c0da881af3248dafefb939a21694b97cfece76" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.6" + +[[deps.StringManipulation]] +git-tree-sha1 = "46da2434b41f41ac3594ee9816ce5541c6096123" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.0" -[[TOML]] +[[deps.StructArrays]] +deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] +git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.15" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+6" + +[[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" -[[Tar]] +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.10.1" + +[[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" -[[TensorBoardLogger]] -deps = ["CRC32c", "ColorTypes", "FileIO", "FixedPointNumbers", "ImageCore", "ProtoBuf", "Requires", "StatsBase"] -git-tree-sha1 = "40bdbb1a241c94189044b6649035225b19d79cf2" -uuid = "899adc3e-224a-11e9-021f-63837185c80f" -version = "0.1.15" +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" -[[Test]] +[[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[TimerOutputs]] -deps = ["Printf"] -git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236" +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.8" +version = "0.5.23" -[[TranscodingStreams]] +[[deps.TranscodingStreams]] deps = ["Random", "Test"] -git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c" +git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.5" +version = "0.9.13" -[[URIParser]] -deps = ["Unicode"] -git-tree-sha1 = "53a9f49546b8d2dd2e688d216421d050c9a31d0d" -uuid = "30578b45-9adc-5946-b283-645ec420af67" -version = "0.4.1" +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "25358a5f2384c490e98abd565ed321ffae2cbb37" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.76" -[[URIs]] -git-tree-sha1 = "7855809b88d7b16e9b029afd17880930626f54a2" +[[deps.TupleTools]] +git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938" +uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +version = "1.3.0" + +[[deps.URIs]] +git-tree-sha1 = "074f993b0ca030848b897beff716d93aca60f06a" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.2.0" +version = "1.4.2" -[[UUIDs]] +[[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" -[[Unicode]] +[[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" -[[ZipFile]] +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.2" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.ZipFile]] deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "c3a5637e27e914a7a445b8d0ad063d701931e9f7" +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.9.3" +version = "0.10.1" -[[Zlib_jll]] +[[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+0" -[[Zstd_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "2c1332c54931e83f8f94d310fa447fd743e8d600" -uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.4.8+0" - -[[Zygote]] -deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "52835a83f7c899cfcb95f796d584201812887ea8" +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "ebac1ae9f048c669317ad48c9bed815790a468d8" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.3" +version = "0.6.61" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" -[[ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.1" +version = "0.2.3" + +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDNN_jll"] +git-tree-sha1 = "ec954b59f6b0324543f2e3ed8118309ac60cb75b" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.0.3" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.7.0+0" -[[nghttp2_jll]] +[[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" -[[p7zip_jll]] +[[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/vision/conv_mnist/Project.toml b/vision/conv_mnist/Project.toml index da364e39..0193d664 100644 --- a/vision/conv_mnist/Project.toml +++ b/vision/conv_mnist/Project.toml @@ -1,14 +1,13 @@ [deps] -BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" -TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" [compat] -CUDA = "2.4.0" -Flux = "0.13" -MLDatasets = "0.6" -julia = "1.5" +CUDA = "3, 4" +Flux = "0.13.16" +JLD2 = "0.4.31" +MLDatasets = "0.7" +julia = "1.6" diff --git a/vision/conv_mnist/conv_mnist.jl b/vision/conv_mnist/conv_mnist.jl index 3b153c75..c920ada4 100644 --- a/vision/conv_mnist/conv_mnist.jl +++ b/vision/conv_mnist/conv_mnist.jl @@ -1,240 +1,222 @@ -# # Classification of MNIST dataset using ConvNet +# Classification of MNIST dataset using a convolutional network, +# which is a variant of the original LeNet from 1998. -# In this tutorial, we build a convolutional neural network (ConvNet or CNN) known as [LeNet5](https://en.wikipedia.org/wiki/LeNet) -# to classify [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digits. +# This example uses a GPU if you have one. +# And demonstrates how to save model state. -# LeNet5 is one of the earliest CNNs. It was originally used for recognizing handwritten characters. At a high level LeNet (LeNet-5) consists of two parts: +using MLDatasets, Flux, JLD2, CUDA # this will install everything if necc. -# * A convolutional encoder consisting of two convolutional layers. -# * A dense block consisting of three fully-connected layers. +folder = "runs" # sub-directory in which to save +isdir(folder) || mkdir(folder) +filename = joinpath(folder, "lenet.jld2") -# The basic units in each convolutional block are a convolutional layer, a sigmoid activation function, -# and a subsequent average pooling operation. Each convolutional layer uses a 5×5 kernel and a sigmoid activation function. -# These layers map spatially arranged inputs to a number of two-dimensional feature maps, typically increasing the number of channels. -# The first convolutional layer has 6 output channels, while the second has 16. -# Each 2×2 pooling operation (stride 2) reduces dimensionality by a factor of 4 via spatial downsampling. -# The convolutional block emits an output with shape given by (width, height, number of channels, batch size). +#===== DATA =====# -# ![LeNet-5](../conv_mnist/docs/LeNet-5.png) +# Calling MLDatasets.MNIST() will dowload the dataset if necessary, +# and return a struct containing it. +# It takes a few seconds to read from disk each time, so do this once: -# Source: https://d2l.ai/chapter_convolutional-neural-networks/lenet.html +train_data = MLDatasets.MNIST() # i.e. split=:train +test_data = MLDatasets.MNIST(split=:test) -# >**Note:** The original architecture of Lenet5 used the sigmoind activation function. However, this is a a modernized version since it uses the RELU activation function instead. +# train_data.features is a 28×28×60000 Array{Float32, 3} of the images. +# Flux needs a 4D array, with the 3rd dim for channels -- here trivial, grayscale. +# Combine the reshape needed with other pre-processing: -# If you need more information about how CNNs work and related technical concepts, check out the following resources: +function loader(data::MNIST=train_data; batchsize::Int=64) + x4dim = reshape(data.features, 28,28,1,:) # insert trivial channel dim + yhot = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix + Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) |> gpu +end -# * [Gradient-Based Learning Applied to Document Recognition](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) . This is LeNet5 original paper by Yann LeCunn and others. -# * [Convolutional Neural Networks for Visual Recognition](https://cs231n.github.io/convolutional-networks/). -# * [Neural Networks in Flux.jl with Huda Nassar (working with the MNIST dataset)](https://youtu.be/Oxi0Pfmskus). -# * [Dive into Deep Learning", 2020](https://d2l.ai/chapter_convolutional-neural-networks/lenet.html). +loader() # returns a DataLoader, with first element a tuple like this: +x1, y1 = first(loader()); # (28×28×1×64 Array{Float32, 3}, 10×64 OneHotMatrix(::Vector{UInt32})) -# This example demonstrates Flux’s Convolution and pooling layers, the usage of TensorBoardLogger, -# how to write out the saved model to the file `mnist_conv.bson`, -# and also combines various packages from the Julia ecosystem with Flux. +# If you are using a GPU, these should be CuArray{Float32, 3} etc. +# If not, the `gpu` function does nothing (except complain the first time). +#===== MODEL =====# -# To run this example, we need the following packages: +# LeNet has two convolutional layers, and our modern version has relu nonlinearities. +# After each conv layer there's a pooling step. Finally, there are some fully connected layers: -using Flux -using Flux.Data: DataLoader -using Flux.Optimise: Optimiser, WeightDecay -using Flux: onehotbatch, onecold, flatten -using Flux.Losses: logitcrossentropy -using Statistics, Random -using Logging: with_logger -using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment! -using ProgressMeter: @showprogress -import MLDatasets -import BSON -using CUDA +lenet = Chain( + Conv((5, 5), 1=>6, relu), + MaxPool((2, 2)), + Conv((5, 5), 6=>16, relu), + MaxPool((2, 2)), + Flux.flatten, + Dense(256 => 120, relu), + Dense(120 => 84, relu), + Dense(84 => 10), +) |> gpu -# We set default values for the arguments for the function `train`: +# Notice that most of the parameters are in the final Dense layers. -Base.@kwdef mutable struct Args - η = 3e-4 ## learning rate - λ = 0 ## L2 regularizer param, implemented as weight decay - batchsize = 128 ## batch size - epochs = 10 ## number of epochs - seed = 0 ## set seed > 0 for reproducibility - use_cuda = true ## if true use cuda (if available) - infotime = 1 ## report every `infotime` epochs - checktime = 5 ## Save the model every `checktime` epochs. Set to 0 for no checkpoints. - tblogger = true ## log training with tensorboard - savepath = "runs/" ## results path -end +y1hat = lenet(x1) # try it out + +sum(softmax(y1hat); dims=1) -# ## Data +# Each column of softmax(y1hat) may be thought of as the network's probabilities +# that an input image is in each of 10 classes. To find its most likely answer, +# we can look for the largest output in each column, without needing softmax first. +# At the moment, these don't resemble the true values at all: -# We create the function `get_data` to load the MNIST train and test data from [MLDatasets](https://github.com/JuliaML/MLDatasets.jl) and reshape them so that they are in the shape that Flux expects. +@show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9)) -function get_data(args) - xtrain, ytrain = MLDatasets.MNIST(:train)[:] - xtest, ytest = MLDatasets.MNIST(:test)[:] +#===== METRICS =====# - xtrain = reshape(xtrain, 28, 28, 1, :) - xtest = reshape(xtest, 28, 28, 1, :) +# We're going to log accuracy and loss during training. There's no advantage to +# calculating these on minibatches, since MNIST is small enough to do it at once. - ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9) +using Statistics: mean # standard library - train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true) - test_loader = DataLoader((xtest, ytest), batchsize=args.batchsize) - - return train_loader, test_loader +function loss_and_accuracy(model, data::MNIST=test_data) + (x,y) = only(loader(data; batchsize=length(data))) # make one big batch + ŷ = model(x) + loss = Flux.logitcrossentropy(ŷ, y) # did not include softmax in the model + acc = round(100 * mean(Flux.onecold(ŷ) .== Flux.onecold(y)); digits=2) + (; loss, acc, split=data.split) # return a NamedTuple end -# The function `get_data` performs the following tasks: +@show loss_and_accuracy(lenet); # accuracy about 10%, before training -# * **Loads MNIST dataset:** Loads the train and test set tensors. The shape of the train data is `28x28x60000` and the test data is `28x28x10000`. -# * **Reshapes the train and test data:** Notice that we reshape the data so that we can pass it as arguments for the input layer of the model. -# * **One-hot encodes the train and test labels:** Creates a batch of one-hot vectors so we can pass the labels of the data as arguments for the loss function. For this example, we use the [logitcrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitcrossentropy) function and it expects data to be one-hot encoded. -# * **Creates mini-batches of data:** Creates two DataLoader objects (train and test) that handle data mini-batches of size `128 ` (as defined above). We create these two objects so that we can pass the entire data set through the loss function at once when training our model. Also, it shuffles the data points during each iteration (`shuffle=true`). +#===== TRAINING =====# -# ## Model +# Let's collect some hyper-parameters in a NamedTuple, just to write them in one place. +# Global variables are fine -- we won't access this from inside any fast loops. -# We create the LeNet5 "constructor". It uses Flux's built-in [Convolutional and pooling layers](https://fluxml.ai/Flux.jl/stable/models/layers/#Convolution-and-Pooling-Layers): +settings = (; + eta = 3e-4, # learning rate + lambda = 1e-2, # for weight decay + batchsize = 128, + epochs = 10, +) +train_log = [] +# Initialise the storage needed for the optimiser: -function LeNet5(; imgsize=(28,28,1), nclasses=10) - out_conv_size = (imgsize[1]÷4 - 3, imgsize[2]÷4 - 3, 16) - - return Chain( - Conv((5, 5), imgsize[end]=>6, relu), - MaxPool((2, 2)), - Conv((5, 5), 6=>16, relu), - MaxPool((2, 2)), - flatten, - Dense(prod(out_conv_size), 120, relu), - Dense(120, 84, relu), - Dense(84, nclasses) - ) +opt_rule = OptimiserChain(WeightDecay(settings.lambda), Adam(settings.eta)) +opt_state = Flux.setup(opt_rule, lenet); + +for epoch in 1:settings.epochs + # @time will show a much longer time for the first epoch, due to compilation + @time for (x,y) in loader(batchsize=settings.batchsize) + grads = Flux.gradient(m -> Flux.logitcrossentropy(m(x), y), lenet) + Flux.update!(opt_state, lenet, grads[1]) + end + + # Logging & saving, but not on every epoch + if epoch % 2 == 1 + loss, acc, _ = loss_and_accuracy(lenet) + test_loss, test_acc, _ = loss_and_accuracy(lenet, test_data) + @info "logging:" epoch acc test_acc + nt = (; epoch, loss, acc, test_loss, test_acc) # make a NamedTuple + push!(train_log, nt) + end + if epoch % 5 == 0 + JLD2.jldsave(filename; lenet_state = Flux.state(lenet) |> cpu) + println("saved to ", filename, " after ", epoch, " epochs") + end end -# ## Loss function +@show train_log; -# We use the function [logitcrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitcrossentropy) to compute the difference between -# the predicted and actual values (loss). +# We can re-run the quick sanity-check of predictions: +y1hat = lenet(x1) +@show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9)) -loss(ŷ, y) = logitcrossentropy(ŷ, y) +#===== INSPECTION =====# -# Also, we create the function `eval_loss_accuracy` to output the loss and the accuracy during training: +using ImageCore, ImageInTerminal -function eval_loss_accuracy(loader, model, device) - l = 0f0 - acc = 0 - ntot = 0 - for (x, y) in loader - x, y = x |> device, y |> device - ŷ = model(x) - l += loss(ŷ, y) * size(x)[end] - acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu)) - ntot += size(x)[end] - end - return (loss = l/ntot |> round4, acc = acc/ntot*100 |> round4) -end +xtest, ytest = only(loader(test_data, batchsize=length(test_data))); -# ## Utility functions -# We need a couple of functions to obtain the total number of the model's parameters. Also, we create a function to round numbers to four digits. +# There are many ways to look at images, you won't need ImageInTerminal if working in a notebook. +# ImageCore.Gray is a special type, whick interprets numbers between 0.0 and 1.0 as shades: -num_params(model) = sum(length, Flux.params(model)) -round4(x) = round(x, digits=4) +xtest[:,:,1,5] .|> Gray |> transpose |> cpu -# ## Train the model +Flux.onecold(ytest, 0:9)[5] # true label, should match! -# Finally, we define the function `train` that calls the functions defined above to train the model. +# Let's look for the image whose classification is least certain. +# First, in each column of probabilities, ask for the largest one. +# Then, over all images, ask for the lowest such probability, and its index. -function train(; kws...) - args = Args(; kws...) - args.seed > 0 && Random.seed!(args.seed) - use_cuda = args.use_cuda && CUDA.functional() - - if use_cuda - device = gpu - @info "Training on GPU" - else - device = cpu - @info "Training on CPU" - end +ptest = softmax(lenet(xtest)) +max_p = maximum(ptest; dims=1) +_, i = findmin(vec(max_p)) - ## DATA - train_loader, test_loader = get_data(args) - @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples" +xtest[:,:,1,i] .|> Gray |> transpose |> cpu - ## MODEL AND OPTIMIZER - model = LeNet5() |> device - @info "LeNet5 model: $(num_params(model)) trainable params" - - ps = Flux.params(model) +Flux.onecold(ytest, 0:9)[i] # true classification +ptest[:,i] # probabilities of all outcomes +Flux.onecold(ptest[:,i], 0:9) # uncertain prediction - opt = ADAM(args.η) - if args.λ > 0 ## add weight decay, equivalent to L2 regularization - opt = Optimiser(WeightDecay(args.λ), opt) - end - - ## LOGGING UTILITIES - if args.tblogger - tblogger = TBLogger(args.savepath, tb_overwrite) - set_step_increment!(tblogger, 0) ## 0 auto increment since we manually set_step! - @info "TensorBoard logging at \"$(args.savepath)\"" - end - - function report(epoch) - train = eval_loss_accuracy(train_loader, model, device) - test = eval_loss_accuracy(test_loader, model, device) - println("Epoch: $epoch Train: $(train) Test: $(test)") - if args.tblogger - set_step!(tblogger, epoch) - with_logger(tblogger) do - @info "train" loss=train.loss acc=train.acc - @info "test" loss=test.loss acc=test.acc - end - end - end - - ## TRAINING - @info "Start Training" - report(0) - for epoch in 1:args.epochs - @showprogress for (x, y) in train_loader - x, y = x |> device, y |> device - gs = Flux.gradient(ps) do - ŷ = model(x) - loss(ŷ, y) - end - - Flux.Optimise.update!(opt, ps, gs) - end - - ## Printing and logging - epoch % args.infotime == 0 && report(epoch) - if args.checktime > 0 && epoch % args.checktime == 0 - !ispath(args.savepath) && mkpath(args.savepath) - modelpath = joinpath(args.savepath, "model.bson") - let model = cpu(model) ## return model to cpu before serialization - BSON.@save modelpath model epoch - end - @info "Model saved in \"$(modelpath)\"" - end - end -end +#===== ARRAY SIZES =====# -# The function `train` performs the following tasks: +# A layer like Conv((5, 5), 1=>6) takes 5x5 patches of an image, and matches them to each +# of 6 different 5x5 filters, placed at every possible position. These filters are here: -# * Checks whether there is a GPU available and uses it for training the model. Otherwise, it uses the CPU. -# * Loads the MNIST data using the function `get_data`. -# * Creates the model and uses the [ADAM optimiser](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.ADAM) with weight decay. -# * Loads the [TensorBoardLogger.jl](https://github.com/JuliaLogging/TensorBoardLogger.jl) for logging data to Tensorboard. -# * Creates the function `report` for computing the loss and accuracy during the training loop. It outputs these values to the TensorBoardLogger. -# * Runs the training loop using [Flux’s training routine](https://fluxml.ai/Flux.jl/stable/training/training/#Training). For each epoch (step), it executes the following: -# * Computes the model’s predictions. -# * Computes the loss. -# * Updates the model’s parameters. -# * Saves the model `model.bson` every `checktime` epochs (defined as argument above.) +Conv((5, 5), 1=>6).weight |> summary # 5×5×1×6 Array{Float32, 4} -# ## Run the example +# This layer can accept any size of image; let's trace the sizes with the actual input: -# We call the function `train`: +#= -if abspath(PROGRAM_FILE) == @__FILE__ - train() -end +julia> x1 |> size +(28, 28, 1, 64) + +julia> lenet[1](x1) |> size # after Conv((5, 5), 1=>6, relu), +(24, 24, 6, 64) + +julia> lenet[1:2](x1) |> size # after MaxPool((2, 2)) +(12, 12, 6, 64) + +julia> lenet[1:3](x1) |> size # after Conv((5, 5), 6 => 16, relu) +(8, 8, 16, 64) + +julia> lenet[1:4](x1) |> size # after MaxPool((2, 2)) +(4, 4, 16, 64) + +julia> lenet[1:5](x1) |> size # after Flux.flatten +(256, 64) + +=# + +# Flux.flatten is just reshape, preserving the batch dimesion (64) while combining others (4*4*16). +# This 256 must match the Dense(256 => 120). Here is how to automate this, with Flux.outputsize: + +lenet2 = Flux.@autosize (28, 28, 1, 1) Chain( + Conv((5, 5), 1=>6, relu), + MaxPool((2, 2)), + Conv((5, 5), _=>16, relu), + MaxPool((2, 2)), + Flux.flatten, + Dense(_ => 120, relu), + Dense(_ => 84, relu), + Dense(_ => 10), +) + +# Check that this indeed accepts input the same size as above: + +@show lenet2(cpu(x1)) |> size; + +#===== LOADING =====# + +# During training, the code above saves the model state to disk. Load the last version: + +loaded_state = JLD2.load(filename, "lenet_state"); + +# Now you would normally re-create the model, and copy all parameters into that. +# We can use lenet2 from just above: + +Flux.loadmodel!(lenet2, loaded_state) + +# Check that it now agrees with the earlier, trained, model: + +@show lenet2(cpu(x1)) ≈ cpu(lenet(x1); + + +#===== THE END =====#