Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support table structure model inference #764

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added configs/table/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added configs/table/example_structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mindocr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .rec_robustscanner import *
from .rec_svtr import *
from .rec_visionlan import *
from .table_master import *

__all__ = []
__all__.extend(builder.__all__)
Expand Down
57 changes: 57 additions & 0 deletions mindocr/models/table_master.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from ._registry import register_model
from .backbones.mindcv_models.utils import load_pretrained
from .base_model import BaseModel


def _cfg(url="", **kwargs):
return {"url": url, **kwargs}


default_cfgs = {
"table_master": _cfg(
url="https://download-mindspore.osinfra.cn/toolkits/mindocr/tablemaster/table_master-78bf35bb.ckpt"
),
}


class TableMaster(BaseModel):
def __init__(self, config):
BaseModel.__init__(self, config)


@register_model
def table_master(
pretrained: bool = True,
**kwargs
):
model_config = {
"type": "table",
"transform": None,
"backbone": {
"name": "table_resnet_extra",
"gcb_config": {
"ratio": 0.0625,
"headers": 1,
"att_scale": False,
"fusion_type": "channel_add",
"layers": [False, True, True, True],
},
"layers": [1, 2, 5, 3],
},
"head": {
"name": "TableMasterHead",
"out_channels": 43,
"hidden_size": 512,
"headers": 8,
"dropout": 0.0,
"d_ff": 2024,
"max_text_length": 500,
"loc_reg_num": 4
},
}
model = TableMaster(model_config)
if pretrained:
default_cfg = default_cfgs["table_master"]
load_pretrained(model, default_cfg)

return model
2 changes: 2 additions & 0 deletions mindocr/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def show_img(img: np.array, is_bgr_img=True, title="img", show=True, save_path=N
if show:
plt.show()
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path) # , bbox_inches='tight', dpi=460)


Expand Down Expand Up @@ -73,6 +74,7 @@ def show_imgs(
plt.show()

if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, bbox_inches="tight", dpi=300, pad_inches=0)


Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ requests>=2.31.0
pycocotools>=2.0.2
setuptools-scm
albumentations
beautifulsoup4
pandas
tablepyxl
lxml
30 changes: 30 additions & 0 deletions tools/infer/text/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,36 @@ def create_parser():
"due to padding or resizing to the same shape.",
)
parser.add_argument("--kie_batch_num", type=int, default=8)
parser.add_argument(
"--table_algorithm",
type=str,
default="TABLE_MASTER",
choices=["TABLE_MASTER"],
help="table structure recognition algorithm",
)
parser.add_argument(
"--table_model_dir",
type=str,
help="directory containing the table structure recognition model checkpoint best.ckpt, "
"or path to a specific checkpoint file.",
)
parser.add_argument(
"--table_amp_level",
type=str,
default="O2",
choices=["O0", "O1", "O2", "O3"],
help="Auto Mixed Precision level. This setting only works on GPU and Ascend",
)
parser.add_argument(
"--table_char_dict_path",
type=str,
default="./mindocr/utils/dict/table_master_structure_dict.txt",
help="path to character dictionary for table structure recognition. "
"If None, will pick according to table_algorithm and table_model_dir.",
)
parser.add_argument(
"--table_max_len", type=int, default=480, help="max length of the input image for table structure recognition."
)

return parser

Expand Down
13 changes: 13 additions & 0 deletions tools/infer/text/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def __init__(self, task="det", algo="DB", rec_char_dict_path=None, **kwargs):
elif task == "ser":
class_path = "mindocr/utils/dict/class_list_xfun.txt"
postproc_cfg = dict(name="VQASerTokenLayoutLMPostProcess", class_path=class_path)
elif task == "table":
table_char_dict_path = kwargs.get(
"table_char_dict_path", "mindocr/utils/dict/table_master_structure_dict.txt"
)
postproc_cfg = dict(
name="TableMasterLabelDecode",
character_dict_path=table_char_dict_path,
merge_no_span_structure=True,
box_shape="pad",
)

postproc_cfg.update(kwargs)
self.task = task
Expand Down Expand Up @@ -142,3 +152,6 @@ def __call__(self, pred, data=None, **kwargs):
pred, segment_offset_ids=kwargs.get("segment_offset_ids"), ocr_infos=kwargs.get("ocr_infos")
)
return output
elif self.task == "table":
output = self.postprocess(pred, labels=kwargs.get("labels"))
return output
118 changes: 118 additions & 0 deletions tools/infer/text/predict_table_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Infer table from images with structure model and ocr model.

Example:
$ python tools/infer/text/predict_table_recognition.py --image_dir {path_to_img} --table_algorithm TABLE_MASTER
"""
import logging
import os
import sys
from typing import Union

import cv2
import numpy as np
from config import parse_args
from predict_system import TextSystem
from predict_table_structure import StructureAnalyzer
from utils import TableMasterMatcher, get_image_paths

from mindocr.utils.logger import set_logger

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))

logger = logging.getLogger("mindocr")


class TableAnalyzer:
"""
Model inference class for table structure analysis and match with ocr result.
Example:
>>> args = parse_args()
>>> analyzer = TableAnalyzer(args)
>>> img_path = "path/to/image.jpg"
>>> pred_html, time_prof = analyzer(img_path)
"""

def __init__(self, args):
self.text_system = TextSystem(args)
self.table_structure = StructureAnalyzer(args)
self.match = TableMasterMatcher()

def _structure(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
structure_res, elapse = self.table_structure(img_or_path, do_visualize)
return structure_res, elapse

def _text_ocr(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
boxes, text_scores, time_prof = self.text_system(img_or_path, do_visualize)
if isinstance(img_or_path, str):
img = cv2.imread(img_or_path)
elif isinstance(img_or_path, np.ndarray):
img = img_or_path
else:
raise ValueError("Invalid input type, should be str or np.ndarray.")
h, w = img.shape[:2]
r_boxes = []
for box in boxes:
x_min = max(0, box[:, 0].min() - 1)
x_max = min(w, box[:, 0].max() + 1)
y_min = max(0, box[:, 1].min() - 1)
y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, text_scores, time_prof

def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
boxes, text_scores, ocr_time_prof = self._text_ocr(img_or_path, do_visualize)
structure_res, struct_time_prof = self._structure(img_or_path, do_visualize)
pred_html = self.match(structure_res, boxes, text_scores)
time_prof = {
"ocr": ocr_time_prof,
"table": struct_time_prof,
}
return pred_html, time_prof


def parse_html_table(html_table):
from bs4 import BeautifulSoup

soup = BeautifulSoup(html_table, "html.parser")
table = soup.find("table")
if not table:
raise ValueError("No table found in the HTML string.")
return table


def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl

table = parse_html_table(html_table)
tablepyxl.document_to_xl(str(table), excel_path)


def to_csv(html_table, csv_path):
import pandas as pd

table = parse_html_table(html_table)
df = pd.read_html(str(table))[0]
df.to_csv(csv_path, index=False)


def main():
args = parse_args()
set_logger(name="mindocr")
analyzer = TableAnalyzer(args)
img_paths = get_image_paths(args.image_dir)
save_dir = args.draw_img_save_dir
for i, img_path in enumerate(img_paths):
logger.info(f"Infering {i+1}/{len(img_paths)}: {img_path}")
pred_html, time_prof = analyzer(img_path, do_visualize=True)
logger.info(f"Time profile: {time_prof}")
img_name = os.path.basename(img_path).rsplit(".", 1)[0]
to_csv(pred_html, os.path.join(save_dir, f"{img_name}.csv"))
logger.info(f"Done! All structure results are saved to {args.draw_img_save_dir}")


if __name__ == "__main__":
main()
120 changes: 120 additions & 0 deletions tools/infer/text/predict_table_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Infer table structure from images.

Example:
$ python tools/infer/text/predict_table_structure.py --image_dir {path_to_img} --table_algorithm TABLE_MASTER
"""
import logging
import os
import sys
import time
from typing import Dict, Union

import numpy as np
from config import parse_args
from postprocess import Postprocessor
from preprocess import Preprocessor
from utils import get_ckpt_file, get_image_paths

from mindspore import Tensor

from mindocr.models import build_model
from mindocr.utils.logger import set_logger
from mindocr.utils.visualize import draw_boxes, show_imgs

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))


algo_to_model_name = {
"TABLE_MASTER": "table_master",
}
logger = logging.getLogger("mindocr")


class StructureAnalyzer:
"""
Model inference class for table structure analysis.
Example:
>>> args = parse_args()
>>> analyzer = StructureAnalyzer(args)
>>> img_path = "path/to/image.jpg"
>>> (structure_str_list, bbox_list), elapsed_time = analyzer(img_path)
"""

def __init__(self, args):
ckpt_dir = args.table_model_dir
if ckpt_dir is None:
pretrained = True
ckpt_load_path = None
else:
ckpt_load_path = get_ckpt_file(ckpt_dir)
pretrained = False
if args.table_algorithm not in algo_to_model_name:
raise ValueError(
f"Invalid table algorithm {args.table_algorithm}. "
f"Supported table algorithms are {list(algo_to_model_name.keys())}"
)
model_name = algo_to_model_name[args.table_algorithm]

self.model = build_model(
model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=args.table_amp_level
)
self.model.set_train(False)
self.preprocess = Preprocessor(task="table", table_max_len=args.table_max_len)
self.postprocess = Postprocessor(task="table", table_char_dict_path=args.table_char_dict_path)
self.vis_dir = args.draw_img_save_dir
os.makedirs(self.vis_dir, exist_ok=True)

def __call__(
self,
img_or_path: Union[str, np.ndarray, Dict],
do_visualize: bool = True,
):
"""
Perform model inference.
Args:
img_or_path (Union[str, np.ndarray, Dict]): Input image or image path.
do_visualize (bool): Whether to visualize the result.
Returns:
Structure string list, bounding box list, and elapsed time.
"""
time_profile = {}
start_time = time.time()
data = self.preprocess(img_or_path)
input_np = data["image"]
if len(input_np.shape) == 3:
input_np = Tensor(np.expand_dims(input_np, axis=0))

net_pred = self.model(input_np)
shape_list = np.expand_dims(data["shape"], axis=0)
post_result = self.postprocess(net_pred, labels=[shape_list])
structure_str_list = post_result["structure_batch_list"][0][0]
structure_str_list = ["<html>", "<body>", "<table>"] + structure_str_list + ["</table>", "</body>", "</html>"]
bbox_list = post_result["bbox_batch_list"][0]
elapse = time.time() - start_time
time_profile["structure"] = elapse

if do_visualize:
vst = time.time()
img_name = os.path.basename(data.get("img_path", "input.png")).rsplit(".", 1)[0]
save_path = os.path.join(self.vis_dir, img_name + "_structure.png")
structure_vis = draw_boxes(img_or_path, bbox_list, draw_type="rectangle")
show_imgs([structure_vis], show=False, save_path=save_path)
time_profile["vis"] = time.time() - vst
return (structure_str_list, bbox_list), time_profile


def main():
args = parse_args()
set_logger(name="mindocr")
analyzer = StructureAnalyzer(args)
img_paths = get_image_paths(args.image_dir)
for i, img_path in enumerate(img_paths):
logger.info(f"Inferring {i+1}/{len(img_paths)}: {img_path}")
_ = analyzer(img_path, do_visualize=True)
logger.info(f"Done! All structure results are saved to {args.draw_img_save_dir}")


if __name__ == "__main__":
main()
Loading
Loading