From 3cb5855aec59a0d04a077b6eacbdb174eacbbc82 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 13 Nov 2024 10:47:31 +0000 Subject: [PATCH] [Doc] Add doc on export with nested keys ghstack-source-id: 9c95e2dba6751d93c20c66d0dba0d4219dc61c0b Pull Request resolved: https://github.com/pytorch/tensordict/pull/1085 --- tutorials/sphinx_tuto/export.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tutorials/sphinx_tuto/export.py b/tutorials/sphinx_tuto/export.py index 1d1e5f30b..df8e3fda5 100644 --- a/tutorials/sphinx_tuto/export.py +++ b/tutorials/sphinx_tuto/export.py @@ -132,6 +132,25 @@ # and the FX graph: print("fx graph:", model_export.graph_module.print_readable()) +################################################## +# Working with nested keys +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Nested keys are a core feature of the tensordict library, and being able to export modules that read and write +# nested entries is therefore an important feature to support. +# Because keyword arguments must be regualar strings, it is not possible for :class:`~tensordict.nn.dispatch` to work +# directly with them. Instead, ``dispatch`` will unpack nested keys joined with a regular underscore (`"_"`), as the +# following example shows. + +model_nested = Seq( + Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]), + Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]), +).select_out_keys(("some", "output")) + +model_nested_export = export(model_nested, args=(), kwargs={"some_key": x}) +print("exported module with nested input:", model_nested_export.module()) + + ################################################## # Note that the callable returned by `module()` is a pure python callable that can be in turn compiled using # :func:`~torch.compile`.