diff --git a/README.rst b/README.rst index 6a77149..4285ce4 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,9 @@ .. image:: https://badge.fury.io/py/eagerpy.svg :target: https://badge.fury.io/py/eagerpy +.. image:: https://codecov.io/gh/jonasrauber/eagerpy/branch/master/graph/badge.svg + :target: https://codecov.io/gh/jonasrauber/eagerpy + .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/ambv/black diff --git a/eagerpy/framework.py b/eagerpy/framework.py index e059026..92dbd08 100644 --- a/eagerpy/framework.py +++ b/eagerpy/framework.py @@ -167,8 +167,6 @@ def tile(t, multiples): def matmul(x, y): - if not istensor(x): - return y.matmul(x) return x.matmul(y) diff --git a/eagerpy/tensor/base.py b/eagerpy/tensor/base.py index d9dc373..a0543cc 100644 --- a/eagerpy/tensor/base.py +++ b/eagerpy/tensor/base.py @@ -115,6 +115,16 @@ def __truediv__(self, other): def __rtruediv__(self, other): return self.tensor.__rtruediv__(other) + @unwrapin + @wrapout + def __floordiv__(self, other): + return self.tensor.__floordiv__(other) + + @unwrapin + @wrapout + def __rfloordiv__(self, other): + return self.tensor.__rfloordiv__(other) + @unwrapin @wrapout def __lt__(self, other): @@ -185,8 +195,6 @@ def ndim(self): @property def T(self): - if self.ndim < 2: - return self return self.transpose() def value_and_grad(self, f, *args, **kwargs): diff --git a/eagerpy/tensor/jax.py b/eagerpy/tensor/jax.py index 32e439e..0f80863 100644 --- a/eagerpy/tensor/jax.py +++ b/eagerpy/tensor/jax.py @@ -154,7 +154,7 @@ def onehot_like(self, indices, *, value=1): assert indices.ndim == 1 x = self.backend.arange(self.tensor.shape[1]).reshape(1, -1) indices = indices.reshape(-1, 1) - return x == indices + return (x == indices) * value @wrapout def from_numpy(self, a): diff --git a/tests/test_main.py b/tests/test_main.py index b594dc1..1617ff6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -70,6 +70,30 @@ def test_astensor_tensor(t): assert (ep.astensor(t) == t).all() +def test_module(): + assert ep.istensor(ep.numpy.tanh([3, 5])) + assert not ep.istensor(ep.numpy.tanh(3)) + + +def test_module_dir(): + assert "zeros" in dir(ep.numpy) + + +def test_repr(t): + assert not repr(t).startswith("<") + + +def test_format(dummy): + t = ep.arange(dummy, 5).sum() + return f"{t:.1f}" == "10.0" + + +@compare_equal +def test_item(t): + t = t.sum() + return t.item() + + @compare_equal def test_len(t): return len(t) @@ -146,25 +170,45 @@ def test_rmul_scalar(t): @compare_allclose -def test_div(t1, t2): +def test_truediv(t1, t2): return t1 / t2 @compare_allclose(rtol=1e-6) -def test_div_scalar(t): +def test_truediv_scalar(t): return t / 3 @compare_allclose -def test_rdiv_scalar(t): +def test_rtruediv_scalar(t): return 3 / (abs(t) + 1e-8) +@compare_allclose +def test_floordiv(t1, t2): + return t1 // t2 + + +@compare_allclose(rtol=1e-6) +def test_floordiv_scalar(t): + return t // 3 + + +@compare_allclose +def test_rfloordiv_scalar(t): + return 3 // (abs(t) + 1e-8) + + @compare_all def test_getitem(t): return t[2] +def test_getitem_tuple(dummy): + t = ep.arange(dummy, 8).float32().reshape((2, 4)) + return t[1, 3] + + @compare_all def test_getitem_slice(t): return t[1:3] @@ -386,11 +430,29 @@ def test_logical_and(t): return ep.logical_and(t < 3, t > 1) +@compare_all +def test_logical_and_scalar(t): + return ep.logical_and(True, t < 3) + + +def test_logical_and_manual(t): + assert (ep.logical_and(t < 3, ep.ones_like(t).bool()) == (t < 3)).all() + + @compare_all def test_logical_or(t): return ep.logical_or(t > 3, t < 1) +@compare_all +def test_logical_or_scalar(t): + return ep.logical_or(True, t < 1) + + +def test_logical_or_manual(t): + assert (ep.logical_or(t < 3, ep.zeros_like(t).bool()) == (t < 3)).all() + + @compare_all def test_logical_not(t): return ep.logical_not(t > 3) @@ -440,6 +502,14 @@ def test_full_like(t): return ep.full_like(t, 5) +@pytest.mark.parametrize("value", [1, -1, 2]) +@compare_all +def test_onehot_like(dummy, value): + t = ep.arange(dummy, 18).float32().reshape((6, 3)) + indices = ep.arange(t, 6) // 2 + return ep.onehot_like(t, indices, value=value) + + @compare_all def test_zeros_scalar(t): return ep.zeros(t, 5) @@ -496,10 +566,16 @@ def test_argsort(t): @compare_all -def test_transpose(t): +def test_transpose(dummy): + t = ep.arange(dummy, 8).float32().reshape((2, 4)) return ep.transpose(t) +def test_transpose_1d(dummy): + t = ep.arange(dummy, 8).float32() + assert (ep.transpose(t) == t).all() + + @compare_all def test_transpose_axes(dummy): t = ep.arange(dummy, 60).float32().reshape((3, 4, 5)) @@ -578,7 +654,7 @@ def test_expand_dims(t, axis): return ep.expand_dims(t, axis) -@pytest.mark.parametrize("axis", [0, 1, (0, 1)]) +@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) @compare_all def test_squeeze(t, axis): t = t.expand_dims(axis=0).expand_dims(axis=1) diff --git a/tests/test_norms.py b/tests/test_norms.py index 4437a46..9a30bbe 100644 --- a/tests/test_norms.py +++ b/tests/test_norms.py @@ -24,7 +24,7 @@ def test_1d(x1d, p): assert_allclose(norms[p](x1d).numpy(), norm(x1d.numpy(), ord=p)) -@pytest.mark.parametrize("p", [0, 1, 2, ep.inf]) +@pytest.mark.parametrize("p", [0, 1, 2, 3, 4, ep.inf]) @pytest.mark.parametrize("axis", [0, 1, -1]) @pytest.mark.parametrize("keepdims", [False, True]) def test_2d(x2d, p, axis, keepdims): @@ -32,7 +32,8 @@ def test_2d(x2d, p, axis, keepdims): lp(x2d, p, axis=axis, keepdims=keepdims).numpy(), norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims), ) - assert_allclose( - norms[p](x2d, axis=axis, keepdims=keepdims).numpy(), - norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims), - ) + if p in norms: + assert_allclose( + norms[p](x2d, axis=axis, keepdims=keepdims).numpy(), + norm(x2d.numpy(), ord=p, axis=axis, keepdims=keepdims), + )