diff --git a/README.md b/README.md
index 29b20f71..9dd2369b 100644
--- a/README.md
+++ b/README.md
@@ -551,11 +551,22 @@ We benchmark the performance of:
compared with the following libraries:
-- OpTree ([`@v0.8.0`](https://github.com/metaopt/optree/tree/v0.8.0))
+- OpTree ([`@v0.9.0`](https://github.com/metaopt/optree/tree/v0.9.0))
- JAX XLA ([`jax[cpu] == 0.4.6`](https://pypi.org/project/jax/0.4.6))
-- PyTorch ([`torch == 1.13.1`](https://pypi.org/project/torch/1.13.1))
+- PyTorch ([`torch == 2.0.0`](https://pypi.org/project/torch/2.0.0))
- DM-Tree ([`dm-tree == 0.1.8`](https://pypi.org/project/dm-tree/0.1.8))
+| Average Time Cost (↓) | OpTree (v0.9.0) | JAX XLA (v0.4.6) | PyTorch (v2.0.0) | DM-Tree (v0.1.8) |
+| :------------------------- | --------------: | ---------------: | ---------------: | ---------------: |
+| Tree Flatten | x1.00 | 2.33 | 22.05 | 1.12 |
+| Tree UnFlatten | x1.00 | 2.69 | 4.28 | 16.23 |
+| Tree Flatten with Path | x1.00 | 16.16 | Not Supported | 27.59 |
+| Tree Copy | x1.00 | 2.56 | 9.97 | 11.02 |
+| Tree Map | x1.00 | 2.56 | 9.58 | 10.62 |
+| Tree Map (nargs) | x1.00 | 2.89 | Not Supported | 31.33 |
+| Tree Map with Path | x1.00 | 7.23 | Not Supported | 19.66 |
+| Tree Map with Path (nargs) | x1.00 | 6.56 | Not Supported | 29.61 |
+
All results are reported on a workstation with an AMD Ryzen 9 5950X CPU @ 4.45GHz in an isolated virtual environment with Python 3.10.9.
Run with the following commands:
@@ -586,152 +597,152 @@ Please refer to [`benchmark.py`](https://github.com/metaopt/optree/blob/HEAD/ben
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 30.18 | 68.69 | 577.79 | 31.55 | 2.28 | 19.15 | 1.05 |
-| AlexNet | 188 | 97.68 | 242.74 | 2102.67 | 118.22 | 2.49 | 21.53 | 1.21 |
-| ResNet18 | 698 | 346.24 | 787.04 | 7769.05 | 407.51 | 2.27 | 22.44 | 1.18 |
-| ResNet34 | 1242 | 663.70 | 1431.62 | 13989.08 | 712.72 | 2.16 | 21.08 | 1.07 |
-| ResNet50 | 1702 | 882.40 | 1906.07 | 19243.43 | 966.05 | 2.16 | 21.81 | 1.09 |
-| ResNet101 | 3317 | 1847.35 | 3953.69 | 39870.71 | 2031.28 | 2.14 | 21.58 | 1.10 |
-| ResNet152 | 4932 | 2678.84 | 5588.23 | 56023.10 | 2874.76 | 2.09 | 20.91 | 1.07 |
-| ViT-H/14 | 3420 | 1947.77 | 4467.48 | 40057.71 | 2195.20 | 2.29 | 20.57 | 1.13 |
-| Swin-B | 2881 | 1763.83 | 3985.11 | 35818.71 | 1968.06 | 2.26 | 20.31 | 1.12 |
-| | | | | | **Average** | **2.24** | **21.04** | **1.11** |
+| TinyMLP | 53 | 29.70 | 71.06 | 583.66 | 31.32 | 2.39 | 19.65 | 1.05 |
+| AlexNet | 188 | 103.92 | 262.56 | 2304.36 | 119.61 | 2.53 | 22.17 | 1.15 |
+| ResNet18 | 698 | 368.06 | 852.69 | 8440.31 | 420.43 | 2.32 | 22.93 | 1.14 |
+| ResNet34 | 1242 | 644.96 | 1461.55 | 14498.81 | 712.81 | 2.27 | 22.48 | 1.11 |
+| ResNet50 | 1702 | 919.95 | 2080.58 | 20995.96 | 1006.42 | 2.26 | 22.82 | 1.09 |
+| ResNet101 | 3317 | 1806.36 | 3996.90 | 40314.12 | 1955.48 | 2.21 | 22.32 | 1.08 |
+| ResNet152 | 4932 | 2656.92 | 5812.38 | 57775.53 | 2826.92 | 2.19 | 21.75 | 1.06 |
+| ViT-H/14 | 3420 | 1863.50 | 4418.24 | 41334.64 | 2128.71 | 2.37 | 22.18 | 1.14 |
+| Swin-B | 2881 | 1631.06 | 3944.13 | 36131.54 | 2032.77 | 2.42 | 22.15 | 1.25 |
+| | | | | | **Average** | **2.33** | **22.05** | **1.12** |
-
+
### Tree UnFlatten
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 54.91 | 134.31 | 234.37 | 913.43 | 2.45 | 4.27 | 16.64 |
-| AlexNet | 188 | 210.76 | 565.19 | 929.61 | 3808.84 | 2.68 | 4.41 | 18.07 |
-| ResNet18 | 698 | 703.22 | 1727.14 | 3184.54 | 11643.08 | 2.46 | 4.53 | 16.56 |
-| ResNet34 | 1242 | 1312.01 | 3147.73 | 5762.68 | 20852.70 | 2.40 | 4.39 | 15.89 |
-| ResNet50 | 1702 | 1758.62 | 4177.30 | 7891.72 | 27874.16 | 2.38 | 4.49 | 15.85 |
-| ResNet101 | 3317 | 3753.81 | 8226.49 | 15362.37 | 53974.51 | 2.19 | 4.09 | 14.38 |
-| ResNet152 | 4932 | 5313.30 | 12205.85 | 24068.88 | 80256.68 | 2.30 | 4.53 | 15.10 |
-| ViT-H/14 | 3420 | 3994.53 | 10016.00 | 17411.04 | 66000.54 | 2.51 | 4.36 | 16.52 |
-| Swin-B | 2881 | 3584.82 | 8940.27 | 15582.13 | 56003.34 | 2.49 | 4.35 | 15.62 |
-| | | | | | **Average** | **2.42** | **4.38** | **16.07** |
+| TinyMLP | 53 | 55.13 | 152.07 | 231.94 | 940.11 | 2.76 | 4.21 | 17.05 |
+| AlexNet | 188 | 226.29 | 678.29 | 972.90 | 4195.04 | 3.00 | 4.30 | 18.54 |
+| ResNet18 | 698 | 766.54 | 1953.26 | 3137.86 | 12049.88 | 2.55 | 4.09 | 15.72 |
+| ResNet34 | 1242 | 1309.22 | 3526.12 | 5759.16 | 20966.75 | 2.69 | 4.40 | 16.01 |
+| ResNet50 | 1702 | 1914.96 | 5002.83 | 8369.43 | 29597.10 | 2.61 | 4.37 | 15.46 |
+| ResNet101 | 3317 | 3672.61 | 9633.29 | 15683.16 | 57240.20 | 2.62 | 4.27 | 15.59 |
+| ResNet152 | 4932 | 5407.58 | 13970.88 | 23074.68 | 82072.54 | 2.58 | 4.27 | 15.18 |
+| ViT-H/14 | 3420 | 4013.18 | 11146.31 | 17633.07 | 66723.58 | 2.78 | 4.39 | 16.63 |
+| Swin-B | 2881 | 3595.34 | 9505.31 | 15054.88 | 57310.03 | 2.64 | 4.19 | 15.94 |
+| | | | | | **Average** | **2.69** | **4.28** | **16.23** |
-
+
### Tree Flatten with Path
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 35.55 | 522.00 | N/A | 915.30 | 14.68 | N/A | 25.75 |
-| AlexNet | 188 | 113.89 | 2005.85 | N/A | 3503.25 | 17.61 | N/A | 30.76 |
-| ResNet18 | 698 | 432.31 | 7052.73 | N/A | 12239.45 | 16.31 | N/A | 28.31 |
-| ResNet34 | 1242 | 812.18 | 12657.12 | N/A | 21703.50 | 15.58 | N/A | 26.72 |
-| ResNet50 | 1702 | 1105.15 | 17173.43 | N/A | 29293.02 | 15.54 | N/A | 26.51 |
-| ResNet101 | 3317 | 2182.68 | 33455.81 | N/A | 56810.61 | 15.33 | N/A | 26.03 |
-| ResNet152 | 4932 | 3272.97 | 49550.72 | N/A | 84535.23 | 15.14 | N/A | 25.83 |
-| ViT-H/14 | 3420 | 2287.13 | 37485.28 | N/A | 63024.61 | 16.39 | N/A | 27.56 |
-| Swin-B | 2881 | 2092.16 | 33942.08 | N/A | 52744.88 | 16.22 | N/A | 25.21 |
-| | | | | | **Average** | **15.87** | N/A | **26.96** |
+| TinyMLP | 53 | 36.49 | 543.67 | N/A | 919.13 | 14.90 | N/A | 25.19 |
+| AlexNet | 188 | 115.44 | 2185.21 | N/A | 3752.11 | 18.93 | N/A | 32.50 |
+| ResNet18 | 698 | 431.84 | 7106.55 | N/A | 12286.70 | 16.46 | N/A | 28.45 |
+| ResNet34 | 1242 | 845.61 | 13431.99 | N/A | 22860.48 | 15.88 | N/A | 27.03 |
+| ResNet50 | 1702 | 1166.27 | 18426.52 | N/A | 31225.05 | 15.80 | N/A | 26.77 |
+| ResNet101 | 3317 | 2312.77 | 34770.49 | N/A | 59346.86 | 15.03 | N/A | 25.66 |
+| ResNet152 | 4932 | 3304.74 | 50557.25 | N/A | 85847.91 | 15.30 | N/A | 25.98 |
+| ViT-H/14 | 3420 | 2235.25 | 37473.53 | N/A | 64105.24 | 16.76 | N/A | 28.68 |
+| Swin-B | 2881 | 1970.25 | 32205.83 | N/A | 55177.50 | 16.35 | N/A | 28.01 |
+| | | | | | **Average** | **16.16** | N/A | **27.59** |
-
+
### Tree Copy
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 91.28 | 211.31 | 833.84 | 952.72 | 2.31 | 9.13 | 10.44 |
-| AlexNet | 188 | 320.20 | 825.66 | 3118.08 | 3938.46 | 2.58 | 9.74 | 12.30 |
-| ResNet18 | 698 | 1113.83 | 2578.31 | 11325.44 | 12068.00 | 2.31 | 10.17 | 10.83 |
-| ResNet34 | 1242 | 2050.00 | 4836.56 | 20324.52 | 22749.73 | 2.36 | 9.91 | 11.10 |
-| ResNet50 | 1702 | 2897.93 | 6121.16 | 27563.39 | 28840.04 | 2.11 | 9.51 | 9.95 |
-| ResNet101 | 3317 | 5456.58 | 12306.26 | 53733.62 | 56140.12 | 2.26 | 9.85 | 10.29 |
-| ResNet152 | 4932 | 8044.33 | 18873.23 | 79896.03 | 83215.06 | 2.35 | 9.93 | 10.34 |
-| ViT-H/14 | 3420 | 6046.78 | 14451.37 | 58204.01 | 70966.61 | 2.39 | 9.63 | 11.74 |
-| Swin-B | 2881 | 5173.48 | 13174.36 | 51701.17 | 60053.21 | 2.55 | 9.99 | 11.61 |
-| | | | | | **Average** | **2.36** | **9.76** | **10.96** |
+| TinyMLP | 53 | 89.81 | 232.26 | 845.20 | 981.48 | 2.59 | 9.41 | 10.93 |
+| AlexNet | 188 | 334.58 | 959.32 | 3360.46 | 4316.05 | 2.87 | 10.04 | 12.90 |
+| ResNet18 | 698 | 1128.11 | 2840.71 | 11471.07 | 12297.07 | 2.52 | 10.17 | 10.90 |
+| ResNet34 | 1242 | 2160.57 | 5333.10 | 20563.06 | 21901.91 | 2.47 | 9.52 | 10.14 |
+| ResNet50 | 1702 | 2746.84 | 6823.88 | 29705.99 | 28927.88 | 2.48 | 10.81 | 10.53 |
+| ResNet101 | 3317 | 5762.05 | 13481.45 | 56968.78 | 60115.93 | 2.34 | 9.89 | 10.43 |
+| ResNet152 | 4932 | 8151.21 | 20805.61 | 81024.06 | 84079.57 | 2.55 | 9.94 | 10.31 |
+| ViT-H/14 | 3420 | 5963.61 | 15665.91 | 59813.52 | 68377.82 | 2.63 | 10.03 | 11.47 |
+| Swin-B | 2881 | 5401.59 | 14255.33 | 53361.77 | 62317.07 | 2.64 | 9.88 | 11.54 |
+| | | | | | **Average** | **2.56** | **9.97** | **11.02** |
-
+
### Tree Map
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 99.25 | 229.26 | 848.37 | 968.28 | 2.31 | 8.55 | 9.76 |
-| AlexNet | 188 | 332.74 | 853.33 | 3142.39 | 3992.22 | 2.56 | 9.44 | 12.00 |
-| ResNet18 | 698 | 1190.94 | 2760.28 | 11399.95 | 12294.02 | 2.32 | 9.57 | 10.32 |
-| ResNet34 | 1242 | 2286.53 | 4925.70 | 20423.57 | 23204.74 | 2.15 | 8.93 | 10.15 |
-| ResNet50 | 1702 | 2968.51 | 6622.94 | 27807.01 | 29259.40 | 2.23 | 9.37 | 9.86 |
-| ResNet101 | 3317 | 5851.06 | 13132.59 | 53999.13 | 57251.12 | 2.24 | 9.23 | 9.78 |
-| ResNet152 | 4932 | 8682.55 | 19346.59 | 80462.95 | 84364.39 | 2.23 | 9.27 | 9.72 |
-| ViT-H/14 | 3420 | 6695.68 | 16045.45 | 58313.07 | 68415.82 | 2.40 | 8.71 | 10.22 |
-| Swin-B | 2881 | 5747.50 | 13757.05 | 52229.81 | 61017.78 | 2.39 | 9.09 | 10.62 |
-| | | | | | **Average** | **2.32** | **9.13** | **10.27** |
+| TinyMLP | 53 | 95.13 | 243.86 | 867.34 | 1026.99 | 2.56 | 9.12 | 10.80 |
+| AlexNet | 188 | 348.44 | 987.57 | 3398.32 | 4354.81 | 2.83 | 9.75 | 12.50 |
+| ResNet18 | 698 | 1190.62 | 2982.66 | 11719.94 | 12559.01 | 2.51 | 9.84 | 10.55 |
+| ResNet34 | 1242 | 2205.87 | 5417.60 | 20935.72 | 22308.51 | 2.46 | 9.49 | 10.11 |
+| ResNet50 | 1702 | 3128.48 | 7579.55 | 30372.71 | 31638.67 | 2.42 | 9.71 | 10.11 |
+| ResNet101 | 3317 | 6173.05 | 14846.57 | 59167.85 | 60245.42 | 2.41 | 9.58 | 9.76 |
+| ResNet152 | 4932 | 8641.22 | 22000.74 | 84018.65 | 86182.21 | 2.55 | 9.72 | 9.97 |
+| ViT-H/14 | 3420 | 6211.79 | 17077.49 | 59790.25 | 69763.86 | 2.75 | 9.63 | 11.23 |
+| Swin-B | 2881 | 5673.66 | 14339.69 | 53309.17 | 59764.61 | 2.53 | 9.40 | 10.53 |
+| | | | | | **Average** | **2.56** | **9.58** | **10.62** |
-
+
### Tree Map (nargs)
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 133.87 | 344.99 | N/A | 3599.07 | 2.58 | N/A | 26.89 |
-| AlexNet | 188 | 445.41 | 1310.77 | N/A | 14207.10 | 2.94 | N/A | 31.90 |
-| ResNet18 | 698 | 1599.16 | 4239.56 | N/A | 49255.49 | 2.65 | N/A | 30.80 |
-| ResNet34 | 1242 | 3066.14 | 8115.79 | N/A | 88568.31 | 2.65 | N/A | 28.89 |
-| ResNet50 | 1702 | 3951.48 | 10557.52 | N/A | 127232.92 | 2.67 | N/A | 32.20 |
-| ResNet101 | 3317 | 7801.80 | 20208.53 | N/A | 235961.43 | 2.59 | N/A | 30.24 |
-| ResNet152 | 4932 | 11489.21 | 29375.98 | N/A | 349007.54 | 2.56 | N/A | 30.38 |
-| ViT-H/14 | 3420 | 8319.66 | 23204.11 | N/A | 266190.21 | 2.79 | N/A | 32.00 |
-| Swin-B | 2881 | 7259.47 | 20098.17 | N/A | 226166.17 | 2.77 | N/A | 31.15 |
-| | | | | | **Average** | **2.69** | N/A | **30.49** |
+| TinyMLP | 53 | 137.06 | 389.96 | N/A | 3908.77 | 2.85 | N/A | 28.52 |
+| AlexNet | 188 | 467.24 | 1496.96 | N/A | 15395.13 | 3.20 | N/A | 32.95 |
+| ResNet18 | 698 | 1603.79 | 4534.01 | N/A | 50323.76 | 2.83 | N/A | 31.38 |
+| ResNet34 | 1242 | 2907.64 | 8435.33 | N/A | 90389.23 | 2.90 | N/A | 31.09 |
+| ResNet50 | 1702 | 4183.77 | 11382.51 | N/A | 121777.01 | 2.72 | N/A | 29.11 |
+| ResNet101 | 3317 | 7721.13 | 22247.85 | N/A | 238755.17 | 2.88 | N/A | 30.92 |
+| ResNet152 | 4932 | 11508.05 | 31429.39 | N/A | 360257.74 | 2.73 | N/A | 31.30 |
+| ViT-H/14 | 3420 | 8294.20 | 24524.86 | N/A | 270514.87 | 2.96 | N/A | 32.61 |
+| Swin-B | 2881 | 7074.62 | 20854.80 | N/A | 241120.41 | 2.95 | N/A | 34.08 |
+| | | | | | **Average** | **2.89** | N/A | **31.33** |
-
+
### Tree Map with Path
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 104.70 | 703.83 | N/A | 1998.00 | 6.72 | N/A | 19.08 |
-| AlexNet | 188 | 352.30 | 2668.73 | N/A | 7681.19 | 7.58 | N/A | 21.80 |
-| ResNet18 | 698 | 1289.51 | 9342.79 | N/A | 25497.31 | 7.25 | N/A | 19.77 |
-| ResNet34 | 1242 | 2366.46 | 16746.52 | N/A | 45254.59 | 7.08 | N/A | 19.12 |
-| ResNet50 | 1702 | 3399.11 | 23574.46 | N/A | 60494.27 | 6.94 | N/A | 17.80 |
-| ResNet101 | 3317 | 6329.82 | 43955.95 | N/A | 118725.60 | 6.94 | N/A | 18.76 |
-| ResNet152 | 4932 | 9307.87 | 64777.45 | N/A | 174764.97 | 6.96 | N/A | 18.78 |
-| ViT-H/14 | 3420 | 6705.10 | 48862.92 | N/A | 139617.21 | 7.29 | N/A | 20.82 |
-| Swin-B | 2881 | 5780.20 | 41703.04 | N/A | 115003.61 | 7.21 | N/A | 19.90 |
-| | | | | | **Average** | **7.11** | N/A | **19.54** |
+| TinyMLP | 53 | 109.82 | 778.30 | N/A | 2186.40 | 7.09 | N/A | 19.91 |
+| AlexNet | 188 | 365.16 | 2939.36 | N/A | 8355.37 | 8.05 | N/A | 22.88 |
+| ResNet18 | 698 | 1308.26 | 9529.58 | N/A | 25758.24 | 7.28 | N/A | 19.69 |
+| ResNet34 | 1242 | 2527.21 | 18084.89 | N/A | 45942.32 | 7.16 | N/A | 18.18 |
+| ResNet50 | 1702 | 3226.03 | 22935.53 | N/A | 61275.34 | 7.11 | N/A | 18.99 |
+| ResNet101 | 3317 | 6663.52 | 46878.89 | N/A | 126642.14 | 7.04 | N/A | 19.01 |
+| ResNet152 | 4932 | 9378.19 | 66136.44 | N/A | 176981.01 | 7.05 | N/A | 18.87 |
+| ViT-H/14 | 3420 | 7033.69 | 50418.37 | N/A | 142508.11 | 7.17 | N/A | 20.26 |
+| Swin-B | 2881 | 6078.15 | 43173.22 | N/A | 116612.71 | 7.10 | N/A | 19.19 |
+| | | | | | **Average** | **7.23** | N/A | **19.66** |
-
+
### Tree Map with Path (nargs)
| Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
-| TinyMLP | 53 | 138.19 | 828.53 | N/A | 3599.32 | 6.00 | N/A | 26.05 |
-| AlexNet | 188 | 461.91 | 3138.59 | N/A | 14069.17 | 6.79 | N/A | 30.46 |
-| ResNet18 | 698 | 1702.79 | 10890.25 | N/A | 49456.32 | 6.40 | N/A | 29.04 |
-| ResNet34 | 1242 | 3115.89 | 19356.46 | N/A | 88955.96 | 6.21 | N/A | 28.55 |
-| ResNet50 | 1702 | 4422.25 | 26205.69 | N/A | 121569.30 | 5.93 | N/A | 27.49 |
-| ResNet101 | 3317 | 8334.83 | 50909.37 | N/A | 241862.38 | 6.11 | N/A | 29.02 |
-| ResNet152 | 4932 | 12208.52 | 75327.94 | N/A | 351472.89 | 6.17 | N/A | 28.79 |
-| ViT-H/14 | 3420 | 9320.75 | 56869.65 | N/A | 266430.78 | 6.10 | N/A | 28.58 |
-| Swin-B | 2881 | 7472.11 | 49260.03 | N/A | 233154.60 | 6.59 | N/A | 31.20 |
-| | | | | | **Average** | **6.26** | N/A | **28.80** |
+| TinyMLP | 53 | 146.05 | 917.00 | N/A | 3940.61 | 6.28 | N/A | 26.98 |
+| AlexNet | 188 | 489.27 | 3560.76 | N/A | 15434.71 | 7.28 | N/A | 31.55 |
+| ResNet18 | 698 | 1712.79 | 11171.44 | N/A | 50219.86 | 6.52 | N/A | 29.32 |
+| ResNet34 | 1242 | 3112.83 | 21024.58 | N/A | 95505.71 | 6.75 | N/A | 30.68 |
+| ResNet50 | 1702 | 4220.70 | 26600.82 | N/A | 121897.57 | 6.30 | N/A | 28.88 |
+| ResNet101 | 3317 | 8631.34 | 54372.37 | N/A | 236555.54 | 6.30 | N/A | 27.41 |
+| ResNet152 | 4932 | 12710.49 | 77643.13 | N/A | 353600.32 | 6.11 | N/A | 27.82 |
+| ViT-H/14 | 3420 | 8753.09 | 58712.71 | N/A | 286365.36 | 6.71 | N/A | 32.72 |
+| Swin-B | 2881 | 7359.29 | 50112.23 | N/A | 228866.66 | 6.81 | N/A | 31.10 |
+| | | | | | **Average** | **6.56** | N/A | **29.61** |
-
+
--------------------------------------------------------------------------------
diff --git a/benchmark.py b/benchmark.py
index 733cb7dc..8e1116f7 100755
--- a/benchmark.py
+++ b/benchmark.py
@@ -15,15 +15,18 @@
# limitations under the License.
# ==============================================================================
-# pylint: disable=missing-module-docstring,missing-function-docstring,invalid-name
+# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
+# pylint: disable=invalid-name
from __future__ import annotations
import argparse
import sys
-from collections import OrderedDict, namedtuple
+import textwrap
+import timeit
+from collections import OrderedDict
from itertools import count
-from typing import Any, Iterable
+from typing import Any, Iterable, NamedTuple, Sequence
import jax
import pandas as pd
@@ -52,56 +55,6 @@ def colored( # pylint: disable=function-redefined,unused-argument
return text
-torch_utils_pytree._register_pytree_node( # pylint: disable=protected-access
- dict,
- lambda d: tuple(zip(*sorted(d.items())))[::-1] if d else ((), ()),
- lambda values, keys: dict(zip(keys, values)),
-)
-
-
-Tensors = namedtuple('Tensors', ['parameters', 'buffers'])
-Containers = namedtuple('Containers', ['parameters', 'buffers'])
-
-
-def get_none() -> None:
- return None
-
-
-def tiny_mlp() -> nn.Module:
- return nn.Sequential(
- nn.Linear(1, 1, bias=True),
- nn.BatchNorm1d(1, affine=True, track_running_stats=True),
- nn.ReLU(),
- nn.Linear(1, 1, bias=False),
- nn.Sigmoid(),
- )
-
-
-def extract(module: nn.Module, unordered: bool) -> Any:
- if isinstance(module, (nn.Sequential, nn.ModuleList)):
- return [extract(submodule, unordered=unordered) for submodule in module]
-
- dict_factory = dict if unordered else OrderedDict
- extracted = dict_factory(
- [
- (name, extract(submodule, unordered=unordered))
- for name, submodule in module.named_children()
- ],
- )
- extracted.update(
- tensors=Tensors(
- parameters=tuple(t.data for t in module.parameters(recurse=False)),
- buffers=[t.data for t in module.buffers(recurse=False)],
- ),
- containers=Containers(
- parameters=dict_factory(module._parameters), # pylint: disable=protected-access
- buffers=dict_factory(module._buffers), # pylint: disable=protected-access
- ),
- )
-
- return extracted
-
-
cmark = '✔' # pylint: disable=invalid-name
xmark = '✘' # pylint: disable=invalid-name
tie = '~' # pylint: disable=invalid-name
@@ -128,198 +81,320 @@ def fnp3(p, x, y, z):
return None
"""
-STMTS = OrderedDict(
+
+class BenchmarkCase(NamedTuple):
+ name: str
+ stmt: str
+ init_stmt: str = ''
+
+ def timeit(
+ self,
+ number: int,
+ repeat: int = 5,
+ globals: dict[str, Any] | None = None, # pylint: disable=redefined-builtin
+ ) -> float:
+ init_warmup = (
+ INIT.strip()
+ + '\n\n'
+ + textwrap.dedent(
+ f"""
+ {self.init_stmt.strip()}
+
+ for _ in range({number} // 10):
+ {self.stmt.strip()}
+ """,
+ ).strip()
+ )
+
+ times = timeit.repeat(
+ self.stmt.strip(),
+ setup=init_warmup,
+ number=number,
+ repeat=repeat,
+ globals=globals,
+ )
+ return min(times) / number
+
+
+BENCHMARK_CASES: dict[str, Sequence[BenchmarkCase]] = OrderedDict(
[
(
'Tree Flatten',
- OrderedDict(
- [
- ('OpTree(default)', ('optree.tree_leaves(x)', '')),
- ('OpTree(NoneIsNode)', ('optree.tree_leaves(x, none_is_leaf=False)', '')),
- ('OpTree(NoneIsLeaf)', ('optree.tree_leaves(x, none_is_leaf=True)', '')),
- ('JAX XLA', ('jax.tree_util.tree_leaves(x)', '')),
- ('PyTorch', ('torch_utils_pytree.tree_flatten(x)[0]', '')),
- ('DM-Tree', ('dm_tree.flatten(x)', '')),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(default)',
+ stmt='optree.tree_leaves(x)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)',
+ stmt='optree.tree_leaves(x, none_is_leaf=False)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_leaves(x, none_is_leaf=True)',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_leaves(x)',
+ ),
+ BenchmarkCase(
+ name='PyTorch',
+ stmt='torch_utils_pytree.tree_flatten(x)[0]',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.flatten(x)',
+ ),
+ ],
),
(
'Tree UnFlatten',
- OrderedDict(
- [
- (
- 'OpTree(NoneIsNode)', # default
- (
- 'optree.tree_unflatten(spec, flat)',
- 'flat, spec = optree.tree_flatten(x, none_is_leaf=False)',
- ),
- ),
- (
- 'OpTree(NoneIsLeaf)',
- (
- 'optree.tree_unflatten(spec, flat)',
- 'flat, spec = optree.tree_flatten(x, none_is_leaf=True)',
- ),
- ),
- (
- 'JAX XLA',
- (
- 'jax.tree_util.tree_unflatten(spec, flat)',
- 'flat, spec = jax.tree_util.tree_flatten(x)',
- ),
- ),
- (
- 'PyTorch',
- (
- 'torch_utils_pytree.tree_unflatten(flat, spec)',
- 'flat, spec = torch_utils_pytree.tree_flatten(x)',
- ),
- ),
- (
- 'DM-Tree',
- (
- 'dm_tree.unflatten_as(spec, flat)',
- 'flat, spec = dm_tree.flatten(x), x',
- ),
- ),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)', # default
+ stmt='optree.tree_unflatten(spec, flat)',
+ init_stmt='flat, spec = optree.tree_flatten(x, none_is_leaf=False)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_unflatten(spec, flat)',
+ init_stmt='flat, spec = optree.tree_flatten(x, none_is_leaf=True)',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_unflatten(spec, flat)',
+ init_stmt='flat, spec = jax.tree_util.tree_flatten(x)',
+ ),
+ BenchmarkCase(
+ name='PyTorch',
+ stmt='torch_utils_pytree.tree_unflatten(flat, spec)',
+ init_stmt='flat, spec = torch_utils_pytree.tree_flatten(x)',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.unflatten_as(spec, flat)',
+ init_stmt='flat, spec = dm_tree.flatten(x), x',
+ ),
+ ],
),
(
'Tree Flatten with Path',
- OrderedDict(
- [
- ('OpTree(default)', ('optree.tree_flatten_with_path(x)', '')),
- (
- 'OpTree(NoneIsNode)',
- ('optree.tree_flatten_with_path(x, none_is_leaf=False)', ''),
- ),
- (
- 'OpTree(NoneIsLeaf)',
- ('optree.tree_flatten_with_path(x, none_is_leaf=True)', ''),
- ),
- ('JAX XLA', ('jax.tree_util.tree_flatten_with_path(x)', '')),
- ('DM-Tree', ('dm_tree.flatten_with_path(x)', '')),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(default)',
+ stmt='optree.tree_flatten_with_path(x)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)',
+ stmt='optree.tree_flatten_with_path(x, none_is_leaf=False)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_flatten_with_path(x, none_is_leaf=True)',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_flatten_with_path(x)',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.flatten_with_path(x)',
+ ),
+ ],
),
(
'Tree Copy',
- OrderedDict(
- [
- (
- 'OpTree(default)',
- ('optree.tree_unflatten(*optree.tree_flatten(x)[::-1])', ''),
- ),
- (
- 'OpTree(NoneIsNode)',
- (
- 'optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1])',
- '',
- ),
- ),
- (
- 'OpTree(NoneIsLeaf)',
- (
- 'optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1])',
- '',
- ),
- ),
- (
- 'JAX XLA',
- ('jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])', ''),
- ),
- (
- 'PyTorch',
- (
- 'torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))',
- '',
- ),
- ),
- (
- 'DM-Tree',
- ('dm_tree.unflatten_as(x, dm_tree.flatten(x))', ''),
- ),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(default)',
+ stmt='optree.tree_unflatten(*optree.tree_flatten(x)[::-1])',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)',
+ stmt='optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1])',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1])',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])',
+ ),
+ BenchmarkCase(
+ name='PyTorch',
+ stmt='torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.unflatten_as(x, dm_tree.flatten(x))',
+ ),
+ ],
),
(
'Tree Map',
- OrderedDict(
- [
- ('OpTree(default)', ('optree.tree_map(fn1, x)', '')),
- ('OpTree(NoneIsNode)', ('optree.tree_map(fn1, x, none_is_leaf=False)', '')),
- ('OpTree(NoneIsLeaf)', ('optree.tree_map(fn1, x, none_is_leaf=True)', '')),
- ('JAX XLA', ('jax.tree_util.tree_map(fn1, x)', '')),
- ('PyTorch', ('torch_utils_pytree.tree_map(fn1, x)', '')),
- ('DM-Tree', ('dm_tree.map_structure(fn1, x)', '')),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(default)',
+ stmt='optree.tree_map(fn1, x)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)',
+ stmt='optree.tree_map(fn1, x, none_is_leaf=False)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_map(fn1, x, none_is_leaf=True)',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_map(fn1, x)',
+ ),
+ BenchmarkCase(
+ name='PyTorch',
+ stmt='torch_utils_pytree.tree_map(fn1, x)',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.map_structure(fn1, x)',
+ ),
+ ],
),
(
'Tree Map (nargs)',
- OrderedDict(
- [
- ('OpTree(default)', ('optree.tree_map(fn3, x, y, z)', '')),
- (
- 'OpTree(NoneIsNode)',
- ('optree.tree_map(fn3, x, y, z, none_is_leaf=False)', ''),
- ),
- (
- 'OpTree(NoneIsLeaf)',
- ('optree.tree_map(fn3, x, y, z, none_is_leaf=True)', ''),
- ),
- ('JAX XLA', ('jax.tree_util.tree_map(fn3, x, y, z)', '')),
- ('DM-Tree', ('dm_tree.map_structure_up_to(x, fn3, x, y, z)', '')),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(default)',
+ stmt='optree.tree_map(fn3, x, y, z)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)',
+ stmt='optree.tree_map(fn3, x, y, z, none_is_leaf=False)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_map(fn3, x, y, z, none_is_leaf=True)',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_map(fn3, x, y, z)',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.map_structure_up_to(x, fn3, x, y, z)',
+ ),
+ ],
),
(
'Tree Map with Path',
- OrderedDict(
- [
- ('OpTree(default)', ('optree.tree_map_with_path(fnp1, x)', '')),
- (
- 'OpTree(NoneIsNode)',
- ('optree.tree_map_with_path(fnp1, x, none_is_leaf=False)', ''),
- ),
- (
- 'OpTree(NoneIsLeaf)',
- ('optree.tree_map_with_path(fnp1, x, none_is_leaf=True)', ''),
- ),
- ('JAX XLA', ('jax.tree_util.tree_map_with_path(fnp1, x)', '')),
- ('DM-Tree', ('dm_tree.map_structure_with_path(fnp1, x)', '')),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(default)',
+ stmt='optree.tree_map_with_path(fnp1, x)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)',
+ stmt='optree.tree_map_with_path(fnp1, x, none_is_leaf=False)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_map_with_path(fnp1, x, none_is_leaf=True)',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_map_with_path(fnp1, x)',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.map_structure_with_path(fnp1, x)',
+ ),
+ ],
),
(
'Tree Map with Path (nargs)',
- OrderedDict(
- [
- ('OpTree(default)', ('optree.tree_map_with_path(fnp3, x, y, z)', '')),
- (
- 'OpTree(NoneIsNode)',
- (
- 'optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=False)',
- '',
- ),
- ),
- (
- 'OpTree(NoneIsLeaf)',
- (
- 'optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=True)',
- '',
- ),
- ),
- ('JAX XLA', ('jax.tree_util.tree_map_with_path(fnp3, x, y, z)', '')),
- ('DM-Tree', ('dm_tree.map_structure_with_path(fnp3, x, y, z)', '')),
- ],
- ),
+ [
+ BenchmarkCase(
+ name='OpTree(default)',
+ stmt='optree.tree_map_with_path(fnp3, x, y, z)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsNode)',
+ stmt='optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=False)',
+ ),
+ BenchmarkCase(
+ name='OpTree(NoneIsLeaf)',
+ stmt='optree.tree_map_with_path(fnp3, x, y, z, none_is_leaf=True)',
+ ),
+ BenchmarkCase(
+ name='JAX XLA',
+ stmt='jax.tree_util.tree_map_with_path(fnp3, x, y, z)',
+ ),
+ BenchmarkCase(
+ name='DM-Tree',
+ stmt='dm_tree.map_structure_with_path(fnp3, x, y, z)',
+ ),
+ ],
),
],
)
+torch_utils_pytree._register_pytree_node( # pylint: disable=protected-access
+ dict,
+ lambda d: tuple(zip(*sorted(d.items())))[::-1] if d else ((), ()),
+ lambda values, keys: dict(zip(keys, values)),
+)
+
+
+class Tensors(NamedTuple):
+ parameters: tuple[torch.Tensor, ...]
+ buffers: list[torch.Tensor]
+
+
+class Containers(NamedTuple):
+ parameters: dict[str, torch.Tensor | None]
+ buffers: dict[str, torch.Tensor | None]
+
+
+def get_none() -> None:
+ return None
+
+
+def tiny_mlp() -> nn.Module:
+ return nn.Sequential(
+ nn.Linear(1, 1, bias=True),
+ nn.BatchNorm1d(1, affine=True, track_running_stats=True),
+ nn.ReLU(),
+ nn.Linear(1, 1, bias=False),
+ nn.Sigmoid(),
+ )
+
+
+def extract(module: nn.Module, unordered: bool) -> Any:
+ if isinstance(module, (nn.Sequential, nn.ModuleList)):
+ return [extract(submodule, unordered=unordered) for submodule in module]
+
+ dict_factory = dict if unordered else OrderedDict
+ extracted = dict_factory(
+ [
+ (name, extract(submodule, unordered=unordered))
+ for name, submodule in module.named_children()
+ ],
+ )
+ extracted.update(
+ tensors=Tensors(
+ parameters=tuple(t.data for t in module.parameters(recurse=False)),
+ buffers=[t.data for t in module.buffers(recurse=False)],
+ ),
+ containers=Containers(
+ parameters=dict_factory(module._parameters), # pylint: disable=protected-access
+ buffers=dict_factory(module._buffers), # pylint: disable=protected-access
+ ),
+ )
+
+ return extracted
+
+
def cprint(text: str = '') -> None:
text = (
text.replace(
@@ -439,58 +514,26 @@ def check(tree: Any) -> None:
print(flush=True)
-def timeit(
- stmt: str,
- init_stmt: str,
- number: int,
- repeat: int = 5,
- globals: dict[str, Any] | None = None, # pylint: disable=redefined-builtin
-) -> float:
- import timeit # pylint: disable=redefined-outer-name,import-outside-toplevel
-
- globals = globals or {}
- init_warmup = f"""
-{INIT.strip()}
-
-{init_stmt.strip()}
-
-for _ in range({number} // 10):
- {stmt.strip()}
-"""
-
- times = timeit.repeat(
- stmt.strip(),
- setup=init_warmup,
- globals=globals,
- number=number,
- repeat=repeat,
- )
- return min(times) / number
-
-
def compare( # pylint: disable=too-many-locals
subject: str,
- stmts: dict[str, tuple[str, str]],
+ cases: Sequence[BenchmarkCase],
number: int,
repeat: int = 5,
globals: dict[str, Any] | None = None, # pylint: disable=redefined-builtin
) -> dict[str, float]:
times_us = OrderedDict(
- [
- (lib, 10e6 * timeit(stmt, init_stmt, number, repeat=repeat, globals=globals))
- for lib, (stmt, init_stmt) in stmts.items()
- ],
+ [(case.name, 10e6 * case.timeit(number, repeat=repeat, globals=globals)) for case in cases],
)
base_time = next(iter(times_us.values()))
best_time = min(times_us.values())
- speedups = {lib: time / base_time for lib, time in times_us.items()}
- best_speedups = {lib: time / best_time for lib, time in times_us.items()}
+ speedups = {name: time / base_time for name, time in times_us.items()}
+ best_speedups = {name: time / best_time for name, time in times_us.items()}
labels = {
- lib: (cmark if speedup == 1.0 else tie if speedup < 1.1 else ' ')
- for lib, speedup in best_speedups.items()
+ name: (cmark if speedup == 1.0 else tie if speedup < 1.1 else ' ')
+ for name, speedup in best_speedups.items()
}
colors = {
- lib: (
+ name: (
'green'
if speedup == 1.0
else 'cyan'
@@ -499,18 +542,18 @@ def compare( # pylint: disable=too-many-locals
if speedup < 4.0
else 'red'
)
- for lib, speedup in best_speedups.items()
+ for name, speedup in best_speedups.items()
}
- attrs = {lib: ('bold',) if color == 'green' else () for lib, color in colors.items()}
+ attrs = {name: ('bold',) if color == 'green' else () for name, color in colors.items()}
cprint(f'### {subject} ###')
- for lib, (stmt, _) in stmts.items():
- label = labels[lib]
- color = colors[lib]
- attr = attrs[lib]
- label = colored(label + ' ' + lib, color=color, attrs=attr)
- time = colored(f'{times_us[lib]:8.2f}μs', attrs=attr)
- speedup = speedups[lib]
+ for name, stmt, _ in cases:
+ label = labels[name]
+ color = colors[name]
+ attr = attrs[name]
+ label = colored(label + ' ' + name, color=color, attrs=attr)
+ time = colored(f'{times_us[name]:8.2f}μs', attrs=attr)
+ speedup = speedups[name]
if speedup != 1.0:
speedup = ' -- ' + colored(f'x{speedup:.2f}'.ljust(6), color=color, attrs=attr)
else:
@@ -573,10 +616,10 @@ def benchmark( # pylint: disable=too-many-locals
z = optree.tree_map(lambda t: (t, None), x) # pylint: disable=invalid-name
check(x)
- for subject, stmts in STMTS.items():
+ for subject, cases in BENCHMARK_CASES.items():
times_us = compare(
subject,
- stmts,
+ cases,
number=number,
repeat=repeat,
globals={'x': x, 'y': y, 'z': z},
diff --git a/conda-recipe.yaml b/conda-recipe.yaml
index ca3d615e..07295596 100644
--- a/conda-recipe.yaml
+++ b/conda-recipe.yaml
@@ -39,7 +39,7 @@ dependencies:
- pybind11
# Benchmark
- - pytorch::pytorch >= 1.13, < 1.14.0a0
+ - pytorch::pytorch >= 2.0, < 2.1.0a0
- pytorch::torchvision
- pytorch::pytorch-mutex = *=*cpu*
- jax >= 0.4.6, < 0.5.0a0
diff --git a/pyproject.toml b/pyproject.toml
index 747706a0..b77b18b8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -70,7 +70,7 @@ lint = [
test = ['pytest', 'pytest-cov', 'pytest-xdist']
benchmark = [
'jax[cpu] >= 0.4.6, < 0.5.0a0',
- 'torch >= 1.13, < 1.14.0a0',
+ 'torch >= 2.0, < 2.1.0a0',
'torchvision',
'dm-tree >= 0.1, < 0.2.0a0',
'pandas',