From 54acff042a003087f24f9dc0e695d6895f6fee1d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Sep 2024 18:13:51 -0700 Subject: [PATCH] Reverts 5adafde02706dfe24e3a56744d0baecfdc3df90b PiperOrigin-RevId: 674083626 --- third_party/xla/xla/python/pytree.cc | 12 +++++++++--- third_party/xla/xla/python/xla_client.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 5592a96454821d..65bfb3fe5305e4 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -595,9 +595,15 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { case PyTreeKind::kNone: if (!object.is_none()) { - throw std::invalid_argument( - absl::StrFormat("Expected None, got %s.", - nb::cast(nb::repr(object)))); + PythonDeprecationWarning( + /*stacklevel=*/3, + "In a future release of JAX, flatten-up-to will no longer " + "consider None to be a tree-prefix of non-None values, got: " + "%s.\n\n" + "To preserve the current behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object))); } break; diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index a92abf945c4d4a..ede0f7749b9b3a 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 284 +_version = 283 # Version number for MLIR:Python components. mlir_api_version = 57