Charles Chan
commited on
Commit
·
edfb894
1
Parent(s):
c2aa18b
coding
Browse files
app.py
CHANGED
@@ -15,13 +15,13 @@ if "data_list" not in st.session_state:
|
|
15 |
if not st.session_state.data_list:
|
16 |
try:
|
17 |
with st.spinner("正在读取数据库..."):
|
18 |
-
|
19 |
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
20 |
data_list = []
|
21 |
answer_list = []
|
22 |
for example in dataset["train"]:
|
23 |
-
converted_answer =
|
24 |
-
converted_question =
|
25 |
answer_list.append(converted_answer)
|
26 |
data_list.append({"Question": converted_question, "Answer": converted_answer})
|
27 |
st.session_state.answer_list = answer_list
|
@@ -63,7 +63,7 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
63 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
64 |
try:
|
65 |
with st.spinner("正在初始化 Gemma 模型..."):
|
66 |
-
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
|
67 |
st.success("Gemma 模型初始化完成!")
|
68 |
print("Gemma 模型初始化完成!")
|
69 |
st.session_state.repo_id = repo_id
|
@@ -91,7 +91,7 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
91 |
print("本地数据集筛选完成!")
|
92 |
|
93 |
with st.spinner("正在生成答案..."):
|
94 |
-
answer = llm.invoke(prompt)
|
95 |
# 去掉 prompt 的内容
|
96 |
answer = answer.replace(prompt, "").strip()
|
97 |
st.success("答案已经生成!")
|
@@ -113,6 +113,13 @@ with col2:
|
|
113 |
|
114 |
st.divider()
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
col3, col4 = st.columns(2)
|
117 |
with col3:
|
118 |
if st.button("使用原数据集中的随机问题"):
|
@@ -120,9 +127,7 @@ with col3:
|
|
120 |
random_index = random.randint(0, dataset_size - 1)
|
121 |
# 读取随机问题
|
122 |
random_question = st.session_state.data_list[random_index]["Question"]
|
123 |
-
random_question = st.session_state.converter.convert(random_question)
|
124 |
origin_answer = st.session_state.data_list[random_index]["Answer"]
|
125 |
-
origin_answer = st.session_state.converter.convert(origin_answer)
|
126 |
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
|
127 |
print('origin_answer: ' + origin_answer)
|
128 |
|
@@ -130,20 +135,12 @@ with col3:
|
|
130 |
st.write(random_question)
|
131 |
st.write("原始答案:")
|
132 |
st.write(origin_answer)
|
133 |
-
|
134 |
-
print('prompt: ' + result["prompt"])
|
135 |
-
print('answer: ' + result["answer"])
|
136 |
-
st.write("生成答案:")
|
137 |
-
st.write(result["answer"])
|
138 |
|
139 |
with col4:
|
140 |
-
question = st.text_area("请输入问题", "
|
141 |
if st.button("提交输入的问题"):
|
142 |
if not question:
|
143 |
st.warning("请输入问题!")
|
144 |
else:
|
145 |
-
|
146 |
-
print('prompt: ' + result["prompt"])
|
147 |
-
print('answer: ' + result["answer"])
|
148 |
-
st.write("生成答案:")
|
149 |
-
st.write(result["answer"])
|
|
|
15 |
if not st.session_state.data_list:
|
16 |
try:
|
17 |
with st.spinner("正在读取数据库..."):
|
18 |
+
converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
|
19 |
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
20 |
data_list = []
|
21 |
answer_list = []
|
22 |
for example in dataset["train"]:
|
23 |
+
converted_answer = converter.convert(example["Answer"])
|
24 |
+
converted_question = converter.convert(example["Question"])
|
25 |
answer_list.append(converted_answer)
|
26 |
data_list.append({"Question": converted_question, "Answer": converted_answer})
|
27 |
st.session_state.answer_list = answer_list
|
|
|
63 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
64 |
try:
|
65 |
with st.spinner("正在初始化 Gemma 模型..."):
|
66 |
+
st.session_state.llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
|
67 |
st.success("Gemma 模型初始化完成!")
|
68 |
print("Gemma 模型初始化完成!")
|
69 |
st.session_state.repo_id = repo_id
|
|
|
91 |
print("本地数据集筛选完成!")
|
92 |
|
93 |
with st.spinner("正在生成答案..."):
|
94 |
+
answer = st.session_state.llm.invoke(prompt)
|
95 |
# 去掉 prompt 的内容
|
96 |
answer = answer.replace(prompt, "").strip()
|
97 |
st.success("答案已经生成!")
|
|
|
113 |
|
114 |
st.divider()
|
115 |
|
116 |
+
def generate_answer(repo_id, temperature, max_length, question):
|
117 |
+
result = answer_question(repo_id, float(temperature), int(max_length), question)
|
118 |
+
print('prompt: ' + result["prompt"])
|
119 |
+
print('answer: ' + result["answer"])
|
120 |
+
st.write("生成答案:")
|
121 |
+
st.write(result["answer"])
|
122 |
+
|
123 |
col3, col4 = st.columns(2)
|
124 |
with col3:
|
125 |
if st.button("使用原数据集中的随机问题"):
|
|
|
127 |
random_index = random.randint(0, dataset_size - 1)
|
128 |
# 读取随机问题
|
129 |
random_question = st.session_state.data_list[random_index]["Question"]
|
|
|
130 |
origin_answer = st.session_state.data_list[random_index]["Answer"]
|
|
|
131 |
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
|
132 |
print('origin_answer: ' + origin_answer)
|
133 |
|
|
|
135 |
st.write(random_question)
|
136 |
st.write("原始答案:")
|
137 |
st.write(origin_answer)
|
138 |
+
generate_answer(gemma, float(temperature), int(max_length), random_question)
|
|
|
|
|
|
|
|
|
139 |
|
140 |
with col4:
|
141 |
+
question = st.text_area("请输入问题", "《进击的巨人》中都有哪些主要角色?")
|
142 |
if st.button("提交输入的问题"):
|
143 |
if not question:
|
144 |
st.warning("请输入问题!")
|
145 |
else:
|
146 |
+
generate_answer(gemma, float(temperature), int(max_length), question)
|
|
|
|
|
|
|
|