diff --git a/tests/modules/test_feature_extraction.py b/tests/modules/test_feature_extraction.py index d7e36965..20066481 100644 --- a/tests/modules/test_feature_extraction.py +++ b/tests/modules/test_feature_extraction.py @@ -205,6 +205,8 @@ def test_feature_extraction_with_checkpoint(model_name, length_target): features_only=True, ) + print(ms.get_context("device_target")) + assert isinstance(model, nn.Cell), "Loading checkpoint error" x = ms.Tensor(np.random.randn(8, 3, 32, 32), dtype=ms.float32)