From 5c61171fbb760e77bb0b0728487c50553d515dd5 Mon Sep 17 00:00:00 2001 From: Jonas Rauber Date: Tue, 11 Feb 2020 01:23:11 +0100 Subject: [PATCH] fixed a bug in TensorFlow getitem caught by the new test --- eagerpy/tensor/tensorflow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/eagerpy/tensor/tensorflow.py b/eagerpy/tensor/tensorflow.py index aafd479..5cba22d 100644 --- a/eagerpy/tensor/tensorflow.py +++ b/eagerpy/tensor/tensorflow.py @@ -508,7 +508,10 @@ def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: def __getitem__(self: TensorType, index: Any) -> TensorType: if isinstance(index, tuple): index = tuple(x.raw if isinstance(x, Tensor) else x for x in index) - basic = all(x is None or x is Ellipsis or isinstance(x, int) for x in index) + basic = all( + x is None or x is Ellipsis or isinstance(x, int) or isinstance(x, slice) + for x in index + ) if not basic: # workaround for missing support for this in TensorFlow # TODO: maybe convert each index individually and then stack them instead