From 934f00379eb81c596bb0e61a2d93bb0373335164 Mon Sep 17 00:00:00 2001 From: "Yanbo Liang (Meta Employee)" Date: Mon, 29 Jan 2024 11:48:18 -0800 Subject: [PATCH] Make trace_rules.lookup only handle function + callable type (#118366) Summary: Step by step changes to unblock #118264 X-link: https://github.com/pytorch/pytorch/pull/118366 Approved by: https://github.com/angelayi Reviewed By: clee2000 Differential Revision: D53171121 Pulled By: yanboliang fbshipit-source-id: afc55186a795dc7e52906a9f7febc3587fd867d8 --- .../dynamo/dynamobench/_dynamo/utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index c3ed283984..6655ad7354 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -518,8 +518,16 @@ def is_numpy_float_type(value): ) +def is_function_or_wrapper(value): + return ( + is_function(value) + or isinstance(value, functools._lru_cache_wrapper) + and is_function(inspect.getattr_static(value, "__wrapped__")) + ) + + def is_function(value): - return istype( + return isinstance( value, ( types.FunctionType, @@ -530,6 +538,12 @@ def is_function(value): ) +def unwrap_if_wrapper(value): + if isinstance(value, functools._lru_cache_wrapper): + value = inspect.getattr_static(value, "__wrapped__") + return value + + def is_numpy_ndarray(value): if not np: return False