Skip to content

Commit

Permalink
streamlining streamlining
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Aug 21, 2023
1 parent 8e0fd65 commit f2ab5fa
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
/.quarto/
/Manifest.toml
/replicated/
**/.CondaPkg/
**/.CondaPkg
/dev/rebuttal/www

# Tex
Expand Down
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ The `experiments/` folder contains separate Julia scripts for each dataset and a
To run the experiment for a single dataset, (e.g. `linearly_separable`) simply run the following command:

```shell
DATANAME=linearly_separable
julia experiments/run_experiments.jl
julia experiments/run_experiments.jl -- data=linearly_separable
```

We use the following identifiers:
Expand All @@ -51,14 +50,17 @@ We use the following identifiers:
- `mnist` (*MNIST* data)
- `gmsc` (*GMSC* data)

To run all experiments at once you can instead just specify `DATANAME=all`.
To run all experiments at once you can instead run

```shell
julia experiments/run_experiments.jl -- run-all
```

Pre-trained versions of all of our black-box models have been archived as `Pkg` [artifacts](https://pkgdocs.julialang.org/v1/artifacts/) and are used by default. Should you wish to retrain the models as well, simply run the following command:
Pre-trained versions of all of our black-box models have been archived as `Pkg` [artifacts](https://pkgdocs.julialang.org/v1/artifacts/) and are used by default. Should you wish to retrain the models as well, simply use the `retrain` flag as follows:

```shell
DATANAME=linearly_separable
RETRAIN=true
julia experiments/run_experiments.jl
julia experiments/run_experiments.jl -- retrain
```

When running the experiments from the command line, the parameter choices used in the main paper are applied by default. To have control over these choices, we recommend you instead rely on the notebooks.
Expand Down
5 changes: 3 additions & 2 deletions experiments/run_experiments.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
include("setup.jl")

# User inputs:
if ENV["DATANAME"] == "all"
if "run-all" in ARGS
datanames = ["linearly_separable", "moons", "circles", "mnist", "gmsc"]
else
datanames = [ENV["DATANAME"]]
datanames = [ARGS[findall(contains.(ARGS, "data="))] |> x -> replace(x, "data=" => "")]
end
datanames = ["linearly_separable", "moons", "circles", "mnist", "gmsc"]

# Linearly Separable
if "linearly_separable" in datanames
Expand Down
6 changes: 1 addition & 5 deletions experiments/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ isdir(params_path) || mkdir(params_path)
test_size = 0.2

# Constants:
if ENV["RETRAIN"] == "true"
const RETRAIN = true
else
const RETRAIN = false
end
const RETRAIN = "retrain" ARGS ? true : false

# Artifacts:
using LazyArtifacts
Expand Down
78 changes: 45 additions & 33 deletions notebooks/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.9.0"
julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "5dac8702e1bf52ac1887686257c409c28f8872ae"
project_hash = "8709f4905068b3fffb11e99abcb4fafbd6d48a97"

[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
Expand Down Expand Up @@ -122,6 +122,12 @@ git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
version = "0.1.0"

[[deps.AtomsBase]]
deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"]
git-tree-sha1 = "c9804781ca49261c8eb6ce4b62f171cfa3d900f0"
uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
version = "0.3.4"

[[deps.Automa]]
deps = ["TranscodingStreams"]
git-tree-sha1 = "ef9997b3d5547c48b41c7bd8899e812a917b409d"
Expand Down Expand Up @@ -265,9 +271,9 @@ version = "0.6.0+0"

[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "c30b29597102341a1ea4c2175c4acae9ae522c9d"
git-tree-sha1 = "75923dce4275ead3799b238e10178a68c07dbd3b"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
version = "8.9.2+0"
version = "8.9.4+0"

[[deps.Cairo]]
deps = ["Cairo_jll", "Colors", "Glib_jll", "Graphics", "Libdl", "Pango_jll"]
Expand Down Expand Up @@ -342,10 +348,10 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"

[[deps.Chemfiles]]
deps = ["Chemfiles_jll", "DocStringExtensions"]
git-tree-sha1 = "6951fe6a535a07041122a3a6860a63a7a83e081e"
deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"]
git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206"
uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015"
version = "0.10.40"
version = "0.10.41"

[[deps.Chemfiles_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
Expand Down Expand Up @@ -443,7 +449,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.2+0"
version = "1.0.5+0"

[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
Expand Down Expand Up @@ -480,7 +486,7 @@ uuid = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
version = "0.2.18"

[[deps.ConformalPrediction]]
deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "ProgressMeter", "Random", "Serialization", "StatsBase", "Tables"]
deps = ["CategoricalArrays", "ChainRules", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "Serialization", "StatsBase"]
path = "../../ConformalPrediction.jl"
uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
version = "0.1.8"
Expand Down Expand Up @@ -693,9 +699,9 @@ version = "0.5.2"

[[deps.EvoTrees]]
deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "5023442c1f797c0fd6677b1a1886ab44f43f3378"
git-tree-sha1 = "a1fa1d1743478394a0a7188d054b67546e4ca143"
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
version = "0.16.0"
version = "0.16.1"

[[deps.ExactPredicates]]
deps = ["IntervalArithmetic", "Random", "StaticArraysCore", "Test"]
Expand Down Expand Up @@ -1292,9 +1298,9 @@ version = "0.1.5"

[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "a7e91ef94114d5bc8952bcaa8d6ff952cf709808"
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.4.2"
version = "1.5.0"

[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
Expand Down Expand Up @@ -1403,8 +1409,8 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.0"

[[deps.LaplaceRedux]]
deps = ["CSV", "Compat", "ComputationalResources", "DataFrames", "Flux", "LinearAlgebra", "MLJ", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "Statistics", "Tables", "Tullio", "Zygote"]
path = "../../LaplaceRedux.jl"
deps = ["Flux", "LinearAlgebra", "Parameters", "Plots", "Zygote"]
git-tree-sha1 = "a4adebbeafb96d0864b4833c254013f66dc6e0ee"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
version = "0.1.2"

Expand Down Expand Up @@ -1555,9 +1561,9 @@ version = "0.6.1"

[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f"
git-tree-sha1 = "5ab83e1679320064c29e8973034357655743d22d"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.24"
version = "0.3.25"

[deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -1631,9 +1637,9 @@ version = "0.19.2"

[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "2c9d6b9c627a80f6e6acbc6193026f455581fd04"
git-tree-sha1 = "0b7307d1a7214ec3c0ba305571e713f9492ea984"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.21.13"
version = "0.21.14"

[[deps.MLJDecisionTreeInterface]]
deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"]
Expand Down Expand Up @@ -1661,9 +1667,9 @@ version = "0.5.1"

[[deps.MLJModelInterface]]
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
git-tree-sha1 = "e89d1ea12c5a50057bfb0c124d905669e5ed4ec9"
git-tree-sha1 = "03ae109be87f460fe3c96b8a0dbbf9c7bf840bd5"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
version = "1.9.1"
version = "1.9.2"

[[deps.MLJModels]]
deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
Expand All @@ -1690,9 +1696,9 @@ version = "0.4.3"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2"
git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.10"
version = "0.5.11"

[[deps.Makie]]
deps = ["Animations", "Base64", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"]
Expand Down Expand Up @@ -1999,9 +2005,9 @@ version = "2.0.2"

[[deps.Optimisers]]
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "16776280310aa5553c370b9c7b17f34aadaf3c8e"
git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.2.19"
version = "0.2.20"

[[deps.Opus_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -2067,6 +2073,12 @@ git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.7.2"

[[deps.PeriodicTable]]
deps = ["Base64", "Test", "Unitful"]
git-tree-sha1 = "9a9731f346797126271405971dfdf4709947718b"
uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884"
version = "1.1.4"

[[deps.Permutations]]
deps = ["Combinatorics", "LinearAlgebra", "Random"]
git-tree-sha1 = "6e6cab1c54ae2382bcc48866b91cf949cea703a1"
Expand Down Expand Up @@ -2099,7 +2111,7 @@ version = "0.42.2+0"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.0"
version = "1.9.2"

[[deps.PkgTemplates]]
deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"]
Expand Down Expand Up @@ -2851,12 +2863,6 @@ git-tree-sha1 = "4d4ed7f294cda19382ff7de4c137d24d16adc89b"
uuid = "981d1d27-644d-49a2-9326-4793e63143c3"
version = "0.1.0"

[[deps.Tullio]]
deps = ["ChainRulesCore", "DiffRules", "LinearAlgebra", "Requires"]
git-tree-sha1 = "7871a39eac745697ee512a87eeff06a048a7905b"
uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
version = "0.3.5"

[[deps.TupleTools]]
git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938"
uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Expand Down Expand Up @@ -2905,6 +2911,12 @@ version = "1.16.3"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

[[deps.UnitfulAtomic]]
deps = ["Unitful"]
git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238"
uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"
version = "1.0.0"

[[deps.UnitfulLatexify]]
deps = ["LaTeXStrings", "Latexify", "Unitful"]
git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
Expand Down Expand Up @@ -3189,7 +3201,7 @@ version = "0.15.1+0"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.7.0+0"
version = "5.8.0+0"

[[deps.libfdk_aac_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down
1 change: 0 additions & 1 deletion notebooks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
Expand Down

0 comments on commit f2ab5fa

Please sign in to comment.