From ba812056ea7da7b5971f6bd9471599199c8ff6d3 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 3 Jan 2023 15:26:28 -0600 Subject: [PATCH] Update RaveledParamsMap bad infer shape test This is in response to Aesara updates that now add static shape information at the `Type`-level, instead of through `Op.infer_shape`, as the test previously expected. --- tests/test_utils.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 92367ea..d1bdf0d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,9 @@ -from copy import copy -from types import MethodType - import aesara.tensor as at import numpy as np import pytest +from aesara.graph.basic import Apply from aesara.tensor.exceptions import ShapeError +from aesara.tensor.random.basic import NormalRV from aehmc.utils import RaveledParamsMap @@ -81,12 +80,25 @@ def test_RaveledParamsMap_dtype(): def test_RaveledParamsMap_bad_infer_shape(): - bad_normal_op = copy(at.random.normal) - - def bad_infer_shape(self, *args, **kwargs): - raise ShapeError() - - bad_normal_op.infer_shape = MethodType(bad_infer_shape, bad_normal_op) + class BadNormalRV(NormalRV): + def make_node(self, *args, **kwargs): + res = super().make_node(*args, **kwargs) + # Drop static `Type`-level shape information + rv_out = res.outputs[1] + outputs = [ + res.outputs[0].clone(), + at.tensor(dtype=rv_out.type.dtype, shape=(None,) * rv_out.type.ndim), + ] + return Apply( + self, + res.inputs, + outputs, + ) + + def infer_shape(self, *args, **kwargs): + raise ShapeError() + + bad_normal_op = BadNormalRV() size = (3, 2) beta_rv = bad_normal_op(0, 1, size=size, name="beta")