File size: 12,145 Bytes
f639c56
dc2d325
24f29f0
f639c56
 
1fe11cb
 
 
 
f639c56
1fe11cb
 
 
 
c6828a0
f639c56
56fb484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f16c64
f639c56
713a6e6
 
 
683b6ad
f639c56
0f16c64
 
c6828a0
0f16c64
c6828a0
0f16c64
 
 
 
d89427c
 
713a6e6
 
d89427c
713a6e6
d89427c
332a246
 
d89427c
 
 
 
332a246
713a6e6
683b6ad
 
55ff70d
683b6ad
d89427c
 
 
 
683b6ad
d89427c
332a246
d89427c
332a246
d89427c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa13896
683b6ad
d89427c
 
 
 
332a246
683b6ad
d89427c
332a246
6aebc39
332a246
d89427c
 
683b6ad
713a6e6
 
 
332a246
 
713a6e6
 
 
 
 
 
683b6ad
 
713a6e6
 
 
332a246
 
713a6e6
 
 
 
 
 
683b6ad
 
713a6e6
 
 
f639c56
56fb484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713a6e6
 
f639c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2a812
 
f639c56
 
 
cc2a812
f639c56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import os
import sys
import re
import gradio as gr
import json
import tempfile
import base64
import io
from typing import List, Dict, Any, Optional, Tuple, Union
import logging
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from shared import initialize_llm, setup_database_connection, create_agent

try:
    from langchain_core.messages import HumanMessage, AIMessage
    LANGCHAIN_AVAILABLE = True
except ImportError:
    # Fallback if langchain not available
    class HumanMessage:
        def __init__(self, content):
            self.content = content
    
    class AIMessage:
        def __init__(self, content):
            self.content = content
    LANGCHAIN_AVAILABLE = False

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def create_ui():
    """Create the Gradio UI components."""
    # Custom CSS for styling
    custom_css = """
    .gradio-container {
        max-width: 1200px !important;
    }
    .chat-container {
        height: 600px;
        overflow-y: auto;
    }
    .chart-container {
        height: 600px;
        overflow-y: auto;
    }
    """
    
    with gr.Blocks(css=custom_css, title="🤖 SQL Database Assistant") as demo:
        gr.Markdown("# 🤖 SQL Database Assistant")
        gr.Markdown("Ask questions about your database in natural language!")
        
        with gr.Row():
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(
                    label="Chat",
                    elem_classes="chat-container",
                    type="messages",
                    height=500
                )
                
                with gr.Row():
                    question_input = gr.Textbox(
                        label="Ask your question",
                        placeholder="Type your question here...",
                        lines=2,
                        scale=4
                    )
                    submit_button = gr.Button("Send", variant="primary", scale=1)
                
                streaming_output_display = gr.Markdown(visible=False)
                
            with gr.Column(scale=1):
                chart_display = gr.Plot(
                    label="Charts",
                    elem_classes="chart-container",
                    height=500
                )
        
        # Status indicators
        with gr.Row():
            status_indicator = gr.Markdown(
                "### ✅ System Status\n- **Database**: Ready\n- **AI Model**: Ready\n- **API**: Available",
                elem_id="status"
            )
        
        return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display

# ... (resto del código existente sin cambios) ...

def create_application():
    """Create and configure the Gradio application."""
    # Create the UI components
    demo, chatbot, chart_display, question_input, submit_button, streaming_output_display = create_ui()
    
    # Montar la API Flask en la aplicación Gradio
    if os.getenv('SPACE_ID'):
        import api
        demo = gr.mount_gradio_app(
            api.app,
            "/api",  # Prefijo para los endpoints de la API
            lambda: True  # Autenticación deshabilitada
        )
    
    def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
        """Add user message to chat history (messages format) and clear input."""
        if not user_input.strip():
            return "", chat_history

        logger.info(f"User message: {user_input}")

        if chat_history is None:
            chat_history = []

        # Append user message in messages format
        chat_history.append({"role": "user", "content": user_input})

        return "", chat_history
    
    async def bot_response(chat_history: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Optional[go.Figure]]:
        """Generate bot response for messages-format chat history and return optional chart figure."""
        if not chat_history:
            return chat_history, None

        # Ensure last message is a user turn awaiting assistant reply
        last = chat_history[-1]
        if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
            return chat_history, None

        try:
            question = last["content"]
            logger.info(f"Processing question: {question}")

            # Convert prior messages to pair history for stream_agent_response()
            pair_history: List[List[str]] = []
            i = 0
            while i < len(chat_history) - 1:
                m1 = chat_history[i]
                m2 = chat_history[i + 1] if i + 1 < len(chat_history) else None
                if (
                    isinstance(m1, dict)
                    and m1.get("role") == "user"
                    and isinstance(m2, dict)
                    and m2.get("role") == "assistant"
                ):
                    pair_history.append([m1.get("content", ""), m2.get("content", "")])
                    i += 2
                else:
                    i += 1

            # Call the agent for this new user question
            assistant_message, chart_fig = await stream_agent_response(question, pair_history)

            # Append assistant message back into messages history
            chat_history.append({"role": "assistant", "content": assistant_message})

            logger.info("Response generation complete")
            return chat_history, chart_fig

        except Exception as e:
            error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
            logger.error(error_msg, exc_info=True)
            # Ensure we add an assistant error message for the UI
            chat_history.append({"role": "assistant", "content": error_msg})
            return chat_history, None
    
    # Event handlers
    with demo:
        # Handle form submission
        msg_submit = question_input.submit(
            fn=user_message,
            inputs=[question_input, chatbot],
            outputs=[question_input, chatbot],
            queue=True
        ).then(
            fn=bot_response,
            inputs=[chatbot],
            outputs=[chatbot, chart_display],
            api_name="ask"
        )
        
        # Handle button click
        btn_click = submit_button.click(
            fn=user_message,
            inputs=[question_input, chatbot],
            outputs=[question_input, chatbot],
            queue=True
        ).then(
            fn=bot_response,
            inputs=[chatbot],
            outputs=[chatbot, chart_display]
        )
    
    return demo

async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional[go.Figure]]:
    """Process a question through the SQL agent and return response with optional chart."""
    
    # Initialize components
    llm, llm_error = initialize_llm()
    if llm_error:
        return f"**LLM Error:** {llm_error}", None
    
    db_connection, db_error = setup_database_connection()
    if db_error:
        return f"**Database Error:** {db_error}", None
    
    agent, agent_error = create_agent(llm, db_connection)
    if agent_error:
        return f"**Agent Error:** {agent_error}", None
    
    try:
        logger.info(f"Processing question: {question}")
        
        # Prepare the input with chat history
        input_data = {"input": question}
        if chat_history:
            # Format chat history for the agent
            formatted_history = []
            for human, ai in chat_history:
                formatted_history.extend([
                    HumanMessage(content=human),
                    AIMessage(content=ai)
                ])
            input_data["chat_history"] = formatted_history
        
        # Execute the agent
        response = agent.invoke(input_data)
        
        # Extract the response text
        if hasattr(response, 'output') and response.output:
            response_text = response.output
        elif isinstance(response, dict) and 'output' in response:
            response_text = response['output']
        elif isinstance(response, str):
            response_text = response
        else:
            response_text = str(response)
        
        # Check for SQL queries in the response
        sql_pattern = r'```sql\s*(.*?)\s*```'
        sql_matches = re.findall(sql_pattern, response_text, re.DOTALL)
        
        chart_fig = None
        if sql_matches:
            # Try to execute the SQL and create a chart
            try:
                sql_query = sql_matches[-1].strip()
                logger.info(f"Executing SQL query: {sql_query}")
                
                # Execute the query
                result = db_connection.run(sql_query)
                
                if result:
                    # Convert result to DataFrame
                    import pandas as pd
                    if isinstance(result, list) and result:
                        df = pd.DataFrame(result)
                        
                        # Determine chart type based on data
                        if len(df.columns) >= 2:
                            # Simple bar chart for categorical data
                            fig = go.Figure()
                            
                            if len(df) <= 20:  # Bar chart for smaller datasets
                                fig.add_trace(go.Bar(
                                    x=df.iloc[:, 0],
                                    y=df.iloc[:, 1],
                                    name=str(df.columns[1])
                                ))
                                fig.update_layout(
                                    title=f"{df.columns[0]} vs {df.columns[1]}",
                                    xaxis_title=str(df.columns[0]),
                                    yaxis_title=str(df.columns[1])
                                )
                            else:  # Line chart for larger datasets
                                fig.add_trace(go.Scatter(
                                    x=df.iloc[:, 0],
                                    y=df.iloc[:, 1],
                                    mode='lines+markers',
                                    name=str(df.columns[1])
                                ))
                                fig.update_layout(
                                    title=f"{df.columns[0]} vs {df.columns[1]}",
                                    xaxis_title=str(df.columns[0]),
                                    yaxis_title=str(df.columns[1])
                                )
                            
                            chart_fig = fig
                        
            except Exception as e:
                logger.warning(f"Could not create chart: {e}")
                # Continue without chart
        
        return response_text, chart_fig
        
    except Exception as e:
        error_msg = f"**Error processing question:** {str(e)}"
        logger.error(error_msg, exc_info=True)
        return error_msg, None

# Create the application
demo = create_application()

# Configuración para Hugging Face Spaces
def get_app():
    """Obtiene la instancia de la aplicación Gradio para Hugging Face Spaces."""
    # Verificar si estamos en un entorno de Hugging Face Spaces
    if os.getenv('SPACE_ID'):
        # Configuración específica para Spaces
        demo.title = "🤖 Asistente de Base de Datos SQL (Demo)"
        demo.description = """
        Este es un demo del asistente de base de datos SQL. 
        Para usar la versión completa con conexión a base de datos, clona este espacio y configura las variables de entorno.
        """
    
    return demo

# Para desarrollo local
if __name__ == "__main__":
    # Configuración para desarrollo local - versión simplificada para Gradio 5.x
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        debug=True,
        share=False
    )