Spaces:
Sleeping
Sleeping
import chainlit as cl | |
import pandas as pd | |
import time | |
from typing import Dict, Any | |
import os | |
import json | |
from datetime import datetime | |
import hashlib | |
import uuid | |
from agents.table_selection import table_selection_agent | |
from agents.data_retrieval import sample_data_retrieval_agent | |
from agents.sql_generation import sql_generation_agent | |
from agents.validation import query_validation_and_optimization | |
from agents.execution import execution_agent | |
from utils.bigquery_utils import init_bigquery_connection | |
from utils.feedback_utils import save_feedback_to_bigquery | |
# File to store chat history | |
HISTORY_DIR = "chat_history" | |
os.makedirs(HISTORY_DIR, exist_ok=True) | |
# Function to get or create a user ID | |
def get_user_id(): | |
# Check if user already has an ID | |
user_id = cl.user_session.get("user_id") | |
if not user_id: | |
# Generate a new user ID if none exists | |
# In a real app, you might use authentication or cookies | |
user_id = str(uuid.uuid4()) | |
cl.user_session.set("user_id", user_id) | |
# Create a user directory for this user | |
user_dir = os.path.join(HISTORY_DIR, user_id) | |
os.makedirs(user_dir, exist_ok=True) | |
return user_id | |
# Function to save chat history | |
def save_chat_history(user_id, conversation_id, user_message, assistant_response): | |
user_dir = os.path.join(HISTORY_DIR, user_id) | |
os.makedirs(user_dir, exist_ok=True) | |
history_file = os.path.join(user_dir, f"{conversation_id}.json") | |
# Load existing history if it exists | |
if os.path.exists(history_file): | |
with open(history_file, 'r') as f: | |
history = json.load(f) | |
else: | |
history = { | |
"conversation_id": conversation_id, | |
"created_at": datetime.now().isoformat(), | |
"messages": [] | |
} | |
# Add new messages | |
timestamp = datetime.now().isoformat() | |
history["messages"].append({ | |
"timestamp": timestamp, | |
"user": user_message, | |
"assistant": assistant_response | |
}) | |
# Save updated history | |
with open(history_file, 'w') as f: | |
json.dump(history, f, indent=2) | |
# Function to load chat history for a user | |
def load_chat_history(user_id, conversation_id=None): | |
user_dir = os.path.join(HISTORY_DIR, user_id) | |
if not os.path.exists(user_dir): | |
return [] | |
if conversation_id: | |
# Load specific conversation | |
history_file = os.path.join(user_dir, f"{conversation_id}.json") | |
if os.path.exists(history_file): | |
with open(history_file, 'r') as f: | |
return json.load(f) | |
return None | |
else: | |
# Load all conversations (just metadata) | |
conversations = [] | |
for filename in os.listdir(user_dir): | |
if filename.endswith('.json'): | |
with open(os.path.join(user_dir, filename), 'r') as f: | |
data = json.load(f) | |
conversations.append({ | |
"conversation_id": data["conversation_id"], | |
"created_at": data["created_at"], | |
"message_count": len(data["messages"]) | |
}) | |
return sorted(conversations, key=lambda x: x["created_at"], reverse=True) | |
async def start(): | |
# Get or create user ID | |
user_id = get_user_id() | |
# Create a new conversation ID for this session | |
conversation_id = str(uuid.uuid4()) | |
cl.user_session.set("conversation_id", conversation_id) | |
# Load previous conversations for this user | |
conversations = load_chat_history(user_id) | |
if conversations: | |
# Display a welcome back message | |
await cl.Message( | |
content=f"Welcome back! You have {len(conversations)} previous conversations.", | |
).send() | |
# Create a dropdown to select previous conversations | |
options = [{"value": conv["conversation_id"], "label": f"Conversation from {conv['created_at'][:10]} ({conv['message_count']} messages)"} for conv in conversations] | |
options.insert(0, {"value": "new", "label": "Start a new conversation"}) | |
res = await cl.AskUserMessage( | |
content="Would you like to continue a previous conversation or start a new one?", | |
timeout=30, | |
select={"options": options, "value": "new"} | |
).send() | |
if res and res.get("value") != "new": | |
# User selected a previous conversation | |
selected_conv_id = res.get("value") | |
cl.user_session.set("conversation_id", selected_conv_id) | |
# Load the selected conversation | |
conversation = load_chat_history(user_id, selected_conv_id) | |
if conversation: | |
# Display previous messages | |
await cl.Message(content="Loading your previous conversation...").send() | |
for msg in conversation["messages"]: | |
await cl.Message(content=msg["user"], author="You").send() | |
await cl.Message(content=msg["assistant"]).send() | |
else: | |
# First time user | |
await cl.Message( | |
content="Welcome to the E-commerce Analytics Assistant! How can I help you today?", | |
).send() | |
# Add option to clear history | |
await cl.Action(name="clear_history", label="Clear Chat History", description="Erase all previous conversation history").send() | |
async def on_message(message: cl.Message): | |
# Get user and conversation IDs | |
user_id = cl.user_session.get("user_id") | |
conversation_id = cl.user_session.get("conversation_id") | |
# Process the message and generate a response | |
# Replace this with your actual processing logic | |
response = f"You said: {message.content}" | |
# Send the response | |
await cl.Message(content=response).send() | |
# Save to chat history | |
save_chat_history(user_id, conversation_id, message.content, response) | |
async def clear_history_callback(action): | |
user_id = cl.user_session.get("user_id") | |
conversation_id = cl.user_session.get("conversation_id") | |
# Clear specific conversation or all conversations | |
user_dir = os.path.join(HISTORY_DIR, user_id) | |
history_file = os.path.join(user_dir, f"{conversation_id}.json") | |
if os.path.exists(history_file): | |
os.remove(history_file) | |
await cl.Message(content="Your current conversation history has been cleared.").send() | |
# Create a new conversation ID | |
new_conversation_id = str(uuid.uuid4()) | |
cl.user_session.set("conversation_id", new_conversation_id) | |
async def on_message(message: cl.Message): | |
"""Handle user messages.""" | |
query = message.content | |
# Check if we're in "awaiting feedback" mode | |
awaiting_feedback = cl.user_session.get("awaiting_feedback", False) | |
if awaiting_feedback: | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
# Save the detailed feedback | |
feedback_details = f"negative: {query}" | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
feedback_details | |
) | |
# Reset the awaiting feedback flag | |
cl.user_session.set("awaiting_feedback", False) | |
if success: | |
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
return | |
# If not in feedback mode, process as a regular query | |
# Get the BigQuery client from the user session | |
client = cl.user_session.get("client") | |
# Store the original query in the user session for feedback | |
cl.user_session.set("original_query", query) | |
# Send a thinking message | |
thinking_msg = await cl.Message(content="🤔 Thinking...", author="SQL Assistant").send() | |
try: | |
# Step 1: Analyze relevant tables | |
thinking_msg.content = "🔍 Analyzing relevant tables..." | |
await thinking_msg.update() | |
# Initialize the state with the query | |
state = {"sql_query": query, "client": client} | |
tables_state = table_selection_agent(state) | |
relevant_tables = tables_state.get("relevant_tables", []) | |
# Send the tables analysis with a slight delay for better UX | |
await cl.sleep(1) | |
if relevant_tables: | |
tables_text = "I've identified these relevant tables for your query:\n\n" | |
tables_text += "\n".join([f"- `{table}`" for table in relevant_tables]) | |
await cl.Message(content=tables_text, author="SQL Assistant").send() | |
# Step 2: Retrieve sample data | |
thinking_msg.content = "📊 Retrieving sample data..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with relevant tables and get sample data | |
state.update(tables_state) | |
sample_data_state = sample_data_retrieval_agent(state) | |
# Step 3: Generate SQL | |
thinking_msg.content = "💻 Generating SQL query..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with sample data and generate SQL | |
state.update(sample_data_state) | |
sql_state = sql_generation_agent(state) | |
generated_sql = sql_state.get("generated_sql", "No SQL generated") | |
# Store the generated SQL in the user session | |
cl.user_session.set("generated_sql", generated_sql) | |
# Send the generated SQL | |
await cl.Message( | |
content=f"Here's the SQL query I generated:\n\n```sql\n{generated_sql}\n```", | |
author="SQL Assistant" | |
).send() | |
# Step 4: Optimize SQL | |
thinking_msg.content = "🔧 Optimizing the query..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with generated SQL and optimize | |
state.update(sql_state) | |
optimization_state = query_validation_and_optimization(state) | |
optimized_sql = optimization_state.get("optimized_sql", "No optimized SQL") | |
# Store the optimized SQL in the user session | |
cl.user_session.set("optimized_sql", optimized_sql) | |
# Send the optimized SQL | |
await cl.Message( | |
content=f"Here's the optimized version of the query:\n\n```sql\n{optimized_sql}\n```", | |
author="SQL Assistant" | |
).send() | |
# Step 5: Execute query | |
thinking_msg.content = "⚙️ Executing query..." | |
await thinking_msg.update() | |
await cl.sleep(1) | |
# Update state with optimized SQL and execute | |
state.update(optimization_state) | |
execution_state = execution_agent(state) | |
execution_result = execution_state.get("execution_result", {}) | |
# Format and send the results | |
if isinstance(execution_result, dict) and "error" in execution_result: | |
error_msg = execution_result.get("error", "Unknown error occurred") | |
await cl.Message( | |
content=f"❌ Error executing query: {error_msg}", | |
author="SQL Assistant" | |
).send() | |
elif not execution_result: | |
await cl.Message( | |
content="✅ Query executed successfully but returned no results.", | |
author="SQL Assistant" | |
).send() | |
else: | |
try: | |
# Convert results to DataFrame for better display | |
if isinstance(execution_result[0], tuple): | |
# Try to get column names from BigQuery schema | |
try: | |
# Get the schema from the query job | |
query_job = client.query(optimized_sql) | |
schema = query_job.result().schema | |
column_names = [field.name for field in schema] | |
# Use these column names for the DataFrame | |
df = pd.DataFrame(execution_result, columns=column_names) | |
except Exception: | |
# Fallback to generic column names | |
columns = [f"Column_{i}" for i in range(len(execution_result[0]))] | |
df = pd.DataFrame(execution_result, columns=columns) | |
else: | |
df = pd.DataFrame(execution_result) | |
# Display the DataFrame as a table | |
await cl.Message( | |
content="✅ Query executed successfully! Here are the results:", | |
author="SQL Assistant" | |
).send() | |
# Send the DataFrame as an element | |
elements = [cl.Dataframe(data=df)] | |
await cl.Message(content="", elements=elements, author="SQL Assistant").send() | |
# Also provide a summary of the results with feedback buttons | |
num_rows = len(df) | |
num_cols = len(df.columns) | |
# Ask for feedback using AskActionMessage | |
res = await cl.AskActionMessage( | |
content=f"The query returned {num_rows} rows and {num_cols} columns.\n\nWas this result helpful?", | |
actions=[ | |
cl.Action(name="feedback", payload={"value": "positive"}, label="👍 Good results"), | |
cl.Action(name="feedback", payload={"value": "negative"}, label="👎 Not what I wanted") | |
], | |
).send() | |
if res: | |
feedback_value = res.get("payload", {}).get("value") | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
if feedback_value == "positive": | |
# Handle positive feedback | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
"positive" | |
) | |
if success: | |
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
elif feedback_value == "negative": | |
# For negative feedback, just ask for text input | |
await cl.Message(content="I'm sorry the results weren't what you expected. Please type your feedback about what was wrong.", author="SQL Assistant").send() | |
# Set flag to indicate we're awaiting detailed feedback | |
cl.user_session.set("awaiting_feedback", True) | |
# Save initial negative feedback | |
save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
"negative" | |
) | |
except Exception as e: | |
await cl.Message( | |
content=f"❌ Error formatting results: {str(e)}", | |
author="SQL Assistant" | |
).send() | |
except Exception as e: | |
# Handle any errors | |
thinking_msg.content = f"❌ Error: {str(e)}" | |
await thinking_msg.update() | |
await cl.Message( | |
content=f"I encountered an error while processing your query: {str(e)}", | |
author="SQL Assistant" | |
).send() | |
# Callback handlers for actions | |
async def on_feedback_action(action): | |
"""Handle feedback action.""" | |
feedback_value = action.payload.get("value") | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
if feedback_value == "positive": | |
# Handle positive feedback | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
"positive" | |
) | |
if success: | |
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
async def on_feedback_bad(action): | |
"""Handle negative feedback.""" | |
# Ask for more detailed feedback | |
res = await cl.AskUserMessage( | |
content="I'm sorry the results weren't what you expected. Could you please provide more details about what was wrong?", | |
author="SQL Assistant", | |
timeout=300, | |
elements=[ | |
cl.Textarea( | |
id="feedback_details", | |
label="Your feedback", | |
initial_value="", | |
rows=3 | |
) | |
] | |
).send() | |
feedback_details = "negative" | |
if res and "feedback_details" in res: | |
feedback_details = f"negative: {res['feedback_details']}" | |
client = cl.user_session.get("client") | |
original_query = cl.user_session.get("original_query") | |
generated_sql = cl.user_session.get("generated_sql") | |
optimized_sql = cl.user_session.get("optimized_sql") | |
# Save the feedback to BigQuery | |
success = save_feedback_to_bigquery( | |
client, | |
original_query, | |
generated_sql, | |
optimized_sql, | |
feedback_details | |
) | |
if success: | |
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send() | |
else: | |
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send() | |
# This is needed for Chainlit to run properly | |
if __name__ == "__main__": | |
# Note: Chainlit uses its own CLI command to run the app | |
# You'll run this with: chainlit run new_app.py -w | |
pass |