Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Work in progress] Add FP8 support in fwd_prefill #115

Draft
wants to merge 61 commits into
base: main_perf
Choose a base branch
from

Conversation

brunomazzottiamd
Copy link

No description provided.

alexkranias-amd and others added 30 commits December 9, 2024 10:09
feat: added fp32 output to input_helper

passing

feat: fp8 tests. small amount of error

added fp8e5m2 type

note: RuntimeError: "abs_cuda" not implemented for 'Float8_e4m3fnuz'

enabled fp8 GEMMs

fix: error down to < 0.1

added another fp8 dtype

best accuracy is with no scaling

improved accuracy to within < 0.02. issue related to torch side casting

fix: passes if we allow v to be fp16 instead of fp8. otherwise we have error < 0.1

all error is < 0.07

feat: added per head scaling tensors

progress towards implementing scaling tensors in kernel

save

issue: error caused by acc += tl.dot(p.to(v.type.element_ty), v)
Error:
UnboundLocalError: local variable 'q_scale_stride_z' referenced before
assignment.

Fix:
Initialize 'q_scale_stride_z' and 'kv_scale_stride_z' before
assignment.
Warning: I don't know if this is the correct thing to do.
Warning - 2 test cases are failing due to this change:
AssertionError: Tensor-likes are not close!

FAILED test.py::test_op_prefill_fwd_impl[False-dtype1-True-bshd-0.0-False-4-6-6-1024-1023-32]
Mismatched elements: 1 / 786432 (0.0%)
Greatest absolute difference: 0.14855387806892395 at index (0, 309, 2, 18) (up to 0.1009 allowed)
Greatest relative difference: 0.28865116834640503 at index (0, 309, 2, 18) (up to 0.09128 allowed)

FAILED test.py::test_op_prefill_fwd_impl[False-dtype1-False-bshd-0.0-False-4-6-6-1024-1023-32]
Mismatched elements: 1 / 786432 (0.0%)
Greatest absolute difference: 0.14855387806892395 at index (0, 309, 2, 18) (up to 0.1009 allowed)
Greatest relative difference: 0.28865116834640503 at index (0, 309, 2, 18) (up to 0.09128 allowed)
Two tests are still failling.
* Do not track gradients for scale factors.
* Handle maximum absolute value equals to zero in per batch / head
  scaling method.
Add support to "thd" varlen layout.
The intent is to make code review easier.
This commit also reduce whitespace changes to facilitate code review.
This commit reduces code review overhead and keeps things as they are before the
introduction of fp8.
This is a temporary commit and it shoud be reverted before merging.
This data will be studied and then deleted from commit history.
Error tolerance is greater than when comparing fp8 Triton with fp8
PyTorch reference implementation.
@@ -0,0 +1,143 @@
# Install the newest triton version with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us remove the benchmarking code for now. Let us just deal with functionality and minimize the diff to main_perf

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll remove this file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by c0dd573.

@@ -0,0 +1,545 @@
Z,HQ,HK,N_CTX_Q,N_CTX_K,D_HEAD,causal,dropout_p,layout,use_exp2,scale_per_head,mismatched_elems,total_elems,mismatched_percentage,greatest_abs_diff,greatest_rel_diff
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be commiting the csv files right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention is not to commit this file in any way. I just used GitHub as a shortcut to get data out of Citrix environment. I'll remove this and the other related shell script file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by 0494786.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants