Skip to content

Commit

Permalink
Add namedtuple pytree serialization (#123648)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch/pytorch#123388 (comment)

X-link: pytorch/pytorch#123648
Approved by: https://github.com/desertfire

Reviewed By: DanilBaibak

Differential Revision: D55961061

Pulled By: angelayi

fbshipit-source-id: c34b610a6835a5f04deda92e5d8658299fb2ca3c
  • Loading branch information
angelayi authored and facebook-github-bot committed Apr 10, 2024
1 parent fc72ed4 commit 34ce4aa
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,15 @@ def load(cls, model, example_inputs, device):
# copy.deepcopy is required to prevent any surprising side-effect,
# see https://github.com/pytorch/pytorch/issues/113029
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)

if pytree._is_namedtuple_instance(example_outputs):
typ = type(example_outputs)
pytree._register_namedtuple(
typ,
serialized_type_name=f"{typ.__module__}.{typ.__name__}",
)
else:
_register_dataclass_output_as_pytree(example_outputs)

gm = torch.export._trace._export(
model,
Expand Down

0 comments on commit 34ce4aa

Please sign in to comment.