diff --git a/hezar/models/text_detection/craft/craft_text_detection.py b/hezar/models/text_detection/craft/craft_text_detection.py index 8fd0734b..4d3c8cb2 100644 --- a/hezar/models/text_detection/craft/craft_text_detection.py +++ b/hezar/models/text_detection/craft/craft_text_detection.py @@ -3,16 +3,19 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchvision import models from ....constants import Backends from ....registry import register_model +from ....utils import is_backend_available from ...model import Model from ...model_outputs import TextDetectionOutput from .craft_text_detection_config import CraftTextDetectionConfig from .craft_utils import adjust_result_coordinates, get_detection_boxes, polys2boxes +if is_backend_available(Backends.TORCHVISION): + from torchvision.models import vgg16_bn + _required_backends = [ Backends.OPENCV, Backends.PILLOW, @@ -141,7 +144,7 @@ def post_process( class VGG16BN(nn.Module): def __init__(self): super(VGG16BN, self).__init__() - vgg_pretrained_features = models.vgg16_bn().features + vgg_pretrained_features = vgg16_bn().features self.slice1 = nn.Sequential() self.slice2 = nn.Sequential()