diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index cf4082630f..75f9cd0f04 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -2575,3 +2575,18 @@ class Invalid(dict): # type: ignore[type-arg] pass return RemovableHandle(Invalid()) + + +# Returns a "proxy" (new object with the same class and dict) for (non-GraphModule) nn.Module's. +# Attribute changes to the original object/proxy will be reflected in the other. +# This is useful for cases where we want a keep-alive reference to a module without increasing +# its reference count. +def nn_module_proxy(mod): + if not isinstance(mod, torch.nn.Module): + return mod + if isinstance(mod, torch.fx.GraphModule): + # Dynamo-generated GM's shouldn't contain user-created GM's + return mod + proxy = mod.__class__.__new__(mod.__class__) + proxy.__dict__ = mod.__dict__ + return proxy