Skip to content

Commit

Permalink
fixed a bug in TensorFlow getitem caught by the new test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 11, 2020
1 parent d854fc8 commit 5c61171
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,10 @@ def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType:
def __getitem__(self: TensorType, index: Any) -> TensorType:
if isinstance(index, tuple):
index = tuple(x.raw if isinstance(x, Tensor) else x for x in index)
basic = all(x is None or x is Ellipsis or isinstance(x, int) for x in index)
basic = all(
x is None or x is Ellipsis or isinstance(x, int) or isinstance(x, slice)
for x in index
)
if not basic:
# workaround for missing support for this in TensorFlow
# TODO: maybe convert each index individually and then stack them instead
Expand Down

0 comments on commit 5c61171

Please sign in to comment.