Skip to content

Commit

Permalink
Move ResNet ImageNet downloading in separate function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hans Gaiser authored and hgaiser committed Jan 9, 2018
1 parent 62c3741 commit c5860e7
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions keras_retinanet/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@
custom_objects.update(keras_resnet.custom_objects)


def download_imagenet(backbone):
allowed_backbones = [50, 101, 152]
if backbone not in allowed_backbones:
raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones))

filename = resnet_filename.format(backbone)
resource = resnet_resource.format(backbone)
if backbone == 50:
checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319'
elif backbone == 101:
checksum = '05dc86924389e5b401a9ea0348a3213c'
elif backbone == 152:
checksum = '6ee11ef2b135592f8031058820bb9e71'

return keras.applications.imagenet_utils.get_file(
filename,
resource,
cache_subdir='models',
md5_hash=checksum
)


def resnet_retinanet(num_classes, backbone=50, inputs=None, weights='imagenet', **kwargs):
allowed_backbones = [50, 101, 152]
if backbone not in allowed_backbones:
Expand All @@ -37,21 +59,7 @@ def resnet_retinanet(num_classes, backbone=50, inputs=None, weights='imagenet',

# determine which weights to load
if weights == 'imagenet':
filename = resnet_filename.format(backbone)
resource = resnet_resource.format(backbone)
if backbone == 50:
checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319'
elif backbone == 101:
checksum = '05dc86924389e5b401a9ea0348a3213c'
elif backbone == 152:
checksum = '6ee11ef2b135592f8031058820bb9e71'

weights_path = keras.applications.imagenet_utils.get_file(
filename,
resource,
cache_subdir='models',
md5_hash=checksum
)
weights_path = download_imagenet(backbone)
elif weights is None:
weights_path = None
else:
Expand Down

0 comments on commit c5860e7

Please sign in to comment.