-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Extend the remove broadcast + squeeze pass #3635
Comments
Thunder definition has the cast ops explicit in the trace. those are currently not cancelled out, since they are separated by the squeeze op. But we should be able to expand this and handle that: https://github.com/NVIDIA/Fuser/blob/main/csrc/preseg_passes/consecutive_cast.cpp Naoya also mentioned that the broadcast/squeeze pattern could also cancel each other out: https://github.com/NVIDIA/Fuser/blob/main/csrc/preseg_passes/remove_bcast_squeeze.cpp |
Since broadcast and squeeze don't affect the values computed in the fusion directly I think they'll commute with most ops and we should be able to just move all of the broadcasts and squeezes toward the inputs or outputs as a pass before we combine bcast+squeeze. That way theyd be adjacent and we could remove these and the consecutive casts as normal afterward. |
An orthogonal note.
We'll need the input to the second kernel |
I think we can do the same for the consecutive cast pass as well.... i.e. they should be able to move across meta operations. |
Here's a pattern in the Mistral RoPE backward function:
This is currently segmented into two segements, one reduction and one pointwise.
It seems the second segment should be just meta operations, but it's probably not detected as such due to the type cast ops. I think this should be safe to ignore the type cast ops and remove the broadcast and squeeze ops. With that, this segment would be just a no-op segment.
Note that while this is a part of a bwd function of the Mistral RoPE, the perf impact is likely small as it's just a small part of the overall fusion, as shown below. The above section corresponds to the upper right vertical sequence from
T3
toT89
.mistral_bwd.pdf
The text was updated successfully, but these errors were encountered: