Clarification regarding GPU performance #23630
Replies: 2 comments 1 reply
-
Thanks for the report! We do document this here: https://jax.readthedocs.io/en/latest/xla_flags.html#configuring-xla-in-jax
But we probably could make this information more prominent. Do you have suggestions for what we could change so that you would have found / seen this warning? That said, I don't think time experimenting with them is necessarily time wasted. For particular pinned XLA versions, the flags won't change, and finding the best configuration for your particular model can be useful for optimizing performance. |
Beta Was this translation helpful? Give feedback.
-
Thanks so much for flagging this. Your points are right, and we need to improve this. We're looking into it. |
Beta Was this translation helpful? Give feedback.
-
We have a full fledged documentation on improving performance on GPUs. I have been trying to do something with JAX and Equinox where I tried using these flags. None of the flags improved any performance, and some of them led to segmentation fault. To this end, I first opened an issue in the JAX repo, and subsequently in the openXLA repo. I was told not to rely on any of these XLA flags as neither those APIs are stable or safe.
Given I wasted a lot of time to experiment with these, a clarification from the JAX devs on whether to use these or not would be better. If they are not reliable, we should put a note in the documentation (happy to do that).
cc: @mattjj @jakevdp
Beta Was this translation helpful? Give feedback.
All reactions