diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index a7c004bfcc5..0611af20b45 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -390,6 +390,11 @@ def __init__( self.batch_size = torch.Size(batch_size) self._run_type_checks = run_type_checks self._allow_done_after_reset = allow_done_after_reset + @wraps(check_env_specs_func) + def check_env_specs(self, *args, **kwargs): + return check_env_specs_func(self, *args, **kwargs) + + check_env_specs.__doc__ = check_env_specs_func.__doc__ @classmethod def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):