From f193dad2010a83e744345aea59dd24dd9f8cc184 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 4 Dec 2023 13:17:21 -0800 Subject: [PATCH 1/9] Fix possible division by zero Add faster version of 2f1 wo derivatives Option for faster central moment matching, document Add test Rename option in API to method_of_moments --- tests/test_hypergeo.py | 6 ++++ tests/test_inference.py | 18 ++++++++++ tsdate/approx.py | 29 +++++++--------- tsdate/core.py | 17 +++++++--- tsdate/hypergeo.py | 75 ++++++++++++++++++++++++++++++----------- 5 files changed, 106 insertions(+), 39 deletions(-) diff --git a/tests/test_hypergeo.py b/tests/test_hypergeo.py index 5f6d5875..d2abfd23 100644 --- a/tests/test_hypergeo.py +++ b/tests/test_hypergeo.py @@ -78,8 +78,14 @@ def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0): def test_2f1(self, a_i, b_i, a_j, b_j, y, mu): pars = [a_i, b_i, a_j, b_j, y, mu] + A = a_j + B = a_i + a_j + y + C = a_j + y + 1 + z = (mu - b_j) / (mu + b_i) f, *_ = hypergeo._hyp2f1(*pars) + ff = hypergeo._hyp2f1_fast(A, B, C, z) check = float(mpmath.log(self._2f1_validate(*pars))) + assert np.isclose(f, ff) assert np.isclose(f, check, rtol=2e-2) def test_grad(self, a_i, b_i, a_j, b_j, y, mu): diff --git a/tests/test_inference.py b/tests/test_inference.py index 2ad9f59d..d4d1bc08 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -427,3 +427,21 @@ def test_bad_arguments(self): method="variational_gamma", max_iterations=-1, ) + + def test_match_central_moments(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + ts0 = tsdate.date( + ts, + mutation_rate=5, + population_size=1, + method="variational_gamma", + method_of_moments=False, + ) + ts1 = tsdate.date( + ts, + mutation_rate=5, + population_size=1, + method="variational_gamma", + method_of_moments=True, + ) + assert np.any(np.not_equal(ts0.nodes_time, ts1.nodes_time)) diff --git a/tsdate/approx.py b/tsdate/approx.py index ad083633..bd84fdc3 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -172,10 +172,11 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij): @numba.njit("UniTuple(f8, 7)(f8, f8, f8, f8, f8, f8)") def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): """ - Calculate gamma sufficient statistics for the PDF proportional to + Calculate sufficient statistics for the PDF proportional to :math:`Ga(t_j | a_j, b_j) Ga(t_i | a_i, b_i) Po(y_{ij} | \\mu_{ij} t_i - t_j)`, where :math:`i` is the parent and :math:`j` is - the child. + the child. The logarithmic moments are approximated via a Taylor + expansion around the mean. :param float a_i: the shape parameter of the cavity distribution for the parent :param float b_i: the rate parameter of the cavity distribution for the parent @@ -184,7 +185,8 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): :param float y_ij: the number of mutations on the edge :param float mu_ij: the span-weighted mutation rate of the edge - :return: normalizing constant, E[t_i], E[log t_i], E[t_j], E[log t_j] + :return: normalizing constant, E[t_i], E[log t_i], V[t_i], + E[t_j], E[log t_j], V[t_j] """ a = a_j @@ -193,14 +195,9 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): t = mu_ij + b_i z = (mu_ij - b_j) / t - assert a > 0 - assert b > 0 - assert c > 0 - assert t > 0 - - f0, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 0, b_j, y_ij, mu_ij) - f1, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 1, b_j, y_ij, mu_ij) - f2, _, _, _, _ = hypergeo._hyp2f1(a_i, b_i, a_j + 2, b_j, y_ij, mu_ij) + f0 = hypergeo._hyp2f1_fast(a, b, c, z) + f1 = hypergeo._hyp2f1_fast(a + 1, b + 1, c + 1, z) + f2 = hypergeo._hyp2f1_fast(a + 2, b + 2, c + 2, z) s1 = a * b / c s2 = s1 * (a + 1) * (b + 1) / (c + 1) d1 = s1 * np.exp(f1 - f0) @@ -208,15 +205,16 @@ def taylor_approximation(a_i, b_i, a_j, b_j, y_ij, mu_ij): logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * np.log(t) - mn_i = d1 * z / t + b / t mn_j = d1 / t - sq_i = z / t**2 * (d2 * z + 2 * d1 * (1 + b)) + b * (1 + b) / t**2 sq_j = d2 / t**2 - va_i = sq_i - mn_i**2 va_j = sq_j - mn_j**2 - ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 ln_j = np.log(mn_j) - va_j / 2 / mn_j**2 + mn_i = mn_j * z + b / t + sq_i = sq_j * z**2 + (b + 1) * (mn_i + mn_j * z) / t + va_i = sq_i - mn_i**2 + ln_i = np.log(mn_i) - va_i / 2 / mn_i**2 + return logl, mn_i, ln_i, va_i, mn_j, ln_j, va_j @@ -271,7 +269,6 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl): proj_i = approximate_gamma_kl(t_i, ln_t_i) proj_j = approximate_gamma_kl(t_j, ln_t_j) else: - # TODO: test logconst, t_i, _, va_t_i, t_j, _, va_t_j = taylor_approximation( a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij ) diff --git a/tsdate/core.py b/tsdate/core.py index 1fdf40fe..fd2dbbf5 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1042,16 +1042,16 @@ def propagate( def cavity_damping(x, y): d = 1.0 - if x[0] - y[0] < lower: + if (y[0] > 0.0) and (x[0] - y[0] < lower): d = min(d, (x[0] - lower) / y[0]) - if x[1] - y[1] < 0.0: + if (y[1] > 0.0) and (x[1] - y[1] < 0.0): d = min(d, x[1] / y[1]) assert 0.0 < d <= 1.0 return d def posterior_damping(x): assert x[0] > -1.0 and x[1] > 0.0 - d = min(1.0, upper / abs(x[0])) + d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0 assert 0.0 < d <= 1.0 return d @@ -1274,6 +1274,12 @@ def date( from the inside algorithm in addition to the dated tree sequence. If ``return_posteriors`` is also ``True``, then the marginal likelihood will be the last element of the tuple. + :param bool method_of_moments: If ``True`` match central moments in variational gamma + algorithm, otherwise match sufficient statistics. Matching central moments + is faster, but introduces a small amount of bias. Default: ``False``. + :param float max_shape: The maximum allowed shape for the posterior in the + variational gamma algorithm. The shape parameter is the inverse of the + variance for ``log(age)``. Default: ``1000``. :param float eps: Specify minimum distance separating time points. Also specifies the error factor in time difference calculations. Default: 1e-6 :param int num_threads: The number of threads to use. A simpler unthreaded algorithm @@ -1554,13 +1560,13 @@ def variational_dates( *, max_iterations=20, max_shape=1000, + method_of_moments=False, global_prior=True, eps=1e-6, progress=False, num_threads=None, # Unused, matches get_dates() probability_space=None, # Can only be None, simply to match get_dates() ignore_oldest_root=False, # Can only be False, simply to match get_dates() - min_kl=True, # Minimize KL divergence or match central moments ): """ Infer dates for the nodes in a tree sequence using expectation propagation, @@ -1647,6 +1653,9 @@ def variational_dates( fixed_node_set=fixed_nodes, ) + # minimize KL divergence or match central moments + min_kl = not method_of_moments + dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress) for _ in tqdm( np.arange(max_iterations), diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index 6b9a6590..943c6789 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -38,19 +38,19 @@ _ptr_dbl = _PTR(_dbl) _gammaln_addr = get_cython_function_address("scipy.special.cython_special", "gammaln") _gammaln_functype = ctypes.CFUNCTYPE(_dbl, _dbl) -_gammaln_float64 = _gammaln_functype(_gammaln_addr) +_gammaln_f8 = _gammaln_functype(_gammaln_addr) class Invalid2F1(Exception): pass -@numba.njit("float64(float64)") +@numba.njit("f8(f8)") def _gammaln(x): - return _gammaln_float64(x) + return _gammaln_f8(x) -@numba.njit("float64(float64)") +@numba.njit("f8(f8)") def _digamma(x): """ Digamma (psi) function, from asymptotic series expansion. @@ -74,7 +74,7 @@ def _digamma(x): ) -@numba.njit("float64(float64)") +@numba.njit("f8(f8)") def _trigamma(x): """ Trigamma function, from asymptotic series expansion @@ -100,12 +100,12 @@ def _trigamma(x): ) -@numba.njit("float64(float64, float64)") +@numba.njit("f8(f8, f8)") def _betaln(p, q): return _gammaln(p) + _gammaln(q) - _gammaln(p + q) -@numba.njit("boolean(float64, float64, float64, float64, float64, float64, float64)") +@numba.njit("b1(f8, f8, f8, f8, f8, f8, f8)") def _is_valid_2f1(f1, f2, a, b, c, z, tol): """ Use the contiguous relation between the Gauss hypergeometric function and @@ -127,7 +127,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z, tol): return numer / denom < tol -@numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)") +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)") def _hyp2f1_taylor_series(a, b, c, z): """ Evaluate a Gaussian hypergeometric function, via its Taylor series at the @@ -198,7 +198,7 @@ def _hyp2f1_taylor_series(a, b, c, z): return val, da, db, dc, dz -@numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)") +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)") def _hyp2f1_laplace_approx(a, b, c, x): """ Approximate a Gaussian hypergeometric function, using Laplace's method @@ -269,7 +269,50 @@ def _hyp2f1_laplace_approx(a, b, c, x): return f, df_da, df_db, df_dc, df_dx -# @numba.njit("UniTuple(float64, 5)(float64, float64, float64, float64)") +@numba.njit("f8(f8, f8, f8, f8)") +def _hyp2f1_fast(a, b, c, x): + """ + Approximate a Gaussian hypergeometric function, using Laplace's method + as per Butler & Wood 2002 Annals of Statistics. + + Shortcut bypassing the lengthly derivative computation. + """ + + assert c > 0.0 + assert a >= 0.0 + assert b >= 0.0 + assert c >= a + assert x < 1.0 + + if x == 0.0: + return 0.0 + + s = 0.0 + if x < 0.0: + s = -b * log(1 - x) + a = c - a + x = x / (x - 1) + + t = x * (b - a) - c + u = np.sqrt(t**2 - 4 * a * x * (c - b)) - t + y = 2 * a / u + yy = y**2 / a + my = (1 - y) ** 2 / (c - a) + ymy = x**2 * b * yy * my / (1 - x * y) ** 2 + r = yy + my - ymy + f = ( + s + + (c - 1 / 2) * log(c) + - log(r) / 2 + + a * (log(y) - log(a)) + + (c - a) * (log(1 - y) - log(c - a)) + - b * log(1 - x * y) + ) + + return f + + +# @numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)") # def _hyp2f1_laplace_recurrence(a, b, c, x): # """ # Use contiguous relations to stabilize the calculation of 2F1 @@ -305,9 +348,7 @@ def _hyp2f1_laplace_approx(a, b, c, x): # return v, da, db, dc, dx -@numba.njit( - "UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)" -) +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)") def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu): """ DLMF 15.8.1, series expansion with Pfaff transformation @@ -332,9 +373,7 @@ def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu): return val, da_i, db_i, da_j, db_j -@numba.njit( - "UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)" -) +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)") def _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu): """ DLMF 15.2.1, series expansion without transformation @@ -356,9 +395,7 @@ def _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu): return val, da_i, db_i, da_j, db_j -@numba.njit( - "UniTuple(float64, 5)(float64, float64, float64, float64, float64, float64)" -) +@numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8, f8, f8)") def _hyp2f1(a_i, b_i, a_j, b_j, y, mu): """ Evaluates: From b49b96076eb60a5aace8d11c2d244e736d27fdc1 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 6 Dec 2023 11:52:53 -0800 Subject: [PATCH 2/9] Zero out message if hypergeometric argument is unity; switch to Taylor approximation instead of derivatives --- tsdate/approx.py | 53 ++++++++++++++++++++++++++++++++++++++++------ tsdate/hypergeo.py | 21 +++++++++--------- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/tsdate/approx.py b/tsdate/approx.py index bd84fdc3..fe35ec51 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -235,6 +235,31 @@ def _valid_sufficient_statistics(t_i, ln_t_i, t_j, ln_t_j): return True +@numba.njit("b1(f8, f8, f8, f8, f8, f8)") +def _valid_parameterization(a_i, b_i, a_j, b_j, y, mu): + """Uses shape / rate parameterization""" + a = a_j + b = a_i + a_j + y + c = a_j + y + 1 + s = mu - b_j + t = mu + b_i + # check that 2F1 argument is not unity under some transformation + if np.isclose(t, 0.0): + return False + if np.isclose(s / t, 1.0): + return False + if np.isclose(-s / (t - s), 1.0): + return False + # check that 2F1 is positive + if a <= 0: + return False + if b <= 0: + return False + if c <= 0: + return False + return True + + @numba.njit("Tuple((f8, f8[:], f8[:]))(f8[:], f8[:], f8[:], b1)") def gamma_projection(pars_i, pars_j, pars_ij, min_kl): """ @@ -254,23 +279,37 @@ def gamma_projection(pars_i, pars_j, pars_ij, min_kl): :return: gamma natural parameters for parent and child """ + # switch from natural to canonical parameterization a_i, b_i = pars_i a_j, b_j = pars_j y_ij, mu_ij = pars_ij + a_i += 1 + a_j += 1 + + # skip update, zeroing out message + if not _valid_parameterization(a_i, b_i, a_j, b_j, y_ij, mu_ij): + return np.nan, pars_i, pars_j + + # if min_kl: + # logconst, t_i, ln_t_i, t_j, ln_t_j = sufficient_statistics( + # a_i, b_i, a_j, b_j, y_ij, mu_ij + # ) + # if not _valid_sufficient_statistics(t_i, ln_t_i, t_j, ln_t_j): + # logconst, t_i, ln_t_i, _, t_j, ln_t_j, _ = taylor_approximation( + # a_i, b_i, a_j, b_j, y_ij, mu_ij + # ) + # proj_i = approximate_gamma_kl(t_i, ln_t_i) + # proj_j = approximate_gamma_kl(t_j, ln_t_j) if min_kl: - logconst, t_i, ln_t_i, t_j, ln_t_j = sufficient_statistics( - a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij + logconst, t_i, ln_t_i, _, t_j, ln_t_j, _ = taylor_approximation( + a_i, b_i, a_j, b_j, y_ij, mu_ij ) - if not _valid_sufficient_statistics(t_i, ln_t_i, t_j, ln_t_j): - logconst, t_i, ln_t_i, _, t_j, ln_t_j, _ = taylor_approximation( - a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij - ) proj_i = approximate_gamma_kl(t_i, ln_t_i) proj_j = approximate_gamma_kl(t_j, ln_t_j) else: logconst, t_i, _, va_t_i, t_j, _, va_t_j = taylor_approximation( - a_i + 1.0, b_i, a_j + 1.0, b_j, y_ij, mu_ij + a_i, b_i, a_j, b_j, y_ij, mu_ij ) proj_i = approximate_gamma_mom(t_i, va_t_i) proj_j = approximate_gamma_mom(t_j, va_t_j) diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index 943c6789..e6ed4ffc 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -215,7 +215,7 @@ def _hyp2f1_laplace_approx(a, b, c, x): assert c >= a assert 1.0 > x >= 0.0 - if x == 0.0: + if np.isclose(x, 0.0): return 0.0, 0.0, 0.0, 0.0, a * b / c # Equations 19, 24, 25 in Butler & Wood @@ -284,7 +284,7 @@ def _hyp2f1_fast(a, b, c, x): assert c >= a assert x < 1.0 - if x == 0.0: + if np.isclose(x, 0.0): return 0.0 s = 0.0 @@ -301,15 +301,14 @@ def _hyp2f1_fast(a, b, c, x): ymy = x**2 * b * yy * my / (1 - x * y) ** 2 r = yy + my - ymy f = ( - s - + (c - 1 / 2) * log(c) + +(c - 1 / 2) * log(c) - log(r) / 2 + a * (log(y) - log(a)) + (c - a) * (log(1 - y) - log(c - a)) - b * log(1 - x * y) ) - return f + return f + s # @numba.njit("UniTuple(f8, 5)(f8, f8, f8, f8)") @@ -354,17 +353,19 @@ def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu): DLMF 15.8.1, series expansion with Pfaff transformation """ + a = y + 1 b = a_i + a_j + y c = a_j + y + 1 z = (b_j - mu) / (b_i + b_j) - scale = -b * np.log(1 - z / (z - 1)) + s = (mu - b_j) / (mu + b_i) + scale = -b * np.log(1 - s) # 2F1(y+1, b; c; z) via series expansion - val, _, db, dc, dz = _hyp2f1_laplace_approx(y + 1, b, c, z) + val, _, db, dc, dz = _hyp2f1_laplace_approx(a, b, c, z) # map gradient to parameters - da_i = db - np.log(1 - z / (z - 1)) - da_j = db + dc - np.log(1 - z / (z - 1)) + da_i = db - np.log(1 - s) + da_j = db + dc - np.log(1 - s) db_i = z * (b / (mu + b_i) - dz / (b_i + b_j)) db_j = (z - 1) * (b / (mu + b_i) - dz / (b_i + b_j)) @@ -416,7 +417,7 @@ def _hyp2f1(a_i, b_i, a_j, b_j, y, mu): and dividing the gradient by the function value. """ z = (mu - b_j) / (mu + b_i) - assert z < 1.0, "Invalid hypergeometric function argument" + assert z < 1.0 and not np.isclose(z, 1.0), "Invalid argument" if z > 0.0: return _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu) else: From e7a1fdcdff59ae1bab9c0199bacf7771cd4ccc82 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 6 Dec 2023 12:04:29 -0800 Subject: [PATCH 3/9] Add a temporary debugging mode --- tsdate/core.py | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index fd2dbbf5..07a5fe83 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1011,9 +1011,19 @@ def factorize(edge_list, fixed_nodes): return internal, external @staticmethod - @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)") + @numba.njit( + "f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1, b1)" + ) def propagate( - edges, likelihoods, posterior, messages, scale, log_partition, max_shape, min_kl + edges, + likelihoods, + posterior, + messages, + scale, + log_partition, + max_shape, + min_kl, + debug, ): """ Update approximating factors for each edge, returning average relative @@ -1056,41 +1066,61 @@ def posterior_damping(x): return d for i, p, c in edges: + if debug: + print("---\nedge:", i, "parent:", p, "child:", c) # Damped downdate to ensure proper cavity distributions parent_message = messages[i, 0] * scale[p] child_message = messages[i, 1] * scale[c] + if debug: + print("p-mess:", parent_message, "c-mess:", child_message) parent_delta = cavity_damping(posterior[p], parent_message) child_delta = cavity_damping(posterior[c], child_message) delta = min(parent_delta, child_delta) + if debug: + print("delta:", delta) # The cavity posteriors: the approximation omitting the variational # factors for this edge. parent_cavity = posterior[p] - delta * parent_message child_cavity = posterior[c] - delta * child_message + if debug: + print("p-cavi:", parent_cavity, "c-cavi:", child_cavity) # The edge likelihood, scaled by the damping factor edge_likelihood = delta * likelihoods[i] + if debug: + print("e-like:", edge_likelihood) # The target posterior: the cavity multiplied by the edge # likelihood then projected onto a gamma via moment matching. logconst, parent_post, child_post = approx.gamma_projection( parent_cavity, child_cavity, edge_likelihood, min_kl ) + if debug: + print("logconst:", logconst) + if debug: + print("p-post:", parent_post, "c-post:", child_post) # The messages: the difference in natural parameters between the # target and cavity posteriors. messages[i, 0] += (parent_post - posterior[p]) / scale[p] messages[i, 1] += (child_post - posterior[c]) / scale[c] + if debug: + print("p-updt:", messages[i, 0], "c-updt:", messages[i, 1]) # Contribution to the marginal likelihood from the edge log_partition[i] = logconst # TODO: incomplete # Constrain the messages so that the gamma shape parameter for each # posterior is bounded (e.g. set a maximum precision for log(age)). parent_eta = posterior_damping(parent_post) child_eta = posterior_damping(child_post) + if debug: + print("p-scal:", parent_eta, "c-scal:", child_eta) posterior[p] = parent_eta * parent_post posterior[c] = child_eta * child_post scale[p] *= parent_eta scale[c] *= child_eta + if debug: + print("p-end:", posterior[p], "c-end:", posterior[c]) return 0.0 # TODO, placeholder - def iterate(self, max_shape=1000, min_kl=True): + def iterate(self, max_shape=1000, min_kl=True, debug=False): """ Update edge factors from leaves to root then from root to leaves, and return approximate log marginal likelihood (TODO) @@ -1105,6 +1135,7 @@ def iterate(self, max_shape=1000, min_kl=True): self.log_partition, max_shape, min_kl, + debug, ) self.propagate( self.edges[::-1], @@ -1115,6 +1146,7 @@ def iterate(self, max_shape=1000, min_kl=True): self.log_partition, max_shape, min_kl, + debug, ) # TODO @@ -1567,6 +1599,7 @@ def variational_dates( num_threads=None, # Unused, matches get_dates() probability_space=None, # Can only be None, simply to match get_dates() ignore_oldest_root=False, # Can only be False, simply to match get_dates() + debug=False, # Print a ton of extra information ): """ Infer dates for the nodes in a tree sequence using expectation propagation, @@ -1662,7 +1695,7 @@ def variational_dates( desc="Expectation Propagation", disable=not progress, ): - dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) + dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl, debug=debug) posterior = priors.clone_with_new_data( grid_data=dynamic_prog.posterior[priors.nonfixed_nodes, :] From 5e3087c734396eb21cbb1ecda72fe922a82d9135 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 6 Dec 2023 12:14:52 -0800 Subject: [PATCH 4/9] Another debugging insert to print skipped messages --- tsdate/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tsdate/core.py b/tsdate/core.py index 07a5fe83..bbc6c2cf 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1696,6 +1696,10 @@ def variational_dates( disable=not progress, ): dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl, debug=debug) + # (DEBUG) check how many messages were skipped + skipped = np.sum(np.isnan(dynamic_prog.log_partition)) + print("Skipped", skipped, "messages") + # (END DEBUG) posterior = priors.clone_with_new_data( grid_data=dynamic_prog.posterior[priors.nonfixed_nodes, :] From 242b9ee69b6db77dc23d56d351dd8da3f404f1a9 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 8 Dec 2023 14:07:06 -0800 Subject: [PATCH 5/9] Replace np.isclose with inequality checks --- tsdate/approx.py | 7 ++++--- tsdate/hypergeo.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tsdate/approx.py b/tsdate/approx.py index fe35ec51..d3efd81a 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -244,11 +244,12 @@ def _valid_parameterization(a_i, b_i, a_j, b_j, y, mu): s = mu - b_j t = mu + b_i # check that 2F1 argument is not unity under some transformation - if np.isclose(t, 0.0): + if t <= 0.0: return False - if np.isclose(s / t, 1.0): + z = s / t + if z >= 1.0: return False - if np.isclose(-s / (t - s), 1.0): + if z / (z - 1) >= 1.0: return False # check that 2F1 is positive if a <= 0: diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index e6ed4ffc..b650bbcc 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -215,7 +215,7 @@ def _hyp2f1_laplace_approx(a, b, c, x): assert c >= a assert 1.0 > x >= 0.0 - if np.isclose(x, 0.0): + if x == 0.0: return 0.0, 0.0, 0.0, 0.0, a * b / c # Equations 19, 24, 25 in Butler & Wood @@ -284,7 +284,7 @@ def _hyp2f1_fast(a, b, c, x): assert c >= a assert x < 1.0 - if np.isclose(x, 0.0): + if x == 0.0: return 0.0 s = 0.0 @@ -417,7 +417,7 @@ def _hyp2f1(a_i, b_i, a_j, b_j, y, mu): and dividing the gradient by the function value. """ z = (mu - b_j) / (mu + b_i) - assert z < 1.0 and not np.isclose(z, 1.0), "Invalid argument" + assert z < 1.0, "Invalid argument" if z > 0.0: return _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu) else: From 9b35494552328ec283405a99002564c203a83985 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 11 Dec 2023 12:46:10 -0800 Subject: [PATCH 6/9] Remove debugging inserts; move skipped messages warning into logging --- tsdate/core.py | 36 ++++++------------------------------ 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index bbc6c2cf..41b4b0c7 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1023,7 +1023,6 @@ def propagate( log_partition, max_shape, min_kl, - debug, ): """ Update approximating factors for each edge, returning average relative @@ -1066,61 +1065,41 @@ def posterior_damping(x): return d for i, p, c in edges: - if debug: - print("---\nedge:", i, "parent:", p, "child:", c) # Damped downdate to ensure proper cavity distributions parent_message = messages[i, 0] * scale[p] child_message = messages[i, 1] * scale[c] - if debug: - print("p-mess:", parent_message, "c-mess:", child_message) parent_delta = cavity_damping(posterior[p], parent_message) child_delta = cavity_damping(posterior[c], child_message) delta = min(parent_delta, child_delta) - if debug: - print("delta:", delta) # The cavity posteriors: the approximation omitting the variational # factors for this edge. parent_cavity = posterior[p] - delta * parent_message child_cavity = posterior[c] - delta * child_message - if debug: - print("p-cavi:", parent_cavity, "c-cavi:", child_cavity) # The edge likelihood, scaled by the damping factor edge_likelihood = delta * likelihoods[i] - if debug: - print("e-like:", edge_likelihood) # The target posterior: the cavity multiplied by the edge # likelihood then projected onto a gamma via moment matching. logconst, parent_post, child_post = approx.gamma_projection( parent_cavity, child_cavity, edge_likelihood, min_kl ) - if debug: - print("logconst:", logconst) - if debug: - print("p-post:", parent_post, "c-post:", child_post) # The messages: the difference in natural parameters between the # target and cavity posteriors. messages[i, 0] += (parent_post - posterior[p]) / scale[p] messages[i, 1] += (child_post - posterior[c]) / scale[c] - if debug: - print("p-updt:", messages[i, 0], "c-updt:", messages[i, 1]) # Contribution to the marginal likelihood from the edge log_partition[i] = logconst # TODO: incomplete # Constrain the messages so that the gamma shape parameter for each # posterior is bounded (e.g. set a maximum precision for log(age)). parent_eta = posterior_damping(parent_post) child_eta = posterior_damping(child_post) - if debug: - print("p-scal:", parent_eta, "c-scal:", child_eta) posterior[p] = parent_eta * parent_post posterior[c] = child_eta * child_post scale[p] *= parent_eta scale[c] *= child_eta - if debug: - print("p-end:", posterior[p], "c-end:", posterior[c]) return 0.0 # TODO, placeholder - def iterate(self, max_shape=1000, min_kl=True, debug=False): + def iterate(self, max_shape=1000, min_kl=True): """ Update edge factors from leaves to root then from root to leaves, and return approximate log marginal likelihood (TODO) @@ -1135,7 +1114,6 @@ def iterate(self, max_shape=1000, min_kl=True, debug=False): self.log_partition, max_shape, min_kl, - debug, ) self.propagate( self.edges[::-1], @@ -1146,7 +1124,6 @@ def iterate(self, max_shape=1000, min_kl=True, debug=False): self.log_partition, max_shape, min_kl, - debug, ) # TODO @@ -1599,7 +1576,6 @@ def variational_dates( num_threads=None, # Unused, matches get_dates() probability_space=None, # Can only be None, simply to match get_dates() ignore_oldest_root=False, # Can only be False, simply to match get_dates() - debug=False, # Print a ton of extra information ): """ Infer dates for the nodes in a tree sequence using expectation propagation, @@ -1695,11 +1671,11 @@ def variational_dates( desc="Expectation Propagation", disable=not progress, ): - dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl, debug=debug) - # (DEBUG) check how many messages were skipped - skipped = np.sum(np.isnan(dynamic_prog.log_partition)) - print("Skipped", skipped, "messages") - # (END DEBUG) + dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) + + num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) + if num_skipped > 0: + logging.info(f"Skipped {num_skipped} messages with invalid posterior updates.") posterior = priors.clone_with_new_data( grid_data=dynamic_prog.posterior[priors.nonfixed_nodes, :] From 7343fe8f233809c78a60c1e7ee58b93fa3644cf0 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 11 Dec 2023 12:52:57 -0800 Subject: [PATCH 7/9] Numba signature bug --- tsdate/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index 41b4b0c7..afb2806f 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1011,9 +1011,7 @@ def factorize(edge_list, fixed_nodes): return internal, external @staticmethod - @numba.njit( - "f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1, b1)" - ) + @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)") def propagate( edges, likelihoods, From 375dc04d813a975d0ce33e9347f289148346aa6b Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Mon, 11 Dec 2023 13:03:48 -0800 Subject: [PATCH 8/9] Set numba minimum version to 0.58.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 73c66ced..0997b109 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ tqdm daiquiri msprime>=1.0.0 scipy -numba +numba>=0.58.0 appdirs pre-commit pytest From 529d1ca7153098df33f5f172a1961df5d2e84f01 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 12 Dec 2023 13:36:11 -0800 Subject: [PATCH 9/9] Pull vgamma parameters from date docstring --- tsdate/core.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index afb2806f..7a0c810a 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1281,12 +1281,6 @@ def date( from the inside algorithm in addition to the dated tree sequence. If ``return_posteriors`` is also ``True``, then the marginal likelihood will be the last element of the tuple. - :param bool method_of_moments: If ``True`` match central moments in variational gamma - algorithm, otherwise match sufficient statistics. Matching central moments - is faster, but introduces a small amount of bias. Default: ``False``. - :param float max_shape: The maximum allowed shape for the posterior in the - variational gamma algorithm. The shape parameter is the inverse of the - variance for ``log(age)``. Default: ``1000``. :param float eps: Specify minimum distance separating time points. Also specifies the error factor in time difference calculations. Default: 1e-6 :param int num_threads: The number of threads to use. A simpler unthreaded algorithm @@ -1313,6 +1307,16 @@ def date( :rtype: tskit.TreeSequence or (tskit.TreeSequence, dict) """ + # TODO: docstrings for variational gamma parameters + """ + :param bool method_of_moments: If ``True`` match central moments in variational gamma + algorithm, otherwise match sufficient statistics. Matching central moments + is faster, but introduces a small amount of bias. Default: ``False``. + :param float max_shape: The maximum allowed shape for the posterior in the + variational gamma algorithm. The shape parameter is the inverse of the + variance for ``log(age)``. Default: ``1000``. + """ + # check valid method - raise error if unknown. check_method(method)