-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change from using scale to using amax and dividing within kernel #5
Open
drisspg
wants to merge
5
commits into
main
Choose a base branch
from
use-amax-instead-of-scale
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…s to use this directly
…d_casting.py 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [01:14<00:00, 3.40s/it] num_rows num_cols high_precision_dtype low_precision_dtype cuda_time pytorch_time compiled_pytorch_time ---------- ---------- ---------------------- --------------------- ----------- -------------- ----------------------- 512 512 torch.bfloat16 torch.float8_e4m3fn 5.46512 72.2336 59.9502 512 512 torch.bfloat16 torch.float8_e5m2 5.50805 72.5867 59.5001 1024 1024 torch.bfloat16 torch.float8_e4m3fn 5.63875 72.5901 64.8824 1024 1024 torch.bfloat16 torch.float8_e5m2 5.61068 73.4559 64.987 2048 2048 torch.bfloat16 torch.float8_e4m3fn 6.92928 73.2639 21.8212 2048 2048 torch.bfloat16 torch.float8_e5m2 6.92831 73.238 21.7541 1024 8192 torch.bfloat16 torch.float8_e4m3fn 11.6602 125.987 21.8744 1024 8192 torch.bfloat16 torch.float8_e5m2 11.662 126.608 21.9473 8192 1280 torch.bfloat16 torch.float8_e4m3fn 13.9878 160.16 21.9679 8192 1280 torch.bfloat16 torch.float8_e5m2 13.9841 160.584 22.0702 8192 7168 torch.bfloat16 torch.float8_e4m3fn 81.9416 936.225 85.1023 8192 7168 torch.bfloat16 torch.float8_e5m2 82.5054 938.797 85.1112 3584 8192 torch.bfloat16 torch.float8_e4m3fn 43.5017 486.14 44.7887 3584 8192 torch.bfloat16 torch.float8_e5m2 43.4985 486.918 44.8074 2048 109760 torch.bfloat16 torch.float8_e4m3fn 301.522 3493.18 317.802 2048 109760 torch.bfloat16 torch.float8_e5m2 300.25 3495.29 319.173 1 3232 torch.bfloat16 torch.float8_e4m3fn 5.54213 71.9274 64.7877 1 3232 torch.bfloat16 torch.float8_e5m2 5.62434 72.1416 64.597 2048 1 torch.bfloat16 torch.float8_e4m3fn 5.59367 71.5838 63.776 2048 1 torch.bfloat16 torch.float8_e5m2 5.59475 71.9973 63.966 14144 2048 torch.bfloat16 torch.float8_e4m3fn 42.9653 479.963 44.3175 14144 2048 torch.bfloat16 torch.float8_e5m2 42.9723 480.981 44.2864
❯ python benchmarks/benchmark_saturated_casting.py 100%|██████████████████████████████████████████████████████████████████| 22/22 [01:09<00:00, 3.16s/it] num_rows num_cols high_precision_dtype low_precision_dtype cuda_time pytorch_time compiled_pytorch_time ---------- ---------- ---------------------- --------------------- ----------- -------------- ----------------------- 512 512 torch.bfloat16 torch.float8_e4m3fn 6.32126 70.6295 59.8 512 512 torch.bfloat16 torch.float8_e5m2 6.44788 72.2925 58.8291 1024 1024 torch.bfloat16 torch.float8_e4m3fn 6.5061 70.3212 64.892 1024 1024 torch.bfloat16 torch.float8_e5m2 6.51736 70.5548 65.027 2048 2048 torch.bfloat16 torch.float8_e4m3fn 8.75586 99.9261 21.3831 2048 2048 torch.bfloat16 torch.float8_e5m2 8.82279 99.6768 21.5061 1024 8192 torch.bfloat16 torch.float8_e4m3fn 14.4368 166.601 21.5972 1024 8192 torch.bfloat16 torch.float8_e5m2 14.5325 167.111 21.2745 8192 1280 torch.bfloat16 torch.float8_e4m3fn 17.2475 207.092 21.407 8192 1280 torch.bfloat16 torch.float8_e5m2 17.4947 207.493 21.5125 8192 7168 torch.bfloat16 torch.float8_e4m3fn 94.3791 1129.12 93.1578 8192 7168 torch.bfloat16 torch.float8_e5m2 95.0259 1129.95 93.2293 3584 8192 torch.bfloat16 torch.float8_e4m3fn 49.7809 588.253 48.8965 3584 8192 torch.bfloat16 torch.float8_e5m2 50.1227 588.554 48.8997 2048 109760 torch.bfloat16 torch.float8_e4m3fn 346.729 4188.53 343.26 2048 109760 torch.bfloat16 torch.float8_e5m2 349.481 4190.3 343.243 1 3232 torch.bfloat16 torch.float8_e4m3fn 6.36614 69.5251 65.228 1 3232 torch.bfloat16 torch.float8_e5m2 6.3741 69.8858 64.997 2048 1 torch.bfloat16 torch.float8_e4m3fn 6.45917 70.7204 64.1961 2048 1 torch.bfloat16 torch.float8_e5m2 6.45133 69.8905 64.0061 14144 2048 torch.bfloat16 torch.float8_e4m3fn 49.2119 580.919 48.2678 14144 2048 torch.bfloat16 torch.float8_e5m2 49.3998 581.265 48.2766
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
No hit to perf: