Skip to content

Commit

Permalink
[Tests] Skip deprecation warning tests on FB fbcode
Browse files Browse the repository at this point in the history
ghstack-source-id: fb0cc381a670377667194324f5b019076b8e762d
Pull Request resolved: #1128
  • Loading branch information
vmoens committed Dec 4, 2024
1 parent d7529ab commit 22da679
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import contextlib
import copy
import os
import pickle
import unittest
import weakref
Expand Down Expand Up @@ -65,6 +66,7 @@
except ImportError:
from tensordict.utils import Buffer


# Capture all warnings
pytestmark = [
pytest.mark.filterwarnings("error"),
Expand All @@ -80,6 +82,18 @@
),
]

PYTORCH_TEST_FBCODE = os.getenv("PYTORCH_TEST_FBCODE")
if PYTORCH_TEST_FBCODE:
pytestmark.append(
pytest.mark.filterwarnings("ignore:aggregate_probabilities"),
)
pytestmark.append(
pytest.mark.filterwarnings("ignore:include_sum"),
)
pytestmark.append(
pytest.mark.filterwarnings("ignore:inplace"),
)


class TestInteractionType:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1091,17 +1105,17 @@ def test_probtdseq_multdist(self, include_sum, aggregate_probabilities, inplace)

v = tdm(TensorDict(x=torch.randn(10, 3)))
assert set(v.keys()) == {"x", "loc", "y", "loc2", "z"}
if aggregate_probabilities is None:
if aggregate_probabilities is None and not PYTORCH_TEST_FBCODE:
cm0 = pytest.warns(
expected_warning=DeprecationWarning, match="aggregate_probabilities"
)
else:
cm0 = contextlib.nullcontext()
if include_sum is None:
if include_sum is None and not PYTORCH_TEST_FBCODE:
cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum")
else:
cm1 = contextlib.nullcontext()
if inplace is None:
if inplace is None and not PYTORCH_TEST_FBCODE:
cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace")
else:
cm2 = contextlib.nullcontext()
Expand Down Expand Up @@ -1150,17 +1164,17 @@ def test_probtdseq_intermediate_dist(

v = tdm(TensorDict(x=torch.randn(10, 3)))
assert set(v.keys()) == {"x", "loc", "y", "loc2"}
if aggregate_probabilities is None:
if aggregate_probabilities is None and not PYTORCH_TEST_FBCODE:
cm0 = pytest.warns(
expected_warning=DeprecationWarning, match="aggregate_probabilities"
)
else:
cm0 = contextlib.nullcontext()
if include_sum is None:
if include_sum is None and not PYTORCH_TEST_FBCODE:
cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum")
else:
cm1 = contextlib.nullcontext()
if inplace is None:
if inplace is None and not PYTORCH_TEST_FBCODE:
cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace")
else:
cm2 = contextlib.nullcontext()
Expand Down

1 comment on commit 22da679

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 22da679 Previous: d7529ab Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 126929.41792517334 iter/sec (stddev: 5.781257540359028e-7) 334359.27259861614 iter/sec (stddev: 3.566247617085364e-7) 2.63
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 126888.70214145185 iter/sec (stddev: 6.22221391115816e-7) 331574.41501332517 iter/sec (stddev: 3.7617805884965205e-7) 2.61

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.