diff --git a/README.md b/README.md index 042dd94..1b6c44f 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,14 @@ |`pp_layout_table`| 表格 | `layout_table.onnx` |`table` | | `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`text title list table figure` | | `pp_layout_table`| 中文 | `layout_cdla.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` | +| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` | +| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `text title header footer figure figure_caption table table_caption`
`toc` | -模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md) +PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md) -模型下载地址为:[百度网盘](https://pan.baidu.com/s/1PI9fksW6F6kQfJhwUkewWg?pwd=p29g) | [Google Drive](https://drive.google.com/drive/folders/1DAPWSN2zGQ-ED_Pz7RaJGTjfkN2-Mvsf?usp=sharing) +yolov8n系列来源:[360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis) + +模型下载地址为:[link](https://github.com/RapidAI/RapidLayout/releases/tag/v0.0.0) ### 安装 由于模型较小,预先将中文版面分析模型(`layout_cdla.onnx`)打包进了whl包内,如果做中文版面分析,可直接安装使用 @@ -41,7 +45,7 @@ import cv2 from rapid_layout import RapidLayout, VisLayout # model_type类型参见上表。指定不同model_type时,会自动下载相应模型到安装目录下的。 -layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla") +layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla") img = cv2.imread('test_images/layout.png') @@ -55,18 +59,23 @@ if ploted_img is not None: - 用法: ```bash $ rapid_layout -h - usage: rapid_layout [-h] -img IMG_PATH [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}] - [--box_threshold {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}] [-v] + usage: rapid_layout [-h] -img IMG_PATH + [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}] + [--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}] + [--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}] + [-v] options: - -h, --help show this help message and exit - -img IMG_PATH, --img_path IMG_PATH + -h, --help show this help message and exit + -img IMG_PATH, --img_path IMG_PATH Path to image for layout. - -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table} + -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report} Support model type - --box_threshold {pp_layout_cdla,pp_layout_publaynet,pp_layout_table} + --conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report} Box threshold, the range is [0, 1] - -v, --vis Wheter to visualize the layout results. + --iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report} + IoU threshold, the range is [0, 1] + -v, --vis Wheter to visualize the layout results. ``` - 示例: ```bash diff --git a/demo.py b/demo.py index b73f279..d5e07f3 100644 --- a/demo.py +++ b/demo.py @@ -5,7 +5,7 @@ from rapid_layout import RapidLayout, VisLayout -layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla") +layout_engine = RapidLayout(model_type="yolov8n_layout_paper") img_path = "tests/test_files/layout.png" img = cv2.imread(img_path) diff --git a/rapid_layout/config.yaml b/rapid_layout/config.yaml deleted file mode 100644 index 33a85d2..0000000 --- a/rapid_layout/config.yaml +++ /dev/null @@ -1,24 +0,0 @@ -model_path: models/layout_cdla.onnx - -use_cuda: false -CUDAExecutionProvider: - device_id: 0 - arena_extend_strategy: kNextPowerOfTwo - cudnn_conv_algo_search: EXHAUSTIVE - do_copy_in_default_stream: true - -pre_process: - Resize: - size: [800, 608] - NormalizeImage: - std: [0.229, 0.224, 0.225] - mean: [0.485, 0.456, 0.406] - scale: 1./255. - order: hwc - ToCHWImage: - KeepKeys: - keep_keys: ['image'] - -post_process: - score_threshold: 0.5 - nms_threshold: 0.5 \ No newline at end of file diff --git a/rapid_layout/main.py b/rapid_layout/main.py index b392c92..14cb89f 100644 --- a/rapid_layout/main.py +++ b/rapid_layout/main.py @@ -14,11 +14,11 @@ LoadImage, OrtInferSession, PicoDetPostProcess, + PPPreProcess, VisLayout, - create_operators, + YOLOv8PostProcess, + YOLOv8PreProcess, get_logger, - read_yaml, - transform, ) ROOT_DIR = Path(__file__).resolve().parent @@ -29,64 +29,86 @@ "pp_layout_cdla": f"{ROOT_URL}/layout_cdla.onnx", "pp_layout_publaynet": f"{ROOT_URL}/layout_publaynet.onnx", "pp_layout_table": f"{ROOT_URL}/layout_table.onnx", + "yolov8n_layout_paper": f"{ROOT_URL}/yolov8n_layout_paper.onnx", + "yolov8n_layout_report": f"{ROOT_URL}/yolov8n_layout_report.onnx", } DEFAULT_MODEL_PATH = str(ROOT_DIR / "models" / "layout_cdla.onnx") class RapidLayout: + def __init__( self, model_type: str = "pp_layout_cdla", - box_threshold: float = 0.5, + model_path: Union[str, Path, None] = None, + conf_thres: float = 0.5, + iou_thres: float = 0.5, use_cuda: bool = False, ): - config_path = str(ROOT_DIR / "config.yaml") - config = read_yaml(config_path) - config["model_path"] = self.get_model_path(model_type) - config["use_cuda"] = use_cuda - + self.model_type = model_type + config = { + "model_path": self.get_model_path(model_type, model_path), + "use_cuda": use_cuda, + } self.session = OrtInferSession(config) labels = self.session.get_character_list() logger.info("%s contains %s", model_type, labels) - self.preprocess_op = create_operators(config["pre_process"]) + # pp + self.pp_preprocess = PPPreProcess(img_size=(800, 608)) + self.pp_postprocess = PicoDetPostProcess(labels, conf_thres, iou_thres) + + # yolov8 + self.yolov8_input_shape = (640, 640) + self.yolo_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape) + self.yolo_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres) - config["post_process"]["score_threshold"] = box_threshold - self.postprocess_op = PicoDetPostProcess(labels, **config["post_process"]) self.load_img = LoadImage() + self.pp_layout_type = [ + "pp_layout_cdla", + "pp_layout_publaynet", + "pp_layout_table", + ] + self.yolov8_layout_type = ["yolov8n_layout_paper", "yolov8n_layout_report"] + def __call__( self, img_content: Union[str, np.ndarray, bytes, Path] ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], float]: img = self.load_img(img_content) + ori_img_shape = img.shape[:2] - ori_im = img.copy() - data = transform({"image": img}, self.preprocess_op) - img = data[0] - if img is None: - return None, None, None, 0.0 + if self.model_type in self.pp_layout_type: + return self.pp_layout(img, ori_img_shape) - img = np.expand_dims(img, axis=0) - img = img.copy() + if self.model_type in self.yolov8_layout_type: + return self.yolov8_layout(img, ori_img_shape) - preds, elapse = 0, 1 - starttime = time.time() + raise ValueError(f"{self.model_type} is not supported.") + + def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]): + s_time = time.time() + + img = self.pp_preprocess(img) preds = self.session(img) + boxes, scores, class_names = self.pp_postprocess(ori_img_shape, img, preds) - score_list, boxes_list = [], [] - num_outs = int(len(preds) / 2) - for out_idx in range(num_outs): - score_list.append(preds[out_idx]) - boxes_list.append(preds[out_idx + num_outs]) + elapse = time.time() - s_time + return boxes, scores, class_names, elapse - boxes, scores, class_names = self.postprocess_op( - ori_im, img, {"boxes": score_list, "boxes_num": boxes_list} + def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]): + input_tensor = self.yolo_preprocess(img) + outputs = self.session(input_tensor) + boxes, scores, class_names = self.yolo_postprocess( + outputs, ori_img_shape, self.yolov8_input_shape ) - elapse = time.time() - starttime - return boxes, scores, class_names, elapse + return boxes, scores, class_names @staticmethod - def get_model_path(model_type: str) -> str: + def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str: + if model_path is not None: + return model_path + model_url = KEY_TO_MODEL_URL.get(model_type, None) if model_url: model_path = DownloadModel.download(model_url) @@ -110,12 +132,19 @@ def main(): help="Support model type", ) parser.add_argument( - "--box_threshold", + "--conf_thres", type=float, default=0.5, choices=list(KEY_TO_MODEL_URL.keys()), help="Box threshold, the range is [0, 1]", ) + parser.add_argument( + "--iou_thres", + type=float, + default=0.5, + choices=list(KEY_TO_MODEL_URL.keys()), + help="IoU threshold, the range is [0, 1]", + ) parser.add_argument( "-v", "--vis", @@ -125,7 +154,7 @@ def main(): args = parser.parse_args() layout_engine = RapidLayout( - model_type=args.model_type, box_threshold=args.box_threshold + model_type=args.model_type, conf_thres=args.conf_thres, iou_thres=args.iou_thres ) img = cv2.imread(args.img_path) diff --git a/rapid_layout/utils/__init__.py b/rapid_layout/utils/__init__.py index beb7dac..0cadd2a 100644 --- a/rapid_layout/utils/__init__.py +++ b/rapid_layout/utils/__init__.py @@ -7,8 +7,8 @@ from .infer_engine import OrtInferSession from .load_image import LoadImage from .logger import get_logger -from .post_prepross import PicoDetPostProcess -from .pre_procss import create_operators, transform +from .post_prepross import PicoDetPostProcess, YOLOv8PostProcess +from .pre_procss import PPPreProcess, YOLOv8PreProcess from .vis_res import VisLayout diff --git a/rapid_layout/utils/post_prepross.py b/rapid_layout/utils/post_prepross.py index f134cf8..eacd006 100644 --- a/rapid_layout/utils/post_prepross.py +++ b/rapid_layout/utils/post_prepross.py @@ -1,39 +1,36 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +from typing import List, Tuple + import numpy as np class PicoDetPostProcess: - def __init__( - self, - labels, - strides=[8, 16, 32, 64], - score_threshold=0.4, - nms_threshold=0.5, - nms_top_k=1000, - keep_top_k=100, - ): + def __init__(self, labels, conf_thres=0.4, iou_thres=0.5): self.labels = labels - self.strides = strides - self.score_threshold = score_threshold - self.nms_threshold = nms_threshold - self.nms_top_k = nms_top_k - self.keep_top_k = keep_top_k - - def __call__(self, ori_img, img, preds): - scores, raw_boxes = preds["boxes"], preds["boxes_num"] + self.strides = [8, 16, 32, 64] + self.conf_thres = conf_thres + self.iou_thres = iou_thres + self.nms_top_k = 1000 + self.keep_top_k = 100 + + def __call__(self, ori_shape, img, preds): + scores, raw_boxes = [], [] + num_outs = int(len(preds) / 2) + for out_idx in range(num_outs): + scores.append(preds[out_idx]) + raw_boxes.append(preds[out_idx + num_outs]) + batch_size = raw_boxes[0].shape[0] reg_max = int(raw_boxes[0].shape[-1] / 4 - 1) - out_boxes_num = [] - out_boxes_list = [] - ori_shape, input_shape, scale_factor = self.img_info(ori_img, img) + out_boxes_num, out_boxes_list = [], [] + ori_shape, input_shape, scale_factor = self.img_info(ori_shape, img) for batch_id in range(batch_size): # generate centers - decode_boxes = [] - select_scores = [] + decode_boxes, select_scores = [], [] for stride, box_distribute, score in zip(self.strides, raw_boxes, scores): box_distribute = box_distribute[batch_id] score = score[batch_id] @@ -71,19 +68,19 @@ def __call__(self, ori_img, img, preds): # nms bboxes = np.concatenate(decode_boxes, axis=0) confidences = np.concatenate(select_scores, axis=0) - picked_box_probs = [] - picked_labels = [] + picked_box_probs, picked_labels = [], [] for class_index in range(0, confidences.shape[1]): probs = confidences[:, class_index] - mask = probs > self.score_threshold + mask = probs > self.conf_thres probs = probs[mask] if probs.shape[0] == 0: continue + subset_boxes = bboxes[mask, :] box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1) box_probs = self.hard_nms( box_probs, - iou_threshold=self.nms_threshold, + iou_thres=self.iou_thres, top_k=self.keep_top_k, ) picked_box_probs.append(box_probs) @@ -92,7 +89,6 @@ def __call__(self, ori_img, img, preds): if len(picked_box_probs) == 0: out_boxes_list.append(np.empty((0, 4))) out_boxes_num.append(0) - else: picked_box_probs = np.concatenate(picked_box_probs) @@ -129,11 +125,6 @@ def __call__(self, ori_img, img, preds): class_names.append(label) return np.array(boxes), np.array(scores), np.array(class_names) - def load_layout_dict(self, layout_dict_path): - with open(layout_dict_path, "r", encoding="utf-8") as fp: - labels = fp.readlines() - return [label.strip("\n") for label in labels] - def warp_boxes(self, boxes, ori_shape): """Apply transform to boxes""" width, height = ori_shape[1], ori_shape[0] @@ -158,8 +149,7 @@ def warp_boxes(self, boxes, ori_shape): return xy.astype(np.float32) return boxes - def img_info(self, ori_img, img): - origin_shape = ori_img.shape + def img_info(self, origin_shape, img): resize_shape = img.shape im_scale_y = resize_shape[2] / float(origin_shape[0]) im_scale_x = resize_shape[3] / float(origin_shape[1]) @@ -195,11 +185,11 @@ def logsumexp(a, axis=None, b=None, keepdims=False): return np.exp(x - logsumexp(x, axis=axis, keepdims=True)) - def hard_nms(self, box_scores, iou_threshold, top_k=-1, candidate_size=200): + def hard_nms(self, box_scores, iou_thres, top_k=-1, candidate_size=200): """ Args: box_scores (N, 5): boxes in corner-form and probabilities. - iou_threshold: intersection over union threshold. + iou_thres: intersection over union threshold. top_k: keep top_k results. If k <= 0, keep all the results. candidate_size: only consider the candidates with the highest scores. Returns: @@ -222,7 +212,7 @@ def hard_nms(self, box_scores, iou_threshold, top_k=-1, candidate_size=200): rest_boxes, np.expand_dims(current_box, axis=0), ) - indexes = indexes[iou <= iou_threshold] + indexes = indexes[iou <= iou_thres] return box_scores[picked, :] @@ -254,3 +244,135 @@ def area_of(left_top, right_bottom): """ hw = np.clip(right_bottom - left_top, 0.0, None) return hw[..., 0] * hw[..., 1] + + +class YOLOv8PostProcess: + + def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5): + self.labels = labels + self.conf_threshold = conf_thres + self.iou_threshold = iou_thres + self.input_width, self.input_height = None, None + self.img_width, self.img_height = None, None + + def __call__( + self, output, ori_img_shape: Tuple[int, int], img_shape: Tuple[int, int] + ): + self.img_height, self.img_width = ori_img_shape + self.input_height, self.input_width = img_shape + + predictions = np.squeeze(output[0]).T + + # Filter out object confidence scores below threshold + scores = np.max(predictions[:, 4:], axis=1) + predictions = predictions[scores > self.conf_threshold, :] + scores = scores[scores > self.conf_threshold] + + if len(scores) == 0: + return [], [], [] + + # Get the class with the highest confidence + class_ids = np.argmax(predictions[:, 4:], axis=1) + + # Get bounding boxes for each object + boxes = self.extract_boxes(predictions) + + # Apply non-maxima suppression to suppress weak, overlapping bounding boxes + # indices = nms(boxes, scores, self.iou_threshold) + indices = multiclass_nms(boxes, scores, class_ids, self.iou_threshold) + + labels = [self.labels[i] for i in class_ids[indices]] + return boxes[indices], scores[indices], labels + + def extract_boxes(self, predictions): + # Extract boxes from predictions + boxes = predictions[:, :4] + + # Scale boxes to original image dimensions + boxes = self.rescale_boxes(boxes) + + # Convert boxes to xyxy format + boxes = xywh2xyxy(boxes) + + return boxes + + def rescale_boxes(self, boxes): + + # Rescale boxes to original image dimensions + input_shape = np.array( + [self.input_width, self.input_height, self.input_width, self.input_height] + ) + boxes = np.divide(boxes, input_shape, dtype=np.float32) + boxes *= np.array( + [self.img_width, self.img_height, self.img_width, self.img_height] + ) + return boxes + + +def nms(boxes, scores, iou_threshold): + # Sort by score + sorted_indices = np.argsort(scores)[::-1] + + keep_boxes = [] + while sorted_indices.size > 0: + # Pick the last box + box_id = sorted_indices[0] + keep_boxes.append(box_id) + + # Compute IoU of the picked box with the rest + ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :]) + + # Remove boxes with IoU over the threshold + keep_indices = np.where(ious < iou_threshold)[0] + + # print(keep_indices.shape, sorted_indices.shape) + sorted_indices = sorted_indices[keep_indices + 1] + + return keep_boxes + + +def multiclass_nms(boxes, scores, class_ids, iou_threshold): + + unique_class_ids = np.unique(class_ids) + + keep_boxes = [] + for class_id in unique_class_ids: + class_indices = np.where(class_ids == class_id)[0] + class_boxes = boxes[class_indices, :] + class_scores = scores[class_indices] + + class_keep_boxes = nms(class_boxes, class_scores, iou_threshold) + keep_boxes.extend(class_indices[class_keep_boxes]) + + return keep_boxes + + +def compute_iou(box, boxes): + # Compute xmin, ymin, xmax, ymax for both boxes + xmin = np.maximum(box[0], boxes[:, 0]) + ymin = np.maximum(box[1], boxes[:, 1]) + xmax = np.minimum(box[2], boxes[:, 2]) + ymax = np.minimum(box[3], boxes[:, 3]) + + # Compute intersection area + intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) + + # Compute union area + box_area = (box[2] - box[0]) * (box[3] - box[1]) + boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + union_area = box_area + boxes_area - intersection_area + + # Compute IoU + iou = intersection_area / union_area + + return iou + + +def xywh2xyxy(x): + # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2) + y = np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 + y[..., 1] = x[..., 1] - x[..., 3] / 2 + y[..., 2] = x[..., 0] + x[..., 2] / 2 + y[..., 3] = x[..., 1] + x[..., 3] / 2 + return y diff --git a/rapid_layout/utils/pre_procss.py b/rapid_layout/utils/pre_procss.py index f5e3e21..78ce748 100644 --- a/rapid_layout/utils/pre_procss.py +++ b/rapid_layout/utils/pre_procss.py @@ -2,7 +2,7 @@ # @Author: SWHL # @Contact: liekkaskono@163.com from pathlib import Path -from typing import Union +from typing import Optional, Tuple, Union import cv2 import numpy as np @@ -10,94 +10,44 @@ InputType = Union[str, np.ndarray, bytes, Path] -def transform(data, ops=None): - """transform""" - if ops is None: - ops = [] +class PPPreProcess: - for op in ops: - data = op(data) - if data is None: - return None - return data + def __init__(self, img_size: Tuple[int, int]): + self.size = img_size + self.mean = np.array([0.485, 0.456, 0.406]) + self.std = np.array([0.229, 0.224, 0.225]) + self.scale = 1 / 255.0 + def __call__(self, img: Optional[np.ndarray] = None) -> np.ndarray: + if img is None: + raise ValueError("img is None.") -def create_operators(op_param_dict): - ops = [] - for op_name, param in op_param_dict.items(): - if param is None: - param = {} - op = eval(op_name)(**param) - ops.append(op) - return ops + img = self.resize(img) + img = self.normalize(img) + img = self.permute(img) + img = np.expand_dims(img, axis=0) + return img.astype(np.float32) - -class Resize: - def __init__(self, size=(640, 640)): - self.size = size - - def resize_image(self, img): + def resize(self, img: np.ndarray) -> np.ndarray: resize_h, resize_w = self.size - ori_h, ori_w = img.shape[:2] # (h, w, c) - ratio_h = float(resize_h) / ori_h - ratio_w = float(resize_w) / ori_w img = cv2.resize(img, (int(resize_w), int(resize_h))) - return img, [ratio_h, ratio_w] - - def __call__(self, data): - img = data["image"] - if "polys" in data: - text_polys = data["polys"] - - img_resize, [ratio_h, ratio_w] = self.resize_image(img) - if "polys" in data: - new_boxes = [] - for box in text_polys: - new_box = [] - for cord in box: - new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) - new_boxes.append(new_box) - data["polys"] = np.array(new_boxes, dtype=np.float32) - data["image"] = img_resize - return data - - -class NormalizeImage: - def __init__(self, scale=None, mean=None, std=None, order="chw"): - if isinstance(scale, str): - scale = eval(scale) - - self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) - mean = mean if mean is not None else [0.485, 0.456, 0.406] - std = std if std is not None else [0.229, 0.224, 0.225] - - shape = (3, 1, 1) if order == "chw" else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype("float32") - self.std = np.array(std).reshape(shape).astype("float32") - - def __call__(self, data): - img = np.array(data["image"]) - assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" - data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std - return data + return img + def normalize(self, img: np.ndarray) -> np.ndarray: + return (img.astype("float32") * self.scale - self.mean) / self.std -class ToCHWImage: - def __init__(self, **kwargs): - pass + def permute(self, img: np.ndarray) -> np.ndarray: + return img.transpose((2, 0, 1)) - def __call__(self, data): - img = np.array(data["image"]) - data["image"] = img.transpose((2, 0, 1)) - return data +class YOLOv8PreProcess: -class KeepKeys: - def __init__(self, keep_keys): - self.keep_keys = keep_keys + def __init__(self, img_size: Tuple[int, int]): + self.img_size = img_size - def __call__(self, data): - data_list = [] - for key in self.keep_keys: - data_list.append(data[key]) - return data_list + def __call__(self, image: np.ndarray) -> np.ndarray: + input_img = cv2.resize(image, self.img_size) + input_img = input_img / 255.0 + input_img = input_img.transpose(2, 0, 1) + input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32) + return input_tensor diff --git a/tests/test_layout.py b/tests/test_layout.py index 1d3fa5c..fa56d7b 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -15,8 +15,6 @@ from rapid_layout import RapidLayout test_file_dir = cur_dir / "test_files" -layout_engine = RapidLayout() - img_path = test_file_dir / "layout.png" img = cv2.imread(str(img_path)) @@ -26,5 +24,15 @@ "img_content", [img_path, str(img_path), open(img_path, "rb").read(), img] ) def test_multi_input(img_content): - boxes, scores, class_names, *elapse = layout_engine(img_content) + engine = RapidLayout() + boxes, scores, class_names, *elapse = engine(img_content) assert len(boxes) == 15 + + +@pytest.mark.parametrize( + "img_content", [img_path, str(img_path), open(img_path, "rb").read(), img] +) +def test_yolov8_input(img_content): + engine = RapidLayout(model_type="yolov8n_layout_paper") + boxes, scores, class_names, *elapse = engine(img_content) + assert len(boxes) == 11