Jeremy Live
commited on
Commit
路
1fe11cb
1
Parent(s):
3c8de53
v1
Browse files- app.py +126 -1
- requirements.txt +2 -0
app.py
CHANGED
@@ -3,8 +3,15 @@ import sys
|
|
3 |
import re
|
4 |
import gradio as gr
|
5 |
import json
|
6 |
-
|
|
|
|
|
|
|
7 |
import logging
|
|
|
|
|
|
|
|
|
8 |
|
9 |
try:
|
10 |
# Intentar importar dependencias opcionales
|
@@ -29,6 +36,70 @@ logger = logging.getLogger(__name__)
|
|
29 |
|
30 |
# Configure logging
|
31 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
def check_environment():
|
@@ -400,7 +471,61 @@ async def stream_agent_response(question: str, chat_history: List) -> List[Dict]
|
|
400 |
db_connection, _ = setup_database_connection()
|
401 |
if db_connection:
|
402 |
query_result = execute_sql_query(sql_query, db_connection)
|
|
|
|
|
403 |
response_text += f"\n\n### 馃攳 Resultado de la consulta:\n```sql\n{sql_query}\n```\n\n{query_result}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
else:
|
405 |
response_text += "\n\n鈿狅笍 No se pudo conectar a la base de datos para ejecutar la consulta."
|
406 |
|
|
|
3 |
import re
|
4 |
import gradio as gr
|
5 |
import json
|
6 |
+
import tempfile
|
7 |
+
import base64
|
8 |
+
import io
|
9 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
10 |
import logging
|
11 |
+
import pandas as pd
|
12 |
+
import plotly.express as px
|
13 |
+
import plotly.graph_objects as go
|
14 |
+
from plotly.subplots import make_subplots
|
15 |
|
16 |
try:
|
17 |
# Intentar importar dependencias opcionales
|
|
|
36 |
|
37 |
# Configure logging
|
38 |
logging.basicConfig(level=logging.INFO)
|
39 |
+
|
40 |
+
def generate_chart(data: Union[Dict, List[Dict], pd.DataFrame],
|
41 |
+
chart_type: str,
|
42 |
+
x: str,
|
43 |
+
y: str = None,
|
44 |
+
title: str = "",
|
45 |
+
x_label: str = None,
|
46 |
+
y_label: str = None) -> str:
|
47 |
+
"""
|
48 |
+
Generate a chart from data and return it as an HTML string with embedded image.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
data: The data to plot (can be a list of dicts or a pandas DataFrame)
|
52 |
+
chart_type: Type of chart to generate (bar, line, pie, scatter, histogram)
|
53 |
+
x: Column name for x-axis
|
54 |
+
y: Column name for y-axis (not needed for pie charts)
|
55 |
+
title: Chart title
|
56 |
+
x_label: Label for x-axis
|
57 |
+
y_label: Label for y-axis
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
HTML string with embedded chart image
|
61 |
+
"""
|
62 |
+
try:
|
63 |
+
# Convert data to DataFrame if it's a list of dicts
|
64 |
+
if isinstance(data, list):
|
65 |
+
df = pd.DataFrame(data)
|
66 |
+
elif isinstance(data, dict):
|
67 |
+
df = pd.DataFrame([data])
|
68 |
+
else:
|
69 |
+
df = data
|
70 |
+
|
71 |
+
if not isinstance(df, pd.DataFrame):
|
72 |
+
return "Error: Data must be a dictionary, list of dictionaries, or pandas DataFrame"
|
73 |
+
|
74 |
+
# Generate the appropriate chart type
|
75 |
+
if chart_type == 'bar':
|
76 |
+
fig = px.bar(df, x=x, y=y, title=title)
|
77 |
+
elif chart_type == 'line':
|
78 |
+
fig = px.line(df, x=x, y=y, title=title)
|
79 |
+
elif chart_type == 'pie':
|
80 |
+
fig = px.pie(df, names=x, values=y, title=title)
|
81 |
+
elif chart_type == 'scatter':
|
82 |
+
fig = px.scatter(df, x=x, y=y, title=title)
|
83 |
+
elif chart_type == 'histogram':
|
84 |
+
fig = px.histogram(df, x=x, title=title)
|
85 |
+
else:
|
86 |
+
return "Error: Unsupported chart type. Use 'bar', 'line', 'pie', 'scatter', or 'histogram'"
|
87 |
+
|
88 |
+
# Update layout
|
89 |
+
fig.update_layout(
|
90 |
+
xaxis_title=x_label or x,
|
91 |
+
yaxis_title=y_label or (y if y != x else ''),
|
92 |
+
title=title or f"{chart_type.capitalize()} Chart of {x} vs {y}" if y else f"{chart_type.capitalize()} Chart of {x}",
|
93 |
+
template="plotly_white"
|
94 |
+
)
|
95 |
+
|
96 |
+
# Convert to HTML
|
97 |
+
return f"<div style='width: 100%;'>{fig.to_html(full_html=False)}</div>"
|
98 |
+
|
99 |
+
except Exception as e:
|
100 |
+
error_msg = f"Error generating chart: {str(e)}"
|
101 |
+
logger.error(error_msg, exc_info=True)
|
102 |
+
return f"<div style='color: red;'>{error_msg}</div>"
|
103 |
logger = logging.getLogger(__name__)
|
104 |
|
105 |
def check_environment():
|
|
|
471 |
db_connection, _ = setup_database_connection()
|
472 |
if db_connection:
|
473 |
query_result = execute_sql_query(sql_query, db_connection)
|
474 |
+
|
475 |
+
# Add the query and its result to the response
|
476 |
response_text += f"\n\n### 馃攳 Resultado de la consulta:\n```sql\n{sql_query}\n```\n\n{query_result}"
|
477 |
+
|
478 |
+
# Try to generate a chart if the result is tabular
|
479 |
+
try:
|
480 |
+
if isinstance(query_result, str) and '|' in query_result and '---' in query_result:
|
481 |
+
# Convert markdown table to DataFrame
|
482 |
+
from io import StringIO
|
483 |
+
import re
|
484 |
+
|
485 |
+
# Clean up the markdown table
|
486 |
+
lines = [line.strip() for line in query_result.split('\n')
|
487 |
+
if line.strip() and '---' not in line and '|' in line]
|
488 |
+
if len(lines) > 1: # At least header + 1 data row
|
489 |
+
# Get column names from the first line
|
490 |
+
columns = [col.strip() for col in lines[0].split('|')[1:-1]]
|
491 |
+
# Get data rows
|
492 |
+
data = []
|
493 |
+
for line in lines[1:]:
|
494 |
+
values = [val.strip() for val in line.split('|')[1:-1]]
|
495 |
+
if len(values) == len(columns):
|
496 |
+
data.append(dict(zip(columns, values)))
|
497 |
+
|
498 |
+
if data and len(columns) >= 2:
|
499 |
+
# Generate a chart based on the data
|
500 |
+
chart_type = 'bar' # Default chart type
|
501 |
+
if len(columns) == 2:
|
502 |
+
# Simple bar chart for two columns
|
503 |
+
chart_html = generate_chart(
|
504 |
+
data=data,
|
505 |
+
chart_type=chart_type,
|
506 |
+
x=columns[0],
|
507 |
+
y=columns[1],
|
508 |
+
title=f"{columns[1]} por {columns[0]}",
|
509 |
+
x_label=columns[0],
|
510 |
+
y_label=columns[1]
|
511 |
+
)
|
512 |
+
response_text += f"\n\n### 馃搳 Visualizaci贸n:\n{chart_html}"
|
513 |
+
elif len(columns) > 2:
|
514 |
+
# For multiple columns, create a line chart
|
515 |
+
chart_html = generate_chart(
|
516 |
+
data=data,
|
517 |
+
chart_type='line',
|
518 |
+
x=columns[0],
|
519 |
+
y=columns[1],
|
520 |
+
title=f"{', '.join(columns[1:])} por {columns[0]}",
|
521 |
+
x_label=columns[0],
|
522 |
+
y_label=", ".join(columns[1:])
|
523 |
+
)
|
524 |
+
response_text += f"\n\n### 馃搳 Visualizaci贸n:\n{chart_html}"
|
525 |
+
except Exception as e:
|
526 |
+
logger.error(f"Error generating chart: {str(e)}", exc_info=True)
|
527 |
+
# Don't fail the whole request if chart generation fails
|
528 |
+
response_text += "\n\n鈿狅笍 No se pudo generar la visualizaci贸n de los datos."
|
529 |
else:
|
530 |
response_text += "\n\n鈿狅笍 No se pudo conectar a la base de datos para ejecutar la consulta."
|
531 |
|
requirements.txt
CHANGED
@@ -12,3 +12,5 @@ python-dotenv==1.0.1
|
|
12 |
pymysql==1.1.0
|
13 |
numpy==1.26.4
|
14 |
python-multipart>=0.0.18 # Required by gradio
|
|
|
|
|
|
12 |
pymysql==1.1.0
|
13 |
numpy==1.26.4
|
14 |
python-multipart>=0.0.18 # Required by gradio
|
15 |
+
plotly==5.18.0 # For interactive charts
|
16 |
+
kaleido==0.2.1 # For saving plotly charts as images
|