-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Work in progress] Add FP8 support in fwd_prefill #115
base: main_perf
Are you sure you want to change the base?
Conversation
feat: added fp32 output to input_helper passing feat: fp8 tests. small amount of error added fp8e5m2 type note: RuntimeError: "abs_cuda" not implemented for 'Float8_e4m3fnuz' enabled fp8 GEMMs fix: error down to < 0.1 added another fp8 dtype best accuracy is with no scaling improved accuracy to within < 0.02. issue related to torch side casting fix: passes if we allow v to be fp16 instead of fp8. otherwise we have error < 0.1 all error is < 0.07 feat: added per head scaling tensors progress towards implementing scaling tensors in kernel save issue: error caused by acc += tl.dot(p.to(v.type.element_ty), v)
Error: UnboundLocalError: local variable 'q_scale_stride_z' referenced before assignment. Fix: Initialize 'q_scale_stride_z' and 'kv_scale_stride_z' before assignment.
Warning: I don't know if this is the correct thing to do.
Warning - 2 test cases are failing due to this change: AssertionError: Tensor-likes are not close! FAILED test.py::test_op_prefill_fwd_impl[False-dtype1-True-bshd-0.0-False-4-6-6-1024-1023-32] Mismatched elements: 1 / 786432 (0.0%) Greatest absolute difference: 0.14855387806892395 at index (0, 309, 2, 18) (up to 0.1009 allowed) Greatest relative difference: 0.28865116834640503 at index (0, 309, 2, 18) (up to 0.09128 allowed) FAILED test.py::test_op_prefill_fwd_impl[False-dtype1-False-bshd-0.0-False-4-6-6-1024-1023-32] Mismatched elements: 1 / 786432 (0.0%) Greatest absolute difference: 0.14855387806892395 at index (0, 309, 2, 18) (up to 0.1009 allowed) Greatest relative difference: 0.28865116834640503 at index (0, 309, 2, 18) (up to 0.09128 allowed)
Two tests are still failling.
* Do not track gradients for scale factors. * Handle maximum absolute value equals to zero in per batch / head scaling method.
q and k were just converted to fp32 5 lines before.
The intention is to document fp8 module to other devs.
Now the function accepts multiple tensors as input arguments.
Warning: * "thd" varlen layout is still not supporded. * "bshd" and "bhsd" layouts only work when HQ == HK.
Add support to MQA and GQA.
Add support to "thd" varlen layout.
The intent is to make code review easier.
This commit also reduce whitespace changes to facilitate code review.
This commit reduces code review overhead and keeps things as they are before the introduction of fp8.
This is a temporary commit and it shoud be reverted before merging. This data will be studied and then deleted from commit history.
Error tolerance is greater than when comparing fp8 Triton with fp8 PyTorch reference implementation.
benchmarks/benchmark_fp8.py
Outdated
@@ -0,0 +1,143 @@ | |||
# Install the newest triton version with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us remove the benchmarking code for now. Let us just deal with functionality and minimize the diff to main_perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll remove this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by c0dd573.
@@ -0,0 +1,545 @@ | |||
Z,HQ,HK,N_CTX_Q,N_CTX_K,D_HEAD,causal,dropout_p,layout,use_exp2,scale_per_head,mismatched_elems,total_elems,mismatched_percentage,greatest_abs_diff,greatest_rel_diff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not be commiting the csv files right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intention is not to commit this file in any way. I just used GitHub as a shortcut to get data out of Citrix environment. I'll remove this and the other related shell script file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by 0494786.
This reverts commit 9de6785.
No description provided.