Skip to content

Commit

Permalink
reshape now accepts single ints (in particular -1), not just tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 11, 2020
1 parent 946f33c commit 36775b2
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 7 deletions.
4 changes: 2 additions & 2 deletions eagerpy/framework.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast
from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast, Union
from typing_extensions import Literal

from .types import Axes, AxisAxes, Shape, ShapeOrScalar
Expand Down Expand Up @@ -348,7 +348,7 @@ def value_aux_and_grad(
return t.value_aux_and_grad(f, *args, **kwargs)


def reshape(t: TensorType, shape: Shape) -> TensorType:
def reshape(t: TensorType, shape: Union[Shape, int]) -> TensorType:
return t.reshape(shape)


Expand Down
4 changes: 3 additions & 1 deletion eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def item(self) -> Union[int, float, bool]:
def shape(self) -> Shape:
return cast(Tuple, self.raw.shape)

def reshape(self: TensorType, shape: Shape) -> TensorType:
def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType:
if isinstance(shape, int):
shape = (shape,)
return type(self)(self.raw.reshape(shape))

def astype(self: TensorType, dtype: Any) -> TensorType:
Expand Down
4 changes: 3 additions & 1 deletion eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def item(self) -> Union[int, float, bool]:
def shape(self: TensorType) -> Shape:
return cast(Tuple, self.raw.shape)

def reshape(self: TensorType, shape: Shape) -> TensorType:
def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType:
if isinstance(shape, int):
shape = (shape,)
return type(self)(self.raw.reshape(shape))

def astype(self: TensorType, dtype: Any) -> TensorType:
Expand Down
4 changes: 3 additions & 1 deletion eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def item(self) -> Union[int, float, bool]:
def shape(self) -> Shape:
return self.raw.shape

def reshape(self: TensorType, shape: Shape) -> TensorType:
def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType:
if isinstance(shape, int):
shape = (shape,)
return type(self)(self.raw.reshape(shape))

def astype(self: TensorType, dtype: Any) -> TensorType:
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def shape(self: TensorType) -> Shape:
...

@abstractmethod
def reshape(self: TensorType, shape: Shape) -> TensorType:
def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType:
...

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def item(self: TensorType) -> Union[int, float, bool]:
def shape(self: TensorType) -> Shape:
return tuple(self.raw.shape.as_list())

def reshape(self: TensorType, shape: Shape) -> TensorType:
def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType:
if isinstance(shape, int):
shape = (shape,)
return type(self)(tf.reshape(self.raw, shape))

def astype(self: TensorType, dtype: Any) -> TensorType:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,19 @@ def test_reshape(t: Tensor) -> Tensor:
return ep.reshape(t, shape)


@compare_all
def test_reshape_minus_1(t: Tensor) -> Tensor:
return ep.reshape(t, -1)


@compare_all
def test_reshape_int(t: Tensor) -> Tensor:
n = 1
for k in t.shape:
n *= k
return ep.reshape(t, n)


@compare_all
def test_clip(t: Tensor) -> Tensor:
return ep.clip(t, 2, 3.5)
Expand Down

0 comments on commit 36775b2

Please sign in to comment.