-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_to_masks.py
61 lines (49 loc) · 2.06 KB
/
image_to_masks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import numpy as np
import torch
import cv2
class ImageToMasks:
def __init__(
self, model_name="vit_h", checkpoint="./models/SAM/sam_vit_h_4b8939.pth"
):
# Determine the best available device
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
self.model_name = model_name
self.checkpoint = checkpoint
self.sam = sam_model_registry[self.model_name](checkpoint=self.checkpoint)
self.sam = self.sam.to(device=self.device)
def read_image(self, file):
image = cv2.imread(file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def get_masks(self, image, count=5):
# get the masks and sort by greatest area
mask_generator = SamAutomaticMaskGenerator(model=self.sam)
masks = mask_generator.generate(image)
masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
masks = masks[:count]
return masks
def get_cutouts(self, image):
masks = self.get_masks(image)
cutouts = []
for i, mask in enumerate(masks):
# Create an empty image with the same size as the image
segmented_img = np.ones_like(image)
# Apply the mask to copy the segmented object
for c in range(3): # Assuming 3 channels (RGB)
segmented_img[:, :, c] = image[:, :, c] * mask["segmentation"]
cutouts.append(segmented_img)
return cutouts
def save_cutouts(self, image, output_folder):
cutouts = self.get_cutouts(image)
for i, cutout in enumerate(cutouts):
# Save the segmented object image
filename = f"{output_folder}/segment_{i+1}.png"
cv2.imwrite(filename, cv2.cvtColor(cutout, cv2.COLOR_RGB2BGR))
print(f"Saved: {filename}")
if __name__ == "__main__":
sam = ImageToMasks()
image = sam.read_image("images/telephone_booth.jpg")
sam.save_cutouts(image, "./images")