diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 6ff258ff5..8130b6aa7 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -63,7 +63,6 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( - _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3087,7 +3086,9 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] + stack = [ + get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) + ] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3107,6 +3108,11 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) +def clear_torch_function_mode_stack(): + for i in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn)