From 68cf2681d2df6535d2ad950b8385c6a899fa839b Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 19 May 2024 22:57:00 +0100 Subject: [PATCH] Fix Numpy NdArray type checking for Python 3.8. (#107) --- jax_scaled_arithmetics/core/datatype.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 7dd3862..2d293fa 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import jax import jax.numpy as jnp @@ -13,7 +13,10 @@ from .pow2 import Pow2RoundMode, pow2_decompose from .typing import Array, ArrayTypes -GenericArray = Union[Array, np.ndarray[Any, Any]] +if TYPE_CHECKING: + GenericArray = Union[Array, np.ndarray[Any, Any]] +else: + GenericArray = Union[Array, np.ndarray] @register_pytree_node_class