Skip to content

Commit

Permalink
[Feature] Avoid some recompiles of ReplayBuffer.extend/sample
Browse files Browse the repository at this point in the history
This change avoids recompiles for back-to-back calls to
`ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`,
`RoundRobinWriter`, and `RandomSampler` are used and the data type is
either tensor or pytree.

ghstack-source-id: 1e82a3363c59edc2e03528bac71ea0bd2cfec4fe
Pull Request resolved: #2504
  • Loading branch information
kurtamohler committed Oct 23, 2024
1 parent 5244a90 commit 55713cf
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 3 deletions.
27 changes: 27 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from __future__ import annotations

import contextlib
import logging
import os

import os.path
import time
import unittest
from functools import wraps

# Get relative file path
Expand Down Expand Up @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs):
return deco_retry


# After calling this function, any log record whose name contains 'record_name'
# and is emitted from the logger that has qualified name 'logger_qname' is
# appended to the 'records' list.
# NOTE: This function is based on testing utilities for 'torch._logging'
def capture_log_records(records, logger_qname, record_name):
assert isinstance(records, list)
logger = logging.getLogger(logger_qname)

class EmitWrapper:
def __init__(self, old_emit):
self.old_emit = old_emit

def __call__(self, record):
nonlocal records
self.old_emit(record)
if record_name in record.name:
records.append(record)

for handler in logger.handlers:
new_emit = EmitWrapper(handler.emit)
contextlib.ExitStack().enter_context(
unittest.mock.patch.object(handler, "emit", new_emit)
)


@pytest.fixture
def dtype_fixture():
dtype = torch.get_default_dtype()
Expand Down
74 changes: 73 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import pytest
import torch

from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)

from mocking_classes import CountingEnv
from packaging import version
Expand Down Expand Up @@ -399,6 +404,73 @@ def data_iter():
) if cond else contextlib.nullcontext():
rb.extend(data2)

def test_extend_sample_recompile(
self, rb_type, sampler, writer, storage, size, datatype
):
if _os_is_windows:
# Compiling on Windows requires "cl" compiler to be installed.
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
# Our Windows CI jobs do not have "cl", so skip this test.
pytest.skip("This test does not support Windows.")
if rb_type is not ReplayBuffer:
pytest.skip(
"Only replay buffer of type 'ReplayBuffer' is currently supported."
)
if sampler is not RandomSampler:
pytest.skip("Only sampler of type 'RandomSampler' is currently supported.")
if storage is not LazyTensorStorage:
pytest.skip(
"Only storage of type 'LazyTensorStorage' is currently supported."
)
if writer is not RoundRobinWriter:
pytest.skip(
"Only writer of type 'RoundRobinWriter' is currently supported."
)
if datatype == "tensordict":
pytest.skip("'tensordict' datatype is not currently supported.")

torch._dynamo.reset()

storage_size = 10 * size
rb = self._get_rb(
rb_type=rb_type,
sampler=sampler,
writer=writer,
storage=storage,
size=storage_size,
)
data_size = size
data = self._get_data(datatype, size=data_size)

@torch.compile
def extend_and_sample(data):
rb.extend(data)
return rb.sample()

# Number of times to extend the replay buffer
num_extend = 30

# NOTE: The first two calls to 'extend' and 'sample' currently cause
# recompilations, so avoid capturing those for now.
num_extend_before_capture = 2

for _ in range(num_extend_before_capture):
extend_and_sample(data)

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")

for _ in range(num_extend - num_extend_before_capture):
extend_and_sample(data)

assert len(rb) == storage_size
assert len(records) == 0

finally:
torch._logging.set_logs()

def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down
27 changes: 26 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

from _utils_internal import get_default_devices
from _utils_internal import capture_log_records, get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend
Expand Down Expand Up @@ -380,6 +380,31 @@ def test_rng_decorator(device):
torch.testing.assert_close(s0b, s1b)


# Check that 'capture_log_records' captures records emitted when torch
# recompiles a function.
def test_capture_log_records_recompile():
torch.compiler.reset()

# This function recompiles each time it is called with a different string
# input.
@torch.compile
def str_to_tensor(s):
return bytes(s, "utf8")

str_to_tensor("a")

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")
str_to_tensor("b")

finally:
torch._logging.set_logs()

assert len(records) == 1


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 8 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,19 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _empty(self):
...

# NOTE: This property is used to enable compiled Storages. A `len(self)`
# call can cause recompiles, but for some reason, wrapping the call in a
# `property` decorated function avoids the recompiles.
@property
def len(self):
return len(self)

def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if self.ndim == 1:
return torch.randint(
0,
len(self),
self.len,
(batch_size,),
generator=self._rng,
device=getattr(self, "device", None),
Expand Down

0 comments on commit 55713cf

Please sign in to comment.