From e583fb93fc99fcab02cebe3a44c5da6f1900ee0e Mon Sep 17 00:00:00 2001 From: DeHors Date: Wed, 6 Mar 2024 22:45:49 +0800 Subject: [PATCH] update ocr --- AUTHORS.md | 2 + EduNLP/SIF/parser/ocr.py | 127 ++++++++++++++++++++++++++++++++++ EduNLP/SIF/segment/segment.py | 17 +++-- tests/test_sif/test_ocr.py | 28 ++++++++ 4 files changed, 169 insertions(+), 5 deletions(-) create mode 100644 EduNLP/SIF/parser/ocr.py create mode 100644 tests/test_sif/test_ocr.py diff --git a/AUTHORS.md b/AUTHORS.md index 870aed9c..bcafe6e0 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -22,4 +22,6 @@ [Shangzi Xue](https://github.com/ShangziXue) +[Heng Yu](https://github.com/GNEHUY) + The stared contributors are the corresponding authors. diff --git a/EduNLP/SIF/parser/ocr.py b/EduNLP/SIF/parser/ocr.py new file mode 100644 index 00000000..55f3af00 --- /dev/null +++ b/EduNLP/SIF/parser/ocr.py @@ -0,0 +1,127 @@ +# coding: utf-8 +# 2024/3/5 @ yuheng +import json +import requests +from EduNLP.utils import image2base64 + +class FormulaRecognitionError(Exception): + """Exception raised when formula recognition fails.""" + def __init__(self, message="Formula recognition failed"): + self.message = message + super().__init__(self.message) + +def ocr_formula_figure(image_PIL_or_base64, is_base64=False): + """ + Recognizes mathematical formulas in an image and returns their LaTeX representation. + + Parameters + ---------- + image_PIL_or_base64 : PngImageFile or str + The PngImageFile if is_base64 is False, or the base64 encoded string of the image if is_base64 is True. + is_base64 : bool, optional + Indicates whether the image_PIL_or_base64 parameter is an PngImageFile or a base64 encoded string, by default False. + + Returns + ------- + latex : str + The LaTeX representation of the mathematical formula recognized in the image. Raises an exception if the image is not recognized as containing a mathematical formula. + + Raises + ------ + FormulaRecognitionError + If the HTTP request does not return a 200 status code, if there is an error processing the response, or if the image is not recognized as a mathematical formula. + + Examples + -------- + >>> from PIL import Image + >>> image_PIL = Image.open("path/to/your/image.jpg") + >>> print(ocr_formula_figure(image_PIL)) + Or + >>> image_base64 = "base64_encoded_image_string" + >>> print(ocr_formula_figure(image_base64, is_base64=True)) + + Notes + ----- + This function relies on an external service "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1", + and the `requests` library to make HTTP requests. Make sure the required libraries are installed before use. + """ + url = "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1" + + if is_base64: + image = image_PIL_or_base64 + else: + image = image2base64(image_PIL_or_base64) + + data = [{ + 'qid': 0, + 'image': image + }] + + resp = requests.post(url, data=json.dumps(data)) + + if resp.status_code != 200: + raise FormulaRecognitionError(f"HTTP error {resp.status_code}: {resp.text}") + + try: + res = json.loads(resp.content) + except Exception as e: + raise FormulaRecognitionError(f"Error processing response: {e}") + + res = json.loads(resp.content) + data = res['data'] + if data['success'] == 1 and data['is_formula'] == 1 and data['detect_formula'] == 1: + latex = data['latex'] + else: + latex = None + raise FormulaRecognitionError("Image is not recognized as a formula") + + return latex + +def ocr(src, is_base64=False, figure_instances: dict = None): + """ + Recognizes mathematical formulas within figures from a given source, which can be either a base64 string or an identifier for a figure within a provided dictionary. + + Parameters + ---------- + src : str + The source from which the figure is to be recognized. It can be a base64 encoded string of the image if is_base64 is True, or an identifier for the figure if is_base64 is False. + is_base64 : bool, optional + Indicates whether the src parameter is a base64 encoded string or an identifier, by default False. + figure_instances : dict, optional + A dictionary mapping figure identifiers to their corresponding PngImageFile, by default None. This is only required and used if is_base64 is False. + + Returns + ------- + forumla_figure_latex : str or None + The LaTeX representation of the mathematical formula recognized within the figure. Returns None if no formula is recognized or if the figure_instances dictionary does not contain the specified figure identifier when is_base64 is False. + + Examples + -------- + >>> src_base64 = r"\FormFigureBase64{base64_encoded_image_string}" + >>> print(ocr(src_base64, is_base64=True)) + Or + >>> from PIL import Image + >>> image_PIL = Image.open("path/to/your/image.jpg") + >>> figure_instances = {"figure1": image_PIL} + >>> src_id = r"\FormFigureID{figure1}" + >>> print(ocr(src_id, figure_instances=figure_instances)) + + Notes + ----- + This function relies on `ocr_formula_figure` for the actual OCR (Optical Character Recognition) process. Ensure that `ocr_formula_figure` is correctly implemented and can handle both base64 encoded strings and PngImageFile as input. + """ + forumla_figure_latex = None + if is_base64: + figure = src[len(r"\FormFigureBase64") + 1: -1] + if figure_instances is not None: + forumla_figure_latex = ocr_formula_figure(figure, is_base64) + else: + figure = src[len(r"\FormFigureID") + 1: -1] + if figure_instances is not None: + figure = figure_instances[figure] + forumla_figure_latex = ocr_formula_figure(figure, is_base64) + + return forumla_figure_latex + + + diff --git a/EduNLP/SIF/segment/segment.py b/EduNLP/SIF/segment/segment.py index 3b3bf227..517e8b2a 100644 --- a/EduNLP/SIF/segment/segment.py +++ b/EduNLP/SIF/segment/segment.py @@ -5,6 +5,7 @@ import re from contextlib import contextmanager from ..constants import Symbol, TEXT_SYMBOL, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL +from ..parser.ocr import ocr class TextSegment(str): @@ -93,7 +94,7 @@ class SegmentList(object): >>> SegmentList(test_item) ['如图所示,则三角形', 'ABC', '的面积是', '\\\\SIFBlank', '。', \\FigureID{1}] """ - def __init__(self, item, figures: dict = None): + def __init__(self, item, figures: dict = None, convert_image_to_latex=False): self._segments = [] self._text_segments = [] self._formula_segments = [] @@ -112,9 +113,15 @@ def __init__(self, item, figures: dict = None): if not re.match(r"\$.+?\$", segment): self.append(TextSegment(segment)) elif re.match(r"\$\\FormFigureID\{.+?}\$", segment): - self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures)) + if convert_image_to_latex: + self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=False, figure_instances=figures))) + else: + self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures)) elif re.match(r"\$\\FormFigureBase64\{.+?}\$", segment): - self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures)) + if convert_image_to_latex: + self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=True, figure_instances=figures))) + else: + self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures)) elif re.match(r"\$\\FigureID\{.+?}\$", segment): self.append(FigureSegment(segment[1:-1], is_base64=False, figure_instances=figures)) elif re.match(r"\$\\FigureBase64\{.+?}\$", segment): @@ -271,7 +278,7 @@ def describe(self): } -def seg(item, figures=None, symbol=None): +def seg(item, figures=None, symbol=None, convert_image_to_latex=False): r""" It is a interface for SegmentList. And show it in an appropriate way. @@ -346,7 +353,7 @@ def seg(item, figures=None, symbol=None): >>> s2.text_segments ['已知', ',则以下说法中正确的是'] """ - segments = SegmentList(item, figures) + segments = SegmentList(item, figures, convert_image_to_latex) if symbol is not None: segments.symbolize(symbol) return segments diff --git a/tests/test_sif/test_ocr.py b/tests/test_sif/test_ocr.py new file mode 100644 index 00000000..c010e75c --- /dev/null +++ b/tests/test_sif/test_ocr.py @@ -0,0 +1,28 @@ +# 2024/3/5 @ yuheng + +import pytest + +from EduNLP.SIF.segment import seg + + +def test_ocr(figure0, figure1, figure0_base64, figure1_base64): + seg( + r"如图所示,则$\FormFigureID{0}$的面积是$\SIFBlank$。$\FigureID{1}$", + figures={ + "0": figure0, + "1": figure1 + }, + convert_image_to_latex=True + ) + s = seg( + r"如图所示,则$\FormFigureBase64{%s}$的面积是$\SIFBlank$。$\FigureBase64{%s}$" % (figure0_base64, figure1_base64), + figures=True, + convert_image_to_latex=True + ) + with pytest.raises(TypeError): + s.append("123") + seg_test_text = seg( + r"如图所示,有三组$\textf{机器人,bu}$在踢$\textf{足球,b}$", + figures=True + ) + assert seg_test_text.text_segments == ['如图所示,有三组机器人在踢足球']