From 4d24e0838ce80e80cc6dc5ac7f80a107d5da1941 Mon Sep 17 00:00:00 2001 From: hongziqi <1102229410@qq.com> Date: Tue, 17 Dec 2024 17:35:00 +0800 Subject: [PATCH] update config to add argument:ser_class_dict_path --- tools/infer/text/config.py | 6 ++++++ tools/infer/text/postprocess.py | 2 +- tools/infer/text/predict_ser.py | 6 ++---- tools/infer/text/preprocess.py | 3 ++- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tools/infer/text/config.py b/tools/infer/text/config.py index 47979760e..c82114ddd 100644 --- a/tools/infer/text/config.py +++ b/tools/infer/text/config.py @@ -158,6 +158,12 @@ def create_parser(): type=str, help="directory containing the ser model checkpoint best.ckpt, or path to a specific checkpoint file.", ) + parser.add_argument( + "--ser_class_dict_path", + type=str, + default="./mindocr/utils/dict/class_list_xfun.txt", + help="path to class dictionary for structure recognition. ", + ) parser.add_argument( "--kie_batch_mode", type=str2bool, diff --git a/tools/infer/text/postprocess.py b/tools/infer/text/postprocess.py index fa90dd44a..8a7efb54a 100644 --- a/tools/infer/text/postprocess.py +++ b/tools/infer/text/postprocess.py @@ -79,7 +79,7 @@ def __init__(self, task="det", algo="DB", rec_char_dict_path=None, **kwargs): else: raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.") elif task == "ser": - class_path = "mindocr/utils/dict/class_list_xfun.txt" + class_path = kwargs.get("ser_class_dict_path", "mindocr/utils/dict/class_list_xfun.txt") postproc_cfg = dict(name="VQASerTokenLayoutLMPostProcess", class_path=class_path) elif task == "layout": if algo == "LAYOUTLMV3": diff --git a/tools/infer/text/predict_ser.py b/tools/infer/text/predict_ser.py index f12f915d4..a9b6da755 100644 --- a/tools/infer/text/predict_ser.py +++ b/tools/infer/text/predict_ser.py @@ -68,11 +68,9 @@ def __init__(self, args): ) self.model.set_train(False) - self.preprocess = Preprocessor( - task="ser", - ) + self.preprocess = Preprocessor(task="ser", ser_class_dict_path=args.ser_class_dict_path) - self.postprocess = Postprocessor(task="ser") + self.postprocess = Postprocessor(task="ser", ser_class_dict_path=args.ser_class_dict_path) self.batch_mode = args.kie_batch_mode self.batch_num = args.kie_batch_num diff --git a/tools/infer/text/preprocess.py b/tools/infer/text/preprocess.py index 47c6288dd..d88b3c8c1 100644 --- a/tools/infer/text/preprocess.py +++ b/tools/infer/text/preprocess.py @@ -166,6 +166,7 @@ def __init__(self, task="det", algo="DB", **kwargs): {"ToCHWImage": None}, ] elif task == "ser": + class_path = kwargs.get("ser_class_dict_path", "mindocr/utils/dict/class_list_xfun.txt") pipeline = [ {"DecodeImage": {"img_mode": "RGB", "infer_mode": True, "to_float32": False}}, { @@ -173,7 +174,7 @@ def __init__(self, task="det", algo="DB", **kwargs): "contains_re": False, "infer_mode": True, "algorithm": "LayoutXLM", - "class_path": "mindocr/utils/dict/class_list_xfun.txt", + "class_path": class_path, "order_method": "tb-yx", } },