From 5ab6c75ea4198d3a8a81e28aa0d4380184638abf Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Wed, 24 May 2023 13:30:35 +0200 Subject: [PATCH] all polished --- README.md | 4 +-- experiments/circles.jl | 14 +++++++++ experiments/gmsc.jl | 21 ++++++++++++++ experiments/mnist.jl | 52 ++++++++++++++++++++++++++++++++++ experiments/moons.jl | 16 +++++++++++ experiments/run_experiments.jl | 9 ++++-- experiments/setup.jl | 7 ++++- 7 files changed, 118 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index dd43df3b..935518de 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,6 @@ We used [Quarto](https://quarto.org/) notebooks for prototyping and running expe - [MNIST](notebooks/mnist.qmd) - [GMSC](notebooks/gmsc.qmd) -Instead of looking at the notebooks directly, you may choose to browse the HTML book contained inside the `docs` folder. The book is automatically generated from the notebooks and includes all code chunks and their outputs. It is currently not possible to view the book online, but you can download the `docs/` folder and open the `index.html` file in your browser. - ## Inspecting the Results All results have been carefully reported either in the paper itself or in the supplementary material. In addition, we have released our results as binary files. These will be made publicly available after the review process. @@ -32,6 +30,8 @@ To reproduce the results, you need to install the package, which will automatica However, provided that the package is indeed installed, you can reproduce the results by either running the experiments in the `experiments/` folder or using the notebooks listed above for a more interactive process. +**Note**: All experiments were run on `julia-1.8.5`. Since pre-trained models were serialised on that version they may not be compatible with newer versions of Julia. + ### Command Line The `experiments/` folder contains separate Julia scripts for each dataset and a [run_experiments.jl](experiments/run_experiments.jl) that calls the individual scripts. You can either cun these scripts inside a Julia session or just use the command line to execute them as described in the following. diff --git a/experiments/circles.jl b/experiments/circles.jl index e69de29b..c8f9e6fd 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -0,0 +1,14 @@ +n_obs = Int(1000 / (1.0 - test_size)) +counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=test_size) +run_experiment( + counterfactual_data, test_data; dataname="Circles", + n_hidden=32, + α=[1.0, 1.0, 1e-2], + sampling_batch_size=nothing, + sampling_steps=20, + λ₁=0.25, + λ₂ = 0.75, + λ₃ = 0.75, + opt=Flux.Optimise.Descent(0.01), + use_class_loss = false, +) \ No newline at end of file diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index e69de29b..58d6b85b 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -0,0 +1,21 @@ +counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size) +run_experiment( + counterfactual_data, test_data; dataname="GMSC", + n_hidden=128, + activation = Flux.swish, + builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_hidden, activation), + Dense(n_hidden, n_out), + ), + α=[1.0, 1.0, 1e-1], + sampling_batch_size=nothing, + sampling_steps = 30, + use_ensembling = true, + λ₁ = 0.1, + λ₂ = 0.5, + λ₃ = 0.5, + opt = Flux.Optimise.Descent(0.05), + use_class_loss=false, + use_variants=false, +) \ No newline at end of file diff --git a/experiments/mnist.jl b/experiments/mnist.jl index e69de29b..b1847303 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -0,0 +1,52 @@ +function pre_process(x; noise::Float32=0.03f0) + ϵ = Float32.(randn(size(x)) * noise) + x += ϵ + return x +end + +# Training data: +n_obs = 10000 +counterfactual_data = load_mnist(n_obs) +counterfactual_data.X = pre_process.(counterfactual_data.X) + +# VAE (trained on full dataset): +using CounterfactualExplanations.Models: load_mnist_vae +vae = load_mnist_vae() +counterfactual_data.generative_model = vae + +# Test data: +test_data = load_mnist_test() + +# Generators: +eccco_generator = ECCCoGenerator( + λ=[0.1,0.25,0.25], + temp=0.1, + opt=nothing, + use_class_loss=true, + nsamples=10, + nmin=10, +) +Λ = eccco_generator.λ +generator_dict = Dict( + "Wachter" => WachterGenerator(λ=Λ[1], opt=eccco_generator.opt), + "REVISE" => REVISEGenerator(λ=Λ[1], opt=eccco_generator.opt), + "Schut" => GreedyGenerator(η=2.0), + "ECCCo" => eccco_generator, +) + +# Run: +run_experiment( + counterfactual_data, test_data; dataname="MNIST", + n_hidden = 128, + activation = Flux.swish, + builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_out), + ), + 𝒟x = Uniform(-1.0, 1.0), + α = [1.0,1.0,1e-2], + sampling_batch_size = 10, + ssampling_steps=25, + use_ensembling = true, + generators = generator_dict, +) \ No newline at end of file diff --git a/experiments/moons.jl b/experiments/moons.jl index e69de29b..01e124b1 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -0,0 +1,16 @@ +n_obs = Int(2500 / (1.0 - test_size)) +counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=test_size) +run_experiment( + counterfactual_data, test_data; dataname="Moons", + epochs=500, + n_hidden=32, + activation = Flux.relu, + α=[1.0, 1.0, 1e-1], + sampling_batch_size=10, + sampling_steps=30, + λ₁=0.25, + λ₂=0.75, + λ₃=0.75, + opt=Flux.Optimise.Descent(0.05), + use_class_loss=false +) \ No newline at end of file diff --git a/experiments/run_experiments.jl b/experiments/run_experiments.jl index 7e9edcab..36726597 100644 --- a/experiments/run_experiments.jl +++ b/experiments/run_experiments.jl @@ -1,33 +1,38 @@ include("setup.jl") # User inputs: -if ENV("DATANAME") == "all" +if ENV["DATANAME"] == "all" datanames = ["linearly_separable", "moons", "circles", "mnist", "gmsc"] else - datanames = [ENV("DATANAME")] + datanames = [ENV["DATANAME"]] end # Linearly Separable if "linearly_separable" in datanames + @info "Running linearly separable experiment." include("linearly_separable.jl") end # Moons if "moons" in datanames + @info "Running moons experiment." include("moons.jl") end # Circles if "circles" in datanames + @info "Running circles experiment." include("circles.jl") end # MNIST if "mnist" in datanames + @info "Running MNIST experiment." include("mnist.jl") end # GMSC if "gmsc" in datanames + @info "Running GMSC experiment." include("gmsc.jl") end diff --git a/experiments/setup.jl b/experiments/setup.jl index c641e286..0f5bde9f 100644 --- a/experiments/setup.jl +++ b/experiments/setup.jl @@ -11,6 +11,7 @@ test_size = 0.2 # Artifacts: using LazyArtifacts +@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions." artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5") pretrained_path = joinpath(artifact_path, "results") @@ -44,6 +45,7 @@ function run_experiment( use_class_loss=false, use_variants=true, n_individuals=25, + generators=nothing, ) # SETUP ---------- @@ -58,6 +60,7 @@ function run_experiment( # Model parameters: batch_size = minimum([Int(round(n_obs / 10)), 128]) + sampling_batch_size = isnothing(sampling_batch_size) ? batch_size : sampling_batch_size _loss = Flux.Losses.crossentropy # loss function _finaliser = Flux.softmax # finaliser function @@ -188,7 +191,9 @@ function run_experiment( CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params) # Benchmark generators: - if use_variants + if !isnothing(generators) + generator_dict = generators + elseif use_variants generator_dict = Dict( "Wachter" => WachterGenerator(λ=λ₁, opt=opt), "REVISE" => REVISEGenerator(λ=λ₁, opt=opt),