Skip to content

Commit

Permalink
update infer.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Oct 14, 2024
1 parent a0ebc2f commit 9d6c2ca
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

**pycorrector**: 中文文本纠错工具。支持中文音似、形似、语法错误纠正,python3.8开发。

**pycorrector**实现了Kenlm、ConvSeq2Seq、BERT、MacBERT、ELECTRA、ERNIE、Transformer等多种模型的文本纠错,并在SigHAN数据集评估各模型的效果
**pycorrector**实现了Kenlm、ConvSeq2Seq、BERT、MacBERT、ELECTRA、ERNIE、GPT等多种模型的文本纠错,评估各模型的效果

**Guide**

Expand All @@ -43,8 +43,8 @@
本项目重点解决其中的"音似、形字、语法、专名错误"等类型。

## News
[2024/10/14] v1.1.0版本:新增了基于Qwen2.5的中文文本纠错模型,支持多字、少字、错字、词序、语法等错误纠正,发布了[shibing624/chinese-text-correction-1.5b](https://huggingface.co/shibing624/chinese-text-correction-1.5b)[shibing624/chinese-text-correction-7b](https://huggingface.co/shibing624/chinese-text-correction-7b)模型,及其对应的LoRA模型。详见[Release-v1.1.0](https://github.com/shibing624/pycorrector/releases/tag/1.1.0)
[2023/11/07] v1.0.0版本:新增了ChatGLM3/LLaMA2等GPT模型用于中文文本纠错,发布了基于ChatGLM3-6B的[shibing624/chatglm3-6b-csc-chinese-lora](https://huggingface.co/shibing624/chatglm3-6b-csc-chinese-lora)拼写和语法纠错模型;重写了DeepContext、ConvSeq2Seq、T5等模型的实现。详见[Release-v1.0.0](https://github.com/shibing624/pycorrector/releases/tag/1.0.0)
- [2024/10/14] v1.1.0版本:新增了基于Qwen2.5的中文文本纠错模型,支持多字、少字、错字、词序、语法等错误纠正,发布了[shibing624/chinese-text-correction-1.5b](https://huggingface.co/shibing624/chinese-text-correction-1.5b)[shibing624/chinese-text-correction-7b](https://huggingface.co/shibing624/chinese-text-correction-7b)模型,及其对应的LoRA模型。详见[Release-v1.1.0](https://github.com/shibing624/pycorrector/releases/tag/1.1.0)
- [2023/11/07] v1.0.0版本:新增了ChatGLM3/LLaMA2等GPT模型用于中文文本纠错,发布了基于ChatGLM3-6B的[shibing624/chatglm3-6b-csc-chinese-lora](https://huggingface.co/shibing624/chatglm3-6b-csc-chinese-lora)拼写和语法纠错模型;重写了DeepContext、ConvSeq2Seq、T5等模型的实现。详见[Release-v1.0.0](https://github.com/shibing624/pycorrector/releases/tag/1.0.0)


## Features
Expand All @@ -55,7 +55,7 @@
* [T5模型](https://github.com/shibing624/pycorrector/tree/master/examples/t5):本项目基于PyTorch实现了用于中文文本纠错的T5模型,使用Langboat/mengzi-t5-base的预训练模型finetune中文纠错数据集,模型改造的潜力较大,效果好
* [ERNIE_CSC模型](https://github.com/shibing624/pycorrector/tree/master/examples/ernie_csc):本项目基于PaddlePaddle实现了用于中文文本纠错的ERNIE_CSC模型,模型在ERNIE-1.0上finetune,模型结构适配了中文拼写纠错任务,效果好
* [MacBERT模型](https://github.com/shibing624/pycorrector/tree/master/examples/macbert)【推荐】:本项目基于PyTorch实现了用于中文文本纠错的MacBERT4CSC模型,模型加入了错误检测和纠正网络,适配中文拼写纠错任务,效果好
* [GPT模型](https://github.com/shibing624/pycorrector/tree/master/examples/gpt):本项目基于PyTorch实现了用于中文文本纠错的ChatGLM/LLaMA模型,模型在中文CSC和语法纠错数据集上finetune,适配中文文本纠错任务,效果好
* [GPT模型](https://github.com/shibing624/pycorrector/tree/master/examples/gpt):本项目基于PyTorch实现了用于中文文本纠错的ChatGLM/Qwen模型,模型在中文CSC和语法纠错数据集上finetune,适配中文文本纠错任务,效果很好

- 延展阅读:[中文文本纠错实践和原理解读](https://github.com/shibing624/pycorrector/blob/master/docs/correction_solution.md)
## Demo
Expand Down
14 changes: 5 additions & 9 deletions pycorrector/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,19 +571,15 @@ def predict(
return_tensors='pt',
padding=True,
)
outputs = self.model.generate(inputs.to(self.device), **generation_kwargs, **kwargs)
input_ids = inputs.to(self.device)
outputs = self.model.generate(input_ids, **generation_kwargs, **kwargs)

for input_text, generated_sequence in zip(batch, outputs):
# Decode text
prompt_len = len(input_ids[0])
generated_sequence = generated_sequence[prompt_len:]
gen_text = self.tokenizer.decode(generated_sequence, skip_special_tokens=True)
stop_str = self.tokenizer.eos_token or prompt_template.stop_str
pos = gen_text.find(stop_str)
if pos != -1:
gen_text = gen_text[:pos]
if skip_prompt:
gen_text = gen_text.split(input_text, 1)[-1]
if gen_text.startswith("\nassistant\n"):
gen_text = gen_text.split("\nassistant\n", 1)[-1]
# logger.error(f"input_text: {input_text}, gen_text: {gen_text}")
all_outputs.append(gen_text)

return all_outputs
Expand Down

0 comments on commit 9d6c2ca

Please sign in to comment.