Skip to content

Commit

Permalink
Improve the error handling in the backward and forward convergence (#358
Browse files Browse the repository at this point in the history
)

* problem: In the backward and forward convergence, for the initial set of points, which uses for example the first 10% of the date, it could be the cases where due to the fact that there are not many data points, so the overlap is pretty bad, which gives terrible statistical error.
   solution: If the statistical error is too bad, use the bootstrap error instead, see choderalab/pymbar#519
* new kwarg error_tol  in convergence.forward_backward_convergence() to allow the user to specify a error tolerance; if error > error_tol then switch to using bootstrap error
* Update CHANGES
* add test
  • Loading branch information
xiki-tempula authored May 21, 2024
1 parent 46cc83b commit c8fe7b8
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 35 deletions.
2 changes: 2 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Changes
`None` (start from all zeros) as this change provides a sizable speedup (PR #357)

Enhancements
- `forward_backward_convergence` uses the bootstrap error when the statistical error
is too large. (PR #358)
- `BAR` result is used as initial guess for `MBAR` estimator. (PR #357)
- `forward_backward_convergence` uses the result from the previous step as the initial guess for the next step. (PR #357)

Expand Down
107 changes: 72 additions & 35 deletions src/alchemlyb/convergence/convergence.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Functions for assessing convergence of free energy estimates and raw data."""

from typing import Any, List, Tuple
from warnings import warn

import numpy as np
import pandas as pd
from loguru import logger
from sklearn.base import BaseEstimator

from .. import concat
from ..estimators import BAR, TI, MBAR, FEP_ESTIMATORS, TI_ESTIMATORS
Expand All @@ -13,7 +15,9 @@
estimators_dispatch = {"BAR": BAR, "TI": TI, "MBAR": MBAR}


def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
def forward_backward_convergence(
df_list, estimator="MBAR", num=10, error_tol: float = 3, **kwargs
):
"""Forward and backward convergence of the free energy estimate.
Generate the free energy estimate as a function of time in both directions,
Expand All @@ -35,6 +39,12 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
Lower case input is also accepted until release 2.0.0.
num : int
The number of time points.
error_tol : float
The maximum error tolerated for analytic error. If the analytic error is
bigger than the error tolerance, the bootstrap error will be used.
.. versionadded:: 2.3.0
kwargs : dict
Keyword arguments to be passed to the estimator.
Expand Down Expand Up @@ -93,23 +103,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
sample = []
for data in df_list:
sample.append(data[: len(data) // num * i])
sample = concat(sample)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
forward_list.append(result.delta_f_.iloc[0, -1])
if estimator.lower() == "bar":
error = np.sqrt(
sum(
[
result.d_delta_f_.iloc[i, i + 1] ** 2
for i in range(len(result.d_delta_f_) - 1)
]
)
)
forward_error_list.append(error)
else:
forward_error_list.append(result.d_delta_f_.iloc[0, -1])
mean, error = _forward_backward_convergence_estimate(
sample, estimator, my_estimator, error_tol, **kwargs
)
forward_list.append(mean)
forward_error_list.append(error)
logger.info(
"{:.2f} +/- {:.2f} kT".format(forward_list[-1], forward_error_list[-1])
)
Expand All @@ -122,23 +120,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
sample = []
for data in df_list:
sample.append(data[-len(data) // num * i :])
sample = concat(sample)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
backward_list.append(result.delta_f_.iloc[0, -1])
if estimator.lower() == "bar":
error = np.sqrt(
sum(
[
result.d_delta_f_.iloc[i, i + 1] ** 2
for i in range(len(result.d_delta_f_) - 1)
]
)
)
backward_error_list.append(error)
else:
backward_error_list.append(result.d_delta_f_.iloc[0, -1])
mean, error = _forward_backward_convergence_estimate(
sample, estimator, my_estimator, error_tol, **kwargs
)
backward_list.append(mean)
backward_error_list.append(error)
logger.info(
"{:.2f} +/- {:.2f} kT".format(backward_list[-1], backward_error_list[-1])
)
Expand All @@ -156,6 +142,57 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
return convergence


def _forward_backward_convergence_estimate(
sample_list: List[pd.DataFrame],
estimator: str,
my_estimator: BaseEstimator,
error_tol: float,
**kwargs: Any,
) -> Tuple[float, float]:
"""Use estimator to run the estimation and return the mean and error.
Parameters
----------
sample_list: A list of samples as pandas Dataframe.
estimator: The string of the estimator
my_estimator: The estimator object.
error_tol: The error tolerance.
kwargs
Returns
-------
mean: The delta_f between 0 and 1
error: The d_delta_f between 0 and 1
"""
sample = concat(sample_list)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
mean = result.delta_f_.iloc[0, -1]
if estimator.lower() == "bar":
error = np.sqrt(
sum(
[
result.d_delta_f_.iloc[i, i + 1] ** 2
for i in range(len(result.d_delta_f_) - 1)
]
)
)
else:
error = result.d_delta_f_.iloc[0, -1]
if estimator.lower() == "mbar" and error > error_tol:
logger.warning(
f"Statistical Error ({error}) bigger than error tolerance ({error_tol}), use bootstrap error instead."
)
bootstraps_estimator = estimators_dispatch[estimator](
n_bootstraps=50, initial_f_k=result.delta_f_.iloc[0, :], **kwargs
)
bootstraps_estimator.fit(sample)
error = bootstraps_estimator.d_delta_f_.iloc[0, -1]

return mean, error


def _cummean(vals, out_length):
"""The cumulative mean of an array.
Expand Down
9 changes: 9 additions & 0 deletions src/alchemlyb/tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def test_convergence_wrong_cases(gmx_benzene_Coulomb_u_nk):
forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar")


def test_convergence_bootstrap(gmx_benzene_Coulomb_u_nk, caplog):
normal_c = forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar", num=2)
bootstrap_c = forward_backward_convergence(
gmx_benzene_Coulomb_u_nk, "mbar", error_tol=0.01, num=2
)
assert "use bootstrap error instead." in caplog.text
assert (bootstrap_c["Forward_Error"] != normal_c["Forward_Error"]).all()


def test_convergence_method(gmx_benzene_Coulomb_u_nk):
convergence = forward_backward_convergence(
gmx_benzene_Coulomb_u_nk, "MBAR", num=2, method="adaptive"
Expand Down

0 comments on commit c8fe7b8

Please sign in to comment.