diff --git a/eagerpy/norms.py b/eagerpy/norms.py index f8f9b1f..60f2148 100644 --- a/eagerpy/norms.py +++ b/eagerpy/norms.py @@ -8,7 +8,7 @@ def l0( x: TensorType, axis: Optional[AxisAxes] = None, keepdims: bool = False ) -> TensorType: - return (x != 0).sum(axis=axis, keepdims=keepdims) + return (x != 0).sum(axis=axis, keepdims=keepdims).astype(x.dtype) def l1(