Skip to content

Commit

Permalink
Add save video prediction blended
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Sep 15, 2024
1 parent c38d149 commit e4b815f
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 5 deletions.
2 changes: 1 addition & 1 deletion samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3239,7 +3239,7 @@ def images_to_video(
height, width, _ = frame.shape
video_size = (width, height)

fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Define the codec for mp4
fourcc = cv2.VideoWriter_fourcc(*"avc1") # Define the codec for mp4
video_writer = cv2.VideoWriter(output_video, fourcc, fps, video_size)

for image_path in images:
Expand Down
107 changes: 103 additions & 4 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL.Image import Image
from tqdm import tqdm
from typing import Any, Dict, List, Optional, Tuple, Union
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor
Expand Down Expand Up @@ -1047,6 +1049,10 @@ def _convert_prompts(self, prompts: Dict[int, Any]) -> Dict[int, Any]:
# Convert labels to np.int32 array
if "labels" in value:
value["labels"] = np.array(value["labels"], dtype=np.int32)
# Convert box to np.float32 array
if "box" in value:
value["box"] = np.array(value["box"], dtype=np.float32)

return prompts

def set_video(
Expand Down Expand Up @@ -1091,6 +1097,7 @@ def set_video(

self.video_path = output_dir
self._num_images = len(os.listdir(output_dir))
self._frame_names = sorted(os.listdir(output_dir))
self.inference_state = self.predictor.init_state(video_path=output_dir)

def predict_video(
Expand Down Expand Up @@ -1131,15 +1138,19 @@ def save_image_from_dict(data, output_path="output_image.png"):
predictor = self.predictor
inference_state = self.inference_state
for obj_id, prompt in prompts.items():
points = prompt["points"]
labels = prompt["labels"]
frame_idx = prompt["frame_idx"]

points = prompt.get("points", None)
labels = prompt.get("labels", None)
box = prompt.get("box", None)
frame_idx = prompt.get("frame_idx", None)

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
points=points,
labels=labels,
box=box,
)

video_segments = {}
Expand Down Expand Up @@ -1202,8 +1213,96 @@ def save_image_from_dict(data, output_path="output_image.png"):
num_frames = len(self.video_segments)
num_digits = len(str(num_frames))

for frame_idx, video_segment in self.video_segments.items():
# Initialize the tqdm progress bar
for frame_idx, video_segment in tqdm(
self.video_segments.items(), desc="Rendering frames", total=num_frames
):
output_path = os.path.join(
output_dir, f"{str(frame_idx).zfill(num_digits)}.{img_ext}"
)
save_image_from_dict(video_segment, output_path)

def save_video_segments_blended(
self,
output_dir: str,
img_ext: str = "png",
dpi: int = 200,
frame_stride: int = 1,
output_video: Optional[str] = None,
fps: int = 30,
) -> None:
"""Save blended video segments to the output directory and optionally create a video.
Args:
output_dir (str): The directory to save the output images.
img_ext (str): The file extension for the output images. Defaults to "png".
dpi (int): The DPI (dots per inch) for the output images. Defaults to 200.
frame_stride (int): The stride for selecting frames to save. Defaults to 1.
output_video (Optional[str]): The path to the output video file. Defaults to None.
fps (int): The frames per second for the output video. Defaults to 30.
"""

from PIL import Image

def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

if not os.path.exists(output_dir):
os.makedirs(output_dir)

plt.close("all")

video_segments = self.video_segments
video_dir = self.video_path
frame_names = self._frame_names
num_frames = len(frame_names)
num_digits = len(str(num_frames))

# Initialize the tqdm progress bar
for out_frame_idx in tqdm(
range(0, len(frame_names), frame_stride), desc="Rendering frames"
):
image = Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))

# Get original image dimensions
w, h = image.size

# Set DPI and calculate figure size based on the original image dimensions
figsize = (
w / dpi,
h / dpi,
)
figsize = (
figsize[0] * 1.3,
figsize[1] * 1.3,
)

# Create a figure with the exact size and DPI
fig = plt.figure(figsize=figsize, dpi=dpi)

# Disable axis to prevent whitespace
plt.axis("off")

# Display the original image
plt.imshow(image)

# Overlay masks for each object ID
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

# Save the figure with no borders or extra padding
filename = f"{str(out_frame_idx).zfill(num_digits)}.{img_ext}"
filepath = os.path.join(output_dir, filename)
plt.savefig(filepath, dpi=dpi, pad_inches=0, bbox_inches="tight")
plt.close(fig)

if output_video is not None:
common.images_to_video(output_dir, output_video, fps=fps)

0 comments on commit e4b815f

Please sign in to comment.