Skip to content

Commit

Permalink
fix dbnet eval on ctw1500
Browse files Browse the repository at this point in the history
  • Loading branch information
alien-0119 committed Nov 15, 2024
1 parent dda7b75 commit 6413f29
Show file tree
Hide file tree
Showing 15 changed files with 995 additions and 2 deletions.
Binary file added configs/table/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added configs/table/example_structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mindocr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .rec_robustscanner import *
from .rec_svtr import *
from .rec_visionlan import *
from .table_master import *

__all__ = []
__all__.extend(builder.__all__)
Expand Down
57 changes: 57 additions & 0 deletions mindocr/models/table_master.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from ._registry import register_model
from .backbones.mindcv_models.utils import load_pretrained
from .base_model import BaseModel


def _cfg(url="", **kwargs):
return {"url": url, **kwargs}


default_cfgs = {
"table_master": _cfg(
url="https://download-mindspore.osinfra.cn/toolkits/mindocr/tablemaster/table_master-78bf35bb.ckpt"
),
}


class TableMaster(BaseModel):
def __init__(self, config):
BaseModel.__init__(self, config)


@register_model
def table_master(
pretrained: bool = True,
**kwargs
):
model_config = {
"type": "table",
"transform": None,
"backbone": {
"name": "table_resnet_extra",
"gcb_config": {
"ratio": 0.0625,
"headers": 1,
"att_scale": False,
"fusion_type": "channel_add",
"layers": [False, True, True, True],
},
"layers": [1, 2, 5, 3],
},
"head": {
"name": "TableMasterHead",
"out_channels": 43,
"hidden_size": 512,
"headers": 8,
"dropout": 0.0,
"d_ff": 2024,
"max_text_length": 500,
"loc_reg_num": 4
},
}
model = TableMaster(model_config)
if pretrained:
default_cfg = default_cfgs["table_master"]
load_pretrained(model, default_cfg)

return model
19 changes: 18 additions & 1 deletion mindocr/postprocess/det_db_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):
continue

poly = Polygon(points)
poly = np.array(expand_poly(points, distance=poly.area * self._expand_ratio / poly.length))
poly_list = expand_poly(points, distance=poly.area * self._expand_ratio / poly.length)
# fix the problem that np.array cannot handle the list with two sublists of different length from numpy 1.22
if self._is_uneven_nested_list(poly_list):
poly = np.array(poly_list, dtype=object)
else:
poly = np.array(poly_list)
if self._out_poly and len(poly) > 1:
continue
poly = poly.reshape(-1, 2)
Expand All @@ -134,6 +139,18 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):
return polys, scores
return np.array(polys), np.array(scores).astype(np.float32)

def _is_uneven_nested_list(self, arr_list):
if not isinstance(arr_list, list):
return False

first_length = len(arr_list[0]) if isinstance(arr_list[0], list) else None

for sublist in arr_list:
if not isinstance(sublist, list) or len(sublist) != first_length:
return True

return False

@staticmethod
def _fit_box(contour):
"""
Expand Down
2 changes: 2 additions & 0 deletions mindocr/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def show_img(img: np.array, is_bgr_img=True, title="img", show=True, save_path=N
if show:
plt.show()
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path) # , bbox_inches='tight', dpi=460)


Expand Down Expand Up @@ -73,6 +74,7 @@ def show_imgs(
plt.show()

if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, bbox_inches="tight", dpi=300, pad_inches=0)


Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ requests>=2.31.0
pycocotools>=2.0.2
setuptools-scm
albumentations
beautifulsoup4
pandas
tablepyxl
lxml
30 changes: 30 additions & 0 deletions tools/infer/text/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,36 @@ def create_parser():
"due to padding or resizing to the same shape.",
)
parser.add_argument("--kie_batch_num", type=int, default=8)
parser.add_argument(
"--table_algorithm",
type=str,
default="TABLE_MASTER",
choices=["TABLE_MASTER"],
help="table structure recognition algorithm",
)
parser.add_argument(
"--table_model_dir",
type=str,
help="directory containing the table structure recognition model checkpoint best.ckpt, "
"or path to a specific checkpoint file.",
)
parser.add_argument(
"--table_amp_level",
type=str,
default="O2",
choices=["O0", "O1", "O2", "O3"],
help="Auto Mixed Precision level. This setting only works on GPU and Ascend",
)
parser.add_argument(
"--table_char_dict_path",
type=str,
default="./mindocr/utils/dict/table_master_structure_dict.txt",
help="path to character dictionary for table structure recognition. "
"If None, will pick according to table_algorithm and table_model_dir.",
)
parser.add_argument(
"--table_max_len", type=int, default=480, help="max length of the input image for table structure recognition."
)

return parser

Expand Down
13 changes: 13 additions & 0 deletions tools/infer/text/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def __init__(self, task="det", algo="DB", rec_char_dict_path=None, **kwargs):
elif task == "ser":
class_path = "mindocr/utils/dict/class_list_xfun.txt"
postproc_cfg = dict(name="VQASerTokenLayoutLMPostProcess", class_path=class_path)
elif task == "table":
table_char_dict_path = kwargs.get(
"table_char_dict_path", "mindocr/utils/dict/table_master_structure_dict.txt"
)
postproc_cfg = dict(
name="TableMasterLabelDecode",
character_dict_path=table_char_dict_path,
merge_no_span_structure=True,
box_shape="pad",
)

postproc_cfg.update(kwargs)
self.task = task
Expand Down Expand Up @@ -142,3 +152,6 @@ def __call__(self, pred, data=None, **kwargs):
pred, segment_offset_ids=kwargs.get("segment_offset_ids"), ocr_infos=kwargs.get("ocr_infos")
)
return output
elif self.task == "table":
output = self.postprocess(pred, labels=kwargs.get("labels"))
return output
118 changes: 118 additions & 0 deletions tools/infer/text/predict_table_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Infer table from images with structure model and ocr model.
Example:
$ python tools/infer/text/predict_table_recognition.py --image_dir {path_to_img} --table_algorithm TABLE_MASTER
"""
import logging
import os
import sys
from typing import Union

import cv2
import numpy as np
from config import parse_args
from predict_system import TextSystem
from predict_table_structure import StructureAnalyzer
from utils import TableMasterMatcher, get_image_paths

from mindocr.utils.logger import set_logger

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))

logger = logging.getLogger("mindocr")


class TableAnalyzer:
"""
Model inference class for table structure analysis and match with ocr result.
Example:
>>> args = parse_args()
>>> analyzer = TableAnalyzer(args)
>>> img_path = "path/to/image.jpg"
>>> pred_html, time_prof = analyzer(img_path)
"""

def __init__(self, args):
self.text_system = TextSystem(args)
self.table_structure = StructureAnalyzer(args)
self.match = TableMasterMatcher()

def _structure(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
structure_res, elapse = self.table_structure(img_or_path, do_visualize)
return structure_res, elapse

def _text_ocr(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
boxes, text_scores, time_prof = self.text_system(img_or_path, do_visualize)
if isinstance(img_or_path, str):
img = cv2.imread(img_or_path)
elif isinstance(img_or_path, np.ndarray):
img = img_or_path
else:
raise ValueError("Invalid input type, should be str or np.ndarray.")
h, w = img.shape[:2]
r_boxes = []
for box in boxes:
x_min = max(0, box[:, 0].min() - 1)
x_max = min(w, box[:, 0].max() + 1)
y_min = max(0, box[:, 1].min() - 1)
y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, text_scores, time_prof

def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize: bool = True):
boxes, text_scores, ocr_time_prof = self._text_ocr(img_or_path, do_visualize)
structure_res, struct_time_prof = self._structure(img_or_path, do_visualize)
pred_html = self.match(structure_res, boxes, text_scores)
time_prof = {
"ocr": ocr_time_prof,
"table": struct_time_prof,
}
return pred_html, time_prof


def parse_html_table(html_table):
from bs4 import BeautifulSoup

soup = BeautifulSoup(html_table, "html.parser")
table = soup.find("table")
if not table:
raise ValueError("No table found in the HTML string.")
return table


def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl

table = parse_html_table(html_table)
tablepyxl.document_to_xl(str(table), excel_path)


def to_csv(html_table, csv_path):
import pandas as pd

table = parse_html_table(html_table)
df = pd.read_html(str(table))[0]
df.to_csv(csv_path, index=False)


def main():
args = parse_args()
set_logger(name="mindocr")
analyzer = TableAnalyzer(args)
img_paths = get_image_paths(args.image_dir)
save_dir = args.draw_img_save_dir
for i, img_path in enumerate(img_paths):
logger.info(f"Infering {i+1}/{len(img_paths)}: {img_path}")
pred_html, time_prof = analyzer(img_path, do_visualize=True)
logger.info(f"Time profile: {time_prof}")
img_name = os.path.basename(img_path).rsplit(".", 1)[0]
to_csv(pred_html, os.path.join(save_dir, f"{img_name}.csv"))
logger.info(f"Done! All structure results are saved to {args.draw_img_save_dir}")


if __name__ == "__main__":
main()
Loading

0 comments on commit 6413f29

Please sign in to comment.