Skip to content

Commit

Permalink
Merge pull request #346 from nspope/mom-fixup
Browse files Browse the repository at this point in the history
Minor fixes for central moment matching
  • Loading branch information
hyanwong authored Dec 13, 2023
2 parents cbd3d68 + 529d1ca commit 50f53b4
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 53 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tqdm
daiquiri
msprime>=1.0.0
scipy
numba
numba>=0.58.0
appdirs
pre-commit
pytest
Expand Down
6 changes: 6 additions & 0 deletions tests/test_hypergeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,14 @@ def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0):

def test_2f1(self, a_i, b_i, a_j, b_j, y, mu):
pars = [a_i, b_i, a_j, b_j, y, mu]
A = a_j
B = a_i + a_j + y
C = a_j + y + 1
z = (mu - b_j) / (mu + b_i)
f, *_ = hypergeo._hyp2f1(*pars)
ff = hypergeo._hyp2f1_fast(A, B, C, z)
check = float(mpmath.log(self._2f1_validate(*pars)))
assert np.isclose(f, ff)
assert np.isclose(f, check, rtol=2e-2)

def test_grad(self, a_i, b_i, a_j, b_j, y, mu):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,21 @@ def test_bad_arguments(self):
method="variational_gamma",
max_iterations=-1,
)

def test_match_central_moments(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
ts0 = tsdate.date(
ts,
mutation_rate=5,
population_size=1,
method="variational_gamma",
method_of_moments=False,
)
ts1 = tsdate.date(
ts,
mutation_rate=5,
population_size=1,
method="variational_gamma",
method_of_moments=True,
)
assert np.any(np.not_equal(ts0.nodes_time, ts1.nodes_time))
83 changes: 60 additions & 23 deletions tsdate/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,11 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij):
@numba.njit("UniTuple(f8, 7)(f8, f8, f8, f8, f8, f8)")
def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij):
"""
Calculate gamma sufficient statistics for the PDF proportional to
Calculate sufficient statistics for the PDF proportional to
:math:`Ga(t_j | a_j, b_j) Ga(t_i | a_i, b_i) Po(y_{ij} |
\\mu_{ij} t_i - t_j)`, where :math:`i` is the parent and :math:`j` is
the child.
the child. The logarithmic moments are approximated via a Taylor
expansion around the mean.
:param float a_i: the shape parameter of the cavity distribution for the parent
:param float b_i: the rate parameter of the cavity distribution for the parent
Expand All @@ -184,7 +185,8 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij):
:param float y_ij: the number of mutations on the edge
:param float mu_ij: the span-weighted mutation rate of the edge
:return: normalizing constant, E[t_i], E[log t_i], E[t_j], E[log t_j]
:return: normalizing constant, E[t_i], E[log t_i], V[t_i],
E[t_j], E[log t_j], V[t_j]
"""

a = a_j
Expand All @@ -193,30 +195,26 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij):
t = mu_ij + b_i
z = (mu_ij - b_j) / t

assert a > 0
assert b > 0
assert c > 0
assert t > 0

f0, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 0, b_j, y_ij, mu_ij)
f1, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 1, b_j, y_ij, mu_ij)
f2, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij)
f0 = hypergeo._hyp2f1_fast(a, b, c, z)
f1 = hypergeo._hyp2f1_fast(a + 1, b + 1, c + 1, z)
f2 = hypergeo._hyp2f1_fast(a + 2, b + 2, c + 2, z)
s1 = a * b / c
s2 = s1 * (a + 1) * (b + 1) / (c + 1)
d1 = s1 * np.exp(f1 - f0)
d2 = s2 * np.exp(f2 - f0)

logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * np.log(t)

mn_i = d1 * z / t + b / t
mn_j = d1 / t
sq_i = z / t**2 * (d2 * z + 2 * d1 * (1 + b)) + b * (1 + b) / t**2
sq_j = d2 / t**2
va_i = sq_i - mn_i**2
va_j = sq_j - mn_j**2
ln_i = np.log(mn_i) - va_i / 2 / mn_i**2
ln_j = np.log(mn_j) - va_j / 2 / mn_j**2

mn_i = mn_j * z + b / t
sq_i = sq_j * z**2 + (b + 1) * (mn_i + mn_j * z) / t
va_i = sq_i - mn_i**2
ln_i = np.log(mn_i) - va_i / 2 / mn_i**2

return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j


Expand All @@ -237,6 +235,32 @@ def _valid_sufficient_statistics(t_i, ln_t_i, t_j, ln_t_j):
return True


@numba.njit("b1(f8, f8, f8, f8, f8, f8)")
def _valid_parameterization(a_i, b_i, a_j, b_j, y, mu):
"""Uses shape / rate parameterization"""
a = a_j
b = a_i + a_j + y
c = a_j + y + 1
s = mu - b_j
t = mu + b_i
# check that 2F1 argument is not unity under some transformation
if t <= 0.0:
return False
z = s / t
if z >= 1.0:
return False
if z / (z - 1) >= 1.0:
return False
# check that 2F1 is positive
if a <= 0:
return False
if b <= 0:
return False
if c <= 0:
return False
return True


@numba.njit("Tuple((f8, f8[:], f8[:]))(f8[:], f8[:], f8[:], b1)")
def gamma_projection(pars_i, pars_j, pars_ij, min_kl):
"""
Expand All @@ -256,24 +280,37 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl):
:return: gamma natural parameters for parent and child
"""

# switch from natural to canonical parameterization
a_i, b_i = pars_i
a_j, b_j = pars_j
y_ij, mu_ij = pars_ij
a_i += 1
a_j += 1

# skip update, zeroing out message
if not _valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return np.nan, pars_i, pars_j

# if min_kl:
# logconst, t_i, ln_t_i, t_j, ln_t_j = sufficient_statistics(
# a_i, b_i, a_j, b_j, y_ij, mu_ij
# )
# if not _valid_sufficient_statistics(t_i, ln_t_i, t_j, ln_t_j):
# logconst, t_i, ln_t_i, _, t_j, ln_t_j, _ = taylor_approximation(
# a_i, b_i, a_j, b_j, y_ij, mu_ij
# )
# proj_i = approximate_gamma_kl(t_i, ln_t_i)
# proj_j = approximate_gamma_kl(t_j, ln_t_j)

if min_kl:
logconst, t_i, ln_t_i, t_j, ln_t_j = sufficient_statistics(
a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij
logconst, t_i, ln_t_i, _, t_j, ln_t_j, _ = taylor_approximation(
a_i, b_i, a_j, b_j, y_ij, mu_ij
)
if not _valid_sufficient_statistics(t_i, ln_t_i, t_j, ln_t_j):
logconst, t_i, ln_t_i, _, t_j, ln_t_j, _ = taylor_approximation(
a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij
)
proj_i = approximate_gamma_kl(t_i, ln_t_i)
proj_j = approximate_gamma_kl(t_j, ln_t_j)
else:
# TODO: test
logconst, t_i, _, va_t_i, t_j, _, va_t_j = taylor_approximation(
a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij
a_i, b_i, a_j, b_j, y_ij, mu_ij
)
proj_i = approximate_gamma_mom(t_i, va_t_i)
proj_j = approximate_gamma_mom(t_j, va_t_j)
Expand Down
34 changes: 29 additions & 5 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,14 @@ def factorize(edge_list, fixed_nodes):
@staticmethod
@numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)")
def propagate(
edges, likelihoods, posterior, messages, scale, log_partition, max_shape, min_kl
edges,
likelihoods,
posterior,
messages,
scale,
log_partition,
max_shape,
min_kl,
):
"""
Update approximating factors for each edge, returning average relative
Expand Down Expand Up @@ -1042,16 +1049,16 @@ def propagate(

def cavity_damping(x, y):
d = 1.0
if x[0] - y[0] < lower:
if (y[0] > 0.0) and (x[0] - y[0] < lower):
d = min(d, (x[0] - lower) / y[0])
if x[1] - y[1] < 0.0:
if (y[1] > 0.0) and (x[1] - y[1] < 0.0):
d = min(d, x[1] / y[1])
assert 0.0 < d <= 1.0
return d

def posterior_damping(x):
assert x[0] > -1.0 and x[1] > 0.0
d = min(1.0, upper / abs(x[0]))
d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0
assert 0.0 < d <= 1.0
return d

Expand Down Expand Up @@ -1300,6 +1307,16 @@ def date(
:rtype: tskit.TreeSequence or (tskit.TreeSequence, dict)
"""

# TODO: docstrings for variational gamma parameters
"""
:param bool method_of_moments: If ``True`` match central moments in variational gamma
algorithm, otherwise match sufficient statistics. Matching central moments
is faster, but introduces a small amount of bias. Default: ``False``.
:param float max_shape: The maximum allowed shape for the posterior in the
variational gamma algorithm. The shape parameter is the inverse of the
variance for ``log(age)``. Default: ``1000``.
"""

# check valid method - raise error if unknown.
check_method(method)

Expand Down Expand Up @@ -1554,13 +1571,13 @@ def variational_dates(
*,
max_iterations=20,
max_shape=1000,
method_of_moments=False,
global_prior=True,
eps=1e-6,
progress=False,
num_threads=None, # Unused, matches get_dates()
probability_space=None, # Can only be None, simply to match get_dates()
ignore_oldest_root=False, # Can only be False, simply to match get_dates()
min_kl=True, # Minimize KL divergence or match central moments
):
"""
Infer dates for the nodes in a tree sequence using expectation propagation,
Expand Down Expand Up @@ -1647,6 +1664,9 @@ def variational_dates(
fixed_node_set=fixed_nodes,
)

# minimize KL divergence or match central moments
min_kl = not method_of_moments

dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress)
for _ in tqdm(
np.arange(max_iterations),
Expand All @@ -1655,6 +1675,10 @@ def variational_dates(
):
dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl)

num_skipped = np.sum(np.isnan(dynamic_prog.log_partition))
if num_skipped > 0:
logging.info(f"Skipped {num_skipped} messages with invalid posterior updates.")

posterior = priors.clone_with_new_data(
grid_data=dynamic_prog.posterior[priors.nonfixed_nodes, :]
)
Expand Down
Loading

0 comments on commit 50f53b4

Please sign in to comment.