Skip to content

Commit

Permalink
[Refactor] composite_lp_aggregate to handle log-probs aggregates gl…
Browse files Browse the repository at this point in the history
…obally (#1181)
  • Loading branch information
vmoens authored Jan 15, 2025
1 parent 8e63112 commit 0013e38
Show file tree
Hide file tree
Showing 8 changed files with 883 additions and 473 deletions.
4 changes: 3 additions & 1 deletion docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ to build distributions from network outputs and get summary statistics or sample
TensorDictModuleWrapper
CudaGraphModule
WrapModule
set_composite_lp_aggregate
composite_lp_aggregate

Ensembles
---------
Expand Down Expand Up @@ -249,10 +251,10 @@ Distributions
:toctree: generated/
:template: rl_template_noinherit.rst

NormalParamExtractor
AddStateIndependentNormalScale
CompositeDistribution
Delta
NormalParamExtractor
OneHotCategorical
TruncatedNormal

Expand Down
2 changes: 2 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2629,6 +2629,8 @@ def rename_key_(
old_key = unravel_key(old_key)
new_key = unravel_key(new_key)
if old_key == new_key:
if old_key not in self.keys(include_nested=isinstance(old_key, tuple)):
raise KeyError(f"Key {old_key} not found in tensordict.")
return self
if safe and (new_key in self.keys(include_nested=True)):
raise KeyError(f"key {new_key} already present in TensorDict.")
Expand Down
1 change: 1 addition & 0 deletions tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
)

from .cudagraphs import CudaGraphModule
from .utils import composite_lp_aggregate, set_composite_lp_aggregate
206 changes: 116 additions & 90 deletions tensordict/nn/distributions/composite.py

Large diffs are not rendered by default.

398 changes: 180 additions & 218 deletions tensordict/nn/probabilistic.py

Large diffs are not rendered by default.

96 changes: 96 additions & 0 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import inspect
import os
import warnings
from enum import Enum
from typing import Any, Callable

Expand Down Expand Up @@ -447,3 +448,98 @@ def __new__(cls, *values):

def _generate_next_value_(name, start, count, last_values):
return name.lower()


_composite_lp_aggregate = _ContextManager()


def composite_lp_aggregate(nowarn: bool = False) -> bool | None:
"""Returns whether a :class:`~tensordict.nn.CompositeDistribution` log-probabilities and entropies will be aggregated in a single tensor.
Args:
nowarn (bool, optional): whether to ignore warnings. Defaults to False.
.. seealso:: :func:`~tensordict.nn.set_composite_lp_aggregate`
"""
mode = _composite_lp_aggregate.get_mode()
if mode is None:
if not nowarn:
warnings.warn(
"Composite log-prob aggregation wasn't defined explicitly and ``composite_lp_aggregate()`` will "
"currently return ``True``. However, from v0.9, this behaviour will change and ``composite_lp_aggregate`` will "
"return ``False``. Please change your code accordingly by specifying the aggregation strategy via "
"`tensordict.nn.set_composite_lp_aggregate`.",
category=DeprecationWarning,
)
return True
return mode


class set_composite_lp_aggregate(_DecoratorContextManager):
"""Controls whether :class:`~tensordict.nn.CompositeDistribution` log-probabilities and entropies will be aggregated in a single tensor.
When :func:`~tensordict.nn.composite_lp_aggregate` returns ``True``, the log-probs / entropies of :class:`~tensordict.nn.CompositeDistribution`
will be summed into a single tensor with the shape of the root tensordict. This behaviour is being deprecated in favor of
non-aggregated log-probs, which offer more flexibility and a somewhat more natural API (tensordict samples, tensordict log-probs, tensordict entropies).
Example:
>>> _ = torch.manual_seed(0)
>>> from tensordict import TensorDict
>>> from tensordict.nn import CompositeDistribution, set_composite_lp_aggregate
>>> import torch
>>> from torch import distributions as d
>>> params = TensorDict({
... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
... ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> with set_composite_lp_aggregate(False):
... lp = dist.log_prob(sample)
... print(lp)
TensorDict(
fields={
cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([4, 3]),
device=None,
is_shared=False)
>>> with set_composite_lp_aggregate(True):
... lp = dist.log_prob(sample)
... print(lp)
tensor([[-2.0886, -1.2155, -0.0414],
[-2.8973, -5.5165, 2.4402],
[-0.2806, -1.2799, 3.1733],
[-3.0407, -4.3593, 0.5763]])
"""

def __init__(
self,
mode: bool = True,
) -> None:
super().__init__()
self.mode = mode

def clone(self) -> set_composite_lp_aggregate:
# override this method if your children class takes __init__ parameters
return type(self)(self.mode)

def __enter__(self) -> None:
self.prev = _composite_lp_aggregate.get_mode()
_composite_lp_aggregate.set_mode(self.mode)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
_composite_lp_aggregate.set_mode(self.prev)

def set(self):
self.__enter__()

def unset(self):
return self.__exit__(None, None, None)
2 changes: 2 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
CudaGraphModule,
InteractionType,
ProbabilisticTensorDictModule as Prob,
set_composite_lp_aggregate,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential as Seq,
Expand Down Expand Up @@ -665,6 +666,7 @@ def test_dispatch_tensor(self, mode):
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))

@set_composite_lp_aggregate(False)
def test_prob_module_with_kwargs(self, mode):
kwargs = TensorDictParams(
TensorDict(scale=1.0, validate_args=NonTensorData(False)), no_convert=True
Expand Down
Loading

1 comment on commit 0013e38

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 0013e38 Previous: 8e63112 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 75068.74805100504 iter/sec (stddev: 9.984056728442533e-7) 164359.35247636584 iter/sec (stddev: 4.7312673233608383e-7) 2.19
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 75567.53017318413 iter/sec (stddev: 8.721267071351303e-7) 164495.79232413115 iter/sec (stddev: 5.572106858218063e-7) 2.18
benchmarks/common/memmap_benchmarks_test.py::test_serialize_weights_pickle 1.2713359511864577 iter/sec (stddev: 0.287109772646582) 2.5479902302818336 iter/sec (stddev: 0.06477800517907545) 2.00

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.