diff --git a/conjugate/distributions.py b/conjugate/distributions.py index e054615..3dd35bb 100644 --- a/conjugate/distributions.py +++ b/conjugate/distributions.py @@ -556,15 +556,37 @@ def _sample_beta_1d(self, variance, size: int, random_state=None) -> NUMERIC: return stats.norm(self.mu, sigma).rvs(size=size, random_state=random_state) def _sample_beta_nd(self, variance, size: int, random_state=None) -> NUMERIC: + variance = (self.delta_inverse[None, ...].T * variance).T return np.stack( [ - stats.multivariate_normal(self.mu, v * self.delta_inverse).rvs( + stats.multivariate_normal(self.mu, v).rvs( size=1, random_state=random_state ) for v in variance ] ) + def sample_mean( + self, + size: int, + return_variance: bool = False, + random_state=None, + ) -> Union[NUMERIC, Tuple[NUMERIC, NUMERIC]]: + """Sample the mean from the normal distribution. + + Args: + size: number of samples + return_variance: whether to return variance as well + random_state: random state + + Returns: + samples from the normal distribution and optionally variance + + """ + return self.sample_beta( + size=size, return_variance=return_variance, random_state=random_state + ) + def sample_beta( self, size: int, return_variance: bool = False, random_state=None ) -> Union[NUMERIC, Tuple[NUMERIC, NUMERIC]]: diff --git a/pyproject.toml b/pyproject.toml index 2195d4f..cfc5eb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "conjugate-models" -version = "0.6.0" +version = "0.7.0" description = "Bayesian Conjugate Models in Python" authors = ["Will Dean "] license = "MIT" diff --git a/tests/test_distributions.py b/tests/test_distributions.py index d995850..6354a7c 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -23,6 +23,8 @@ MultivariateStudentT, InverseWishart, NormalInverseWishart, + InverseGamma, + NormalInverseGamma, ) @@ -247,3 +249,28 @@ def test_normal_inverse_wishart() -> None: mean = distribution.sample_mean(size=1) assert mean.shape == (1, 2) + + +@pytest.mark.parametrize("n_features", [1, 2, 3]) +@pytest.mark.parametrize("n_samples", [1, 2, 10]) +def test_normal_inverse_gamma(n_features, n_samples) -> None: + mu = np.zeros(n_features) + delta_inverse = np.eye(n_features) + distribution = NormalInverseGamma( + mu=mu, + alpha=1, + beta=1, + delta_inverse=delta_inverse, + ) + + assert isinstance(distribution.inverse_gamma, InverseGamma) + + variance = distribution.sample_variance(size=n_samples) + assert variance.shape == (n_samples,) + + mean = distribution.sample_mean(size=n_samples) + + if n_features == 1: + assert mean.shape == (n_samples,) + else: + assert mean.shape == (n_samples, n_features)