Skip to content

Commit

Permalink
Update RaveledParamsMap bad infer shape test
Browse files Browse the repository at this point in the history
This is in response to Aesara updates that now add static shape information at
the `Type`-level, instead of through `Op.infer_shape`, as the test previously
expected.
  • Loading branch information
brandonwillard committed Jan 3, 2023
1 parent b8b9641 commit ba81205
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from copy import copy
from types import MethodType

import aesara.tensor as at
import numpy as np
import pytest
from aesara.graph.basic import Apply
from aesara.tensor.exceptions import ShapeError
from aesara.tensor.random.basic import NormalRV

from aehmc.utils import RaveledParamsMap

Expand Down Expand Up @@ -81,12 +80,25 @@ def test_RaveledParamsMap_dtype():


def test_RaveledParamsMap_bad_infer_shape():
bad_normal_op = copy(at.random.normal)

def bad_infer_shape(self, *args, **kwargs):
raise ShapeError()

bad_normal_op.infer_shape = MethodType(bad_infer_shape, bad_normal_op)
class BadNormalRV(NormalRV):
def make_node(self, *args, **kwargs):
res = super().make_node(*args, **kwargs)
# Drop static `Type`-level shape information
rv_out = res.outputs[1]
outputs = [
res.outputs[0].clone(),
at.tensor(dtype=rv_out.type.dtype, shape=(None,) * rv_out.type.ndim),
]
return Apply(
self,
res.inputs,
outputs,
)

def infer_shape(self, *args, **kwargs):
raise ShapeError()

bad_normal_op = BadNormalRV()

size = (3, 2)
beta_rv = bad_normal_op(0, 1, size=size, name="beta")
Expand Down

0 comments on commit ba81205

Please sign in to comment.