Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update OCR module #157

Merged
merged 4 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@

[Shangzi Xue](https://github.com/ShangziXue)

[Heng Yu](https://github.com/GNEHUY)

The stared contributors are the corresponding authors.
156 changes: 156 additions & 0 deletions EduNLP/SIF/parser/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# coding: utf-8
# 2024/3/5 @ yuheng
import json
import requests
from EduNLP.utils import image2base64


class FormulaRecognitionError(Exception):
"""Exception raised when formula recognition fails."""
def __init__(self, message="Formula recognition failed"):
self.message = message
super().__init__(self.message)


def ocr_formula_figure(image_PIL_or_base64, is_base64=False):
"""
Recognizes mathematical formulas in an image and returns their LaTeX representation.

Parameters
----------
image_PIL_or_base64 : PngImageFile or str
The PngImageFile if is_base64 is False, or the base64 encoded string of the image if is_base64 is True.
is_base64 : bool, optional
Indicates whether the image_PIL_or_base64 parameter is an PngImageFile or a base64 encoded string.

Returns
-------
latex : str
The LaTeX representation of the mathematical formula recognized in the image.
Raises an exception if the image is not recognized as containing a mathematical formula.

Raises
------
FormulaRecognitionError
If the HTTP request does not return a 200 status code,
if there is an error processing the response,
if the image is not recognized as a mathematical formula.

Examples
--------
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> print(ocr_formula_figure(image_PIL))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> image_base64 = image2base64(image_PIL)
>>> print(ocr_formula_figure(image_base64, is_base64=True))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}

Notes
-----
This function relies on an external service "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1",
and the `requests` library to make HTTP requests. Make sure the required libraries are installed before use.
"""
url = "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1"

if is_base64:
image = image_PIL_or_base64
else:
image = image2base64(image_PIL_or_base64)

data = [{
'qid': 0,
'image': image
}]

resp = requests.post(url, data=json.dumps(data))

if resp.status_code != 200:
raise FormulaRecognitionError(f"HTTP error {resp.status_code}: {resp.text}")

try:
res = json.loads(resp.content)
except Exception as e:
raise FormulaRecognitionError(f"Error processing response: {e}")

res = json.loads(resp.content)
data = res['data']
if data['success'] == 1 and data['is_formula'] == 1 and data['detect_formula'] == 1:
latex = data['latex']
else:
latex = None
raise FormulaRecognitionError("Image is not recognized as a formula")

return latex


def ocr(src, is_base64=False, figure_instances: dict = None):
"""
Recognizes mathematical formulas within figures from a given source,
which can be either a base64 string or an identifier for a figure within a provided dictionary.

Parameters
----------
src : str
The source from which the figure is to be recognized.
It can be a base64 encoded string of the image if is_base64 is True,
or an identifier for the figure if is_base64 is False.
is_base64 : bool, optional
Indicates whether the src parameter is a base64 encoded string or an identifier, by default False.
figure_instances : dict, optional
A dictionary mapping figure identifiers to their corresponding PngImageFile, by default None.
This is only required and used if is_base64 is False.

Returns
-------
forumla_figure_latex : str or None
The LaTeX representation of the mathematical formula recognized within the figure.
Returns None if no formula is recognized or
if the figure_instances dictionary does not contain the specified figure identifier when is_base64 is False.

Examples
--------
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> figure_instances = {"1": image_PIL}
>>> src_id = r"$\\FormFigureID{1}$"
>>> print(ocr(src_id[1:-1], figure_instances=figure_instances))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> image_base64 = image2base64(image_PIL)
>>> src_base64 = r"$\\FormFigureBase64{%s}$" % (image_base64)
>>> print(ocr(src_base64[1:-1], is_base64=True, figure_instances=True))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}

Notes
-----
This function relies on `ocr_formula_figure` for the actual OCR (Optical Character Recognition) process.
Ensure that `ocr_formula_figure` is correctly implemented and can handle base64 encoded strings and PngImageFile.
"""
forumla_figure_latex = None
if is_base64:
figure = src[len(r"\FormFigureBase64") + 1: -1]
if figure_instances is not None:
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
else:
figure = src[len(r"\FormFigureID") + 1: -1]
if figure_instances is not None:
figure = figure_instances[figure]
forumla_figure_latex = ocr_formula_figure(figure, is_base64)

return forumla_figure_latex
17 changes: 12 additions & 5 deletions EduNLP/SIF/segment/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from contextlib import contextmanager
from ..constants import Symbol, TEXT_SYMBOL, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL
from ..parser.ocr import ocr


class TextSegment(str):
Expand Down Expand Up @@ -93,7 +94,7 @@ class SegmentList(object):
>>> SegmentList(test_item)
['如图所示,则三角形', 'ABC', '的面积是', '\\\\SIFBlank', '。', \\FigureID{1}]
"""
def __init__(self, item, figures: dict = None):
def __init__(self, item, figures: dict = None, convert_image_to_latex=False):
self._segments = []
self._text_segments = []
self._formula_segments = []
Expand All @@ -112,9 +113,15 @@ def __init__(self, item, figures: dict = None):
if not re.match(r"\$.+?\$", segment):
self.append(TextSegment(segment))
elif re.match(r"\$\\FormFigureID\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=False, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FormFigureBase64\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=True, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
elif re.match(r"\$\\FigureID\{.+?}\$", segment):
self.append(FigureSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FigureBase64\{.+?}\$", segment):
Expand Down Expand Up @@ -271,7 +278,7 @@ def describe(self):
}


def seg(item, figures=None, symbol=None):
def seg(item, figures=None, symbol=None, convert_image_to_latex=False):
r"""
It is a interface for SegmentList. And show it in an appropriate way.

Expand Down Expand Up @@ -346,7 +353,7 @@ def seg(item, figures=None, symbol=None):
>>> s2.text_segments
['已知', ',则以下说法中正确的是']
"""
segments = SegmentList(item, figures)
segments = SegmentList(item, figures, convert_image_to_latex)
if symbol is not None:
segments.symbolize(symbol)
return segments
4 changes: 2 additions & 2 deletions EduNLP/SIF/sif.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def to_sif(item, check_formula=True, parser: Parser = None):


def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str = None, tokenization=True,
tokenization_params=None, errors="raise"):
tokenization_params=None, convert_image_to_latex=False, errors="raise"):
r"""

Default to use linear Tokenizer, change the tokenizer by specifying tokenization_params
Expand Down Expand Up @@ -260,7 +260,7 @@ def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str
"Unknown mode %s, use only 0 or 1 or 2." % mode
)

ret = seg(item, figures, symbol)
ret = seg(item, figures, symbol, convert_image_to_latex)

if tokenization is True:
ret = tokenize(ret, **(tokenization_params if tokenization_params is not None else {}))
Expand Down
18 changes: 12 additions & 6 deletions EduNLP/Tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
for item in items:
yield self._tokenize(item, key=key, **kwargs)

def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
convert_image_to_latex=False, **kwargs):
"""Tokenize one item, return token list

Parameters
Expand All @@ -67,7 +68,8 @@ def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
determine how to get the text of item, by default lambdax: x
"""
symbol = self.symbol if symbol is None else symbol
return tokenize(seg(key(item), symbol=symbol, figures=self.figures),
return tokenize(seg(key(item), symbol=symbol, figures=self.figures,
convert_image_to_latex=convert_image_to_latex),
**self.tokenization_params, **kwargs).tokens


Expand Down Expand Up @@ -191,9 +193,11 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
for item in items:
yield self._tokenize(item, key=key, **kwargs)

def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
convert_image_to_latex=False, **kwargs):
symbol = self.symbol if symbol is None else symbol
return tokenize(seg(key(item), symbol=symbol), **self.tokenization_params, **kwargs).tokens
return tokenize(seg(key(item), symbol=symbol, convert_image_to_latex=convert_image_to_latex),
**self.tokenization_params, **kwargs).tokens


class AstFormulaTokenizer(Tokenizer):
Expand Down Expand Up @@ -235,11 +239,13 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
for item in items:
yield self._tokenize(item, key=key, **kwargs)

def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
convert_image_to_latex=False, **kwargs):
mode = kwargs.pop("mode", 0)
symbol = self.symbol if symbol is None else symbol
ret = sif4sci(key(item), figures=self.figures, mode=mode, symbol=symbol,
tokenization_params=self.tokenization_params, errors="ignore", **kwargs)
tokenization_params=self.tokenization_params, convert_image_to_latex=convert_image_to_latex,
errors="ignore", **kwargs)
ret = [] if ret is None else ret.tokens
return ret

Expand Down
Binary file added asset/_static/item_ocr_formula.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
62 changes: 62 additions & 0 deletions tests/test_sif/test_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# 2024/3/5 @ yuheng

import pytest
import json

from EduNLP.SIF.segment import seg
from EduNLP.SIF.parser.ocr import ocr_formula_figure, FormulaRecognitionError
from unittest.mock import patch


def test_ocr(figure0, figure1, figure0_base64, figure1_base64):
seg(
r"如图所示,则$\FormFigureID{0}$的面积是$\SIFBlank$。$\FigureID{1}$",
figures={
"0": figure0,
"1": figure1
},
convert_image_to_latex=True
)
s = seg(
r"如图所示,则$\FormFigureBase64{%s}$的面积是$\SIFBlank$。$\FigureBase64{%s}$" % (figure0_base64, figure1_base64),
figures=True,
convert_image_to_latex=True
)
with pytest.raises(TypeError):
s.append("123")
seg_test_text = seg(
r"如图所示,有三组$\textf{机器人,bu}$在踢$\textf{足球,b}$",
figures=True
)
assert seg_test_text.text_segments == ['如图所示,有三组机器人在踢足球']


def test_ocr_formula_figure_exceptions(figure0_base64):
"""Simulate a non-200 status code"""
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
mock_post.return_value.status_code = 404
with pytest.raises(FormulaRecognitionError) as exc_info:
ocr_formula_figure(figure0_base64, is_base64=True)
assert "HTTP error 404" in str(exc_info.value)

"""Simulate an invalid JSON response"""
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
mock_post.return_value.status_code = 200
mock_post.return_value.content = b"invalid_json_response"
with pytest.raises(FormulaRecognitionError) as exc_info:
ocr_formula_figure(figure0_base64, is_base64=True)
assert "Error processing response" in str(exc_info.value)

"""Simulate image not recognized as a formula"""
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
mock_post.return_value.status_code = 200
mock_post.return_value.content = json.dumps({
"data": {
'success': 1,
'is_formula': 0,
'detect_formula': 0
}
}).encode('utf-8')
with pytest.raises(FormulaRecognitionError) as exc_info:
ocr_formula_figure(figure0_base64, is_base64=True)
assert "Image is not recognized as a formula" in str(exc_info.value)
Loading