Skip to content

Commit

Permalink
Merge branch 'refactor_mcmc_samplers' into v1-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
guilgautier committed Dec 1, 2021
2 parents f57f8f9 + 3dc9a32 commit 3c1d230
Show file tree
Hide file tree
Showing 8 changed files with 1,065 additions and 686 deletions.
93 changes: 32 additions & 61 deletions dppy/finite/dpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from dppy.finite.exact_samplers.vfx_samplers import vfx_sampler_dpp, vfx_sampler_k_dpp

# MCMC
from dppy.mcmc_sampling import dpp_sampler_mcmc, zonotope_sampler
from dppy.finite.mcmc_samplers.add_delete_sampler import add_delete_sampler
from dppy.finite.mcmc_samplers.add_exchange_delete_sampler import (
add_exchange_delete_sampler,
)
from dppy.finite.mcmc_samplers.exchange_sampler import exchange_sampler
from dppy.finite.mcmc_samplers.zonotope_sampler import zonotope_sampler

# UTILS
from dppy.utils import (
Expand Down Expand Up @@ -447,31 +452,31 @@ def _select_sampler_exact_k_dpp(method):
default = samplers["spectral"]
return samplers.get(method.lower(), default)

def sample_mcmc(self, mode, **params):
def sample_mcmc(self, method="aed", random_state=None, **params):
"""Run a MCMC with stationary distribution the corresponding :class:`FiniteDPP <FiniteDPP>` object.
:param string mode:
:param string method:
- ``"AED"`` Add-Exchange-Delete
- ``"AD"`` Add-Delete
- ``"E"`` Exchange
- ``"aed"`` add-exchange-delete
- ``"ad"`` add-delete
- ``"e"`` exchange
- ``"zonotope"`` Zonotope sampling
:param dict params:
Dictionary containing the parameters for MCMC samplers with keys
``"random_state"`` (default None)
- If ``mode="AED","AD","E"``
- If ``method="aed","ad","e"``
+ ``"s_init"`` (default None) Starting state of the Markov chain
+ ``"nb_iter"`` (default 10) Number of iterations of the chain
+ ``"T_max"`` (default None) Time horizon
+ ``"size"`` (default None) Size of the initial sample for ``mode="AD"/"E"``
+ ``"size"`` (default None) Size of the initial sample for ``method="AD"/"E"``
* :math:`\\operatorname{rank}(\\mathbf{K})=\\operatorname{trace}(\\mathbf{K})` for projection :math:`\\mathbf{K}` (correlation) kernel and ``mode="E"``
* :math:`\\operatorname{rank}(\\mathbf{K})=\\operatorname{trace}(\\mathbf{K})` for projection :math:`\\mathbf{K}` (correlation) kernel and ``method="E"``
- If ``mode="zonotope"``:
- If ``method="zonotope"``:
+ ``"lin_obj"`` linear objective in main optimization problem (default np.random.randn(N))
+ ``"x_0"`` initial point in zonotope (default A*u, u~U[0,1]^n)
Expand Down Expand Up @@ -499,56 +504,26 @@ def sample_mcmc(self, mode, **params):
- :py:meth:`~FiniteDPP.flush_samples`
"""

self.sampling_mode = mode

if self.sampling_mode == "zonotope":
if self.A_zono is not None:
chain = zonotope_sampler(self.A_zono, **params)
else:
err_print = [
"Invalid `mode=zonotope` parameter",
"DPP must be defined via `A_zono`",
"Given: {}".format(self.params_keys),
]
raise ValueError(" ".join(err_print))

elif self.sampling_mode == "E":
if self.kernel_type == "correlation" and self.projection:
self.compute_K()
size = params.get("size", None)
rank = np.rint(np.trace(self.K)).astype(int)
# |sample| = Tr(K) = rank(K) a.s. for projection DPP(K)
if size == rank:
chain = dpp_sampler_mcmc(self.K, self.sampling_mode, **params)
else:
raise ValueError(
"size={} != rank={} for projection correlation K kernel".format(
size, rank
)
)
else:
self.compute_L()
chain = dpp_sampler_mcmc(self.L, self.sampling_mode, **params)

elif self.sampling_mode in ("AED", "AD"):
self.compute_L()
chain = dpp_sampler_mcmc(self.L, self.sampling_mode, **params)

else:
err_print = [
"Invalid `mode` parameter, choose among:",
"- `AED`: Add-Exchange-Delete",
"- `AD`: Add-Delete",
"- `E`: Exchange",
"- `zonotope`: projection correlation kernel only",
"Given: {}".format(self.sampling_mode),
]
raise ValueError("\n".join(err_print))
rng = check_random_state(random_state)
sampler = self._select_sampler_mcmc_dpp(method)
chain = sampler(self, rng, **params)

self.list_of_samples.append(chain)
self.sampling_mode = method
return chain[-1]

def sample_mcmc_k_dpp(self, size, mode="E", **params):
@staticmethod
def _select_sampler_mcmc_dpp(method):
samplers = {
"aed": add_exchange_delete_sampler,
"ad": add_delete_sampler,
"e": exchange_sampler,
"zonotope": zonotope_sampler,
}
default = samplers["aed"]
return samplers.get(method.lower(), default)

def sample_mcmc_k_dpp(self, size, method="e", random_state=None, **params):
"""Calls :py:meth:`~sample_mcmc` with ``mode="E"`` and ``params["size"] = size``
.. seealso::
Expand All @@ -558,13 +533,9 @@ def sample_mcmc_k_dpp(self, size, mode="E", **params):
- :py:meth:`~FiniteDPP.sample_exact_k_dpp`
- :py:meth:`~FiniteDPP.flush_samples`
"""

self.sampling_mode = "E"

self.size_k_dpp = size
params["size"] = size

return self.sample_mcmc(self.sampling_mode, **params)
return self.sample_mcmc(method="e", random_state=None, **params)

def compute_K(self):
"""Alias of :py:meth:`~dppy.finite.dpp.FiniteDPP.compute_correlation_kernel`"""
Expand Down
190 changes: 190 additions & 0 deletions dppy/finite/mcmc_samplers/add_delete_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import time

from dppy.utils import check_random_state, det_ST


def add_delete_sampler(dpp, random_state=None, **params):
dpp.compute_L()
kernel = dpp.L
rng = check_random_state(random_state)
s0 = params.pop("s_init", None)
if s0 is None:
s0 = initialize_add_delete_sampler(kernel, rng, **params)
return add_delete_sampler_core(kernel, s0, rng, **params)


def add_delete_sampler_core(
kernel, s_init, random_state=None, nb_iter=10, T_max=None, **kwargs
):
"""MCMC sampler for generic DPP(kernel), it performs local moves by removing/adding one element at a time.
:param kernel:
Kernel matrix
:type kernel:
array_like
:param s_init:
Initial sample.
:type s_init:
list
:param nb_iter:
Maximum number of iterations performed by the the algorithm.
Default is 10.
:type nb_iter:
int
:param T_max:
Maximum running time of the algorithm (in seconds).
Default is None.
:type T_max:
float
:param random_state:
:type random_state:
None, np.random, int, np.random.RandomState
:return:
list of `nb_iter` approximate samples of DPP(kernel)
:rtype:
array_like
.. seealso::
Algorithm 1 in :cite:`LiJeSr16c`
"""
rng = check_random_state(random_state)

# Initialization
N = kernel.shape[0] # Number of elements

# Initialization
S0, det_S0 = s_init, det_ST(kernel, s_init)
chain = [S0] # Initialize the collection (list) of sample

# Evaluate running time...
t_start = time.time() if T_max else 0

for _ in range(nb_iter):

# With proba 1/2 try to add/delete an element
if rng.rand() < 0.5:

# Perform the potential add/delete move S1 = S0 +/- s
S1 = S0.copy() # S1 = S0
s = rng.choice(N) # Uniform item in [N]
if s in S1:
S1.remove(s) # S1 = S0 - s
else:
S1.append(s) # S1 = SO + s

# Accept_reject the move
det_S1 = det_ST(kernel, S1) # det K_S1
if rng.rand() < det_S1 / det_S0:
S0, det_S0 = S1, det_S1

chain.append(S0)

if T_max and (time.time() - t_start < T_max):
break

return chain


def add_delete_sampler_refactored(
kernel, s_init, nb_iter=10, T_max=None, random_state=None
):
"""MCMC sampler for generic DPP(kernel), it performs local moves by removing/adding one element at a time.
:param kernel:
Kernel matrix
:type kernel:
array_like
:param s_init:
Initial sample.
:type s_init:
list
:param nb_iter:
Maximum number of iterations performed by the the algorithm.
Default is 10.
:type nb_iter:
int
:param T_max:
Maximum running time of the algorithm (in seconds).
Default is None.
:type T_max:
float
:param random_state:
:type random_state:
None, np.random, int, np.random.RandomState
:return:
list of `nb_iter` approximate samples of DPP(kernel)
:rtype:
list of lists
.. seealso::
Algorithm 1 in :cite:`LiJeSr16c`
"""

# Initialization
rng = check_random_state(random_state)

N = kernel.shape[0]
items = s_init + [i for i in range(N) if i not in s_init]

det_S0, size, add_or_del = det_ST(kernel, s_init), len(s_init), 0
chain = [s_init]

t_start = time.time() if T_max else 0

for _ in range(nb_iter):

# With proba 1/2 try to add/delete an element
if rng.rand() < 0.5:

s = rng.randint(0, N) # Uniform item in [N]
if s >= size: # S += s
items[s], items[size] = items[size], items[s]
add_or_del = 1
else: # S -= s
items[s], items[size - 1] = items[size - 1], items[s]
add_or_del = -1

# Accept_reject the move
det_S1 = det_ST(kernel, items[: size + add_or_del])
if rng.rand() < det_S1 / det_S0:
det_S0 = det_S1
size += add_or_del

chain.append(items[:size])

if T_max and (time.time() - t_start < T_max):
break

return chain


def initialize_add_delete_sampler(
kernel, random_state=None, size=None, nb_trials=100, tol=1e-9, **kwargs
):
rng = check_random_state(random_state)
N = kernel.shape[0]

for _ in range(nb_trials):
_size = rng.randint(1, N + 1) if size is None else size
S0 = rng.choice(N, size=_size, replace=False)
det_S0 = det_ST(kernel, S0)
if det_S0 > tol:
return S0.tolist()

raise ValueError(
"Failed to initialize add-delete sampler. After {} random trials, no initial set S0 satisfies det L_S0 > {}. If you are sampling from a k-DPP, make sure size k <= rank(L). You may consider passing your own initial state s_init.".format(
nb_trials, tol
)
)
Loading

0 comments on commit 3c1d230

Please sign in to comment.