Skip to content

Commit

Permalink
added tvae option for gaussian denoising (#28)
Browse files Browse the repository at this point in the history
* added tvae option for gaussian denoising

* black pass

* ignore complexity of example demonstration

* yabp-yet another black pass

* [GAUDEN-EXAMPLE] Revise argument parser and GF visualization

* Use subparser to choose between model
* Visualize singleton means instead of last layer weights for TVAE

* [GAUDEN-EXAMPLE] Update readme

[no ci]

Co-authored-by: Jakob Drefs <drefs.jakob@gmail.com>
  • Loading branch information
mknull and jdrefs authored Sep 21, 2022
1 parent 3e0e901 commit d9ed912
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 62 deletions.
46 changes: 21 additions & 25 deletions examples/gaussian-denoising/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,21 @@ The example additionally requires `ffmpeg`, `imageio`, and `tvutil` to be instal


## Get started
To start the experiment, run `python main.py`. To see possible options, run, e.g.:
To start the experiment, run `python main.py <model>` with `<model>` being one of `bsc` or `tvae`. To see possible options, run, e.g.,

```bash
$ python main.py -h
usage: main.py [-h] [--clean_image CLEAN_IMAGE] [--rescale RESCALE] [--noise_level NOISE_LEVEL]
[--patch_height PATCH_HEIGHT] [--patch_width PATCH_WIDTH] [--Ksize KSIZE]
[--selection {fitness,uniform}] [--crossover] [--no_parents NO_PARENTS]
[--no_children NO_CHILDREN] [--no_generations NO_GENERATIONS] [-H H] [--no_epochs NO_EPOCHS]
[--merge_every MERGE_EVERY] [--output_directory OUTPUT_DIRECTORY] [--viz_every VIZ_EVERY]
[--gif_framerate GIF_FRAMERATE]

Gaussian Denoising with BSC
$ python main.py bsc -h
usage: Gaussian Denoising bsc [-h] [--clean_image CLEAN_IMAGE] [--rescale RESCALE] [--noise_level NOISE_LEVEL] [--patch_height PATCH_HEIGHT]
[--patch_width PATCH_WIDTH] [--Ksize KSIZE] [--selection {fitness,uniform}] [--crossover] [--no_parents NO_PARENTS]
[--no_children NO_CHILDREN] [--no_generations NO_GENERATIONS] [--no_epochs NO_EPOCHS] [--merge_every MERGE_EVERY]
[--output_directory OUTPUT_DIRECTORY] [--viz_every VIZ_EVERY] [--gif_framerate GIF_FRAMERATE] [-H H]

optional arguments:
-h, --help show this help message and exit
--clean_image CLEAN_IMAGE
Full path to clean image (png, jpg, ... file) (default: ./img/house.png)
--rescale RESCALE If specified, the size of the clean image will be rescaled by this factor (only for
demonstration purposes to minimize computational effort) (default: 0.5)
--rescale RESCALE If specified, the size of the clean image will be rescaled by this factor (only for demonstration purposes to minimize computational
effort) (default: 0.5)
--noise_level NOISE_LEVEL
Standard deviation of the additive white Gaussian noise (default: 25)
--patch_height PATCH_HEIGHT
Expand All @@ -44,29 +40,26 @@ optional arguments:
--Ksize KSIZE Size of the K sets (i.e., S=|K|) (default: 50)
--selection {fitness,uniform}
Selection operator (default: fitness)
--crossover Whether to apply crossover. Must be False if no_children is specified (default:
False)
--crossover Whether to apply crossover. Must be False if no_children is specified (default: False)
--no_parents NO_PARENTS
Number of parental states to select per generation (default: 20)
--no_children NO_CHILDREN
Number of children to evolve per generation (default: 2)
--no_generations NO_GENERATIONS
Number of generations to evolve (default: 1)
-H H Number of generative fields to learn (dictionary size) (default: 32)
--no_epochs NO_EPOCHS
Number of epochs to train (default: 40)
--merge_every MERGE_EVERY
Generate reconstructed image by merging image patches every Xth epoch (will be set
equal to viz_every if not specified) (default: None)
Generate reconstructed image by merging image patches every Xth epoch (will be set equal to viz_every if not specified) (default:
None)
--output_directory OUTPUT_DIRECTORY
Directory to write H5 training output and visualizations to (will be
output/<TIMESTAMP> if not specified) (default: None)
Directory to write H5 training output and visualizations to (will be output/<TIMESTAMP> if not specified) (default: None)
--viz_every VIZ_EVERY
Create visualizations every Xth epoch. (default: 1)
--gif_framerate GIF_FRAMERATE
If specified, the training output will be additionally saved as animated gif. The
framerate is given in frames per second. If not specified, no gif will be produced.
(default: None)
If specified, the training output will be additionally saved as animated gif. The framerate is given in frames per second. If not
specified, no gif will be produced. (default: None)
-H H Number of generative fields to learn (dictionary size) (default: 32)
```
Expand All @@ -75,19 +68,22 @@ optional arguments:
For distributed execution on multiple CPU cores (requires MPI to be installed), run with `mpirun -n <n_proc> python main.py ...`, e.g.:
```bash
env TVO_MPI=1 mpirun -n 4 python main.py
env TVO_MPI=1 mpirun -n 4 python main.py bsc
```
To run on GPU (requires cudatoolkit to be installed), run, e.g.:
```bash
env TVO_GPU=0 python main.py
env TVO_GPU=0 python main.py bsc
```
# Note
The default hyperparameters in this examples are chosen s.t. examplary executions of the algorithm on a standard personal computer can be performed in short time. For full resolution images and improved performance, larger models and, in turn, larger compute ressources (on the order of hundreds of CPU cores) are required. For details, see [1].
The default hyperparameters in this examples are chosen s.t. examplary executions of the algorithm on a standard personal computer can be performed in short time. For full resolution images and improved performance, larger models and, in turn, larger compute ressources (on the order of hundreds of CPU cores) are required. For details, see [1,2].
# Reference
[1] Evolutionary Variational Optimization of Generative Models. Jakob Drefs, Enrico Guiraud, Jörg Lücke. _Journal of Machine Learning Research_ 23(21):1-51, 2022. [(online access)](https://www.jmlr.org/papers/v23/20-233.html)
[2] Direct Evolutionary Optimization of Variational Autoencoders With Binary Latents.
Jakob Drefs*, Enrico Guiraud*, Filippos Panagiotou, Jörg Lücke. In _Joint European Conference on Machine Learning and Knowledge Discovery in Databases_, accepted, 2022. *Joint first authorship.
68 changes: 43 additions & 25 deletions examples/gaussian-denoising/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import tvo
from tvo.exp import EVOConfig, ExpConfig, Training
from tvo.models import BSC
from tvo.models import BSC, GaussianTVAE
from tvo.utils.parallel import pprint, broadcast, barrier, bcast_shape, gather_from_processes
from tvo.utils.param_init import init_W_data_mean, init_sigma2_default
from tvo.utils.model_protocols import Reconstructor
Expand All @@ -30,6 +30,7 @@
store_as_h5,
get_epochs_from_every,
eval_fn,
get_singleton_means,
)
from viz import Visualizer

Expand All @@ -38,7 +39,7 @@
dtype_device_kwargs = {"dtype": PRECISION, "device": DEVICE}


def gaussian_denoising_example():
def gaussian_denoising_example(): # noqa: C901

# initialize MPI (if executed with env TVO_MPI=...), otherwise pass
comm_rank = init_processes()[0]
Expand Down Expand Up @@ -85,27 +86,39 @@ def gaussian_denoising_example():
pprint("Initializing model")

# initialize model
W_init = (
init_W_data_mean(data=train_data, H=args.H, dtype=PRECISION, device=DEVICE).contiguous()
if comm_rank == 0
else to.zeros((D, args.H), dtype=PRECISION, device=DEVICE)
)
sigma2_init = (
init_sigma2_default(train_data, PRECISION, DEVICE)
if comm_rank == 0
else to.zeros((1), dtype=PRECISION, device=DEVICE)
)
barrier()
broadcast(W_init)
broadcast(sigma2_init)
model = BSC(
H=args.H,
D=D,
W_init=W_init,
sigma2_init=sigma2_init,
pies_init=to.full((args.H,), 2.0 / args.H, **dtype_device_kwargs),
precision=PRECISION,
)
if args.model == "bsc":
W_init = (
init_W_data_mean(data=train_data, H=args.H, dtype=PRECISION, device=DEVICE).contiguous()
if comm_rank == 0
else to.zeros((D, args.H), dtype=PRECISION, device=DEVICE)
)
sigma2_init = (
init_sigma2_default(train_data, PRECISION, DEVICE)
if comm_rank == 0
else to.zeros((1), dtype=PRECISION, device=DEVICE)
)
barrier()
broadcast(W_init)
broadcast(sigma2_init)

model = BSC(
H=args.H,
D=D,
W_init=W_init,
sigma2_init=sigma2_init,
pies_init=to.full((args.H,), 2.0 / args.H, **dtype_device_kwargs),
precision=PRECISION,
)
elif args.model == "tvae":
model = GaussianTVAE(
shape=[
D,
]
+ args.inner_net_shape,
min_lr=0.0001,
max_lr=0.01,
)

assert isinstance(model, Reconstructor)

pprint("Initializing experiment")
Expand Down Expand Up @@ -148,7 +161,7 @@ def gaussian_denoising_example():
noisy_image=noisy,
patch_size=(args.patch_height, patch_width),
sort_gfs=True,
ncol_gfs=4,
ncol_gfs=3,
gif_framerate=args.gif_framerate,
)
if comm_rank == 0
Expand Down Expand Up @@ -189,10 +202,15 @@ def gaussian_denoising_example():

# visualize epoch
if comm_rank == 0:
if args.model == "bsc":
gfs = model.theta["W"]
elif args.model == "tvae":
gfs = get_singleton_means(model.theta).T

visualizer.process_epoch(
epoch=epoch,
pies=model.theta["pies"],
gfs=model.theta["W"],
gfs=gfs,
rec=imgs["mean"] if merge else None,
)
barrier()
Expand Down
53 changes: 42 additions & 11 deletions examples/gaussian-denoising/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,27 @@
)


experiment_parser = argparse.ArgumentParser(add_help=False)
experiment_parser.add_argument(
bsc_parser = argparse.ArgumentParser(add_help=False)
bsc_parser.add_argument(
"-H",
type=int,
help="Number of generative fields to learn (dictionary size)",
default=32,
)


tvae_parser = argparse.ArgumentParser(add_help=False)
tvae_parser.add_argument(
"--inner_net_shape",
nargs="+",
type=int,
help="Decoder shape (...,H1,H0) excluding number of observables. "
"Full network shape will be (patch_height*patch_width*no_channels,...,H1,H0)",
default=[25, 25],
)


experiment_parser = argparse.ArgumentParser(add_help=False)
experiment_parser.add_argument(
"--no_epochs",
type=int,
Expand Down Expand Up @@ -143,17 +156,35 @@


def get_args():
parser = argparse.ArgumentParser(
description="Gaussian Denoising with BSC",
parser = argparse.ArgumentParser(prog="Gaussian Denoising")
algo_parsers = parser.add_subparsers(help="Select model to train", dest="model", required=True)
comm_parents = [
awgn_parser,
patch_parser,
variational_parser,
experiment_parser,
output_parser,
viz_parser,
]

algo_parsers.add_parser(
"bsc",
help="Run experiment with BSC",
parents=comm_parents
+ [
bsc_parser,
],
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parents=[
awgn_parser,
patch_parser,
variational_parser,
experiment_parser,
output_parser,
viz_parser,
)

algo_parsers.add_parser(
"tvae",
help="Run experiment with TVAE",
parents=comm_parents
+ [
tvae_parser,
],
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

return parser.parse_args()
23 changes: 22 additions & 1 deletion examples/gaussian-denoising/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from skimage.metrics import peak_signal_noise_ratio
from typing import Dict, Union

from tvo import get_run_policy
from tvo import get_run_policy, get_device
from tvo.utils.parallel import init_processes as _init_processes
from tvo.models import GaussianTVAE


def init_processes() -> Tuple[int, int]:
Expand Down Expand Up @@ -125,3 +126,23 @@ def eval_fn(
data_range=data_range,
)
)


def get_singleton_means(theta: Dict[str, to.Tensor]) -> to.Tensor:
"""Initialize TVAE model with parameters `theta` and compute NN output for NN input vectors
corresponding to singleton states (only one active unit per unit vector).
:param theta: Dictionary with TVAE model parameters
:return: Decoded means
"""
n_layers = len(tuple(k for k in theta.keys() if k.startswith("W_")))
W = tuple(theta[f"W_{ind_layer}"].clone().detach() for ind_layer in range(n_layers))
b = tuple(theta[f"b_{ind_layer}"].clone().detach() for ind_layer in range(n_layers))
sigma2 = float(theta["sigma2"])
H0 = W[0].shape[0]
m = GaussianTVAE(W_init=W, b_init=b, sigma2_init=sigma2)
singletons = to.eye(H0).to(get_device())
means = m.forward(singletons).detach().cpu()
D = W[-1].shape[-1]
assert means.shape == (H0, D)
return means

0 comments on commit d9ed912

Please sign in to comment.