diff --git a/ns_vfs/model/vision/_base.py b/ns_vfs/model/vision/_base.py index 8a781ca..c9beb39 100644 --- a/ns_vfs/model/vision/_base.py +++ b/ns_vfs/model/vision/_base.py @@ -1,3 +1,5 @@ +import supervision as sv +import numpy as np import abc @@ -27,6 +29,20 @@ def get_weight(self): """Get weight.""" return self._weight + def get_labels(self) -> list: + """Return sv.Detections""" + return self._labels + + def get_detections(self) -> sv.Detections: + """Return sv.Detections""" + return self._detection + + def get_confidence(self) -> np.ndarray: + return self._confidence + + def get_size(self) -> int: + return self._size + @abc.abstractmethod def detect(self, frame) -> any: """Detect object in frame.""" diff --git a/ns_vfs/model/vision/grounding_dino.py b/ns_vfs/model/vision/grounding_dino.py index 17e7b4b..078384c 100644 --- a/ns_vfs/model/vision/grounding_dino.py +++ b/ns_vfs/model/vision/grounding_dino.py @@ -1,14 +1,13 @@ from __future__ import annotations -import warnings - from groundingdino.util.inference import Model from omegaconf import DictConfig +import numpy as np +import warnings from ns_vfs.model.vision._base import ComputerVisionDetector warnings.filterwarnings("ignore") -import numpy as np class GroundingDino(ComputerVisionDetector): @@ -59,11 +58,22 @@ def detect(self, frame_img: np.ndarray, classes: list) -> any: Returns: any: Detections. """ - detections = self.model.predict_with_classes( + detected_obj = self.model.predict_with_classes( image=frame_img, classes=self._parse_class_name(class_names=classes), box_threshold=self._config.BOX_TRESHOLD, text_threshold=self._config.TEXT_TRESHOLD, ) - return detections + self._labels = [ + f"{classes[class_id] if class_id is not None else None} {confidence:0.2f}" + for _, _, confidence, class_id, _ in detected_obj + ] + + self._detections = detected_obj + + self._confidence = detected_obj.confidence + + self._size = len(detected_obj) + + return detected_obj diff --git a/ns_vfs/model/vision/yolo.py b/ns_vfs/model/vision/yolo.py index d79eb67..3c677be 100644 --- a/ns_vfs/model/vision/yolo.py +++ b/ns_vfs/model/vision/yolo.py @@ -1,14 +1,14 @@ from __future__ import annotations -import warnings - -from ultralytics import YOLO from omegaconf import DictConfig +from ultralytics import YOLO +import supervision as sv +import numpy as np +import warnings from ns_vfs.model.vision._base import ComputerVisionDetector warnings.filterwarnings("ignore") -import numpy as np class Yolo(ComputerVisionDetector): @@ -58,9 +58,21 @@ def detect(self, frame_img: np.ndarray, classes: list) -> any: """ classes_reversed = {v:k for k, v in self.model.names.items()} class_ids = [classes_reversed[c] for c in classes] - detections = self.model.predict( + detected_obj = self.model.predict( source=frame_img, classes=class_ids ) - return detections + self._labels = [] + for i in range(len(detected_obj[0].boxes)): + class_id = int(detected_obj[0].boxes.cls[i]) + confidence = float(detected_obj[0].boxes.conf[i]) + self._labels.append(f"{detected_obj[0].names[class_id] if class_id is not None else None} {confidence:0.2f}") + + self._detections = sv.Detections(xyxy=detected_obj[0].boxes.xyxy.cpu().detach().numpy()) + + self._confidence = detected_obj[0].boxes.conf.cpu().detach().numpy() + + self._size = len(detected_obj[0].boxes) + + return detected_obj diff --git a/ns_vfs/video_to_automaton.py b/ns_vfs/video_to_automaton.py index 2e161f6..92c3f38 100644 --- a/ns_vfs/video_to_automaton.py +++ b/ns_vfs/video_to_automaton.py @@ -64,29 +64,19 @@ def _sigmoid(self, x, k=1, x0=0) -> float: def _annotate_frame( self, frame_img: np.ndarray, - proposition: list, - detected_obj: any, output_dir: str | None = None, ) -> None: """Annotate frame with bounding box. Args: frame_img (np.ndarray): Frame image. - proposition (list): List of propositions. detected_obj (any): Detected object. output_dir (str | None, optional): Output directory. Defaults to None. """ box_annotator = sv.BoxAnnotator() - labels = [] - for i in range(len(detected_obj[0].boxes)): - class_id = int(detected_obj[0].boxes.cls[i]) - confidence = float(detected_obj[0].boxes.conf[i]) - labels.append(f"{detected_obj[0].names[class_id] if class_id is not None else None} {confidence:0.2f}") - - detections = sv.Detections(xyxy=detected_obj[0].boxes.xyxy.cpu().detach().numpy()) annotated_frame = box_annotator.annotate( - scene=frame_img.copy(), detections=detections, labels=labels + scene=frame_img.copy(), detections=self._detector.get_detections(), labels=self._detector.get_labels() ) sv.plot_image(annotated_frame, (16, 16)) @@ -161,15 +151,13 @@ def get_probabilistic_proposition_from_frame( Returns: float: Probabilistic proposition from frame. """ - detected_obj = self._detector.detect(frame_img, [proposition]) - if len(detected_obj[0].boxes) > 0: + self._detector.detect(frame_img, [proposition]) + if self._detector.get_size() > 0: if is_annotation: self._annotate_frame( frame_img=frame_img, - detected_obj=detected_obj, - proposition=[proposition], ) - return self._mapping_probability(np.round(np.max(detected_obj[0].boxes.conf.cpu().detach().numpy()), 2)) + return self._mapping_probability(np.round(np.max(self._detector.get_confidence()), 2)) # probability of the object in the frame else: return 0 # probability of the object in the frame is 0 diff --git a/run_frame_to_automata.py b/run_frame_to_automata.py index 24a70ee..ff6762b 100644 --- a/run_frame_to_automata.py +++ b/run_frame_to_automata.py @@ -1,10 +1,9 @@ from __future__ import annotations from ns_vfs.config.loader import load_config +from ns_vfs.model.vision.grounding_dino import GroundingDino from ns_vfs.model.vision.yolo import Yolo -from ns_vfs.processor.video_processor import ( - VideoFrameWindowProcessor, -) +from ns_vfs.processor.video_processor import VideoFrameWindowProcessor from ns_vfs.video_to_automaton import VideotoAutomaton if __name__ == "__main__": @@ -15,9 +14,14 @@ config = load_config() frame2automaton = VideotoAutomaton( - detector=Yolo( - config=config.YOLO, - weight_path=config.YOLO.YOLO_CHECKPOINT_PATH, + # detector=Yolo( + # config=config.YOLO, + # weight_path=config.YOLO.YOLO_CHECKPOINT_PATH, + # ), + detector=GroundingDino( + config=config.GROUNDING_DINO, + weight_path=config.GROUNDING_DINO.GROUNDING_DINO_CHECKPOINT_PATH, + config_path=config.GROUNDING_DINO.GROUNDING_DINO_CONFIG_PATH, ), video_processor=VideoFrameWindowProcessor( video_path=sample_video_path,