From 50b0da21c4513ff7c0980223ce0a68934eaff419 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 30 Sep 2024 18:37:34 -0700 Subject: [PATCH] [XLA:Python] Improve the error message for the case where the previous permissive None treedef behavior is encountered. PiperOrigin-RevId: 680797730 --- third_party/xla/xla/python/pytree.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 5592a96454821d..e5662f9f5da674 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -595,9 +595,14 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { case PyTreeKind::kNone: if (!object.is_none()) { - throw std::invalid_argument( - absl::StrFormat("Expected None, got %s.", - nb::cast(nb::repr(object)))); + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); } break;