From 7040d0efada908a2bc35f0f3777332c63426bb11 Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Sat, 14 Sep 2024 00:05:10 -0400 Subject: [PATCH] Add video predictor --- samgeo/samgeo2.py | 155 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 1 deletion(-) diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index ad0981b2..d93178ff 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -4,8 +4,9 @@ import numpy as np from PIL.Image import Image from typing import Any, Dict, List, Optional, Tuple, Union -from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.sam2_video_predictor import SAM2VideoPredictor from . import common @@ -21,6 +22,7 @@ def __init__( device: Optional[str] = None, empty_cache: bool = True, automatic: bool = True, + video: bool = False, mode: str = "eval", hydra_overrides_extra: Optional[List[str]] = None, apply_postprocessing: bool = False, @@ -54,6 +56,7 @@ def __init__( device (Optional[str]): The device to use (e.g., "cpu", "cuda", "mps"). Defaults to None. empty_cache (bool): Whether to empty the cache. Defaults to True. automatic (bool): Whether to use automatic mask generation. Defaults to True. + video (bool): Whether to use video prediction. Defaults to False. mode (str): The mode to use. Defaults to "eval". hydra_overrides_extra (Optional[List[str]]): Additional Hydra overrides. Defaults to None. apply_postprocessing (bool): Whether to apply postprocessing. Defaults to False. @@ -129,6 +132,9 @@ def __init__( self.model_id = model_id self.device = device + if video: + automatic = False + if automatic: self.mask_generator = SAM2AutomaticMaskGenerator.from_pretrained( model_id, @@ -154,11 +160,22 @@ def __init__( multimask_output=multimask_output, **kwargs, ) + elif video: + self.predictor = SAM2VideoPredictor.from_pretrained( + model_id, + device=device, + mode=mode, + hydra_overrides_extra=hydra_overrides_extra, + apply_postprocessing=apply_postprocessing, + **kwargs, + ) else: self.predictor = SAM2ImagePredictor.from_pretrained( model_id, device=device, mode=mode, + hydra_overrides_extra=hydra_overrides_extra, + apply_postprocessing=apply_postprocessing, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area, @@ -574,3 +591,139 @@ def predict_batch( return_logits=return_logits, normalize_coords=normalize_coords, ) + + @torch.inference_mode() + def init_state( + self, + video_path: str, + offload_video_to_cpu: bool = False, + offload_state_to_cpu: bool = False, + async_loading_frames: bool = False, + ) -> Any: + """Initialize an inference state. + + Args: + video_path (str): The path to the video file. + offload_video_to_cpu (bool): Whether to offload the video to CPU. + Defaults to False. + offload_state_to_cpu (bool): Whether to offload the state to CPU. + Defaults to False. + async_loading_frames (bool): Whether to load frames asynchronously. + Defaults to False. + + Returns: + Any: The initialized inference state. + """ + return self.predictor.init_state( + video_path, + offload_video_to_cpu=offload_video_to_cpu, + offload_state_to_cpu=offload_state_to_cpu, + async_loading_frames=async_loading_frames, + ) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state: Any, + frame_idx: int, + obj_id: int, + points: Optional[np.ndarray] = None, + labels: Optional[np.ndarray] = None, + clear_old_points: bool = True, + normalize_coords: bool = True, + box: Optional[np.ndarray] = None, + ) -> Any: + """Add new points or a box to the inference state. + + Args: + inference_state (Any): The current inference state. + frame_idx (int): The frame index. + obj_id (int): The object ID. + points (Optional[np.ndarray]): The points to add. Defaults to None. + labels (Optional[np.ndarray]): The labels for the points. Defaults to None. + clear_old_points (bool): Whether to clear old points. Defaults to True. + normalize_coords (bool): Whether to normalize the coordinates. Defaults to True. + box (Optional[np.ndarray]): The bounding box to add. Defaults to None. + + Returns: + Any: The updated inference state. + """ + return self.predictor.add_new_points_or_box( + inference_state, + frame_idx, + obj_id, + points=points, + labels=labels, + clear_old_points=clear_old_points, + normalize_coords=normalize_coords, + box=box, + ) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state: Any, + frame_idx: int, + obj_id: int, + mask: np.ndarray, + ) -> Any: + """Add a new mask to the inference state. + + Args: + inference_state (Any): The current inference state. + frame_idx (int): The frame index. + obj_id (int): The object ID. + mask (np.ndarray): The mask to add. + + Returns: + Any: The updated inference state. + """ + return self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask) + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state: Any) -> Any: + """Propagate the inference state in video preflight. + + Args: + inference_state (Any): The current inference state. + + Returns: + Any: The propagated inference state. + """ + return self.predictor.propagate_in_video_preflight(inference_state) + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state: Any, + start_frame_idx: Optional[int] = None, + max_frame_num_to_track: Optional[int] = None, + reverse: bool = False, + ) -> Any: + """Propagate the inference state in video. + + Args: + inference_state (Any): The current inference state. + start_frame_idx (Optional[int]): The starting frame index. Defaults to None. + max_frame_num_to_track (Optional[int]): The maximum number of frames + to track. Defaults to None. + reverse (bool): Whether to propagate in reverse. Defaults to False. + + Returns: + Any: The propagated inference state. + """ + return self.predictor.propagate_in_video( + inference_state, + start_frame_idx=start_frame_idx, + max_frame_num_to_track=max_frame_num_to_track, + reverse=reverse, + ) + + @torch.inference_mode() + def reset_state(self, inference_state: Any) -> None: + """Remove all input points or mask in all frames throughout the video. + + Args: + inference_state (Any): The current inference state. + """ + self.predictor.reset_state(inference_state)