Skip to content

Commit

Permalink
Add mask and anns functions
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Sep 13, 2024
1 parent 3590d05 commit 916e5aa
Showing 1 changed file with 185 additions and 5 deletions.
190 changes: 185 additions & 5 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,190 @@ def generate(
self.masks = masks # Store the masks as a list of dictionaries
self.batch = False

# if output is not None:
# # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
# self.save_masks(
# output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
# )
if output is not None:
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
self.save_masks(
output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
)

return masks

def save_masks(
self,
output=None,
foreground=True,
unique=True,
erosion_kernel=None,
mask_multiplier=255,
**kwargs,
):
"""Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
Args:
output (str, optional): The path to the output image. Defaults to None, saving the masks to SamGeo.objects.
foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
"""

if self.masks is None:
raise ValueError("No masks found. Please run generate() first.")

h, w, _ = self.image.shape
masks = self.masks

# Set output image data type based on the number of objects
if len(masks) < 255:
dtype = np.uint8
elif len(masks) < 65535:
dtype = np.uint16
else:
dtype = np.uint32

# Generate a mask of objects with unique values
if unique:
# Sort the masks by area in ascending order
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False)

# Create an output image with the same size as the input image
objects = np.zeros(
(
sorted_masks[0]["segmentation"].shape[0],
sorted_masks[0]["segmentation"].shape[1],
)
)
# Assign a unique value to each object
for index, ann in enumerate(sorted_masks):
m = ann["segmentation"]
objects[m] = index + 1

# Generate a binary mask
else:
if foreground: # Extract foreground objects only
resulting_mask = np.zeros((h, w), dtype=dtype)
else:
resulting_mask = np.ones((h, w), dtype=dtype)
resulting_borders = np.zeros((h, w), dtype=dtype)

for m in masks:
mask = (m["segmentation"] > 0).astype(dtype)
resulting_mask += mask

# Apply erosion to the mask
if erosion_kernel is not None:
mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
mask_erode = (mask_erode > 0).astype(dtype)
edge_mask = mask - mask_erode
resulting_borders += edge_mask

resulting_mask = (resulting_mask > 0).astype(dtype)
resulting_borders = (resulting_borders > 0).astype(dtype)
objects = resulting_mask - resulting_borders
objects = objects * mask_multiplier

objects = objects.astype(dtype)
self.objects = objects

if output is not None: # Save the output image
common.array_to_image(self.objects, output, self.source, **kwargs)

def show_masks(
self, figsize=(12, 10), cmap="binary_r", axis="off", foreground=True, **kwargs
):
"""Show the binary mask or the mask of objects with unique values.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
cmap (str, optional): The colormap. Defaults to "binary_r".
axis (str, optional): Whether to show the axis. Defaults to "off".
foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
**kwargs: Other arguments for save_masks().
"""

import matplotlib.pyplot as plt

if self.batch:
self.objects = cv2.imread(self.masks)
else:
if self.objects is None:
self.save_masks(foreground=foreground, **kwargs)

plt.figure(figsize=figsize)
plt.imshow(self.objects, cmap=cmap)
plt.axis(axis)
plt.show()

def show_anns(
self,
figsize=(12, 10),
axis="off",
alpha=0.35,
output=None,
blend=True,
**kwargs,
):
"""Show the annotations (objects with random color) on the input image.
Args:
figsize (tuple, optional): The figure size. Defaults to (12, 10).
axis (str, optional): Whether to show the axis. Defaults to "off".
alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
output (str, optional): The path to the output image. Defaults to None.
blend (bool, optional): Whether to show the input image. Defaults to True.
"""

import matplotlib.pyplot as plt

anns = self.masks

if self.image is None:
print("Please run generate() first.")
return

if anns is None or len(anns) == 0:
return

plt.figure(figsize=figsize)
plt.imshow(self.image)

sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)

ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones(
(
sorted_anns[0]["segmentation"].shape[0],
sorted_anns[0]["segmentation"].shape[1],
4,
)
)
img[:, :, 3] = 0
for ann in sorted_anns:
m = ann["segmentation"]
color_mask = np.concatenate([np.random.random(3), [alpha]])
img[m] = color_mask
ax.imshow(img)

if "dpi" not in kwargs:
kwargs["dpi"] = 100

if "bbox_inches" not in kwargs:
kwargs["bbox_inches"] = "tight"

plt.axis(axis)

self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)

if output is not None:
if blend:
array = common.blend_images(
self.annotations, self.image, alpha=alpha, show=False
)
else:
array = self.annotations
common.array_to_image(array, output, self.source)

0 comments on commit 916e5aa

Please sign in to comment.