File size: 12,145 Bytes
f639c56 dc2d325 24f29f0 f639c56 1fe11cb f639c56 1fe11cb c6828a0 f639c56 56fb484 0f16c64 f639c56 713a6e6 683b6ad f639c56 0f16c64 c6828a0 0f16c64 c6828a0 0f16c64 d89427c 713a6e6 d89427c 713a6e6 d89427c 332a246 d89427c 332a246 713a6e6 683b6ad 55ff70d 683b6ad d89427c 683b6ad d89427c 332a246 d89427c 332a246 d89427c fa13896 683b6ad d89427c 332a246 683b6ad d89427c 332a246 6aebc39 332a246 d89427c 683b6ad 713a6e6 332a246 713a6e6 683b6ad 713a6e6 332a246 713a6e6 683b6ad 713a6e6 f639c56 56fb484 713a6e6 f639c56 cc2a812 f639c56 cc2a812 f639c56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
import os
import sys
import re
import gradio as gr
import json
import tempfile
import base64
import io
from typing import List, Dict, Any, Optional, Tuple, Union
import logging
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from shared import initialize_llm, setup_database_connection, create_agent
try:
from langchain_core.messages import HumanMessage, AIMessage
LANGCHAIN_AVAILABLE = True
except ImportError:
# Fallback if langchain not available
class HumanMessage:
def __init__(self, content):
self.content = content
class AIMessage:
def __init__(self, content):
self.content = content
LANGCHAIN_AVAILABLE = False
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_ui():
"""Create the Gradio UI components."""
# Custom CSS for styling
custom_css = """
.gradio-container {
max-width: 1200px !important;
}
.chat-container {
height: 600px;
overflow-y: auto;
}
.chart-container {
height: 600px;
overflow-y: auto;
}
"""
with gr.Blocks(css=custom_css, title="🤖 SQL Database Assistant") as demo:
gr.Markdown("# 🤖 SQL Database Assistant")
gr.Markdown("Ask questions about your database in natural language!")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Chat",
elem_classes="chat-container",
type="messages",
height=500
)
with gr.Row():
question_input = gr.Textbox(
label="Ask your question",
placeholder="Type your question here...",
lines=2,
scale=4
)
submit_button = gr.Button("Send", variant="primary", scale=1)
streaming_output_display = gr.Markdown(visible=False)
with gr.Column(scale=1):
chart_display = gr.Plot(
label="Charts",
elem_classes="chart-container",
height=500
)
# Status indicators
with gr.Row():
status_indicator = gr.Markdown(
"### ✅ System Status\n- **Database**: Ready\n- **AI Model**: Ready\n- **API**: Available",
elem_id="status"
)
return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display
# ... (resto del código existente sin cambios) ...
def create_application():
"""Create and configure the Gradio application."""
# Create the UI components
demo, chatbot, chart_display, question_input, submit_button, streaming_output_display = create_ui()
# Montar la API Flask en la aplicación Gradio
if os.getenv('SPACE_ID'):
import api
demo = gr.mount_gradio_app(
api.app,
"/api", # Prefijo para los endpoints de la API
lambda: True # Autenticación deshabilitada
)
def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
"""Add user message to chat history (messages format) and clear input."""
if not user_input.strip():
return "", chat_history
logger.info(f"User message: {user_input}")
if chat_history is None:
chat_history = []
# Append user message in messages format
chat_history.append({"role": "user", "content": user_input})
return "", chat_history
async def bot_response(chat_history: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Optional[go.Figure]]:
"""Generate bot response for messages-format chat history and return optional chart figure."""
if not chat_history:
return chat_history, None
# Ensure last message is a user turn awaiting assistant reply
last = chat_history[-1]
if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
return chat_history, None
try:
question = last["content"]
logger.info(f"Processing question: {question}")
# Convert prior messages to pair history for stream_agent_response()
pair_history: List[List[str]] = []
i = 0
while i < len(chat_history) - 1:
m1 = chat_history[i]
m2 = chat_history[i + 1] if i + 1 < len(chat_history) else None
if (
isinstance(m1, dict)
and m1.get("role") == "user"
and isinstance(m2, dict)
and m2.get("role") == "assistant"
):
pair_history.append([m1.get("content", ""), m2.get("content", "")])
i += 2
else:
i += 1
# Call the agent for this new user question
assistant_message, chart_fig = await stream_agent_response(question, pair_history)
# Append assistant message back into messages history
chat_history.append({"role": "assistant", "content": assistant_message})
logger.info("Response generation complete")
return chat_history, chart_fig
except Exception as e:
error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
logger.error(error_msg, exc_info=True)
# Ensure we add an assistant error message for the UI
chat_history.append({"role": "assistant", "content": error_msg})
return chat_history, None
# Event handlers
with demo:
# Handle form submission
msg_submit = question_input.submit(
fn=user_message,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot],
queue=True
).then(
fn=bot_response,
inputs=[chatbot],
outputs=[chatbot, chart_display],
api_name="ask"
)
# Handle button click
btn_click = submit_button.click(
fn=user_message,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot],
queue=True
).then(
fn=bot_response,
inputs=[chatbot],
outputs=[chatbot, chart_display]
)
return demo
async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional[go.Figure]]:
"""Process a question through the SQL agent and return response with optional chart."""
# Initialize components
llm, llm_error = initialize_llm()
if llm_error:
return f"**LLM Error:** {llm_error}", None
db_connection, db_error = setup_database_connection()
if db_error:
return f"**Database Error:** {db_error}", None
agent, agent_error = create_agent(llm, db_connection)
if agent_error:
return f"**Agent Error:** {agent_error}", None
try:
logger.info(f"Processing question: {question}")
# Prepare the input with chat history
input_data = {"input": question}
if chat_history:
# Format chat history for the agent
formatted_history = []
for human, ai in chat_history:
formatted_history.extend([
HumanMessage(content=human),
AIMessage(content=ai)
])
input_data["chat_history"] = formatted_history
# Execute the agent
response = agent.invoke(input_data)
# Extract the response text
if hasattr(response, 'output') and response.output:
response_text = response.output
elif isinstance(response, dict) and 'output' in response:
response_text = response['output']
elif isinstance(response, str):
response_text = response
else:
response_text = str(response)
# Check for SQL queries in the response
sql_pattern = r'```sql\s*(.*?)\s*```'
sql_matches = re.findall(sql_pattern, response_text, re.DOTALL)
chart_fig = None
if sql_matches:
# Try to execute the SQL and create a chart
try:
sql_query = sql_matches[-1].strip()
logger.info(f"Executing SQL query: {sql_query}")
# Execute the query
result = db_connection.run(sql_query)
if result:
# Convert result to DataFrame
import pandas as pd
if isinstance(result, list) and result:
df = pd.DataFrame(result)
# Determine chart type based on data
if len(df.columns) >= 2:
# Simple bar chart for categorical data
fig = go.Figure()
if len(df) <= 20: # Bar chart for smaller datasets
fig.add_trace(go.Bar(
x=df.iloc[:, 0],
y=df.iloc[:, 1],
name=str(df.columns[1])
))
fig.update_layout(
title=f"{df.columns[0]} vs {df.columns[1]}",
xaxis_title=str(df.columns[0]),
yaxis_title=str(df.columns[1])
)
else: # Line chart for larger datasets
fig.add_trace(go.Scatter(
x=df.iloc[:, 0],
y=df.iloc[:, 1],
mode='lines+markers',
name=str(df.columns[1])
))
fig.update_layout(
title=f"{df.columns[0]} vs {df.columns[1]}",
xaxis_title=str(df.columns[0]),
yaxis_title=str(df.columns[1])
)
chart_fig = fig
except Exception as e:
logger.warning(f"Could not create chart: {e}")
# Continue without chart
return response_text, chart_fig
except Exception as e:
error_msg = f"**Error processing question:** {str(e)}"
logger.error(error_msg, exc_info=True)
return error_msg, None
# Create the application
demo = create_application()
# Configuración para Hugging Face Spaces
def get_app():
"""Obtiene la instancia de la aplicación Gradio para Hugging Face Spaces."""
# Verificar si estamos en un entorno de Hugging Face Spaces
if os.getenv('SPACE_ID'):
# Configuración específica para Spaces
demo.title = "🤖 Asistente de Base de Datos SQL (Demo)"
demo.description = """
Este es un demo del asistente de base de datos SQL.
Para usar la versión completa con conexión a base de datos, clona este espacio y configura las variables de entorno.
"""
return demo
# Para desarrollo local
if __name__ == "__main__":
# Configuración para desarrollo local - versión simplificada para Gradio 5.x
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
share=False
)
|