diff --git a/keras_retinanet/models/resnet.py b/keras_retinanet/models/resnet.py index dde93433f..a8e87364b 100644 --- a/keras_retinanet/models/resnet.py +++ b/keras_retinanet/models/resnet.py @@ -26,6 +26,28 @@ custom_objects.update(keras_resnet.custom_objects) +def download_imagenet(backbone): + allowed_backbones = [50, 101, 152] + if backbone not in allowed_backbones: + raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) + + filename = resnet_filename.format(backbone) + resource = resnet_resource.format(backbone) + if backbone == 50: + checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319' + elif backbone == 101: + checksum = '05dc86924389e5b401a9ea0348a3213c' + elif backbone == 152: + checksum = '6ee11ef2b135592f8031058820bb9e71' + + return keras.applications.imagenet_utils.get_file( + filename, + resource, + cache_subdir='models', + md5_hash=checksum + ) + + def resnet_retinanet(num_classes, backbone=50, inputs=None, weights='imagenet', **kwargs): allowed_backbones = [50, 101, 152] if backbone not in allowed_backbones: @@ -37,21 +59,7 @@ def resnet_retinanet(num_classes, backbone=50, inputs=None, weights='imagenet', # determine which weights to load if weights == 'imagenet': - filename = resnet_filename.format(backbone) - resource = resnet_resource.format(backbone) - if backbone == 50: - checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319' - elif backbone == 101: - checksum = '05dc86924389e5b401a9ea0348a3213c' - elif backbone == 152: - checksum = '6ee11ef2b135592f8031058820bb9e71' - - weights_path = keras.applications.imagenet_utils.get_file( - filename, - resource, - cache_subdir='models', - md5_hash=checksum - ) + weights_path = download_imagenet(backbone) elif weights is None: weights_path = None else: