From 3dbe08375e66ecf8583a6434bd2268ef8c7926b3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 26 Nov 2024 14:59:28 +0000 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- tensordict/nn/distributions/composite.py | 3 +- test/test_nn.py | 37 +++++++++++++++++++----- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 5b14dad5a..a68136014 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -314,7 +314,8 @@ def log_prob_composite( "The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.", category=DeprecationWarning, ) - slp = 0.0 + if include_sum: + slp = 0.0 d = {} for name, dist in self.dists.items(): d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name)) diff --git a/test/test_nn.py b/test/test_nn.py index f3f875d76..4bab031a9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12,6 +12,7 @@ import pytest import torch + from tensordict import NonTensorData, NonTensorStack, tensorclass, TensorDict from tensordict._C import unravel_key_list from tensordict.nn import ( @@ -2254,7 +2255,9 @@ def test_log_prob(self): assert isinstance(lp, torch.Tensor) assert lp.requires_grad - def test_log_prob_composite(self): + @pytest.mark.parametrize("inplace", [None, True, False]) + @pytest.mark.parametrize("include_sum", [None, True, False]) + def test_log_prob_composite(self, inplace, include_sum): params = TensorDict( { "cont": { @@ -2273,12 +2276,25 @@ def test_log_prob_composite(self): }, extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, aggregate_probabilities=False, + inplace=inplace, + include_sum=include_sum, ) + if include_sum is None: + include_sum = True + if inplace is None: + inplace = True sample = dist.rsample((4,)) - sample = dist.log_prob_composite(sample, include_sum=True) - assert sample.get("cont_log_prob").requires_grad - assert sample.get(("nested", "disc_log_prob")).requires_grad - assert "sample_log_prob" in sample.keys() + sample_lp = dist.log_prob_composite(sample) + assert sample_lp.get("cont_log_prob").requires_grad + assert sample_lp.get(("nested", "disc_log_prob")).requires_grad + if inplace: + assert sample_lp is sample + else: + assert sample_lp is not sample + if include_sum: + assert "sample_log_prob" in sample_lp.keys() + else: + assert "sample_log_prob" not in sample_lp.keys() def test_entropy(self): params = TensorDict( @@ -2304,7 +2320,8 @@ def test_entropy(self): assert isinstance(ent, torch.Tensor) assert ent.requires_grad - def test_entropy_composite(self): + @pytest.mark.parametrize("include_sum", [None, True, False]) + def test_entropy_composite(self, include_sum): params = TensorDict( { "cont": { @@ -2322,12 +2339,18 @@ def test_entropy_composite(self): ("nested", "disc"): distributions.Categorical, }, aggregate_probabilities=False, + include_sum=include_sum, ) + if include_sum is None: + include_sum = True sample = dist.entropy() assert sample.shape == params.shape == dist._batch_shape assert sample.get("cont_entropy").requires_grad assert sample.get(("nested", "disc_entropy")).requires_grad - assert "entropy" in sample.keys() + if include_sum: + assert "entropy" in sample.keys() + else: + assert "entropy" not in sample.keys() def test_cdf(self): params = TensorDict(