You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/Users/freedom/tf-master/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 3154, in ndim
return a.ndim
AttributeError: 'list' object has no attribute 'ndim'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/tmp/f_jax.py", line 21, in <module>
c_jit = foo_jit(a, b)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 253, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 491, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 969, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 922, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/f_jax.py", line 14, in foo_jit
return jnp.tile(a, b)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1813, in tile
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1191, in broadcast_to
return util._broadcast_to(array, shape)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 388, in _broadcast_to
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
File "<__array_function__ internals>", line 180, in ndim
File "/Users/freedom/tf-master/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 3156, in ndim
return asarray(a).ndim
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/core.py", line 611, in __array__
raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[].
The error occurred while tracing the function foo_jit at /tmp/f_jax.py:10 for jit. This concrete value was not available in Python because it depends on the value of the argument b.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/f_jax.py", line 21, in <module>
c_jit = foo_jit(a, b)
File "/tmp/f_jax.py", line 14, in foo_jit
return jnp.tile(a, b)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1813, in tile
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1191, in broadcast_to
return util._broadcast_to(array, shape)
File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 388, in _broadcast_to
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
File "<__array_function__ internals>", line 180, in ndim
File "/Users/freedom/tf-master/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 3156, in ndim
return asarray(a).ndim
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[].
The error occurred while tracing the function foo_jit at /tmp/f_jax.py:10 for jit. This concrete value was not available in Python because it depends on the value of the argument b.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
What jax/jaxlib version are you using?
jax/jaxlib v0.4.14
Which accelerator(s) are you using?
CPU/GPU
Additional system info
Python 3.9/3.10, macOS
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered:
For jax.jit, the output shapes must only depend on the input shapes, not the input values.
Your foo_jit violates that: the output shape depends on values in b.
The usual way to mitigate this is to declare such arguments as "static", which causes recompilation for different values and requires the argument to be hashable.
Repro:
importjaximportjax.numpyasjnpimportnumpyasnpfromfunctoolsimportpartialdeffoo(a, b):
returnjnp.tile(a, b)
@partial(jax.jit, static_argnames="b") # declare `b` as staticdeffoo_jit(a, b):
returnjnp.tile(a, b)
a=np.ndarray([1, 20, 20])
b= (64, 1, 1) # a tuple is hashablec=foo(a, b)
print(c.shape)
# (64, 20, 20)c_jit=foo_jit(a, b)
print(c_jit.shape)
# (64, 20, 20)
Thanks for answering the question! @clemisch's answer touches all the relevant pieces: this is working as expected.
FYI, printing from a jit'ed function does not work as usual. See the docs about jax.debug.print.
Keep in mind jax.debug.print is only necessary for printing traced runtime values like array contents. Static values like array shapes and dtypes can be printed at trace-time with a standard numpy print.
I'm going to close this, because it's working as expected. Feel free to comment here or open another issue if you have additional questions!
Description
I ran into a possible bug when trying to make an Keras IO example to be Keras-Core based and backend-agnostic. keras-team/keras-core#623
with the following code snippets,
for
foo(a, b)
, I gotfor
foo_jit(a, b)
, I gotand error messages:
What jax/jaxlib version are you using?
jax/jaxlib v0.4.14
Which accelerator(s) are you using?
CPU/GPU
Additional system info
Python 3.9/3.10, macOS
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: