-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AMD CK FP8 Kernels to Llama Dispatch
Summary: This diff replaces the slow `torch._scaled_mm` implementation with faster CK kernels added in D56963018. I also extend AMD support to rowwise fp8 quantization. Previously we didn't have an accelerated kernel to quantize tensors to fp8 on AMD, I updated the Triton functions to support AMD and use them for rowwise quantization and get excellent performance. Before this change we have: ``` INFO:root:BF16 T: 5966.47us, FLOPS: 483.74TF/s INFO:root:BF16 (G) T: 6017.79us, FLOPS: 479.61TF/s INFO:root:FP8D AMD T: 32761.27us, FLOPS: 88.10TF/s INFO:root:FP8D AMD (G) T: 33014.49us, FLOPS: 87.42TF/s ``` After we have: ``` INFO:root:BF16 T: 6006.43us, FLOPS: 480.52TF/s INFO:root:BF16 (G) T: 6045.48us, FLOPS: 477.42TF/s INFO:root:FP8D AMD CK T: 5894.74us, FLOPS: 489.63TF/s INFO:root:FP8D AMD CK (G) T: 5870.20us, FLOPS: 491.67TF/s INFO:root:FP8D rowwise AMD CK T: 3877.73us, FLOPS: 744.31TF/s INFO:root:FP8D rowwise AMD CK (G) T: 3892.07us, FLOPS: 741.56TF/s ``` When using LLama3 shapes the performance is: ``` INFO:root:BF16 T: 90416.92us, FLOPS: 474.26TF/s INFO:root:BF16 (G) T: 91022.73us, FLOPS: 471.10TF/s INFO:root:FP8D AMD CK T: 65453.99us, FLOPS: 655.13TF/s INFO:root:FP8D AMD CK (G) T: 63632.00us, FLOPS: 673.89TF/s INFO:root:FP8D rowwise AMD CK T: 60791.51us, FLOPS: 705.38TF/s INFO:root:FP8D rowwise AMD CK (G) T: 60916.44us, FLOPS: 703.93TF/s ``` Reviewed By: jianyuh, jiawenliu64 Differential Revision: D57739339
- Loading branch information
1 parent
ab05ca9
commit 262d4b9
Showing
2 changed files
with
27 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters