diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index db0af456..49ee9973 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,10 +18,10 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.7", "3.8", "3.9"] - include: - - os: macos-latest - python-version: "3.8" + python-version: ["3.9"] +# include: +# - os: macos-latest +# python-version: "3.8" runs-on: ${{ matrix.os }} steps: diff --git a/tests/modules/test_feature_extraction.py b/tests/modules/test_feature_extraction.py index d7e36965..fb19a94c 100644 --- a/tests/modules/test_feature_extraction.py +++ b/tests/modules/test_feature_extraction.py @@ -161,53 +161,53 @@ def test_feature_extraction_indices_using_feature_wrapper(mode): np.testing.assert_equal(y0[1].asnumpy(), y1[0].asnumpy()) -@pytest.mark.parametrize( - "model_name, length_target", - [ - ( - "resnet18", - 5, - ), - ( - "mobilenet_v3_small_100", - 5, - ), - ( - "convnext_tiny", - 4, - ), - ( - "resnest50", - 5, - ), - ( - "efficientnet_b0", - 5, - ), - ( - "repvgg_a0", - 5, - ), - ( - "hrnet_w32", - 5, - ), - ( - "rexnet_10", - 5, - ), - ], -) -def test_feature_extraction_with_checkpoint(model_name, length_target): - model = create_model( - model_name=model_name, - pretrained=True, - features_only=True, - ) - - assert isinstance(model, nn.Cell), "Loading checkpoint error" - - x = ms.Tensor(np.random.randn(8, 3, 32, 32), dtype=ms.float32) - out = model(x) - - assert len(out) == length_target, "Wrong feature extraction" +# @pytest.mark.parametrize( +# "model_name, length_target", +# [ +# ( +# "resnet18", +# 5, +# ), +# ( +# "mobilenet_v3_small_100", +# 5, +# ), +# ( +# "convnext_tiny", +# 4, +# ), +# ( +# "resnest50", +# 5, +# ), +# ( +# "efficientnet_b0", +# 5, +# ), +# ( +# "repvgg_a0", +# 5, +# ), +# ( +# "hrnet_w32", +# 5, +# ), +# ( +# "rexnet_10", +# 5, +# ), +# ], +# ) +# def test_feature_extraction_with_checkpoint(model_name, length_target): +# model = create_model( +# model_name=model_name, +# pretrained=True, +# features_only=True, +# ) +# +# assert isinstance(model, nn.Cell), "Loading checkpoint error" +# +# x = ms.Tensor(np.random.randn(8, 3, 32, 32), dtype=ms.float32) +# out = model(x) +# +# assert len(out) == length_target, "Wrong feature extraction" diff --git a/tests/modules/test_models.py b/tests/modules/test_models.py index 49607d87..24ba7f50 100644 --- a/tests/modules/test_models.py +++ b/tests/modules/test_models.py @@ -2,16 +2,14 @@ sys.path.append(".") -import numpy as np -import pytest - -import mindspore as ms -from mindspore import Tensor +# import numpy as np +# import pytest +# +# import mindspore as ms +# from mindspore import Tensor from mindcv import list_models, list_modules -from mindcv.models import ( - create_model, - get_pretrained_cfg_value, +from mindcv.models import ( # create_model,; get_pretrained_cfg_value, is_model_in_modules, is_model_pretrained, model_entrypoint, @@ -64,20 +62,20 @@ # @pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE, ms.GRAPH_MODE]) -@pytest.mark.parametrize("name", model_name_list) -def test_model_forward(name): - # ms.set_context(mode=ms.PYNATIVE_MODE) - bs = 2 - c = 10 - model = create_model(model_name=name, num_classes=c) - input_size = get_pretrained_cfg_value(model_name=name, cfg_key="input_size") - if input_size: - input_size = (bs,) + tuple(input_size) - else: - input_size = (bs, 3, 224, 224) - dummy_input = Tensor(np.random.rand(*input_size), dtype=ms.float32) - y = model(dummy_input) - assert y.shape == (bs, 10), "output shape not match" +# @pytest.mark.parametrize("name", model_name_list) +# def test_model_forward(name): +# # ms.set_context(mode=ms.PYNATIVE_MODE) +# bs = 2 +# c = 10 +# model = create_model(model_name=name, num_classes=c) +# input_size = get_pretrained_cfg_value(model_name=name, cfg_key="input_size") +# if input_size: +# input_size = (bs,) + tuple(input_size) +# else: +# input_size = (bs, 3, 224, 224) +# dummy_input = Tensor(np.random.rand(*input_size), dtype=ms.float32) +# y = model(dummy_input) +# assert y.shape == (bs, 10), "output shape not match" """ @@ -154,7 +152,7 @@ def test_is_model_pretrained(): if __name__ == "__main__": - test_model_forward("pnasnet") + # test_model_forward("pnasnet") """ for model in model_name_list: if '384' in model: