From 5c0976fa047ff514bee45101b7f6d70b4a0eea20 Mon Sep 17 00:00:00 2001 From: David Berard Date: Fri, 8 Dec 2023 11:23:10 -0800 Subject: [PATCH] Revert "[dynamo] guarded config (#111299)" (#115386) This reverts commit 5927e9cbf2ac18aaaaecaab02258b7a35ac10969. Differential Revision: [D51959266](https://our.internmc.facebook.com/intern/diff/D51959266) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115386 Approved by: https://github.com/yanboliang, https://github.com/malfet ghstack dependencies: #115384, #115401, #115385 --- test/dynamo/test_comptime.py | 7 - test/dynamo/test_config.py | 200 ------------------------ test/dynamo/test_misc.py | 4 +- test/dynamo/test_recompiles.py | 108 ++++--------- test/dynamo/test_subclasses.py | 36 ++--- torch/_dynamo/config.py | 3 +- torch/_dynamo/convert_frame.py | 72 ++------- torch/_dynamo/eval_frame.py | 113 ++----------- torch/_dynamo/guards.py | 12 -- torch/_dynamo/output_graph.py | 2 - torch/testing/_internal/common_utils.py | 6 +- torch/utils/_config_module.py | 22 +-- torch/utils/_config_typing.pyi | 5 +- 13 files changed, 93 insertions(+), 497 deletions(-) diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index 1365dac3cc750..45f2a6c6ad9a9 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -230,13 +230,6 @@ def _(ctx): 'obj_weakref': None 'guarded_class': None } - global '' CONFIG_HASH_MATCH - { - 'guard_types': None, - 'code': None, - 'obj_weakref': None - 'guarded_class': None - } shape_env '' SHAPE_ENV { 'guard_types': None, diff --git a/test/dynamo/test_config.py b/test/dynamo/test_config.py index 64542bfa901b0..5ccd46a1ff20d 100644 --- a/test/dynamo/test_config.py +++ b/test/dynamo/test_config.py @@ -110,206 +110,6 @@ def test_config_hash(self): assert changed_hash != newest_hash assert newest_hash == starting_hash - @disable_cache_limit() - def test_no_saved_config(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - cnt_dynamic = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - opt_fn_static_shape = torch._dynamo.optimize( - cnt_dynamic, save_config=False - )(fn) - opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - for i in range(2, 12): - opt_fn_static_shape( - torch.randn(i), torch.randn(i) - ) # will be recompiled under new config - - self.assertEqual(cnt_dynamic.frame_count, 3) - - @disable_cache_limit() - def test_no_saved_config_nested(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - cnt_dynamic = torch._dynamo.testing.CompileCounter() - cnt_dynamic_1 = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(fn) - - # Will trigger recompile as compiled as static - opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - opt_fn_try_dynamic = torch._dynamo.optimize( - cnt_dynamic_1, save_config=False - )(opt_fn_static_shape) - - for i in range(2, 6): - opt_fn_try_dynamic(torch.randn(i), torch.randn(i)) - self.assertEqual(cnt_dynamic_1.frame_count, 1) - - # Saved config = False will use whatever config is available - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - for i in range(6, 12): - opt_fn_try_dynamic(torch.randn(i), torch.randn(i)) - self.assertEqual(cnt_dynamic_1.frame_count, 7) - - @disable_cache_limit() - def test_config_changed_from_guarded_config_1(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn) - res = opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - for i in range(2, 12): - # Only 4-11 will now be recompiled under old config - # 2-3 have been already been compiled under old config - # and hence will hit cache - opt_fn_static_shape(torch.randn(i), torch.randn(i)) - - self.assertEqual(cnt_dynamic.frame_count, 10) - - @disable_cache_limit() - def test_config_changed_from_guarded_config_2(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - opt_fn_dynamic_shape = torch._dynamo.optimize(cnt_dynamic)(fn) - opt_fn_dynamic_shape(torch.randn(2), torch.randn(2)) - opt_fn_dynamic_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 1) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - for i in range(2, 12): - opt_fn_dynamic_shape( - torch.randn(i), torch.randn(i) - ) # will not be recompiled due to automatic dynamic shapes - - self.assertEqual(cnt_dynamic.frame_count, 1) - - @disable_cache_limit() - def test_nested_compile_outer_wins(self): - def fn(a, b): - return a - b * 10 - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - cnt_dynamic_1 = torch._dynamo.testing.CompileCounter() - with torch._dynamo.config.patch( - automatic_dynamic_shapes=False, assume_static_by_default=True - ): - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn) - opt_fn_static_shape(torch.randn(2), torch.randn(2)) - opt_fn_static_shape(torch.randn(3), torch.randn(3)) - - self.assertEqual(cnt_dynamic.frame_count, 2) - - with torch._dynamo.config.patch( - automatic_dynamic_shapes=True, assume_static_by_default=False - ): - opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1)( - lambda x, y: opt_fn_static_shape(x, y) - ) - for i in range(2, 12): - opt_fn_dynamic( - torch.randn(i), torch.randn(i) - ) # will be recompiled under new config - - self.assertEqual(cnt_dynamic.frame_count, 2) - self.assertEqual(cnt_dynamic_1.frame_count, 1) - - @disable_cache_limit() - def test_nested_fn_does_not_inherit_outer_config(self): - def g1(x): - return x + 1 - - def g2(x): - return x * 2 - - def f(x): - x = g1(x) - torch._dynamo.graph_break() - return g2(x) - - torch._dynamo.reset() - - cnt_dynamic = torch._dynamo.testing.CompileCounter() - cnt_dynamic_1 = torch._dynamo.testing.CompileCounter() - - opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(f) - opt_fn_static_shape(torch.randn(2)) - opt_fn_static_shape(torch.randn(3)) - self.assertEqual(cnt_dynamic.frame_count, 4) # 2 compiles * 2 graphs - - opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1, dynamic=True)(g2) - - for i in range(2, 12): - opt_fn_dynamic( - torch.randn(i), - ) # will be recompiled under new config - - self.assertEqual(cnt_dynamic_1.frame_count, 1) - - @disable_cache_limit() - def test_multiple_compile_recompiles(self): - cnt_dynamic = torch._dynamo.testing.CompileCounter() - - def f(dynamic, compile_count): - @torch._dynamo.optimize(cnt_dynamic, dynamic=dynamic) - def g(x): - return x + 1 - - for i in range(2, 12): - g(torch.randn(i)) # will be recompiled under new config - self.assertEqual(cnt_dynamic.frame_count, compile_count) - cnt_dynamic.clear() - - f(dynamic=True, compile_count=1) # first compile - f(dynamic=False, compile_count=10) # recompile - f(dynamic=True, compile_count=0) # reuse first compile product - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 046645ffafbde..bfa9ff7d01baf 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -950,9 +950,7 @@ def fn(x, y): # Filter out id-matches that won't reproduce run to run guard_code = filter( - lambda line: not any( - banned in line for banned in ["id", "lookup_backend", "config_hash"] - ), + lambda line: "id" not in line and "lookup_backend" not in line, sorted(guard_code), ) guard_code_str = "\n".join(guard_code) diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index 171bf9020d728..ff39d0c8052a0 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -37,23 +37,15 @@ def run_foo_6_times_and_count_recompiles(dynamic=None): return cnt + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() without = run_without_automatic() self.assertEqual(without.frame_count, 5) @@ -107,23 +99,15 @@ def run_foo_6_times_and_count_recompiles(): return cnt + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() without = run_without_automatic() self.assertEqual(without.frame_count, 5) @@ -162,23 +146,15 @@ def run_foo_6_times_and_count_recompiles_swap_types(): return cnt + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles_swap_types() + return run_foo_6_times_and_count_recompiles_swap_types() + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): - with torch._dynamo.config.patch( - { - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles_swap_types() + return run_foo_6_times_and_count_recompiles_swap_types() without = run_without_automatic() self.assertEqual(without.frame_count, 5) @@ -275,45 +251,29 @@ def run_foo_6_times_and_count_recompiles(): return cnt + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_static_comp_default_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": True, - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_dynamic_comp_default_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": True, - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_static_comp_dynamic_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": False, - "automatic_dynamic_shapes": False, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() + @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) + @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) + @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_dynamic_comp_dynamic_param(): - with torch._dynamo.config.patch( - { - "force_parameter_static_shapes": False, - "automatic_dynamic_shapes": True, - "assume_static_by_default": True, - } - ): - return run_foo_6_times_and_count_recompiles() + return run_foo_6_times_and_count_recompiles() torch._dynamo.reset() static_comp_default_param = run_static_comp_default_param() diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 5199d12b9f63d..1c616628441b3 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -275,17 +275,17 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def sigmoid(self): return None - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - - @torch.compile(backend="eager", fullgraph=True) - def fn(x): - x.sigmoid() + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + x.sigmoid() msg = ( "Accessing overridden method/attribute sigmoid on a tensor" " subclass with a __torch_function__ override is not supported" ) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) @@ -299,17 +299,17 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): ndim = 10 - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - - @torch.compile(backend="eager", fullgraph=True) - def fn(x): - return x.ndim + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) @@ -332,17 +332,17 @@ def ndim(self): def ndim(self, value): self._ndim = value - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - - @torch.compile(backend="eager", fullgraph=True) - def fn(x): - return x.ndim + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index b9ca840931cf7..9d6ac54092d11 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -14,8 +14,7 @@ # or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity) # see this design doc for more detailed info # Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit# -# the name of a file to write the logs to (currently unused) -# TODO(jon-chuang): use setup_log_file in setup_compile_debug +# the name of a file to write the logs to # [@compile_ignored: debug] log_file_name = None diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index cee3ac372b45c..77808923ed7ec 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -85,7 +85,6 @@ log = logging.getLogger(__name__) bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode") -recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard @@ -379,23 +378,21 @@ def format_guard_failures(): }, ) - with config.patch(_patch_config_if_changed()): - compiled_product = _compile( - frame.f_code, - frame.f_globals, - frame.f_locals, - frame.f_builtins, - compiler_fn, - one_graph, - export, - export_constraints, - hooks, - cache_size, - frame, - frame_state=frame_state, - compile_id=compile_id, - ) - return compiled_product + return _compile( + frame.f_code, + frame.f_globals, + frame.f_locals, + frame.f_builtins, + compiler_fn, + one_graph, + export, + export_constraints, + hooks, + cache_size, + frame, + frame_state=frame_state, + compile_id=compile_id, + ) _convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] @@ -406,45 +403,6 @@ def _clone_with_backend(backend): return _convert_frame_assert -def _patch_config_if_changed(): - """ - Will return {} if the ambient config is the same as the compile-time. - Else, returns the compile-time config saved on the code object. - """ - patch: Dict[str, Any] = {} - eval_frame = torch._dynamo.eval_frame - eval_frame._maybe_init_guarded_config_cache() - if eval_frame.config_cache.saved_config_and_hash is None: - return patch - - saved = eval_frame.config_cache.saved_config_and_hash - saved_config, saved_config_hash = saved.config, saved.hash - current_config_hash = config.get_hash() - assert current_config_hash is not None - - if saved_config_hash != current_config_hash: - patch = saved_config - if recompiles_log.isEnabledFor(logging.DEBUG): - recompiles_log.debug( - ( - "Current config does not match config saved when compiling\n" - "Saved hash: %s, Current hash: %s\nRestoring saved config." - ), - saved_config_hash.hex(), - current_config_hash.hex(), - ) - config_dict_ref = config.shallow_copy_dict() - for key in patch: - if patch[key] != config_dict_ref[key]: - recompiles_log.debug( - "* %s=%s (prev: %s)", - key, - patch[key], - config_dict_ref[key], - ) - return patch - - def maybe_cprofile(func): if config.cprofile: return cprofile_wrapper(func) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 5bdb314f9f894..9bbea28758fd2 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -14,7 +14,6 @@ import traceback import types import warnings -from dataclasses import dataclass from enum import Enum from os.path import dirname, join from typing import ( @@ -287,66 +286,19 @@ def innermost_fn(fn): return unaltered_fn -# The config to restore to should dynamo compile / recompile when -# executing from the compiled function's _TorchDynamoContext -config_cache = threading.local() - - -@dataclass -class ConfigAndHash: - config: Dict[str, Any] - hash: bytes - - -def _maybe_init_guarded_config_cache(): - if not hasattr(config_cache, "saved_config_and_hash"): - # Optional[ConfigAndHash] - config_cache.saved_config_and_hash = None - - @contextlib.contextmanager -def restore_guarded_dynamo_config( - first_ctx: bool, saved_config_and_hash: ConfigAndHash -): - _maybe_init_guarded_config_cache() - # Set exactly once from top-level compile - is_top_level = False - try: - if first_ctx and config_cache.saved_config_and_hash is None: - is_top_level = True - config_cache.saved_config_and_hash = saved_config_and_hash - log.debug( - "Setting top-level compile config hash: %s", - saved_config_and_hash.hash.hex(), - ) - else: - log.debug("Ignoring inner dynamo compile config and hash") +def enable_dynamic(enable: Optional[bool] = None, export: bool = False): + if enable is None: yield - finally: - if is_top_level: - log.debug( - "Unsetting top-level compile config hash: %s", - config_cache.saved_config_and_hash.hash.hex(), - ) - config_cache.saved_config_and_hash = None - - -def _get_config_and_hash(dynamic=None): - if dynamic is None: - updates = {} - elif dynamic: - updates = {"assume_static_by_default": False} - else: - updates = {"automatic_dynamic_shapes": False, "assume_static_by_default": True} - return ConfigAndHash(*config.get_config_and_hash_with_updates(updates)) - - -def get_saved_else_current_config_hash() -> bytes: - _maybe_init_guarded_config_cache() - if config_cache.saved_config_and_hash is not None: - return config_cache.saved_config_and_hash.hash + elif enable: + # Assume everything is dynamic by default + with config.patch(assume_static_by_default=False): + yield else: - return config.get_hash() + with config.patch( + automatic_dynamic_shapes=False, assume_static_by_default=True + ): + yield class _TorchDynamoContext: @@ -361,7 +313,6 @@ def __init__( export=False, dynamic=None, compiler_config=None, - save_config=True, ): super().__init__() assert callable(callback) or callback is False or callback is None @@ -373,19 +324,8 @@ def __init__( self.export = export self.dynamic = dynamic self.compiler_config = compiler_config - self.save_config = save_config and first_ctx - if self.save_config: - self.save_and_hash_config() patch_fn() - def save_and_hash_config(self): - # save current value of dynamo configs - self.saved_config_and_hash = _get_config_and_hash(self.dynamic) - log.debug( - "Saving dynamo config and hash for new compiled object(s). Hash: %s", - self.saved_config_and_hash.hash.hex(), - ) - def __enter__(self): if config.raise_on_ctx_manager_usage: raise RuntimeError( @@ -399,19 +339,15 @@ def __enter__(self): self.backend_cache_manager.__enter__() self.backend_ctx = self.extra_ctx_ctor() self.backend_ctx.__enter__() - if self.save_config: - self.dynamo_config_ctx = restore_guarded_dynamo_config( - self.first_ctx, self.saved_config_and_hash - ) - self.dynamo_config_ctx.__enter__() + self.dynamic_ctx = enable_dynamic(self.dynamic, self.export) + self.dynamic_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): assert self.prior is not unset set_eval_frame(self.prior) self.prior = unset # TODO: This is totally not the right way to chain contexts manually - if self.save_config: - self.dynamo_config_ctx.__exit__(exc_type, exc_val, exc_tb) + self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb) self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) self.backend_cache_manager.__exit__(exc_type, exc_val, exc_tb) @@ -492,17 +428,13 @@ def _fn(*args, **kwargs): backend_cache_manager.__enter__() backend_ctx = backend_ctx_ctor() backend_ctx.__enter__() - if self.save_config: - dynamo_config_ctx = restore_guarded_dynamo_config( - self.first_ctx, self.saved_config_and_hash - ) - dynamo_config_ctx.__enter__() + dynamic_ctx = enable_dynamic(self.dynamic, self.export) + dynamic_ctx.__enter__() try: return fn(*args, **kwargs) finally: set_eval_frame(prior) - if self.save_config: - dynamo_config_ctx.__exit__(None, None, None) + dynamic_ctx.__exit__(None, None, None) backend_ctx.__exit__(None, None, None) backend_cache_manager.__exit__(None, None, None) @@ -572,7 +504,6 @@ def __init__( *, export=False, dynamic=None, - save_config=True, compiler_config=None, ): def on_enter(): @@ -587,7 +518,6 @@ def on_enter(): export=export, dynamic=dynamic, compiler_config=compiler_config, - save_config=save_config, ) @@ -677,7 +607,6 @@ def _optimize_catch_errors( export=False, dynamic=None, compiler_config=None, - save_config=True, ): return OptimizeContext( catch_errors_wrapper(compile_fn, hooks), @@ -686,7 +615,6 @@ def _optimize_catch_errors( export=export, dynamic=dynamic, compiler_config=compiler_config, - save_config=save_config, ) @@ -732,7 +660,6 @@ def optimize( guard_fail_fn=None, disable=False, dynamic=None, - save_config=True, ): """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -753,9 +680,7 @@ def optimize( dynamic: If True, upfront compile as dynamic a kernel as possible. If False, disable all dynamic shapes support (always specialize). If None, automatically detect when sizes vary and generate dynamic kernels upon recompile. - save_config: If True, recompiling this function will first restore the dynamo config - at the time when `optimize` was first called, for the duration of the compilation - process. + Example Usage:: @torch._dynamo.optimize() @@ -783,14 +708,12 @@ def toy_example(a, b): backend, dynamic=dynamic, hooks=hooks, - save_config=save_config, ) return _optimize_catch_errors( convert_frame.convert_frame(backend, hooks=hooks), hooks, backend_ctx_ctor, dynamic=dynamic, - save_config=save_config, compiler_config=backend.get_compiler_config() if hasattr(backend, "get_compiler_config") else None, @@ -1486,7 +1409,6 @@ def optimize_assert( export=False, export_constraints=None, dynamic=None, - save_config=True, ): """ The same as `torch._dynamo.optimize(backend, nopython=True)` @@ -1504,7 +1426,6 @@ def optimize_assert( backend_ctx_ctor, export=export, dynamic=dynamic, - save_config=save_config, ) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index c77fbfcdfc60b..e79cada9120ed 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -107,9 +107,6 @@ def uninteresting_files(): "___skip_backend_check": ( lambda: torch._dynamo.eval_frame.guarded_backend_cache.skip_backend_check_for_run_only_mode ), - "___compile_config_hash": ( - lambda: torch._dynamo.eval_frame.get_saved_else_current_config_hash().hex() - ), "___odict_getitem": collections.OrderedDict.__getitem__, "___dict_param_key_ids": dict_param_key_ids, "___dict_const_keys": dict_const_keys, @@ -601,15 +598,6 @@ def BACKEND_MATCH(self, guard: Guard): ] self._produce_guard_code(guard, code) - def CONFIG_HASH_MATCH(self, guard: Guard): - """Guard on the hash of the compiled function's dynamo config""" - - assert guard.source is GuardSource.GLOBAL - code = [ - f"___compile_config_hash() == '{torch._dynamo.eval_frame.get_saved_else_current_config_hash().hex()}'" - ] - self._produce_guard_code(guard, code) - def SHAPE_ENV(self, guard: Guard): # Let's handle ShapeEnv guards. To do this, we will resolve # shape variables to sources from tracked_fakes. This must happen after diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index a41aec07ee38b..66cd7acb52321 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -384,8 +384,6 @@ def init_ambient_guards(self): self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH)) - self.guards.add(GlobalStateSource().make_guard(GuardBuilder.CONFIG_HASH_MATCH)) - def add_cleanup_hook(self, fn: Callable[[], Any]): self.cleanup_hooks.append(fn) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 6b4bffdc9db06..1bf3b03355b68 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2747,12 +2747,12 @@ def _run_with_retry(self, result=None, num_runs_left=0, report_only=True, num_re supress_errors = torch._dynamo.config.suppress_errors with unittest.mock.patch("torch._dynamo.config.suppress_errors", supress_errors): if TEST_WITH_TORCHINDUCTOR: - super_run = torch._dynamo.optimize("inductor", save_config=False)(super_run) + super_run = torch._dynamo.optimize("inductor")(super_run) elif TEST_WITH_AOT_EAGER: - super_run = torch._dynamo.optimize("aot_eager_decomp_partition", save_config=False)(super_run) + super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run) elif TEST_WITH_TORCHDYNAMO: # TorchDynamo optimize annotation - super_run = torch._dynamo.optimize("eager", save_config=False, nopython=nopython)(super_run) + super_run = torch._dynamo.optimize("eager", nopython=nopython)(super_run) super_run(result=result) diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 20050f7d2dcbe..4412048508e52 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -9,7 +9,7 @@ import unittest import warnings from types import FunctionType, ModuleType -from typing import Any, Dict, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Union from unittest import mock # Types saved/loaded in configs @@ -172,23 +172,6 @@ def codegen_config(self) -> str: lines.append(f"{mod}.{k} = {v!r}") return "\n".join(lines) - def get_config_and_hash_with_updates( - self, updates: Dict[str, Any] - ) -> Tuple[Dict[str, Any], bytes]: - """Hashes the configs that are not compile_ignored, along with updates""" - if any(k in self._compile_ignored_keys for k in updates): - raise ValueError("update keys cannot be @compile_ignored") - cfg = { - k: v for k, v in self._config.items() if k not in self._compile_ignored_keys - } - cfg.update(updates) - hashed = self._get_hash(cfg) - return cfg, hashed - - def _get_hash(self, config: Dict[str, Any]) -> bytes: - string_to_hash = repr(sorted(config.items())) - return hashlib.md5(string_to_hash.encode("utf-8")).digest() - def get_hash(self) -> bytes: """Hashes the configs that are not compile_ignored""" if self._is_dirty or self._hash_digest is None: @@ -197,7 +180,8 @@ def get_hash(self) -> bytes: for k, v in self._config.items() if k not in self._compile_ignored_keys } - self._hash_digest = self._get_hash(dict_to_hash) + string_to_hash = repr(sorted(dict_to_hash.items())) + self._hash_digest = hashlib.md5(string_to_hash.encode("utf-8")).digest() self._is_dirty = False return self._hash_digest diff --git a/torch/utils/_config_typing.pyi b/torch/utils/_config_typing.pyi index b0d7b80d405a3..c31eb5f34a59d 100644 --- a/torch/utils/_config_typing.pyi +++ b/torch/utils/_config_typing.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union """ This was semi-automatically generated by running @@ -24,9 +24,6 @@ assert TYPE_CHECKING, "Do not use at runtime" def save_config() -> bytes: ... def codegen_config() -> str: ... -def get_config_and_hash_with_updates( - updates: Dict[str, Any] -) -> Tuple[Dict[str, Any], bytes]: ... def get_hash() -> bytes: ... def to_dict() -> Dict[str, Any]: ... def shallow_copy_dict() -> Dict[str, Any]: ...