diff --git a/hawk_eye/inference/find_targets.py b/hawk_eye/inference/find_targets.py index 0fc1283..93dbb8a 100755 --- a/hawk_eye/inference/find_targets.py +++ b/hawk_eye/inference/find_targets.py @@ -259,41 +259,67 @@ def find_targets( def globalize_boxes( - results: List[postprocess.BoundingBox], img_size: int + results: List[postprocess.BoundingBox], tile_size: int, exclusion_region: int = 5 ) -> List[inference_types.Target]: """Take the normalized detections on a _tile_ and gloabalize them to pixel space of the original large image. Args: results: A list of the detections for the tiles. - img_size: The size of the tile whihc is needed to unnormalize the detections. + tile_size: The size of the tile which is needed to unnormalize the detections. + exclusion_region: The number of pixels from the edge of the tile in which + a target will be thrown out. Returns: A list of the globalized boxes """ final_targets = [] - img_size = torch.Tensor([img_size] * 4) + img_size = torch.Tensor([tile_size] * 4) for coords, bboxes in results: for box in bboxes: relative_coords = box.box * img_size - relative_coords += torch.Tensor(2 * list(coords)).int() - final_targets.append( - inference_types.Target( - x=int(relative_coords[0]), - y=int(relative_coords[1]), - width=int(relative_coords[2] - relative_coords[0]), - height=int(relative_coords[3] - relative_coords[1]), - shape=inference_types.Shape[ - config.OD_CLASSES[box.class_id].upper().replace("-", "_") - ], + if resolve_boxes(relative_coords, tile_size, exclusion_region): + relative_coords += torch.Tensor(2 * list(coords)).int() + final_targets.append( + inference_types.Target( + x=int(relative_coords[0]), + y=int(relative_coords[1]), + width=int(relative_coords[2] - relative_coords[0]), + height=int(relative_coords[3] - relative_coords[1]), + shape=inference_types.Shape[ + config.OD_CLASSES[box.class_id].upper().replace("-", "_") + ], + ) ) - ) return final_targets +def resolve_boxes(relative_coords, tile_size: int, exclusion_region: int) -> bool: + """Finds targets that are within a boundary at the edge of the tile. + If true, the target is within the boundary. + + Examples:: + + >>> resolve_boxes(torch.Tensor([0, 10, 60, 70]), 512, 5) + True + >>> resolve_boxes(torch.Tensor([10, 10, 60, 70]), 512, 5) + False + """ + + if ( + int(relative_coords[1]) - exclusion_region <= 0 + or int(relative_coords[3]) + exclusion_region >= tile_size + or int(relative_coords[0]) - exclusion_region <= 0 + or int(relative_coords[2]) + exclusion_region >= tile_size + ): + return True + + return False + + def visualize_image( image_name: str, image: np.ndarray,