-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support table structure model inference, and table inference concaten…
…ated with OCR model. (#764)
- Loading branch information
Showing
14 changed files
with
977 additions
and
1 deletion.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from ._registry import register_model | ||
from .backbones.mindcv_models.utils import load_pretrained | ||
from .base_model import BaseModel | ||
|
||
|
||
def _cfg(url="", **kwargs): | ||
return {"url": url, **kwargs} | ||
|
||
|
||
default_cfgs = { | ||
"table_master": _cfg( | ||
url="https://download-mindspore.osinfra.cn/toolkits/mindocr/tablemaster/table_master-78bf35bb.ckpt" | ||
), | ||
} | ||
|
||
|
||
class TableMaster(BaseModel): | ||
def __init__(self, config): | ||
BaseModel.__init__(self, config) | ||
|
||
|
||
@register_model | ||
def table_master( | ||
pretrained: bool = True, | ||
**kwargs | ||
): | ||
model_config = { | ||
"type": "table", | ||
"transform": None, | ||
"backbone": { | ||
"name": "table_resnet_extra", | ||
"gcb_config": { | ||
"ratio": 0.0625, | ||
"headers": 1, | ||
"att_scale": False, | ||
"fusion_type": "channel_add", | ||
"layers": [False, True, True, True], | ||
}, | ||
"layers": [1, 2, 5, 3], | ||
}, | ||
"head": { | ||
"name": "TableMasterHead", | ||
"out_channels": 43, | ||
"hidden_size": 512, | ||
"headers": 8, | ||
"dropout": 0.0, | ||
"d_ff": 2024, | ||
"max_text_length": 500, | ||
"loc_reg_num": 4 | ||
}, | ||
} | ||
model = TableMaster(model_config) | ||
if pretrained: | ||
default_cfg = default_cfgs["table_master"] | ||
load_pretrained(model, default_cfg) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,3 +21,7 @@ requests>=2.31.0 | |
pycocotools>=2.0.2 | ||
setuptools-scm | ||
albumentations | ||
beautifulsoup4 | ||
pandas | ||
tablepyxl | ||
lxml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
""" | ||
Infer table from images with structure model and ocr model. | ||
Example: | ||
$ python tools/infer/text/predict_table_recognition.py --image_dir {path_to_img} --table_algorithm TABLE_MASTER | ||
""" | ||
import logging | ||
import os | ||
import sys | ||
from typing import Union | ||
|
||
import cv2 | ||
import numpy as np | ||
from config import parse_args | ||
from predict_system import TextSystem | ||
from predict_table_structure import StructureAnalyzer | ||
from utils import TableMasterMatcher, get_image_paths | ||
|
||
from mindocr.utils.logger import set_logger | ||
|
||
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../"))) | ||
|
||
logger = logging.getLogger("mindocr") | ||
|
||
|
||
class TableAnalyzer: | ||
""" | ||
Model inference class for table structure analysis and match with ocr result. | ||
Example: | ||
>>> args = parse_args() | ||
>>> analyzer = TableAnalyzer(args) | ||
>>> img_path = "path/to/image.jpg" | ||
>>> pred_html, time_prof = analyzer(img_path) | ||
""" | ||
|
||
def __init__(self, args): | ||
self.text_system = TextSystem(args) | ||
self.table_structure = StructureAnalyzer(args) | ||
self.match = TableMasterMatcher() | ||
|
||
def _structure(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True): | ||
structure_res, elapse = self.table_structure(img_or_path, do_visualize) | ||
return structure_res, elapse | ||
|
||
def _text_ocr(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True): | ||
boxes, text_scores, time_prof = self.text_system(img_or_path, do_visualize) | ||
if isinstance(img_or_path, str): | ||
img = cv2.imread(img_or_path) | ||
elif isinstance(img_or_path, np.ndarray): | ||
img = img_or_path | ||
else: | ||
raise ValueError("Invalid input type, should be str or np.ndarray.") | ||
h, w = img.shape[:2] | ||
r_boxes = [] | ||
for box in boxes: | ||
x_min = max(0, box[:, 0].min() - 1) | ||
x_max = min(w, box[:, 0].max() + 1) | ||
y_min = max(0, box[:, 1].min() - 1) | ||
y_max = min(h, box[:, 1].max() + 1) | ||
box = [x_min, y_min, x_max, y_max] | ||
r_boxes.append(box) | ||
dt_boxes = np.array(r_boxes) | ||
return dt_boxes, text_scores, time_prof | ||
|
||
def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True): | ||
boxes, text_scores, ocr_time_prof = self._text_ocr(img_or_path, do_visualize) | ||
structure_res, struct_time_prof = self._structure(img_or_path, do_visualize) | ||
pred_html = self.match(structure_res, boxes, text_scores) | ||
time_prof = { | ||
"ocr": ocr_time_prof, | ||
"table": struct_time_prof, | ||
} | ||
return pred_html, time_prof | ||
|
||
|
||
def parse_html_table(html_table): | ||
from bs4 import BeautifulSoup | ||
|
||
soup = BeautifulSoup(html_table, "html.parser") | ||
table = soup.find("table") | ||
if not table: | ||
raise ValueError("No table found in the HTML string.") | ||
return table | ||
|
||
|
||
def to_excel(html_table, excel_path): | ||
from tablepyxl import tablepyxl | ||
|
||
table = parse_html_table(html_table) | ||
tablepyxl.document_to_xl(str(table), excel_path) | ||
|
||
|
||
def to_csv(html_table, csv_path): | ||
import pandas as pd | ||
|
||
table = parse_html_table(html_table) | ||
df = pd.read_html(str(table))[0] | ||
df.to_csv(csv_path, index=False) | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
set_logger(name="mindocr") | ||
analyzer = TableAnalyzer(args) | ||
img_paths = get_image_paths(args.image_dir) | ||
save_dir = args.draw_img_save_dir | ||
for i, img_path in enumerate(img_paths): | ||
logger.info(f"Infering {i+1}/{len(img_paths)}: {img_path}") | ||
pred_html, time_prof = analyzer(img_path, do_visualize=True) | ||
logger.info(f"Time profile: {time_prof}") | ||
img_name = os.path.basename(img_path).rsplit(".", 1)[0] | ||
to_csv(pred_html, os.path.join(save_dir, f"{img_name}.csv")) | ||
logger.info(f"Done! All structure results are saved to {args.draw_img_save_dir}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
""" | ||
Infer table structure from images. | ||
Example: | ||
$ python tools/infer/text/predict_table_structure.py --image_dir {path_to_img} --table_algorithm TABLE_MASTER | ||
""" | ||
import logging | ||
import os | ||
import sys | ||
import time | ||
from typing import Dict, Union | ||
|
||
import numpy as np | ||
from config import parse_args | ||
from postprocess import Postprocessor | ||
from preprocess import Preprocessor | ||
from utils import get_ckpt_file, get_image_paths | ||
|
||
from mindspore import Tensor | ||
|
||
from mindocr.models import build_model | ||
from mindocr.utils.logger import set_logger | ||
from mindocr.utils.visualize import draw_boxes, show_imgs | ||
|
||
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../"))) | ||
|
||
|
||
algo_to_model_name = { | ||
"TABLE_MASTER": "table_master", | ||
} | ||
logger = logging.getLogger("mindocr") | ||
|
||
|
||
class StructureAnalyzer: | ||
""" | ||
Model inference class for table structure analysis. | ||
Example: | ||
>>> args = parse_args() | ||
>>> analyzer = StructureAnalyzer(args) | ||
>>> img_path = "path/to/image.jpg" | ||
>>> (structure_str_list, bbox_list), elapsed_time = analyzer(img_path) | ||
""" | ||
|
||
def __init__(self, args): | ||
ckpt_dir = args.table_model_dir | ||
if ckpt_dir is None: | ||
pretrained = True | ||
ckpt_load_path = None | ||
else: | ||
ckpt_load_path = get_ckpt_file(ckpt_dir) | ||
pretrained = False | ||
if args.table_algorithm not in algo_to_model_name: | ||
raise ValueError( | ||
f"Invalid table algorithm {args.table_algorithm}. " | ||
f"Supported table algorithms are {list(algo_to_model_name.keys())}" | ||
) | ||
model_name = algo_to_model_name[args.table_algorithm] | ||
|
||
self.model = build_model( | ||
model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=args.table_amp_level | ||
) | ||
self.model.set_train(False) | ||
self.preprocess = Preprocessor(task="table", table_max_len=args.table_max_len) | ||
self.postprocess = Postprocessor(task="table", table_char_dict_path=args.table_char_dict_path) | ||
self.vis_dir = args.draw_img_save_dir | ||
os.makedirs(self.vis_dir, exist_ok=True) | ||
|
||
def __call__( | ||
self, | ||
img_or_path: Union[str, np.ndarray, Dict], | ||
do_visualize: bool = True, | ||
): | ||
""" | ||
Perform model inference. | ||
Args: | ||
img_or_path (Union[str, np.ndarray, Dict]): Input image or image path. | ||
do_visualize (bool): Whether to visualize the result. | ||
Returns: | ||
Structure string list, bounding box list, and elapsed time. | ||
""" | ||
time_profile = {} | ||
start_time = time.time() | ||
data = self.preprocess(img_or_path) | ||
input_np = data["image"] | ||
if len(input_np.shape) == 3: | ||
input_np = Tensor(np.expand_dims(input_np, axis=0)) | ||
|
||
net_pred = self.model(input_np) | ||
shape_list = np.expand_dims(data["shape"], axis=0) | ||
post_result = self.postprocess(net_pred, labels=[shape_list]) | ||
structure_str_list = post_result["structure_batch_list"][0][0] | ||
structure_str_list = ["<html>", "<body>", "<table>"] + structure_str_list + ["</table>", "</body>", "</html>"] | ||
bbox_list = post_result["bbox_batch_list"][0] | ||
elapse = time.time() - start_time | ||
time_profile["structure"] = elapse | ||
|
||
if do_visualize: | ||
vst = time.time() | ||
img_name = os.path.basename(data.get("img_path", "input.png")).rsplit(".", 1)[0] | ||
save_path = os.path.join(self.vis_dir, img_name + "_structure.png") | ||
structure_vis = draw_boxes(img_or_path, bbox_list, draw_type="rectangle") | ||
show_imgs([structure_vis], show=False, save_path=save_path) | ||
time_profile["vis"] = time.time() - vst | ||
return (structure_str_list, bbox_list), time_profile | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
set_logger(name="mindocr") | ||
analyzer = StructureAnalyzer(args) | ||
img_paths = get_image_paths(args.image_dir) | ||
for i, img_path in enumerate(img_paths): | ||
logger.info(f"Inferring {i+1}/{len(img_paths)}: {img_path}") | ||
_ = analyzer(img_path, do_visualize=True) | ||
logger.info(f"Done! All structure results are saved to {args.draw_img_save_dir}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.