Skip to content

Commit

Permalink
support table structure model inference, and table inference concaten…
Browse files Browse the repository at this point in the history
…ated with OCR model. (#764)
  • Loading branch information
hongziqi authored Nov 15, 2024
1 parent dda7b75 commit 9d2f8f8
Show file tree
Hide file tree
Showing 14 changed files with 977 additions and 1 deletion.
Binary file added configs/table/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added configs/table/example_structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mindocr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .rec_robustscanner import *
from .rec_svtr import *
from .rec_visionlan import *
from .table_master import *

__all__ = []
__all__.extend(builder.__all__)
Expand Down
57 changes: 57 additions & 0 deletions mindocr/models/table_master.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from ._registry import register_model
from .backbones.mindcv_models.utils import load_pretrained
from .base_model import BaseModel


def _cfg(url="", **kwargs):
return {"url": url, **kwargs}


default_cfgs = {
"table_master": _cfg(
url="https://download-mindspore.osinfra.cn/toolkits/mindocr/tablemaster/table_master-78bf35bb.ckpt"
),
}


class TableMaster(BaseModel):
def __init__(self, config):
BaseModel.__init__(self, config)


@register_model
def table_master(
pretrained: bool = True,
**kwargs
):
model_config = {
"type": "table",
"transform": None,
"backbone": {
"name": "table_resnet_extra",
"gcb_config": {
"ratio": 0.0625,
"headers": 1,
"att_scale": False,
"fusion_type": "channel_add",
"layers": [False, True, True, True],
},
"layers": [1, 2, 5, 3],
},
"head": {
"name": "TableMasterHead",
"out_channels": 43,
"hidden_size": 512,
"headers": 8,
"dropout": 0.0,
"d_ff": 2024,
"max_text_length": 500,
"loc_reg_num": 4
},
}
model = TableMaster(model_config)
if pretrained:
default_cfg = default_cfgs["table_master"]
load_pretrained(model, default_cfg)

return model
2 changes: 2 additions & 0 deletions mindocr/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def show_img(img: np.array, is_bgr_img=True, title="img", show=True, save_path=N
if show:
plt.show()
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path) # , bbox_inches='tight', dpi=460)


Expand Down Expand Up @@ -73,6 +74,7 @@ def show_imgs(
plt.show()

if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, bbox_inches="tight", dpi=300, pad_inches=0)


Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ requests>=2.31.0
pycocotools>=2.0.2
setuptools-scm
albumentations
beautifulsoup4
pandas
tablepyxl
lxml
30 changes: 30 additions & 0 deletions tools/infer/text/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,36 @@ def create_parser():
"due to padding or resizing to the same shape.",
)
parser.add_argument("--kie_batch_num", type=int, default=8)
parser.add_argument(
"--table_algorithm",
type=str,
default="TABLE_MASTER",
choices=["TABLE_MASTER"],
help="table structure recognition algorithm",
)
parser.add_argument(
"--table_model_dir",
type=str,
help="directory containing the table structure recognition model checkpoint best.ckpt, "
"or path to a specific checkpoint file.",
)
parser.add_argument(
"--table_amp_level",
type=str,
default="O2",
choices=["O0", "O1", "O2", "O3"],
help="Auto Mixed Precision level. This setting only works on GPU and Ascend",
)
parser.add_argument(
"--table_char_dict_path",
type=str,
default="./mindocr/utils/dict/table_master_structure_dict.txt",
help="path to character dictionary for table structure recognition. "
"If None, will pick according to table_algorithm and table_model_dir.",
)
parser.add_argument(
"--table_max_len", type=int, default=480, help="max length of the input image for table structure recognition."
)

return parser

Expand Down
13 changes: 13 additions & 0 deletions tools/infer/text/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def __init__(self, task="det", algo="DB", rec_char_dict_path=None, **kwargs):
elif task == "ser":
class_path = "mindocr/utils/dict/class_list_xfun.txt"
postproc_cfg = dict(name="VQASerTokenLayoutLMPostProcess", class_path=class_path)
elif task == "table":
table_char_dict_path = kwargs.get(
"table_char_dict_path", "mindocr/utils/dict/table_master_structure_dict.txt"
)
postproc_cfg = dict(
name="TableMasterLabelDecode",
character_dict_path=table_char_dict_path,
merge_no_span_structure=True,
box_shape="pad",
)

postproc_cfg.update(kwargs)
self.task = task
Expand Down Expand Up @@ -142,3 +152,6 @@ def __call__(self, pred, data=None, **kwargs):
pred, segment_offset_ids=kwargs.get("segment_offset_ids"), ocr_infos=kwargs.get("ocr_infos")
)
return output
elif self.task == "table":
output = self.postprocess(pred, labels=kwargs.get("labels"))
return output
118 changes: 118 additions & 0 deletions tools/infer/text/predict_table_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Infer table from images with structure model and ocr model.
Example:
$ python tools/infer/text/predict_table_recognition.py --image_dir {path_to_img} --table_algorithm TABLE_MASTER
"""
import logging
import os
import sys
from typing import Union

import cv2
import numpy as np
from config import parse_args
from predict_system import TextSystem
from predict_table_structure import StructureAnalyzer
from utils import TableMasterMatcher, get_image_paths

from mindocr.utils.logger import set_logger

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))

logger = logging.getLogger("mindocr")


class TableAnalyzer:
"""
Model inference class for table structure analysis and match with ocr result.
Example:
>>> args = parse_args()
>>> analyzer = TableAnalyzer(args)
>>> img_path = "path/to/image.jpg"
>>> pred_html, time_prof = analyzer(img_path)
"""

def __init__(self, args):
self.text_system = TextSystem(args)
self.table_structure = StructureAnalyzer(args)
self.match = TableMasterMatcher()

def _structure(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
structure_res, elapse = self.table_structure(img_or_path, do_visualize)
return structure_res, elapse

def _text_ocr(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
boxes, text_scores, time_prof = self.text_system(img_or_path, do_visualize)
if isinstance(img_or_path, str):
img = cv2.imread(img_or_path)
elif isinstance(img_or_path, np.ndarray):
img = img_or_path
else:
raise ValueError("Invalid input type, should be str or np.ndarray.")
h, w = img.shape[:2]
r_boxes = []
for box in boxes:
x_min = max(0, box[:, 0].min() - 1)
x_max = min(w, box[:, 0].max() + 1)
y_min = max(0, box[:, 1].min() - 1)
y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, text_scores, time_prof

def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
boxes, text_scores, ocr_time_prof = self._text_ocr(img_or_path, do_visualize)
structure_res, struct_time_prof = self._structure(img_or_path, do_visualize)
pred_html = self.match(structure_res, boxes, text_scores)
time_prof = {
"ocr": ocr_time_prof,
"table": struct_time_prof,
}
return pred_html, time_prof


def parse_html_table(html_table):
from bs4 import BeautifulSoup

soup = BeautifulSoup(html_table, "html.parser")
table = soup.find("table")
if not table:
raise ValueError("No table found in the HTML string.")
return table


def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl

table = parse_html_table(html_table)
tablepyxl.document_to_xl(str(table), excel_path)


def to_csv(html_table, csv_path):
import pandas as pd

table = parse_html_table(html_table)
df = pd.read_html(str(table))[0]
df.to_csv(csv_path, index=False)


def main():
args = parse_args()
set_logger(name="mindocr")
analyzer = TableAnalyzer(args)
img_paths = get_image_paths(args.image_dir)
save_dir = args.draw_img_save_dir
for i, img_path in enumerate(img_paths):
logger.info(f"Infering {i+1}/{len(img_paths)}: {img_path}")
pred_html, time_prof = analyzer(img_path, do_visualize=True)
logger.info(f"Time profile: {time_prof}")
img_name = os.path.basename(img_path).rsplit(".", 1)[0]
to_csv(pred_html, os.path.join(save_dir, f"{img_name}.csv"))
logger.info(f"Done! All structure results are saved to {args.draw_img_save_dir}")


if __name__ == "__main__":
main()
120 changes: 120 additions & 0 deletions tools/infer/text/predict_table_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Infer table structure from images.
Example:
$ python tools/infer/text/predict_table_structure.py --image_dir {path_to_img} --table_algorithm TABLE_MASTER
"""
import logging
import os
import sys
import time
from typing import Dict, Union

import numpy as np
from config import parse_args
from postprocess import Postprocessor
from preprocess import Preprocessor
from utils import get_ckpt_file, get_image_paths

from mindspore import Tensor

from mindocr.models import build_model
from mindocr.utils.logger import set_logger
from mindocr.utils.visualize import draw_boxes, show_imgs

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))


algo_to_model_name = {
"TABLE_MASTER": "table_master",
}
logger = logging.getLogger("mindocr")


class StructureAnalyzer:
"""
Model inference class for table structure analysis.
Example:
>>> args = parse_args()
>>> analyzer = StructureAnalyzer(args)
>>> img_path = "path/to/image.jpg"
>>> (structure_str_list, bbox_list), elapsed_time = analyzer(img_path)
"""

def __init__(self, args):
ckpt_dir = args.table_model_dir
if ckpt_dir is None:
pretrained = True
ckpt_load_path = None
else:
ckpt_load_path = get_ckpt_file(ckpt_dir)
pretrained = False
if args.table_algorithm not in algo_to_model_name:
raise ValueError(
f"Invalid table algorithm {args.table_algorithm}. "
f"Supported table algorithms are {list(algo_to_model_name.keys())}"
)
model_name = algo_to_model_name[args.table_algorithm]

self.model = build_model(
model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=args.table_amp_level
)
self.model.set_train(False)
self.preprocess = Preprocessor(task="table", table_max_len=args.table_max_len)
self.postprocess = Postprocessor(task="table", table_char_dict_path=args.table_char_dict_path)
self.vis_dir = args.draw_img_save_dir
os.makedirs(self.vis_dir, exist_ok=True)

def __call__(
self,
img_or_path: Union[str, np.ndarray, Dict],
do_visualize: bool = True,
):
"""
Perform model inference.
Args:
img_or_path (Union[str, np.ndarray, Dict]): Input image or image path.
do_visualize (bool): Whether to visualize the result.
Returns:
Structure string list, bounding box list, and elapsed time.
"""
time_profile = {}
start_time = time.time()
data = self.preprocess(img_or_path)
input_np = data["image"]
if len(input_np.shape) == 3:
input_np = Tensor(np.expand_dims(input_np, axis=0))

net_pred = self.model(input_np)
shape_list = np.expand_dims(data["shape"], axis=0)
post_result = self.postprocess(net_pred, labels=[shape_list])
structure_str_list = post_result["structure_batch_list"][0][0]
structure_str_list = ["<html>", "<body>", "<table>"] + structure_str_list + ["</table>", "</body>", "</html>"]
bbox_list = post_result["bbox_batch_list"][0]
elapse = time.time() - start_time
time_profile["structure"] = elapse

if do_visualize:
vst = time.time()
img_name = os.path.basename(data.get("img_path", "input.png")).rsplit(".", 1)[0]
save_path = os.path.join(self.vis_dir, img_name + "_structure.png")
structure_vis = draw_boxes(img_or_path, bbox_list, draw_type="rectangle")
show_imgs([structure_vis], show=False, save_path=save_path)
time_profile["vis"] = time.time() - vst
return (structure_str_list, bbox_list), time_profile


def main():
args = parse_args()
set_logger(name="mindocr")
analyzer = StructureAnalyzer(args)
img_paths = get_image_paths(args.image_dir)
for i, img_path in enumerate(img_paths):
logger.info(f"Inferring {i+1}/{len(img_paths)}: {img_path}")
_ = analyzer(img_path, do_visualize=True)
logger.info(f"Done! All structure results are saved to {args.draw_img_save_dir}")


if __name__ == "__main__":
main()
Loading

0 comments on commit 9d2f8f8

Please sign in to comment.