Spaces:
Sleeping
Sleeping
# --- Imports --- | |
import os | |
import json | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import tempfile | |
import io | |
import re | |
import networkx as nx | |
from datetime import datetime | |
from contextlib import redirect_stdout | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain.agents import initialize_agent, Tool, AgentType | |
from langchain_community.tools import DuckDuckGoSearchRun | |
import openai | |
# --- Pre-create rating log file --- | |
log_filename = "rating_log.txt" | |
if not os.path.exists(log_filename): | |
with open(log_filename, "w", encoding="utf-8") as f: | |
f.write("=== Rating Log Initialized ===\n") | |
# --- Setup API keys --- | |
openai_api_key = os.environ.get("OPENAI_API_KEY") | |
if not openai_api_key: | |
raise ValueError("OPENAI_API_KEY environment variable is not set.") | |
llm = ChatOpenAI(temperature=0, model="gpt-4", openai_api_key=openai_api_key) | |
openrouter_key = os.environ.get("OpenRouter") | |
openai_rater = openai.OpenAI(api_key=openrouter_key, base_url="https://openrouter.ai/api/v1") | |
# --- Helpers --- | |
def safe_file_or_none(path): | |
return path if isinstance(path, str) and os.path.isfile(path) else None | |
def remove_ansi(text): | |
return re.sub(r'\x1b\[[0-9;]*m', '', text) | |
# --- Rating function --- | |
def rate_answer_rater(question, final_answer): | |
try: | |
prompt = f"Rate this answer 1-5 stars with explanation:\n\n{final_answer}" | |
response = openai_rater.chat.completions.create( | |
model="mistral/ministral-8b", | |
messages=[{"role": "user", "content": prompt}] | |
) | |
rating_text = response.choices[0].message.content.strip() | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
with open("rating_log.txt", "a", encoding="utf-8") as log_file: | |
log_file.write(f"\n---\nTimestamp: {timestamp}\nQuestion: {question}\nAnswer: {final_answer}\nRating Response: {rating_text}\n") | |
return rating_text | |
except Exception as e: | |
return f"Rating error: {e}" | |
# --- Word map generation --- | |
def generate_wordmap(text): | |
try: | |
from wordcloud import WordCloud | |
wc = WordCloud(width=800, height=400, background_color="white").generate(text) | |
tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
wc.to_file(tmpfile.name) | |
return tmpfile.name | |
except Exception as e: | |
return None | |
# --- Reasoning tree generation --- | |
def generate_reasoning_tree(trace: str): | |
try: | |
G = nx.DiGraph() | |
step = 0 | |
last_node = "Start" | |
G.add_node(last_node) | |
for line in trace.splitlines(): | |
if line.strip(): | |
step += 1 | |
node_id = f"Step_{step}" | |
G.add_node(node_id, label=line) | |
G.add_edge(last_node, node_id) | |
last_node = node_id | |
pos = nx.spring_layout(G) | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
labels = nx.get_node_attributes(G, 'label') | |
nx.draw(G, pos, with_labels=False, node_size=3000, node_color='lightblue', ax=ax) | |
nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, ax=ax) | |
tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
plt.savefig(tmpfile.name) | |
plt.close(fig) | |
return tmpfile.name | |
except Exception as e: | |
return None | |
# --- Define specialist tools --- | |
def simple_tool(prompt_prefix): | |
return lambda query: llm.predict(f"{prompt_prefix}\n\n{query}") | |
legal_tool = Tool("LegalAnalystAgent", simple_tool("You are a legal analyst."), "Legal analysis") | |
financial_tool = Tool("FinancialMarketsAgent", simple_tool("You are a financial markets analyst."), "Financial insights") | |
lending_tool = Tool("LendingSpecialistAgent", simple_tool("You are a lending specialist."), "Lending guidance") | |
credit_tool = Tool("CreditSpecialistAgent", simple_tool("You are a credit specialist."), "Credit evaluation") | |
research_agent = DuckDuckGoSearchRun() | |
research_tool = Tool("ResearchAgent", research_agent.run, "Web search") | |
planner_tools = [ | |
research_tool, | |
legal_tool, | |
financial_tool, | |
lending_tool, | |
credit_tool | |
] | |
planner_agent = initialize_agent( | |
planner_tools, | |
llm=llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose=True | |
) | |
# --- Main agent logic --- | |
def agent_query(user_input, selected_agent, retry_threshold): | |
try: | |
f = io.StringIO() | |
with redirect_stdout(f): | |
if selected_agent == "Auto": | |
result = planner_agent.run(user_input) | |
trace_output = f.getvalue() | |
else: | |
agent_map = { | |
"ResearchAgent": research_agent.run, | |
"LegalAnalystAgent": legal_tool.func, | |
"FinancialMarketsAgent": financial_tool.func, | |
"LendingSpecialistAgent": lending_tool.func, | |
"CreditSpecialistAgent": credit_tool.func | |
} | |
agent_fn = agent_map.get(selected_agent) | |
result = agent_fn(user_input) if agent_fn else "Invalid agent selected." | |
trace_output = f.getvalue() | |
final_answer_str = result or "(No answer produced.)" | |
if "Final Answer:" not in trace_output: | |
trace_output += f"\n\nFinal Answer: {final_answer_str}" | |
wordmap_path = generate_wordmap(trace_output) | |
reasoning_tree_path = generate_reasoning_tree(remove_ansi(trace_output)) | |
rating_text = rate_answer_rater(user_input, final_answer_str) | |
return ( | |
trace_output + f"\n\n⭐ Rating: {rating_text}", | |
wordmap_path, | |
reasoning_tree_path, | |
gr.update(visible=bool(wordmap_path)), | |
gr.update(visible=bool(reasoning_tree_path)) | |
) | |
except Exception as e: | |
return f"Error: {e}", None, None, gr.update(visible=False), gr.update(visible=False) | |
# --- Gradio UI --- | |
demo = gr.Blocks(theme=gr.themes.Glass()) | |
with demo: | |
gr.Markdown("# Financial Services Multi-Agent Assistant") | |
gr.Markdown("Select an agent or use Auto for automatic routing.") | |
with gr.Row(): | |
input_box = gr.Textbox(label="Your Question") | |
with gr.Row(): | |
agent_selector = gr.Dropdown(label="Choose Agent", choices=[ | |
"Auto", "ResearchAgent", "LegalAnalystAgent", | |
"FinancialMarketsAgent", "LendingSpecialistAgent", "CreditSpecialistAgent" | |
], value="Auto") | |
with gr.Row(): | |
retry_slider = gr.Slider(label="Retry Rating Threshold", minimum=1.0, maximum=5.0, step=0.1, value=4.0) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit") | |
download_btn = gr.File(label="Download Rating Log") | |
with gr.Row(): | |
output_text = gr.Textbox(label="Agent Reasoning + Final Answer", lines=20) | |
with gr.Row(): | |
output_wordmap = gr.Image(label="Word Map", visible=True) | |
output_tree_image = gr.Image(label="Reasoning Tree", visible=True) | |
submit_btn.click( | |
fn=agent_query, | |
inputs=[input_box, agent_selector, retry_slider], | |
outputs=[ | |
output_text, output_wordmap, output_tree_image, | |
output_wordmap, output_tree_image | |
] | |
) | |
demo.load(lambda: "rating_log.txt", None, download_btn) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |