Skip to content

Commit

Permalink
Teach dynamo about torch.func.jvp (#119926)
Browse files Browse the repository at this point in the history
Summary:
List of changes:
- Replace JVP_NESTING by torch._C._functorch.maybe_current_level()
- Remove all increment nesting functions from wrap_fx_proxy_cls
- fwAD.make_dual receives the dual_level as keyword argument
- Add jvp_increment_nesting, set_fwd_grad_enabled and dual_level context managers to dynamo

X-link: pytorch/pytorch#119926
Approved by: https://github.com/zou3519

Reviewed By: huydhn

Differential Revision: D55273902

fbshipit-source-id: ce4d4f2f74a3c5545de62b13fae5f5a954f6cc3a
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Mar 23, 2024
1 parent 9d64293 commit d8fafe6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,12 @@ def is_namedtuple(obj):


def is_namedtuple_cls(cls):
"""Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
"""Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple"""
try:
if issubclass(cls, tuple):
bases = getattr(cls, "__bases__", []) or [None]
module = getattr(cls, "__module__", None)
return module == "torch.return_types" or (
return module in ("torch.return_types", "torch.autograd.forward_ad") or (
bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields")
)
except TypeError:
Expand Down

0 comments on commit d8fafe6

Please sign in to comment.