From 217353a3a2fc919d2330d9c74f78e40d96d72b35 Mon Sep 17 00:00:00 2001 From: Shay Aharon <80472096+shaydeci@users.noreply.github.com> Date: Wed, 22 May 2024 11:47:42 +0300 Subject: [PATCH] Feature/sg 1442 sliding window inference for yolonas (#1979) * wip * wip * wip2 * working version, hard coded nms params * moved post prediction callback to utils * moved back to wrapper * added abstract class, small refactoring for pipeline * rolled back customizable detector, solved pretrained weights setting of proccessing for the wrapper * temp cleanup * support for fuse model in predict * example added for predict * added support for forward wrappers in trainer * added test for validation forward wrapper * added option for None as post prediction callback in DetectionMetrics * wip adding set_model before using wrapper * commit changes before removal of validation during training support * refined docs * removed old test for forward wrapper, fixed defaults * fixed test and added clarifications * forward wrapper test removed * updated wrong threshold extraction and test result * fixed docstring format --- .../sliding_sindow_detection_predict.py | 20 + .../training/metrics/detection_metrics.py | 5 +- .../detection_models/customizable_detector.py | 15 + ...liding_window_detection_forward_wrapper.py | 392 ++++++++++++++++++ .../training/pipelines/pipelines.py | 57 ++- .../detection_sliding_window_wrapper_test.py | 30 ++ 6 files changed, 506 insertions(+), 13 deletions(-) create mode 100644 src/super_gradients/examples/predict/sliding_sindow_detection_predict.py create mode 100644 src/super_gradients/training/models/detection_models/sliding_window_detection_forward_wrapper.py create mode 100644 tests/unit_tests/detection_sliding_window_wrapper_test.py diff --git a/src/super_gradients/examples/predict/sliding_sindow_detection_predict.py b/src/super_gradients/examples/predict/sliding_sindow_detection_predict.py new file mode 100644 index 0000000000..12e53a7f48 --- /dev/null +++ b/src/super_gradients/examples/predict/sliding_sindow_detection_predict.py @@ -0,0 +1,20 @@ +import torch +from super_gradients.common.object_names import Models +from super_gradients.training import models + + +# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported. +from super_gradients.training.models.detection_models.sliding_window_detection_forward_wrapper import SlidingWindowInferenceDetectionWrapper + +model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco") + +# We want to use cuda if available to speed up inference. +model = model.to("cuda" if torch.cuda.is_available() else "cpu") + +model = SlidingWindowInferenceDetectionWrapper(model=model, tile_size=640, tile_step=160, tile_nms_conf=0.35) + +predictions = model.predict( + "https://images.pexels.com/photos/7968254/pexels-photo-7968254.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2", skip_image_resizing=True +) +predictions.show() +predictions.save(output_path="2.jpg") # Save in working directory diff --git a/src/super_gradients/training/metrics/detection_metrics.py b/src/super_gradients/training/metrics/detection_metrics.py index 10beb2a6c8..7c51215fb3 100755 --- a/src/super_gradients/training/metrics/detection_metrics.py +++ b/src/super_gradients/training/metrics/detection_metrics.py @@ -33,6 +33,8 @@ class DetectionMetrics(Metric): :param num_cls: Number of classes. :param post_prediction_callback: DetectionPostPredictionCallback to be applied on net's output prior to the metric computation (NMS). + When None, the direct outputs of the model will be used. + :param normalize_targets: Whether to normalize bbox coordinates by image size. :param iou_thres: IoU threshold to compute the mAP. Could be either instance of IouThreshold, a tuple (lower bound, upper_bound) or single scalar. @@ -179,7 +181,8 @@ def update(self, preds, target: torch.Tensor, device: str, inputs: torch.tensor, targets = target.clone() crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.clone() - preds = self.post_prediction_callback(preds, device=device) + if self.post_prediction_callback is not None: + preds = self.post_prediction_callback(preds, device=device) new_matching_info = compute_detection_matching( preds, diff --git a/src/super_gradients/training/models/detection_models/customizable_detector.py b/src/super_gradients/training/models/detection_models/customizable_detector.py index 1dc372e2aa..92bfe9b23c 100644 --- a/src/super_gradients/training/models/detection_models/customizable_detector.py +++ b/src/super_gradients/training/models/detection_models/customizable_detector.py @@ -194,9 +194,24 @@ def set_dataset_processing_params( if class_agnostic_nms is not None: self._default_class_agnostic_nms = bool(class_agnostic_nms) + def get_dataset_processing_params(self): + return dict( + class_names=self._class_names, + image_processor=self._image_processor, + iou=self._default_nms_iou, + conf=self._default_nms_iou, + nms_top_k=self._default_nms_top_k, + max_predictions=self._default_max_predictions, + multi_label_per_box=self._default_multi_label_per_box, + class_agnostic_nms=self._default_class_agnostic_nms, + ) + def get_processing_params(self) -> Optional[Processing]: return self._image_processor + def get_class_names(self) -> Optional[List[str]]: + return self._class_names + @lru_cache(maxsize=1) def _get_pipeline( self, diff --git a/src/super_gradients/training/models/detection_models/sliding_window_detection_forward_wrapper.py b/src/super_gradients/training/models/detection_models/sliding_window_detection_forward_wrapper.py new file mode 100644 index 0000000000..e31b6a1bdf --- /dev/null +++ b/src/super_gradients/training/models/detection_models/sliding_window_detection_forward_wrapper.py @@ -0,0 +1,392 @@ +from typing import Optional, List +from functools import lru_cache + +import torch +from torch import nn +from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.common.factories.processing_factory import ProcessingFactory +from super_gradients.module_interfaces import HasPredict +from super_gradients.training.models import CustomizableDetector +from super_gradients.training.utils.predict import ImagesDetectionPrediction +from super_gradients.training.pipelines.pipelines import SlidingWindowDetectionPipeline +from super_gradients.training.processing.processing import Processing, ComposeProcessing, DetectionAutoPadding +from super_gradients.training.utils.media.image import ImageSource +import torchvision +from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback + + +class SlidingWindowInferenceDetectionWrapper(HasPredict, nn.Module): + """ + Implements a sliding window inference wrapper for a customizable detector. + + :param tile_size: (int) The size of each square tile (in pixels) used in the sliding window. + :param tile_step: (int) The step size (in pixels) between consecutive tiles in the sliding window. + :param model: (CustomizableDetector) The detection model to which the sliding window inference is applied. + :param min_tile_threshold: (int) Minimum dimension size for edge tiles before padding is applied. + If the remainder of the image (after the full tiles have been applied) is smaller than this threshold, + it will not be processed. + :param tile_nms_iou: (Optional[float]) IoU threshold for Non-Maximum Suppression (NMS) of bounding boxes. + Defaults to the model's internal setting if None. + :param tile_nms_conf: (Optional[float]) Confidence threshold for predictions to consider in post-processing. + Defaults to the model's internal setting if None. + :param tile_nms_top_k: (Optional[int]) Maximum number of top-scoring detections to consider for NMS in each tile. + Defaults to the model's internal setting if None. + :param tile_nms_max_predictions: (Optional[int]) Maximum number of detections to return from each tile. + Defaults to the model's internal setting if None. + :param tile_nms_multi_label_per_box: (Optional[bool]) Allows multiple labels per box if True. Each anchor can produce + multiple labels of different classes that pass the confidence threshold. Only the highest-scoring class is considered + per anchor if False. Defaults to the model's internal setting if None. + :param tile_nms_class_agnostic_nms: (Optional[bool]) Performs class-agnostic NMS if True, where the IoU of boxes across + different classes is considered. Performs class-specific NMS if False. Defaults to the model's internal setting if None. + """ + + def __init__( + self, + tile_size: int, + tile_step: int, + model: Optional[CustomizableDetector], + min_tile_threshold: int = 30, + tile_nms_iou: Optional[float] = None, + tile_nms_conf: Optional[float] = None, + tile_nms_top_k: Optional[int] = None, + tile_nms_max_predictions: Optional[int] = None, + tile_nms_multi_label_per_box: Optional[bool] = None, + tile_nms_class_agnostic_nms: Optional[bool] = None, + ): + + super().__init__() + self.tile_size = tile_size + self.tile_step = tile_step + self.min_tile_threshold = min_tile_threshold + + # GENERAL DEFAULTS + self._class_names: Optional[List[str]] = None + self._image_processor: Optional[Processing] = None + self._default_nms_iou: float = 0.7 + self._default_nms_conf: float = 0.5 + self._default_nms_top_k: int = 1024 + self._default_max_predictions = 300 + self._default_multi_label_per_box = True + self._default_class_agnostic_nms = False + + # TAKE PROCESSING PARAMS FROM THE WRAPPED MODEL IF THEY ARE AVAILABLE, OTHERWISE USE THE GENERAL DEFAULTS + self.model = model + self.set_dataset_processing_params(**self.model.get_dataset_processing_params()) + + # OVERRIDE WITH ANY EXPLICITLY PASSED PROCESSING PARAMS + if any( + arg is not None + for arg in [tile_nms_iou, tile_nms_conf, tile_nms_top_k, tile_nms_max_predictions, tile_nms_multi_label_per_box, tile_nms_class_agnostic_nms] + ): + self.set_dataset_processing_params( + iou=tile_nms_iou, + conf=tile_nms_conf, + nms_top_k=tile_nms_top_k, + max_predictions=tile_nms_max_predictions, + multi_label_per_box=tile_nms_multi_label_per_box, + class_agnostic_nms=tile_nms_class_agnostic_nms, + ) + else: + + self.sliding_window_post_prediction_callback = self.get_post_prediction_callback( + iou=self._default_nms_iou, + conf=self._default_nms_conf, + nms_top_k=self._default_nms_top_k, + max_predictions=self._default_max_predictions, + multi_label_per_box=self._default_multi_label_per_box, + class_agnostic_nms=self._default_class_agnostic_nms, + ) + + def forward(self, inputs: torch.Tensor, sliding_window_post_prediction_callback: Optional[DetectionPostPredictionCallback] = None) -> List[torch.Tensor]: + + sliding_window_post_prediction_callback = sliding_window_post_prediction_callback or self.sliding_window_post_prediction_callback + batch_size, _, _, _ = inputs.shape + all_detections = [[] for _ in range(batch_size)] # Create a list for each image in the batch + # Generate and process each tile + for img_idx in range(batch_size): + single_image = inputs[img_idx : img_idx + 1] # Extract each image + tiles = self._generate_tiles(single_image, self.tile_size, self.tile_step) + for tile, (start_x, start_y) in tiles: + tile_detections = self.model(tile) + # Apply local NMS using post_prediction_callback + tile_detections = sliding_window_post_prediction_callback(tile_detections) + # Adjust detections to global image coordinates + for img_i_tile_detections in tile_detections: + if len(img_i_tile_detections) > 0: + img_i_tile_detections[:, :4] += torch.tensor([start_x, start_y, start_x, start_y], device=tile.device) + all_detections[img_idx].append(img_i_tile_detections) + # Concatenate and apply global NMS for each image's detections + final_detections = [] + for detections in all_detections: + if detections: + detections = torch.cat(detections, dim=0) + # Apply global NMS + pred_bboxes = detections[:, :4] + pred_cls_conf = detections[:, 4] + pred_cls_label = detections[:, 5] + idx_to_keep = torchvision.ops.boxes.batched_nms( + boxes=pred_bboxes, scores=pred_cls_conf, idxs=pred_cls_label, iou_threshold=sliding_window_post_prediction_callback.nms_threshold + ) + + final_detections.append(detections[idx_to_keep]) + else: + final_detections.append(torch.empty(0, 6).to(inputs.device)) # Empty tensor for images with no detections + return final_detections + + def _generate_tiles(self, image, tile_size, tile_step): + _, _, h, w = image.shape + tiles = [] + + # Calculate the end points for the grid + max_y = h if (h - tile_size) % tile_step < self.min_tile_threshold else h - (h - tile_size) % tile_step + tile_size + max_x = w if (w - tile_size) % tile_step < self.min_tile_threshold else w - (w - tile_size) % tile_step + tile_size + + # Ensure that the image has enough padding if needed + if max_y > h or max_x > w: + padded_image = torch.zeros((image.shape[0], image.shape[1], max(max_y, h), max(max_x, w)), device=image.device) + padded_image[:, :, :h, :w] = image # Place the original image in the padded one + else: + padded_image = image + + for y in range(0, max_y - tile_size + 1, tile_step): + for x in range(0, max_x - tile_size + 1, tile_step): + tile = padded_image[:, :, y : y + tile_size, x : x + tile_size] + tiles.append((tile, (x, y))) + + return tiles + + def get_post_prediction_callback( + self, *, conf: float, iou: float, nms_top_k: int, max_predictions: int, multi_label_per_box: bool, class_agnostic_nms: bool + ) -> DetectionPostPredictionCallback: + """ + Get a post prediction callback for this model. + + :param conf: A minimum confidence threshold for predictions to be used in post-processing. + :param iou: A IoU threshold for boxes non-maximum suppression. + :param nms_top_k: The maximum number of detections to consider for the NMS applied on each tile. + :param max_predictions: The maximum number of detections to return in each tile. + :param multi_label_per_box: If True, each anchor can produce multiple labels of different classes. + If False, each anchor can produce only one label of the class with the highest score. + :param class_agnostic_nms: If True, perform class-agnostic NMS (i.e IoU of boxes of different classes is checked). + If False NMS is performed separately for each class. + :return: + """ + return self.model.get_post_prediction_callback( + conf=conf, + iou=iou, + nms_top_k=nms_top_k, + max_predictions=max_predictions, + multi_label_per_box=multi_label_per_box, + class_agnostic_nms=class_agnostic_nms, + ) + + @resolve_param("image_processor", ProcessingFactory()) + def set_dataset_processing_params( + self, + class_names: Optional[List[str]] = None, + image_processor: Optional[Processing] = None, + iou: Optional[float] = None, + conf: Optional[float] = None, + nms_top_k: Optional[int] = None, + max_predictions: Optional[int] = None, + multi_label_per_box: Optional[bool] = None, + class_agnostic_nms: Optional[bool] = None, + ) -> None: + """Set the processing parameters for the dataset. + + :param class_names: (Optional) Names of the dataset the model was trained on. + :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training. + :param iou: (Optional) IoU threshold for the nms algorithm applied. + :param conf: (Optional) Below the confidence threshold, prediction are discarded + :param nms_top_k: (Optional) The maximum number of detections to consider for NMS in each tile. + :param max_predictions: (Optional) The maximum number of detections to return in each tile. + :param multi_label_per_box: (Optional) If True, each anchor can produce multiple labels of different classes. + If False, each anchor can produce only one label of the class with the highest score. + :param class_agnostic_nms: (Optional) If True, perform class-agnostic NMS (i.e IoU of boxes of different classes is checked). + If False NMS is performed separately for each class. + """ + if class_names is not None: + self._class_names = tuple(class_names) + if image_processor is not None: + self._image_processor = image_processor + + if iou is None: + iou = self._default_nms_iou + if conf is None: + conf = self._default_nms_conf + if nms_top_k is None: + nms_top_k = self._default_nms_top_k + if max_predictions is None: + max_predictions = self._default_max_predictions + if multi_label_per_box is None: + multi_label_per_box = self._default_multi_label_per_box + if class_agnostic_nms is None: + class_agnostic_nms = self._default_class_agnostic_nms + + self.sliding_window_post_prediction_callback = self.get_post_prediction_callback( + iou=float(iou), + conf=float(conf), + nms_top_k=int(nms_top_k), + max_predictions=int(max_predictions), + multi_label_per_box=bool(multi_label_per_box), + class_agnostic_nms=bool(class_agnostic_nms), + ) + + def get_processing_params(self) -> Optional[Processing]: + return self._image_processor + + @lru_cache(maxsize=1) + def _get_pipeline( + self, + *, + iou: Optional[float] = None, + conf: Optional[float] = None, + fuse_model: bool = True, + skip_image_resizing: bool = False, + nms_top_k: Optional[int] = None, + max_predictions: Optional[int] = None, + multi_label_per_box: Optional[bool] = None, + class_agnostic_nms: Optional[bool] = None, + fp16: bool = True, + ) -> SlidingWindowDetectionPipeline: + """Instantiate the prediction pipeline of this model. + + :param iou: (Optional) IoU threshold for the nms algorithm. + If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage. + :param skip_image_resizing: If True, the image processor will not resize the images. + :param nms_top_k: (Optional) The maximum number of detections to consider for NMS for each tile. + :param max_predictions: (Optional) The maximum number of detections to return for each tile. + :param multi_label_per_box: (Optional) If True, each anchor can produce multiple labels of different classes. + If False, each anchor can produce only one label of the class with the highest score. + :param class_agnostic_nms: (Optional) If True, perform class-agnostic NMS (i.e IoU of boxes of different classes is checked). + If False NMS is performed separately for each class. + :param fp16: If True, use mixed precision for inference. + """ + if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf): + raise RuntimeError( + "You must set the dataset processing parameters before calling predict.\n" + "Please call " + "`model.set_dataset_processing_params(...)` first or do so on self.model. " + ) + + iou = self._default_nms_iou if iou is None else iou + conf = self._default_nms_conf if conf is None else conf + nms_top_k = self._default_nms_top_k if nms_top_k is None else nms_top_k + max_predictions = self._default_max_predictions if max_predictions is None else max_predictions + multi_label_per_box = self._default_multi_label_per_box if multi_label_per_box is None else multi_label_per_box + class_agnostic_nms = self._default_class_agnostic_nms if class_agnostic_nms is None else class_agnostic_nms + + # Ensure that the image size is divisible by 32. + if isinstance(self._image_processor, ComposeProcessing) and skip_image_resizing: + image_processor = self._image_processor.get_equivalent_compose_without_resizing( + auto_padding=DetectionAutoPadding(shape_multiple=(32, 32), pad_value=0) + ) + else: + image_processor = self._image_processor + + pipeline = SlidingWindowDetectionPipeline( + model=self, + image_processor=image_processor, + post_prediction_callback=self.get_post_prediction_callback( + iou=iou, + conf=conf, + nms_top_k=nms_top_k, + max_predictions=max_predictions, + multi_label_per_box=multi_label_per_box, + class_agnostic_nms=class_agnostic_nms, + ), + class_names=self._class_names, + fuse_model=fuse_model, + fp16=fp16, + ) + return pipeline + + def predict( + self, + images: ImageSource, + iou: Optional[float] = None, + conf: Optional[float] = None, + batch_size: int = 32, + fuse_model: bool = True, + skip_image_resizing: bool = False, + nms_top_k: Optional[int] = None, + max_predictions: Optional[int] = None, + multi_label_per_box: Optional[bool] = None, + class_agnostic_nms: Optional[bool] = None, + fp16: bool = True, + ) -> ImagesDetectionPrediction: + """Predict an image or a list of images. + + :param images: Images to predict. + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + :param batch_size: Maximum number of images to process at the same time. + :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage. + :param skip_image_resizing: If True, the image processor will not resize the images. + :param nms_top_k: (Optional) The maximum number of detections to consider for NMS. + :param max_predictions: (Optional) The maximum number of detections to return. + :param multi_label_per_box: (Optional) If True, each anchor can produce multiple labels of different classes. + If False, each anchor can produce only one label of the class with the highest score. + :param class_agnostic_nms: (Optional) If True, perform class-agnostic NMS (i.e IoU of boxes of different classes is checked). + If False NMS is performed separately for each class. + :param fp16: If True, use mixed precision for inference. + """ + pipeline = self._get_pipeline( + iou=iou, + conf=conf, + fuse_model=fuse_model, + skip_image_resizing=skip_image_resizing, + nms_top_k=nms_top_k, + max_predictions=max_predictions, + multi_label_per_box=multi_label_per_box, + class_agnostic_nms=class_agnostic_nms, + fp16=fp16, + ) + return pipeline(images, batch_size=batch_size) # type: ignore + + def predict_webcam( + self, + iou: Optional[float] = None, + conf: Optional[float] = None, + fuse_model: bool = True, + skip_image_resizing: bool = False, + nms_top_k: Optional[int] = None, + max_predictions: Optional[int] = None, + multi_label_per_box: Optional[bool] = None, + class_agnostic_nms: Optional[bool] = None, + fp16: bool = True, + ): + """Predict using webcam. + + :param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage. + :param skip_image_resizing: If True, the image processor will not resize the images. + :param nms_top_k: (Optional) The maximum number of detections to consider for NMS. + :param max_predictions: (Optional) The maximum number of detections to return. + :param multi_label_per_box: (Optional) If True, each anchor can produce multiple labels of different classes. + If False, each anchor can produce only one label of the class with the highest score. + :param class_agnostic_nms: (Optional) If True, perform class-agnostic NMS (i.e IoU of boxes of different classes is checked). + If False NMS is performed separately for each class. + :param fp16: If True, use mixed precision for inference. + """ + pipeline = self._get_pipeline( + iou=iou, + conf=conf, + fuse_model=fuse_model, + skip_image_resizing=skip_image_resizing, + nms_top_k=nms_top_k, + max_predictions=max_predictions, + multi_label_per_box=multi_label_per_box, + class_agnostic_nms=class_agnostic_nms, + fp16=fp16, + ) + pipeline.predict_webcam() + + def get_input_channels(self) -> int: + return self.model.get_input_channels() diff --git a/src/super_gradients/training/pipelines/pipelines.py b/src/super_gradients/training/pipelines/pipelines.py index 8c625c2183..6f059d6de0 100644 --- a/src/super_gradients/training/pipelines/pipelines.py +++ b/src/super_gradients/training/pipelines/pipelines.py @@ -207,17 +207,7 @@ def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray]) ) # Predict - with eval_mode(self.model), torch.no_grad(), torch.cuda.amp.autocast(enabled=self.fp16): - torch_inputs = torch.from_numpy(np.array(preprocessed_images)).to(self.device) - torch_inputs = torch_inputs.to(self.dtype) - - if isinstance(self.model, SupportsInputShapeCheck): - self.model.validate_input_shape(torch_inputs.size()) - - if self.fuse_model: - self._fuse_model(torch_inputs) - model_output = self.model(torch_inputs) - predictions = self._decode_model_output(model_output, model_input=torch_inputs) + predictions = self.pass_images_through_model(preprocessed_images) # Postprocess postprocessed_predictions = [] @@ -229,6 +219,22 @@ def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray]) for image, prediction in zip(images, postprocessed_predictions): yield self._instantiate_image_prediction(image=image, prediction=prediction) + def pass_images_through_model(self, preprocessed_images: List[np.ndarray]) -> List[Prediction]: + with eval_mode(self.model), torch.no_grad(), torch.cuda.amp.autocast(enabled=self.fp16): + torch_inputs = self._prep_inputs_for_model(preprocessed_images) + model_output = self.model(torch_inputs) + predictions = self._decode_model_output(model_output, model_input=torch_inputs) + return predictions + + def _prep_inputs_for_model(self, preprocessed_images: List[np.ndarray]) -> torch.Tensor: + torch_inputs = torch.from_numpy(np.array(preprocessed_images)).to(self.device) + torch_inputs = torch_inputs.to(self.dtype) + if isinstance(self.model, SupportsInputShapeCheck): + self.model.validate_input_shape(torch_inputs.size()) + if self.fuse_model: + self._fuse_model(torch_inputs) + return torch_inputs + @abstractmethod def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[Prediction]: """Decode the model outputs, move each prediction to numpy and store it in a Prediction object. @@ -324,7 +330,10 @@ def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], m :return: Predicted Bboxes. """ post_nms_predictions = self.post_prediction_callback(model_output, device=self.device) + return self._decode_detection_model_output(model_input, post_nms_predictions) + @staticmethod + def _decode_detection_model_output(model_input: np.ndarray, post_nms_predictions: List[torch.Tensor]) -> List[DetectionPrediction]: predictions = [] for prediction, image in zip(post_nms_predictions, model_input): prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32) @@ -338,7 +347,6 @@ def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], m image_shape=image.shape, ) ) - return predictions def _instantiate_image_prediction(self, image: np.ndarray, prediction: DetectionPrediction) -> ImagePrediction: @@ -362,6 +370,31 @@ def _combine_image_prediction_to_video( return VideoDetectionPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images) +class SlidingWindowDetectionPipeline(DetectionPipeline): + def pass_images_through_model(self, preprocessed_images: List[np.ndarray]) -> List[Prediction]: + with eval_mode(self.model), torch.no_grad(), torch.cuda.amp.autocast(enabled=self.fp16): + torch_inputs = self._prep_inputs_for_model(preprocessed_images) + model_output = self.model(torch_inputs, sliding_window_post_prediction_callback=self.post_prediction_callback) + predictions = self._decode_model_output(model_output, model_input=torch_inputs) + return predictions + + def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]: + """Decode the model output, by applying post prediction callback. This includes NMS. + + :param model_output: Direct output of the model, without any post-processing. + :param model_input: Model input (i.e. images after preprocessing). + :return: Predicted Bboxes. + """ + return self._decode_detection_model_output(model_input, model_output) + + def _fuse_model(self, input_example: torch.Tensor): + logger.info("Fusing some of the model's layers. If this takes too much memory, you can deactivate it by setting `fuse_model=False`") + self.model = copy.deepcopy(self.model) + self.model.eval() + self.model.model.prep_model_for_conversion(input_size=input_example.shape[-2:]) + self.fuse_model = False + + class PoseEstimationPipeline(Pipeline): """Pipeline specifically designed for pose estimation tasks. The pipeline includes loading images, preprocessing, prediction, and postprocessing. diff --git a/tests/unit_tests/detection_sliding_window_wrapper_test.py b/tests/unit_tests/detection_sliding_window_wrapper_test.py new file mode 100644 index 0000000000..b4f26e55c9 --- /dev/null +++ b/tests/unit_tests/detection_sliding_window_wrapper_test.py @@ -0,0 +1,30 @@ +import unittest +from pathlib import Path + +from super_gradients.training import models +from super_gradients.training.dataloaders import coco2017_val_yolo_nas +from super_gradients.training import Trainer +from super_gradients.training.models.detection_models.sliding_window_detection_forward_wrapper import SlidingWindowInferenceDetectionWrapper +from super_gradients.training.metrics import DetectionMetrics + + +class SlidingWindowWrapperTest(unittest.TestCase): + def setUp(self): + self.mini_coco_data_dir = str(Path(__file__).parent.parent / "data" / "tinycoco") + + def test_yolo_nas_s_coco_with_sliding_window(self): + trainer = Trainer("test_yolo_nas_s_coco_with_sliding_window") + model = models.get("yolo_nas_s", num_classes=80, pretrained_weights="coco") + model = SlidingWindowInferenceDetectionWrapper(tile_size=320, tile_step=160, model=model, tile_nms_iou=0.65, tile_nms_conf=0.03) + dl = coco2017_val_yolo_nas(dataset_params=dict(data_dir=self.mini_coco_data_dir)) + metric = DetectionMetrics( + normalize_targets=True, + post_prediction_callback=None, + num_cls=80, + ) + metric_values = trainer.test(model=model, test_loader=dl, test_metrics_list=[metric]) + self.assertAlmostEqual(metric_values[metric.map_str], 0.342, delta=0.001) + + +if __name__ == "__main__": + unittest.main()