Skip to content

Commit

Permalink
Merge pull request #364 from nspope/add-hyperu
Browse files Browse the repository at this point in the history
Non-contemporary samples
  • Loading branch information
hyanwong authored Feb 7, 2024
2 parents b785d66 + 8b4eed0 commit 9490955
Show file tree
Hide file tree
Showing 10 changed files with 1,808 additions and 937 deletions.
265 changes: 195 additions & 70 deletions tests/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@
from tsdate import hypergeo
from tsdate import prior

# TODO: better test set?
# TODO: test special case where child is fixed to age 0
_gamma_trio_test_cases = [ # [shape1, rate1, shape2, rate2, muts, rate]
[10.541, 0.0005, 10.552, 0.005, 1.0, 0.0151],
[10.541, 0.0065, 10.552, 0.005, 1.0, 0.0101],
[10.541, 0.0065, 10.552, 0.022, 1.0, 0.0051],
[10.541, 0.0265, 10.552, 0.022, 1.0, 0.0051],
[4, 4, 4, 4, 4, 4],
[2.0, 0.0005, 2.0, 0.005, 0.0, 0.001],
[2.0, 0.0005, 2.0, 0.005, 1.0, 0.001],
[2.0, 0.0005, 2.0, 0.005, 2.0, 0.001],
[2.0, 0.0005, 2.0, 0.005, 3.0, 0.001],
]


Expand Down Expand Up @@ -70,110 +71,221 @@ def pdf(t_i, t_j, a_i, b_i, a_j, b_j, y, mu):
* np.exp(-(t_i - t_j) * mu)
)

def test_sufficient_statistics(self, pars):
logconst, t_i, ln_t_i, t_j, ln_t_j = approx.sufficient_statistics(*pars)
@staticmethod
def pdf_rootward(t_i, t_j, a_i, b_i, y, mu):
"""
Target conditional distribution, proportional to the parent
marginals (gamma) and a Poisson mutation likelihood at a
fixed child age
"""
if t_i < t_j:
return 0.0
else:
return (
t_i ** (a_i - 1)
* np.exp(-t_i * b_i)
* (t_i - t_j) ** y
* np.exp(-(t_i - t_j) * mu)
)

@staticmethod
def pdf_leafward(t_i, t_j, a_j, b_j, y, mu):
"""
Target conditional distribution, proportional to the child
marginals (gamma) and a Poisson mutation likelihood at a
fixed parent age
"""
if t_i < t_j:
return 0.0
else:
return (
t_j ** (a_j - 1)
* np.exp(-t_j * b_j)
* (t_i - t_j) ** y
* np.exp(-(t_i - t_j) * mu)
)

@staticmethod
def pdf_truncated(t_i, low, upp, a_i, b_i):
"""
Target proportional to the node marginals (gamma) and an indicator
function
"""
if low < t_i < upp:
return np.exp(
np.log(t_i) * (a_i - 1)
- t_i * b_i
- scipy.special.gammaln(a_i)
+ np.log(b_i) * a_i
)
else:
return 0.0

def test_moments(self, pars):
"""
Test mean and variance when ages of both nodes are free
"""
logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.moments(*pars)
ck_normconst = scipy.integrate.dblquad(
lambda ti, tj: self.pdf(ti, tj, *pars),
lambda t_i, t_j: self.pdf(t_i, t_j, *pars),
0,
np.inf,
lambda tj: tj,
lambda t_j: t_j,
np.inf,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-3)
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2)
ck_t_i = scipy.integrate.dblquad(
lambda ti, tj: ti * self.pdf(ti, tj, *pars) / ck_normconst,
lambda t_i, t_j: t_i * self.pdf(t_i, t_j, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
lambda t_j: t_j,
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=2e-3)
assert np.isclose(t_i, ck_t_i, rtol=2e-2)
ck_t_j = scipy.integrate.dblquad(
lambda ti, tj: tj * self.pdf(ti, tj, *pars) / ck_normconst,
lambda t_i, t_j: t_j * self.pdf(t_i, t_j, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
lambda t_j: t_j,
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_j, ck_t_j, rtol=2e-3)
ck_ln_t_i = scipy.integrate.dblquad(
lambda ti, tj: np.log(ti) * self.pdf(ti, tj, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
assert np.isclose(t_j, ck_t_j, rtol=2e-2)
ck_var_t_i = (
scipy.integrate.dblquad(
lambda t_i, t_j: t_i**2 * self.pdf(t_i, t_j, *pars) / ck_normconst,
0,
np.inf,
lambda t_j: t_j,
np.inf,
epsabs=0,
)[0]
- ck_t_i**2
)
assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2)
ck_var_t_j = (
scipy.integrate.dblquad(
lambda t_i, t_j: t_j**2 * self.pdf(t_i, t_j, *pars) / ck_normconst,
0,
np.inf,
lambda t_j: t_j,
np.inf,
epsabs=0,
)[0]
- ck_t_j**2
)
assert np.isclose(var_t_j, ck_var_t_j, rtol=2e-2)

def test_rootward_moments(self, pars):
"""
Test mean and variance of parent age when child age is fixed to a nonzero value
"""
a_i, b_i, a_j, b_j, y, mu = pars
t_j = a_j / b_j # point "estimate" for child
pars_redux = (a_i, b_i, y, mu)
logconst, t_i, _, var_t_i = approx.rootward_moments(t_j, *pars_redux)
ck_normconst = scipy.integrate.quad(
lambda t_i: self.pdf_rootward(t_i, t_j, *pars_redux),
t_j,
np.inf,
epsabs=0,
)[0]
assert np.isclose(ln_t_i, ck_ln_t_i, rtol=2e-3)
ck_ln_t_j = scipy.integrate.dblquad(
lambda ti, tj: np.log(tj) * self.pdf(ti, tj, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2)
ck_t_i = scipy.integrate.quad(
lambda t_i: t_i * self.pdf_rootward(t_i, t_j, *pars_redux) / ck_normconst,
t_j,
np.inf,
epsabs=0,
)[0]
assert np.isclose(ln_t_j, ck_ln_t_j, rtol=2e-3)
assert np.isclose(t_i, ck_t_i, rtol=2e-2)
ck_var_t_i = (
scipy.integrate.quad(
lambda t_i: t_i**2
* self.pdf_rootward(t_i, t_j, *pars_redux)
/ ck_normconst,
t_j,
np.inf,
epsabs=0,
)[0]
- ck_t_i**2
)
assert np.isclose(var_t_i, ck_var_t_i, rtol=2e-2)

def test_taylor_approximation(self, pars):
logconst, t_i, _, var_t_i, t_j, _, var_t_j = approx.taylor_approximation(*pars)
ck_normconst = scipy.integrate.dblquad(
lambda ti, tj: self.pdf(ti, tj, *pars),
0,
np.inf,
lambda tj: tj,
np.inf,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-3)
ck_t_i = scipy.integrate.dblquad(
lambda ti, tj: ti * self.pdf(ti, tj, *pars) / ck_normconst,
def test_leafward_moments(self, pars):
"""
Test mean and variance of child age when parent age is fixed to a nonzero value
"""
a_i, b_i, a_j, b_j, y, mu = pars
t_i = a_i / b_i # point "estimate" for parent
pars_redux = (a_j, b_j, y, mu)
logconst, t_j, _, var_t_j = approx.leafward_moments(t_i, *pars_redux)
ck_normconst = scipy.integrate.quad(
lambda t_j: self.pdf_leafward(t_i, t_j, *pars_redux),
0,
np.inf,
lambda tj: tj,
np.inf,
t_i,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=2e-3)
ck_t_j = scipy.integrate.dblquad(
lambda ti, tj: tj * self.pdf(ti, tj, *pars) / ck_normconst,
assert np.isclose(logconst, np.log(ck_normconst), rtol=2e-2)
ck_t_j = scipy.integrate.quad(
lambda t_j: t_j * self.pdf_leafward(t_i, t_j, *pars_redux) / ck_normconst,
0,
np.inf,
lambda tj: tj,
np.inf,
t_i,
epsabs=0,
)[0]
assert np.isclose(t_j, ck_t_j, rtol=2e-3)
ck_var_t_i = (
scipy.integrate.dblquad(
lambda ti, tj: ti**2 * self.pdf(ti, tj, *pars) / ck_normconst,
assert np.isclose(t_j, ck_t_j, rtol=2e-2)
ck_var_t_j = (
scipy.integrate.quad(
lambda t_j: t_j**2
* self.pdf_leafward(t_i, t_j, *pars_redux)
/ ck_normconst,
0,
np.inf,
lambda tj: tj,
np.inf,
t_i,
epsabs=0,
)[0]
- ck_t_i**2
- ck_t_j**2
)
assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-2)
ck_var_t_j = (
scipy.integrate.dblquad(
lambda ti, tj: tj**2 * self.pdf(ti, tj, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
np.inf,
assert np.isclose(var_t_j, ck_var_t_j, rtol=2e-2)

def test_truncated_moments(self, pars):
"""
Test mean and variance of child age when parent age is fixed to a nonzero value
"""
a_i, b_i, *_ = pars
upp = a_i / b_i * 2
low = a_i / b_i / 2
pars_redux = (low, upp, a_i, b_i)
logconst, t_i, _, var_t_i = approx.truncated_moments(*pars_redux)
ck_normconst = scipy.integrate.quad(
lambda t_i: self.pdf_truncated(t_i, *pars_redux),
low,
upp,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=1e-4)
ck_t_i = scipy.integrate.quad(
lambda t_i: t_i * self.pdf_truncated(t_i, *pars_redux) / ck_normconst,
low,
upp,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=1e-4)
ck_var_t_i = (
scipy.integrate.quad(
lambda t_i: t_i**2
* self.pdf_truncated(t_i, *pars_redux)
/ ck_normconst,
low,
upp,
epsabs=0,
)[0]
- ck_t_j**2
- ck_t_i**2
)
assert np.isclose(var_t_j, ck_var_t_j, rtol=1e-2)
assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-4)

def test_approximate_gamma(self, pars):
_, t_i, ln_t_i, t_j, ln_t_j = approx.sufficient_statistics(*pars)
def test_approximate_gamma_kl(self, pars):
_, t_i, ln_t_i, _, t_j, ln_t_j, _ = approx.moments(*pars)
alpha_i, beta_i = approx.approximate_gamma_kl(t_i, ln_t_i)
alpha_j, beta_j = approx.approximate_gamma_kl(t_j, ln_t_j)
ck_t_i = (alpha_i + 1) / beta_i
Expand All @@ -185,6 +297,19 @@ def test_approximate_gamma(self, pars):
ck_ln_t_j = hypergeo._digamma(alpha_j + 1) - np.log(beta_j)
assert np.isclose(ln_t_j, ck_ln_t_j)

def test_approximate_gamma_mom(self, pars):
_, t_i, _, va_t_i, t_j, _, va_t_j = approx.moments(*pars)
alpha_i, beta_i = approx.approximate_gamma_mom(t_i, va_t_i)
alpha_j, beta_j = approx.approximate_gamma_mom(t_j, va_t_j)
ck_t_i = (alpha_i + 1) / beta_i
assert np.isclose(t_i, ck_t_i)
ck_t_j = (alpha_j + 1) / beta_j
assert np.isclose(t_j, ck_t_j)
ck_va_t_i = (alpha_i + 1) / beta_i**2
assert np.isclose(va_t_i, ck_va_t_i)
ck_va_t_j = (alpha_j + 1) / beta_j**2
assert np.isclose(va_t_j, ck_va_t_j)


class TestPriorMomentMatching:
"""
Expand Down
Loading

0 comments on commit 9490955

Please sign in to comment.