Skip to content

Commit

Permalink
Add video predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Sep 14, 2024
1 parent f0d2e7c commit 7040d0e
Showing 1 changed file with 154 additions and 1 deletion.
155 changes: 154 additions & 1 deletion samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import numpy as np
from PIL.Image import Image
from typing import Any, Dict, List, Optional, Tuple, Union
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.sam2_video_predictor import SAM2VideoPredictor

from . import common

Expand All @@ -21,6 +22,7 @@ def __init__(
device: Optional[str] = None,
empty_cache: bool = True,
automatic: bool = True,
video: bool = False,
mode: str = "eval",
hydra_overrides_extra: Optional[List[str]] = None,
apply_postprocessing: bool = False,
Expand Down Expand Up @@ -54,6 +56,7 @@ def __init__(
device (Optional[str]): The device to use (e.g., "cpu", "cuda", "mps"). Defaults to None.
empty_cache (bool): Whether to empty the cache. Defaults to True.
automatic (bool): Whether to use automatic mask generation. Defaults to True.
video (bool): Whether to use video prediction. Defaults to False.
mode (str): The mode to use. Defaults to "eval".
hydra_overrides_extra (Optional[List[str]]): Additional Hydra overrides. Defaults to None.
apply_postprocessing (bool): Whether to apply postprocessing. Defaults to False.
Expand Down Expand Up @@ -129,6 +132,9 @@ def __init__(
self.model_id = model_id
self.device = device

if video:
automatic = False

if automatic:
self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
model_id,
Expand All @@ -154,11 +160,22 @@ def __init__(
multimask_output=multimask_output,
**kwargs,
)
elif video:
self.predictor = SAM2VideoPredictor.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
**kwargs,
)
else:
self.predictor = SAM2ImagePredictor.from_pretrained(
model_id,
device=device,
mode=mode,
hydra_overrides_extra=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
Expand Down Expand Up @@ -574,3 +591,139 @@ def predict_batch(
return_logits=return_logits,
normalize_coords=normalize_coords,
)

@torch.inference_mode()
def init_state(
self,
video_path: str,
offload_video_to_cpu: bool = False,
offload_state_to_cpu: bool = False,
async_loading_frames: bool = False,
) -> Any:
"""Initialize an inference state.
Args:
video_path (str): The path to the video file.
offload_video_to_cpu (bool): Whether to offload the video to CPU.
Defaults to False.
offload_state_to_cpu (bool): Whether to offload the state to CPU.
Defaults to False.
async_loading_frames (bool): Whether to load frames asynchronously.
Defaults to False.
Returns:
Any: The initialized inference state.
"""
return self.predictor.init_state(
video_path,
offload_video_to_cpu=offload_video_to_cpu,
offload_state_to_cpu=offload_state_to_cpu,
async_loading_frames=async_loading_frames,
)

@torch.inference_mode()
def add_new_points_or_box(
self,
inference_state: Any,
frame_idx: int,
obj_id: int,
points: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
clear_old_points: bool = True,
normalize_coords: bool = True,
box: Optional[np.ndarray] = None,
) -> Any:
"""Add new points or a box to the inference state.
Args:
inference_state (Any): The current inference state.
frame_idx (int): The frame index.
obj_id (int): The object ID.
points (Optional[np.ndarray]): The points to add. Defaults to None.
labels (Optional[np.ndarray]): The labels for the points. Defaults to None.
clear_old_points (bool): Whether to clear old points. Defaults to True.
normalize_coords (bool): Whether to normalize the coordinates. Defaults to True.
box (Optional[np.ndarray]): The bounding box to add. Defaults to None.
Returns:
Any: The updated inference state.
"""
return self.predictor.add_new_points_or_box(
inference_state,
frame_idx,
obj_id,
points=points,
labels=labels,
clear_old_points=clear_old_points,
normalize_coords=normalize_coords,
box=box,
)

@torch.inference_mode()
def add_new_mask(
self,
inference_state: Any,
frame_idx: int,
obj_id: int,
mask: np.ndarray,
) -> Any:
"""Add a new mask to the inference state.
Args:
inference_state (Any): The current inference state.
frame_idx (int): The frame index.
obj_id (int): The object ID.
mask (np.ndarray): The mask to add.
Returns:
Any: The updated inference state.
"""
return self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)

@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state: Any) -> Any:
"""Propagate the inference state in video preflight.
Args:
inference_state (Any): The current inference state.
Returns:
Any: The propagated inference state.
"""
return self.predictor.propagate_in_video_preflight(inference_state)

@torch.inference_mode()
def propagate_in_video(
self,
inference_state: Any,
start_frame_idx: Optional[int] = None,
max_frame_num_to_track: Optional[int] = None,
reverse: bool = False,
) -> Any:
"""Propagate the inference state in video.
Args:
inference_state (Any): The current inference state.
start_frame_idx (Optional[int]): The starting frame index. Defaults to None.
max_frame_num_to_track (Optional[int]): The maximum number of frames
to track. Defaults to None.
reverse (bool): Whether to propagate in reverse. Defaults to False.
Returns:
Any: The propagated inference state.
"""
return self.predictor.propagate_in_video(
inference_state,
start_frame_idx=start_frame_idx,
max_frame_num_to_track=max_frame_num_to_track,
reverse=reverse,
)

@torch.inference_mode()
def reset_state(self, inference_state: Any) -> None:
"""Remove all input points or mask in all frames throughout the video.
Args:
inference_state (Any): The current inference state.
"""
self.predictor.reset_state(inference_state)

0 comments on commit 7040d0e

Please sign in to comment.