Skip to content

Commit

Permalink
all polished
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed May 24, 2023
1 parent a2eea56 commit 5ab6c75
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ We used [Quarto](https://quarto.org/) notebooks for prototyping and running expe
- [MNIST](notebooks/mnist.qmd)
- [GMSC](notebooks/gmsc.qmd)

Instead of looking at the notebooks directly, you may choose to browse the HTML book contained inside the `docs` folder. The book is automatically generated from the notebooks and includes all code chunks and their outputs. It is currently not possible to view the book online, but you can download the `docs/` folder and open the `index.html` file in your browser.

## Inspecting the Results

All results have been carefully reported either in the paper itself or in the supplementary material. In addition, we have released our results as binary files. These will be made publicly available after the review process.
Expand All @@ -32,6 +30,8 @@ To reproduce the results, you need to install the package, which will automatica

However, provided that the package is indeed installed, you can reproduce the results by either running the experiments in the `experiments/` folder or using the notebooks listed above for a more interactive process.

**Note**: All experiments were run on `julia-1.8.5`. Since pre-trained models were serialised on that version they may not be compatible with newer versions of Julia.

### Command Line

The `experiments/` folder contains separate Julia scripts for each dataset and a [run_experiments.jl](experiments/run_experiments.jl) that calls the individual scripts. You can either cun these scripts inside a Julia session or just use the command line to execute them as described in the following.
Expand Down
14 changes: 14 additions & 0 deletions experiments/circles.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=test_size)
run_experiment(
counterfactual_data, test_data; dataname="Circles",
n_hidden=32,
α=[1.0, 1.0, 1e-2],
sampling_batch_size=nothing,
sampling_steps=20,
λ₁=0.25,
λ₂ = 0.75,
λ₃ = 0.75,
opt=Flux.Optimise.Descent(0.01),
use_class_loss = false,
)
21 changes: 21 additions & 0 deletions experiments/gmsc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size)
run_experiment(
counterfactual_data, test_data; dataname="GMSC",
n_hidden=128,
activation = Flux.swish,
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_hidden, activation),
Dense(n_hidden, n_out),
),
α=[1.0, 1.0, 1e-1],
sampling_batch_size=nothing,
sampling_steps = 30,
use_ensembling = true,
λ₁ = 0.1,
λ₂ = 0.5,
λ₃ = 0.5,
opt = Flux.Optimise.Descent(0.05),
use_class_loss=false,
use_variants=false,
)
52 changes: 52 additions & 0 deletions experiments/mnist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
function pre_process(x; noise::Float32=0.03f0)
ϵ = Float32.(randn(size(x)) * noise)
x += ϵ
return x
end

# Training data:
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)

# VAE (trained on full dataset):
using CounterfactualExplanations.Models: load_mnist_vae
vae = load_mnist_vae()
counterfactual_data.generative_model = vae

# Test data:
test_data = load_mnist_test()

# Generators:
eccco_generator = ECCCoGenerator(
λ=[0.1,0.25,0.25],
temp=0.1,
opt=nothing,
use_class_loss=true,
nsamples=10,
nmin=10,
)
Λ = eccco_generator.λ
generator_dict = Dict(
"Wachter" => WachterGenerator=Λ[1], opt=eccco_generator.opt),
"REVISE" => REVISEGenerator=Λ[1], opt=eccco_generator.opt),
"Schut" => GreedyGenerator=2.0),
"ECCCo" => eccco_generator,
)

# Run:
run_experiment(
counterfactual_data, test_data; dataname="MNIST",
n_hidden = 128,
activation = Flux.swish,
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
),
𝒟x = Uniform(-1.0, 1.0),
α = [1.0,1.0,1e-2],
sampling_batch_size = 10,
ssampling_steps=25,
use_ensembling = true,
generators = generator_dict,
)
16 changes: 16 additions & 0 deletions experiments/moons.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
n_obs = Int(2500 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=test_size)
run_experiment(
counterfactual_data, test_data; dataname="Moons",
epochs=500,
n_hidden=32,
activation = Flux.relu,
α=[1.0, 1.0, 1e-1],
sampling_batch_size=10,
sampling_steps=30,
λ₁=0.25,
λ₂=0.75,
λ₃=0.75,
opt=Flux.Optimise.Descent(0.05),
use_class_loss=false
)
9 changes: 7 additions & 2 deletions experiments/run_experiments.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,38 @@
include("setup.jl")

# User inputs:
if ENV("DATANAME") == "all"
if ENV["DATANAME"] == "all"
datanames = ["linearly_separable", "moons", "circles", "mnist", "gmsc"]
else
datanames = [ENV("DATANAME")]
datanames = [ENV["DATANAME"]]
end

# Linearly Separable
if "linearly_separable" in datanames
@info "Running linearly separable experiment."
include("linearly_separable.jl")
end

# Moons
if "moons" in datanames
@info "Running moons experiment."
include("moons.jl")
end

# Circles
if "circles" in datanames
@info "Running circles experiment."
include("circles.jl")
end

# MNIST
if "mnist" in datanames
@info "Running MNIST experiment."
include("mnist.jl")
end

# GMSC
if "gmsc" in datanames
@info "Running GMSC experiment."
include("gmsc.jl")
end
7 changes: 6 additions & 1 deletion experiments/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ test_size = 0.2

# Artifacts:
using LazyArtifacts
@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5")
pretrained_path = joinpath(artifact_path, "results")

Expand Down Expand Up @@ -44,6 +45,7 @@ function run_experiment(
use_class_loss=false,
use_variants=true,
n_individuals=25,
generators=nothing,
)

# SETUP ----------
Expand All @@ -58,6 +60,7 @@ function run_experiment(

# Model parameters:
batch_size = minimum([Int(round(n_obs / 10)), 128])
sampling_batch_size = isnothing(sampling_batch_size) ? batch_size : sampling_batch_size
_loss = Flux.Losses.crossentropy # loss function
_finaliser = Flux.softmax # finaliser function

Expand Down Expand Up @@ -188,7 +191,9 @@ function run_experiment(
CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params)

# Benchmark generators:
if use_variants
if !isnothing(generators)
generator_dict = generators
elseif use_variants
generator_dict = Dict(
"Wachter" => WachterGenerator=λ₁, opt=opt),
"REVISE" => REVISEGenerator=λ₁, opt=opt),
Expand Down

0 comments on commit 5ab6c75

Please sign in to comment.