diff --git a/eagerpy/framework.py b/eagerpy/framework.py index 40a24e8..49a241d 100644 --- a/eagerpy/framework.py +++ b/eagerpy/framework.py @@ -1,4 +1,4 @@ -from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast +from typing import overload, Sequence, Callable, Tuple, Any, Optional, cast, Union from typing_extensions import Literal from .types import Axes, AxisAxes, Shape, ShapeOrScalar @@ -348,7 +348,7 @@ def value_aux_and_grad( return t.value_aux_and_grad(f, *args, **kwargs) -def reshape(t: TensorType, shape: Shape) -> TensorType: +def reshape(t: TensorType, shape: Union[Shape, int]) -> TensorType: return t.reshape(shape) diff --git a/eagerpy/tensor/jax.py b/eagerpy/tensor/jax.py index 14c9a73..0828114 100644 --- a/eagerpy/tensor/jax.py +++ b/eagerpy/tensor/jax.py @@ -108,7 +108,9 @@ def item(self) -> Union[int, float, bool]: def shape(self) -> Shape: return cast(Tuple, self.raw.shape) - def reshape(self: TensorType, shape: Shape) -> TensorType: + def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: + if isinstance(shape, int): + shape = (shape,) return type(self)(self.raw.reshape(shape)) def astype(self: TensorType, dtype: Any) -> TensorType: diff --git a/eagerpy/tensor/numpy.py b/eagerpy/tensor/numpy.py index fadbf52..0e03c83 100644 --- a/eagerpy/tensor/numpy.py +++ b/eagerpy/tensor/numpy.py @@ -56,7 +56,9 @@ def item(self) -> Union[int, float, bool]: def shape(self: TensorType) -> Shape: return cast(Tuple, self.raw.shape) - def reshape(self: TensorType, shape: Shape) -> TensorType: + def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: + if isinstance(shape, int): + shape = (shape,) return type(self)(self.raw.reshape(shape)) def astype(self: TensorType, dtype: Any) -> TensorType: diff --git a/eagerpy/tensor/pytorch.py b/eagerpy/tensor/pytorch.py index 3a942fc..62ed64d 100644 --- a/eagerpy/tensor/pytorch.py +++ b/eagerpy/tensor/pytorch.py @@ -71,7 +71,9 @@ def item(self) -> Union[int, float, bool]: def shape(self) -> Shape: return self.raw.shape - def reshape(self: TensorType, shape: Shape) -> TensorType: + def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: + if isinstance(shape, int): + shape = (shape,) return type(self)(self.raw.reshape(shape)) def astype(self: TensorType, dtype: Any) -> TensorType: diff --git a/eagerpy/tensor/tensor.py b/eagerpy/tensor/tensor.py index 108230f..0c9e846 100644 --- a/eagerpy/tensor/tensor.py +++ b/eagerpy/tensor/tensor.py @@ -235,7 +235,7 @@ def shape(self: TensorType) -> Shape: ... @abstractmethod - def reshape(self: TensorType, shape: Shape) -> TensorType: + def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: ... @abstractmethod diff --git a/eagerpy/tensor/tensorflow.py b/eagerpy/tensor/tensorflow.py index ccc68e9..8b9ec9b 100644 --- a/eagerpy/tensor/tensorflow.py +++ b/eagerpy/tensor/tensorflow.py @@ -103,7 +103,9 @@ def item(self: TensorType) -> Union[int, float, bool]: def shape(self: TensorType) -> Shape: return tuple(self.raw.shape.as_list()) - def reshape(self: TensorType, shape: Shape) -> TensorType: + def reshape(self: TensorType, shape: Union[Shape, int]) -> TensorType: + if isinstance(shape, int): + shape = (shape,) return type(self)(tf.reshape(self.raw, shape)) def astype(self: TensorType, dtype: Any) -> TensorType: diff --git a/tests/test_main.py b/tests/test_main.py index c81bfe7..ae0a4bd 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -655,6 +655,19 @@ def test_reshape(t: Tensor) -> Tensor: return ep.reshape(t, shape) +@compare_all +def test_reshape_minus_1(t: Tensor) -> Tensor: + return ep.reshape(t, -1) + + +@compare_all +def test_reshape_int(t: Tensor) -> Tensor: + n = 1 + for k in t.shape: + n *= k + return ep.reshape(t, n) + + @compare_all def test_clip(t: Tensor) -> Tensor: return ep.clip(t, 2, 3.5)