Skip to content

Commit

Permalink
typing: convert_frame (#130670)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#130670
Approved by: https://github.com/Skylion007
ghstack dependencies: #130669

Reviewed By: atalman

Differential Revision: D59842185

Pulled By: aorenste

fbshipit-source-id: fd76404791ed6cf3ebc9a2adffc3857dc892b3ad
  • Loading branch information
aorenste authored and facebook-github-bot committed Jul 17, 2024
1 parent 736bd91 commit 97770e9
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,25 @@
DefaultDict,
Deque,
Dict,
Iterable,
Iterator,
KeysView,
List,
Optional,
overload,
Set,
Tuple,
Type,
TypeVar,
Union,
ValuesView,
)
from typing_extensions import TypeGuard

from ..utils.hooks import RemovableHandle

T = TypeVar("T")

try:
import numpy as np
except ModuleNotFoundError:
Expand Down Expand Up @@ -498,6 +504,23 @@ def clear(self):
self.values.clear()


@overload
def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]:
...


@overload
def istype(
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
) -> TypeGuard[T]:
...


@overload
def istype(obj: object, allowed_types: Iterable[type]) -> bool:
...


def istype(obj, allowed_types):
"""isinstance() without subclasses"""
if isinstance(allowed_types, (tuple, list, set)):
Expand Down Expand Up @@ -631,7 +654,7 @@ def is_numpy_ndarray(value):

def istensor(obj):
"""Check of obj is a tensor"""
tensor_list = (
tensor_list: Tuple[type, ...] = (
torch.Tensor,
torch.nn.Parameter,
*config.traceable_tensor_subclasses,
Expand Down Expand Up @@ -1061,7 +1084,7 @@ def rot_n_helper(n):
return fn


common_constant_types = {
common_constant_types: Set[type] = {
int,
float,
complex,
Expand Down

0 comments on commit 97770e9

Please sign in to comment.