diff --git a/aehmc/metrics.py b/aehmc/metrics.py index 1a31792..3668cbb 100644 --- a/aehmc/metrics.py +++ b/aehmc/metrics.py @@ -1,9 +1,9 @@ from typing import Callable, Tuple import aesara.tensor as at -import aesara.tensor.slinalg as slinalg from aesara.tensor.random.utils import RandomStream from aesara.tensor.shape import shape_tuple +from aesara.tensor.slinalg import cholesky, solve_triangular from aesara.tensor.var import TensorVariable @@ -51,9 +51,9 @@ def gaussian_metric( dot, matmul = at.dot, lambda x, y: x * y elif inverse_mass_matrix.ndim == 2: shape = (shape_tuple(inverse_mass_matrix)[0],) - tril_inv = slinalg.cholesky(inverse_mass_matrix) + tril_inv = cholesky(inverse_mass_matrix) identity = at.eye(*shape) - mass_matrix_sqrt = slinalg.solve_lower_triangular(tril_inv, identity) + mass_matrix_sqrt = solve_triangular(tril_inv, identity, lower=True) dot, matmul = at.dot, at.dot else: raise ValueError( diff --git a/aehmc/proposals.py b/aehmc/proposals.py index eddf90a..c3bcbf1 100644 --- a/aehmc/proposals.py +++ b/aehmc/proposals.py @@ -40,7 +40,7 @@ def update(initial_energy, state): delta_energy = initial_energy - new_energy delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy) - is_transition_divergent = at.abs_(delta_energy) > divergence_threshold + is_transition_divergent = at.abs(delta_energy) > divergence_threshold weight = delta_energy log_p_accept = at.where( diff --git a/aehmc/utils.py b/aehmc/utils.py index 26665e9..58dbd7c 100644 --- a/aehmc/utils.py +++ b/aehmc/utils.py @@ -5,7 +5,7 @@ from aesara.graph.basic import Variable, ancestors from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.utils import rewrite_graph -from aesara.tensor.rewriting.shape import ShapeFeature +from aesara.tensor.rewriting.basic import ShapeFeature from aesara.tensor.var import TensorVariable diff --git a/setup.cfg b/setup.cfg index 719fce5..3863ac2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,6 +16,11 @@ convention = numpy [tool:pytest] python_files=test*.py testpaths=tests +filterwarnings= + error:::aesara + error:::aeppl + error:::aemcmc + ignore:::xarray [coverage:run] omit =