Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ordering workaround to fix multi-peak models #8

Merged
merged 4 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions peak_performance/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def define_model_normal(time: np.ndarray, intensity: np.ndarray) -> pm.Model:

def double_model_mean_prior(time):
"""
Function creating prior probability distributions for double peaks using a ZeroSumNormal distribution.
Function creating prior probability distributions for multi-peaks using a ZeroSumNormal distribution.

Parameters
----------
Expand All @@ -203,23 +203,31 @@ def double_model_mean_prior(time):
Returns
-------
mean
Normally distributed prior for the ordered means of the double peak model.
Normally distributed prior for the ordered means of the multi-peak model.
diff
Difference between meanmean and mean.
Difference between the group mean and peak-wise mean.
meanmean
Normally distributed prior for the mean of the double peak means.
Normally distributed prior for the group mean of the peak means.
"""
pmodel = pm.modelcontext(None)
meanmean = pm.Normal("meanmean", mu=np.min(time) + np.ptp(time) / 2, sigma=np.ptp(time) / 6)
diff = pm.ZeroSumNormal(
"diff",
sigma=1,
shape=(2,), # currently no dims due to bug with ordered transformation
diff_unsorted = pm.ZeroSumNormal(
"diff_unsorted",
sigma=2,
# Support arbitrary number of subpeaks
shape=len(pmodel.coords["subpeak"]),
# NOTE: As of PyMC v5.14, the OrderedTransform and ZeroSumTransform are incompatible.
# See https://github.com/pymc-devs/pymc/issues/6975.
# As a workaround we'll call pt.sort a few lines below.
)
diff = pm.Deterministic("diff", pt.sort(diff_unsorted), dims="subpeak")
mean = pm.Normal(
"mean",
mu=meanmean + diff,
sigma=1,
transform=pm.distributions.transforms.ordered,
meanmean + diff,
# Introduce a small jitter to the subpeak means to decouple them
# from the strictly asymmetric ZeroSumNormal entries.
# This reduces the chances of unwanted bimodality.
sigma=0.01,
dims=("subpeak",),
)
return mean, diff, meanmean
Expand Down
1 change: 1 addition & 0 deletions peak_performance/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def sampling(pmodel, **sample_kwargs):
idata
Inference data object.
"""
sample_kwargs.setdefault("chains", 4)
sample_kwargs.setdefault("tune", 2000)
sample_kwargs.setdefault("draws", 2000)
# check if nutpie is available; if so, use it to enhance performance
Expand Down
57 changes: 47 additions & 10 deletions peak_performance/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import arviz as az
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest
import scipy.integrate
import scipy.stats as st
Expand All @@ -26,6 +27,43 @@ def test_initial_guesses():
pass


def test_zsn_sorting():
"""This tests a workaround that we rely on for multi-peak models."""
coords = {
"thing": ["left", "right"],
}
with pm.Model(coords=coords) as pmodel:
hyper = pm.Normal("hyper", mu=0, sigma=3)
diff = pm.ZeroSumNormal(
"diff",
sigma=1,
shape=2,
)
# Create a sorted deterministic without using transforms
diff_sorted = pm.Deterministic("diff_sorted", pt.sort(diff), dims="thing")
pos = pm.Deterministic(
"pos",
hyper + diff_sorted,
dims="thing",
)
# Observe the two things in incorrect order to provoke the model 😈
dat = pm.Data("dat", [0.2, -0.3], dims="thing")
pm.Normal("L", pos, observed=dat, dims="thing")

# Check draws from the prior
drawn = pm.draw(diff_sorted, draws=69)
np.testing.assert_array_less(drawn[:, 0], drawn[:, 1])

# And check MCMC draws too
with pmodel:
idata = pm.sample(
chains=1, tune=10, draws=69, step=pm.Metropolis(), compute_convergence_checks=False
)
sampled = idata.posterior["diff_sorted"].stack(sample=("chain", "draw")).values.T
np.testing.assert_array_less(sampled[:, 0], sampled[:, 1])
pass


class TestDistributions:
def test_normal_posterior(self):
x = np.linspace(-5, 10, 10000)
Expand Down Expand Up @@ -158,21 +196,20 @@ def test_double_skew_normal_posterior(self):


@pytest.mark.parametrize(
"model_type", ["normal", "skew_normal", "double_normal", "double_skew_normal"]
"model_type,define_func",
[
("normal", models.define_model_normal),
("skew_normal", models.define_model_skew),
("double_normal", models.define_model_double_normal),
("double_skew_normal", models.define_model_double_skew_normal),
],
)
def test_pymc_sampling(model_type):
def test_pymc_sampling(model_type, define_func):
timeseries = np.load(
Path(__file__).absolute().parent.parent / "example" / "A2t2R1Part1_132_85.9_86.1.npy"
)

if model_type == models.ModelType.Normal:
pmodel = models.define_model_normal(timeseries[0], timeseries[1])
elif model_type == models.ModelType.SkewNormal:
pmodel = models.define_model_skew(timeseries[0], timeseries[1])
elif model_type == models.ModelType.DoubleNormal:
pmodel = models.define_model_double_normal(timeseries[0], timeseries[1])
elif model_type == models.ModelType.DoubleSkewNormal:
pmodel = models.define_model_double_skew_normal(timeseries[0], timeseries[1])
pmodel = define_func(timeseries[0], timeseries[1])
with pmodel:
idata = pm.sample(cores=2, chains=2, tune=3, draws=5)
if model_type in [models.ModelType.DoubleNormal, models.ModelType.DoubleSkewNormal]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "peak_performance"
version = "0.6.5"
version = "0.7.0"
authors = [
{name = "Jochen Nießer", email = "j.niesser@fz-juelich.de"},
{name = "Michael Osthege", email = "m.osthege@fz-juelich.de"},
Expand Down
Loading