From 6bee8963a38276714a31ce97aaa77d1c61cb0f05 Mon Sep 17 00:00:00 2001 From: Jonas Rauber Date: Thu, 13 Feb 2020 23:00:24 +0100 Subject: [PATCH] changed return type of norms.l0 to same as input like lx (was always int) --- eagerpy/norms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eagerpy/norms.py b/eagerpy/norms.py index f8f9b1f..60f2148 100644 --- a/eagerpy/norms.py +++ b/eagerpy/norms.py @@ -8,7 +8,7 @@ def l0( x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: - return (x != 0).sum(axis=axis, keepdims=keepdims) + return (x != 0).sum(axis=axis, keepdims=keepdims).astype(x.dtype) def l1(