Update app.py
Browse files
app.py
CHANGED
|
@@ -103,35 +103,35 @@ def display_optimization_results(result_data):
|
|
| 103 |
success = result["succeed"]
|
| 104 |
prompt = result["prompt"]
|
| 105 |
|
| 106 |
-
with st.expander(f"
|
| 107 |
-
st.markdown("
|
| 108 |
st.code(prompt, language="text")
|
| 109 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 110 |
|
| 111 |
col1, col2 = st.columns(2)
|
| 112 |
with col1:
|
| 113 |
-
st.markdown(f"
|
| 114 |
with col2:
|
| 115 |
-
st.markdown(f"
|
| 116 |
|
| 117 |
-
st.markdown("
|
| 118 |
for idx, answer in enumerate(result["answers"]):
|
| 119 |
-
st.markdown(f"
|
| 120 |
st.text(answer["question"])
|
| 121 |
-
st.markdown("
|
| 122 |
st.text(answer["answer"])
|
| 123 |
st.markdown("---")
|
| 124 |
|
| 125 |
-
#
|
| 126 |
success_count = sum(1 for r in result_data if r["succeed"])
|
| 127 |
total_rounds = len(result_data)
|
| 128 |
|
| 129 |
-
st.markdown("###
|
| 130 |
col1, col2 = st.columns(2)
|
| 131 |
with col1:
|
| 132 |
-
st.metric("
|
| 133 |
with col2:
|
| 134 |
-
st.metric("
|
| 135 |
|
| 136 |
|
| 137 |
def main():
|
|
@@ -144,69 +144,68 @@ def main():
|
|
| 144 |
"""
|
| 145 |
<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin-bottom: 25px">
|
| 146 |
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px">
|
| 147 |
-
<h1 style="margin: 0;">SPO |
|
| 148 |
</div>
|
| 149 |
<div style="display: flex; gap: 20px; align-items: center">
|
| 150 |
<a href="https://arxiv.org/pdf/2502.06855" target="_blank" style="text-decoration: none;">
|
| 151 |
-
<img src="https://img.shields.io/badge
|
| 152 |
</a>
|
| 153 |
<a href="https://github.com/geekan/MetaGPT/blob/main/examples/spo/README.md" target="_blank" style="text-decoration: none;">
|
| 154 |
-
<img src="https://img.shields.io/badge/GitHub
|
| 155 |
</a>
|
| 156 |
-
<span style="color: #666;"
|
| 157 |
</div>
|
| 158 |
</div>
|
| 159 |
""",
|
| 160 |
unsafe_allow_html=True
|
| 161 |
)
|
| 162 |
|
| 163 |
-
#
|
| 164 |
with st.sidebar:
|
| 165 |
-
st.header("
|
| 166 |
|
| 167 |
-
#
|
| 168 |
settings_path = Path("metagpt/ext/spo/settings")
|
| 169 |
existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
|
| 170 |
-
|
| 171 |
-
template_mode = st.radio("Template Mode", ["Use Existing", "Create New"])
|
| 172 |
|
| 173 |
existing_templates = get_all_templates()
|
| 174 |
|
| 175 |
-
if template_mode == "
|
| 176 |
-
template_name = st.selectbox("
|
| 177 |
is_new_template = False
|
| 178 |
else:
|
| 179 |
-
template_name = st.text_input("
|
| 180 |
is_new_template = True
|
| 181 |
|
| 182 |
-
# LLM
|
| 183 |
-
st.subheader("LLM
|
| 184 |
|
| 185 |
-
base_url = st.text_input("
|
| 186 |
-
api_key = st.text_input("API
|
| 187 |
|
| 188 |
opt_model = st.selectbox(
|
| 189 |
-
"
|
| 190 |
)
|
| 191 |
-
opt_temp = st.slider("
|
| 192 |
|
| 193 |
eval_model = st.selectbox(
|
| 194 |
-
"
|
| 195 |
)
|
| 196 |
-
eval_temp = st.slider("
|
| 197 |
|
| 198 |
exec_model = st.selectbox(
|
| 199 |
-
"
|
| 200 |
)
|
| 201 |
-
exec_temp = st.slider("
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
st.subheader("
|
| 205 |
-
initial_round = st.number_input("
|
| 206 |
-
max_rounds = st.number_input("
|
| 207 |
|
| 208 |
-
#
|
| 209 |
-
st.header("
|
| 210 |
|
| 211 |
if template_name:
|
| 212 |
template_real_name = get_template_path(template_name, is_new_template)
|
|
@@ -220,30 +219,30 @@ def main():
|
|
| 220 |
st.session_state.current_template = template_name
|
| 221 |
st.session_state.qas = template_data.get("qa", [])
|
| 222 |
|
| 223 |
-
#
|
| 224 |
-
prompt = st.text_area("
|
| 225 |
-
requirements = st.text_area("
|
| 226 |
|
| 227 |
-
#
|
| 228 |
-
st.subheader("
|
| 229 |
|
| 230 |
-
#
|
| 231 |
-
if st.button("
|
| 232 |
st.session_state.qas.append({"question": "", "answer": ""})
|
| 233 |
|
| 234 |
-
#
|
| 235 |
new_qas = []
|
| 236 |
for i in range(len(st.session_state.qas)):
|
| 237 |
-
st.markdown(f"
|
| 238 |
col1, col2, col3 = st.columns([45, 45, 10])
|
| 239 |
|
| 240 |
with col1:
|
| 241 |
question = st.text_area(
|
| 242 |
-
f"
|
| 243 |
)
|
| 244 |
with col2:
|
| 245 |
answer = st.text_area(
|
| 246 |
-
f"
|
| 247 |
)
|
| 248 |
with col3:
|
| 249 |
if st.button("🗑️", key=f"delete_{i}"):
|
|
@@ -252,20 +251,20 @@ def main():
|
|
| 252 |
|
| 253 |
new_qas.append({"question": question, "answer": answer})
|
| 254 |
|
| 255 |
-
#
|
| 256 |
-
if st.button("
|
| 257 |
new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
|
| 258 |
|
| 259 |
save_yaml_template(template_path, new_template_data, is_new_template)
|
| 260 |
|
| 261 |
st.session_state.qas = new_qas
|
| 262 |
-
st.success(f"
|
| 263 |
|
| 264 |
-
st.subheader("
|
| 265 |
preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
|
| 266 |
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
|
| 267 |
|
| 268 |
-
st.subheader("
|
| 269 |
log_container = st.empty()
|
| 270 |
|
| 271 |
class StreamlitSink:
|
|
@@ -289,8 +288,8 @@ def main():
|
|
| 289 |
)
|
| 290 |
_logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
if st.button("
|
| 294 |
try:
|
| 295 |
# Initialize LLM
|
| 296 |
SPO_LLM.initialize(
|
|
@@ -315,37 +314,35 @@ def main():
|
|
| 315 |
with st.spinner("Optimizing prompts..."):
|
| 316 |
optimizer.optimize()
|
| 317 |
|
| 318 |
-
st.success("
|
| 319 |
-
|
| 320 |
-
st.header("Optimization Results")
|
| 321 |
-
|
| 322 |
prompt_path = optimizer.root_path / "prompts"
|
| 323 |
result_data = optimizer.data_utils.load_results(prompt_path)
|
| 324 |
|
| 325 |
st.session_state.optimization_results = result_data
|
| 326 |
|
| 327 |
except Exception as e:
|
| 328 |
-
st.error(f"
|
| 329 |
-
_logger.error(f"
|
| 330 |
|
| 331 |
if st.session_state.optimization_results:
|
| 332 |
-
st.header("
|
| 333 |
display_optimization_results(st.session_state.optimization_results)
|
| 334 |
|
| 335 |
st.markdown("---")
|
| 336 |
-
st.subheader("
|
| 337 |
col1, col2 = st.columns(2)
|
| 338 |
|
| 339 |
with col1:
|
| 340 |
-
test_prompt = st.text_area("
|
| 341 |
|
| 342 |
with col2:
|
| 343 |
-
test_question = st.text_area("
|
| 344 |
|
| 345 |
-
if st.button("
|
| 346 |
if test_prompt and test_question:
|
| 347 |
try:
|
| 348 |
-
with st.spinner("
|
| 349 |
SPO_LLM.initialize(
|
| 350 |
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
|
| 351 |
"api_key": api_key},
|
|
@@ -368,13 +365,13 @@ def main():
|
|
| 368 |
finally:
|
| 369 |
loop.close()
|
| 370 |
|
| 371 |
-
st.subheader("
|
| 372 |
st.markdown(response)
|
| 373 |
|
| 374 |
except Exception as e:
|
| 375 |
-
st.error(f"
|
| 376 |
else:
|
| 377 |
-
st.warning("
|
| 378 |
|
| 379 |
|
| 380 |
if __name__ == "__main__":
|
|
|
|
| 103 |
success = result["succeed"]
|
| 104 |
prompt = result["prompt"]
|
| 105 |
|
| 106 |
+
with st.expander(f"轮次 {round_num} {':white_check_mark:' if success else ':x:'}"):
|
| 107 |
+
st.markdown("**提示词:**")
|
| 108 |
st.code(prompt, language="text")
|
| 109 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 110 |
|
| 111 |
col1, col2 = st.columns(2)
|
| 112 |
with col1:
|
| 113 |
+
st.markdown(f"**状态:** {'成功 ✅ ' if success else '失败 ❌ '}")
|
| 114 |
with col2:
|
| 115 |
+
st.markdown(f"**令牌数:** {result['tokens']}")
|
| 116 |
|
| 117 |
+
st.markdown("**回答:**")
|
| 118 |
for idx, answer in enumerate(result["answers"]):
|
| 119 |
+
st.markdown(f"**问题 {idx + 1}:**")
|
| 120 |
st.text(answer["question"])
|
| 121 |
+
st.markdown("**答案:**")
|
| 122 |
st.text(answer["answer"])
|
| 123 |
st.markdown("---")
|
| 124 |
|
| 125 |
+
# 总结
|
| 126 |
success_count = sum(1 for r in result_data if r["succeed"])
|
| 127 |
total_rounds = len(result_data)
|
| 128 |
|
| 129 |
+
st.markdown("### 总结")
|
| 130 |
col1, col2 = st.columns(2)
|
| 131 |
with col1:
|
| 132 |
+
st.metric("总轮次", total_rounds)
|
| 133 |
with col2:
|
| 134 |
+
st.metric("成功轮次", success_count)
|
| 135 |
|
| 136 |
|
| 137 |
def main():
|
|
|
|
| 144 |
"""
|
| 145 |
<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin-bottom: 25px">
|
| 146 |
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px">
|
| 147 |
+
<h1 style="margin: 0;">SPO | 自监督提示词优化 🤖</h1>
|
| 148 |
</div>
|
| 149 |
<div style="display: flex; gap: 20px; align-items: center">
|
| 150 |
<a href="https://arxiv.org/pdf/2502.06855" target="_blank" style="text-decoration: none;">
|
| 151 |
+
<img src="https://img.shields.io/badge/论文-PDF-red.svg" alt="论文">
|
| 152 |
</a>
|
| 153 |
<a href="https://github.com/geekan/MetaGPT/blob/main/examples/spo/README.md" target="_blank" style="text-decoration: none;">
|
| 154 |
+
<img src="https://img.shields.io/badge/GitHub-仓库-blue.svg" alt="GitHub">
|
| 155 |
</a>
|
| 156 |
+
<span style="color: #666;">一个自监督提示词优化框架</span>
|
| 157 |
</div>
|
| 158 |
</div>
|
| 159 |
""",
|
| 160 |
unsafe_allow_html=True
|
| 161 |
)
|
| 162 |
|
| 163 |
+
# 侧边栏配置
|
| 164 |
with st.sidebar:
|
| 165 |
+
st.header("配置")
|
| 166 |
|
| 167 |
+
# 模板选择/创建
|
| 168 |
settings_path = Path("metagpt/ext/spo/settings")
|
| 169 |
existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
|
| 170 |
+
template_mode = st.radio("模板模式", ["使用现有", "创建新模板"])
|
|
|
|
| 171 |
|
| 172 |
existing_templates = get_all_templates()
|
| 173 |
|
| 174 |
+
if template_mode == "使用现有":
|
| 175 |
+
template_name = st.selectbox("选择模板", existing_templates)
|
| 176 |
is_new_template = False
|
| 177 |
else:
|
| 178 |
+
template_name = st.text_input("新模板名称")
|
| 179 |
is_new_template = True
|
| 180 |
|
| 181 |
+
# LLM 设置
|
| 182 |
+
st.subheader("LLM 设置")
|
| 183 |
|
| 184 |
+
base_url = st.text_input("基础 URL", value="https://api.example.com")
|
| 185 |
+
api_key = st.text_input("API 密钥", type="password")
|
| 186 |
|
| 187 |
opt_model = st.selectbox(
|
| 188 |
+
"优化模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
|
| 189 |
)
|
| 190 |
+
opt_temp = st.slider("优化温度", 0.0, 1.0, 0.7)
|
| 191 |
|
| 192 |
eval_model = st.selectbox(
|
| 193 |
+
"评估模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
|
| 194 |
)
|
| 195 |
+
eval_temp = st.slider("评估温度", 0.0, 1.0, 0.3)
|
| 196 |
|
| 197 |
exec_model = st.selectbox(
|
| 198 |
+
"执行模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
|
| 199 |
)
|
| 200 |
+
exec_temp = st.slider("执行温度", 0.0, 1.0, 0.0)
|
| 201 |
|
| 202 |
+
# 优化器设置
|
| 203 |
+
st.subheader("优化器设置")
|
| 204 |
+
initial_round = st.number_input("初始轮次", 1, 100, 1)
|
| 205 |
+
max_rounds = st.number_input("最大轮次", 1, 100, 10)
|
| 206 |
|
| 207 |
+
# 主要内容区域
|
| 208 |
+
st.header("模板配置")
|
| 209 |
|
| 210 |
if template_name:
|
| 211 |
template_real_name = get_template_path(template_name, is_new_template)
|
|
|
|
| 219 |
st.session_state.current_template = template_name
|
| 220 |
st.session_state.qas = template_data.get("qa", [])
|
| 221 |
|
| 222 |
+
# 编辑模板部分
|
| 223 |
+
prompt = st.text_area("提示词", template_data.get("prompt", ""), height=100)
|
| 224 |
+
requirements = st.text_area("要求", template_data.get("requirements", ""), height=100)
|
| 225 |
|
| 226 |
+
# 问答部分
|
| 227 |
+
st.subheader("问答示例")
|
| 228 |
|
| 229 |
+
# 添加新问答按钮
|
| 230 |
+
if st.button("添加新问答"):
|
| 231 |
st.session_state.qas.append({"question": "", "answer": ""})
|
| 232 |
|
| 233 |
+
# 编辑问答
|
| 234 |
new_qas = []
|
| 235 |
for i in range(len(st.session_state.qas)):
|
| 236 |
+
st.markdown(f"**问答 #{i + 1}**")
|
| 237 |
col1, col2, col3 = st.columns([45, 45, 10])
|
| 238 |
|
| 239 |
with col1:
|
| 240 |
question = st.text_area(
|
| 241 |
+
f"问题 {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100
|
| 242 |
)
|
| 243 |
with col2:
|
| 244 |
answer = st.text_area(
|
| 245 |
+
f"答案 {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100
|
| 246 |
)
|
| 247 |
with col3:
|
| 248 |
if st.button("🗑️", key=f"delete_{i}"):
|
|
|
|
| 251 |
|
| 252 |
new_qas.append({"question": question, "answer": answer})
|
| 253 |
|
| 254 |
+
# 保存模板按钮
|
| 255 |
+
if st.button("保存模板"):
|
| 256 |
new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
|
| 257 |
|
| 258 |
save_yaml_template(template_path, new_template_data, is_new_template)
|
| 259 |
|
| 260 |
st.session_state.qas = new_qas
|
| 261 |
+
st.success(f"模板已保存到 {template_path}")
|
| 262 |
|
| 263 |
+
st.subheader("当前模板预览")
|
| 264 |
preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
|
| 265 |
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
|
| 266 |
|
| 267 |
+
st.subheader("优化日志")
|
| 268 |
log_container = st.empty()
|
| 269 |
|
| 270 |
class StreamlitSink:
|
|
|
|
| 288 |
)
|
| 289 |
_logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
|
| 290 |
|
| 291 |
+
# 开始优化按钮
|
| 292 |
+
if st.button("开始优化"):
|
| 293 |
try:
|
| 294 |
# Initialize LLM
|
| 295 |
SPO_LLM.initialize(
|
|
|
|
| 314 |
with st.spinner("Optimizing prompts..."):
|
| 315 |
optimizer.optimize()
|
| 316 |
|
| 317 |
+
st.success("优化完成!")
|
| 318 |
+
st.header("优化结果")
|
|
|
|
|
|
|
| 319 |
prompt_path = optimizer.root_path / "prompts"
|
| 320 |
result_data = optimizer.data_utils.load_results(prompt_path)
|
| 321 |
|
| 322 |
st.session_state.optimization_results = result_data
|
| 323 |
|
| 324 |
except Exception as e:
|
| 325 |
+
st.error(f"发生错误:{str(e)}")
|
| 326 |
+
_logger.error(f"优化过程中出错:{str(e)}")
|
| 327 |
|
| 328 |
if st.session_state.optimization_results:
|
| 329 |
+
st.header("优化结果")
|
| 330 |
display_optimization_results(st.session_state.optimization_results)
|
| 331 |
|
| 332 |
st.markdown("---")
|
| 333 |
+
st.subheader("测试优化后的提示词")
|
| 334 |
col1, col2 = st.columns(2)
|
| 335 |
|
| 336 |
with col1:
|
| 337 |
+
test_prompt = st.text_area("优化后的提示词", value="", height=200, key="test_prompt")
|
| 338 |
|
| 339 |
with col2:
|
| 340 |
+
test_question = st.text_area("你的问题", value="", height=200, key="test_question")
|
| 341 |
|
| 342 |
+
if st.button("测试提示词"):
|
| 343 |
if test_prompt and test_question:
|
| 344 |
try:
|
| 345 |
+
with st.spinner("正在生成回答..."):
|
| 346 |
SPO_LLM.initialize(
|
| 347 |
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
|
| 348 |
"api_key": api_key},
|
|
|
|
| 365 |
finally:
|
| 366 |
loop.close()
|
| 367 |
|
| 368 |
+
st.subheader("回答:")
|
| 369 |
st.markdown(response)
|
| 370 |
|
| 371 |
except Exception as e:
|
| 372 |
+
st.error(f"生成回答时出错:{str(e)}")
|
| 373 |
else:
|
| 374 |
+
st.warning("请输入提示词和问题。")
|
| 375 |
|
| 376 |
|
| 377 |
if __name__ == "__main__":
|