Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug of jnp.tile() #16994

Closed
freedomtan opened this issue Aug 7, 2023 · 3 comments
Closed

Possible bug of jnp.tile() #16994

freedomtan opened this issue Aug 7, 2023 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@freedomtan
Copy link

Description

I ran into a possible bug when trying to make an Keras IO example to be Keras-Core based and backend-agnostic. keras-team/keras-core#623

with the following code snippets,

import jax
import jax.numpy as jnp
import numpy as np

def foo(a, b):
    print(f"a.shape {a.shape}")
    print(f"b.shape {b.shape}")
    return jnp.tile(a, b)

@jax.jit
def foo_jit(a, b):
    print(f"a.shape {a.shape}")
    print(f"b.shape {b.shape}")
    return jnp.tile(a, b)

a = np.ndarray([1, 20, 20])
b = np.array([64, 1, 1])
c = foo(a, b)
print(c.shape)

c_jit = foo_jit(a, b)
print(c_jit.shape)

for foo(a, b), I got

a.shape ((1, 20, 20), <class 'numpy.ndarray'>)
b.shape ((3,), <class 'numpy.ndarray'>)

and get expected `(64, 20, 20)`

for foo_jit(a, b), I got

a.shape ((1, 20, 20), Traced<ShapedArray(float32[1,20,20])>with<DynamicJaxprTrace(level=1/0)>)
b.shape ((3,), Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=1/0)>)

and error messages:

Traceback (most recent call last):
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 3154, in ndim
    return a.ndim
AttributeError: 'list' object has no attribute 'ndim'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/f_jax.py", line 21, in <module>
    c_jit = foo_jit(a, b)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 253, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 491, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 969, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/pjit.py", line 922, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/tmp/f_jax.py", line 14, in foo_jit
    return jnp.tile(a, b)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1813, in tile
    result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1191, in broadcast_to
    return util._broadcast_to(array, shape)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 388, in _broadcast_to
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
  File "<__array_function__ internals>", line 180, in ndim
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 3156, in ndim
    return asarray(a).ndim
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/core.py", line 611, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[].
The error occurred while tracing the function foo_jit at /tmp/f_jax.py:10 for jit. This concrete value was not available in Python because it depends on the value of the argument b.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/f_jax.py", line 21, in <module>
    c_jit = foo_jit(a, b)
  File "/tmp/f_jax.py", line 14, in foo_jit
    return jnp.tile(a, b)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1813, in tile
    result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1191, in broadcast_to
    return util._broadcast_to(array, shape)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 388, in _broadcast_to
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
  File "<__array_function__ internals>", line 180, in ndim
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 3156, in ndim
    return asarray(a).ndim
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[].
The error occurred while tracing the function foo_jit at /tmp/f_jax.py:10 for jit. This concrete value was not available in Python because it depends on the value of the argument b.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

What jax/jaxlib version are you using?

jax/jaxlib v0.4.14

Which accelerator(s) are you using?

CPU/GPU

Additional system info

Python 3.9/3.10, macOS

NVIDIA GPU info

No response

@freedomtan freedomtan added the bug Something isn't working label Aug 7, 2023
@clemisch
Copy link
Contributor

clemisch commented Aug 7, 2023

For jax.jit, the output shapes must only depend on the input shapes, not the input values.

Your foo_jit violates that: the output shape depends on values in b.

The usual way to mitigate this is to declare such arguments as "static", which causes recompilation for different values and requires the argument to be hashable.

Repro:

import jax
import jax.numpy as jnp
import numpy as np
from functools import partial

def foo(a, b):
    return jnp.tile(a, b)

@partial(jax.jit, static_argnames="b")   # declare `b` as static
def foo_jit(a, b):
    return jnp.tile(a, b)

a = np.ndarray([1, 20, 20])
b = (64, 1, 1)   # a tuple is hashable
c = foo(a, b)
print(c.shape)
# (64, 20, 20)

c_jit = foo_jit(a, b)
print(c_jit.shape)
# (64, 20, 20)

@clemisch
Copy link
Contributor

clemisch commented Aug 7, 2023

FYI, printing from a jit'ed function does not work as usual. See the docs about jax.debug.print.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 7, 2023

Thanks for answering the question! @clemisch's answer touches all the relevant pieces: this is working as expected.

FYI, printing from a jit'ed function does not work as usual. See the docs about jax.debug.print.

Keep in mind jax.debug.print is only necessary for printing traced runtime values like array contents. Static values like array shapes and dtypes can be printed at trace-time with a standard numpy print.

I'm going to close this, because it's working as expected. Feel free to comment here or open another issue if you have additional questions!

@jakevdp jakevdp self-assigned this Aug 7, 2023
@jakevdp jakevdp closed this as completed Aug 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants