diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index d7a40fa2b..2fec4c5cf 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -37,19 +37,25 @@ DefaultDict, Deque, Dict, + Iterable, Iterator, KeysView, List, Optional, + overload, Set, Tuple, Type, + TypeVar, Union, ValuesView, ) +from typing_extensions import TypeGuard from ..utils.hooks import RemovableHandle +T = TypeVar("T") + try: import numpy as np except ModuleNotFoundError: @@ -498,6 +504,23 @@ def clear(self): self.values.clear() +@overload +def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]: + ... + + +@overload +def istype( + obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]] +) -> TypeGuard[T]: + ... + + +@overload +def istype(obj: object, allowed_types: Iterable[type]) -> bool: + ... + + def istype(obj, allowed_types): """isinstance() without subclasses""" if isinstance(allowed_types, (tuple, list, set)): @@ -631,7 +654,7 @@ def is_numpy_ndarray(value): def istensor(obj): """Check of obj is a tensor""" - tensor_list = ( + tensor_list: Tuple[type, ...] = ( torch.Tensor, torch.nn.Parameter, *config.traceable_tensor_subclasses, @@ -1061,7 +1084,7 @@ def rot_n_helper(n): return fn -common_constant_types = { +common_constant_types: Set[type] = { int, float, complex,