-
Notifications
You must be signed in to change notification settings - Fork 0
/
frontapp.py
87 lines (73 loc) · 3.38 KB
/
frontapp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import time
import subprocess
import streamlit as st
from generate import gen_sample
def subprocess_getoutput(stmt):
result = subprocess.getoutput(stmt)
# 执行失败不需要特殊处理,因为该方法无法判断失败成功,只负责将结果进行返回
return result # 返回执行结果,但是结果返回的是一个str字符串(不论有多少行)
def writer():
st.set_page_config(
page_title="古诗生成 DEMO"
)
st.markdown(
"""
## 古诗生成 DEMO
"""
)
st.sidebar.subheader("配置参数")
option = st.sidebar.selectbox(
'诗歌格式',
['五言绝句', '五言律诗', '七言绝句', '七言律诗'])
num = st.sidebar.number_input('生成诗歌数量', min_value=1, max_value=10, step=1)
title_num = 0
full_length = 65
if option == '五言绝句':
title_num = 4
elif option == '五言律诗':
title_num = 8
elif option == '七言绝句':
title_num = 4
else:
title_num = 8
title = st.sidebar.text_input('藏头字', placeholder='对应格式的藏头字数量:最多{}个字'.format(title_num))
if option == '五言绝句':
full_length = 26 + (4 - len(title))
elif option == '五言律诗':
full_length = 50 + (8 - len(title))
elif option == '七言绝句':
full_length = 34 + (4 - len(title))
else:
full_length = 65 + (8 - len(title))
parser = argparse.ArgumentParser()
parser.add_argument('--length', default=full_length, type=int, help='生成文本长度')
parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size')
parser.add_argument('--nsamples', default=num, type=int, help='生成样本数量')
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一')
parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
help='模型参数')
parser.add_argument('--tokenizer_path', default='vocab/vocab_guwen.txt', type=str, required=False, help='词表路径')
parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径')
parser.add_argument('--prefix', default=option + '[SEP]' + title, type=str, help='生成文本前缀')
parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本')
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)
args = parser.parse_args()
if st.button("点击开始生成"):
start_message = st.empty()
start_message.write("正在作诗,请等待...")
start_time = time.time()
content = gen_sample(args)
print(content)
end_time = time.time()
start_message.write("作诗完成,耗时{}s".format(end_time - start_time))
for i in range(num):
st.text("第{}个结果".format(i + 1))
st.text(content[i])
else:
st.stop()
if __name__ == '__main__':
writer()