Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change from using scale to using amax and dividing within kernel #5

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Commits on Feb 29, 2024

  1. remove add fluff

    drisspg committed Feb 29, 2024
    Configuration menu
    Copy the full SHA
    e98b59e View commit details
    Browse the repository at this point in the history

Commits on Mar 1, 2024

  1. I am coninvced of the numerics, amax it is lets change all the kernel…

    …s to use this directly
    drisspg committed Mar 1, 2024
    Configuration menu
    Copy the full SHA
    eb7220c View commit details
    Browse the repository at this point in the history
  2. go all the way

    drisspg committed Mar 1, 2024
    Configuration menu
    Copy the full SHA
    4abee6e View commit details
    Browse the repository at this point in the history
  3. ❯ python /home/drisspg/meta/driss_torch/benchmarks/benchmark_saturate…

    …d_casting.py
    
    100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [01:14<00:00,  3.40s/it]
      num_rows    num_cols  high_precision_dtype    low_precision_dtype      cuda_time    pytorch_time    compiled_pytorch_time
    ----------  ----------  ----------------------  ---------------------  -----------  --------------  -----------------------
           512         512  torch.bfloat16          torch.float8_e4m3fn        5.46512         72.2336                  59.9502
           512         512  torch.bfloat16          torch.float8_e5m2          5.50805         72.5867                  59.5001
          1024        1024  torch.bfloat16          torch.float8_e4m3fn        5.63875         72.5901                  64.8824
          1024        1024  torch.bfloat16          torch.float8_e5m2          5.61068         73.4559                  64.987
          2048        2048  torch.bfloat16          torch.float8_e4m3fn        6.92928         73.2639                  21.8212
          2048        2048  torch.bfloat16          torch.float8_e5m2          6.92831         73.238                   21.7541
          1024        8192  torch.bfloat16          torch.float8_e4m3fn       11.6602         125.987                   21.8744
          1024        8192  torch.bfloat16          torch.float8_e5m2         11.662          126.608                   21.9473
          8192        1280  torch.bfloat16          torch.float8_e4m3fn       13.9878         160.16                    21.9679
          8192        1280  torch.bfloat16          torch.float8_e5m2         13.9841         160.584                   22.0702
          8192        7168  torch.bfloat16          torch.float8_e4m3fn       81.9416         936.225                   85.1023
          8192        7168  torch.bfloat16          torch.float8_e5m2         82.5054         938.797                   85.1112
          3584        8192  torch.bfloat16          torch.float8_e4m3fn       43.5017         486.14                    44.7887
          3584        8192  torch.bfloat16          torch.float8_e5m2         43.4985         486.918                   44.8074
          2048      109760  torch.bfloat16          torch.float8_e4m3fn      301.522         3493.18                   317.802
          2048      109760  torch.bfloat16          torch.float8_e5m2        300.25          3495.29                   319.173
             1        3232  torch.bfloat16          torch.float8_e4m3fn        5.54213         71.9274                  64.7877
             1        3232  torch.bfloat16          torch.float8_e5m2          5.62434         72.1416                  64.597
          2048           1  torch.bfloat16          torch.float8_e4m3fn        5.59367         71.5838                  63.776
          2048           1  torch.bfloat16          torch.float8_e5m2          5.59475         71.9973                  63.966
         14144        2048  torch.bfloat16          torch.float8_e4m3fn       42.9653         479.963                   44.3175
         14144        2048  torch.bfloat16          torch.float8_e5m2         42.9723         480.981                   44.2864
    drisspg committed Mar 1, 2024
    Configuration menu
    Copy the full SHA
    1dc439d View commit details
    Browse the repository at this point in the history

Commits on Mar 4, 2024

  1. slightly hurts eager perf but needed for the matmul:

    ❯ python benchmarks/benchmark_saturated_casting.py
    100%|██████████████████████████████████████████████████████████████████| 22/22 [01:09<00:00,  3.16s/it]
      num_rows    num_cols  high_precision_dtype    low_precision_dtype      cuda_time    pytorch_time    compiled_pytorch_time
    ----------  ----------  ----------------------  ---------------------  -----------  --------------  -----------------------
           512         512  torch.bfloat16          torch.float8_e4m3fn        6.32126         70.6295                  59.8
           512         512  torch.bfloat16          torch.float8_e5m2          6.44788         72.2925                  58.8291
          1024        1024  torch.bfloat16          torch.float8_e4m3fn        6.5061          70.3212                  64.892
          1024        1024  torch.bfloat16          torch.float8_e5m2          6.51736         70.5548                  65.027
          2048        2048  torch.bfloat16          torch.float8_e4m3fn        8.75586         99.9261                  21.3831
          2048        2048  torch.bfloat16          torch.float8_e5m2          8.82279         99.6768                  21.5061
          1024        8192  torch.bfloat16          torch.float8_e4m3fn       14.4368         166.601                   21.5972
          1024        8192  torch.bfloat16          torch.float8_e5m2         14.5325         167.111                   21.2745
          8192        1280  torch.bfloat16          torch.float8_e4m3fn       17.2475         207.092                   21.407
          8192        1280  torch.bfloat16          torch.float8_e5m2         17.4947         207.493                   21.5125
          8192        7168  torch.bfloat16          torch.float8_e4m3fn       94.3791        1129.12                    93.1578
          8192        7168  torch.bfloat16          torch.float8_e5m2         95.0259        1129.95                    93.2293
          3584        8192  torch.bfloat16          torch.float8_e4m3fn       49.7809         588.253                   48.8965
          3584        8192  torch.bfloat16          torch.float8_e5m2         50.1227         588.554                   48.8997
          2048      109760  torch.bfloat16          torch.float8_e4m3fn      346.729         4188.53                   343.26
          2048      109760  torch.bfloat16          torch.float8_e5m2        349.481         4190.3                    343.243
             1        3232  torch.bfloat16          torch.float8_e4m3fn        6.36614         69.5251                  65.228
             1        3232  torch.bfloat16          torch.float8_e5m2          6.3741          69.8858                  64.997
          2048           1  torch.bfloat16          torch.float8_e4m3fn        6.45917         70.7204                  64.1961
          2048           1  torch.bfloat16          torch.float8_e5m2          6.45133         69.8905                  64.0061
         14144        2048  torch.bfloat16          torch.float8_e4m3fn       49.2119         580.919                   48.2678
         14144        2048  torch.bfloat16          torch.float8_e5m2         49.3998         581.265                   48.2766
    drisspg committed Mar 4, 2024
    Configuration menu
    Copy the full SHA
    6882d2d View commit details
    Browse the repository at this point in the history