Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeningLin committed Oct 22, 2022
2 parents 712907a + bb4afd0 commit 50e1ca9
Show file tree
Hide file tree
Showing 21 changed files with 198 additions and 128 deletions.
2 changes: 1 addition & 1 deletion data/SROIE_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __getitem__(self, index):
ocr_text_filter = []
seg_index = 0
for text in ocr_text:
if text == "" or text.isspace:
if text == "" or text.isspace():
continue
curr_tokens = self.tokenizer.tokenize(text.lower())
if len(curr_tokens) == 0:
Expand Down
2 changes: 1 addition & 1 deletion deployment/inference_EPHOIE.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def inference_pipe(
DEVICE,
NUM_CLASSES,
image_bytes=image_bytes,
parse_mode=PARSE_MODE
parse_mode=PARSE_MODE,
)

with open(image_dir.replace(".jpg", ".json"), "w") as f:
Expand Down
2 changes: 1 addition & 1 deletion deployment/inference_preporcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def generate_batch(
parse_mode: str = None,
):
image = Image.open(io.BytesIO(image_bytes))
image = image.convert('RGB')
image = image.convert("RGB")

status_code, return_text_list, return_coor_list = ocr_extraction(
image_bytes=image_bytes, ocr_url=ocr_url, parse_mode=parse_mode
Expand Down
4 changes: 1 addition & 3 deletions deployment/module_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from model.ViBERTgrid_net import ViBERTgridNet


def inference_init(
dir_config: str = "./deployment/config/network_config.yaml"
):
def inference_init(dir_config: str = "./deployment/config/network_config.yaml"):
with open(dir_config, "r") as c:
hyp = yaml.load(c, Loader=yaml.FullLoader)

Expand Down
2 changes: 1 addition & 1 deletion eval_FUNSD.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def evaluation_FUNSD(
pred_gt_dict.update({pred_label.detach(): gt_label.detach()})

p, r, f, report = BIO_F1_criteria(
pred_gt_dict=pred_gt_dict, tag_to_idx=TAG_TO_IDX, average="macro"
pred_gt_list=pred_gt_dict, tag_to_idx=TAG_TO_IDX, average="macro"
)

return p, r, f, report
Expand Down
2 changes: 1 addition & 1 deletion model/BERTgrid_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def BERT_embedding(
curr_batch_aggre_embeddings.append(curr_embedding.unsqueeze(0))

prev_seg_index = curr_seg_index.int().item()

if self.grid_mode == "mean":
mean_embeddings /= num_tok
curr_batch_aggre_embeddings.append(mean_embeddings.unsqueeze(0))
Expand Down
14 changes: 10 additions & 4 deletions model/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def argmax(vec):
_, idx = torch.max(vec, 1)
return idx.item()


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
max_score = vec[0, argmax(vec)]
Expand All @@ -43,7 +44,6 @@ def __init__(self, tag_to_ix):
self.transitions.data[tag_to_ix[START_TAG], :] = -10000
self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000


def _forward_alg(self, feats):
device = self.transitions.device

Expand Down Expand Up @@ -81,7 +81,12 @@ def _score_sentence(self, feats, tags):
# Gives the score of a provided tag sequence
score = torch.zeros(1, device=device)
tags = torch.cat(
[torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long, device=device), tags]
[
torch.tensor(
[self.tag_to_ix[START_TAG]], dtype=torch.long, device=device
),
tags,
]
)
for i, feat in enumerate(feats):
score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
Expand All @@ -94,7 +99,9 @@ def _viterbi_decode(self, feats):
backpointers = []

# Initialize the viterbi variables in log space
init_vvars = torch.full((1, self.tagset_size), -10000.0, device=self.transitions.device)
init_vvars = torch.full(
(1, self.tagset_size), -10000.0, device=self.transitions.device
)
init_vvars[0][self.tag_to_ix[START_TAG]] = 0

# forward_var at step i holds the viterbi variables for step i-1
Expand Down Expand Up @@ -148,4 +155,3 @@ def inference(self, feats): # dont confuse this with _forward_alg above.
# Find the best path, given the features.
score, tag_seq = self._viterbi_decode(feats)
return score, tag_seq

2 changes: 1 addition & 1 deletion model/field_type_classification_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def __init__(
type of classifier, `single` for a single layer perceptron, `multi` for a MLP
work_mode: str, optional
work mode of the model, controls the return values, `train`, `eval` or `inference`
"""
super().__init__()

Expand Down
11 changes: 6 additions & 5 deletions model/semantic_segmentation_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ def forward(self, x):
class SemanticSegmentationBinaryClassifier(nn.Module):
"""binaly classifier used in auxiliary semantic segmentation head
Parameters
----------
in_channels : int
number of channels of the input feature
Parameters
----------
in_channels : int
number of channels of the input feature
"""

def __init__(self, in_channels: int) -> None:

super().__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)

Expand Down
23 changes: 15 additions & 8 deletions pipeline/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
classification_report,
)

from typing import Dict
from typing import Dict, List, Tuple


@torch.no_grad()
Expand All @@ -22,12 +22,16 @@ def token_classification_criteria(gt_label: torch.Tensor, pred_label: torch.Tens


@torch.no_grad()
def BIO_F1_criteria(pred_gt_dict: Dict[torch.Tensor, torch.Tensor], tag_to_idx: Dict, average: str = "micro"):
def BIO_F1_criteria(
pred_gt_list: List[Tuple[torch.Tensor, torch.Tensor]],
tag_to_idx: Dict,
average: str = "micro",
):
idx_to_tag = {v: k for k, v in tag_to_idx.items()}

pred_list = list()
label_list = list()
for pred, label in pred_gt_dict.items():
for (pred, label) in pred_gt_list:
if len(pred.shape) != 1 and pred.shape[1] != 1:
pred = pred.argmax(dim=1)
if len(pred.shape) != 1:
Expand All @@ -49,11 +53,14 @@ def BIO_F1_criteria(pred_gt_dict: Dict[torch.Tensor, torch.Tensor], tag_to_idx:


@torch.no_grad()
def token_F1_criteria(pred_gt_dict: Dict[torch.Tensor, torch.Tensor]):
pred_label: torch.Tensor
gt_label: torch.Tensor
pred_label = torch.cat(list(pred_gt_dict.keys()), dim=0)
gt_label = torch.cat(list(pred_gt_dict.values()), dim=0)
def token_F1_criteria(pred_gt_list: List[Tuple[torch.Tensor, torch.Tensor]]):
pred_label = list()
gt_label = list()
for item in pred_gt_list:
pred_label.append(item[0])
gt_label.append(item[1])
pred_label = torch.cat(pred_label, dim=0)
gt_label = torch.cat(gt_label, dim=0)

num_classes = pred_label.shape[1]
pred_label = pred_label.int()
Expand Down
1 change: 1 addition & 0 deletions pipeline/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def reduce_loss(loss, average=True):

return loss


def is_dist_avail_and_initialized():
if not torch.distributed.is_available():
return False
Expand Down
5 changes: 2 additions & 3 deletions pipeline/funsd_data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def annotation_parsing_word(dir_annotation: str, dir_save: str):
top = coors[1]
right = coors[2]
bot = coors[3]


curr_row_dict = {
"left": [left],
Expand Down Expand Up @@ -64,7 +63,7 @@ def annotation_parsing_seg(dir_annotation: str, dir_save: str):
seg_text = Literal["N/A"]
if seg_text == "NA":
seg_text = Literal["NA"]

data_class = seg["label"]
pos_neg = 2 if data_class == 0 else 1

Expand Down Expand Up @@ -129,4 +128,4 @@ def run_annotation_parser(dir_funsd_root: str, mode: str):
parser.add_argument("--mode", type=str, help="label data level, word or seg")
args = parser.parse_args()

run_annotation_parser(args.root, args.mode)
run_annotation_parser(args.root, args.mode)
2 changes: 1 addition & 1 deletion pipeline/sroie_data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def data_parser_multiprocessing(
|___img: images
|___box: txt files that contain OCR results
|___key: txt files that contain key info labels
___test_raw
|___img: images
|___box: txt files that contain OCR results
Expand Down
27 changes: 16 additions & 11 deletions pipeline/train_val_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def validate(
method_precision_sum = torch.zeros(1, device=device)

model.eval()
pred_gt_dict = dict()
pred_gt_list = list()
mean_validate_loss = torch.zeros(1).to(device)
for step, validate_batch in enumerate(validate_loader):
(
Expand Down Expand Up @@ -517,7 +517,7 @@ def validate(
num_gt += torch.tensor(curr_num_gt, device=device)
num_det += torch.tensor(curr_num_det, device=device)

pred_gt_dict.update({pred_label.detach(): gt_label.detach()})
pred_gt_list.append((pred_label.detach(), gt_label.detach()))

validate_loss = reduce_loss(validate_loss)
validate_loss_value = validate_loss.item()
Expand All @@ -541,14 +541,19 @@ def validate(
torch.distributed.all_reduce(method_precision_sum)
torch.distributed.all_reduce(method_recall_sum)

pred_gt_dict_syn = [None for _ in range(num_proc)]
pred_gt_list_syn = [None for _ in range(num_proc)]
torch.distributed.all_gather_object(
object_list=pred_gt_dict_syn, obj=pred_gt_dict
object_list=pred_gt_list_syn, obj=pred_gt_list
)
pred_gt_dict_ = dict()
for p_g_d in pred_gt_dict_syn:
for k, v in p_g_d.items():
pred_gt_dict_.update({k: v})
pred_gt_list_ = list()
for p_g_d in pred_gt_list_syn:
for p_g_item in p_g_d:
pred_gt_list_.append(p_g_item)
del pred_gt_list_syn
else:
pred_gt_list_ = pred_gt_list

del pred_gt_list

num_gt = int(num_gt.item())
num_det = int(num_det.item())
Expand All @@ -558,7 +563,7 @@ def validate(
if eval_mode == "seqeval":
assert tag_to_idx is not None
precision, recall, F1, report = BIO_F1_criteria(
pred_gt_dict=pred_gt_dict_, tag_to_idx=tag_to_idx, average=seqeval_average
pred_gt_list=pred_gt_list_, tag_to_idx=tag_to_idx, average=seqeval_average
)
print(report)
print(
Expand Down Expand Up @@ -588,7 +593,7 @@ def validate(
elif eval_mode == "seq_and_str":
assert tag_to_idx is not None
token_precision, token_recall, token_F1, report = BIO_F1_criteria(
pred_gt_dict=pred_gt_dict_, tag_to_idx=tag_to_idx, average=seqeval_average
pred_gt_list=pred_gt_list_, tag_to_idx=tag_to_idx, average=seqeval_average
)
print("==> token level result")
print(report)
Expand Down Expand Up @@ -620,7 +625,7 @@ def validate(

else:
result_dict: Dict
result_dict = token_F1_criteria(pred_gt_dict=pred_gt_dict_)
result_dict = token_F1_criteria(pred_gt_list=pred_gt_list_)
num_classes = result_dict["num_classes"]
precision = 0.0
recall = 0.0
Expand Down
5 changes: 1 addition & 4 deletions pipeline/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,7 @@ def return_batch(

# 创建shape为batch_shape且值全部为0的tensor
batched_imgs = images[0].new_full(batch_image_shape, 0)
for (
img,
pad_img,
) in zip(
for (img, pad_img,) in zip(
images,
batched_imgs,
):
Expand Down
35 changes: 26 additions & 9 deletions train_EPHOIE.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch
from transformers import BertTokenizer, RobertaTokenizer

from data.EPHOIE_dataset import load_train_dataset_multi_gpu as EPHOIE_load_train
from data.EPHOIE_dataset import load_train_dataset_multi_gpu as EPHOIE_load_train_multi
from data.EPHOIE_dataset import load_train_dataset as EPHOIE_load_train
from model.ViBERTgrid_net import ViBERTgridNet
from pipeline.train_val_utils import (
train_one_epoch,
Expand Down Expand Up @@ -155,6 +156,11 @@ def train(args):
eval_mode = hyp["eval_mode"]
tag_mode = hyp["tag_mode"]

if classifier_mode == "crf":
assert (
eval_mode == "seqeval"
), "When using the crf classifier, only the seqeval metric is available"

if tag_mode == "BIO":
map_dict = TAG_TO_IDX_BIO
else:
Expand All @@ -170,12 +176,21 @@ def train(args):
print(f"==> tokenizer {bert_version} loaded")

print(f"==> loading datasets")
train_loader, val_loader, train_sampler = EPHOIE_load_train(
root=data_root,
batch_size=batch_size,
num_workers=num_workers,
tokenizer=tokenizer,
)
if args.distributed:
train_loader, val_loader, train_sampler = EPHOIE_load_train_multi(
root=data_root,
batch_size=batch_size,
num_workers=num_workers,
tokenizer=tokenizer,
)
else:
train_loader, val_loader = EPHOIE_load_train(
root=data_root,
batch_size=batch_size,
num_workers=num_workers,
tokenizer=tokenizer,
)

print(f"==> dataset loaded")

print(f"==> creating model {backbone} | {bert_version}")
Expand Down Expand Up @@ -296,7 +311,7 @@ def train(args):
f"{curr_time.tm_year:04d}-{curr_time.tm_mon:02d}-{curr_time.tm_mday:02d}"
)
curr_time_h += (
f"_{curr_time.tm_hour:02d}:{curr_time.tm_min:02d}:{curr_time.tm_sec:02d}"
f"_{curr_time.tm_hour:02d}-{curr_time.tm_min:02d}-{curr_time.tm_sec:02d}"
)
comment = (
comment_exp
Expand All @@ -321,6 +336,7 @@ def train(args):
device=device,
epoch=0,
logger=logger,
distributed=args.distributed,
eval_mode=eval_mode,
tag_to_idx=map_dict,
category_list=EPHOIE_CLASS_LIST,
Expand Down Expand Up @@ -361,6 +377,7 @@ def train(args):
device=device,
epoch=epoch,
logger=logger,
distributed=args.distributed,
eval_mode=eval_mode,
tag_to_idx=map_dict,
category_list=EPHOIE_CLASS_LIST,
Expand All @@ -371,7 +388,7 @@ def train(args):
if F1 > top_F1:
top_F1 = F1

if F1 > top_F1_tresh or (epoch % 400 == 0 and epoch != start_epoch):
if F1 > top_F1_tresh or (epoch % 10 == 0 and epoch != start_epoch):
top_F1_tresh = F1
if save_top is not None:
if not os.path.exists(save_top) and is_main_process():
Expand Down
Loading

0 comments on commit 50e1ca9

Please sign in to comment.