From fc883ec85f8bbd08fe934d1fa19a22b7051b3c04 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Aug 2023 16:53:03 +0800 Subject: [PATCH] modify skip test conditions --- tests/test_cnn/test_wrappers.py | 37 +++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 94c51ce15c..8c76ccbdd4 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -4,6 +4,8 @@ import pytest import torch import torch.nn as nn +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d) @@ -376,24 +378,19 @@ def test_nn_op_forward_called(): nn_module_forward.assert_called_with(x_normal) -@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10)) +@pytest.mark.skipif( + digit_version(TORCH_VERSION) < digit_version('1.10'), + reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9') def test_fx_compatibility(): - try: - from torch import fx - - # ensure the fx trace can pass the network - for Net in (MaxPool2d, MaxPool3d): - net = Net(1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - for Net in (Linear, ): - net = Net(1, 1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): - net = Net(1, 1, 1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - except ImportError: - # torch.fx might not be available - pass + from torch import fx + + # ensure the fx trace can pass the network + for Net in (MaxPool2d, MaxPool3d): + net = Net(1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Linear, ): + net = Net(1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): + net = Net(1, 1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841