diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 65257f74b..497394a99 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -56,10 +56,7 @@ class InteractionType(StrEnum): @classmethod def from_str(cls, type_str: str) -> InteractionType: """Return the interaction_type with name matched to the provided string (case insensitive).""" - for member_type in cls: - if member_type.name == type_str.upper(): - return member_type - raise ValueError(f"The provided interaction type {type_str} is unsupported!") + return cls(type_str.lower()) _INTERACTION_TYPE: InteractionType | None = None @@ -74,13 +71,16 @@ class set_interaction_type(_DecoratorContextManager): """Sets all ProbabilisticTDModules sampling to the desired type. Args: - type (InteractionType): sampling type to use when the policy is being called. + type (InteractionType or str): sampling type to use when the policy is being called. + """ def __init__( - self, type: InteractionType | None = InteractionType.DETERMINISTIC + self, type: InteractionType | str | None = InteractionType.DETERMINISTIC ) -> None: super().__init__() + if isinstance(type, str): + type = InteractionType(type.lower()) self.type = type def clone(self) -> set_interaction_type: