From d0ea6af20ce8a70707ef9caf37beb7d52c5822d7 Mon Sep 17 00:00:00 2001 From: Jonas Rauber Date: Sat, 14 Mar 2020 15:09:13 +0100 Subject: [PATCH] squeeze now raises if a concrete axis is not 1 --- eagerpy/tensor/jax.py | 8 ++++++++ eagerpy/tensor/pytorch.py | 4 ++++ tests/test_main.py | 15 +++++++++++++-- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/eagerpy/tensor/jax.py b/eagerpy/tensor/jax.py index 53a3cac..3840a0d 100644 --- a/eagerpy/tensor/jax.py +++ b/eagerpy/tensor/jax.py @@ -291,6 +291,14 @@ def log_softmax(self: TensorType, axis: int = -1) -> TensorType: return type(self)(jax.nn.log_softmax(self.raw, axis=axis)) def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: + if axis is not None: + # workaround for https://github.com/google/jax/issues/2284 + axis = (axis,) if isinstance(axis, int) else axis + shape = self.shape + if any(shape[i] != 1 for i in axis): + raise ValueError( + "cannot select an axis to squeeze out which has size not equal to one" + ) return type(self)(self.raw.squeeze(axis=axis)) def expand_dims(self: TensorType, axis: int) -> TensorType: diff --git a/eagerpy/tensor/pytorch.py b/eagerpy/tensor/pytorch.py index dd07757..17f082f 100644 --- a/eagerpy/tensor/pytorch.py +++ b/eagerpy/tensor/pytorch.py @@ -338,6 +338,10 @@ def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: axis = (axis,) x = self.raw for i in sorted(axis, reverse=True): + if x.shape[i] != 1: + raise ValueError( + "cannot select an axis to squeeze out which has size not equal to one" + ) x = x.squeeze(dim=i) return type(self)(x) diff --git a/tests/test_main.py b/tests/test_main.py index 7035390..507b3e6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,7 +4,7 @@ import numpy as np import eagerpy as ep from eagerpy import Tensor -from eagerpy.types import Shape +from eagerpy.types import Shape, AxisAxes # make sure there are no undecorated tests in the "special tests" section below # -> /\n\ndef test_ @@ -382,6 +382,17 @@ def test_flatten(dummy: Tensor) -> None: assert ep.flatten(t, start=1, end=-2).shape == (16, 3 * 32, 32) +@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) +def test_squeeze_not_one(dummy: Tensor, axis: Optional[AxisAxes]) -> None: + t = ep.zeros(dummy, (3, 4, 5)) + if axis is None: + t.squeeze(axis=axis) + else: + with pytest.raises(Exception): + # squeezing specifc axis should fail if they are not 1 + t.squeeze(axis=axis) + + ############################################################################### # special tests # - decorated with compare_* @@ -1231,7 +1242,7 @@ def test_expand_dims(t: Tensor, axis: int) -> Tensor: @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) @compare_all -def test_squeeze(t: Tensor, axis: Optional[int]) -> Tensor: +def test_squeeze(t: Tensor, axis: Optional[AxisAxes]) -> Tensor: t = t.expand_dims(axis=0).expand_dims(axis=1) return ep.squeeze(t, axis=axis)