From f8f3aecf7a3677716db922b5c4f3fce2e6ee1273 Mon Sep 17 00:00:00 2001 From: "Yanbo Liang (Meta Employee)" Date: Mon, 11 Dec 2023 17:33:36 -0800 Subject: [PATCH] Refactor out TorchInGraphFunctionVariable and improve heuristic (#113432) Summary: This is splitted from #113009, please check https://github.com/pytorch/pytorch/pull/113009#issuecomment-1804417925 for more details. X-link: https://github.com/pytorch/pytorch/pull/113432 Approved by: https://github.com/ezyang, https://github.com/jansel Reviewed By: osalpekar Differential Revision: D52017969 Pulled By: yanboliang fbshipit-source-id: 0430cf5428cf086616f6290db3cf3cb065f0938d --- .../dynamo/dynamobench/_dynamo/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 2c90d64f61..e5d93431e3 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -512,6 +512,19 @@ def is_numpy_float_type(value): ) +def is_function(value): + return istype( + value, + ( + types.FunctionType, + types.MethodType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + ), + ) + + def is_numpy_ndarray(value): if not np: return False @@ -2268,11 +2281,14 @@ def get_static_address_type(t): def is_rng_state_getter_or_setter(value): getters = ( + # The following two functions are not identical, so don't remove anyone! + torch._C.Generator.get_state, torch.default_generator.get_state, torch.get_rng_state, torch.cuda.get_rng_state, ) setters = ( + torch._C.Generator.set_state, torch.default_generator.set_state, torch.set_rng_state, torch.cuda.set_rng_state,