A python script to evaluate and plot the discretized Wasserstein-2 gradient flow starting at an empirical measure with respect to an Maximum-Mean-Discrepancy-regularized f-divergence functional, whose target is an empirical measure as well.
This repository provides the method
MMD_reg_f_div_flow
(from the file MMD_reg_fDiv_ParticleFlows.py
)
used to produce the numerical experiments for the paper
Wasserstein Gradient Flows for Moreau Envelopes of f-Divergences in Reproducing Kernel Hilbert Spaces by Sebastian Neumayer, Viktor Stein, Gabriele Steidl and Nikolaj Rux.
If you use this code please cite this preprint, preferably like this:
@unpublished{NSSR24,
author = {Neumayer, Sebastian and Stein, Viktor and Steidl, Gabriele and Rux, Nicolaj},
title = {Wasserstein Gradient Flows for {M}oreau Envelopes of $f$-Divergences in Reproducing Kernel {H}ilbert Spaces},
note = {ArXiv preprint},
volume = {arXiv:2402.04613},
year = {2024},
month = {Feb},
url = {https://arxiv.org/abs/2402.04613},
doi = {10.48550/arXiv.2402.04613}
}
The other python files contain auxillary functions.
Scripts to exactly reproduce the figures in the preprint are soon to come. An example file is AlphaComparison.py
.
This code is written and maintained by Viktor Stein. Any comments, feedback, questions and bug reports are welcome! Alternatively you can use the GitHub issue tracker.
- Required packages
- Options of the main method
- Supported kernels
- Supported
$f$ -divergences / entropy functions - Supported targets
This script requires the following Python packages. We tested the code with Python 3.11.7 and the following package versions:
- torch 2.1.2
- scipy 1.12.0
- numpy 1.26.3
- pillow 10.2.0 (if you want to generate a gif of the evolution of the flow)
- matplotlib 3.8.2
- pot 0.9.3 (if you want to evaluate the exact Wasserstein-2 loss along the flow)
- sklearn.datasets 1.4.1.post1 (for more targets)
- https://github.com/gmgeorg/torchlambertw/
Usually code is also compatible with some later or earlier versions of those packages.
Parameter | Type | Explanation |
---|---|---|
a | float | divergence parameter |
s | float | kernel parameter > 0 |
N | int | number of prior particles |
M | int | number of target particles |
lambd | float | regularization parameter > 0 |
step_size | float | step size for Euler forward discretization |
max_time | float | maximal time horizon for simulation |
plot | boolean | decide whether to plot particles along the evolution |
arrows | boolean | decide whether to plot arrows at particles to show their gradients |
timeline | boolean | decide whether to plot timeline of functional value along the flow |
kern | function | kernel (see below) |
primal | bolean | decide whether to solve the primal problem |
dual | bolean | decide whether to solve the dual problem |
div | class entr_fnc | entropy function |
target_name | string | name of the target measure nu |
verbose | boolean | decide whether to print warnings and information |
compute_W2 | boolean | decide whether to compute W2 dist of particles to target along flow |
save_opts | boolean | decide whether to save minimizers and gradients along the flow |
st | int | random state for reproducibility |
annealing | boolean | decide wether to use the annealing heuristic |
annealing_factor | int | factor by which to divide lambda |
tight | boolean | decide whether to use the tight variational formulation |
line_search | string | step size choice for the exponetial GD for the tight formulation, either 'const', 'armijo', 'Polyak' or 'two_way' |
The following kernels all are radial and twice-differentiable, hence fulfilling all assumptions in the paper.
We denote the reLU by
Kernel | Name | Expression |
---|---|---|
inverse multiquadric | imq |
|
Gauss | gauss |
|
Matérn- |
matern |
|
Matérn- |
matern2 |
|
|
compact |
|
Another Spline | compact2 |
|
inverse log | inv_log |
|
inverse quadric | inv_quad |
|
student t | student |
I also implemented the following two "$W_2$-metrizing kernels", which metrize the Wasserstein-2 distance on
Kernel | Name | Expression |
---|---|---|
|
W2_1 |
|
|
W2_2 |
The following entropy functions each have an infinite recession constant if
Entropy | Name | Expression |
---|---|---|
Kullback-Leibler |
tsa , |
|
Tsallis- |
tsa |
|
Jeffreys | jeffreys |
|
chi_entr |
Below we list some other implemented entropy functions with finite recession constant. For even more entropy functions we refer to table 1 in the above mentioned preprint.
Entropy | Name | Expression |
---|---|---|
Burg | reverse_kl |
|
Jensen-Shannon | jensen_shannon |
|
total variation | tv |
|
Matusita | matusita |
|
Kafka | kafka |
|
Marton | marton |
|
perimeter | per |
|
equality indicator | 'equality_indicator' | |
zero | 'zero' | |
Lindsay | 'lind' |
bananas
: the two parabolas in the gif at the topcircles
: three circles
cross
: four versions of Neals funnel arranged in a cross shape
-
GMM
: two exactly equal Gaussians which have a symmetry axis at$y = - x$
four_wells
: a sum of four Gaussians, which don't have a symmetry axis. The initial measure is initiated at one of the Gaussians.
swiss_role_2d
:
We also include some target measures from sklearn.data
: moons
, annulus
and the three-dimensional data sets swiss_role_3d
and s_curve
.
Copyright (c) 2024 Viktor Stein
This software was written by Viktor Stein. It was developed at the Institute of Mathematics, TU Berlin. The author acknowledges support by the German Research Foundation within the project VI screen.
This is free software. You can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. If not stated otherwise, this applies to all files contained in this package and its sub-directories.
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.