-
Hello, I am trying to understand how to verify if in-place updates are performed by the JAX compiler. Some background: I've noticed that my program slows down significantly (by magnitudes) when I double the size of certain arrays used as buffers. This leads me to believe that in-place updates might not be happening when writing to these buffers. If the updates were truly in-place, the buffer size shouldn't impact the performance as much. I am considering rewriting my code to ensure in-place updates, but I need guidance on how to check if an array is being updated in-place or if it's being copied. Are there specific commands or indicators in the HLO/jaxpr representations that I should look for? Any pointers or advice would be greatly appreciated. Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
There's no direct way to check for this, but you can get an idea of what the compiler is doing with your code by printing the optimized HLO. For example: import jax
def f(x):
x = x.at[0].set(1)
return x.sum()
x = jax.numpy.arange(10)
print(jax.jit(f).lower(x).compile().as_text())
Here the Hope that helps! |
Beta Was this translation helpful? Give feedback.
There's no direct way to check for this, but you can get an idea of what the compiler is doing with your code by printing the optimized HLO. For example: