Skip to content

Commit

Permalink
[Refactor] Make CompositeDistribution a tensordict-exclusive class
Browse files Browse the repository at this point in the history
ghstack-source-id: 56c1dd2ad856a18613ec1a4c0ca70aedd28a52e3
Pull Request resolved: #1112
  • Loading branch information
vmoens committed Nov 27, 2024
1 parent c842730 commit b4b8b31
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 121 deletions.
262 changes: 193 additions & 69 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,38 @@


class CompositeDistribution(d.Distribution):
"""A composition of distributions.
"""A composite distribution that groups multiple distributions together using the TensorDict interface.
Groups distributions together with the TensorDict interface. Methods
(``log_prob_composite``, ``entropy_composite``, ``cdf``, ``icdf``, ``rsample``, ``sample`` etc.)
will return a tensordict, possibly modified in-place if the input was a tensordict.
This class allows for operations such as `log_prob_composite`, `entropy_composite`, `cdf`, `icdf`, `rsample`, and `sample`
to be performed on a collection of distributions, returning a TensorDict. The input TensorDict may be modified in-place.
Args:
params (TensorDictBase): a nested key-tensor map where the root entries
point to the sample names, and the leaves are the distribution parameters.
Entry names must match those of ``distribution_map``.
distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]):
indicated the distribution types to be used. The names of the distributions
will match the names of the samples in the tensordict.
params (TensorDictBase): A nested key-tensor map where the root entries correspond to sample names, and the leaves
are the distribution parameters. Entry names must match those specified in `distribution_map`.
distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]): Specifies the distribution types to be used.
The names of the distributions should match the sample names in the `TensorDict`.
Keyword Arguments:
name_map (Dict[NestedKey, NestedKey]]): a dictionary representing where each
sample should be written. If not provided, the key names from ``distribution_map``
will be used.
extra_kwargs (Dict[NestedKey, Dict]): a possibly incomplete dictionary of
extra keyword arguments for the distributions to be built.
aggregate_probabilities (bool): if ``True``, the :meth:`~.log_prob` and :meth:`~.entropy` methods will
sum the probabilities and entropies of the individual distributions and return a single tensor.
If ``False``, the single log-probabilities will be registered in the input tensordict (for :meth:`~.log_prob`)
or retuned as leaves of the output tensordict (for :meth:`~.entropy`).
This parameter can be overridden at runtime by passing the ``aggregate_probabilities`` argument to
``log_prob`` and ``entropy``.
Defaults to ``False``.
log_prob_key (NestedKey, optional): key where to write the log_prob.
Defaults to `'sample_log_prob'`.
entropy_key (NestedKey, optional): key where to write the entropy.
Defaults to `'entropy'`.
.. note::
In this distribution class, the batch-size of the input tensordict containing the params
(``params``) is indicative of the batch_shape of the distribution. For instance,
the ``"sample_log_prob"`` entry resulting from a call to ``log_prob``
will be of the shape of the params (+ any supplementary batch dimension).
name_map (Dict[NestedKey, NestedKey], optional): A mapping of where each sample should be written. If not provided,
the key names from `distribution_map` will be used.
extra_kwargs (Dict[NestedKey, Dict], optional): A dictionary of additional keyword arguments for constructing the distributions.
aggregate_probabilities (bool, optional): If `True`, the `log_prob` and `entropy` methods will sum the probabilities and entropies
of the individual distributions and return a single tensor. If `False`, individual log-probabilities will be stored in the input
TensorDict (for `log_prob`) or returned as leaves of the output TensorDict (for `entropy`). This can be overridden at runtime
by passing the `aggregate_probabilities` argument to `log_prob` and `entropy`. Defaults to `False`.
log_prob_key (NestedKey, optional): The key where the log probability will be stored. Defaults to `'sample_log_prob'`.
entropy_key (NestedKey, optional): The key where the entropy will be stored. Defaults to `'entropy'`.
inplace (bool, optional): Whether to modify the input TensorDict in-place. Defaults to `True`.
.. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor.
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. Defaults to `True`.
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
.. note:: The batch size of the input TensorDict containing the parameters (`params`) determines the batch shape of
the distribution. For example, the `"sample_log_prob"` entry resulting from a call to `log_prob` will have the
shape of the parameters plus any additional batch dimensions.
Examples:
>>> params = TensorDict({
Expand Down Expand Up @@ -88,6 +82,8 @@ def __init__(
aggregate_probabilities: bool | None = None,
log_prob_key: NestedKey = "sample_log_prob",
entropy_key: NestedKey = "entropy",
inplace: bool | None = None,
include_sum: bool | None = None,
):
self._batch_shape = params.shape
if extra_kwargs is None:
Expand Down Expand Up @@ -122,6 +118,8 @@ def __init__(
self.entropy_key = entropy_key

self.aggregate_probabilities = aggregate_probabilities
self.include_sum = include_sum
self.inplace = inplace

@property
def aggregate_probabilities(self):
Expand Down Expand Up @@ -223,16 +221,32 @@ def rsample(self, shape=None) -> TensorDictBase:
)

def log_prob(
self, sample: TensorDictBase, *, aggregate_probabilities: bool | None = None
self,
sample: TensorDictBase,
*,
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> torch.Tensor | TensorDictBase: # noqa: D417
"""Computes and returns the summed log-prob.
"""Compute the summed log-probability of a given sample.
Args:
sample (TensorDictBase): the sample to compute the log probability.
sample (TensorDictBase): The input sample to compute the log probability for.
Keyword Args:
aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities``
from the class.
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Has no effect if ``aggregate_probabilities`` is set to ``True``.
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
inplace (bool, optional): Whether to update the input sample in-place or return a new TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Has no effect if ``aggregate_probabilities`` is set to ``True``.
.. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor.
If ``self.aggregate_probabilities`` is ``True``, this method will return a single tensor with
the summed log-probabilities. If ``self.aggregate_probabilities`` is ``False``, this method will
Expand All @@ -243,7 +257,9 @@ def log_prob(
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if not aggregate_probabilities:
return self.log_prob_composite(sample, include_sum=True)
return self.log_prob_composite(
sample, include_sum=include_sum, inplace=inplace
)
slp = 0.0
for name, dist in self.dists.items():
lp = dist.log_prob(sample.get(name))
Expand All @@ -253,47 +269,105 @@ def log_prob(
return slp

def log_prob_composite(
self, sample: TensorDictBase, include_sum=True
self,
sample: TensorDictBase,
*,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> TensorDictBase:
"""Writes a ``<sample>_log_prob`` entry for each sample in the input tensordict, along with a ``"sample_log_prob"`` entry with the summed log-prob.
"""Computes the log-probability of each component in the input sample and return a TensorDict with individual log-probabilities.
Args:
sample (TensorDictBase): The input sample to compute the log probabilities for.
This method is called by the :meth:`~.log_prob` method when ``self.aggregate_probabilities`` is ``False``.
Keyword Args:
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
inplace (bool, optional): Whether to update the input sample in-place or return a new TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
.. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor.
Returns:
TensorDictBase: A TensorDict containing the individual log-probabilities for each component in the input sample,
along with a "sample_log_prob" entry containing the summed log-probability if `include_sum` is True.
"""
slp = 0.0
if include_sum is None:
include_sum = self.include_sum

if include_sum is None:
include_sum = True
warnings.warn(
"`include_sum` wasn't set when building the `CompositeDistribution` or when calling log_prob_composite. "
"The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.",
category=DeprecationWarning,
)
if inplace is None:
inplace = self.inplace
if inplace is None:
inplace = True
warnings.warn(
"`inplace` wasn't set when building the `CompositeDistribution` or when calling log_prob_composite. "
"The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.",
category=DeprecationWarning,
)
if include_sum:
slp = 0.0
d = {}
for name, dist in self.dists.items():
d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))
if lp.ndim > sample.ndim:
lp = lp.flatten(sample.ndim, -1).sum(-1)
slp = slp + lp
if include_sum:
if lp.ndim > sample.ndim:
lp = lp.flatten(sample.ndim, -1).sum(-1)
slp = slp + lp
if include_sum:
d[self.log_prob_key] = slp
sample.update(d)
if inplace:
sample.update(d)
else:
return sample.empty(recurse=True).update(d).filter_empty_()
return sample

def entropy(
self, samples_mc: int = 1, *, aggregate_probabilities: bool | None = None
self,
samples_mc: int = 1,
*,
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
) -> torch.Tensor | TensorDictBase: # noqa: D417
"""Computes and returns the summed entropies.
"""Computes and returns the entropy of the composite distribution.
This method calculates the entropy for each component distribution and optionally sums them.
Args:
samples_mc (int): the number samples to draw if the entropy does not have a closed form formula.
Defaults to ``1``.
samples_mc (int): The number of samples to draw if the entropy does not have a closed-form solution.
Defaults to `1`.
Keyword Args:
aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities``
from the class.
If ``self.aggregate_probabilities`` is ``True``, this method will return a single tensor with
the summed entropies. If ``self.aggregate_probabilities`` is ``False``, this method will call
the `:meth:`~.entropy_composite` method and return a tensordict with the entropies of each sample
in the input tensordict along with an ``entropy`` entry with the summed entropy. In both cases,
the output shape will match the shape of the distribution ``batch_shape``.
aggregate_probabilities (bool, optional): If provided, overrides the default `aggregate_probabilities`
setting from the class. Determines whether to return a single summed entropy tensor or a TensorDict
with individual entropies. Defaults to ``False`` if not set in the class.
include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict.
Defaults to `self.inplace`, which is set through the class constructor. Has no effect if
`aggregate_probabilities` is set to `True`.
.. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor.
Returns:
torch.Tensor or TensorDictBase: If `aggregate_probabilities` is `True`, returns a single tensor with
the summed entropies. If `aggregate_probabilities` is `False`, returns a TensorDict with the entropies
of each component distribution.
.. note:: If a distribution does not implement a closed-form solution for entropy, Monte Carlo sampling is used
to estimate it.
"""
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if not aggregate_probabilities:
return self.entropy_composite(samples_mc, include_sum=True)
return self.entropy_composite(samples_mc, include_sum=include_sum)
se = 0.0
for _, dist in self.dists.items():
try:
Expand All @@ -306,11 +380,44 @@ def entropy(
se = se + e
return se

def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase:
"""Writes a ``<sample>_entropy`` entry for each sample in the input tensordict, along with a ``"entropy"`` entry with the summed entropies.
def entropy_composite(
self,
samples_mc=1,
*,
include_sum: bool | None = None,
) -> TensorDictBase:
"""Computes the entropy for each component distribution and returns a TensorDict with individual entropies.
This method is used by the `entropy` method when `self.aggregate_probabilities` is `False`.
Args:
samples_mc (int): The number of samples to draw if the entropy does not have a closed-form solution.
Defaults to `1`.
Keyword Args:
include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict.
Defaults to `self.include_sum`, which is set through the class constructor.
.. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor.
This method is called by the :meth:`~.entropy` method when ``self.aggregate_probabilities`` is ``False``.
Returns:
TensorDictBase: A TensorDict containing the individual entropies for each component distribution,
along with an "entropy" entry containing the summed entropies if `include_sum` is `True`.
.. note:: If a distribution does not implement a closed-form solution for entropy, Monte Carlo sampling is used
to estimate it.
"""
if include_sum is None:
include_sum = self.include_sum

if include_sum is None:
include_sum = True
warnings.warn(
"`include_sum` wasn't set when building the `CompositeDistribution` or when calling log_prob_composite. "
"The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.",
category=DeprecationWarning,
)

se = 0.0
d = {}
for name, dist in self.dists.items():
Expand All @@ -320,9 +427,10 @@ def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase:
x = dist.rsample((samples_mc,))
e = -dist.log_prob(x).mean(0)
d[_add_suffix(name, "_entropy")] = e
if e.ndim > len(self.batch_shape):
e = e.flatten(len(self.batch_shape), -1).sum(-1)
se = se + e
if include_sum:
if e.ndim > len(self.batch_shape):
e = e.flatten(len(self.batch_shape), -1).sum(-1)
se = se + e
if include_sum:
d[self.entropy_key] = se
return TensorDict(
Expand All @@ -331,6 +439,16 @@ def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase:
)

def cdf(self, sample: TensorDictBase) -> TensorDictBase:
"""Computes the cumulative distribution function (CDF) for each component distribution in the composite distribution.
This method calculates the CDF for each component distribution and updates the input TensorDict with the results.
Args:
sample (TensorDictBase): A TensorDict containing samples for which to compute the CDF.
Returns:
TensorDictBase: The input TensorDict updated with `<sample_name>_cdf` entries for each component distribution.
"""
cdfs = {
_add_suffix(name, "_cdf"): dist.cdf(sample.get(name))
for name, dist in self.dists.items()
Expand All @@ -339,14 +457,20 @@ def cdf(self, sample: TensorDictBase) -> TensorDictBase:
return sample

def icdf(self, sample: TensorDictBase) -> TensorDictBase:
"""Computes the inverse CDF.
"""Computes the inverse cumulative distribution function (inverse CDF) for each component distribution.
Requires the input tensordict to have one of `<sample_name>+'_cdf'` entry
or a `<sample_name>` entry.
This method requires the input TensorDict to have either a `<sample_name>_cdf` entry or a `<sample_name>` entry
for each component distribution. It calculates the inverse CDF and updates the TensorDict with the results.
Args:
sample (TensorDictBase): a tensordict containing `<sample>_log_prob` where
`<sample>` is the name of the sample provided during construction.
sample (TensorDictBase): A TensorDict containing either `<sample_name>_cdf` or `<sample_name>` entries
for each component distribution.
Returns:
TensorDictBase: The input TensorDict updated with `<sample_name>_icdf` entries for each component distribution.
Raises:
KeyError: If neither `<sample_name>` nor `<sample_name>_cdf` can be found in the input TensorDict for a component distribution.
"""
for name, dist in self.dists.items():
# TODO: v0.7: remove the None
Expand Down
Loading

0 comments on commit b4b8b31

Please sign in to comment.