Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jonasrauber/eagerpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Jan 25, 2020
2 parents 59db214 + 527c816 commit ff55ceb
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 15 deletions.
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
.. image:: https://badge.fury.io/py/eagerpy.svg
:target: https://badge.fury.io/py/eagerpy

.. image:: https://codecov.io/gh/jonasrauber/eagerpy/branch/master/graph/badge.svg
:target: https://codecov.io/gh/jonasrauber/eagerpy

.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/ambv/black

Expand Down
2 changes: 0 additions & 2 deletions eagerpy/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ def tile(t, multiples):


def matmul(x, y):
if not istensor(x):
return y.matmul(x)
return x.matmul(y)


Expand Down
12 changes: 10 additions & 2 deletions eagerpy/tensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ def __truediv__(self, other):
def __rtruediv__(self, other):
return self.tensor.__rtruediv__(other)

@unwrapin
@wrapout
def __floordiv__(self, other):
return self.tensor.__floordiv__(other)

@unwrapin
@wrapout
def __rfloordiv__(self, other):
return self.tensor.__rfloordiv__(other)

@unwrapin
@wrapout
def __lt__(self, other):
Expand Down Expand Up @@ -185,8 +195,6 @@ def ndim(self):

@property
def T(self):
if self.ndim < 2:
return self
return self.transpose()

def value_and_grad(self, f, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def onehot_like(self, indices, *, value=1):
assert indices.ndim == 1
x = self.backend.arange(self.tensor.shape[1]).reshape(1, -1)
indices = indices.reshape(-1, 1)
return x == indices
return (x == indices) * value

@wrapout
def from_numpy(self, a):
Expand Down
86 changes: 81 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ def test_astensor_tensor(t):
assert (ep.astensor(t) == t).all()


def test_module():
assert ep.istensor(ep.numpy.tanh([3, 5]))
assert not ep.istensor(ep.numpy.tanh(3))


def test_module_dir():
assert "zeros" in dir(ep.numpy)


def test_repr(t):
assert not repr(t).startswith("<")


def test_format(dummy):
t = ep.arange(dummy, 5).sum()
return f"{t:.1f}" == "10.0"


@compare_equal
def test_item(t):
t = t.sum()
return t.item()


@compare_equal
def test_len(t):
return len(t)
Expand Down Expand Up @@ -146,25 +170,45 @@ def test_rmul_scalar(t):


@compare_allclose
def test_div(t1, t2):
def test_truediv(t1, t2):
return t1 / t2


@compare_allclose(rtol=1e-6)
def test_div_scalar(t):
def test_truediv_scalar(t):
return t / 3


@compare_allclose
def test_rdiv_scalar(t):
def test_rtruediv_scalar(t):
return 3 / (abs(t) + 1e-8)


@compare_allclose
def test_floordiv(t1, t2):
return t1 // t2


@compare_allclose(rtol=1e-6)
def test_floordiv_scalar(t):
return t // 3


@compare_allclose
def test_rfloordiv_scalar(t):
return 3 // (abs(t) + 1e-8)


@compare_all
def test_getitem(t):
return t[2]


def test_getitem_tuple(dummy):
t = ep.arange(dummy, 8).float32().reshape((2, 4))
return t[1, 3]


@compare_all
def test_getitem_slice(t):
return t[1:3]
Expand Down Expand Up @@ -386,11 +430,29 @@ def test_logical_and(t):
return ep.logical_and(t < 3, t > 1)


@compare_all
def test_logical_and_scalar(t):
return ep.logical_and(True, t < 3)


def test_logical_and_manual(t):
assert (ep.logical_and(t < 3, ep.ones_like(t).bool()) == (t < 3)).all()


@compare_all
def test_logical_or(t):
return ep.logical_or(t > 3, t < 1)


@compare_all
def test_logical_or_scalar(t):
return ep.logical_or(True, t < 1)


def test_logical_or_manual(t):
assert (ep.logical_or(t < 3, ep.zeros_like(t).bool()) == (t < 3)).all()


@compare_all
def test_logical_not(t):
return ep.logical_not(t > 3)
Expand Down Expand Up @@ -440,6 +502,14 @@ def test_full_like(t):
return ep.full_like(t, 5)


@pytest.mark.parametrize("value", [1, -1, 2])
@compare_all
def test_onehot_like(dummy, value):
t = ep.arange(dummy, 18).float32().reshape((6, 3))
indices = ep.arange(t, 6) // 2
return ep.onehot_like(t, indices, value=value)


@compare_all
def test_zeros_scalar(t):
return ep.zeros(t, 5)
Expand Down Expand Up @@ -496,10 +566,16 @@ def test_argsort(t):


@compare_all
def test_transpose(t):
def test_transpose(dummy):
t = ep.arange(dummy, 8).float32().reshape((2, 4))
return ep.transpose(t)


def test_transpose_1d(dummy):
t = ep.arange(dummy, 8).float32()
assert (ep.transpose(t) == t).all()


@compare_all
def test_transpose_axes(dummy):
t = ep.arange(dummy, 60).float32().reshape((3, 4, 5))
Expand Down Expand Up @@ -578,7 +654,7 @@ def test_expand_dims(t, axis):
return ep.expand_dims(t, axis)


@pytest.mark.parametrize("axis", [0, 1, (0, 1)])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@compare_all
def test_squeeze(t, axis):
t = t.expand_dims(axis=0).expand_dims(axis=1)
Expand Down
11 changes: 6 additions & 5 deletions tests/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ def test_1d(x1d, p):
assert_allclose(norms[p](x1d).numpy(), norm(x1d.numpy(), ord=p))


@pytest.mark.parametrize("p", [0, 1, 2, ep.inf])
@pytest.mark.parametrize("p", [0, 1, 2, 3, 4, ep.inf])
@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.parametrize("keepdims", [False, True])
def test_2d(x2d, p, axis, keepdims):
assert_allclose(
lp(x2d, p, axis=axis, keepdims=keepdims).numpy(),
norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims),
)
assert_allclose(
norms[p](x2d, axis=axis, keepdims=keepdims).numpy(),
norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims),
)
if p in norms:
assert_allclose(
norms[p](x2d, axis=axis, keepdims=keepdims).numpy(),
norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims),
)

0 comments on commit ff55ceb

Please sign in to comment.