diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index c7e93eb251..82af2b4cb7 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -28,13 +28,29 @@ from contextlib import contextmanager from functools import lru_cache, wraps from pathlib import Path -from typing import Any, Dict, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + cast, + ClassVar, + Counter, + DefaultDict, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + Union, + ValuesView, +) try: import numpy as np except ModuleNotFoundError: - np = None + np = None # type: ignore[assignment] try: import torch._logging @@ -45,7 +61,12 @@ # NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync. if np: - NP_SUPPORTED_MODULES = (np, np.fft, np.linalg, np.random) + NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = ( + np, + np.fft, + np.linalg, + np.random, + ) NP_TO_TNP_MODULE = { np: tnp, @@ -54,7 +75,7 @@ np.random: tnp.random, } else: - NP_SUPPORTED_MODULES = {} + NP_SUPPORTED_MODULES = tuple() NP_TO_TNP_MODULE = {} from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode @@ -73,17 +94,17 @@ from torch.utils._pytree import tree_map_only -counters = collections.defaultdict(collections.Counter) +counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html" nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html" nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." log = logging.getLogger(__name__) # profiling compilation time by function -compilation_time_metrics = {} +compilation_time_metrics: Dict[str, List[float]] = {} # profiling compilation time by frame phase -frame_phase_timing = {} +frame_phase_timing: Dict[str, Dict[str, float]] = {} timer_counter = itertools.count() @@ -172,7 +193,7 @@ def increment_op_count(cnt): # entire_frame_compile:8.574629999999999 # backend_compile:5.26806 def print_time_report(): - total = 0 + total = 0.0 total_by_key = {} for timings in frame_phase_timing.values(): for key, timing in timings.items(): @@ -378,7 +399,7 @@ def write_record_to_file(filename, exec_record): with open(filename, "wb") as f: exec_record.dump(f) except Exception: - log.error("Unable to write execution record %s", filename, exc_info=1) + log.error("Unable to write execution record %s", filename, exc_info=True) def count_calls(g: fx.Graph): @@ -454,7 +475,7 @@ def is_typing(value): # # NB: we intentionally ignore classes that inherit from Generic, since they # can be used as both TypingVariable as well as UserDefinedClassVariable. - return isinstance(value, typing._Final) or value is typing.Generic + return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] def is_numpy_int_type(value): @@ -524,7 +545,7 @@ def make_cell(val=None): def f(): return x - assert len(f.__closure__) == 1 + assert f.__closure__ is not None and len(f.__closure__) == 1 return f.__closure__[0] @@ -581,6 +602,7 @@ def create(scope, name, val): class CleanupManager(ExactWeakKeyDictionary): count = 0 + instance: ClassVar["CleanupManager"] def _remove_id(self, idx): for hook in self.values[idx]: @@ -651,6 +673,7 @@ def torch_clone(x): def clone_inputs(example_inputs): + res: Union[Dict[Any, Any], List[Any]] if type(example_inputs) is dict: res = dict(example_inputs) for key, value in res.items(): @@ -757,7 +780,7 @@ class Marker: # frustrating ones e.g. torch.return_types.max assert cls.__module__ == "torch.return_types" obj = cls(map(Marker, range(cls.n_fields))) - fields = [None] * cls.n_fields + fields: List[Optional[str]] = [None] * cls.n_fields for name in dir(obj): if name[0] != "_" and isinstance(getattr(obj, name), Marker): fields[getattr(obj, name).index] = name @@ -881,10 +904,10 @@ def check_numpy_ndarray_args(args, kwargs): ) -dict_values = type(dict().values()) -odict_values = type(collections.OrderedDict().values()) -tuple_iterator = type(iter(tuple())) -tuple_iterator_len = tuple_iterator.__length_hint__ +dict_values: Type[ValuesView[Any]] = type(dict().values()) +odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values()) +tuple_iterator: Type[Iterator[Any]] = type(iter(tuple())) +tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined] object_new = object.__new__ @@ -923,15 +946,20 @@ def _get_fake_tensor(vt): def iter_contains(items, search, tx, check_tensor_identity=False): - from .variables import BuiltinVariable, ConstantVariable, TensorVariable + from .variables import ( + BuiltinVariable, + ConstantVariable, + TensorVariable, + VariableTracker, + ) if search.is_python_constant(): - found = any( + found_const = any( x.is_python_constant() and x.as_python_constant() == search.as_python_constant() for x in items ) - return ConstantVariable.create(found) + return ConstantVariable.create(found_const) must_check_tensor_id = False if check_tensor_identity and isinstance(search, TensorVariable): @@ -939,7 +967,7 @@ def iter_contains(items, search, tx, check_tensor_identity=False): # Match of Tensor means match of FakeTensor search = _get_fake_tensor(search) - found = None + found: Optional[VariableTracker] = None for x in items: if must_check_tensor_id: if isinstance(x, TensorVariable): @@ -1255,10 +1283,10 @@ def disable_cache_limit(): orig_code_map = ExactWeakKeyDictionary() # keep a record of code_obj -> list of guard failure reasons for logging -guard_failures = collections.defaultdict(list) +guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list) # Keep a record of graph break reasons for logging -graph_break_reasons = list() +graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReasons"] = list() # keep record of compiled code, if we are in "error if recompile" # to track code that dynamo has compiled previously @@ -1385,6 +1413,8 @@ def extract_fake_example_value(node, required=True): if "example_value" in node.meta and is_fake(node.meta["example_value"]): return node.meta["example_value"] elif required: + from torch._dynamo.exc import unimplemented + unimplemented("`FakeTensor` example value was required but not available") else: return None @@ -1456,7 +1486,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): except Unsupported: raise except RuntimeError as e: - cause = e + cause: BaseException = e if e.__cause__ is not None: cause = e.__cause__ @@ -1636,7 +1666,7 @@ def import_submodule(mod: types.ModuleType): """ Ensure all the files in a given submodule are imported """ - for filename in sorted(os.listdir(os.path.dirname(mod.__file__))): + for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))): if filename.endswith(".py") and filename[0] != "_": importlib.import_module(f"{mod.__name__}.{filename[:-3]}") @@ -1681,8 +1711,10 @@ def tensor_static_reason_to_message(reason: TensorStaticReason): def tensor_always_has_static_shape( - tensor: Union[torch.Tensor, Any], is_tensor: bool, guard_source: "GuardSource" -) -> Tuple[bool, TensorStaticReason]: + tensor: Union[torch.Tensor, Any], + is_tensor: bool, + guard_source: "torch._guards.GuardSource", +) -> Tuple[bool, Optional[TensorStaticReason]]: """ Given a tensor, source, and is_tensor flag, determine if a shape should be static. @@ -1825,6 +1857,7 @@ def to_numpy_helper(value): def numpy_to_tensor(value): """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" + assert np is not None if isinstance(value, np.ndarray): return torch.as_tensor(value) if isinstance(value, tnp.ndarray): @@ -1879,7 +1912,7 @@ def __call__(self, *args, **kwargs): class numpy_operator_wrapper: """Implements dunder methods for tnp.ndarray via functions from the operator library""" - def __init__(self, op: str): + def __init__(self, op: Callable[..., Any]): self.op = op self.__name__ = f"wrapped_{op.__name__}" @@ -2052,7 +2085,7 @@ def nextline(lineno, col): # left^^^^^ right^^^^^ # -2 since end_lineno is 1-indexed and because we added an extra # bracket to `segment` when calling ast.parse - cur_lineno = expr.left.end_lineno - 2 + cur_lineno = cast(int, expr.left.end_lineno) - 2 cur_col = normalize(cur_lineno, expr.left.end_col_offset) cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) @@ -2084,13 +2117,13 @@ def nextline(lineno, col): # value^^^^^ slice^^^^^ # subscript^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '[' after value) - left_lineno = expr.value.end_lineno - 2 + left_lineno = cast(int, expr.value.end_lineno) - 2 left_col = normalize(left_lineno, expr.value.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "[": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) - right_lineno = expr.end_lineno - 2 + right_lineno = cast(int, expr.end_lineno) - 2 right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) elif isinstance(expr, ast.Call): @@ -2098,13 +2131,13 @@ def nextline(lineno, col): # func^^^^^ # call^^^^^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '(' after func) - left_lineno = expr.func.end_lineno - 2 + left_lineno = cast(int, expr.func.end_lineno) - 2 left_col = normalize(left_lineno, expr.func.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "(": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) - right_lineno = expr.end_lineno - 2 + right_lineno = cast(int, expr.end_lineno) - 2 right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) @@ -2126,6 +2159,7 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s Python's `traceback` module doesn't handle multi-line expressions (and their anchor extraction code is not completely correct). """ + assert inst.positions is not None if inst.positions.lineno is None: return "" # The rstrip + "\n" pattern is used throughout this function to handle @@ -2180,7 +2214,7 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s markers = [marker.replace("~", "^") for marker in markers] else: # make markers mutable - markers = [list(marker) for marker in markers] + mutable_markers: List[List[str]] = [list(marker) for marker in markers] # anchor positions do not take start_offset into account if anchors.left_end_lineno == 0: @@ -2189,24 +2223,24 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s anchors.right_start_offset += start_offset # Turn `~`` markers between anchors to `^` - for line in range(len(markers)): - for col in range(len(markers[line])): - if line < anchors.left_end_lineno: + for lineno in range(len(markers)): + for col in range(len(mutable_markers[lineno])): + if lineno < anchors.left_end_lineno: continue - if line == anchors.left_end_lineno and col < anchors.left_end_offset: + if lineno == anchors.left_end_lineno and col < anchors.left_end_offset: continue if ( - line == anchors.right_start_lineno + lineno == anchors.right_start_lineno and col >= anchors.right_start_offset ): continue - if line > anchors.right_start_lineno: + if lineno > anchors.right_start_lineno: continue - if markers[line][col] == "~": - markers[line][col] = "^" + if mutable_markers[lineno][col] == "~": + mutable_markers[lineno][col] = "^" # make markers into strings again - markers = ["".join(marker) for marker in markers] + markers = ["".join(marker) for marker in mutable_markers] result = "" for i in range(len(markers)): @@ -2247,7 +2281,7 @@ def is_tensor_base_attr_getter(value): return ( isinstance(value, types.MethodWrapperType) and value.__name__ == "__get__" - and value.__self__.__objclass__ is torch._C._TensorBase + and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined] )