-
Hi I was reading the following snippet https://github.com/google/jax/blob/0b87bf48f97ace10c7aee19c8f980788891a2df7/docs/sharded-computation.md?plain=1#L243-L262 in the documentation and saw the example function was wrapped in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi - thanks for the question! Functions wrapped in |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! Functions wrapped in
jax.jit
only operate on JAX arrays, not on NumPy arrays. But for historical reasons, JIT-compiled functions accept NumPy arrays as inputs, and silently convert them to JAX arrays.