Skip to content

Commit

Permalink
up the version and some tests for distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jan 22, 2024
1 parent 12761b7 commit 44d92a2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 2 deletions.
24 changes: 23 additions & 1 deletion conjugate/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,15 +556,37 @@ def _sample_beta_1d(self, variance, size: int, random_state=None) -> NUMERIC:
return stats.norm(self.mu, sigma).rvs(size=size, random_state=random_state)

def _sample_beta_nd(self, variance, size: int, random_state=None) -> NUMERIC:
variance = (self.delta_inverse[None, ...].T * variance).T
return np.stack(
[
stats.multivariate_normal(self.mu, v * self.delta_inverse).rvs(
stats.multivariate_normal(self.mu, v).rvs(
size=1, random_state=random_state
)
for v in variance
]
)

def sample_mean(
self,
size: int,
return_variance: bool = False,
random_state=None,
) -> Union[NUMERIC, Tuple[NUMERIC, NUMERIC]]:
"""Sample the mean from the normal distribution.
Args:
size: number of samples
return_variance: whether to return variance as well
random_state: random state
Returns:
samples from the normal distribution and optionally variance
"""
return self.sample_beta(
size=size, return_variance=return_variance, random_state=random_state
)

def sample_beta(
self, size: int, return_variance: bool = False, random_state=None
) -> Union[NUMERIC, Tuple[NUMERIC, NUMERIC]]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "conjugate-models"
version = "0.6.0"
version = "0.7.0"
description = "Bayesian Conjugate Models in Python"
authors = ["Will Dean <wd60622@gmail.com>"]
license = "MIT"
Expand Down
27 changes: 27 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
MultivariateStudentT,
InverseWishart,
NormalInverseWishart,
InverseGamma,
NormalInverseGamma,
)


Expand Down Expand Up @@ -247,3 +249,28 @@ def test_normal_inverse_wishart() -> None:

mean = distribution.sample_mean(size=1)
assert mean.shape == (1, 2)


@pytest.mark.parametrize("n_features", [1, 2, 3])
@pytest.mark.parametrize("n_samples", [1, 2, 10])
def test_normal_inverse_gamma(n_features, n_samples) -> None:
mu = np.zeros(n_features)
delta_inverse = np.eye(n_features)
distribution = NormalInverseGamma(
mu=mu,
alpha=1,
beta=1,
delta_inverse=delta_inverse,
)

assert isinstance(distribution.inverse_gamma, InverseGamma)

variance = distribution.sample_variance(size=n_samples)
assert variance.shape == (n_samples,)

mean = distribution.sample_mean(size=n_samples)

if n_features == 1:
assert mean.shape == (n_samples,)
else:
assert mean.shape == (n_samples, n_features)

0 comments on commit 44d92a2

Please sign in to comment.