-
Notifications
You must be signed in to change notification settings - Fork 316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(Closes #35, updates #52) Changes for Jax-Metal #57
base: main
Are you sure you want to change the base?
Conversation
The new implementation of top_k is enabled with a flag. It should check the current device and use jax.lax.top_k if not running on jax-metal. This is not a general-purpose implementation of top_k but hopefully sufficient for sampling.
…ers code back for gpu/cpu; added jax-metal library to poetry"
Added back the scaled_rope that was accidentally deleted by me before. Tested with jax-metal and with jax-cpu. @nix sure, we have to wait for them. My hunch is that he will want to run the original code on "his" devices (read: TPU and jax-gpu) and we should try to come as close to that as possible but not force him to run watered down code. But now @Arrabonae has the choice. I will try to integrate the Frog branch next. |
Merging frog into main as his ideas have proven to be valid
Integrated the Frog branch and ran main.py on jax-metal as well as jax-cpu. Our lass is happy. Can someone validate on his rig, preferably on a jax-gpu, too? Note that I did not test the eval_main.py yet due to the README update by xjdr. |
Frog entropix/main.py successfully tested under
|
Based on #52
Added conditionals based on jax.extended.backend.get_backend(). Added jax-metal to toml. Verified that it is running on both Metal and CPU.