kgauvin603 commited on
Commit
a5f781d
·
verified ·
1 Parent(s): db5fd14

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Imports ---
2
+ import os
3
+ import json
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import tempfile
7
+ import io
8
+ import re
9
+ import networkx as nx
10
+ from datetime import datetime
11
+ from contextlib import redirect_stdout
12
+ from langchain_community.chat_models import ChatOpenAI
13
+ from langchain.agents import initialize_agent, Tool, AgentType
14
+ from langchain_community.tools import DuckDuckGoSearchRun
15
+ import openai
16
+
17
+ # --- Pre-create rating log file ---
18
+ log_filename = "rating_log.txt"
19
+ if not os.path.exists(log_filename):
20
+ with open(log_filename, "w", encoding="utf-8") as f:
21
+ f.write("=== Rating Log Initialized ===\n")
22
+
23
+ # --- Setup API keys ---
24
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
25
+ if not openai_api_key:
26
+ raise ValueError("OPENAI_API_KEY environment variable is not set.")
27
+ llm = ChatOpenAI(temperature=0, model="gpt-4", openai_api_key=openai_api_key)
28
+
29
+ openrouter_key = os.environ.get("OpenRouter")
30
+ openai_rater = openai.OpenAI(api_key=openrouter_key, base_url="https://openrouter.ai/api/v1")
31
+
32
+ # --- Helpers ---
33
+ def safe_file_or_none(path):
34
+ return path if isinstance(path, str) and os.path.isfile(path) else None
35
+
36
+ def remove_ansi(text):
37
+ return re.sub(r'\x1b\[[0-9;]*m', '', text)
38
+
39
+ # --- Rating function ---
40
+ def rate_answer_rater(question, final_answer):
41
+ try:
42
+ prompt = f"Rate this answer 1-5 stars with explanation:\n\n{final_answer}"
43
+ response = openai_rater.chat.completions.create(
44
+ model="mistral/ministral-8b",
45
+ messages=[{"role": "user", "content": prompt}]
46
+ )
47
+ rating_text = response.choices[0].message.content.strip()
48
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
49
+ with open("rating_log.txt", "a", encoding="utf-8") as log_file:
50
+ log_file.write(f"\n---\nTimestamp: {timestamp}\nQuestion: {question}\nAnswer: {final_answer}\nRating Response: {rating_text}\n")
51
+ return rating_text
52
+ except Exception as e:
53
+ return f"Rating error: {e}"
54
+
55
+ # --- Word map generation ---
56
+ def generate_wordmap(text):
57
+ try:
58
+ from wordcloud import WordCloud
59
+ wc = WordCloud(width=800, height=400, background_color="white").generate(text)
60
+ tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
61
+ wc.to_file(tmpfile.name)
62
+ return tmpfile.name
63
+ except Exception as e:
64
+ return None
65
+
66
+ # --- Reasoning tree generation ---
67
+ def generate_reasoning_tree(trace: str):
68
+ try:
69
+ G = nx.DiGraph()
70
+ step = 0
71
+ last_node = "Start"
72
+ G.add_node(last_node)
73
+
74
+ for line in trace.splitlines():
75
+ if line.strip():
76
+ step += 1
77
+ node_id = f"Step_{step}"
78
+ G.add_node(node_id, label=line)
79
+ G.add_edge(last_node, node_id)
80
+ last_node = node_id
81
+
82
+ pos = nx.spring_layout(G)
83
+ fig, ax = plt.subplots(figsize=(10, 5))
84
+ labels = nx.get_node_attributes(G, 'label')
85
+ nx.draw(G, pos, with_labels=False, node_size=3000, node_color='lightblue', ax=ax)
86
+ nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, ax=ax)
87
+ tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
88
+ plt.savefig(tmpfile.name)
89
+ plt.close(fig)
90
+ return tmpfile.name
91
+ except Exception as e:
92
+ return None
93
+
94
+ # --- Define specialist tools ---
95
+
96
+ def simple_tool(prompt_prefix):
97
+ return lambda query: llm.predict(f"{prompt_prefix}\n\n{query}")
98
+
99
+ legal_tool = Tool("LegalAnalystAgent", simple_tool("You are a legal analyst."), "Legal analysis")
100
+ financial_tool = Tool("FinancialMarketsAgent", simple_tool("You are a financial markets analyst."), "Financial insights")
101
+ lending_tool = Tool("LendingSpecialistAgent", simple_tool("You are a lending specialist."), "Lending guidance")
102
+ credit_tool = Tool("CreditSpecialistAgent", simple_tool("You are a credit specialist."), "Credit evaluation")
103
+ research_agent = DuckDuckGoSearchRun()
104
+ research_tool = Tool("ResearchAgent", research_agent.run, "Web search")
105
+
106
+ planner_tools = [
107
+ research_tool,
108
+ legal_tool,
109
+ financial_tool,
110
+ lending_tool,
111
+ credit_tool
112
+ ]
113
+
114
+ planner_agent = initialize_agent(
115
+ planner_tools,
116
+ llm=llm,
117
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
118
+ verbose=True
119
+ )
120
+
121
+ # --- Main agent logic ---
122
+ def agent_query(user_input, selected_agent, retry_threshold):
123
+ try:
124
+ f = io.StringIO()
125
+ with redirect_stdout(f):
126
+
127
+ if selected_agent == "Auto":
128
+ result = planner_agent.run(user_input)
129
+ trace_output = f.getvalue()
130
+ else:
131
+ agent_map = {
132
+ "ResearchAgent": research_agent.run,
133
+ "LegalAnalystAgent": legal_tool.func,
134
+ "FinancialMarketsAgent": financial_tool.func,
135
+ "LendingSpecialistAgent": lending_tool.func,
136
+ "CreditSpecialistAgent": credit_tool.func
137
+ }
138
+ agent_fn = agent_map.get(selected_agent)
139
+ result = agent_fn(user_input) if agent_fn else "Invalid agent selected."
140
+ trace_output = f.getvalue()
141
+
142
+ final_answer_str = result or "(No answer produced.)"
143
+ if "Final Answer:" not in trace_output:
144
+ trace_output += f"\n\nFinal Answer: {final_answer_str}"
145
+
146
+ wordmap_path = generate_wordmap(trace_output)
147
+ reasoning_tree_path = generate_reasoning_tree(remove_ansi(trace_output))
148
+
149
+ rating_text = rate_answer_rater(user_input, final_answer_str)
150
+
151
+ return (
152
+ trace_output + f"\n\n⭐ Rating: {rating_text}",
153
+ wordmap_path,
154
+ reasoning_tree_path,
155
+ gr.update(visible=bool(wordmap_path)),
156
+ gr.update(visible=bool(reasoning_tree_path))
157
+ )
158
+
159
+ except Exception as e:
160
+ return f"Error: {e}", None, None, gr.update(visible=False), gr.update(visible=False)
161
+
162
+ # --- Gradio UI ---
163
+
164
+ demo = gr.Blocks(theme=gr.themes.Glass())
165
+
166
+ with demo:
167
+ gr.Markdown("# Financial Services Multi-Agent Assistant")
168
+ gr.Markdown("Select an agent or use Auto for automatic routing.")
169
+
170
+ with gr.Row():
171
+ input_box = gr.Textbox(label="Your Question")
172
+ with gr.Row():
173
+ agent_selector = gr.Dropdown(label="Choose Agent", choices=[
174
+ "Auto", "ResearchAgent", "LegalAnalystAgent",
175
+ "FinancialMarketsAgent", "LendingSpecialistAgent", "CreditSpecialistAgent"
176
+ ], value="Auto")
177
+ with gr.Row():
178
+ retry_slider = gr.Slider(label="Retry Rating Threshold", minimum=1.0, maximum=5.0, step=0.1, value=4.0)
179
+
180
+ with gr.Row():
181
+ submit_btn = gr.Button("Submit")
182
+ download_btn = gr.File(label="Download Rating Log")
183
+
184
+ with gr.Row():
185
+ output_text = gr.Textbox(label="Agent Reasoning + Final Answer", lines=20)
186
+
187
+ with gr.Row():
188
+ output_wordmap = gr.Image(label="Word Map", visible=True)
189
+ output_tree_image = gr.Image(label="Reasoning Tree", visible=True)
190
+
191
+ submit_btn.click(
192
+ fn=agent_query,
193
+ inputs=[input_box, agent_selector, retry_slider],
194
+ outputs=[
195
+ output_text, output_wordmap, output_tree_image,
196
+ output_wordmap, output_tree_image
197
+ ]
198
+ )
199
+
200
+ demo.load(lambda: "rating_log.txt", None, download_btn)
201
+
202
+ if __name__ == "__main__":
203
+ demo.launch(share=True)