Skip to content

jax.jit(foo, device=dev) vs jax.device_put and interpretted vs JIT-ed #22963

Answered by yashk2810
erick-xanadu asked this question in Q&A
Discussion options

You must be logged in to vote

jax.jit(foo, device=dev)(x) vs x = jax.device_put(x, dev); jax.jit(foo)(x)

Yes, these should be equivalent but only in single device case. If x is sharded, then they won't be equivalent. Also jax.jit(f, device=) is deprecated so you should use the latter.

I can see this working by jax.jit just determining the target by looking at all inputs current device placement.

Yes, that's correct. JAX uses computation follows data semantics.

Does this mean that the python line x + x will generate an MLIR

+ is jitted internally, so it will be executed on the default device on the default backend that's present.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@erick-xanadu
Comment options

@yashk2810
Comment options

Answer selected by erick-xanadu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants