diff --git a/README.md b/README.md index 29b20f71..9dd2369b 100644 --- a/README.md +++ b/README.md @@ -551,11 +551,22 @@ We benchmark the performance of: compared with the following libraries: -- OpTree ([`@v0.8.0`](https://github.com/metaopt/optree/tree/v0.8.0)) +- OpTree ([`@v0.9.0`](https://github.com/metaopt/optree/tree/v0.9.0)) - JAX XLA ([`jax[cpu] == 0.4.6`](https://pypi.org/project/jax/0.4.6)) -- PyTorch ([`torch == 1.13.1`](https://pypi.org/project/torch/1.13.1)) +- PyTorch ([`torch == 2.0.0`](https://pypi.org/project/torch/2.0.0)) - DM-Tree ([`dm-tree == 0.1.8`](https://pypi.org/project/dm-tree/0.1.8)) +| Average Time Cost (↓) | OpTree (v0.9.0) | JAX XLA (v0.4.6) | PyTorch (v2.0.0) | DM-Tree (v0.1.8) | +| :------------------------- | --------------: | ---------------: | ---------------: | ---------------: | +| Tree Flatten | x1.00 | 2.33 | 22.05 | 1.12 | +| Tree UnFlatten | x1.00 | 2.69 | 4.28 | 16.23 | +| Tree Flatten with Path | x1.00 | 16.16 | Not Supported | 27.59 | +| Tree Copy | x1.00 | 2.56 | 9.97 | 11.02 | +| Tree Map | x1.00 | 2.56 | 9.58 | 10.62 | +| Tree Map (nargs) | x1.00 | 2.89 | Not Supported | 31.33 | +| Tree Map with Path | x1.00 | 7.23 | Not Supported | 19.66 | +| Tree Map with Path (nargs) | x1.00 | 6.56 | Not Supported | 29.61 | + All results are reported on a workstation with an AMD Ryzen 9 5950X CPU @ 4.45GHz in an isolated virtual environment with Python 3.10.9. Run with the following commands: @@ -586,152 +597,152 @@ Please refer to [`benchmark.py`](https://github.com/metaopt/optree/blob/HEAD/ben | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 30.18 | 68.69 | 577.79 | 31.55 | 2.28 | 19.15 | 1.05 | -| AlexNet | 188 | 97.68 | 242.74 | 2102.67 | 118.22 | 2.49 | 21.53 | 1.21 | -| ResNet18 | 698 | 346.24 | 787.04 | 7769.05 | 407.51 | 2.27 | 22.44 | 1.18 | -| ResNet34 | 1242 | 663.70 | 1431.62 | 13989.08 | 712.72 | 2.16 | 21.08 | 1.07 | -| ResNet50 | 1702 | 882.40 | 1906.07 | 19243.43 | 966.05 | 2.16 | 21.81 | 1.09 | -| ResNet101 | 3317 | 1847.35 | 3953.69 | 39870.71 | 2031.28 | 2.14 | 21.58 | 1.10 | -| ResNet152 | 4932 | 2678.84 | 5588.23 | 56023.10 | 2874.76 | 2.09 | 20.91 | 1.07 | -| ViT-H/14 | 3420 | 1947.77 | 4467.48 | 40057.71 | 2195.20 | 2.29 | 20.57 | 1.13 | -| Swin-B | 2881 | 1763.83 | 3985.11 | 35818.71 | 1968.06 | 2.26 | 20.31 | 1.12 | -| | | | | | **Average** | **2.24** | **21.04** | **1.11** | +| TinyMLP | 53 | 29.70 | 71.06 | 583.66 | 31.32 | 2.39 | 19.65 | 1.05 | +| AlexNet | 188 | 103.92 | 262.56 | 2304.36 | 119.61 | 2.53 | 22.17 | 1.15 | +| ResNet18 | 698 | 368.06 | 852.69 | 8440.31 | 420.43 | 2.32 | 22.93 | 1.14 | +| ResNet34 | 1242 | 644.96 | 1461.55 | 14498.81 | 712.81 | 2.27 | 22.48 | 1.11 | +| ResNet50 | 1702 | 919.95 | 2080.58 | 20995.96 | 1006.42 | 2.26 | 22.82 | 1.09 | +| ResNet101 | 3317 | 1806.36 | 3996.90 | 40314.12 | 1955.48 | 2.21 | 22.32 | 1.08 | +| ResNet152 | 4932 | 2656.92 | 5812.38 | 57775.53 | 2826.92 | 2.19 | 21.75 | 1.06 | +| ViT-H/14 | 3420 | 1863.50 | 4418.24 | 41334.64 | 2128.71 | 2.37 | 22.18 | 1.14 | +| Swin-B | 2881 | 1631.06 | 3944.13 | 36131.54 | 2032.77 | 2.42 | 22.15 | 1.25 | +| | | | | | **Average** | **2.33** | **22.05** | **1.12** |
- +
### Tree UnFlatten | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 54.91 | 134.31 | 234.37 | 913.43 | 2.45 | 4.27 | 16.64 | -| AlexNet | 188 | 210.76 | 565.19 | 929.61 | 3808.84 | 2.68 | 4.41 | 18.07 | -| ResNet18 | 698 | 703.22 | 1727.14 | 3184.54 | 11643.08 | 2.46 | 4.53 | 16.56 | -| ResNet34 | 1242 | 1312.01 | 3147.73 | 5762.68 | 20852.70 | 2.40 | 4.39 | 15.89 | -| ResNet50 | 1702 | 1758.62 | 4177.30 | 7891.72 | 27874.16 | 2.38 | 4.49 | 15.85 | -| ResNet101 | 3317 | 3753.81 | 8226.49 | 15362.37 | 53974.51 | 2.19 | 4.09 | 14.38 | -| ResNet152 | 4932 | 5313.30 | 12205.85 | 24068.88 | 80256.68 | 2.30 | 4.53 | 15.10 | -| ViT-H/14 | 3420 | 3994.53 | 10016.00 | 17411.04 | 66000.54 | 2.51 | 4.36 | 16.52 | -| Swin-B | 2881 | 3584.82 | 8940.27 | 15582.13 | 56003.34 | 2.49 | 4.35 | 15.62 | -| | | | | | **Average** | **2.42** | **4.38** | **16.07** | +| TinyMLP | 53 | 55.13 | 152.07 | 231.94 | 940.11 | 2.76 | 4.21 | 17.05 | +| AlexNet | 188 | 226.29 | 678.29 | 972.90 | 4195.04 | 3.00 | 4.30 | 18.54 | +| ResNet18 | 698 | 766.54 | 1953.26 | 3137.86 | 12049.88 | 2.55 | 4.09 | 15.72 | +| ResNet34 | 1242 | 1309.22 | 3526.12 | 5759.16 | 20966.75 | 2.69 | 4.40 | 16.01 | +| ResNet50 | 1702 | 1914.96 | 5002.83 | 8369.43 | 29597.10 | 2.61 | 4.37 | 15.46 | +| ResNet101 | 3317 | 3672.61 | 9633.29 | 15683.16 | 57240.20 | 2.62 | 4.27 | 15.59 | +| ResNet152 | 4932 | 5407.58 | 13970.88 | 23074.68 | 82072.54 | 2.58 | 4.27 | 15.18 | +| ViT-H/14 | 3420 | 4013.18 | 11146.31 | 17633.07 | 66723.58 | 2.78 | 4.39 | 16.63 | +| Swin-B | 2881 | 3595.34 | 9505.31 | 15054.88 | 57310.03 | 2.64 | 4.19 | 15.94 | +| | | | | | **Average** | **2.69** | **4.28** | **16.23** |
- +
### Tree Flatten with Path | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 35.55 | 522.00 | N/A | 915.30 | 14.68 | N/A | 25.75 | -| AlexNet | 188 | 113.89 | 2005.85 | N/A | 3503.25 | 17.61 | N/A | 30.76 | -| ResNet18 | 698 | 432.31 | 7052.73 | N/A | 12239.45 | 16.31 | N/A | 28.31 | -| ResNet34 | 1242 | 812.18 | 12657.12 | N/A | 21703.50 | 15.58 | N/A | 26.72 | -| ResNet50 | 1702 | 1105.15 | 17173.43 | N/A | 29293.02 | 15.54 | N/A | 26.51 | -| ResNet101 | 3317 | 2182.68 | 33455.81 | N/A | 56810.61 | 15.33 | N/A | 26.03 | -| ResNet152 | 4932 | 3272.97 | 49550.72 | N/A | 84535.23 | 15.14 | N/A | 25.83 | -| ViT-H/14 | 3420 | 2287.13 | 37485.28 | N/A | 63024.61 | 16.39 | N/A | 27.56 | -| Swin-B | 2881 | 2092.16 | 33942.08 | N/A | 52744.88 | 16.22 | N/A | 25.21 | -| | | | | | **Average** | **15.87** | N/A | **26.96** | +| TinyMLP | 53 | 36.49 | 543.67 | N/A | 919.13 | 14.90 | N/A | 25.19 | +| AlexNet | 188 | 115.44 | 2185.21 | N/A | 3752.11 | 18.93 | N/A | 32.50 | +| ResNet18 | 698 | 431.84 | 7106.55 | N/A | 12286.70 | 16.46 | N/A | 28.45 | +| ResNet34 | 1242 | 845.61 | 13431.99 | N/A | 22860.48 | 15.88 | N/A | 27.03 | +| ResNet50 | 1702 | 1166.27 | 18426.52 | N/A | 31225.05 | 15.80 | N/A | 26.77 | +| ResNet101 | 3317 | 2312.77 | 34770.49 | N/A | 59346.86 | 15.03 | N/A | 25.66 | +| ResNet152 | 4932 | 3304.74 | 50557.25 | N/A | 85847.91 | 15.30 | N/A | 25.98 | +| ViT-H/14 | 3420 | 2235.25 | 37473.53 | N/A | 64105.24 | 16.76 | N/A | 28.68 | +| Swin-B | 2881 | 1970.25 | 32205.83 | N/A | 55177.50 | 16.35 | N/A | 28.01 | +| | | | | | **Average** | **16.16** | N/A | **27.59** |
- +
### Tree Copy | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 91.28 | 211.31 | 833.84 | 952.72 | 2.31 | 9.13 | 10.44 | -| AlexNet | 188 | 320.20 | 825.66 | 3118.08 | 3938.46 | 2.58 | 9.74 | 12.30 | -| ResNet18 | 698 | 1113.83 | 2578.31 | 11325.44 | 12068.00 | 2.31 | 10.17 | 10.83 | -| ResNet34 | 1242 | 2050.00 | 4836.56 | 20324.52 | 22749.73 | 2.36 | 9.91 | 11.10 | -| ResNet50 | 1702 | 2897.93 | 6121.16 | 27563.39 | 28840.04 | 2.11 | 9.51 | 9.95 | -| ResNet101 | 3317 | 5456.58 | 12306.26 | 53733.62 | 56140.12 | 2.26 | 9.85 | 10.29 | -| ResNet152 | 4932 | 8044.33 | 18873.23 | 79896.03 | 83215.06 | 2.35 | 9.93 | 10.34 | -| ViT-H/14 | 3420 | 6046.78 | 14451.37 | 58204.01 | 70966.61 | 2.39 | 9.63 | 11.74 | -| Swin-B | 2881 | 5173.48 | 13174.36 | 51701.17 | 60053.21 | 2.55 | 9.99 | 11.61 | -| | | | | | **Average** | **2.36** | **9.76** | **10.96** | +| TinyMLP | 53 | 89.81 | 232.26 | 845.20 | 981.48 | 2.59 | 9.41 | 10.93 | +| AlexNet | 188 | 334.58 | 959.32 | 3360.46 | 4316.05 | 2.87 | 10.04 | 12.90 | +| ResNet18 | 698 | 1128.11 | 2840.71 | 11471.07 | 12297.07 | 2.52 | 10.17 | 10.90 | +| ResNet34 | 1242 | 2160.57 | 5333.10 | 20563.06 | 21901.91 | 2.47 | 9.52 | 10.14 | +| ResNet50 | 1702 | 2746.84 | 6823.88 | 29705.99 | 28927.88 | 2.48 | 10.81 | 10.53 | +| ResNet101 | 3317 | 5762.05 | 13481.45 | 56968.78 | 60115.93 | 2.34 | 9.89 | 10.43 | +| ResNet152 | 4932 | 8151.21 | 20805.61 | 81024.06 | 84079.57 | 2.55 | 9.94 | 10.31 | +| ViT-H/14 | 3420 | 5963.61 | 15665.91 | 59813.52 | 68377.82 | 2.63 | 10.03 | 11.47 | +| Swin-B | 2881 | 5401.59 | 14255.33 | 53361.77 | 62317.07 | 2.64 | 9.88 | 11.54 | +| | | | | | **Average** | **2.56** | **9.97** | **11.02** |
- +
### Tree Map | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 99.25 | 229.26 | 848.37 | 968.28 | 2.31 | 8.55 | 9.76 | -| AlexNet | 188 | 332.74 | 853.33 | 3142.39 | 3992.22 | 2.56 | 9.44 | 12.00 | -| ResNet18 | 698 | 1190.94 | 2760.28 | 11399.95 | 12294.02 | 2.32 | 9.57 | 10.32 | -| ResNet34 | 1242 | 2286.53 | 4925.70 | 20423.57 | 23204.74 | 2.15 | 8.93 | 10.15 | -| ResNet50 | 1702 | 2968.51 | 6622.94 | 27807.01 | 29259.40 | 2.23 | 9.37 | 9.86 | -| ResNet101 | 3317 | 5851.06 | 13132.59 | 53999.13 | 57251.12 | 2.24 | 9.23 | 9.78 | -| ResNet152 | 4932 | 8682.55 | 19346.59 | 80462.95 | 84364.39 | 2.23 | 9.27 | 9.72 | -| ViT-H/14 | 3420 | 6695.68 | 16045.45 | 58313.07 | 68415.82 | 2.40 | 8.71 | 10.22 | -| Swin-B | 2881 | 5747.50 | 13757.05 | 52229.81 | 61017.78 | 2.39 | 9.09 | 10.62 | -| | | | | | **Average** | **2.32** | **9.13** | **10.27** | +| TinyMLP | 53 | 95.13 | 243.86 | 867.34 | 1026.99 | 2.56 | 9.12 | 10.80 | +| AlexNet | 188 | 348.44 | 987.57 | 3398.32 | 4354.81 | 2.83 | 9.75 | 12.50 | +| ResNet18 | 698 | 1190.62 | 2982.66 | 11719.94 | 12559.01 | 2.51 | 9.84 | 10.55 | +| ResNet34 | 1242 | 2205.87 | 5417.60 | 20935.72 | 22308.51 | 2.46 | 9.49 | 10.11 | +| ResNet50 | 1702 | 3128.48 | 7579.55 | 30372.71 | 31638.67 | 2.42 | 9.71 | 10.11 | +| ResNet101 | 3317 | 6173.05 | 14846.57 | 59167.85 | 60245.42 | 2.41 | 9.58 | 9.76 | +| ResNet152 | 4932 | 8641.22 | 22000.74 | 84018.65 | 86182.21 | 2.55 | 9.72 | 9.97 | +| ViT-H/14 | 3420 | 6211.79 | 17077.49 | 59790.25 | 69763.86 | 2.75 | 9.63 | 11.23 | +| Swin-B | 2881 | 5673.66 | 14339.69 | 53309.17 | 59764.61 | 2.53 | 9.40 | 10.53 | +| | | | | | **Average** | **2.56** | **9.58** | **10.62** |
- +
### Tree Map (nargs) | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 133.87 | 344.99 | N/A | 3599.07 | 2.58 | N/A | 26.89 | -| AlexNet | 188 | 445.41 | 1310.77 | N/A | 14207.10 | 2.94 | N/A | 31.90 | -| ResNet18 | 698 | 1599.16 | 4239.56 | N/A | 49255.49 | 2.65 | N/A | 30.80 | -| ResNet34 | 1242 | 3066.14 | 8115.79 | N/A | 88568.31 | 2.65 | N/A | 28.89 | -| ResNet50 | 1702 | 3951.48 | 10557.52 | N/A | 127232.92 | 2.67 | N/A | 32.20 | -| ResNet101 | 3317 | 7801.80 | 20208.53 | N/A | 235961.43 | 2.59 | N/A | 30.24 | -| ResNet152 | 4932 | 11489.21 | 29375.98 | N/A | 349007.54 | 2.56 | N/A | 30.38 | -| ViT-H/14 | 3420 | 8319.66 | 23204.11 | N/A | 266190.21 | 2.79 | N/A | 32.00 | -| Swin-B | 2881 | 7259.47 | 20098.17 | N/A | 226166.17 | 2.77 | N/A | 31.15 | -| | | | | | **Average** | **2.69** | N/A | **30.49** | +| TinyMLP | 53 | 137.06 | 389.96 | N/A | 3908.77 | 2.85 | N/A | 28.52 | +| AlexNet | 188 | 467.24 | 1496.96 | N/A | 15395.13 | 3.20 | N/A | 32.95 | +| ResNet18 | 698 | 1603.79 | 4534.01 | N/A | 50323.76 | 2.83 | N/A | 31.38 | +| ResNet34 | 1242 | 2907.64 | 8435.33 | N/A | 90389.23 | 2.90 | N/A | 31.09 | +| ResNet50 | 1702 | 4183.77 | 11382.51 | N/A | 121777.01 | 2.72 | N/A | 29.11 | +| ResNet101 | 3317 | 7721.13 | 22247.85 | N/A | 238755.17 | 2.88 | N/A | 30.92 | +| ResNet152 | 4932 | 11508.05 | 31429.39 | N/A | 360257.74 | 2.73 | N/A | 31.30 | +| ViT-H/14 | 3420 | 8294.20 | 24524.86 | N/A | 270514.87 | 2.96 | N/A | 32.61 | +| Swin-B | 2881 | 7074.62 | 20854.80 | N/A | 241120.41 | 2.95 | N/A | 34.08 | +| | | | | | **Average** | **2.89** | N/A | **31.33** |
- +
### Tree Map with Path | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 104.70 | 703.83 | N/A | 1998.00 | 6.72 | N/A | 19.08 | -| AlexNet | 188 | 352.30 | 2668.73 | N/A | 7681.19 | 7.58 | N/A | 21.80 | -| ResNet18 | 698 | 1289.51 | 9342.79 | N/A | 25497.31 | 7.25 | N/A | 19.77 | -| ResNet34 | 1242 | 2366.46 | 16746.52 | N/A | 45254.59 | 7.08 | N/A | 19.12 | -| ResNet50 | 1702 | 3399.11 | 23574.46 | N/A | 60494.27 | 6.94 | N/A | 17.80 | -| ResNet101 | 3317 | 6329.82 | 43955.95 | N/A | 118725.60 | 6.94 | N/A | 18.76 | -| ResNet152 | 4932 | 9307.87 | 64777.45 | N/A | 174764.97 | 6.96 | N/A | 18.78 | -| ViT-H/14 | 3420 | 6705.10 | 48862.92 | N/A | 139617.21 | 7.29 | N/A | 20.82 | -| Swin-B | 2881 | 5780.20 | 41703.04 | N/A | 115003.61 | 7.21 | N/A | 19.90 | -| | | | | | **Average** | **7.11** | N/A | **19.54** | +| TinyMLP | 53 | 109.82 | 778.30 | N/A | 2186.40 | 7.09 | N/A | 19.91 | +| AlexNet | 188 | 365.16 | 2939.36 | N/A | 8355.37 | 8.05 | N/A | 22.88 | +| ResNet18 | 698 | 1308.26 | 9529.58 | N/A | 25758.24 | 7.28 | N/A | 19.69 | +| ResNet34 | 1242 | 2527.21 | 18084.89 | N/A | 45942.32 | 7.16 | N/A | 18.18 | +| ResNet50 | 1702 | 3226.03 | 22935.53 | N/A | 61275.34 | 7.11 | N/A | 18.99 | +| ResNet101 | 3317 | 6663.52 | 46878.89 | N/A | 126642.14 | 7.04 | N/A | 19.01 | +| ResNet152 | 4932 | 9378.19 | 66136.44 | N/A | 176981.01 | 7.05 | N/A | 18.87 | +| ViT-H/14 | 3420 | 7033.69 | 50418.37 | N/A | 142508.11 | 7.17 | N/A | 20.26 | +| Swin-B | 2881 | 6078.15 | 43173.22 | N/A | 116612.71 | 7.10 | N/A | 19.19 | +| | | | | | **Average** | **7.23** | N/A | **19.66** |
- +
### Tree Map with Path (nargs) | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | -| TinyMLP | 53 | 138.19 | 828.53 | N/A | 3599.32 | 6.00 | N/A | 26.05 | -| AlexNet | 188 | 461.91 | 3138.59 | N/A | 14069.17 | 6.79 | N/A | 30.46 | -| ResNet18 | 698 | 1702.79 | 10890.25 | N/A | 49456.32 | 6.40 | N/A | 29.04 | -| ResNet34 | 1242 | 3115.89 | 19356.46 | N/A | 88955.96 | 6.21 | N/A | 28.55 | -| ResNet50 | 1702 | 4422.25 | 26205.69 | N/A | 121569.30 | 5.93 | N/A | 27.49 | -| ResNet101 | 3317 | 8334.83 | 50909.37 | N/A | 241862.38 | 6.11 | N/A | 29.02 | -| ResNet152 | 4932 | 12208.52 | 75327.94 | N/A | 351472.89 | 6.17 | N/A | 28.79 | -| ViT-H/14 | 3420 | 9320.75 | 56869.65 | N/A | 266430.78 | 6.10 | N/A | 28.58 | -| Swin-B | 2881 | 7472.11 | 49260.03 | N/A | 233154.60 | 6.59 | N/A | 31.20 | -| | | | | | **Average** | **6.26** | N/A | **28.80** | +| TinyMLP | 53 | 146.05 | 917.00 | N/A | 3940.61 | 6.28 | N/A | 26.98 | +| AlexNet | 188 | 489.27 | 3560.76 | N/A | 15434.71 | 7.28 | N/A | 31.55 | +| ResNet18 | 698 | 1712.79 | 11171.44 | N/A | 50219.86 | 6.52 | N/A | 29.32 | +| ResNet34 | 1242 | 3112.83 | 21024.58 | N/A | 95505.71 | 6.75 | N/A | 30.68 | +| ResNet50 | 1702 | 4220.70 | 26600.82 | N/A | 121897.57 | 6.30 | N/A | 28.88 | +| ResNet101 | 3317 | 8631.34 | 54372.37 | N/A | 236555.54 | 6.30 | N/A | 27.41 | +| ResNet152 | 4932 | 12710.49 | 77643.13 | N/A | 353600.32 | 6.11 | N/A | 27.82 | +| ViT-H/14 | 3420 | 8753.09 | 58712.71 | N/A | 286365.36 | 6.71 | N/A | 32.72 | +| Swin-B | 2881 | 7359.29 | 50112.23 | N/A | 228866.66 | 6.81 | N/A | 31.10 | +| | | | | | **Average** | **6.56** | N/A | **29.61** |
- +
-------------------------------------------------------------------------------- diff --git a/benchmark.py b/benchmark.py index 733cb7dc..8e1116f7 100755 --- a/benchmark.py +++ b/benchmark.py @@ -15,15 +15,18 @@ # limitations under the License. # ============================================================================== -# pylint: disable=missing-module-docstring,missing-function-docstring,invalid-name +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring +# pylint: disable=invalid-name from __future__ import annotations import argparse import sys -from collections import OrderedDict, namedtuple +import textwrap +import timeit +from collections import OrderedDict from itertools import count -from typing import Any, Iterable +from typing import Any, Iterable, NamedTuple, Sequence import jax import pandas as pd @@ -52,56 +55,6 @@ def colored( # pylint: disable=function-redefined,unused-argument return text -torch_utils_pytree._register_pytree_node( # pylint: disable=protected-access - dict, - lambda d: tuple(zip(*sorted(d.items())))[::-1] if d else ((), ()), - lambda values, keys: dict(zip(keys, values)), -) - - -Tensors = namedtuple('Tensors', ['parameters', 'buffers']) -Containers = namedtuple('Containers', ['parameters', 'buffers']) - - -def get_none() -> None: - return None - - -def tiny_mlp() -> nn.Module: - return nn.Sequential( - nn.Linear(1, 1, bias=True), - nn.BatchNorm1d(1, affine=True, track_running_stats=True), - nn.ReLU(), - nn.Linear(1, 1, bias=False), - nn.Sigmoid(), - ) - - -def extract(module: nn.Module, unordered: bool) -> Any: - if isinstance(module, (nn.Sequential, nn.ModuleList)): - return [extract(submodule, unordered=unordered) for submodule in module] - - dict_factory = dict if unordered else OrderedDict - extracted = dict_factory( - [ - (name, extract(submodule, unordered=unordered)) - for name, submodule in module.named_children() - ], - ) - extracted.update( - tensors=Tensors( - parameters=tuple(t.data for t in module.parameters(recurse=False)), - buffers=[t.data for t in module.buffers(recurse=False)], - ), - containers=Containers( - parameters=dict_factory(module._parameters), # pylint: disable=protected-access - buffers=dict_factory(module._buffers), # pylint: disable=protected-access - ), - ) - - return extracted - - cmark = '✔' # pylint: disable=invalid-name xmark = '✘' # pylint: disable=invalid-name tie = '~' # pylint: disable=invalid-name @@ -128,198 +81,320 @@ def fnp3(p, x, y, z): return None """ -STMTS = OrderedDict( + +class BenchmarkCase(NamedTuple): + name: str + stmt: str + init_stmt: str = '' + + def timeit( + self, + number: int, + repeat: int = 5, + globals: dict[str, Any] | None = None, # pylint: disable=redefined-builtin + ) -> float: + init_warmup = ( + INIT.strip() + + '\n\n' + + textwrap.dedent( + f""" + {self.init_stmt.strip()} + + for _ in range({number} // 10): + {self.stmt.strip()} + """, + ).strip() + ) + + times = timeit.repeat( + self.stmt.strip(), + setup=init_warmup, + number=number, + repeat=repeat, + globals=globals, + ) + return min(times) / number + + +BENCHMARK_CASES: dict[str, Sequence[BenchmarkCase]] = OrderedDict( [ ( 'Tree Flatten', - OrderedDict( - [ - ('OpTree(default)', ('optree.tree_leaves(x)', '')), - ('OpTree(NoneIsNode)', ('optree.tree_leaves(x, none_is_leaf=False)', '')), - ('OpTree(NoneIsLeaf)', ('optree.tree_leaves(x, none_is_leaf=True)', '')), - ('JAX XLA', ('jax.tree_util.tree_leaves(x)', '')), - ('PyTorch', ('torch_utils_pytree.tree_flatten(x)[0]', '')), - ('DM-Tree', ('dm_tree.flatten(x)', '')), - ], - ), + [ + BenchmarkCase( + name='OpTree(default)', + stmt='optree.tree_leaves(x)', + ), + BenchmarkCase( + name='OpTree(NoneIsNode)', + stmt='optree.tree_leaves(x, none_is_leaf=False)', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_leaves(x, none_is_leaf=True)', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_leaves(x)', + ), + BenchmarkCase( + name='PyTorch', + stmt='torch_utils_pytree.tree_flatten(x)[0]', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.flatten(x)', + ), + ], ), ( 'Tree UnFlatten', - OrderedDict( - [ - ( - 'OpTree(NoneIsNode)', # default - ( - 'optree.tree_unflatten(spec, flat)', - 'flat, spec = optree.tree_flatten(x, none_is_leaf=False)', - ), - ), - ( - 'OpTree(NoneIsLeaf)', - ( - 'optree.tree_unflatten(spec, flat)', - 'flat, spec = optree.tree_flatten(x, none_is_leaf=True)', - ), - ), - ( - 'JAX XLA', - ( - 'jax.tree_util.tree_unflatten(spec, flat)', - 'flat, spec = jax.tree_util.tree_flatten(x)', - ), - ), - ( - 'PyTorch', - ( - 'torch_utils_pytree.tree_unflatten(flat, spec)', - 'flat, spec = torch_utils_pytree.tree_flatten(x)', - ), - ), - ( - 'DM-Tree', - ( - 'dm_tree.unflatten_as(spec, flat)', - 'flat, spec = dm_tree.flatten(x), x', - ), - ), - ], - ), + [ + BenchmarkCase( + name='OpTree(NoneIsNode)', # default + stmt='optree.tree_unflatten(spec, flat)', + init_stmt='flat, spec = optree.tree_flatten(x, none_is_leaf=False)', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_unflatten(spec, flat)', + init_stmt='flat, spec = optree.tree_flatten(x, none_is_leaf=True)', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_unflatten(spec, flat)', + init_stmt='flat, spec = jax.tree_util.tree_flatten(x)', + ), + BenchmarkCase( + name='PyTorch', + stmt='torch_utils_pytree.tree_unflatten(flat, spec)', + init_stmt='flat, spec = torch_utils_pytree.tree_flatten(x)', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.unflatten_as(spec, flat)', + init_stmt='flat, spec = dm_tree.flatten(x), x', + ), + ], ), ( 'Tree Flatten with Path', - OrderedDict( - [ - ('OpTree(default)', ('optree.tree_flatten_with_path(x)', '')), - ( - 'OpTree(NoneIsNode)', - ('optree.tree_flatten_with_path(x, none_is_leaf=False)', ''), - ), - ( - 'OpTree(NoneIsLeaf)', - ('optree.tree_flatten_with_path(x, none_is_leaf=True)', ''), - ), - ('JAX XLA', ('jax.tree_util.tree_flatten_with_path(x)', '')), - ('DM-Tree', ('dm_tree.flatten_with_path(x)', '')), - ], - ), + [ + BenchmarkCase( + name='OpTree(default)', + stmt='optree.tree_flatten_with_path(x)', + ), + BenchmarkCase( + name='OpTree(NoneIsNode)', + stmt='optree.tree_flatten_with_path(x, none_is_leaf=False)', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_flatten_with_path(x, none_is_leaf=True)', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_flatten_with_path(x)', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.flatten_with_path(x)', + ), + ], ), ( 'Tree Copy', - OrderedDict( - [ - ( - 'OpTree(default)', - ('optree.tree_unflatten(*optree.tree_flatten(x)[::-1])', ''), - ), - ( - 'OpTree(NoneIsNode)', - ( - 'optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1])', - '', - ), - ), - ( - 'OpTree(NoneIsLeaf)', - ( - 'optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1])', - '', - ), - ), - ( - 'JAX XLA', - ('jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])', ''), - ), - ( - 'PyTorch', - ( - 'torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))', - '', - ), - ), - ( - 'DM-Tree', - ('dm_tree.unflatten_as(x, dm_tree.flatten(x))', ''), - ), - ], - ), + [ + BenchmarkCase( + name='OpTree(default)', + stmt='optree.tree_unflatten(*optree.tree_flatten(x)[::-1])', + ), + BenchmarkCase( + name='OpTree(NoneIsNode)', + stmt='optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1])', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1])', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])', + ), + BenchmarkCase( + name='PyTorch', + stmt='torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.unflatten_as(x, dm_tree.flatten(x))', + ), + ], ), ( 'Tree Map', - OrderedDict( - [ - ('OpTree(default)', ('optree.tree_map(fn1, x)', '')), - ('OpTree(NoneIsNode)', ('optree.tree_map(fn1, x, none_is_leaf=False)', '')), - ('OpTree(NoneIsLeaf)', ('optree.tree_map(fn1, x, none_is_leaf=True)', '')), - ('JAX XLA', ('jax.tree_util.tree_map(fn1, x)', '')), - ('PyTorch', ('torch_utils_pytree.tree_map(fn1, x)', '')), - ('DM-Tree', ('dm_tree.map_structure(fn1, x)', '')), - ], - ), + [ + BenchmarkCase( + name='OpTree(default)', + stmt='optree.tree_map(fn1, x)', + ), + BenchmarkCase( + name='OpTree(NoneIsNode)', + stmt='optree.tree_map(fn1, x, none_is_leaf=False)', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_map(fn1, x, none_is_leaf=True)', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_map(fn1, x)', + ), + BenchmarkCase( + name='PyTorch', + stmt='torch_utils_pytree.tree_map(fn1, x)', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.map_structure(fn1, x)', + ), + ], ), ( 'Tree Map (nargs)', - OrderedDict( - [ - ('OpTree(default)', ('optree.tree_map(fn3, x, y, z)', '')), - ( - 'OpTree(NoneIsNode)', - ('optree.tree_map(fn3, x, y, z, none_is_leaf=False)', ''), - ), - ( - 'OpTree(NoneIsLeaf)', - ('optree.tree_map(fn3, x, y, z, none_is_leaf=True)', ''), - ), - ('JAX XLA', ('jax.tree_util.tree_map(fn3, x, y, z)', '')), - ('DM-Tree', ('dm_tree.map_structure_up_to(x, fn3, x, y, z)', '')), - ], - ), + [ + BenchmarkCase( + name='OpTree(default)', + stmt='optree.tree_map(fn3, x, y, z)', + ), + BenchmarkCase( + name='OpTree(NoneIsNode)', + stmt='optree.tree_map(fn3, x, y, z, none_is_leaf=False)', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_map(fn3, x, y, z, none_is_leaf=True)', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_map(fn3, x, y, z)', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.map_structure_up_to(x, fn3, x, y, z)', + ), + ], ), ( 'Tree Map with Path', - OrderedDict( - [ - ('OpTree(default)', ('optree.tree_map_with_path(fnp1, x)', '')), - ( - 'OpTree(NoneIsNode)', - ('optree.tree_map_with_path(fnp1, x, none_is_leaf=False)', ''), - ), - ( - 'OpTree(NoneIsLeaf)', - ('optree.tree_map_with_path(fnp1, x, none_is_leaf=True)', ''), - ), - ('JAX XLA', ('jax.tree_util.tree_map_with_path(fnp1, x)', '')), - ('DM-Tree', ('dm_tree.map_structure_with_path(fnp1, x)', '')), - ], - ), + [ + BenchmarkCase( + name='OpTree(default)', + stmt='optree.tree_map_with_path(fnp1, x)', + ), + BenchmarkCase( + name='OpTree(NoneIsNode)', + stmt='optree.tree_map_with_path(fnp1, x, none_is_leaf=False)', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_map_with_path(fnp1, x, none_is_leaf=True)', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_map_with_path(fnp1, x)', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.map_structure_with_path(fnp1, x)', + ), + ], ), ( 'Tree Map with Path (nargs)', - OrderedDict( - [ - ('OpTree(default)', ('optree.tree_map_with_path(fnp3, x, y, z)', '')), - ( - 'OpTree(NoneIsNode)', - ( - 'optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=False)', - '', - ), - ), - ( - 'OpTree(NoneIsLeaf)', - ( - 'optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=True)', - '', - ), - ), - ('JAX XLA', ('jax.tree_util.tree_map_with_path(fnp3, x, y, z)', '')), - ('DM-Tree', ('dm_tree.map_structure_with_path(fnp3, x, y, z)', '')), - ], - ), + [ + BenchmarkCase( + name='OpTree(default)', + stmt='optree.tree_map_with_path(fnp3, x, y, z)', + ), + BenchmarkCase( + name='OpTree(NoneIsNode)', + stmt='optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=False)', + ), + BenchmarkCase( + name='OpTree(NoneIsLeaf)', + stmt='optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=True)', + ), + BenchmarkCase( + name='JAX XLA', + stmt='jax.tree_util.tree_map_with_path(fnp3, x, y, z)', + ), + BenchmarkCase( + name='DM-Tree', + stmt='dm_tree.map_structure_with_path(fnp3, x, y, z)', + ), + ], ), ], ) +torch_utils_pytree._register_pytree_node( # pylint: disable=protected-access + dict, + lambda d: tuple(zip(*sorted(d.items())))[::-1] if d else ((), ()), + lambda values, keys: dict(zip(keys, values)), +) + + +class Tensors(NamedTuple): + parameters: tuple[torch.Tensor, ...] + buffers: list[torch.Tensor] + + +class Containers(NamedTuple): + parameters: dict[str, torch.Tensor | None] + buffers: dict[str, torch.Tensor | None] + + +def get_none() -> None: + return None + + +def tiny_mlp() -> nn.Module: + return nn.Sequential( + nn.Linear(1, 1, bias=True), + nn.BatchNorm1d(1, affine=True, track_running_stats=True), + nn.ReLU(), + nn.Linear(1, 1, bias=False), + nn.Sigmoid(), + ) + + +def extract(module: nn.Module, unordered: bool) -> Any: + if isinstance(module, (nn.Sequential, nn.ModuleList)): + return [extract(submodule, unordered=unordered) for submodule in module] + + dict_factory = dict if unordered else OrderedDict + extracted = dict_factory( + [ + (name, extract(submodule, unordered=unordered)) + for name, submodule in module.named_children() + ], + ) + extracted.update( + tensors=Tensors( + parameters=tuple(t.data for t in module.parameters(recurse=False)), + buffers=[t.data for t in module.buffers(recurse=False)], + ), + containers=Containers( + parameters=dict_factory(module._parameters), # pylint: disable=protected-access + buffers=dict_factory(module._buffers), # pylint: disable=protected-access + ), + ) + + return extracted + + def cprint(text: str = '') -> None: text = ( text.replace( @@ -439,58 +514,26 @@ def check(tree: Any) -> None: print(flush=True) -def timeit( - stmt: str, - init_stmt: str, - number: int, - repeat: int = 5, - globals: dict[str, Any] | None = None, # pylint: disable=redefined-builtin -) -> float: - import timeit # pylint: disable=redefined-outer-name,import-outside-toplevel - - globals = globals or {} - init_warmup = f""" -{INIT.strip()} - -{init_stmt.strip()} - -for _ in range({number} // 10): - {stmt.strip()} -""" - - times = timeit.repeat( - stmt.strip(), - setup=init_warmup, - globals=globals, - number=number, - repeat=repeat, - ) - return min(times) / number - - def compare( # pylint: disable=too-many-locals subject: str, - stmts: dict[str, tuple[str, str]], + cases: Sequence[BenchmarkCase], number: int, repeat: int = 5, globals: dict[str, Any] | None = None, # pylint: disable=redefined-builtin ) -> dict[str, float]: times_us = OrderedDict( - [ - (lib, 10e6 * timeit(stmt, init_stmt, number, repeat=repeat, globals=globals)) - for lib, (stmt, init_stmt) in stmts.items() - ], + [(case.name, 10e6 * case.timeit(number, repeat=repeat, globals=globals)) for case in cases], ) base_time = next(iter(times_us.values())) best_time = min(times_us.values()) - speedups = {lib: time / base_time for lib, time in times_us.items()} - best_speedups = {lib: time / best_time for lib, time in times_us.items()} + speedups = {name: time / base_time for name, time in times_us.items()} + best_speedups = {name: time / best_time for name, time in times_us.items()} labels = { - lib: (cmark if speedup == 1.0 else tie if speedup < 1.1 else ' ') - for lib, speedup in best_speedups.items() + name: (cmark if speedup == 1.0 else tie if speedup < 1.1 else ' ') + for name, speedup in best_speedups.items() } colors = { - lib: ( + name: ( 'green' if speedup == 1.0 else 'cyan' @@ -499,18 +542,18 @@ def compare( # pylint: disable=too-many-locals if speedup < 4.0 else 'red' ) - for lib, speedup in best_speedups.items() + for name, speedup in best_speedups.items() } - attrs = {lib: ('bold',) if color == 'green' else () for lib, color in colors.items()} + attrs = {name: ('bold',) if color == 'green' else () for name, color in colors.items()} cprint(f'### {subject} ###') - for lib, (stmt, _) in stmts.items(): - label = labels[lib] - color = colors[lib] - attr = attrs[lib] - label = colored(label + ' ' + lib, color=color, attrs=attr) - time = colored(f'{times_us[lib]:8.2f}μs', attrs=attr) - speedup = speedups[lib] + for name, stmt, _ in cases: + label = labels[name] + color = colors[name] + attr = attrs[name] + label = colored(label + ' ' + name, color=color, attrs=attr) + time = colored(f'{times_us[name]:8.2f}μs', attrs=attr) + speedup = speedups[name] if speedup != 1.0: speedup = ' -- ' + colored(f'x{speedup:.2f}'.ljust(6), color=color, attrs=attr) else: @@ -573,10 +616,10 @@ def benchmark( # pylint: disable=too-many-locals z = optree.tree_map(lambda t: (t, None), x) # pylint: disable=invalid-name check(x) - for subject, stmts in STMTS.items(): + for subject, cases in BENCHMARK_CASES.items(): times_us = compare( subject, - stmts, + cases, number=number, repeat=repeat, globals={'x': x, 'y': y, 'z': z}, diff --git a/conda-recipe.yaml b/conda-recipe.yaml index ca3d615e..07295596 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -39,7 +39,7 @@ dependencies: - pybind11 # Benchmark - - pytorch::pytorch >= 1.13, < 1.14.0a0 + - pytorch::pytorch >= 2.0, < 2.1.0a0 - pytorch::torchvision - pytorch::pytorch-mutex = *=*cpu* - jax >= 0.4.6, < 0.5.0a0 diff --git a/pyproject.toml b/pyproject.toml index 747706a0..b77b18b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ lint = [ test = ['pytest', 'pytest-cov', 'pytest-xdist'] benchmark = [ 'jax[cpu] >= 0.4.6, < 0.5.0a0', - 'torch >= 1.13, < 1.14.0a0', + 'torch >= 2.0, < 2.1.0a0', 'torchvision', 'dm-tree >= 0.1, < 0.2.0a0', 'pandas',