-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Avoid some recompiles of ReplayBuffer.extend\sample
#2504
Conversation
ghstack-source-id: f50d4ecbc9c2fe0f5334e3cabc0e51cd4f930b80 Pull Request resolved: #2504
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2504
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Unrelated FailuresAs of commit fff3e79 with merge base 5244a90 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.4163s | 0.4150s | 2.4095 Ops/s | 2.3972 Ops/s | |
test_transformed | 0.6800s | 0.6123s | 1.6332 Ops/s | 1.6983 Ops/s | |
test_serial | 1.4355s | 1.3625s | 0.7339 Ops/s | 0.7396 Ops/s | |
test_parallel | 1.3234s | 1.3126s | 0.7618 Ops/s | 0.7444 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2202ms | 28.6828μs | 34.8641 KOps/s | 33.3941 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 49.1120μs | 17.0422μs | 58.6780 KOps/s | 56.1536 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 51.7360μs | 15.9134μs | 62.8401 KOps/s | 60.3302 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 48.6210μs | 9.3864μs | 106.5366 KOps/s | 102.8456 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1770ms | 31.0716μs | 32.1838 KOps/s | 31.0121 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.2966ms | 20.4966μs | 48.7886 KOps/s | 50.6166 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 66.6440μs | 18.0803μs | 55.3087 KOps/s | 53.7792 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 49.1320μs | 11.2942μs | 88.5411 KOps/s | 84.3360 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 75.2310μs | 32.7150μs | 30.5670 KOps/s | 29.0582 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 54.0010μs | 21.2299μs | 47.1034 KOps/s | 45.2020 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 68.2980μs | 17.6991μs | 56.4999 KOps/s | 53.2725 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 67.2060μs | 11.2500μs | 88.8889 KOps/s | 86.2189 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 77.7460μs | 34.5965μs | 28.9047 KOps/s | 27.1853 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 53.3400μs | 23.1814μs | 43.1380 KOps/s | 41.7715 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 0.1557ms | 19.8938μs | 50.2670 KOps/s | 47.4466 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 35.3760μs | 13.2501μs | 75.4713 KOps/s | 72.0331 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 73.4070μs | 33.0431μs | 30.2635 KOps/s | 28.8098 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 74.2190μs | 21.0633μs | 47.4759 KOps/s | 44.9797 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 55.8540μs | 21.1426μs | 47.2979 KOps/s | 44.8018 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 55.0730μs | 13.0808μs | 76.4480 KOps/s | 73.5095 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 93.1740μs | 35.0233μs | 28.5525 KOps/s | 27.2876 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 68.7890μs | 23.4231μs | 42.6929 KOps/s | 40.9925 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 97.6956ms | 26.5549μs | 37.6578 KOps/s | 41.5636 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 49.8440μs | 14.8641μs | 67.2763 KOps/s | 63.4518 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 82.8050μs | 36.4681μs | 27.4212 KOps/s | 25.9151 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 60.2630μs | 25.2497μs | 39.6044 KOps/s | 37.5584 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 62.0160μs | 23.0438μs | 43.3956 KOps/s | 40.8839 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 40.9470μs | 15.1518μs | 65.9986 KOps/s | 63.3810 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.1012ms | 38.2845μs | 26.1202 KOps/s | 24.5447 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 59.3310μs | 27.3042μs | 36.6244 KOps/s | 35.5084 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 58.0090μs | 24.6640μs | 40.5449 KOps/s | 38.7338 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 64.8800μs | 17.1673μs | 58.2502 KOps/s | 56.9587 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.1535ms | 9.5063ms | 105.1936 Ops/s | 105.6888 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 37.6139ms | 35.5793ms | 28.1062 Ops/s | 29.7845 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2564ms | 0.1933ms | 5.1727 KOps/s | 5.9017 KOps/s | |
test_values[td1_return_estimate-False-False] | 23.8754ms | 23.4087ms | 42.7192 Ops/s | 42.3657 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 42.3615ms | 35.7409ms | 27.9791 Ops/s | 29.7523 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 37.3725ms | 33.4722ms | 29.8756 Ops/s | 29.3473 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 37.5785ms | 35.5888ms | 28.0987 Ops/s | 29.7561 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.4962ms | 8.3241ms | 120.1325 Ops/s | 119.5437 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.2399ms | 2.0081ms | 497.9938 Ops/s | 511.4219 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5165ms | 0.3524ms | 2.8378 KOps/s | 2.7788 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 48.6522ms | 45.0280ms | 22.2084 Ops/s | 22.5544 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 3.7938ms | 3.0376ms | 329.2069 Ops/s | 329.4319 Ops/s | |
test_dqn_speed[False-None] | 2.7828ms | 1.3709ms | 729.4722 Ops/s | 727.8601 Ops/s | |
test_dqn_speed[False-backward] | 1.9284ms | 1.8323ms | 545.7668 Ops/s | 533.8205 Ops/s | |
test_dqn_speed[True-None] | 0.7501ms | 0.4735ms | 2.1120 KOps/s | 2.1262 KOps/s | |
test_dqn_speed[True-backward] | 0.9490ms | 0.8955ms | 1.1166 KOps/s | 1.1225 KOps/s | |
test_dqn_speed[reduce-overhead-None] | 0.7364ms | 0.4803ms | 2.0821 KOps/s | 2.0979 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 0.9536ms | 0.8969ms | 1.1150 KOps/s | 1.1107 KOps/s | |
test_ddpg_speed[False-None] | 3.4658ms | 2.8613ms | 349.4952 Ops/s | 356.4170 Ops/s | |
test_ddpg_speed[False-backward] | 4.1322ms | 3.9862ms | 250.8673 Ops/s | 253.7159 Ops/s | |
test_ddpg_speed[True-None] | 1.4948ms | 1.0216ms | 978.8905 Ops/s | 994.8040 Ops/s | |
test_ddpg_speed[True-backward] | 1.9682ms | 1.9052ms | 524.8913 Ops/s | 446.9471 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.1724ms | 1.0160ms | 984.2819 Ops/s | 994.6256 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.9916ms | 1.9199ms | 520.8561 Ops/s | 530.5882 Ops/s | |
test_sac_speed[False-None] | 9.5701ms | 8.1056ms | 123.3716 Ops/s | 125.4971 Ops/s | |
test_sac_speed[False-backward] | 12.5792ms | 10.8777ms | 91.9313 Ops/s | 93.0994 Ops/s | |
test_sac_speed[True-None] | 2.1152ms | 1.8587ms | 538.0132 Ops/s | 532.8648 Ops/s | |
test_sac_speed[True-backward] | 3.6684ms | 3.5626ms | 280.6945 Ops/s | 278.8707 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.0956ms | 1.8790ms | 532.2030 Ops/s | 537.1854 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.7312ms | 3.5961ms | 278.0801 Ops/s | 282.1230 Ops/s | |
test_redq_speed[False-None] | 15.4134ms | 12.8024ms | 78.1103 Ops/s | 77.1090 Ops/s | |
test_redq_speed[False-backward] | 24.2314ms | 22.1244ms | 45.1990 Ops/s | 44.4683 Ops/s | |
test_redq_speed[True-None] | 6.1211ms | 4.6795ms | 213.6970 Ops/s | 202.8666 Ops/s | |
test_redq_speed[True-backward] | 13.1361ms | 12.4698ms | 80.1935 Ops/s | 81.9055 Ops/s | |
test_redq_speed[reduce-overhead-None] | 12.3760ms | 4.9883ms | 200.4680 Ops/s | 207.2844 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 13.3105ms | 12.5634ms | 79.5965 Ops/s | 81.0529 Ops/s | |
test_redq_deprec_speed[False-None] | 14.7993ms | 13.1741ms | 75.9066 Ops/s | 78.9891 Ops/s | |
test_redq_deprec_speed[False-backward] | 20.1031ms | 19.0969ms | 52.3644 Ops/s | 54.2687 Ops/s | |
test_redq_deprec_speed[True-None] | 4.9542ms | 3.6910ms | 270.9318 Ops/s | 279.8576 Ops/s | |
test_redq_deprec_speed[True-backward] | 8.3832ms | 8.0469ms | 124.2717 Ops/s | 119.7055 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.6763ms | 3.5868ms | 278.8035 Ops/s | 277.6527 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 9.6927ms | 8.3654ms | 119.5405 Ops/s | 119.7957 Ops/s | |
test_td3_speed[False-None] | 36.9779ms | 8.3883ms | 119.2136 Ops/s | 125.8498 Ops/s | |
test_td3_speed[False-backward] | 11.7708ms | 10.6218ms | 94.1456 Ops/s | 96.1111 Ops/s | |
test_td3_speed[True-None] | 2.0945ms | 1.8218ms | 548.8988 Ops/s | 572.4395 Ops/s | |
test_td3_speed[True-backward] | 4.1117ms | 3.5100ms | 284.9002 Ops/s | 294.2988 Ops/s | |
test_td3_speed[reduce-overhead-None] | 1.9795ms | 1.8066ms | 553.5305 Ops/s | 572.2651 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 3.7711ms | 3.6053ms | 277.3680 Ops/s | 293.9764 Ops/s | |
test_cql_speed[False-None] | 39.4885ms | 36.4748ms | 27.4162 Ops/s | 28.1819 Ops/s | |
test_cql_speed[False-backward] | 48.8861ms | 46.7136ms | 21.4071 Ops/s | 21.1829 Ops/s | |
test_cql_speed[True-None] | 17.4333ms | 16.0874ms | 62.1604 Ops/s | 62.2612 Ops/s | |
test_cql_speed[True-backward] | 25.7086ms | 23.6697ms | 42.2481 Ops/s | 43.4875 Ops/s | |
test_cql_speed[reduce-overhead-None] | 18.0498ms | 16.6484ms | 60.0660 Ops/s | 62.6842 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 24.8894ms | 23.1761ms | 43.1479 Ops/s | 44.1836 Ops/s | |
test_a2c_speed[False-None] | 9.0947ms | 7.5043ms | 133.2564 Ops/s | 139.0123 Ops/s | |
test_a2c_speed[False-backward] | 17.6603ms | 15.3810ms | 65.0154 Ops/s | 68.6003 Ops/s | |
test_a2c_speed[True-None] | 3.9146ms | 3.3258ms | 300.6794 Ops/s | 301.0125 Ops/s | |
test_a2c_speed[True-backward] | 10.5584ms | 9.9599ms | 100.4029 Ops/s | 101.5352 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.1489ms | 3.3884ms | 295.1225 Ops/s | 297.5162 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 10.1533ms | 9.8087ms | 101.9503 Ops/s | 101.9094 Ops/s | |
test_ppo_speed[False-None] | 8.6283ms | 7.4142ms | 134.8767 Ops/s | 135.6811 Ops/s | |
test_ppo_speed[False-backward] | 15.5033ms | 15.0901ms | 66.2687 Ops/s | 67.8905 Ops/s | |
test_ppo_speed[True-None] | 5.1997ms | 3.7382ms | 267.5114 Ops/s | 265.9081 Ops/s | |
test_ppo_speed[True-backward] | 12.1863ms | 10.1894ms | 98.1411 Ops/s | 100.4458 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.1525ms | 3.7253ms | 268.4350 Ops/s | 266.7200 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 10.3312ms | 9.6531ms | 103.5933 Ops/s | 103.4742 Ops/s | |
test_reinforce_speed[False-None] | 7.9058ms | 6.4485ms | 155.0744 Ops/s | 153.4187 Ops/s | |
test_reinforce_speed[False-backward] | 10.1263ms | 9.7472ms | 102.5941 Ops/s | 101.7271 Ops/s | |
test_reinforce_speed[True-None] | 3.6245ms | 2.6495ms | 377.4362 Ops/s | 368.4567 Ops/s | |
test_reinforce_speed[True-backward] | 11.0073ms | 8.6826ms | 115.1733 Ops/s | 115.6296 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.0285ms | 2.6955ms | 370.9941 Ops/s | 374.7667 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 9.1403ms | 8.7072ms | 114.8470 Ops/s | 113.5205 Ops/s | |
test_iql_speed[False-None] | 35.2701ms | 32.4785ms | 30.7896 Ops/s | 30.4212 Ops/s | |
test_iql_speed[False-backward] | 46.8691ms | 45.3211ms | 22.0648 Ops/s | 22.0379 Ops/s | |
test_iql_speed[True-None] | 11.4810ms | 10.7676ms | 92.8714 Ops/s | 90.4795 Ops/s | |
test_iql_speed[True-backward] | 23.8691ms | 22.1153ms | 45.2176 Ops/s | 44.5736 Ops/s | |
test_iql_speed[reduce-overhead-None] | 11.8985ms | 10.9461ms | 91.3565 Ops/s | 88.5938 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 22.8757ms | 21.8683ms | 45.7282 Ops/s | 45.0874 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.0681ms | 4.8582ms | 205.8372 Ops/s | 197.8925 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.3352ms | 0.4834ms | 2.0686 KOps/s | 2.0797 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7510ms | 0.4631ms | 2.1592 KOps/s | 2.2060 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.1545ms | 4.7271ms | 211.5450 Ops/s | 208.7561 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.6756ms | 0.4779ms | 2.0924 KOps/s | 2.1104 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7741ms | 0.4565ms | 2.1904 KOps/s | 2.1803 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.8074ms | 1.5727ms | 635.8507 Ops/s | 624.4942 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.8214ms | 1.5244ms | 656.0112 Ops/s | 649.8708 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.3833ms | 4.8168ms | 207.6050 Ops/s | 197.5627 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.2437ms | 0.6223ms | 1.6068 KOps/s | 1.6222 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8746ms | 0.5973ms | 1.6741 KOps/s | 1.6917 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.1187ms | 4.6895ms | 213.2408 Ops/s | 206.8919 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.0070ms | 0.4845ms | 2.0640 KOps/s | 2.0710 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6349ms | 0.4559ms | 2.1935 KOps/s | 2.1528 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.9956ms | 4.6821ms | 213.5808 Ops/s | 208.5125 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 3.0759ms | 0.5331ms | 1.8757 KOps/s | 2.1064 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6965ms | 0.4596ms | 2.1757 KOps/s | 2.2203 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.6248ms | 4.8763ms | 205.0722 Ops/s | 203.7481 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.8201ms | 0.6193ms | 1.6146 KOps/s | 1.6098 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 8.1579ms | 0.6080ms | 1.6446 KOps/s | 1.6734 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.0510ms | 4.3319ms | 230.8460 Ops/s | 36.0966 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 8.1175ms | 2.3456ms | 426.3366 Ops/s | 453.2282 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 5.9596ms | 1.3445ms | 743.7734 Ops/s | 742.6659 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.4396s | 12.9550ms | 77.1901 Ops/s | 230.1522 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 8.2310ms | 2.3546ms | 424.6969 Ops/s | 439.2781 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 4.4607ms | 1.3292ms | 752.3144 Ops/s | 785.3182 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 5.5856ms | 4.3052ms | 232.2789 Ops/s | 220.2234 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 10.5374ms | 2.6357ms | 379.4098 Ops/s | 397.8122 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.0697ms | 1.4069ms | 710.7631 Ops/s | 733.6930 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.7041s | 0.7022s | 1.4241 Ops/s | 1.4165 Ops/s | |
test_transformed | 1.0550s | 0.9572s | 1.0447 Ops/s | 1.0589 Ops/s | |
test_serial | 2.1496s | 2.0544s | 0.4868 Ops/s | 0.4869 Ops/s | |
test_parallel | 2.0600s | 1.9591s | 0.5105 Ops/s | 0.5289 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 71.1310μs | 36.1031μs | 27.6984 KOps/s | 25.8558 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.1614ms | 21.3459μs | 46.8473 KOps/s | 44.5098 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 47.9400μs | 19.5307μs | 51.2015 KOps/s | 48.5176 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 44.7810μs | 11.5173μs | 86.8258 KOps/s | 82.9814 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1443ms | 39.3936μs | 25.3849 KOps/s | 24.0970 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 50.2210μs | 23.6092μs | 42.3564 KOps/s | 40.3139 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1947ms | 22.6559μs | 44.1385 KOps/s | 42.4887 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.1843ms | 14.1837μs | 70.5034 KOps/s | 67.2520 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.2236ms | 42.0192μs | 23.7986 KOps/s | 22.4917 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 0.2048ms | 26.0460μs | 38.3936 KOps/s | 35.9477 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 53.4010μs | 22.5759μs | 44.2950 KOps/s | 41.5471 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 40.5110μs | 14.1449μs | 70.6971 KOps/s | 67.1623 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 72.6020μs | 44.0196μs | 22.7172 KOps/s | 21.3566 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 55.7010μs | 28.7983μs | 34.7243 KOps/s | 32.7255 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 53.2710μs | 24.9307μs | 40.1113 KOps/s | 38.1760 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 43.5710μs | 16.5816μs | 60.3078 KOps/s | 56.7203 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 78.1820μs | 41.8832μs | 23.8759 KOps/s | 23.6771 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 0.1692ms | 26.3442μs | 37.9590 KOps/s | 35.9572 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 57.7810μs | 27.0127μs | 37.0197 KOps/s | 35.2293 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 78.5910μs | 16.5726μs | 60.3407 KOps/s | 57.3203 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 77.3120μs | 44.0312μs | 22.7112 KOps/s | 21.4173 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 55.3510μs | 28.5188μs | 35.0646 KOps/s | 33.1333 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.4731ms | 29.4755μs | 33.9264 KOps/s | 31.8791 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 50.8310μs | 18.7196μs | 53.4200 KOps/s | 50.0960 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 74.4210μs | 46.3817μs | 21.5602 KOps/s | 20.3542 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 59.3510μs | 31.0395μs | 32.2170 KOps/s | 30.3985 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 58.0610μs | 28.7239μs | 34.8142 KOps/s | 32.1220 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 79.5410μs | 18.1688μs | 55.0394 KOps/s | 49.7276 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 84.0320μs | 48.2478μs | 20.7263 KOps/s | 19.6078 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 57.0800μs | 33.6445μs | 29.7226 KOps/s | 28.5329 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 64.3210μs | 31.6450μs | 31.6005 KOps/s | 30.1759 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 60.7310μs | 20.6651μs | 48.3908 KOps/s | 44.7350 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 23.6861ms | 23.1210ms | 43.2507 Ops/s | 42.5579 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 93.8948ms | 2.7448ms | 364.3194 Ops/s | 332.8076 Ops/s | |
test_values[td0_return_estimate-False-False] | 85.6810μs | 62.5202μs | 15.9948 KOps/s | 16.0819 KOps/s | |
test_values[td1_return_estimate-False-False] | 51.9276ms | 51.4050ms | 19.4534 Ops/s | 19.2002 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3193ms | 1.0371ms | 964.2261 Ops/s | 962.7549 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 88.7955ms | 83.6268ms | 11.9579 Ops/s | 12.1039 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3551ms | 1.0375ms | 963.8774 Ops/s | 967.9770 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 23.1675ms | 22.7821ms | 43.8941 Ops/s | 43.5682 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0849ms | 0.7110ms | 1.4064 KOps/s | 1.4081 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7933ms | 0.6243ms | 1.6017 KOps/s | 1.5905 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.7212ms | 1.4605ms | 684.6894 Ops/s | 695.7065 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.8191ms | 0.6375ms | 1.5685 KOps/s | 1.5284 KOps/s | |
test_dqn_speed[False-None] | 1.5203ms | 1.3153ms | 760.3013 Ops/s | 647.8920 Ops/s | |
test_dqn_speed[False-backward] | 1.9478ms | 1.8190ms | 549.7661 Ops/s | 547.3781 Ops/s | |
test_dqn_speed[True-None] | 0.7629ms | 0.5564ms | 1.7972 KOps/s | 1.7780 KOps/s | |
test_dqn_speed[True-backward] | 1.1823ms | 1.0171ms | 983.1583 Ops/s | 809.2453 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 1.7042ms | 0.5712ms | 1.7508 KOps/s | 1.4191 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0955ms | 1.0150ms | 985.2162 Ops/s | 804.2388 Ops/s | |
test_ddpg_speed[False-None] | 3.2608ms | 2.7421ms | 364.6894 Ops/s | 365.3727 Ops/s | |
test_ddpg_speed[False-backward] | 4.0969ms | 3.9597ms | 252.5440 Ops/s | 247.4289 Ops/s | |
test_ddpg_speed[True-None] | 1.4331ms | 1.2541ms | 797.3699 Ops/s | 776.8978 Ops/s | |
test_ddpg_speed[True-backward] | 2.5226ms | 2.2669ms | 441.1263 Ops/s | 382.1305 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.5390ms | 1.2579ms | 794.9626 Ops/s | 793.8581 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.4342ms | 2.2599ms | 442.5002 Ops/s | 443.0260 Ops/s | |
test_sac_speed[False-None] | 8.1350ms | 7.5123ms | 133.1146 Ops/s | 132.5480 Ops/s | |
test_sac_speed[False-backward] | 11.4770ms | 10.8908ms | 91.8207 Ops/s | 91.9327 Ops/s | |
test_sac_speed[True-None] | 2.4791ms | 2.1238ms | 470.8611 Ops/s | 463.4598 Ops/s | |
test_sac_speed[True-backward] | 4.2743ms | 4.1259ms | 242.3708 Ops/s | 232.1486 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.4792ms | 2.1245ms | 470.6945 Ops/s | 476.8346 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 4.2860ms | 4.0755ms | 245.3679 Ops/s | 249.2840 Ops/s | |
test_redq_speed[False-None] | 16.1329ms | 11.1624ms | 89.5864 Ops/s | 75.3317 Ops/s | |
test_redq_speed[False-backward] | 19.5463ms | 18.5645ms | 53.8662 Ops/s | 53.0877 Ops/s | |
test_redq_speed[True-None] | 4.3581ms | 3.7134ms | 269.2978 Ops/s | 254.8512 Ops/s | |
test_redq_speed[True-backward] | 10.2910ms | 9.1759ms | 108.9806 Ops/s | 105.8938 Ops/s | |
test_redq_speed[reduce-overhead-None] | 3.9702ms | 3.6956ms | 270.5889 Ops/s | 263.2169 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 10.1759ms | 9.3242ms | 107.2473 Ops/s | 105.0482 Ops/s | |
test_redq_deprec_speed[False-None] | 11.4784ms | 10.8522ms | 92.1473 Ops/s | 91.1206 Ops/s | |
test_redq_deprec_speed[False-backward] | 16.4515ms | 16.0551ms | 62.2856 Ops/s | 61.5394 Ops/s | |
test_redq_deprec_speed[True-None] | 3.7022ms | 3.2913ms | 303.8308 Ops/s | 297.7743 Ops/s | |
test_redq_deprec_speed[True-backward] | 7.8783ms | 7.4457ms | 134.3062 Ops/s | 126.2340 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 3.6852ms | 3.3007ms | 302.9657 Ops/s | 292.5213 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 8.0161ms | 7.5338ms | 132.7346 Ops/s | 131.8465 Ops/s | |
test_td3_speed[False-None] | 7.7137ms | 7.5420ms | 132.5905 Ops/s | 130.4172 Ops/s | |
test_td3_speed[False-backward] | 11.2323ms | 10.6081ms | 94.2679 Ops/s | 94.9340 Ops/s | |
test_td3_speed[True-None] | 2.0706ms | 1.9894ms | 502.6738 Ops/s | 499.0550 Ops/s | |
test_td3_speed[True-backward] | 4.1430ms | 3.8363ms | 260.6659 Ops/s | 260.5497 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.0003ms | 1.9649ms | 508.9382 Ops/s | 499.2014 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.1460ms | 3.8566ms | 259.2937 Ops/s | 263.0388 Ops/s | |
test_cql_speed[False-None] | 31.7065ms | 26.3268ms | 37.9841 Ops/s | 39.1085 Ops/s | |
test_cql_speed[False-backward] | 41.4241ms | 36.9320ms | 27.0768 Ops/s | 28.4551 Ops/s | |
test_cql_speed[True-None] | 12.1801ms | 11.6075ms | 86.1514 Ops/s | 88.2029 Ops/s | |
test_cql_speed[True-backward] | 18.3614ms | 17.5700ms | 56.9153 Ops/s | 56.9250 Ops/s | |
test_cql_speed[reduce-overhead-None] | 12.2809ms | 11.5305ms | 86.7268 Ops/s | 88.2139 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 18.8248ms | 17.8393ms | 56.0562 Ops/s | 56.7907 Ops/s | |
test_a2c_speed[False-None] | 6.1024ms | 5.4643ms | 183.0060 Ops/s | 178.9941 Ops/s | |
test_a2c_speed[False-backward] | 12.7038ms | 12.2542ms | 81.6044 Ops/s | 80.0081 Ops/s | |
test_a2c_speed[True-None] | 3.3791ms | 3.1388ms | 318.5949 Ops/s | 319.4137 Ops/s | |
test_a2c_speed[True-backward] | 9.2093ms | 8.7904ms | 113.7603 Ops/s | 110.9467 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 3.3326ms | 3.1150ms | 321.0321 Ops/s | 315.8121 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 9.2675ms | 8.7550ms | 114.2209 Ops/s | 113.0240 Ops/s | |
test_ppo_speed[False-None] | 6.0515ms | 5.7217ms | 174.7740 Ops/s | 174.0278 Ops/s | |
test_ppo_speed[False-backward] | 13.4679ms | 12.7976ms | 78.1399 Ops/s | 76.8019 Ops/s | |
test_ppo_speed[True-None] | 3.7148ms | 3.5424ms | 282.2945 Ops/s | 280.2530 Ops/s | |
test_ppo_speed[True-backward] | 9.1318ms | 8.5671ms | 116.7255 Ops/s | 115.3405 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.0117ms | 3.5439ms | 282.1714 Ops/s | 285.1664 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 8.8672ms | 8.4894ms | 117.7935 Ops/s | 115.0187 Ops/s | |
test_reinforce_speed[False-None] | 4.8908ms | 4.4806ms | 223.1862 Ops/s | 213.8430 Ops/s | |
test_reinforce_speed[False-backward] | 7.6582ms | 7.4064ms | 135.0186 Ops/s | 127.9899 Ops/s | |
test_reinforce_speed[True-None] | 2.5574ms | 2.2968ms | 435.3915 Ops/s | 433.6441 Ops/s | |
test_reinforce_speed[True-backward] | 7.8296ms | 7.4783ms | 133.7209 Ops/s | 133.6615 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 2.5274ms | 2.3106ms | 432.7862 Ops/s | 423.1785 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 7.7904ms | 7.3653ms | 135.7715 Ops/s | 135.6560 Ops/s | |
test_iql_speed[False-None] | 21.6331ms | 20.5482ms | 48.6662 Ops/s | 49.7373 Ops/s | |
test_iql_speed[False-backward] | 32.6285ms | 31.8282ms | 31.4187 Ops/s | 32.1354 Ops/s | |
test_iql_speed[True-None] | 7.9432ms | 7.2306ms | 138.3006 Ops/s | 126.3156 Ops/s | |
test_iql_speed[True-backward] | 17.2399ms | 16.3307ms | 61.2345 Ops/s | 61.0730 Ops/s | |
test_iql_speed[reduce-overhead-None] | 7.7781ms | 7.1677ms | 139.5140 Ops/s | 137.3571 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 17.0612ms | 16.3590ms | 61.1283 Ops/s | 61.7007 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.5417ms | 6.2072ms | 161.1031 Ops/s | 163.3355 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.9403ms | 0.2799ms | 3.5727 KOps/s | 3.0749 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5943ms | 0.2907ms | 3.4395 KOps/s | 3.2784 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.5842ms | 6.0521ms | 165.2320 Ops/s | 169.2039 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.8991ms | 0.3033ms | 3.2967 KOps/s | 3.1308 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4795ms | 0.2459ms | 4.0668 KOps/s | 3.2971 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.6520ms | 1.3577ms | 736.5445 Ops/s | 748.5058 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.5588ms | 1.3131ms | 761.5423 Ops/s | 783.4220 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.5750ms | 6.1916ms | 161.5090 Ops/s | 163.8144 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.3627ms | 0.4592ms | 2.1777 KOps/s | 2.0119 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6997ms | 0.4143ms | 2.4138 KOps/s | 2.6344 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.3223ms | 6.0368ms | 165.6494 Ops/s | 166.2714 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.3773ms | 0.3089ms | 3.2370 KOps/s | 3.0152 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5884ms | 0.2815ms | 3.5529 KOps/s | 3.0106 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3788ms | 6.0426ms | 165.4908 Ops/s | 165.7667 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.9369ms | 0.2872ms | 3.4821 KOps/s | 3.8742 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4650ms | 0.2563ms | 3.9021 KOps/s | 3.7454 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.5257ms | 6.2113ms | 160.9967 Ops/s | 164.8756 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.6330ms | 0.4296ms | 2.3278 KOps/s | 2.0392 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7671ms | 0.4480ms | 2.2322 KOps/s | 2.5265 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.0933ms | 5.4541ms | 183.3478 Ops/s | 30.7749 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 9.1596ms | 2.3112ms | 432.6826 Ops/s | 499.3654 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 7.4025ms | 1.2340ms | 810.3923 Ops/s | 774.5385 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.5060s | 15.4912ms | 64.5526 Ops/s | 179.8423 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 9.3371ms | 2.0393ms | 490.3653 Ops/s | 494.0052 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 6.6882ms | 1.2384ms | 807.5002 Ops/s | 821.2939 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 7.2938ms | 5.6903ms | 175.7367 Ops/s | 177.1479 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 11.6546ms | 2.2438ms | 445.6762 Ops/s | 453.7258 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.4261ms | 1.2535ms | 797.7436 Ops/s | 671.5712 Ops/s |
ghstack-source-id: 9f3ab17a572ffd28a30ad4dd46305b2face65bef Pull Request resolved: #2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 70d47e2f3d34949ea648b8c1a351593774b88ce0 Pull Request resolved: #2504
ReplayBuffer.extend
ReplayBuffer.extend\sample
FWIW, I looked into disabling the multiprocessing types like you suggested. I tried temporarily changing Click to expand
With that change, I was still getting recompile records like this:
I also tried to make a minimal reproducer using Click to expandimport torch
import torch.multiprocessing as mp
torch._logging.set_logs(recompiles=True)
class MyStorage:
def __init__(self):
self._len_value = mp.Value("i", 0)
def extend(self, batch_size):
self._len_value.value += batch_size
def sample(self, batch_size):
return torch.randint(0, len(self), (batch_size,))
def __len__(self):
return self._len_value.value
storage = MyStorage()
@torch.compile
def fn(batch_size):
storage.extend(batch_size)
return storage.sample(batch_size)
for idx in range(10):
print('-----------------------')
res = fn(idx+1)
print(f'res: {res}')
print(f'len: {len(storage)}') output:
So I'm guessing |
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: e199d0d19448070633a006c3ddd4031c9168a9cd Pull Request resolved: #2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: e199d0d19448070633a006c3ddd4031c9168a9cd Pull Request resolved: pytorch#2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 02c3066b8064cb549d75ea1cdeee4cae559bef81 Pull Request resolved: #2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 05f2d0271fb8d48e3974542653ed7fa39e70b88f Pull Request resolved: #2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 05f2d0271fb8d48e3974542653ed7fa39e70b88f Pull Request resolved: pytorch#2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 1e82a3363c59edc2e03528bac71ea0bd2cfec4fe Pull Request resolved: #2504
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I like it (and I'm learning a lot too!)
For the compile tests, we may want to decorate them with some pytorch version check - we have an olddeps CI workflow that runs with PT 2.0 for backward compatibility checks. It's ok to skip compile tests as it is not expected that these will succeed with earlier versions of PT.
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 1e82a3363c59edc2e03528bac71ea0bd2cfec4fe Pull Request resolved: pytorch#2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: e6d20c91b3d3173ae87a2bc12c7b2ad0d9a936a6 Pull Request resolved: #2504
@@ -144,12 +144,29 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |||
def _empty(self): | |||
... | |||
|
|||
# NOTE: This property is used to enable compiled Storages. Calling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this note to summarize what I know and don't know. I agree that it makes sense to merge this, but I think I can figure out what's going on here if I spend a little more time on it, so I'll submit another PR when I do
I'm happy with these changes! |
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: e6d20c91b3d3173ae87a2bc12c7b2ad0d9a936a6 Pull Request resolved: pytorch#2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: d306cb9f47bdbfb81988589f4f4d923c8427eaa0 Pull Request resolved: #2504
This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: d306cb9f47bdbfb81988589f4f4d923c8427eaa0 Pull Request resolved: #2504
Part of #2501
Stack from ghstack (oldest at bottom):
ReplayBuffer.extend\sample
#2504