diff --git a/test/test_nn.py b/test/test_nn.py index ff22382a9..580f115b0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6,6 +6,7 @@ import argparse import contextlib import copy +import os import pickle import unittest import weakref @@ -65,6 +66,7 @@ except ImportError: from tensordict.utils import Buffer + # Capture all warnings pytestmark = [ pytest.mark.filterwarnings("error"), @@ -80,6 +82,18 @@ ), ] +PYTORCH_TEST_FBCODE = os.getenv("PYTORCH_TEST_FBCODE") +if PYTORCH_TEST_FBCODE: + pytestmark.append( + pytest.mark.filterwarnings("ignore:aggregate_probabilities"), + ) + pytestmark.append( + pytest.mark.filterwarnings("ignore:include_sum"), + ) + pytestmark.append( + pytest.mark.filterwarnings("ignore:inplace"), + ) + class TestInteractionType: @pytest.mark.parametrize( @@ -1091,17 +1105,17 @@ def test_probtdseq_multdist(self, include_sum, aggregate_probabilities, inplace) v = tdm(TensorDict(x=torch.randn(10, 3))) assert set(v.keys()) == {"x", "loc", "y", "loc2", "z"} - if aggregate_probabilities is None: + if aggregate_probabilities is None and not PYTORCH_TEST_FBCODE: cm0 = pytest.warns( expected_warning=DeprecationWarning, match="aggregate_probabilities" ) else: cm0 = contextlib.nullcontext() - if include_sum is None: + if include_sum is None and not PYTORCH_TEST_FBCODE: cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum") else: cm1 = contextlib.nullcontext() - if inplace is None: + if inplace is None and not PYTORCH_TEST_FBCODE: cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace") else: cm2 = contextlib.nullcontext() @@ -1150,17 +1164,17 @@ def test_probtdseq_intermediate_dist( v = tdm(TensorDict(x=torch.randn(10, 3))) assert set(v.keys()) == {"x", "loc", "y", "loc2"} - if aggregate_probabilities is None: + if aggregate_probabilities is None and not PYTORCH_TEST_FBCODE: cm0 = pytest.warns( expected_warning=DeprecationWarning, match="aggregate_probabilities" ) else: cm0 = contextlib.nullcontext() - if include_sum is None: + if include_sum is None and not PYTORCH_TEST_FBCODE: cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum") else: cm1 = contextlib.nullcontext() - if inplace is None: + if inplace is None and not PYTORCH_TEST_FBCODE: cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace") else: cm2 = contextlib.nullcontext()