Skip to content

Commit

Permalink
Mixture prior works
Browse files Browse the repository at this point in the history
More streamlined numerical checks

Initialize gamma mixture from conditional coalescent prior

Add pdf

Update mixture.py to use natural parameterization

WIP

Moved fully into numba

Cleanup

Cleanup

More debugging

WIP

Working

wording

Add missing constant to loglikelihood

Skip prior update completely instead of components

Skip prior update completely instead of components

Remove verbose; use logweights in conditional posterior

Move mixture initialization to function

Docstrings and CLI

Remove some debugging inserts

Remove preemptive reference

Fix tests
  • Loading branch information
nspope committed Dec 29, 2023
1 parent 9c2c977 commit ced5d8b
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 70 deletions.
7 changes: 0 additions & 7 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,13 +799,6 @@ def test_variational_nosize(self):
with pytest.raises(ValueError, match="Must specify population size"):
variational_dates(ts, mutation_rate=1)

def test_variational_toomanysizes(self):
ts = utility_functions.two_tree_mutation_ts()
Ne = 1
priors = tsdate.build_prior_grid(ts, Ne, np.array([0, 1.2, 2]))
with pytest.raises(ValueError, match="Cannot specify"):
variational_dates(ts, mutation_rate=1, population_size=Ne, priors=priors)


class TestNodeGridValuesClass:
def test_init(self):
Expand Down
21 changes: 14 additions & 7 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,13 @@ def test_nonglobal_priors(self):
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
grid = priors.make_parameter_grid(population_size=1)
grid.grid_data[:] = [1.0, 0.0] # noninformative prior
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
global_prior=False,
)
with pytest.raises(ValueError, match="not yet implemented"):
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
)

def test_bad_arguments(self):
ts = utility_functions.two_tree_mutation_ts()
Expand All @@ -437,6 +437,13 @@ def test_bad_arguments(self):
method="variational_gamma",
max_iterations=-1,
)
with pytest.raises(ValueError, match="must be a positive integer"):
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
global_prior=False,
)

def test_match_central_moments(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
Expand Down
31 changes: 31 additions & 0 deletions tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,34 @@ def tsdate_cli_parser():
"but does not exactly minimize KL divergence in each EP update."
),
)
parser.add_argument(
"--max-iterations",
type=int,
help=(
"The number of iterations used in the expectation propagation "
"algorithm. Default: 20"
),
default=20,
)
parser.add_argument(
"--em-iterations",
type=int,
help=(
"The number of expectation-maximization iterations used to optimize the "
"global mixture prior at the end of each expectation propagation step. "
"Setting to zero disables optimization. Default: 10"
),
default=10,
)
parser.add_argument(
"--global-prior",
type=int,
help=(
"The number of components in the i.i.d. mixture prior for node "
"ages. Default: 1"
),
default=1,
)
parser.set_defaults(runner=run_date)

parser = subparsers.add_parser(
Expand Down Expand Up @@ -253,8 +281,11 @@ def run_date(args):
method=args.method,
eps=args.epsilon,
progress=args.progress,
max_iterations=args.max_iterations,
max_shape=args.max_shape,
match_central_moments=args.match_central_moments,
em_iterations=args.em_iterations,
global_prior=args.global_prior,
)
else:
params = dict(
Expand Down
Loading

0 comments on commit ced5d8b

Please sign in to comment.