diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index ad3f4318..6021758f 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -211,10 +211,190 @@ def generate( self.masks = masks # Store the masks as a list of dictionaries self.batch = False - # if output is not None: - # # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values. - # self.save_masks( - # output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs - # ) + if output is not None: + # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values. + self.save_masks( + output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs + ) return masks + + def save_masks( + self, + output=None, + foreground=True, + unique=True, + erosion_kernel=None, + mask_multiplier=255, + **kwargs, + ): + """Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values. + + Args: + output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects. + foreground (bool, optional): Whether to generate the foreground mask. Defaults to True. + unique (bool, optional): Whether to assign a unique value to each object. Defaults to True. + erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders. + Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None. + mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1]. + You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255. + + """ + + if self.masks is None: + raise ValueError("No masks found. Please run generate() first.") + + h, w, _ = self.image.shape + masks = self.masks + + # Set output image data type based on the number of objects + if len(masks) < 255: + dtype = np.uint8 + elif len(masks) < 65535: + dtype = np.uint16 + else: + dtype = np.uint32 + + # Generate a mask of objects with unique values + if unique: + # Sort the masks by area in ascending order + sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False) + + # Create an output image with the same size as the input image + objects = np.zeros( + ( + sorted_masks[0]["segmentation"].shape[0], + sorted_masks[0]["segmentation"].shape[1], + ) + ) + # Assign a unique value to each object + for index, ann in enumerate(sorted_masks): + m = ann["segmentation"] + objects[m] = index + 1 + + # Generate a binary mask + else: + if foreground: # Extract foreground objects only + resulting_mask = np.zeros((h, w), dtype=dtype) + else: + resulting_mask = np.ones((h, w), dtype=dtype) + resulting_borders = np.zeros((h, w), dtype=dtype) + + for m in masks: + mask = (m["segmentation"] > 0).astype(dtype) + resulting_mask += mask + + # Apply erosion to the mask + if erosion_kernel is not None: + mask_erode = cv2.erode(mask, erosion_kernel, iterations=1) + mask_erode = (mask_erode > 0).astype(dtype) + edge_mask = mask - mask_erode + resulting_borders += edge_mask + + resulting_mask = (resulting_mask > 0).astype(dtype) + resulting_borders = (resulting_borders > 0).astype(dtype) + objects = resulting_mask - resulting_borders + objects = objects * mask_multiplier + + objects = objects.astype(dtype) + self.objects = objects + + if output is not None: # Save the output image + common.array_to_image(self.objects, output, self.source, **kwargs) + + def show_masks( + self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs + ): + """Show the binary mask or the mask of objects with unique values. + + Args: + figsize (tuple, optional): The figure size. Defaults to (12, 10). + cmap (str, optional): The colormap. Defaults to "binary_r". + axis (str, optional): Whether to show the axis. Defaults to "off". + foreground (bool, optional): Whether to show the foreground mask only. Defaults to True. + **kwargs: Other arguments for save_masks(). + """ + + import matplotlib.pyplot as plt + + if self.batch: + self.objects = cv2.imread(self.masks) + else: + if self.objects is None: + self.save_masks(foreground=foreground, **kwargs) + + plt.figure(figsize=figsize) + plt.imshow(self.objects, cmap=cmap) + plt.axis(axis) + plt.show() + + def show_anns( + self, + figsize=(12, 10), + axis="off", + alpha=0.35, + output=None, + blend=True, + **kwargs, + ): + """Show the annotations (objects with random color) on the input image. + + Args: + figsize (tuple, optional): The figure size. Defaults to (12, 10). + axis (str, optional): Whether to show the axis. Defaults to "off". + alpha (float, optional): The alpha value for the annotations. Defaults to 0.35. + output (str, optional): The path to the output image. Defaults to None. + blend (bool, optional): Whether to show the input image. Defaults to True. + """ + + import matplotlib.pyplot as plt + + anns = self.masks + + if self.image is None: + print("Please run generate() first.") + return + + if anns is None or len(anns) == 0: + return + + plt.figure(figsize=figsize) + plt.imshow(self.image) + + sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) + + ax = plt.gca() + ax.set_autoscale_on(False) + + img = np.ones( + ( + sorted_anns[0]["segmentation"].shape[0], + sorted_anns[0]["segmentation"].shape[1], + 4, + ) + ) + img[:, :, 3] = 0 + for ann in sorted_anns: + m = ann["segmentation"] + color_mask = np.concatenate([np.random.random(3), [alpha]]) + img[m] = color_mask + ax.imshow(img) + + if "dpi" not in kwargs: + kwargs["dpi"] = 100 + + if "bbox_inches" not in kwargs: + kwargs["bbox_inches"] = "tight" + + plt.axis(axis) + + self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8) + + if output is not None: + if blend: + array = common.blend_images( + self.annotations, self.image, alpha=alpha, show=False + ) + else: + array = self.annotations + common.array_to_image(array, output, self.source)