Skip to content

Commit

Permalink
update ocr
Browse files Browse the repository at this point in the history
  • Loading branch information
GNEHUY committed Mar 6, 2024
1 parent 504d147 commit e583fb9
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 5 deletions.
2 changes: 2 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@

[Shangzi Xue](https://github.com/ShangziXue)

[Heng Yu](https://github.com/GNEHUY)

The stared contributors are the corresponding authors.
127 changes: 127 additions & 0 deletions EduNLP/SIF/parser/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# coding: utf-8
# 2024/3/5 @ yuheng
import json
import requests
from EduNLP.utils import image2base64

class FormulaRecognitionError(Exception):
"""Exception raised when formula recognition fails."""
def __init__(self, message="Formula recognition failed"):
self.message = message
super().__init__(self.message)

def ocr_formula_figure(image_PIL_or_base64, is_base64=False):
"""
Recognizes mathematical formulas in an image and returns their LaTeX representation.
Parameters
----------
image_PIL_or_base64 : PngImageFile or str
The PngImageFile if is_base64 is False, or the base64 encoded string of the image if is_base64 is True.
is_base64 : bool, optional
Indicates whether the image_PIL_or_base64 parameter is an PngImageFile or a base64 encoded string, by default False.
Returns
-------
latex : str
The LaTeX representation of the mathematical formula recognized in the image. Raises an exception if the image is not recognized as containing a mathematical formula.
Raises
------
FormulaRecognitionError
If the HTTP request does not return a 200 status code, if there is an error processing the response, or if the image is not recognized as a mathematical formula.
Examples
--------
>>> from PIL import Image
>>> image_PIL = Image.open("path/to/your/image.jpg")
>>> print(ocr_formula_figure(image_PIL))
Or
>>> image_base64 = "base64_encoded_image_string"
>>> print(ocr_formula_figure(image_base64, is_base64=True))
Notes
-----
This function relies on an external service "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1",
and the `requests` library to make HTTP requests. Make sure the required libraries are installed before use.
"""
url = "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1"

if is_base64:
image = image_PIL_or_base64
else:
image = image2base64(image_PIL_or_base64)

data = [{
'qid': 0,
'image': image
}]

resp = requests.post(url, data=json.dumps(data))

if resp.status_code != 200:
raise FormulaRecognitionError(f"HTTP error {resp.status_code}: {resp.text}")

try:
res = json.loads(resp.content)
except Exception as e:
raise FormulaRecognitionError(f"Error processing response: {e}")

res = json.loads(resp.content)
data = res['data']
if data['success'] == 1 and data['is_formula'] == 1 and data['detect_formula'] == 1:
latex = data['latex']
else:
latex = None
raise FormulaRecognitionError("Image is not recognized as a formula")

return latex

def ocr(src, is_base64=False, figure_instances: dict = None):
"""
Recognizes mathematical formulas within figures from a given source, which can be either a base64 string or an identifier for a figure within a provided dictionary.
Parameters
----------
src : str
The source from which the figure is to be recognized. It can be a base64 encoded string of the image if is_base64 is True, or an identifier for the figure if is_base64 is False.
is_base64 : bool, optional
Indicates whether the src parameter is a base64 encoded string or an identifier, by default False.
figure_instances : dict, optional
A dictionary mapping figure identifiers to their corresponding PngImageFile, by default None. This is only required and used if is_base64 is False.
Returns
-------
forumla_figure_latex : str or None
The LaTeX representation of the mathematical formula recognized within the figure. Returns None if no formula is recognized or if the figure_instances dictionary does not contain the specified figure identifier when is_base64 is False.
Examples
--------
>>> src_base64 = r"\FormFigureBase64{base64_encoded_image_string}"
>>> print(ocr(src_base64, is_base64=True))
Or
>>> from PIL import Image
>>> image_PIL = Image.open("path/to/your/image.jpg")
>>> figure_instances = {"figure1": image_PIL}
>>> src_id = r"\FormFigureID{figure1}"
>>> print(ocr(src_id, figure_instances=figure_instances))
Notes
-----
This function relies on `ocr_formula_figure` for the actual OCR (Optical Character Recognition) process. Ensure that `ocr_formula_figure` is correctly implemented and can handle both base64 encoded strings and PngImageFile as input.
"""
forumla_figure_latex = None
if is_base64:
figure = src[len(r"\FormFigureBase64") + 1: -1]
if figure_instances is not None:
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
else:
figure = src[len(r"\FormFigureID") + 1: -1]
if figure_instances is not None:
figure = figure_instances[figure]
forumla_figure_latex = ocr_formula_figure(figure, is_base64)

return forumla_figure_latex



17 changes: 12 additions & 5 deletions EduNLP/SIF/segment/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from contextlib import contextmanager
from ..constants import Symbol, TEXT_SYMBOL, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL
from ..parser.ocr import ocr


class TextSegment(str):
Expand Down Expand Up @@ -93,7 +94,7 @@ class SegmentList(object):
>>> SegmentList(test_item)
['如图所示,则三角形', 'ABC', '的面积是', '\\\\SIFBlank', '。', \\FigureID{1}]
"""
def __init__(self, item, figures: dict = None):
def __init__(self, item, figures: dict = None, convert_image_to_latex=False):
self._segments = []
self._text_segments = []
self._formula_segments = []
Expand All @@ -112,9 +113,15 @@ def __init__(self, item, figures: dict = None):
if not re.match(r"\$.+?\$", segment):
self.append(TextSegment(segment))
elif re.match(r"\$\\FormFigureID\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=False, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FormFigureBase64\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=True, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
elif re.match(r"\$\\FigureID\{.+?}\$", segment):
self.append(FigureSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FigureBase64\{.+?}\$", segment):
Expand Down Expand Up @@ -271,7 +278,7 @@ def describe(self):
}


def seg(item, figures=None, symbol=None):
def seg(item, figures=None, symbol=None, convert_image_to_latex=False):
r"""
It is a interface for SegmentList. And show it in an appropriate way.
Expand Down Expand Up @@ -346,7 +353,7 @@ def seg(item, figures=None, symbol=None):
>>> s2.text_segments
['已知', ',则以下说法中正确的是']
"""
segments = SegmentList(item, figures)
segments = SegmentList(item, figures, convert_image_to_latex)
if symbol is not None:
segments.symbolize(symbol)
return segments
28 changes: 28 additions & 0 deletions tests/test_sif/test_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# 2024/3/5 @ yuheng

import pytest

from EduNLP.SIF.segment import seg


def test_ocr(figure0, figure1, figure0_base64, figure1_base64):
seg(
r"如图所示,则$\FormFigureID{0}$的面积是$\SIFBlank$。$\FigureID{1}$",
figures={
"0": figure0,
"1": figure1
},
convert_image_to_latex=True
)
s = seg(
r"如图所示,则$\FormFigureBase64{%s}$的面积是$\SIFBlank$。$\FigureBase64{%s}$" % (figure0_base64, figure1_base64),
figures=True,
convert_image_to_latex=True
)
with pytest.raises(TypeError):
s.append("123")
seg_test_text = seg(
r"如图所示,有三组$\textf{机器人,bu}$在踢$\textf{足球,b}$",
figures=True
)
assert seg_test_text.text_segments == ['如图所示,有三组机器人在踢足球']

0 comments on commit e583fb9

Please sign in to comment.