Skip to content

Commit

Permalink
[XLA:Python] Improve the error message for the case where the previou…
Browse files Browse the repository at this point in the history
…s permissive None treedef behavior is encountered.

PiperOrigin-RevId: 680797730
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Oct 1, 2024
1 parent 2846a34 commit 50b0da2
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions third_party/xla/xla/python/pytree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,14 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const {

case PyTreeKind::kNone:
if (!object.is_none()) {
throw std::invalid_argument(
absl::StrFormat("Expected None, got %s.",
nb::cast<std::string_view>(nb::repr(object))));
throw std::invalid_argument(absl::StrFormat(
"Expected None, got %s.\n\n"
"In previous releases of JAX, flatten-up-to used to "
"consider None to be a tree-prefix of non-None values. To obtain "
"the previous behavior, you can usually write:\n"
" jax.tree.map(lambda x, y: None if x is None else f(x, y), a, "
"b, is_leaf=lambda x: x is None)",
nb::cast<std::string_view>(nb::repr(object))));
}
break;

Expand Down

0 comments on commit 50b0da2

Please sign in to comment.