diff --git a/scripts/cpn_inference.py b/scripts/cpn_inference.py index c873f21..a4827c6 100644 --- a/scripts/cpn_inference.py +++ b/scripts/cpn_inference.py @@ -250,6 +250,8 @@ def main(): parser.add_argument('--inputs_glob', default='*.*', type=str, help='Should `inputs` specify a directory, this' 'can be used to filter for specific files ' 'in that directory.') + parser.add_argument('--inputs_method', default='imageio', help='Method used for loading non-hdf5 inputs.') + parser.add_argument('--inputs_dataset', default=None, help='Dataset name for hdf5 inputs.') parser.add_argument('-m', '--models', default='/models', help='Model directory or filename.') parser.add_argument('--devices', default='auto', type=str, help='Devices.') parser.add_argument('--accelerator', default='auto', type=str, help='Accelerator.') @@ -279,6 +281,11 @@ def main(): 'Note: Intended for smaller images!') parser.add_argument('--truncated_images', action='store_true', help='Whether to support truncated images.') parser.add_argument('-p', '--properties', nargs='*', help='Region properties') + parser.add_argument('--spacing', default=1., type=float, help='The pixel spacing. Relevant for pixel-based ' + 'region properties.') + parser.add_argument('--separator', default='-', type=str, + help='Separator string for region properties that are written to multiple columns. ' + 'Default is "-" as in bbox-0, bbox-1, bbox-2, bbox-4.') args, unknown = parser.parse_known_args() if args.truncated_images: @@ -320,9 +327,15 @@ def main(): ) for src in inputs: - dst = join(outputs, splitext(basename(src))[0] + '{ext}') + prefix, ext = splitext(basename(src)) + dst = join(outputs, prefix + '{ext}') print(src, '-->', dst.format(ext='.*'), flush=True) - img = cd.load_image(src) + if ext in ('.h5', '.hdf5'): + assert args.inputs_dataset is not None, 'Please specify the dataset name for hdf5 inputs via --dataset ' + print('Read from h5:', args.inputs_dataset) + img = cd.from_h5(src, args.inputs_dataset) + else: + img = cd.load_image(src, method=args.inputs_method) y = cd.asnumpy(apply_model( img, models, trainer, crop_size=args.tile_size, @@ -353,10 +366,12 @@ def main(): if do_props: if args.flat_labels: assert flat_labels is not None - cd.data.labels2property_table(flat_labels, props).to_csv(dst.format(ext='_flat.csv')) + cd.data.labels2property_table(flat_labels, props, spacing=args.spacing, + separator=args.separator).to_csv(dst.format(ext='_flat.csv')) if args.labels or not args.flat_labels: assert labels is not None - cd.data.labels2property_table(labels, props).to_csv(dst.format(ext='.csv')) + cd.data.labels2property_table(labels, props, spacing=args.spacing, separator=args.separator).to_csv( + dst.format(ext='.csv')) if args.demo_figure: from matplotlib import pyplot as plt