diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index da6a0c1a83..b684107b03 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -1059,8 +1059,11 @@ def iter_contains(items, search, tx, check_tensor_identity=False): return found -def tensor_to_id(value): - return [id(k) if isinstance(k, torch.Tensor) else k for k in value.keys()] +def tensor_or_module_to_id(value): + return [ + id(k) if isinstance(k, (torch.Tensor, torch.nn.Module)) else k + for k in value.keys() + ] def const_repr(x, *, local) -> str: