Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change from using scale to using amax and dividing within kernel #5

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

drisspg
Copy link
Owner

@drisspg drisspg commented Mar 1, 2024

Summary

No hit to perf:

❯ python /home/drisspg/meta/driss_torch/benchmarks/benchmark_saturated_casting.py
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [01:14<00:00,  3.40s/it]
  num_rows    num_cols  high_precision_dtype    low_precision_dtype      cuda_time    pytorch_time    compiled_pytorch_time
----------  ----------  ----------------------  ---------------------  -----------  --------------  -----------------------
       512         512  torch.bfloat16          torch.float8_e4m3fn        5.46512         72.2336                  59.9502
       512         512  torch.bfloat16          torch.float8_e5m2          5.50805         72.5867                  59.5001
      1024        1024  torch.bfloat16          torch.float8_e4m3fn        5.63875         72.5901                  64.8824
      1024        1024  torch.bfloat16          torch.float8_e5m2          5.61068         73.4559                  64.987
      2048        2048  torch.bfloat16          torch.float8_e4m3fn        6.92928         73.2639                  21.8212
      2048        2048  torch.bfloat16          torch.float8_e5m2          6.92831         73.238                   21.7541
      1024        8192  torch.bfloat16          torch.float8_e4m3fn       11.6602         125.987                   21.8744
      1024        8192  torch.bfloat16          torch.float8_e5m2         11.662          126.608                   21.9473
      8192        1280  torch.bfloat16          torch.float8_e4m3fn       13.9878         160.16                    21.9679
      8192        1280  torch.bfloat16          torch.float8_e5m2         13.9841         160.584                   22.0702
      8192        7168  torch.bfloat16          torch.float8_e4m3fn       81.9416         936.225                   85.1023
      8192        7168  torch.bfloat16          torch.float8_e5m2         82.5054         938.797                   85.1112
      3584        8192  torch.bfloat16          torch.float8_e4m3fn       43.5017         486.14                    44.7887
      3584        8192  torch.bfloat16          torch.float8_e5m2         43.4985         486.918                   44.8074
      2048      109760  torch.bfloat16          torch.float8_e4m3fn      301.522         3493.18                   317.802
      2048      109760  torch.bfloat16          torch.float8_e5m2        300.25          3495.29                   319.173
         1        3232  torch.bfloat16          torch.float8_e4m3fn        5.54213         71.9274                  64.7877
         1        3232  torch.bfloat16          torch.float8_e5m2          5.62434         72.1416                  64.597
      2048           1  torch.bfloat16          torch.float8_e4m3fn        5.59367         71.5838                  63.776
      2048           1  torch.bfloat16          torch.float8_e5m2          5.59475         71.9973                  63.966
     14144        2048  torch.bfloat16          torch.float8_e4m3fn       42.9653         479.963                   44.3175
     14144        2048  torch.bfloat16          torch.float8_e5m2         42.9723         480.981                   44.2864
     ```
     
     Still need to dive back into ncu to see if there more we can do to achieve better BW utilization

…d_casting.py

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [01:14<00:00,  3.40s/it]
  num_rows    num_cols  high_precision_dtype    low_precision_dtype      cuda_time    pytorch_time    compiled_pytorch_time
----------  ----------  ----------------------  ---------------------  -----------  --------------  -----------------------
       512         512  torch.bfloat16          torch.float8_e4m3fn        5.46512         72.2336                  59.9502
       512         512  torch.bfloat16          torch.float8_e5m2          5.50805         72.5867                  59.5001
      1024        1024  torch.bfloat16          torch.float8_e4m3fn        5.63875         72.5901                  64.8824
      1024        1024  torch.bfloat16          torch.float8_e5m2          5.61068         73.4559                  64.987
      2048        2048  torch.bfloat16          torch.float8_e4m3fn        6.92928         73.2639                  21.8212
      2048        2048  torch.bfloat16          torch.float8_e5m2          6.92831         73.238                   21.7541
      1024        8192  torch.bfloat16          torch.float8_e4m3fn       11.6602         125.987                   21.8744
      1024        8192  torch.bfloat16          torch.float8_e5m2         11.662          126.608                   21.9473
      8192        1280  torch.bfloat16          torch.float8_e4m3fn       13.9878         160.16                    21.9679
      8192        1280  torch.bfloat16          torch.float8_e5m2         13.9841         160.584                   22.0702
      8192        7168  torch.bfloat16          torch.float8_e4m3fn       81.9416         936.225                   85.1023
      8192        7168  torch.bfloat16          torch.float8_e5m2         82.5054         938.797                   85.1112
      3584        8192  torch.bfloat16          torch.float8_e4m3fn       43.5017         486.14                    44.7887
      3584        8192  torch.bfloat16          torch.float8_e5m2         43.4985         486.918                   44.8074
      2048      109760  torch.bfloat16          torch.float8_e4m3fn      301.522         3493.18                   317.802
      2048      109760  torch.bfloat16          torch.float8_e5m2        300.25          3495.29                   319.173
         1        3232  torch.bfloat16          torch.float8_e4m3fn        5.54213         71.9274                  64.7877
         1        3232  torch.bfloat16          torch.float8_e5m2          5.62434         72.1416                  64.597
      2048           1  torch.bfloat16          torch.float8_e4m3fn        5.59367         71.5838                  63.776
      2048           1  torch.bfloat16          torch.float8_e5m2          5.59475         71.9973                  63.966
     14144        2048  torch.bfloat16          torch.float8_e4m3fn       42.9653         479.963                   44.3175
     14144        2048  torch.bfloat16          torch.float8_e5m2         42.9723         480.981                   44.2864
❯ python benchmarks/benchmark_saturated_casting.py
100%|██████████████████████████████████████████████████████████████████| 22/22 [01:09<00:00,  3.16s/it]
  num_rows    num_cols  high_precision_dtype    low_precision_dtype      cuda_time    pytorch_time    compiled_pytorch_time
----------  ----------  ----------------------  ---------------------  -----------  --------------  -----------------------
       512         512  torch.bfloat16          torch.float8_e4m3fn        6.32126         70.6295                  59.8
       512         512  torch.bfloat16          torch.float8_e5m2          6.44788         72.2925                  58.8291
      1024        1024  torch.bfloat16          torch.float8_e4m3fn        6.5061          70.3212                  64.892
      1024        1024  torch.bfloat16          torch.float8_e5m2          6.51736         70.5548                  65.027
      2048        2048  torch.bfloat16          torch.float8_e4m3fn        8.75586         99.9261                  21.3831
      2048        2048  torch.bfloat16          torch.float8_e5m2          8.82279         99.6768                  21.5061
      1024        8192  torch.bfloat16          torch.float8_e4m3fn       14.4368         166.601                   21.5972
      1024        8192  torch.bfloat16          torch.float8_e5m2         14.5325         167.111                   21.2745
      8192        1280  torch.bfloat16          torch.float8_e4m3fn       17.2475         207.092                   21.407
      8192        1280  torch.bfloat16          torch.float8_e5m2         17.4947         207.493                   21.5125
      8192        7168  torch.bfloat16          torch.float8_e4m3fn       94.3791        1129.12                    93.1578
      8192        7168  torch.bfloat16          torch.float8_e5m2         95.0259        1129.95                    93.2293
      3584        8192  torch.bfloat16          torch.float8_e4m3fn       49.7809         588.253                   48.8965
      3584        8192  torch.bfloat16          torch.float8_e5m2         50.1227         588.554                   48.8997
      2048      109760  torch.bfloat16          torch.float8_e4m3fn      346.729         4188.53                   343.26
      2048      109760  torch.bfloat16          torch.float8_e5m2        349.481         4190.3                    343.243
         1        3232  torch.bfloat16          torch.float8_e4m3fn        6.36614         69.5251                  65.228
         1        3232  torch.bfloat16          torch.float8_e5m2          6.3741          69.8858                  64.997
      2048           1  torch.bfloat16          torch.float8_e4m3fn        6.45917         70.7204                  64.1961
      2048           1  torch.bfloat16          torch.float8_e5m2          6.45133         69.8905                  64.0061
     14144        2048  torch.bfloat16          torch.float8_e4m3fn       49.2119         580.919                   48.2678
     14144        2048  torch.bfloat16          torch.float8_e5m2         49.3998         581.265                   48.2766
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant