diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index 4ac2bf75b54..c50b987db33 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -1766,8 +1766,10 @@ def _postprocess_results( results['masks'] = np.array( [results['masks'][i] for i in results['idx_mapper']]) results['masks'] = ori_masks.__class__( - results['masks'], ori_masks.height, ori_masks.width) - + results['masks'], + results['masks'][0].shape[0], + results['masks'][0].shape[1], + ) if (not len(results['idx_mapper']) and self.skip_img_without_anno): return None