Skip to content

Commit

Permalink
Revert "[dynamo] guarded config (pytorch#111299)" (pytorch#115386)
Browse files Browse the repository at this point in the history
This reverts commit 5927e9c.

Differential Revision: [D51959266](https://our.internmc.facebook.com/intern/diff/D51959266)
Pull Request resolved: pytorch#115386
Approved by: https://github.com/yanboliang, https://github.com/malfet
ghstack dependencies: pytorch#115384, pytorch#115401, pytorch#115385
  • Loading branch information
davidberard98 authored and pytorchmergebot committed Dec 11, 2023
1 parent 6db7b30 commit 5c0976f
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 497 deletions.
7 changes: 0 additions & 7 deletions test/dynamo/test_comptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,6 @@ def _(ctx):
'obj_weakref': None
'guarded_class': None
}
global '' CONFIG_HASH_MATCH
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
shape_env '' SHAPE_ENV
{
'guard_types': None,
Expand Down
200 changes: 0 additions & 200 deletions test/dynamo/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,206 +110,6 @@ def test_config_hash(self):
assert changed_hash != newest_hash
assert newest_hash == starting_hash

@disable_cache_limit()
def test_no_saved_config(self):
def fn(a, b):
return a - b * 10

torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn_static_shape = torch._dynamo.optimize(
cnt_dynamic, save_config=False
)(fn)
opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))

self.assertEqual(cnt_dynamic.frame_count, 2)

with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
for i in range(2, 12):
opt_fn_static_shape(
torch.randn(i), torch.randn(i)
) # will be recompiled under new config

self.assertEqual(cnt_dynamic.frame_count, 3)

@disable_cache_limit()
def test_no_saved_config_nested(self):
def fn(a, b):
return a - b * 10

torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
cnt_dynamic_1 = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(fn)

# Will trigger recompile as compiled as static
opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))

self.assertEqual(cnt_dynamic.frame_count, 2)

opt_fn_try_dynamic = torch._dynamo.optimize(
cnt_dynamic_1, save_config=False
)(opt_fn_static_shape)

for i in range(2, 6):
opt_fn_try_dynamic(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_dynamic_1.frame_count, 1)

# Saved config = False will use whatever config is available
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
for i in range(6, 12):
opt_fn_try_dynamic(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_dynamic_1.frame_count, 7)

@disable_cache_limit()
def test_config_changed_from_guarded_config_1(self):
def fn(a, b):
return a - b * 10

torch._dynamo.reset()

cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn)
res = opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))

self.assertEqual(cnt_dynamic.frame_count, 2)

with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
for i in range(2, 12):
# Only 4-11 will now be recompiled under old config
# 2-3 have been already been compiled under old config
# and hence will hit cache
opt_fn_static_shape(torch.randn(i), torch.randn(i))

self.assertEqual(cnt_dynamic.frame_count, 10)

@disable_cache_limit()
def test_config_changed_from_guarded_config_2(self):
def fn(a, b):
return a - b * 10

torch._dynamo.reset()

cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn_dynamic_shape = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn_dynamic_shape(torch.randn(2), torch.randn(2))
opt_fn_dynamic_shape(torch.randn(3), torch.randn(3))

self.assertEqual(cnt_dynamic.frame_count, 1)

with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
for i in range(2, 12):
opt_fn_dynamic_shape(
torch.randn(i), torch.randn(i)
) # will not be recompiled due to automatic dynamic shapes

self.assertEqual(cnt_dynamic.frame_count, 1)

@disable_cache_limit()
def test_nested_compile_outer_wins(self):
def fn(a, b):
return a - b * 10

torch._dynamo.reset()

cnt_dynamic = torch._dynamo.testing.CompileCounter()
cnt_dynamic_1 = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))

self.assertEqual(cnt_dynamic.frame_count, 2)

with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1)(
lambda x, y: opt_fn_static_shape(x, y)
)
for i in range(2, 12):
opt_fn_dynamic(
torch.randn(i), torch.randn(i)
) # will be recompiled under new config

self.assertEqual(cnt_dynamic.frame_count, 2)
self.assertEqual(cnt_dynamic_1.frame_count, 1)

@disable_cache_limit()
def test_nested_fn_does_not_inherit_outer_config(self):
def g1(x):
return x + 1

def g2(x):
return x * 2

def f(x):
x = g1(x)
torch._dynamo.graph_break()
return g2(x)

torch._dynamo.reset()

cnt_dynamic = torch._dynamo.testing.CompileCounter()
cnt_dynamic_1 = torch._dynamo.testing.CompileCounter()

opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(f)
opt_fn_static_shape(torch.randn(2))
opt_fn_static_shape(torch.randn(3))
self.assertEqual(cnt_dynamic.frame_count, 4) # 2 compiles * 2 graphs

opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1, dynamic=True)(g2)

for i in range(2, 12):
opt_fn_dynamic(
torch.randn(i),
) # will be recompiled under new config

self.assertEqual(cnt_dynamic_1.frame_count, 1)

@disable_cache_limit()
def test_multiple_compile_recompiles(self):
cnt_dynamic = torch._dynamo.testing.CompileCounter()

def f(dynamic, compile_count):
@torch._dynamo.optimize(cnt_dynamic, dynamic=dynamic)
def g(x):
return x + 1

for i in range(2, 12):
g(torch.randn(i)) # will be recompiled under new config
self.assertEqual(cnt_dynamic.frame_count, compile_count)
cnt_dynamic.clear()

f(dynamic=True, compile_count=1) # first compile
f(dynamic=False, compile_count=10) # recompile
f(dynamic=True, compile_count=0) # reuse first compile product


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
4 changes: 1 addition & 3 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,9 +950,7 @@ def fn(x, y):

# Filter out id-matches that won't reproduce run to run
guard_code = filter(
lambda line: not any(
banned in line for banned in ["id", "lookup_backend", "config_hash"]
),
lambda line: "id" not in line and "lookup_backend" not in line,
sorted(guard_code),
)
guard_code_str = "\n".join(guard_code)
Expand Down
108 changes: 34 additions & 74 deletions test/dynamo/test_recompiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,15 @@ def run_foo_6_times_and_count_recompiles(dynamic=None):

return cnt

@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
with torch._dynamo.config.patch(
{
"automatic_dynamic_shapes": False,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
with torch._dynamo.config.patch(
{
"automatic_dynamic_shapes": True,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
Expand Down Expand Up @@ -107,23 +99,15 @@ def run_foo_6_times_and_count_recompiles():

return cnt

@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
with torch._dynamo.config.patch(
{
"automatic_dynamic_shapes": False,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
with torch._dynamo.config.patch(
{
"automatic_dynamic_shapes": True,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
Expand Down Expand Up @@ -162,23 +146,15 @@ def run_foo_6_times_and_count_recompiles_swap_types():

return cnt

@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
with torch._dynamo.config.patch(
{
"automatic_dynamic_shapes": False,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles_swap_types()
return run_foo_6_times_and_count_recompiles_swap_types()

@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
with torch._dynamo.config.patch(
{
"automatic_dynamic_shapes": True,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles_swap_types()
return run_foo_6_times_and_count_recompiles_swap_types()

without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
Expand Down Expand Up @@ -275,45 +251,29 @@ def run_foo_6_times_and_count_recompiles():

return cnt

@patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_static_comp_default_param():
with torch._dynamo.config.patch(
{
"force_parameter_static_shapes": True,
"automatic_dynamic_shapes": False,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

@patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_dynamic_comp_default_param():
with torch._dynamo.config.patch(
{
"force_parameter_static_shapes": True,
"automatic_dynamic_shapes": True,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

@patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_static_comp_dynamic_param():
with torch._dynamo.config.patch(
{
"force_parameter_static_shapes": False,
"automatic_dynamic_shapes": False,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

@patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_dynamic_comp_dynamic_param():
with torch._dynamo.config.patch(
{
"force_parameter_static_shapes": False,
"automatic_dynamic_shapes": True,
"assume_static_by_default": True,
}
):
return run_foo_6_times_and_count_recompiles()
return run_foo_6_times_and_count_recompiles()

torch._dynamo.reset()
static_comp_default_param = run_static_comp_default_param()
Expand Down
Loading

0 comments on commit 5c0976f

Please sign in to comment.