File size: 19,443 Bytes
05e3517
 
 
 
b2b0814
 
 
 
 
05e3517
 
 
 
 
 
 
 
 
b2b0814
 
 
 
 
d765e31
b2b0814
d765e31
b2b0814
 
 
 
 
d765e31
b2b0814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05e3517
b2b0814
 
d765e31
b2b0814
 
 
d765e31
b2b0814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d765e31
b2b0814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d765e31
 
 
b2b0814
 
 
 
 
 
 
 
 
 
 
 
 
d765e31
 
b2b0814
 
 
 
 
 
 
 
 
 
 
d765e31
05e3517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
import chainlit as cl
import pandas as pd
import time
from typing import Dict, Any
import os
import json
from datetime import datetime
import hashlib
import uuid

from agents.table_selection import table_selection_agent
from agents.data_retrieval import sample_data_retrieval_agent
from agents.sql_generation import sql_generation_agent
from agents.validation import query_validation_and_optimization
from agents.execution import execution_agent
from utils.bigquery_utils import init_bigquery_connection
from utils.feedback_utils import save_feedback_to_bigquery

# File to store chat history
HISTORY_DIR = "chat_history"
os.makedirs(HISTORY_DIR, exist_ok=True)

# Function to get or create a user ID
def get_user_id():
    # Check if user already has an ID
    user_id = cl.user_session.get("user_id")
    
    if not user_id:
        # Generate a new user ID if none exists
        # In a real app, you might use authentication or cookies
        user_id = str(uuid.uuid4())
        cl.user_session.set("user_id", user_id)
        
        # Create a user directory for this user
        user_dir = os.path.join(HISTORY_DIR, user_id)
        os.makedirs(user_dir, exist_ok=True)
    
    return user_id

# Function to save chat history
def save_chat_history(user_id, conversation_id, user_message, assistant_response):
    user_dir = os.path.join(HISTORY_DIR, user_id)
    os.makedirs(user_dir, exist_ok=True)
    
    history_file = os.path.join(user_dir, f"{conversation_id}.json")
    
    # Load existing history if it exists
    if os.path.exists(history_file):
        with open(history_file, 'r') as f:
            history = json.load(f)
    else:
        history = {
            "conversation_id": conversation_id,
            "created_at": datetime.now().isoformat(),
            "messages": []
        }
    
    # Add new messages
    timestamp = datetime.now().isoformat()
    history["messages"].append({
        "timestamp": timestamp,
        "user": user_message,
        "assistant": assistant_response
    })
    
    # Save updated history
    with open(history_file, 'w') as f:
        json.dump(history, f, indent=2)

# Function to load chat history for a user
def load_chat_history(user_id, conversation_id=None):
    user_dir = os.path.join(HISTORY_DIR, user_id)
    
    if not os.path.exists(user_dir):
        return []
    
    if conversation_id:
        # Load specific conversation
        history_file = os.path.join(user_dir, f"{conversation_id}.json")
        if os.path.exists(history_file):
            with open(history_file, 'r') as f:
                return json.load(f)
        return None
    else:
        # Load all conversations (just metadata)
        conversations = []
        for filename in os.listdir(user_dir):
            if filename.endswith('.json'):
                with open(os.path.join(user_dir, filename), 'r') as f:
                    data = json.load(f)
                    conversations.append({
                        "conversation_id": data["conversation_id"],
                        "created_at": data["created_at"],
                        "message_count": len(data["messages"])
                    })
        return sorted(conversations, key=lambda x: x["created_at"], reverse=True)

@cl.on_chat_start
async def start():
    # Get or create user ID
    user_id = get_user_id()
    
    # Create a new conversation ID for this session
    conversation_id = str(uuid.uuid4())
    cl.user_session.set("conversation_id", conversation_id)
    
    # Load previous conversations for this user
    conversations = load_chat_history(user_id)
    
    if conversations:
        # Display a welcome back message
        await cl.Message(
            content=f"Welcome back! You have {len(conversations)} previous conversations.",
        ).send()
        
        # Create a dropdown to select previous conversations
        options = [{"value": conv["conversation_id"], "label": f"Conversation from {conv['created_at'][:10]} ({conv['message_count']} messages)"} for conv in conversations]
        options.insert(0, {"value": "new", "label": "Start a new conversation"})
        
        res = await cl.AskUserMessage(
            content="Would you like to continue a previous conversation or start a new one?",
            timeout=30,
            select={"options": options, "value": "new"}
        ).send()
        
        if res and res.get("value") != "new":
            # User selected a previous conversation
            selected_conv_id = res.get("value")
            cl.user_session.set("conversation_id", selected_conv_id)
            
            # Load the selected conversation
            conversation = load_chat_history(user_id, selected_conv_id)
            if conversation:
                # Display previous messages
                await cl.Message(content="Loading your previous conversation...").send()
                for msg in conversation["messages"]:
                    await cl.Message(content=msg["user"], author="You").send()
                    await cl.Message(content=msg["assistant"]).send()
    else:
        # First time user
        await cl.Message(
            content="Welcome to the E-commerce Analytics Assistant! How can I help you today?",
        ).send()
    
    # Add option to clear history
    await cl.Action(name="clear_history", label="Clear Chat History", description="Erase all previous conversation history").send()

@cl.on_message
async def on_message(message: cl.Message):
    # Get user and conversation IDs
    user_id = cl.user_session.get("user_id")
    conversation_id = cl.user_session.get("conversation_id")
    
    # Process the message and generate a response
    # Replace this with your actual processing logic
    response = f"You said: {message.content}"
    
    # Send the response
    await cl.Message(content=response).send()
    
    # Save to chat history
    save_chat_history(user_id, conversation_id, message.content, response)

@cl.action_callback("clear_history")
async def clear_history_callback(action):
    user_id = cl.user_session.get("user_id")
    conversation_id = cl.user_session.get("conversation_id")
    
    # Clear specific conversation or all conversations
    user_dir = os.path.join(HISTORY_DIR, user_id)
    history_file = os.path.join(user_dir, f"{conversation_id}.json")
    
    if os.path.exists(history_file):
        os.remove(history_file)
        await cl.Message(content="Your current conversation history has been cleared.").send()
    
    # Create a new conversation ID
    new_conversation_id = str(uuid.uuid4())
    cl.user_session.set("conversation_id", new_conversation_id)

@cl.on_message
async def on_message(message: cl.Message):
    """Handle user messages."""
    query = message.content
    
    # Check if we're in "awaiting feedback" mode
    awaiting_feedback = cl.user_session.get("awaiting_feedback", False)
    if awaiting_feedback:
        client = cl.user_session.get("client")
        original_query = cl.user_session.get("original_query")
        generated_sql = cl.user_session.get("generated_sql")
        optimized_sql = cl.user_session.get("optimized_sql")
        
        # Save the detailed feedback
        feedback_details = f"negative: {query}"
        success = save_feedback_to_bigquery(
            client, 
            original_query, 
            generated_sql, 
            optimized_sql, 
            feedback_details
        )
        
        # Reset the awaiting feedback flag
        cl.user_session.set("awaiting_feedback", False)
        
        if success:
            await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
        else:
            await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
        return
    
    # If not in feedback mode, process as a regular query
    # Get the BigQuery client from the user session
    client = cl.user_session.get("client")
    
    # Store the original query in the user session for feedback
    cl.user_session.set("original_query", query)
    
    # Send a thinking message
    thinking_msg = await cl.Message(content="🤔 Thinking...", author="SQL Assistant").send()
    
    try:
        # Step 1: Analyze relevant tables
        thinking_msg.content = "🔍 Analyzing relevant tables..."
        await thinking_msg.update()
        
        # Initialize the state with the query
        state = {"sql_query": query, "client": client}
        tables_state = table_selection_agent(state)
        relevant_tables = tables_state.get("relevant_tables", [])
        
        # Send the tables analysis with a slight delay for better UX
        await cl.sleep(1)
        if relevant_tables:
            tables_text = "I've identified these relevant tables for your query:\n\n"
            tables_text += "\n".join([f"- `{table}`" for table in relevant_tables])
            await cl.Message(content=tables_text, author="SQL Assistant").send()
        
        # Step 2: Retrieve sample data
        thinking_msg.content = "📊 Retrieving sample data..."
        await thinking_msg.update()
        await cl.sleep(1)
        
        # Update state with relevant tables and get sample data
        state.update(tables_state)
        sample_data_state = sample_data_retrieval_agent(state)
        
        # Step 3: Generate SQL
        thinking_msg.content = "💻 Generating SQL query..."
        await thinking_msg.update()
        await cl.sleep(1)
        
        # Update state with sample data and generate SQL
        state.update(sample_data_state)
        sql_state = sql_generation_agent(state)
        generated_sql = sql_state.get("generated_sql", "No SQL generated")
        
        # Store the generated SQL in the user session
        cl.user_session.set("generated_sql", generated_sql)
        
        # Send the generated SQL
        await cl.Message(
            content=f"Here's the SQL query I generated:\n\n```sql\n{generated_sql}\n```",
            author="SQL Assistant"
        ).send()
        
        # Step 4: Optimize SQL
        thinking_msg.content = "🔧 Optimizing the query..."
        await thinking_msg.update()
        await cl.sleep(1)
        
        # Update state with generated SQL and optimize
        state.update(sql_state)
        optimization_state = query_validation_and_optimization(state)
        optimized_sql = optimization_state.get("optimized_sql", "No optimized SQL")
        
        # Store the optimized SQL in the user session
        cl.user_session.set("optimized_sql", optimized_sql)
        
        # Send the optimized SQL
        await cl.Message(
            content=f"Here's the optimized version of the query:\n\n```sql\n{optimized_sql}\n```",
            author="SQL Assistant"
        ).send()
        
        # Step 5: Execute query
        thinking_msg.content = "⚙️ Executing query..."
        await thinking_msg.update()
        await cl.sleep(1)
        
        # Update state with optimized SQL and execute
        state.update(optimization_state)
        execution_state = execution_agent(state)
        execution_result = execution_state.get("execution_result", {})
        
        # Format and send the results
        if isinstance(execution_result, dict) and "error" in execution_result:
            error_msg = execution_result.get("error", "Unknown error occurred")
            await cl.Message(
                content=f"❌ Error executing query: {error_msg}",
                author="SQL Assistant"
            ).send()
        elif not execution_result:
            await cl.Message(
                content="✅ Query executed successfully but returned no results.",
                author="SQL Assistant"
            ).send()
        else:
            try:
                # Convert results to DataFrame for better display
                if isinstance(execution_result[0], tuple):
                    # Try to get column names from BigQuery schema
                    try:
                        # Get the schema from the query job
                        query_job = client.query(optimized_sql)
                        schema = query_job.result().schema
                        column_names = [field.name for field in schema]
                        
                        # Use these column names for the DataFrame
                        df = pd.DataFrame(execution_result, columns=column_names)
                    except Exception:
                        # Fallback to generic column names
                        columns = [f"Column_{i}" for i in range(len(execution_result[0]))]
                        df = pd.DataFrame(execution_result, columns=columns)
                else:
                    df = pd.DataFrame(execution_result)
                
                # Display the DataFrame as a table
                await cl.Message(
                    content="✅ Query executed successfully! Here are the results:",
                    author="SQL Assistant"
                ).send()
                
                # Send the DataFrame as an element
                elements = [cl.Dataframe(data=df)]
                await cl.Message(content="", elements=elements, author="SQL Assistant").send()
                
                # Also provide a summary of the results with feedback buttons
                num_rows = len(df)
                num_cols = len(df.columns)
                
                # Ask for feedback using AskActionMessage
                res = await cl.AskActionMessage(
                    content=f"The query returned {num_rows} rows and {num_cols} columns.\n\nWas this result helpful?",
                    actions=[
                        cl.Action(name="feedback", payload={"value": "positive"}, label="👍 Good results"),
                        cl.Action(name="feedback", payload={"value": "negative"}, label="👎 Not what I wanted")
                    ],
                ).send()
                
                if res:
                    feedback_value = res.get("payload", {}).get("value")
                    
                    client = cl.user_session.get("client")
                    original_query = cl.user_session.get("original_query")
                    generated_sql = cl.user_session.get("generated_sql")
                    optimized_sql = cl.user_session.get("optimized_sql")
                    
                    if feedback_value == "positive":
                        # Handle positive feedback
                        success = save_feedback_to_bigquery(
                            client, 
                            original_query, 
                            generated_sql, 
                            optimized_sql, 
                            "positive"
                        )
                        
                        if success:
                            await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
                        else:
                            await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
                    
                    elif feedback_value == "negative":
                        # For negative feedback, just ask for text input
                        await cl.Message(content="I'm sorry the results weren't what you expected. Please type your feedback about what was wrong.", author="SQL Assistant").send()
                        
                        # Set flag to indicate we're awaiting detailed feedback
                        cl.user_session.set("awaiting_feedback", True)
                        
                        # Save initial negative feedback
                        save_feedback_to_bigquery(
                            client, 
                            original_query, 
                            generated_sql, 
                            optimized_sql, 
                            "negative"
                        )
                
            except Exception as e:
                await cl.Message(
                    content=f"❌ Error formatting results: {str(e)}",
                    author="SQL Assistant"
                ).send()
        
    except Exception as e:
        # Handle any errors
        thinking_msg.content = f"❌ Error: {str(e)}"
        await thinking_msg.update()
        
        await cl.Message(
            content=f"I encountered an error while processing your query: {str(e)}",
            author="SQL Assistant"
        ).send()

# Callback handlers for actions
@cl.action_callback("feedback")
async def on_feedback_action(action):
    """Handle feedback action."""
    feedback_value = action.payload.get("value")
    
    client = cl.user_session.get("client")
    original_query = cl.user_session.get("original_query")
    generated_sql = cl.user_session.get("generated_sql")
    optimized_sql = cl.user_session.get("optimized_sql")
    
    if feedback_value == "positive":
        # Handle positive feedback
        success = save_feedback_to_bigquery(
            client, 
            original_query, 
            generated_sql, 
            optimized_sql, 
            "positive"
        )
        
        if success:
                await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
        else:
                await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()

@cl.action_callback("feedback_bad")
async def on_feedback_bad(action):
    """Handle negative feedback."""
    # Ask for more detailed feedback
    res = await cl.AskUserMessage(
        content="I'm sorry the results weren't what you expected. Could you please provide more details about what was wrong?",
        author="SQL Assistant",
        timeout=300,
        elements=[
            cl.Textarea(
                id="feedback_details",
                label="Your feedback",
                initial_value="",
                rows=3
            )
        ]
    ).send()
    
    feedback_details = "negative"
    if res and "feedback_details" in res:
        feedback_details = f"negative: {res['feedback_details']}"
    
    client = cl.user_session.get("client")
    original_query = cl.user_session.get("original_query")
    generated_sql = cl.user_session.get("generated_sql")
    optimized_sql = cl.user_session.get("optimized_sql")
    
    # Save the feedback to BigQuery
    success = save_feedback_to_bigquery(
        client, 
        original_query, 
        generated_sql, 
        optimized_sql, 
        feedback_details
    )
    
    if success:
        await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
    else:
        await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()

# This is needed for Chainlit to run properly
if __name__ == "__main__":
    # Note: Chainlit uses its own CLI command to run the app
    # You'll run this with: chainlit run new_app.py -w
    pass