From e1d5826674480af1bfc31c2e0c27fdcfe2e8a2b5 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 23 Oct 2024 11:48:13 -0700 Subject: [PATCH] [Feature] Avoid some recompiles of `ReplayBuffer.extend\sample` This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: 70d47e2f3d34949ea648b8c1a351593774b88ce0 Pull Request resolved: https://github.com/pytorch/rl/pull/2504 --- test/_utils_internal.py | 27 +++++++++ test/test_rb.py | 74 ++++++++++++++++++++++++- test/test_utils.py | 27 ++++++++- torchrl/data/replay_buffers/storages.py | 2 +- 4 files changed, 127 insertions(+), 3 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 51535afa606..48492459315 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -5,10 +5,12 @@ from __future__ import annotations import contextlib +import logging import os import os.path import time +import unittest from functools import wraps # Get relative file path @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs): return deco_retry +# After calling this function, any log record whose name contains 'record_name' +# and is emitted from the logger that has qualified name 'logger_qname' is +# appended to the 'records' list. +# NOTE: This function is based on testing utilities for 'torch._logging' +def capture_log_records(records, logger_qname, record_name): + assert isinstance(records, list) + logger = logging.getLogger(logger_qname) + + class EmitWrapper: + def __init__(self, old_emit): + self.old_emit = old_emit + + def __call__(self, record): + nonlocal records + self.old_emit(record) + if record_name in record.name: + records.append(record) + + for handler in logger.handlers: + new_emit = EmitWrapper(handler.emit) + contextlib.ExitStack().enter_context( + unittest.mock.patch.object(handler, "emit", new_emit) + ) + + @pytest.fixture def dtype_fixture(): dtype = torch.get_default_dtype() diff --git a/test/test_rb.py b/test/test_rb.py index 0e10f534728..9db96e5d8c0 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -17,7 +17,12 @@ import pytest import torch -from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc +from _utils_internal import ( + capture_log_records, + CARTPOLE_VERSIONED, + get_default_devices, + make_tc, +) from mocking_classes import CountingEnv from packaging import version @@ -399,6 +404,73 @@ def data_iter(): ) if cond else contextlib.nullcontext(): rb.extend(data2) + def test_extend_sample_recompile( + self, rb_type, sampler, writer, storage, size, datatype + ): + if _os_is_windows: + # Compiling on Windows requires "cl" compiler to be installed. + # + # Our Windows CI jobs do not have "cl", so skip this test. + pytest.skip("This test does not support Windows.") + if rb_type is not ReplayBuffer: + pytest.skip( + "Only replay buffer of type 'ReplayBuffer' is currently supported." + ) + if sampler is not RandomSampler: + pytest.skip("Only sampler of type 'RandomSampler' is currently supported.") + if storage is not LazyTensorStorage: + pytest.skip( + "Only storage of type 'LazyTensorStorage' is currently supported." + ) + if writer is not RoundRobinWriter: + pytest.skip( + "Only writer of type 'RoundRobinWriter' is currently supported." + ) + if datatype == "tensordict": + pytest.skip("'tensordict' datatype is not currently supported.") + + torch.compiler.reset() + + storage_size = 10 * size + rb = self._get_rb( + rb_type=rb_type, + sampler=sampler, + writer=writer, + storage=storage, + size=storage_size, + ) + data_size = size + data = self._get_data(datatype, size=data_size) + + @torch.compile + def extend_and_sample(data): + rb.extend(data) + return rb.sample() + + # Number of times to extend the replay buffer + num_extend = 30 + + # NOTE: The first two calls to 'extend' and 'sample' currently cause + # recompilations, so avoid capturing those for now. + num_extend_before_capture = 2 + + for _ in range(num_extend_before_capture): + extend_and_sample(data) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend - num_extend_before_capture): + extend_and_sample(data) + + assert len(rb) == storage_size + assert len(records) == 0 + + finally: + torch._logging.set_logs() + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( diff --git a/test/test_utils.py b/test/test_utils.py index 4224a36b54f..af5dc09985c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,7 +14,7 @@ import torch -from _utils_internal import get_default_devices +from _utils_internal import capture_log_records, get_default_devices from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -380,6 +380,31 @@ def test_rng_decorator(device): torch.testing.assert_close(s0b, s1b) +# Check that 'capture_log_records' captures records emitted when torch +# recompiles a function. +def test_capture_log_records_recompile(): + torch.compiler.reset() + + # This function recompiles each time it is called with a different string + # input. + @torch.compile + def str_to_tensor(s): + return bytes(s, "utf8") + + str_to_tensor("a") + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + str_to_tensor("b") + + finally: + torch._logging.set_logs() + + assert len(records) == 1 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 4d47fd5265d..217229b5d9b 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -149,7 +149,7 @@ def _rand_given_ndim(self, batch_size): if self.ndim == 1: return torch.randint( 0, - len(self), + self._len, (batch_size,), generator=self._rng, device=getattr(self, "device", None),