diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index 51e80e7f9e62..6fe84edf4eda 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -84,7 +84,7 @@ def cast_to_half(x): def cifar_trainset(fp16=False): torchvision = pytest.importorskip("torchvision", minversion="0.5.0") - import torchvision.transforms as transforms + from torchvision import transforms transform_list = [ transforms.ToTensor(),