Jeremy Live commited on
Commit
56fb484
·
1 Parent(s): c6828a0

API solved v2

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py CHANGED
@@ -14,6 +14,81 @@ import plotly.graph_objects as go
14
  from plotly.subplots import make_subplots
15
  from shared import initialize_llm, setup_database_connection, create_agent
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ... (resto del código existente sin cambios) ...
18
 
19
  def create_application():
@@ -121,6 +196,112 @@ def create_application():
121
 
122
  return demo
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # Create the application
125
  demo = create_application()
126
 
 
14
  from plotly.subplots import make_subplots
15
  from shared import initialize_llm, setup_database_connection, create_agent
16
 
17
+ try:
18
+ from langchain_core.messages import HumanMessage, AIMessage
19
+ LANGCHAIN_AVAILABLE = True
20
+ except ImportError:
21
+ # Fallback if langchain not available
22
+ class HumanMessage:
23
+ def __init__(self, content):
24
+ self.content = content
25
+
26
+ class AIMessage:
27
+ def __init__(self, content):
28
+ self.content = content
29
+ LANGCHAIN_AVAILABLE = False
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ def create_ui():
36
+ """Create the Gradio UI components."""
37
+ # Custom CSS for styling
38
+ custom_css = """
39
+ .gradio-container {
40
+ max-width: 1200px !important;
41
+ }
42
+ .chat-container {
43
+ height: 600px;
44
+ overflow-y: auto;
45
+ }
46
+ .chart-container {
47
+ height: 600px;
48
+ overflow-y: auto;
49
+ }
50
+ """
51
+
52
+ with gr.Blocks(css=custom_css, title="🤖 SQL Database Assistant") as demo:
53
+ gr.Markdown("# 🤖 SQL Database Assistant")
54
+ gr.Markdown("Ask questions about your database in natural language!")
55
+
56
+ with gr.Row():
57
+ with gr.Column(scale=2):
58
+ chatbot = gr.Chatbot(
59
+ label="Chat",
60
+ elem_classes="chat-container",
61
+ type="messages",
62
+ height=500
63
+ )
64
+
65
+ with gr.Row():
66
+ question_input = gr.Textbox(
67
+ label="Ask your question",
68
+ placeholder="Type your question here...",
69
+ lines=2,
70
+ scale=4
71
+ )
72
+ submit_button = gr.Button("Send", variant="primary", scale=1)
73
+
74
+ streaming_output_display = gr.Markdown(visible=False)
75
+
76
+ with gr.Column(scale=1):
77
+ chart_display = gr.Plot(
78
+ label="Charts",
79
+ elem_classes="chart-container",
80
+ height=500
81
+ )
82
+
83
+ # Status indicators
84
+ with gr.Row():
85
+ status_indicator = gr.Markdown(
86
+ "### ✅ System Status\n- **Database**: Ready\n- **AI Model**: Ready\n- **API**: Available",
87
+ elem_id="status"
88
+ )
89
+
90
+ return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display
91
+
92
  # ... (resto del código existente sin cambios) ...
93
 
94
  def create_application():
 
196
 
197
  return demo
198
 
199
+ async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional[go.Figure]]:
200
+ """Process a question through the SQL agent and return response with optional chart."""
201
+
202
+ # Initialize components
203
+ llm, llm_error = initialize_llm()
204
+ if llm_error:
205
+ return f"**LLM Error:** {llm_error}", None
206
+
207
+ db_connection, db_error = setup_database_connection()
208
+ if db_error:
209
+ return f"**Database Error:** {db_error}", None
210
+
211
+ agent, agent_error = create_agent(llm, db_connection)
212
+ if agent_error:
213
+ return f"**Agent Error:** {agent_error}", None
214
+
215
+ try:
216
+ logger.info(f"Processing question: {question}")
217
+
218
+ # Prepare the input with chat history
219
+ input_data = {"input": question}
220
+ if chat_history:
221
+ # Format chat history for the agent
222
+ formatted_history = []
223
+ for human, ai in chat_history:
224
+ formatted_history.extend([
225
+ HumanMessage(content=human),
226
+ AIMessage(content=ai)
227
+ ])
228
+ input_data["chat_history"] = formatted_history
229
+
230
+ # Execute the agent
231
+ response = agent.invoke(input_data)
232
+
233
+ # Extract the response text
234
+ if hasattr(response, 'output') and response.output:
235
+ response_text = response.output
236
+ elif isinstance(response, dict) and 'output' in response:
237
+ response_text = response['output']
238
+ elif isinstance(response, str):
239
+ response_text = response
240
+ else:
241
+ response_text = str(response)
242
+
243
+ # Check for SQL queries in the response
244
+ sql_pattern = r'```sql\s*(.*?)\s*```'
245
+ sql_matches = re.findall(sql_pattern, response_text, re.DOTALL)
246
+
247
+ chart_fig = None
248
+ if sql_matches:
249
+ # Try to execute the SQL and create a chart
250
+ try:
251
+ sql_query = sql_matches[-1].strip()
252
+ logger.info(f"Executing SQL query: {sql_query}")
253
+
254
+ # Execute the query
255
+ result = db_connection.run(sql_query)
256
+
257
+ if result:
258
+ # Convert result to DataFrame
259
+ import pandas as pd
260
+ if isinstance(result, list) and result:
261
+ df = pd.DataFrame(result)
262
+
263
+ # Determine chart type based on data
264
+ if len(df.columns) >= 2:
265
+ # Simple bar chart for categorical data
266
+ fig = go.Figure()
267
+
268
+ if len(df) <= 20: # Bar chart for smaller datasets
269
+ fig.add_trace(go.Bar(
270
+ x=df.iloc[:, 0],
271
+ y=df.iloc[:, 1],
272
+ name=str(df.columns[1])
273
+ ))
274
+ fig.update_layout(
275
+ title=f"{df.columns[0]} vs {df.columns[1]}",
276
+ xaxis_title=str(df.columns[0]),
277
+ yaxis_title=str(df.columns[1])
278
+ )
279
+ else: # Line chart for larger datasets
280
+ fig.add_trace(go.Scatter(
281
+ x=df.iloc[:, 0],
282
+ y=df.iloc[:, 1],
283
+ mode='lines+markers',
284
+ name=str(df.columns[1])
285
+ ))
286
+ fig.update_layout(
287
+ title=f"{df.columns[0]} vs {df.columns[1]}",
288
+ xaxis_title=str(df.columns[0]),
289
+ yaxis_title=str(df.columns[1])
290
+ )
291
+
292
+ chart_fig = fig
293
+
294
+ except Exception as e:
295
+ logger.warning(f"Could not create chart: {e}")
296
+ # Continue without chart
297
+
298
+ return response_text, chart_fig
299
+
300
+ except Exception as e:
301
+ error_msg = f"**Error processing question:** {str(e)}"
302
+ logger.error(error_msg, exc_info=True)
303
+ return error_msg, None
304
+
305
  # Create the application
306
  demo = create_application()
307