Skip to content

Commit

Permalink
Update model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Aug 26, 2023
1 parent 9495b54 commit 207ce67
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions scripts/cpn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def apply_model(img, models, trainer, crop_size=(768, 768), strides=(384, 384),
h_tiles, w_tiles = tile_loader.num_slices_per_axis
nms_thresh = None
for model_name in models:
model = cd.models.LitCpn(model_name)
if model_name.endswith('.ckpt'):
model = cd.models.LitCpn.load_from_checkpoint(model_name, map_location='cpu')
else:
model = cd.load_model(model_name, map_location='cpu')
model.eval()
model.requires_grad_(False)
nms_thresh = kwargs.get('nms_thresh', model.model.nms_thresh)
Expand Down Expand Up @@ -301,7 +304,7 @@ def main():
outputs = args.outputs
makedirs(outputs, exist_ok=True)
if isdir(args.models):
models = sorted(glob(join(args.models, '*.pt*')))
models = sorted(glob(join(args.models, '*.pt'))) + sorted(glob(join(args.models, '*.ckpt')))
else:
models = sorted(glob(args.models))

Expand Down

0 comments on commit 207ce67

Please sign in to comment.