diff --git a/docs/source/_static/img/mcts_forest.png b/docs/source/_static/img/mcts_forest.png new file mode 100644 index 00000000000..2ee7014dc54 Binary files /dev/null and b/docs/source/_static/img/mcts_forest.png differ diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 9b072cc9664..c95edd6156f 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -975,7 +975,72 @@ The following classes are deprecated and just point to the classes above: Trees and Forests ----------------- -TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently. +TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently, +which is particularly useful for Monte Carlo Tree Search (MCTS) algorithms. + +TensorDictMap +~~~~~~~~~~~~~ + +At its core, the MCTS API relies on the :class:`~torchrl.data.TensorDictMap` which acts like a storage where indices can +be any numerical object. In traditional storages (e.g., :class:`~torchrl.data.TensorStorage`), only integer indices +are allowed: + + >>> storage = TensorStorage(...) + >>> data = storage[3] + +:class:`~torchrl.data.TensorDictMap` allows us to make more advanced queries in the storage. The typical example is +when we have a storage containing a set of MDPs and we want to rebuild a trajectory given its initial observation, action +pair. In tensor terms, this could be written with the following pseudocode: + + >>> next_state = storage[observation, action] + +(if there is more than one next state associated with this pair one could return a stack of ``next_states`` instead). +This API would make sense but it would be restrictive: allowing observations or actions that are composed of +multiple tensors may be hard to implement. Instead, we provide a tensordict containing these values and let the storage +know what ``in_keys`` to look at to query the next state: + + >>> td = TensorDict(observation=observation, action=action) + >>> next_td = storage[td] + +Of course, this class also allows us to extend the storage with new data: + + >>> storage[td] = next_state + +This comes in handy because it allows us to represent complex rollout structures where different actions are undertaken +at a given node (ie, for a given observation). All `(observation, action)` pairs that have been observed may lead us to +a (set of) rollout that we can use further. + +MCTSForest +~~~~~~~~~~ + +Building a tree from an initial observation then becomes just a matter of organizing data efficiently. +The :class:`~torchrl.data.MCTSForest` has at its core two storages: a first storage links observations to hashes and +indices of actions encountered in the past in the dataset: + + >>> data = TensorDict(observation=observation) + >>> metadata = forest.node_map[data] + >>> index = metadata["_index"] + +where ``forest`` is a :class:`~torchrl.data.MCTSForest` instance. +Then, a second storage keeps track of the actions and results associated with the observation: + + >>> next_data = forest.data_map[index] + +The ``next_data`` entry can have any shape, but it will usually match the shape of ``index`` (since at each index +corresponds one action). Once ``next_data`` is obtrained, it can be put together with ``data`` to form a set of nodes, +and the tree can be expanded for each of these. The following figure shows how this is done. + +.. figure:: /_static/img/collector-copy.png + + Building a :class:`~torchrl.data.Tree` from a :class:`~torchrl.data.MCTSForest` object. + The flowchart represents a tree being built from an initial observation `o`. The :class:`~torchrl.data.MCTSForest.get_tree` + method passed the input data structure (the root node) to the ``node_map`` :class:`~torchrl.data.TensorDictMap` instance + that returns a set of hashes and indices. These indices are then used to query the corresponding tuples of + actions, next observations, rewards etc. that are associated with the root node. + A vertex is created from each of them (possibly with a longer rollout when a compact representation is asked). + The stack of vertices is then used to build up the tree further, and these vertices are stacked together and make + up the branches of the tree at the root. This process is repeated for a given depth or until the tree cannot be + expanded anymore. .. currentmodule:: torchrl.data @@ -985,11 +1050,13 @@ TorchRL offers a set of classes and functions that can be used to represent tree BinaryToDecimal HashToInt + MCTSForeset QueryModule RandomProjectionHash SipHash TensorDictMap TensorMap + Tree Reinforcement Learning From Human Feedback (RLHF) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index c5f748e2e9d..6c9348e5ae9 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -5,13 +5,14 @@ import argparse import functools import importlib.util +from typing import Tuple import pytest import torch -from tensordict import TensorDict -from torchrl.data import LazyTensorStorage, ListStorage +from tensordict import assert_close, TensorDict +from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest from torchrl.data.map import ( BinaryToDecimal, QueryModule, @@ -238,6 +239,240 @@ def test_map_rollout(self): assert not contains[rollout.shape[-1] :].any() +class TestMCTSForest: + def dummy_rollouts(self) -> Tuple[TensorDict, ...]: + """ + ├── 0 + │ ├── 16 + │ ├── 17 + │ ├── 18 + │ ├── 19 + │ └── 20 + ├── 1 + ├── 2 + ├── 3 + │ ├── 6 + │ ├── 7 + │ ├── 8 + │ ├── 9 + │ └── 10 + ├── 4 + │ ├── 11 + │ ├── 12 + │ ├── 13 + │ │ ├── 21 + │ │ ├── 22 + │ │ ├── 23 + │ │ ├── 24 + │ │ └── 25 + │ ├── 14 + │ └── 15 + └── 5 + + """ + + states0 = torch.arange(6) + actions0 = torch.full((5,), 0) + + states1 = torch.cat([torch.tensor([3]), torch.arange(6, 11)]) + actions1 = torch.full((5,), 1) + + states2 = torch.cat([torch.tensor([4]), torch.arange(11, 16)]) + actions2 = torch.full((5,), 2) + + states3 = torch.cat([torch.tensor([0]), torch.arange(16, 21)]) + actions3 = torch.full((5,), 3) + + states4 = torch.cat([torch.tensor([13]), torch.arange(21, 26)]) + actions4 = torch.full((5,), 4) + + return ( + self._make_td(states0, actions0), + self._make_td(states1, actions1), + self._make_td(states2, actions2), + self._make_td(states3, actions3), + self._make_td(states4, actions4), + ) + + def _state0(self) -> TensorDict: + return self.dummy_rollouts()[0][0] + + @staticmethod + def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict: + done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1) + reward = action.clone() + + return TensorDict( + { + "observation": state[:-1], + "action": action, + "done": torch.zeros_like(done), + "next": { + "observation": state[1:], + "done": done, + "reward": reward, + }, + } + ).auto_batch_size_() + + def _make_forest(self) -> MCTSForest: + r0, r1, r2, r3, r4 = self.dummy_rollouts() + assert r0.shape + forest = MCTSForest() + forest.extend(r0) + forest.extend(r1) + forest.extend(r2) + forest.extend(r3) + forest.extend(r4) + return forest + + def _make_forest_intersect(self) -> MCTSForest: + """ + ├── 0 + │ ├── 16 + │ ├── 17 + │ ├── 18 + │ ├── 19───────│ + │ │ └── 26 │ + │ └── 20 │ + ├── 1 │ + ├── 2 │ + ├── 3 │ + │ ├── 6 │ + │ ├── 7 │ + │ ├── 8 │ + │ ├── 9 │ + │ └── 10 │ + ├── 4 │ + │ ├── 11 │ + │ ├── 12 │ + │ ├── 13 │ + │ │ ├── 21 │ + │ │ ├── 22 │ + │ │ ├── 23 │ + │ │ ├── 24 ──│ + │ │ └── 25 + │ ├── 14 + │ └── 15 + └── 5 + """ + forest = self._make_forest() + states5 = torch.cat([torch.tensor([24]), torch.tensor([19, 26])]) + actions5 = torch.full((2,), 5) + rollout5 = self._make_td(states5, actions5) + forest.extend(rollout5) + return forest + + @staticmethod + def make_labels(tree): + if tree.rollout is not None: + s = torch.cat( + [ + tree.rollout["observation"][:1], + tree.rollout["next", "observation"], + ] + ) + a = tree.rollout["action"].tolist() + s = s.tolist() + return f"node {tree.node_id}: states {s}, actions {a}" + return f"node {tree.node_id}" + + def test_forest_build(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + tree = forest.get_tree(r0[0]) + tree.plot(make_labels=self.make_labels) + + def test_forest_vertices(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + + tree = forest.get_tree(r0[0]) + assert tree.num_vertices() == 9 # (0, 20, 3, 10, 4, 13, 25, 15, 5) + + tree = forest.get_tree(r0[0], compact=False) + assert tree.num_vertices() == 26 + + def test_forest_rebuild_rollout(self): + r0, r1, r2, r3, r4 = self.dummy_rollouts() + forest = self._make_forest() + + tree = forest.get_tree(r0[0]) + assert_close(tree.rollout_from_path((0, 0, 0)), r0, intersection=True) + assert_close(tree.rollout_from_path((0, 1))[-5:], r1, intersection=True) + assert_close(tree.rollout_from_path((0, 0, 1, 0))[-5:], r2, intersection=True) + assert_close(tree.rollout_from_path((1,))[-5:], r3, intersection=True) + assert_close(tree.rollout_from_path((0, 0, 1, 1))[-5:], r4, intersection=True) + + def test_forest_check_hashes(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + tree = forest.get_tree(r0[0]) + nodes = range(tree.num_vertices()) + hashes = set() + for n in nodes: + vertex = tree.get_vertex_by_id(n) + node_hash = vertex.hash + if node_hash is not None: + assert isinstance(node_hash, int) + hashes.add(node_hash) + else: + assert vertex is tree + assert len(hashes) == tree.num_vertices() - 1 + + def test_forest_check_ids(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + tree = forest.get_tree(r0[0]) + nodes = range(tree.num_vertices()) + for n in nodes: + vertex = tree.get_vertex_by_id(n) + node_id = vertex.node_id + assert isinstance(node_id, int) + assert node_id == n + + # Ideally, we'd like to have only views but because we index the storage with a tensor + # we actually get regular, single-storage tensors + # def test_forest_view(self): + # import tensordict.base + # r0, *_ = self.dummy_rollouts() + # forest = self._make_forest() + # tree = forest.get_tree(r0[0]) + # dataptr = set() + # # Check that all tensors point to the same storage (ie, that we only have views) + # for k, v in tree.items(True, True, is_leaf=tensordict.base._NESTED_TENSORS_AS_LISTS): + # if isinstance(k, tuple) and "rollout" in k: + # dataptr.add(v.storage().data_ptr()) + # assert len(dataptr) == 1, k + + def test_forest_intersect(self): + state0 = self._state0() + forest = self._make_forest_intersect() + tree = forest.get_tree(state0) + subtree = forest.get_tree(TensorDict(observation=19)) + + # subtree.plot(make_labels=make_labels) + # tree.plot(make_labels=make_labels) + assert tree.get_vertex_by_id(2).num_children == 2 + assert tree.get_vertex_by_id(1).num_children == 2 + assert tree.get_vertex_by_id(3).num_children == 2 + assert tree.get_vertex_by_id(8).num_children == 2 + assert tree.get_vertex_by_id(10).num_children == 2 + assert tree.get_vertex_by_id(12).num_children == 2 + + # Test contains + assert subtree in tree + + def test_forest_intersect_vertices(self): + state0 = self._state0() + forest = self._make_forest_intersect() + tree = forest.get_tree(state0) + assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash")) + assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash")) + with pytest.raises(ValueError, match="key_type must be"): + tree.vertices(key_type="another key type") + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_transforms.py b/test/test_transforms.py index a5a4fad4e40..1c1903a3b1b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7098,7 +7098,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): torch.manual_seed(0) env.set_seed(0) r1 = env.rollout(100, break_when_any_done=break_when_any_done) - tensordict.tensordict.assert_allclose_td(r0, r1) + tensordict.assert_close(r0, r1) def test_callable_default_value(self): def create_tensor(): diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index cd6a535f8b7..639ed820e86 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -6,11 +6,13 @@ from .map import ( BinaryToDecimal, HashToInt, + MCTSForest, QueryModule, RandomProjectionHash, SipHash, TensorDictMap, TensorMap, + Tree, ) from .postprocs import MultiStep from .replay_buffers import ( diff --git a/torchrl/data/map/__init__.py b/torchrl/data/map/__init__.py index 1fba910884d..c9bc25477c2 100644 --- a/torchrl/data/map/__init__.py +++ b/torchrl/data/map/__init__.py @@ -6,3 +6,4 @@ from .hash import BinaryToDecimal, RandomProjectionHash, SipHash from .query import HashToInt, QueryModule from .tdstorage import TensorDictMap, TensorMap +from .tree import MCTSForest, Tree diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index f5ba93e900f..01988dc43be 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -7,9 +7,10 @@ from typing import Callable, List import torch +from torch.nn import Module -class BinaryToDecimal(torch.nn.Module): +class BinaryToDecimal(Module): """A Module to convert binaries encoded tensors to decimals. This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to @@ -71,7 +72,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: return aggregated_digits -class SipHash(torch.nn.Module): +class SipHash(Module): """A Module to Compute SipHash values for given tensors. A hash function module based on SipHash implementation in python. diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py index 73680999c8c..ff0fb4dfe24 100644 --- a/torchrl/data/map/query.py +++ b/torchrl/data/map/query.py @@ -12,7 +12,7 @@ from tensordict import NestedKey, TensorDictBase from tensordict.nn.common import TensorDictModuleBase from torchrl._utils import logger as torchrl_logger -from torchrl.data.map import SipHash +from torchrl.data.map.hash import SipHash K = TypeVar("K") V = TypeVar("V") @@ -82,7 +82,8 @@ class QueryModule(TensorDictModuleBase): provided but no aggregator is passed, it will default to ``SipHash``. clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be returned. This can be used to retrieve the integer index within the storage, - corresponding to a given input tensordict. + corresponding to a given input tensordict. This can be overridden at runtime by + providing the ``clone`` argument to the forward method. Defaults to ``False``. d Examples: @@ -168,8 +169,10 @@ def __init__( def forward( self, tensordict: TensorDictBase, + *, extend: bool = True, write_hash: bool = True, + clone: bool | None = None, ) -> TensorDictBase: hash_values = [] @@ -186,7 +189,8 @@ def forward( td_hash_value = self.hash_to_int(hash_values, extend=extend) - if self.clone: + clone = clone if clone is not None else self.clone + if clone: output = tensordict.copy() else: output = tensordict diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index 77dbc229b9e..a601f1e3261 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -171,10 +171,12 @@ def from_tensordict_pair( dest, in_keys: List[NestedKey], out_keys: List[NestedKey] | None = None, + max_size: int = 1000, storage_constructor: type | None = None, hash_module: Callable | None = None, collate_fn: Callable[[Any], Any] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, + consolidated: bool | None = None, ): """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. @@ -185,16 +187,20 @@ def from_tensordict_pair( out_keys (List[NestedKey]): a list of keys to return in the output tensordict. All keys absent from out_keys, even if present in ``dest``, will not be stored in the storage. Defaults to ``None`` (all keys are registered). + max_size (int, optional): the maximum number of elements in the storage. Ignored if the + ``storage_constructor`` is passed. Defaults to ``1000``. storage_constructor (type, optional): a type of tensor storage. Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`. Other options include :class:`~tensordict.nn.storage.FixedStorage`. - hash_module (Callable, optional): a hash function to use in the :class:`~tensordict.nn.storage.QueryModule`. - Defaults to :class:`SipHash` for low-dimensional inputs, and :class:`~tensordict.nn.storage.RandomProjectionHash` + hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`. + Defaults to :class:`SipHash` for low-dimensional inputs, and :class:`~torchrl.data.map.RandomProjectionHash` for larger inputs. collate_fn (callable, optional): a function to use to collate samples from the storage. Defaults to a custom value for each known storage type (stack for :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage` subtypes and others). + consolidated (bool, optional): whether to consolidate the storage in a single storage tensor. + Defaults to ``False``. Examples: >>> # The following example requires torchrl and gymnasium to be installed @@ -242,7 +248,13 @@ def from_tensordict_pair( # Build key_to_storage if storage_constructor is None: - storage_constructor = functools.partial(LazyTensorStorage, 1000) + storage_constructor = functools.partial( + LazyTensorStorage, max_size, consolidated=bool(consolidated) + ) + elif consolidated is not None: + storage_constructor = functools.partial( + storage_constructor, consolidated=consolidated + ) storage = storage_constructor() result = cls( query_module=query_module, @@ -257,8 +269,10 @@ def clear(self) -> None: for mem in self.storage.values(): mem.clear() - def _to_index(self, item: TensorDictBase, extend: bool) -> torch.Tensor: - item = self.query_module(item, extend=extend) + def _to_index( + self, item: TensorDictBase, extend: bool, clone: bool | None = None + ) -> torch.Tensor: + item = self.query_module(item, extend=extend, clone=clone) return item[self.index_key] def _maybe_add_batch( @@ -280,9 +294,10 @@ def _maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase: return item def __getitem__(self, item: TensorDictBase) -> TensorDictBase: + item = item.copy() item, _ = self._maybe_add_batch(item, None) - index = self._to_index(item, extend=False) + index = self._to_index(item, extend=False, clone=False) res = self.storage[index] res = self.collate_fn(res) @@ -316,7 +331,7 @@ def __len__(self): def contains(self, item: TensorDictBase) -> torch.Tensor: item, _ = self._maybe_add_batch(item, None) - index = self._to_index(item, extend=False) + index = self._to_index(item, extend=False, clone=True) res = self.storage.contains(index) res = self._maybe_remove_batch(res) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py new file mode 100644 index 00000000000..efbf6f79553 --- /dev/null +++ b/torchrl/data/map/tree.py @@ -0,0 +1,670 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from collections import deque + +from typing import Any, Callable, Dict, List, Literal, Tuple + +import torch +from tensordict import ( + merge_tensordicts, + NestedKey, + TensorClass, + TensorDict, + TensorDictBase, +) +from torchrl.data.map.tdstorage import TensorDictMap +from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree +from torchrl.data.replay_buffers.storages import ListStorage +from torchrl.envs.common import EnvBase + + +class Tree(TensorClass["nocast"]): + """Representation of a single MCTS (Monte Carlo Tree Search) Tree. + + This class encapsulates the data and behavior of a tree node in an MCTS algorithm. + It includes attributes for storing information about the node, such as its children, + visit count, and rollout data. Methods are provided for traversing the tree, + computing statistics, and visualizing the tree structure. + + It is somewhat indistinguishable from a node or a vertex - we use the term "Tree" when talking about + a node with children, "node" or "vertex" when talking about a place in the tree where a branching occurs. + A node in the tree is defined primarily by its ``hash`` value. Usually, a ``hash`` is determined by a unique + combination of state (or observation) and action. If one observation (found in the ``node`` attribute) has more than + one action associated, each branch will be stored in the ``subtree`` attribute as a stack of ``Tree`` instances. + + Attributes: + count (int): The number of visits to this node. + index (torch.Tensor): Indices of the child nodes in the data map. + hash (torch.Tensor): A hash value for this node. + It may be the case that ``hash`` is ``None`` in the specific case where the root of the tree + has more than one action associated. In that case, each subtree branch will have a different action + associated and a hash correspoding to the ``(observation, action)`` pair. + node_id (int): A unique identifier for this node. + rollout (TensorDict): Rollout data following the observation encoded in this node, in a TED format. + If there are multiple actions taken at this node, subtrees are stored in the corresponding + entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method. + node (TensorDict): Data defining this node (e.g., observations) before the next branching. + Entries usually matches the ``in_keys`` in ``MCTSForeset.node_map``. + subtree (Tree): A stack of subtrees produced when actions are taken. + num_children (int): The number of child nodes (read-only). + is_terminal (bool): whether the tree has children nodes (read-only). + If the tree is compact, ``is_terminal == True`` means that there are more than one child node in + ``self.subtree``. + + Methods: + __contains__: Whether another tree can be found in the tree. + vertices: Returns a dictionary containing all vertices in the tree. Keys must be paths, ids or hashes. + num_vertices: Returns the total number of vertices in the tree, with or without duplicates. + edges: Returns a list of edges in the tree. + valid_paths: Yields all valid paths in the tree. + max_length: Returns the maximum length of any path in the tree. + rollout_from_path: Reconstructs a rollout from a given path. + plot: Visualizes the tree using a specified backend and figure type. + get_node_by_id: returns the vertex given by its id in the tree. + get_node_by_hash: returns the vertex given by its hash in the forest. + + """ + + count: int = None + index: torch.Tensor | None = None + # The hash is None if the node has more than one action associated + hash: int | None = None + node_id: int | None = None + + # rollout following the observation encoded in node, in a TorchRL (TED) format + rollout: TensorDict | None = None + + # The data specifying the node + node: TensorDict | None = None + + # Stack of subtrees. A subtree is produced when an action is taken. + subtree: "Tree" = None + + @property + def num_children(self) -> int: + """Number of children of this node. + + Equates to the number of elements in the ``self.subtree`` stack. + """ + return len(self.subtree) if self.subtree is not None else 0 + + @property + def is_terminal(self): + """Returns True if the the tree has no children nodes.""" + return self.subtree is None + + def get_vertex_by_id(self, id: int) -> Tree: + """Goes through the tree and returns the node corresponding the given id.""" + q = deque() + q.append(self) + while len(q): + tree = q.popleft() + if tree.node_id == id: + return tree + if tree.subtree is not None: + q.extend(tree.subtree.unbind(0)) + raise ValueError(f"Node with id {id} not found.") + + def get_vertex_by_hash(self, hash: int) -> Tree: + """Goes through the tree and returns the node corresponding the given hash.""" + q = deque() + q.append(self) + while len(q): + tree = q.popleft() + if tree.hash == hash: + return tree + if tree.subtree is not None: + q.extend(tree.subtree.unbind(0)) + raise ValueError(f"Node with hash {hash} not found.") + + def __contains__(self, other: Tree) -> bool: + hash = other.hash + for vertex in self.vertices().values(): + if vertex.hash == hash: + return True + else: + return False + + def vertices( + self, *, key_type: Literal["id", "hash", "path"] = "hash" + ) -> Dict[int | Tuple[int], Tree]: + """Returns a map containing the vertices of the Tree. + + Keyword args: + key_type (Literal["id", "hash", "path"], optional): Specifies the type of key to use for the vertices. + + - "id": Use the vertex ID as the key. + - "hash": Use a hash of the vertex as the key. + - "path": Use the path to the vertex as the key. This may lead to a dictionary with a longer length than + when ``"id"`` or ``"hash"`` are used as the same node may be part of multiple trajectories. + Defaults to ``"hash"``. + + Defaults to an empty string, which may imply a default behavior. + + Returns: + Dict[int | Tuple[int], Tree]: A dictionary mapping keys to Tree vertices. + + """ + memo = set() + result = {} + q = deque() + cur_path = () + q.append((self, cur_path)) + use_hash = key_type == "hash" + use_id = key_type == "id" + use_path = key_type == "path" + while len(q): + tree, cur_path = q.popleft() + h = tree.hash + if h in memo and not use_path: + continue + memo.add(h) + r = tree.rollout + if r is not None: + r = r["next", "observation"] + if use_path: + result[cur_path] = tree + elif use_id: + result[tree.node_id] = tree + elif use_hash: + result[tree.node_id] = tree + else: + raise ValueError( + f"key_type must be either 'hash', 'id' or 'path'. Got {key_type}." + ) + + n = int(tree.num_children) + for i in range(n): + cur_path_tree = cur_path + (i,) + q.append((tree.subtree[i], cur_path_tree)) + return result + + def num_vertices(self, *, count_repeat: bool = False) -> int: + """Returns the number of unique vertices in the Tree. + + Keyword Args: + count_repeat (bool, optional): Determines whether to count repeated vertices. + - If ``False``, counts each unique vertex only once. + - If ``True``, counts vertices multiple times if they appear in different paths. + Defaults to ``False``. + + Returns: + int: The number of unique vertices in the Tree. + + """ + return len( + { + v.node_id + for v in self.vertices( + key_type="hash" if not count_repeat else "path" + ).values() + } + ) + + def edges(self) -> List[Tuple[int, int]]: + result = [] + q = deque() + parent = self.node_id + q.append((self, parent)) + while len(q): + tree, parent = q.popleft() + n = int(tree.num_children) + for i in range(n): + node = tree.subtree[i] + node_id = node.node_id + result.append((parent, node_id)) + q.append((node, node_id)) + return result + + def valid_paths(self): + q = deque() + cur_path = () + q.append((self, cur_path)) + while len(q): + tree, cur_path = q.popleft() + n = int(tree.num_children) + if not n: + yield cur_path + for i in range(n): + cur_path_tree = cur_path + (i,) + q.append((tree.subtree[i], cur_path_tree)) + + def max_length(self): + return max(*(len(path) for path in self.valid_paths())) + + def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + r = self.rollout + tree = self + rollouts = [] + if r is not None: + rollouts.append(r) + for i in path: + tree = tree.subtree[i] + r = tree.rollout + if r is not None: + rollouts.append(r) + if rollouts: + return torch.cat(rollouts, dim=-1) + + @staticmethod + def _label(info: List[str], tree: "Tree", root=False): + labels = [] + for key in info: + if key == "hash": + hash = tree.hash + if hash is not None: + hash = hash.item() + v = f"hash={hash}" + elif root: + v = f"{key}=None" + else: + v = f"{key}={tree.rollout[key].mean().item()}" + + labels.append(v) + return ", ".join(labels) + + def plot( + self: Tree, + backend: str = "plotly", + figure: str = "tree", + info: List[str] = None, + make_labels: Callable[[Any], Any] | None = None, + ): + if backend == "plotly": + if figure == "box": + _plot_plotly_box(self) + return + elif figure == "tree": + _plot_plotly_tree(self, make_labels=make_labels) + return + else: + pass + raise NotImplementedError( + f"Unkown plotting backend {backend} with figure {figure}." + ) + + +class MCTSForest: + """A collection of MCTS trees. + + The class is aimed at storing rollouts in a storage, and produce trees based on a given root + in that dataset. + + Keyword Args: + data_map (TensorDictMap, optional): the storage to use to store the data + (observation, reward, states etc). If not provided, it is lazily + initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair`. + node_map (TensorDictMap, optional): TODO + done_keys (list of NestedKey): the done keys of the environment. If not provided, + defaults to ``("done", "terminated", "truncated")``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + action_keys (list of NestedKey): the action keys of the environment. If not provided, + defaults to ``("action",)``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + reward_keys (list of NestedKey): the reward keys of the environment. If not provided, + defaults to ``("reward",)``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + observation_keys (list of NestedKey): the observation keys of the environment. If not provided, + defaults to ``("observation",)``. + The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk. + Defaults to ``False``. + + Examples: + >>> from torchrl.envs import GymEnv + >>> import torch + >>> from tensordict import TensorDict, LazyStackedTensorDict + >>> from torchrl.data import TensorDictMap, ListStorage + >>> from torchrl.data.map.tree import MCTSForest + >>> + >>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter + >>> # Create the MCTS Forest + >>> forest = MCTSForest() + >>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle) + >>> env = PendulumEnv() + >>> obs_keys = list(env.observation_spec.keys(True, True)) + >>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys) + >>> # Appending transforms to get an "observation" key that concatenates the observations together + >>> env = env.append_transform( + ... UnsqueezeTransform( + ... in_keys=obs_keys, + ... out_keys=[("unsqueeze", key) for key in obs_keys], + ... dim=-1 + ... ) + ... ) + >>> env = env.append_transform( + ... CatTensors([("unsqueeze", key) for key in obs_keys], "observation") + ... ) + >>> env = env.append_transform(StepCounter()) + >>> env.set_seed(0) + >>> # Get a reset state, then make a rollout out of it + >>> reset_state = env.reset() + >>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) + >>> # Append the rollout to the forest. We're removing the state entries for clarity + >>> rollout0 = rollout0.copy() + >>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) + >>> forest.extend(rollout0) + >>> # The forest should have 6 elements (the length of the rollout) + >>> assert len(forest) == 6 + >>> # Let's make another rollout from the same reset state + >>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) + >>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) + >>> forest.extend(rollout1) + >>> assert len(forest) == 12 + >>> # Let's make another final rollout from an intermediate step in the second rollout + >>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next")) + >>> rollout1b.exclude(*state_keys, inplace=True) + >>> rollout1b.get("next").exclude(*state_keys, inplace=True) + >>> forest.extend(rollout1b) + >>> assert len(forest) == 18 + >>> # Since we have 2 rollouts starting at the same state, our tree should have two + >>> # branches if we produce it from the reset entry. Take the state, and call `get_tree`: + >>> r = rollout0[0] + >>> # Let's get the compact tree that follows the initial reset. A compact tree is + >>> # a tree where nodes that have a single child are collapsed. + >>> tree = forest.get_tree(r) + >>> print(tree.max_length()) + 2 + >>> print(list(tree.valid_paths())) + [(0,), (1, 0), (1, 1)] + >>> from tensordict import assert_close + >>> # We can manually rebuild the tree + >>> assert_close( + ... rollout1, + ... torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]), + ... intersection=True, + ... ) + True + >>> # Or we can rebuild it using the dedicated method + >>> assert_close( + ... rollout1, + ... tree.rollout_from_path((1, 0)), + ... intersection=True, + ... ) + True + >>> tree.plot() + >>> tree = forest.get_tree(r, compact=False) + >>> print(tree.max_length()) + 9 + >>> print(list(tree.valid_paths())) + [(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)] + >>> assert_close( + ... rollout1, + ... tree.rollout_from_path((1, 0, 0, 0, 0, 0)), + ... intersection=True, + ... ) + True + """ + + def __init__( + self, + *, + data_map: TensorDictMap | None = None, + node_map: TensorDictMap | None = None, + done_keys: List[NestedKey] | None = None, + reward_keys: List[NestedKey] = None, + observation_keys: List[NestedKey] = None, + action_keys: List[NestedKey] = None, + consolidated: bool | None = None, + ): + + self.data_map = data_map + + self.node_map = node_map + + self.done_keys = done_keys + self.action_keys = action_keys + self.reward_keys = reward_keys + self.observation_keys = observation_keys + self.consolidated = consolidated + + @property + def done_keys(self): + done_keys = getattr(self, "_done_keys", None) + if done_keys is None: + self._done_keys = done_keys = ("done", "terminated", "truncated") + return done_keys + + @done_keys.setter + def done_keys(self, value): + self._done_keys = value + + @property + def reward_keys(self): + reward_keys = getattr(self, "_reward_keys", None) + if reward_keys is None: + self._reward_keys = reward_keys = ("reward",) + return reward_keys + + @reward_keys.setter + def reward_keys(self, value): + self._reward_keys = value + + @property + def action_keys(self): + action_keys = getattr(self, "_action_keys", None) + if action_keys is None: + self._action_keys = action_keys = ("action",) + return action_keys + + @action_keys.setter + def action_keys(self, value): + self._action_keys = value + + @property + def observation_keys(self): + observation_keys = getattr(self, "_observation_keys", None) + if observation_keys is None: + self._observation_keys = observation_keys = ("observation",) + return observation_keys + + @observation_keys.setter + def observation_keys(self, value): + self._observation_keys = value + + def get_keys_from_env(self, env: EnvBase): + """Writes missing done, action and reward keys to the Forest given an environment. + + Existing keys are not overwritten. + """ + if getattr(self, "_reward_keys", None) is None: + self.reward_keys = env.reward_keys + if getattr(self, "_done_keys", None) is None: + self.done_keys = env.done_keys + if getattr(self, "_action_keys", None) is None: + self.action_keys = env.action_keys + if getattr(self, "_observation_keys", None) is None: + self.observation_keys = env.observation_keys + + @classmethod + def _write_fn_stack(cls, new, old=None): + if old is None: + result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False) + result.set( + "count", torch.ones(result.shape, dtype=torch.int, device=result.device) + ) + else: + + def cat(name, x, y): + if name == "count": + return x + if y.ndim < x.ndim: + y = y.unsqueeze(0) + result = torch.cat([x, y], 0).unique(dim=0, sorted=False) + return result + + result = old.named_apply(cat, new, default=None) + result.set_("count", old.get("count") + 1) + return result + + def _make_storage(self, source, dest): + try: + self.data_map = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=[*self.observation_keys, *self.action_keys], + consolidated=self.consolidated, + ) + except KeyError as err: + raise KeyError( + "A KeyError occurred during data map creation. This could be due to the wrong setting of a key in the MCTSForest constructor. Scroll up for more info." + ) from err + + def _make_storage_branches(self, source, dest): + self.node_map = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=[*self.observation_keys], + out_keys=[ + *self.data_map.query_module.out_keys, # hash and index + # *self.action_keys, + # *[("next", rk) for rk in self.reward_keys], + "count", + ], + storage_constructor=ListStorage, + collate_fn=TensorDict.lazy_stack, + write_fn=self._write_fn_stack, + ) + + def extend(self, rollout): + source, dest = ( + rollout.exclude("next").copy(), + rollout.select("next", *self.action_keys).copy(), + ) + + if self.data_map is None: + self._make_storage(source, dest) + + # We need to set the action somewhere to keep track of what action lead to what child + # # Set the action in the 'next' + # dest[1:] = source[:-1].exclude(*self.done_keys) + + self.data_map[source] = dest + value = source + if self.node_map is None: + self._make_storage_branches(source, dest) + self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) + + def get_child(self, root: TensorDictBase) -> TensorDictBase: + return self.data_map[root] + + def _make_local_tree( + self, + root: TensorDictBase, + index: torch.Tensor | None = None, + compact: bool = True, + ) -> Tuple[Tree, torch.Tensor | None, torch.Tensor | None]: + root = root.select(*self.node_map.in_keys) + node_meta = None + if root in self.node_map: + node_meta = self.node_map[root] + if index is None: + node_meta = self.node_map[root] + index = node_meta["_index"] + elif index is not None: + pass + else: + return None + steps = [] + while index.numel() <= 1: + index = index.squeeze() + d = self.data_map.storage[index] + steps.append(merge_tensordicts(d, root, callback_exist=lambda *x: None)) + d = d["next"] + if d in self.node_map: + root = d.select(*self.node_map.in_keys) + node_meta = self.node_map[root] + index = node_meta["_index"] + if not compact: + break + else: + index = None + break + rollout = None + if steps: + rollout = torch.stack(steps, -1) + # Will be populated later + hash = node_meta["_hash"] + return ( + Tree( + rollout=rollout, + count=node_meta["count"], + node=root, + index=index, + hash=None, + subtree=None, + ), + index, + hash, + ) + + # The recursive implementation is slower and less compatible with compile + # def _make_tree(self, root: TensorDictBase, index: torch.Tensor|None=None)->Tree: + # tree, indices = self._make_local_tree(root, index=index) + # subtrees = [] + # if indices is not None: + # for i in indices: + # subtree = self._make_tree(tree.node, index=i) + # subtrees.append(subtree) + # subtrees = TensorDict.lazy_stack(subtrees) + # tree.subtree = subtrees + # return tree + def _make_tree_iter( + self, root, index=None, max_depth: int | None = None, compact: bool = True + ): + q = deque() + memo = {} + tree, indices, hash = self._make_local_tree(root, index=index) + tree.node_id = 0 + + result = tree + depth = 0 + counter = 1 + if indices is not None: + q.append((tree, indices, hash, depth)) + del tree, indices + + while len(q): + tree, indices, hash, depth = q.popleft() + extend = max_depth is None or depth < max_depth + subtrees = [] + for i, h in zip(indices, hash): + # TODO: remove the .item() + h = h.item() + subtree, subtree_indices, subtree_hash = memo.get(h, (None,) * 3) + if subtree is None: + subtree, subtree_indices, subtree_hash = self._make_local_tree( + tree.node, index=i, compact=compact + ) + subtree.node_id = counter + counter += 1 + subtree.hash = h + memo[h] = (subtree, subtree_indices, subtree_hash) + + subtrees.append(subtree) + if extend and subtree_indices is not None: + q.append((subtree, subtree_indices, subtree_hash, depth + 1)) + subtrees = TensorDict.lazy_stack(subtrees) + tree.subtree = subtrees + + return result + + def get_tree( + self, + root, + *, + max_depth: int | None = None, + compact: bool = True, + ) -> Tree: + return self._make_tree_iter(root=root, max_depth=max_depth, compact=compact) + + @classmethod + def valid_paths(cls, tree: Tree): + yield from tree.valid_paths() + + def __len__(self): + return len(self.data_map) diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py new file mode 100644 index 00000000000..570214f1cb2 --- /dev/null +++ b/torchrl/data/map/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import Callable, List + +from tensordict import NestedKey + + +def _plot_plotly_tree( + tree: "Tree", make_labels: Callable[[Tree], str] | None = None # noqa: F821 +): + import plotly.graph_objects as go + from igraph import Graph + + if make_labels is None: + + def make_labels(tree): + return str((tree.node_id, tree.hash)) + + nr_vertices = tree.num_vertices() + vertices = tree.vertices() + + v_label = [make_labels(subtree) for subtree in vertices.values()] + G = Graph(nr_vertices, tree.edges()) + + layout = G.layout_sugiyama(range(nr_vertices)) + + position = {k: layout[k] for k in range(nr_vertices)} + # Y = [layout[k][1] for k in range(nr_vertices)] + # M = max(Y) + + # es = EdgeSeq(G) # sequence of edges + E = [e.tuple for e in G.es] # list of edges + + L = len(position) + Xn = [position[k][0] for k in range(L)] + # Yn = [2 * M - position[k][1] for k in range(L)] + Yn = [position[k][1] for k in range(L)] + Xe = [] + Ye = [] + for edge in E: + Xe += [position[edge[0]][0], position[edge[1]][0], None] + # Ye += [2 * M - position[edge[0]][1], 2 * M - position[edge[1]][1], None] + Ye += [position[edge[0]][1], position[edge[1]][1], None] + + labels = v_label + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=Xe, + y=Ye, + mode="lines", + line={"color": "rgb(210,210,210)", "width": 5}, + hoverinfo="none", + ) + ) + fig.add_trace( + go.Scatter( + x=Xn, + y=Yn, + mode="markers+text", + name="bla", + marker={ + "symbol": "circle-dot", + "size": 40, + "color": "#6175c1", # '#DB4551', + "line": {"color": "rgb(50,50,50)", "width": 1}, + }, + text=labels, + hoverinfo="text", + textposition="middle right", + opacity=0.8, + ) + ) + fig.show() + + +def _plot_plotly_box(tree: "Tree", info: List[NestedKey] = None): # noqa: F821 + import plotly.graph_objects as go + + if info is None: + info = ["hash", ("next", "reward")] + + parents = [""] + labels = [tree._label(info, tree, root=True)] + + _tree = tree + + def extend(tree: "Tree", parent): # noqa: F821 + children = tree.subtree + if children is None: + return + for child in children: + labels.append(tree._label(info, child)) + parents.append(parent) + extend(child, labels[-1]) + + extend(_tree, labels[-1]) + fig = go.Figure(go.Treemap(labels=labels, parents=parents)) + fig.show() diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index dc2ddff0ef9..f029099eb7c 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -911,6 +911,8 @@ class LazyTensorStorage(TensorStorage): compilable (bool, optional): whether the storage is compilable. If ``True``, the writer cannot be shared between multiple processes. Defaults to ``False``. + consolidated (bool, optional): if ``True``, the storage will be consolidated after + its first expansion. Defaults to ``False``. Examples: >>> data = TensorDict({ @@ -971,6 +973,7 @@ def __init__( device: torch.device = "cpu", ndim: int = 1, compilable: bool = False, + consolidated: bool = False, ): super().__init__( storage=None, @@ -979,6 +982,7 @@ def __init__( ndim=ndim, compilable=compilable, ) + self.consolidated = consolidated def _init( self, @@ -1003,7 +1007,11 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.to(self.device) - out = torch.empty_like(out.expand(max_size_along_dim0(data.shape))) + out: TensorDictBase = torch.empty_like( + out.expand(max_size_along_dim0(data.shape)) + ) + if self.consolidated: + out = out.consolidate() else: # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = tree_map( @@ -1014,6 +1022,8 @@ def max_size_along_dim0(data_shape): ), data, ) + if self.consolidated: + raise ValueError("Cannot consolidate non-tensordict storages.") self._storage = out self.initialized = True