Charles Chan commited on
Commit
edfb894
·
1 Parent(s): c2aa18b
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -15,13 +15,13 @@ if "data_list" not in st.session_state:
15
  if not st.session_state.data_list:
16
  try:
17
  with st.spinner("正在读取数据库..."):
18
- st.session_state.converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
19
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
20
  data_list = []
21
  answer_list = []
22
  for example in dataset["train"]:
23
- converted_answer = st.session_state.converter.convert(example["Answer"])
24
- converted_question = st.session_state.converter.convert(example["Question"])
25
  answer_list.append(converted_answer)
26
  data_list.append({"Question": converted_question, "Answer": converted_answer})
27
  st.session_state.answer_list = answer_list
@@ -63,7 +63,7 @@ def answer_question(repo_id, temperature, max_length, question):
63
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
64
  try:
65
  with st.spinner("正在初始化 Gemma 模型..."):
66
- llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
67
  st.success("Gemma 模型初始化完成!")
68
  print("Gemma 模型初始化完成!")
69
  st.session_state.repo_id = repo_id
@@ -91,7 +91,7 @@ def answer_question(repo_id, temperature, max_length, question):
91
  print("本地数据集筛选完成!")
92
 
93
  with st.spinner("正在生成答案..."):
94
- answer = llm.invoke(prompt)
95
  # 去掉 prompt 的内容
96
  answer = answer.replace(prompt, "").strip()
97
  st.success("答案已经生成!")
@@ -113,6 +113,13 @@ with col2:
113
 
114
  st.divider()
115
 
 
 
 
 
 
 
 
116
  col3, col4 = st.columns(2)
117
  with col3:
118
  if st.button("使用原数据集中的随机问题"):
@@ -120,9 +127,7 @@ with col3:
120
  random_index = random.randint(0, dataset_size - 1)
121
  # 读取随机问题
122
  random_question = st.session_state.data_list[random_index]["Question"]
123
- random_question = st.session_state.converter.convert(random_question)
124
  origin_answer = st.session_state.data_list[random_index]["Answer"]
125
- origin_answer = st.session_state.converter.convert(origin_answer)
126
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
127
  print('origin_answer: ' + origin_answer)
128
 
@@ -130,20 +135,12 @@ with col3:
130
  st.write(random_question)
131
  st.write("原始答案:")
132
  st.write(origin_answer)
133
- result = answer_question(gemma, float(temperature), int(max_length), random_question)
134
- print('prompt: ' + result["prompt"])
135
- print('answer: ' + result["answer"])
136
- st.write("生成答案:")
137
- st.write(result["answer"])
138
 
139
  with col4:
140
- question = st.text_area("请输入问题", "Gemma 有哪些特点?")
141
  if st.button("提交输入的问题"):
142
  if not question:
143
  st.warning("请输入问题!")
144
  else:
145
- result = answer_question(gemma, float(temperature), int(max_length), question)
146
- print('prompt: ' + result["prompt"])
147
- print('answer: ' + result["answer"])
148
- st.write("生成答案:")
149
- st.write(result["answer"])
 
15
  if not st.session_state.data_list:
16
  try:
17
  with st.spinner("正在读取数据库..."):
18
+ converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
19
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
20
  data_list = []
21
  answer_list = []
22
  for example in dataset["train"]:
23
+ converted_answer = converter.convert(example["Answer"])
24
+ converted_question = converter.convert(example["Question"])
25
  answer_list.append(converted_answer)
26
  data_list.append({"Question": converted_question, "Answer": converted_answer})
27
  st.session_state.answer_list = answer_list
 
63
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
64
  try:
65
  with st.spinner("正在初始化 Gemma 模型..."):
66
+ st.session_state.llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
67
  st.success("Gemma 模型初始化完成!")
68
  print("Gemma 模型初始化完成!")
69
  st.session_state.repo_id = repo_id
 
91
  print("本地数据集筛选完成!")
92
 
93
  with st.spinner("正在生成答案..."):
94
+ answer = st.session_state.llm.invoke(prompt)
95
  # 去掉 prompt 的内容
96
  answer = answer.replace(prompt, "").strip()
97
  st.success("答案已经生成!")
 
113
 
114
  st.divider()
115
 
116
+ def generate_answer(repo_id, temperature, max_length, question):
117
+ result = answer_question(repo_id, float(temperature), int(max_length), question)
118
+ print('prompt: ' + result["prompt"])
119
+ print('answer: ' + result["answer"])
120
+ st.write("生成答案:")
121
+ st.write(result["answer"])
122
+
123
  col3, col4 = st.columns(2)
124
  with col3:
125
  if st.button("使用原数据集中的随机问题"):
 
127
  random_index = random.randint(0, dataset_size - 1)
128
  # 读取随机问题
129
  random_question = st.session_state.data_list[random_index]["Question"]
 
130
  origin_answer = st.session_state.data_list[random_index]["Answer"]
 
131
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
132
  print('origin_answer: ' + origin_answer)
133
 
 
135
  st.write(random_question)
136
  st.write("原始答案:")
137
  st.write(origin_answer)
138
+ generate_answer(gemma, float(temperature), int(max_length), random_question)
 
 
 
 
139
 
140
  with col4:
141
+ question = st.text_area("请输入问题", "《进击的巨人》中都有哪些主要角色?")
142
  if st.button("提交输入的问题"):
143
  if not question:
144
  st.warning("请输入问题!")
145
  else:
146
+ generate_answer(gemma, float(temperature), int(max_length), question)