Skip to content

Commit

Permalink
RWMH: Correct get-setAdaptationRange doc
Browse files Browse the repository at this point in the history
* Correct get-setAdaptationRange doc

for the RandomWalkMetropolisHastings class.

* Changes following Michaël's review

* Fix bug hiding expansion/shrink factor accessors

* Tests: RandomWalkMetropolisHastings adaptation

- setAdapatationPeriod
- setAdaptationRange
- setAdaptationExpansionFactor
- setAdaptationShrinkFactor
  • Loading branch information
josephmure committed Oct 20, 2023
1 parent b8886f3 commit 0bfaed5
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 14 deletions.
45 changes: 45 additions & 0 deletions lib/test/t_RandomWalkMetropolisHastings_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,51 @@ int main(int, char *[])
//std::cout << "acceptance rate=" << rwmh.getAcceptanceRate() << std::endl;
assert_almost_equal(rwmh.getAcceptanceRate(), 0.28, 0.1, 0.0); // Empirical acceptance rate observed when executing the code


// Trick RandomWalkMetropolisHastings into being a simple random walk
// with Uniform(-1, 1) step: every "proposal" is automatically accepted.
const SymbolicFunction logdensity("x", "1");
Interval support(1);
support.setFiniteLowerBound({false});
support.setFiniteUpperBound({false});
const Uniform proposal(-1.0, 1.0);
RandomWalkMetropolisHastings rw(logdensity, support, {0.0}, proposal);

// The acceptance rate is 1 in this trivial case,
// so every adaptation step will multiply the adaptation factor
// by the expansion factor.
rw.setAdaptationExpansionFactor(2.0);
rw.setAdaptationPeriod(10);
rw.getSample(100);
assert_almost_equal(rw.getAdaptationFactor(), 1024.0, 0.0, 0.0);

// Check that the adaptation factor is really taken into account.
// We lengthen the adaptation period to get a longer period witout adaptation.
// We then compare the standard deviation of the step lengths with
// their theoretical standard deviation considering the 1024 adaptation factor.
rw.setAdaptationPeriod(100);
const Sample constantAdapationFactorSample(rw.getSample(99));
Indices notTaken(1);
const Indices up(notTaken.complement(99)); // [1, 2, ..., 98]
notTaken[0] = 98;
const Indices down(notTaken.complement(99)); // [0, 1, ..., 97]
const Sample steps(constantAdapationFactorSample.select(up) - constantAdapationFactorSample.select(down));
const Point ref_std = Uniform(-1024.0, 1024.0).getStandardDeviation();
assert_almost_equal(steps.computeStandardDeviation(), ref_std, 0.1, 0.0);

// At the next realization, once again the adaptation factor is multiplied by 2.
rw.getRealization();

// We now change the adaptation range
// to an interval with lower bound larger than 1 (the acceptance rate)
// This way, every adaptation step will multiply the adaptation factor
// by the shrink factor.
rw.setAdaptationRange(Interval(1.1, 1.2));
rw.setAdaptationPeriod(10);
rw.setAdaptationShrinkFactor(0.5);
const Sample decreasing_step_sample(rw.getSample(100));
assert_almost_equal(rw.getAdaptationFactor(), 2.0, 0.0, 0.0);

}
catch (TestFailed & ex)
{
Expand Down
40 changes: 27 additions & 13 deletions python/src/RandomWalkMetropolisHastings_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ marginalIndices : sequence of int, optional

Notes
-----
The random walk Metropolis-Hastings algorithm
The random walk Metropolis-Hastings algorithm
is a Markov Chain Monte-Carlo algorithm.
It draws candidates for the
next state of the chain as follows: denoting the current state by
Expand Down Expand Up @@ -143,8 +143,8 @@ period : positive int

// ---------------------------------------------------------------------

%feature("docstring") OT::RandomWalkMetropolisHastings::getExpansionFactor
"Get the expansion factor.
%feature("docstring") OT::RandomWalkMetropolisHastings::getAdaptationExpansionFactor
"Get the adaptation expansion factor.

Returns
-------
Expand All @@ -153,8 +153,8 @@ expansionFactor : float

// ---------------------------------------------------------------------

%feature("docstring") OT::RandomWalkMetropolisHastings::setExpansionFactor
"Set the expansion factor.
%feature("docstring") OT::RandomWalkMetropolisHastings::setAdaptationExpansionFactor
"Set the adaptation expansion factor.

Parameters
----------
Expand All @@ -164,27 +164,41 @@ expansionFactor : float, :math:`e > 1`
// ---------------------------------------------------------------------

%feature("docstring") OT::RandomWalkMetropolisHastings::getAdaptationRange
"Get the range.
"Get the expected range for the acceptance rate.
During burn-in, at the end of every adaptation period,
if the acceptance rate does not belong to this range,
the adaptation factor is multiplied
either by the expansion factor
(if the acceptance rate is larger than the *upperBound*)
or by the shrink factor
(if the acceptance rate is smaller than the *lowerBound*).

Returns
-------
range : :class:`~openturns.Interval` of dimension 1
Range :math:`[m,M]` of the adaptation factor."
Range [*lowerBound*, *upperBound*] of the expected acceptance rate."

// ---------------------------------------------------------------------

%feature("docstring") OT::RandomWalkMetropolisHastings::setAdaptationRange
"Set the range.
"Set the expected range for the acceptance rate.
During burn-in, at the end of every adaptation period,
if the acceptance rate does not belong to this range,
the adaptation factor is multiplied
either by the expansion factor
(if the acceptance rate is larger than the *upperBound*)
or by the shrink factor
(if the acceptance rate is smaller than the *lowerBound*).

Parameters
----------
range : :class:`~openturns.Interval` of dimension 1
Range :math:`[m,M]` of the adaptation factor."
Range [*lowerBound*, *upperBound*] of the expected acceptance rate."

// ---------------------------------------------------------------------

%feature("docstring") OT::RandomWalkMetropolisHastings::getShrinkFactor
"Get the shrink factor.
%feature("docstring") OT::RandomWalkMetropolisHastings::getAdaptationShrinkFactor
"Get the adaptation shrink factor.

Returns
-------
Expand All @@ -193,8 +207,8 @@ shrinkFactor : float

// ---------------------------------------------------------------------

%feature("docstring") OT::RandomWalkMetropolisHastings::setShrinkFactor
"Set the shrink factor.
%feature("docstring") OT::RandomWalkMetropolisHastings::setAdaptationShrinkFactor
"Set the adaptation shrink factor.

Parameters
----------
Expand Down
41 changes: 40 additions & 1 deletion python/test/t_RandomWalkMetropolisHastings_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import openturns as ot
import openturns.testing as ott
from math import exp
from math import exp, inf

ot.TESTPREAMBLE()

Expand Down Expand Up @@ -162,3 +162,42 @@ def post_den(alpha_beta):
mh.setLikelihood(likelihood, obs)
sampler = ot.Gibbs(mh_coll)
parameters_sample = sampler.getSample(2000)


# Trick RandomWalkMetropolisHastings into being a simple random walk
# with Uniform(-1, 1) step: every "proposal" is automatically accepted.
logdensity = ot.SymbolicFunction("x", "1")
support = ot.Interval([-inf], [inf])
proposal = ot.Uniform(-1.0, 1.0)
rw = ot.RandomWalkMetropolisHastings(logdensity, support, [0.0], proposal)

# The acceptance rate is 1 in this trivial case,
# so every adaptation step will multiply the adaptation factor
# by the expansion factor.
rw.setAdaptationExpansionFactor(2.0)
rw.setAdaptationPeriod(10)
rw.getSample(100)
ott.assert_almost_equal(rw.getAdaptationFactor(), 2.0**10, 0.0, 0.0)

# Check that the adaptation factor is really taken into account.
# We lengthen the adaptation period to get a longer period witout adaptation.
# We then compare the standard deviation of the step lengths with
# their theoretical standard deviation considering the 1024 adaptation factor.
rw.setAdaptationPeriod(100)
constantAdapationFactorSample = rw.getSample(99)
steps = constantAdapationFactorSample[1:] - constantAdapationFactorSample[:-1]
ref_std = ot.Uniform(-(2.0**10), 2.0**10).getStandardDeviation()
ott.assert_almost_equal(steps.computeStandardDeviation(), ref_std, 0.1, 0.0)

# At the next realization, once again the adaptation factor is multiplied by 2.
rw.getRealization()

# We now change the adaptation range
# to an interval with lower bound larger than 1 (the acceptance rate)
# This way, every adaptation step will multiply the adaptation factor
# by the shrink factor.
rw.setAdaptationRange(ot.Interval(1.1, 1.2))
rw.setAdaptationPeriod(10)
rw.setAdaptationShrinkFactor(0.5)
rw.getSample(100)
ott.assert_almost_equal(rw.getAdaptationFactor(), 2.0, 0.0, 0.0)

0 comments on commit 0bfaed5

Please sign in to comment.