Gonalb's picture
add chat history
d765e31
raw
history blame
19.4 kB
import chainlit as cl
import pandas as pd
import time
from typing import Dict, Any
import os
import json
from datetime import datetime
import hashlib
import uuid
from agents.table_selection import table_selection_agent
from agents.data_retrieval import sample_data_retrieval_agent
from agents.sql_generation import sql_generation_agent
from agents.validation import query_validation_and_optimization
from agents.execution import execution_agent
from utils.bigquery_utils import init_bigquery_connection
from utils.feedback_utils import save_feedback_to_bigquery
# File to store chat history
HISTORY_DIR = "chat_history"
os.makedirs(HISTORY_DIR, exist_ok=True)
# Function to get or create a user ID
def get_user_id():
# Check if user already has an ID
user_id = cl.user_session.get("user_id")
if not user_id:
# Generate a new user ID if none exists
# In a real app, you might use authentication or cookies
user_id = str(uuid.uuid4())
cl.user_session.set("user_id", user_id)
# Create a user directory for this user
user_dir = os.path.join(HISTORY_DIR, user_id)
os.makedirs(user_dir, exist_ok=True)
return user_id
# Function to save chat history
def save_chat_history(user_id, conversation_id, user_message, assistant_response):
user_dir = os.path.join(HISTORY_DIR, user_id)
os.makedirs(user_dir, exist_ok=True)
history_file = os.path.join(user_dir, f"{conversation_id}.json")
# Load existing history if it exists
if os.path.exists(history_file):
with open(history_file, 'r') as f:
history = json.load(f)
else:
history = {
"conversation_id": conversation_id,
"created_at": datetime.now().isoformat(),
"messages": []
}
# Add new messages
timestamp = datetime.now().isoformat()
history["messages"].append({
"timestamp": timestamp,
"user": user_message,
"assistant": assistant_response
})
# Save updated history
with open(history_file, 'w') as f:
json.dump(history, f, indent=2)
# Function to load chat history for a user
def load_chat_history(user_id, conversation_id=None):
user_dir = os.path.join(HISTORY_DIR, user_id)
if not os.path.exists(user_dir):
return []
if conversation_id:
# Load specific conversation
history_file = os.path.join(user_dir, f"{conversation_id}.json")
if os.path.exists(history_file):
with open(history_file, 'r') as f:
return json.load(f)
return None
else:
# Load all conversations (just metadata)
conversations = []
for filename in os.listdir(user_dir):
if filename.endswith('.json'):
with open(os.path.join(user_dir, filename), 'r') as f:
data = json.load(f)
conversations.append({
"conversation_id": data["conversation_id"],
"created_at": data["created_at"],
"message_count": len(data["messages"])
})
return sorted(conversations, key=lambda x: x["created_at"], reverse=True)
@cl.on_chat_start
async def start():
# Get or create user ID
user_id = get_user_id()
# Create a new conversation ID for this session
conversation_id = str(uuid.uuid4())
cl.user_session.set("conversation_id", conversation_id)
# Load previous conversations for this user
conversations = load_chat_history(user_id)
if conversations:
# Display a welcome back message
await cl.Message(
content=f"Welcome back! You have {len(conversations)} previous conversations.",
).send()
# Create a dropdown to select previous conversations
options = [{"value": conv["conversation_id"], "label": f"Conversation from {conv['created_at'][:10]} ({conv['message_count']} messages)"} for conv in conversations]
options.insert(0, {"value": "new", "label": "Start a new conversation"})
res = await cl.AskUserMessage(
content="Would you like to continue a previous conversation or start a new one?",
timeout=30,
select={"options": options, "value": "new"}
).send()
if res and res.get("value") != "new":
# User selected a previous conversation
selected_conv_id = res.get("value")
cl.user_session.set("conversation_id", selected_conv_id)
# Load the selected conversation
conversation = load_chat_history(user_id, selected_conv_id)
if conversation:
# Display previous messages
await cl.Message(content="Loading your previous conversation...").send()
for msg in conversation["messages"]:
await cl.Message(content=msg["user"], author="You").send()
await cl.Message(content=msg["assistant"]).send()
else:
# First time user
await cl.Message(
content="Welcome to the E-commerce Analytics Assistant! How can I help you today?",
).send()
# Add option to clear history
await cl.Action(name="clear_history", label="Clear Chat History", description="Erase all previous conversation history").send()
@cl.on_message
async def on_message(message: cl.Message):
# Get user and conversation IDs
user_id = cl.user_session.get("user_id")
conversation_id = cl.user_session.get("conversation_id")
# Process the message and generate a response
# Replace this with your actual processing logic
response = f"You said: {message.content}"
# Send the response
await cl.Message(content=response).send()
# Save to chat history
save_chat_history(user_id, conversation_id, message.content, response)
@cl.action_callback("clear_history")
async def clear_history_callback(action):
user_id = cl.user_session.get("user_id")
conversation_id = cl.user_session.get("conversation_id")
# Clear specific conversation or all conversations
user_dir = os.path.join(HISTORY_DIR, user_id)
history_file = os.path.join(user_dir, f"{conversation_id}.json")
if os.path.exists(history_file):
os.remove(history_file)
await cl.Message(content="Your current conversation history has been cleared.").send()
# Create a new conversation ID
new_conversation_id = str(uuid.uuid4())
cl.user_session.set("conversation_id", new_conversation_id)
@cl.on_message
async def on_message(message: cl.Message):
"""Handle user messages."""
query = message.content
# Check if we're in "awaiting feedback" mode
awaiting_feedback = cl.user_session.get("awaiting_feedback", False)
if awaiting_feedback:
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
# Save the detailed feedback
feedback_details = f"negative: {query}"
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
feedback_details
)
# Reset the awaiting feedback flag
cl.user_session.set("awaiting_feedback", False)
if success:
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
return
# If not in feedback mode, process as a regular query
# Get the BigQuery client from the user session
client = cl.user_session.get("client")
# Store the original query in the user session for feedback
cl.user_session.set("original_query", query)
# Send a thinking message
thinking_msg = await cl.Message(content="🤔 Thinking...", author="SQL Assistant").send()
try:
# Step 1: Analyze relevant tables
thinking_msg.content = "🔍 Analyzing relevant tables..."
await thinking_msg.update()
# Initialize the state with the query
state = {"sql_query": query, "client": client}
tables_state = table_selection_agent(state)
relevant_tables = tables_state.get("relevant_tables", [])
# Send the tables analysis with a slight delay for better UX
await cl.sleep(1)
if relevant_tables:
tables_text = "I've identified these relevant tables for your query:\n\n"
tables_text += "\n".join([f"- `{table}`" for table in relevant_tables])
await cl.Message(content=tables_text, author="SQL Assistant").send()
# Step 2: Retrieve sample data
thinking_msg.content = "📊 Retrieving sample data..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with relevant tables and get sample data
state.update(tables_state)
sample_data_state = sample_data_retrieval_agent(state)
# Step 3: Generate SQL
thinking_msg.content = "💻 Generating SQL query..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with sample data and generate SQL
state.update(sample_data_state)
sql_state = sql_generation_agent(state)
generated_sql = sql_state.get("generated_sql", "No SQL generated")
# Store the generated SQL in the user session
cl.user_session.set("generated_sql", generated_sql)
# Send the generated SQL
await cl.Message(
content=f"Here's the SQL query I generated:\n\n```sql\n{generated_sql}\n```",
author="SQL Assistant"
).send()
# Step 4: Optimize SQL
thinking_msg.content = "🔧 Optimizing the query..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with generated SQL and optimize
state.update(sql_state)
optimization_state = query_validation_and_optimization(state)
optimized_sql = optimization_state.get("optimized_sql", "No optimized SQL")
# Store the optimized SQL in the user session
cl.user_session.set("optimized_sql", optimized_sql)
# Send the optimized SQL
await cl.Message(
content=f"Here's the optimized version of the query:\n\n```sql\n{optimized_sql}\n```",
author="SQL Assistant"
).send()
# Step 5: Execute query
thinking_msg.content = "⚙️ Executing query..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with optimized SQL and execute
state.update(optimization_state)
execution_state = execution_agent(state)
execution_result = execution_state.get("execution_result", {})
# Format and send the results
if isinstance(execution_result, dict) and "error" in execution_result:
error_msg = execution_result.get("error", "Unknown error occurred")
await cl.Message(
content=f"❌ Error executing query: {error_msg}",
author="SQL Assistant"
).send()
elif not execution_result:
await cl.Message(
content="✅ Query executed successfully but returned no results.",
author="SQL Assistant"
).send()
else:
try:
# Convert results to DataFrame for better display
if isinstance(execution_result[0], tuple):
# Try to get column names from BigQuery schema
try:
# Get the schema from the query job
query_job = client.query(optimized_sql)
schema = query_job.result().schema
column_names = [field.name for field in schema]
# Use these column names for the DataFrame
df = pd.DataFrame(execution_result, columns=column_names)
except Exception:
# Fallback to generic column names
columns = [f"Column_{i}" for i in range(len(execution_result[0]))]
df = pd.DataFrame(execution_result, columns=columns)
else:
df = pd.DataFrame(execution_result)
# Display the DataFrame as a table
await cl.Message(
content="✅ Query executed successfully! Here are the results:",
author="SQL Assistant"
).send()
# Send the DataFrame as an element
elements = [cl.Dataframe(data=df)]
await cl.Message(content="", elements=elements, author="SQL Assistant").send()
# Also provide a summary of the results with feedback buttons
num_rows = len(df)
num_cols = len(df.columns)
# Ask for feedback using AskActionMessage
res = await cl.AskActionMessage(
content=f"The query returned {num_rows} rows and {num_cols} columns.\n\nWas this result helpful?",
actions=[
cl.Action(name="feedback", payload={"value": "positive"}, label="👍 Good results"),
cl.Action(name="feedback", payload={"value": "negative"}, label="👎 Not what I wanted")
],
).send()
if res:
feedback_value = res.get("payload", {}).get("value")
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
if feedback_value == "positive":
# Handle positive feedback
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
"positive"
)
if success:
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
elif feedback_value == "negative":
# For negative feedback, just ask for text input
await cl.Message(content="I'm sorry the results weren't what you expected. Please type your feedback about what was wrong.", author="SQL Assistant").send()
# Set flag to indicate we're awaiting detailed feedback
cl.user_session.set("awaiting_feedback", True)
# Save initial negative feedback
save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
"negative"
)
except Exception as e:
await cl.Message(
content=f"❌ Error formatting results: {str(e)}",
author="SQL Assistant"
).send()
except Exception as e:
# Handle any errors
thinking_msg.content = f"❌ Error: {str(e)}"
await thinking_msg.update()
await cl.Message(
content=f"I encountered an error while processing your query: {str(e)}",
author="SQL Assistant"
).send()
# Callback handlers for actions
@cl.action_callback("feedback")
async def on_feedback_action(action):
"""Handle feedback action."""
feedback_value = action.payload.get("value")
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
if feedback_value == "positive":
# Handle positive feedback
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
"positive"
)
if success:
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
@cl.action_callback("feedback_bad")
async def on_feedback_bad(action):
"""Handle negative feedback."""
# Ask for more detailed feedback
res = await cl.AskUserMessage(
content="I'm sorry the results weren't what you expected. Could you please provide more details about what was wrong?",
author="SQL Assistant",
timeout=300,
elements=[
cl.Textarea(
id="feedback_details",
label="Your feedback",
initial_value="",
rows=3
)
]
).send()
feedback_details = "negative"
if res and "feedback_details" in res:
feedback_details = f"negative: {res['feedback_details']}"
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
# Save the feedback to BigQuery
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
feedback_details
)
if success:
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
# This is needed for Chainlit to run properly
if __name__ == "__main__":
# Note: Chainlit uses its own CLI command to run the app
# You'll run this with: chainlit run new_app.py -w
pass