Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial example 03 performance issue #1122

Open
narendrachaudhary51 opened this issue May 14, 2024 · 26 comments · Fixed by #1185 or #2011
Open

Tutorial example 03 performance issue #1122

narendrachaudhary51 opened this issue May 14, 2024 · 26 comments · Fixed by #1185 or #2011

Comments

@narendrachaudhary51
Copy link

Hi,

I am trying to run the benchmark in python tutorial 03-matrix-multiplication.py.
I do not see the expected performance with triton. Even for higher square matrix sizes the performance of triton is not improving.

I am using the following XPU hardware.
Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1100 1.3 [1.3.26918]

image

matmul-performance-fp16:
M N K rocBLAS Triton
0 256.0 256.0 256.0 0.108552 0.065673
1 384.0 384.0 384.0 0.363976 0.154921
2 512.0 512.0 512.0 0.857502 0.162093
3 640.0 640.0 640.0 1.664666 0.174643
4 768.0 768.0 768.0 2.829421 0.307113
5 896.0 896.0 896.0 4.453225 0.395680
6 1024.0 1024.0 1024.0 6.550690 0.316575
7 1152.0 1152.0 1152.0 9.213149 0.378349
8 1280.0 1280.0 1280.0 12.485583 0.444944
9 1408.0 1408.0 1408.0 15.589347 0.518402
10 1536.0 1536.0 1536.0 19.637789 0.576684
11 1664.0 1664.0 1664.0 23.102231 0.644403
12 1792.0 1792.0 1792.0 29.274080 0.726035
13 1920.0 1920.0 1920.0 32.775947 0.721385
14 2048.0 2048.0 2048.0 36.019792 0.722632
15 2176.0 2176.0 2176.0 43.476061 0.755484
16 2304.0 2304.0 2304.0 50.967526 0.797482
17 2432.0 2432.0 2432.0 59.675966 0.516829
18 2560.0 2560.0 2560.0 65.795927 0.524481
19 2688.0 2688.0 2688.0 72.057158 0.525320
20 2816.0 2816.0 2816.0 75.685494 0.397117
21 2944.0 2944.0 2944.0 78.605996 0.397091
22 3072.0 3072.0 3072.0 77.401139 0.392022
23 3200.0 3200.0 3200.0 93.879067 0.395438
24 3328.0 3328.0 3328.0 100.308276 0.366734
25 3456.0 3456.0 3456.0 107.904954 0.323059
26 3584.0 3584.0 3584.0 116.285356 0.305686
27 3712.0 3712.0 3712.0 97.801647 0.296061
28 3840.0 3840.0 3840.0 100.229800 0.293752
29 3968.0 3968.0 3968.0 103.656807 0.278301
30 4096.0 4096.0 4096.0 105.869743 0.265378

@vlad-penkin
Copy link
Contributor

@narendrachaudhary51 could you please provide the information on the environment

@narendrachaudhary51
Copy link
Author

  • I build the triton, ipex and pytorch from source using "scripts/compile-triton.sh" and "scripts/test-triton.sh" scripts.
  • My oneAPI version is 2024.1.0
  • My OS version is Rocky Linux 9.2 on a cluster, therefore commands mentioned in the link did not work for me.
  • I tried the same with "yum info" but that doesn't seem to work.

I was able to run the following command
xpu-smi discovery | grep "Device Name" | sed -n 's/.Device Name: (.)\s|/\1/p' >gpu.txt*
and gpu.txt contains this -
Intel(R) Data Center GPU Max 1100
Intel(R) Data Center GPU Max 1100

@fcharras
Copy link

fcharras commented May 16, 2024

@narendrachaudhary51 I'm playing with this exemple too and noticed the performance issue too.

I think the grid search parameters have not been tuned. The grid search parameters currently used are the same than for cuda, but I noticed that using higher num_warp offers up to a 3x speedup for the max series gpu. The best parameters I have found so far are:

BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 16, num_ctas: 1, num_stages: 5

with those parameters, the 512 * 512 matmul performance is very close to torch.matmul performance. (however I unfortunately also noticed that the performance becomes worse with higher dimensions)

@fcharras
Copy link

fcharras commented May 16, 2024

Here are the changes to the example that I find gives better performance: #1139

It also improves performance for higher dimensions with a 3 to 4 times speedup, but not to the point of reaching torch.matmul performance, e.g I find torch.matmul 100x faster on 4096 * 4096 example, and still 20x faster after this change:

In [6]: %time matmul(a, b)[0,0].cpu()
CPU times: user 17.3 ms, sys: 8.88 ms, total: 26.2 ms
Wall time: 25.6 ms
Out[6]: tensor(-64.9375, dtype=torch.float16)

In [7]: %time torch.matmul(a, b)[0,0].cpu()
CPU times: user 1.15 ms, sys: 788 µs, total: 1.94 ms
Wall time: 1.82 ms
Out[7]: tensor(-64.9375, dtype=torch.float16)

Maybe more work on the grid search could help having a good speedup on higher dimensions too.

@narendrachaudhary51
Copy link
Author

@fcharras Thank you for your reply. I suspected that grid search parameters could be the cause.
I will play with grid search parameters and check performance. Do you have a guess on why more warps help with XPU performance?

@fcharras
Copy link

I have only some intuition and a limited understanding of the xpu concepts to which warps (which is a cuda concept only) could be mapped to (execution units ? subslices ?), but I suspected that "num_warps" value in the grid is too low in this regard. But to be honest I'm surprised that it's good for cuda device to set it so low to begin with. So it was more some luck...

@fcharras
Copy link

fcharras commented May 16, 2024

From (again) entry-level understanding of it, matmul performance is a balance that is hard to achieve, between global and local memory bandwidth, cache bandwidth and hit rate, and actual compute, and I thought that if something has to be a bottleneck there, it might be the compute because of a too low number of threads being leveraged, which is (I think) increased when increasing num_warp.

@etiotto
Copy link
Contributor

etiotto commented May 16, 2024

By default the Triton kernel in the tutorial compiles the matrix multiplication operation to a sequence of floating point multiply-add scalar instructions. We can force the tutorial to use 16 threads per warp which allows our compiler to generate specialized HW instructions (DPAS) rather than scalar FMAs. Performance then improves ~7X for large problem sizes.

The code to change in the tutorial is:

    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation,  #
        threads_per_warp=16   <<<<< add this parameter
    )

Performance is still lacking behind the torch.matmul implementation (which offload that computation to a specialized oneDNN library). Future improvement are WIP.

I will post a PR to change the warp size used by the tutorial.

@vlad-penkin vlad-penkin linked a pull request May 16, 2024 that will close this issue
@fcharras
Copy link

Thank @etiotto , which this tip, and also suggestion in #1139 (comment) , I'm seeing almost equivalent walltime on the 512 * 512 example, and it has increased a lot on larger (4096 * 4096) example, albeit still below torch.matmul:

In [7]: %time torch.matmul(a, b)[0,0].cpu()
CPU times: user 0 ns, sys: 1.77 ms, total: 1.77 ms
Wall time: 1.67 ms
Out[7]: tensor(-64.9375, dtype=torch.float16)

In [8]: %time matmul(a, b)[0,0].cpu()
CPU times: user 4.79 ms, sys: 4.3 ms, total: 9.09 ms
Wall time: 8.77 ms
Out[8]: tensor(-64.9375, dtype=torch.float16)

maybe the matmul with the experimental block pointer approach (tutorial 09) will give better results ?

@ogrisel
Copy link

ogrisel commented May 17, 2024

maybe the matmul with the experimental block pointer approach (tutorial 09) will give better results?

That's an interesting question. Note however that the experimental example, including tutorial 09 on the use of block pointers for matrix matrix multiplication have been removed from the upstream repo (triton-lang/triton#3371), but I am not sure why.

Still I would be interested in the performance results of a Max Series GPU with a tuned grid and optimal threads_per_warp with block pointers.

@narendrachaudhary51
Copy link
Author

narendrachaudhary51 commented May 17, 2024

I tried the change suggested by @etiotto and increased the number of warps. This gave me speedup across the board.
I use the following autotune parameters and obtained the matmul performance changes. Currently, it is 100x faster than the default parameters. However, we are still 7-8x slower compared to the torch version.
def get_xpu_autotune_config():
return [
# FIXME: Once tl.dot uses DPAS put back the workload commented out.
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
# num_warps=64),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=16),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=16),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32)
]

image

I tried several other configurations, but I was not able to go beyond this. Is there a way to check and analyze the triton generated code? We can then identify the inefficiencies in the code by comparing it to torch implementation.

@ogrisel
Copy link

ogrisel commented May 17, 2024

The rocBLAS legend actually refers to Intel oneDNN's kernel wrapped as pytorch XPU matmul, right?

@narendrachaudhary51
Copy link
Author

@ogrisel Yes. It is the pytorch XPU matmul. It must be using the oneDNN kernel underneath.

@narendrachaudhary51
Copy link
Author

Is the triton generated code using 2D loads when doing the following operations?

a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

How can I dump the generated code?

@whitneywhtsang
Copy link
Contributor

How can I dump the generated code?

Output of different stages can be found in TRITON_CACHE_DIR, or you can use MLIR_ENABLE_DUMP=1 to dump the IR before every MLIR pass Triton runs.

@narendrachaudhary51
Copy link
Author

narendrachaudhary51 commented May 27, 2024

@ogrisel Block pointer-based matrix multiplication is faster than the previous implementation which only reached peak performance of 25 TFlops. But current performance is still below oneDNN matrix multiplication.

matmul-performance-block-pointerfp16

@vlad-penkin
Copy link
Contributor

vlad-penkin commented Aug 9, 2024

@AshburnLee could please retest the issue and report back the results with the top of the llvm-target branch with the

@AshburnLee
Copy link
Contributor

AshburnLee commented Aug 19, 2024

I'm working on it.

@narendrachaudhary51
Copy link
Author

@vlad-penkin I get the following output with latest nightly build of pytorch. I comment out the intel_extension_for_pytorch.

WARNING:root:Wall time is used instead of elapsed_time (not supported). The timing measurements could be innacurate.
(I): Detected 104320 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 97536 spills
(I): Detected 32000 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 23808 spills
(I): Detected 32000 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 24832 spills
(I): Detected 13760 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 5888 spills
(I): Detected 12992 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 5440 spills
(I): Detected 107648 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 91648 spills
(I): Detected 56192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 47616 spills
L0 build module failed. Log:
error: total scratch space exceeds HW supported limit for kernel matmul_kernel: 336128 bytes (max permitted PTSS 262144 bytes)
error: backend compiler failed build.

Segmentation fault (core dumped)

@AshburnLee
Copy link
Contributor

Segmentation fault (core dumped)

Same issue from my local.

@AshburnLee
Copy link
Contributor

AshburnLee commented Aug 20, 2024

After rebase and another try, here is the result:

Image

Note:

  • Agama Rolling (914.32)
  • PTDB 0.5.2
  • applied pytorch-upstream.txt & patch-pytorch.sh @vlad-penkin

@whitneywhtsang whitneywhtsang linked a pull request Aug 26, 2024 that will close this issue
@etiotto etiotto reopened this Aug 27, 2024
@etiotto
Copy link
Contributor

etiotto commented Aug 27, 2024

I just expanded 10-experimental-block-pointer.py and added a benchmarking section. That tutorial uses blocked pointers and the performance of the Triton generated kernel is now much closer to the one of the library (oneDNN) on a GPU Max 1100 card:

matmul-performance-fp16:
         M       N       K      oneDNN      Triton
0    256.0   256.0   256.0    4.032985    4.112063
1    384.0   384.0   384.0   10.889058   10.408659
2    512.0   512.0   512.0   21.236982   19.508391
3    640.0   640.0   640.0   33.781444   30.062384
4    768.0   768.0   768.0   41.330734   41.330734
5    896.0   896.0   896.0   56.550561   53.204376
6   1024.0  1024.0  1024.0   69.184396   65.472062
7   1152.0  1152.0  1152.0   84.186331   78.643196
8   1280.0  1280.0  1280.0  100.438311   91.658745
9   1408.0  1408.0  1408.0  114.397925   74.236951
10  1536.0  1536.0  1536.0  114.970764   83.576540
11  1664.0  1664.0  1664.0  140.129040   98.618210
12  1792.0  1792.0  1792.0  147.100846  108.006478
13  1920.0  1920.0  1920.0  143.859517   95.854388
14  2048.0  2048.0  2048.0  149.963941  103.944030
15  2176.0  2176.0  2176.0  137.892230  102.215355
16  2304.0  2304.0  2304.0  130.334506  102.881814
17  2432.0  2432.0  2432.0  161.549472  112.871665
18  2560.0  2560.0  2560.0  156.270644  122.568794
19  2688.0  2688.0  2688.0  160.882412  118.194522
20  2816.0  2816.0  2816.0  158.642184  119.490976
21  2944.0  2944.0  2944.0  158.524154  128.402019
22  3072.0  3072.0  3072.0  130.051270  109.070840
23  3200.0  3200.0  3200.0  162.153605  122.159258
24  3328.0  3328.0  3328.0  183.052956  133.240109
25  3456.0  3456.0  3456.0  173.613062  134.790491
26  3584.0  3584.0  3584.0  180.961785  140.372846
27  3712.0  3712.0  3712.0  187.821092  147.552047
28  3840.0  3840.0  3840.0  178.960510  151.820849
29  3968.0  3968.0  3968.0  183.021604  139.543144
30  4096.0  4096.0  4096.0  188.789775  148.846555

@narendrachaudhary51 can we close this issue ?

@narendrachaudhary51
Copy link
Author

This is good. However, performance still seems lower compared to oneDNN. I will try to replicate on my side. Do I still need to patch the upstream pytorch and compile it? or is it possible to use the latest nightly pytorch?

@vlad-penkin
Copy link
Contributor

This is good. However, performance still seems lower compared to oneDNN. I will try to replicate on my side. Do I still need to patch the upstream pytorch and compile it? or is it possible to use the latest nightly pytorch?

For the time being, please use our PyTorch nightly wheels, which are built with 'XPUEvent elapsed_time(...)' support. Our 'XPUEvent elapsed_time(...)' patch, inherited from IPEX, works well with Triton but causes significant performance degradation in PT use cases.

We expect that the nightly PyTorch wheels will have 'XPUEvent elapsed_time(...)' enabled sometime in November 2024, following the release of oneAPI 2025.0.0.

@narendrachaudhary51
Copy link
Author

narendrachaudhary51 commented Sep 2, 2024

I see the following performance with tutorial 03-matrix-multiplication.py.

✅ Triton and Torch match
matmul-performance-fp16:
M N K rocBLAS Triton
0 256.0 256.0 256.0 0.147863 0.138547
1 384.0 384.0 384.0 0.485206 0.451662
2 512.0 512.0 512.0 1.145325 1.040219
3 640.0 640.0 640.0 2.229993 1.968363
4 768.0 768.0 768.0 3.813825 3.322159
5 896.0 896.0 896.0 5.988511 5.119631
6 1024.0 1024.0 1024.0 8.840294 6.623564
7 1152.0 1152.0 1152.0 12.461816 9.168466
8 1280.0 1280.0 1280.0 16.876099 11.938756
9 1408.0 1408.0 1408.0 22.286945 11.808110
10 1536.0 1536.0 1536.0 28.085917 14.224635
11 1664.0 1664.0 1664.0 34.851498 14.088904
12 1792.0 1792.0 1792.0 42.090680 20.558901
13 1920.0 1920.0 1920.0 46.646869 18.776573
14 2048.0 2048.0 2048.0 53.553078 20.788403
15 2176.0 2176.0 2176.0 64.628405 23.061393
16 2304.0 2304.0 2304.0 69.723728 22.467316
17 2432.0 2432.0 2432.0 76.641771 23.700065
18 2560.0 2560.0 2560.0 83.397423 25.484694
19 2688.0 2688.0 2688.0 89.133204 25.116863
20 2816.0 2816.0 2816.0 90.920103 26.909569
21 2944.0 2944.0 2944.0 106.919421 26.048660
22 3072.0 3072.0 3072.0 101.450026 28.381275
23 3200.0 3200.0 3200.0 105.683250 26.986574
24 3328.0 3328.0 3328.0 112.480306 28.670089
25 3456.0 3456.0 3456.0 114.009226 29.442870
26 3584.0 3584.0 3584.0 117.231112 29.747002
27 3712.0 3712.0 3712.0 121.255919 31.131572
28 3840.0 3840.0 3840.0 123.464591 30.836395
29 3968.0 3968.0 3968.0 127.219857 29.903914
30 4096.0 4096.0 4096.0 129.087355 30.785845

And with tutorial 10-experimental-block-pointer.py

matmul-performance-fp16:
M N K oneDNN Triton
0 256.0 256.0 256.0 0.149553 0.148583
1 384.0 384.0 384.0 0.494516 0.492940
2 512.0 512.0 512.0 1.157381 1.158600
3 640.0 640.0 640.0 2.251031 2.253393
4 768.0 768.0 768.0 3.865471 3.869501
5 896.0 896.0 896.0 6.087495 6.100098
6 1024.0 1024.0 1024.0 8.975605 8.930044
7 1152.0 1152.0 1152.0 12.549223 12.424727
8 1280.0 1280.0 1280.0 17.077405 16.976155
9 1408.0 1408.0 1408.0 22.662444 20.787642
10 1536.0 1536.0 1536.0 28.934517 26.482439
11 1664.0 1664.0 1664.0 35.776467 33.022023
12 1792.0 1792.0 1792.0 42.317380 37.503231
13 1920.0 1920.0 1920.0 47.101591 44.193642
14 2048.0 2048.0 2048.0 55.083166 47.869894
15 2176.0 2176.0 2176.0 64.504927 56.232310
16 2304.0 2304.0 2304.0 70.907995 57.864855
17 2432.0 2432.0 2432.0 76.517352 65.067213
18 2560.0 2560.0 2560.0 84.734250 69.730570
19 2688.0 2688.0 2688.0 89.635363 75.368436
20 2816.0 2816.0 2816.0 91.010571 80.907230
21 2944.0 2944.0 2944.0 102.439338 86.769393
22 3072.0 3072.0 3072.0 104.393192 92.735069
23 3200.0 3200.0 3200.0 108.788432 90.169787
24 3328.0 3328.0 3328.0 113.708673 97.751175
25 3456.0 3456.0 3456.0 114.105404 100.760240
26 3584.0 3584.0 3584.0 117.340538 106.055255
27 3712.0 3712.0 3712.0 121.926329 106.022224
28 3840.0 3840.0 3840.0 123.629123 109.142698
29 3968.0 3968.0 3968.0 127.314797 111.455898
30 4096.0 4096.0 4096.0 129.398909 119.827576

My configurations are as following

Device
+-----------+--------------------------------------------------------------------------------------+
| Device ID | Device Information |
+-----------+--------------------------------------------------------------------------------------+
| 0 | Device Name: Intel(R) Data Center GPU Max 1550 |
| | Vendor Name: Intel(R) Corporation |
| | SOC UUID: 00000000-0000-0018-0000-002f0bd68086 |
| | PCI BDF Address: 0000:18:00.0 |
| | DRM Device: /dev/dri/card0 |
| | Function Type: physical |

Driver -
graphics-compute-runtime/hotfix_agama-ci-devel-914.33

OneAPI - 2024.2.1

@etiotto I can see that triton is reaching close to the performance of oneDNN. However, it seems both oneDNN and triton numbers are slower compared to yours.
@vlad-penkin I am using the latest nightly version from PyTorch. I suspect that is the reason I see performance degradation in both oneDNN and triton numbers.

@vlad-penkin vlad-penkin assigned vlad-penkin and etiotto and unassigned etiotto and vlad-penkin Sep 9, 2024
@etiotto
Copy link
Contributor

etiotto commented Sep 18, 2024

The numbers reported in #1122 (comment) are low for a PVC 1550. I just tried tutorial 10 again on a PVC 1100 and the throughput for both Triton and oneDNN is higher:

LIBIGC1_VERSION=1.0.17193.16-950
LEVEL_ZERO_VERSION=1.3.30049.10-950
AGAMA_VERSION=950
GPU_DEVICE=Intel(R) Data Center GPU Max 1100

matmul-performance-fp16:
         M       N       K      oneDNN      Triton
0    256.0   256.0   256.0    4.032985    4.112063
1    384.0   384.0   384.0   11.059200   10.564012
2    512.0   512.0   512.0   21.509251   19.737901
3    640.0   640.0   640.0   35.617391   29.789091
4    768.0   768.0   768.0   41.634635   41.330734
5    896.0   896.0   896.0   57.270952   52.891406
6   1024.0  1024.0  1024.0   69.184396   65.154234
7   1152.0  1152.0  1152.0   83.817098   78.643196
8   1280.0  1280.0  1280.0   98.550376   91.658745
9   1408.0  1408.0  1408.0  113.283659   73.922385
10  1536.0  1536.0  1536.0  118.272800   84.987772
11  1664.0  1664.0  1664.0  136.476390   83.832657
12  1792.0  1792.0  1792.0  148.008869   92.102836
13  1920.0  1920.0  1920.0  155.763382   97.545318
14  2048.0  2048.0  2048.0  146.886711  101.873037
15  2176.0  2176.0  2176.0  143.260675  102.540884
16  2304.0  2304.0  2304.0  135.294138  100.119435
17  2432.0  2432.0  2432.0  166.331703  114.488738
18  2560.0  2560.0  2560.0  168.175776  123.144563
19  2688.0  2688.0  2688.0  164.479379  118.714700
20  2816.0  2816.0  2816.0  158.327247  121.838028
21  2944.0  2944.0  2944.0  159.077605  125.422970
22  3072.0  3072.0  3072.0  132.185981  109.631785
23  3200.0  3200.0  3200.0  170.453600  126.497839
24  3328.0  3328.0  3328.0  180.118957  132.321733
25  3456.0  3456.0  3456.0  173.176057  137.173473
26  3584.0  3584.0  3584.0  187.201855  139.134075
27  3712.0  3712.0  3712.0  178.363230  141.666964
28  3840.0  3840.0  3840.0  181.437788  149.828280
29  3968.0  3968.0  3968.0  180.567219  144.034160
30  4096.0  4096.0  4096.0  184.571007  151.712016

@vlad-penkin vlad-penkin assigned vlad-penkin and unassigned etiotto Sep 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment