Skip to content

Commit

Permalink
squeeze now raises if a concrete axis is not 1
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Mar 14, 2020
1 parent 088aae1 commit d0ea6af
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
8 changes: 8 additions & 0 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ def log_softmax(self: TensorType, axis: int = -1) -> TensorType:
return type(self)(jax.nn.log_softmax(self.raw, axis=axis))

def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
if axis is not None:
# workaround for https://github.com/google/jax/issues/2284
axis = (axis,) if isinstance(axis, int) else axis
shape = self.shape
if any(shape[i] != 1 for i in axis):
raise ValueError(
"cannot select an axis to squeeze out which has size not equal to one"
)
return type(self)(self.raw.squeeze(axis=axis))

def expand_dims(self: TensorType, axis: int) -> TensorType:
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def squeeze(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
axis = (axis,)
x = self.raw
for i in sorted(axis, reverse=True):
if x.shape[i] != 1:
raise ValueError(
"cannot select an axis to squeeze out which has size not equal to one"
)
x = x.squeeze(dim=i)
return type(self)(x)

Expand Down
15 changes: 13 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import eagerpy as ep
from eagerpy import Tensor
from eagerpy.types import Shape
from eagerpy.types import Shape, AxisAxes

# make sure there are no undecorated tests in the "special tests" section below
# -> /\n\ndef test_
Expand Down Expand Up @@ -382,6 +382,17 @@ def test_flatten(dummy: Tensor) -> None:
assert ep.flatten(t, start=1, end=-2).shape == (16, 3 * 32, 32)


@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
def test_squeeze_not_one(dummy: Tensor, axis: Optional[AxisAxes]) -> None:
t = ep.zeros(dummy, (3, 4, 5))
if axis is None:
t.squeeze(axis=axis)
else:
with pytest.raises(Exception):
# squeezing specifc axis should fail if they are not 1
t.squeeze(axis=axis)


###############################################################################
# special tests
# - decorated with compare_*
Expand Down Expand Up @@ -1231,7 +1242,7 @@ def test_expand_dims(t: Tensor, axis: int) -> Tensor:

@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@compare_all
def test_squeeze(t: Tensor, axis: Optional[int]) -> Tensor:
def test_squeeze(t: Tensor, axis: Optional[AxisAxes]) -> Tensor:
t = t.expand_dims(axis=0).expand_dims(axis=1)
return ep.squeeze(t, axis=axis)

Expand Down

0 comments on commit d0ea6af

Please sign in to comment.