|
import json |
|
import os |
|
import uuid |
|
from datetime import datetime |
|
from typing import Dict |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from datasets import load_dataset |
|
from dotenv import load_dotenv |
|
|
|
from langgraph_agent import DataAnalystAgent |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
st.set_page_config( |
|
page_title="π€ LangGraph Data Analyst Agent", |
|
layout="wide", |
|
page_icon="π€", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
/* Main theme colors */ |
|
:root { |
|
--primary-color: #1f77b4; |
|
--secondary-color: #ff7f0e; |
|
--success-color: #2ca02c; |
|
--error-color: #d62728; |
|
--warning-color: #ff9800; |
|
--background-color: #0e1117; |
|
--card-background: #262730; |
|
} |
|
|
|
/* Custom styling for the main container */ |
|
.main-header { |
|
background: linear-gradient(90deg, #1f77b4 0%, #ff7f0e 100%); |
|
padding: 2rem 1rem; |
|
border-radius: 10px; |
|
margin-bottom: 2rem; |
|
text-align: center; |
|
color: white; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
} |
|
|
|
.main-header h1 { |
|
margin: 0; |
|
font-size: 2.5rem; |
|
font-weight: 700; |
|
text-shadow: 2px 2px 4px rgba(0,0,0,0.3); |
|
} |
|
|
|
.main-header p { |
|
margin: 0.5rem 0 0 0; |
|
font-size: 1.2rem; |
|
opacity: 0.9; |
|
} |
|
|
|
/* Card styling */ |
|
.info-card { |
|
background: var(--card-background); |
|
padding: 1.5rem; |
|
border-radius: 10px; |
|
border-left: 4px solid var(--primary-color); |
|
margin: 1rem 0; |
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); |
|
} |
|
|
|
.success-card { |
|
background: linear-gradient(90deg, |
|
rgba(44, 160, 44, 0.1) 0%, |
|
rgba(44, 160, 44, 0.05) 100%); |
|
border-left: 4px solid var(--success-color); |
|
padding: 1rem; |
|
border-radius: 8px; |
|
margin: 1rem 0; |
|
} |
|
|
|
.error-card { |
|
background: linear-gradient(90deg, |
|
rgba(214, 39, 40, 0.1) 0%, |
|
rgba(214, 39, 40, 0.05) 100%); |
|
border-left: 4px solid var(--error-color); |
|
padding: 1rem; |
|
border-radius: 8px; |
|
margin: 1rem 0; |
|
} |
|
|
|
.memory-card { |
|
background: linear-gradient(90deg, |
|
rgba(255, 127, 14, 0.1) 0%, |
|
rgba(255, 127, 14, 0.05) 100%); |
|
border-left: 4px solid var(--secondary-color); |
|
padding: 1rem; |
|
border-radius: 8px; |
|
margin: 1rem 0; |
|
} |
|
|
|
/* Chat message styling */ |
|
.user-message { |
|
background: linear-gradient(90deg, |
|
rgba(31, 119, 180, 0.1) 0%, |
|
rgba(31, 119, 180, 0.05) 100%); |
|
padding: 1rem; |
|
border-radius: 10px; |
|
margin: 0.5rem 0; |
|
border-left: 4px solid var(--primary-color); |
|
} |
|
|
|
.assistant-message { |
|
background: linear-gradient(90deg, |
|
rgba(255, 127, 14, 0.1) 0%, |
|
rgba(255, 127, 14, 0.05) 100%); |
|
padding: 1rem; |
|
border-radius: 10px; |
|
margin: 0.5rem 0; |
|
border-left: 4px solid var(--secondary-color); |
|
} |
|
|
|
.session-info { |
|
background: var(--card-background); |
|
padding: 1rem; |
|
border-radius: 8px; |
|
margin: 0.5rem 0; |
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
font-size: 0.9rem; |
|
} |
|
|
|
/* Animation for thinking indicator */ |
|
@keyframes pulse { |
|
0% { opacity: 1; } |
|
50% { opacity: 0.5; } |
|
100% { opacity: 1; } |
|
} |
|
|
|
.thinking-indicator { |
|
animation: pulse 2s infinite; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
|
|
def get_api_configuration(): |
|
"""Get API configuration from environment variables.""" |
|
api_key = os.environ.get("NEBIUS_API_KEY") or os.environ.get("OPENAI_API_KEY") |
|
|
|
if not api_key: |
|
st.markdown( |
|
""" |
|
<div class="error-card"> |
|
<h3>π API Key Configuration Required</h3> |
|
|
|
<h4>For Local Development:</h4> |
|
<ol> |
|
<li>Create a <code>.env</code> file in your project directory</li> |
|
<li>Add your API key: <code>NEBIUS_API_KEY=your_api_key_here</code></li> |
|
<li>Or use OpenAI: <code>OPENAI_API_KEY=your_api_key_here</code></li> |
|
<li>Restart the application</li> |
|
</ol> |
|
|
|
<h4>For Deployment:</h4> |
|
<ol> |
|
<li>Set environment variable <code>NEBIUS_API_KEY</code> or |
|
<code>OPENAI_API_KEY</code></li> |
|
<li>Restart your application</li> |
|
</ol> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
st.stop() |
|
|
|
return api_key |
|
|
|
|
|
|
|
@st.cache_resource |
|
def get_agent(api_key: str) -> DataAnalystAgent: |
|
"""Initialize and cache the LangGraph agent.""" |
|
return DataAnalystAgent(api_key=api_key) |
|
|
|
|
|
|
|
@st.cache_data |
|
def load_bitext_dataset(): |
|
"""Load and cache the Bitext dataset.""" |
|
try: |
|
dataset = load_dataset( |
|
"bitext/Bitext-customer-support-llm-chatbot-training-dataset" |
|
) |
|
df = pd.DataFrame(dataset["train"]) |
|
return df |
|
except Exception as e: |
|
st.error(f"Error loading dataset: {e}") |
|
return None |
|
|
|
|
|
|
|
def initialize_session(): |
|
"""Initialize session state variables.""" |
|
if "session_id" not in st.session_state: |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
|
|
if "conversation_history" not in st.session_state: |
|
st.session_state.conversation_history = [] |
|
|
|
if "user_profile" not in st.session_state: |
|
st.session_state.user_profile = {} |
|
|
|
if "current_thread_id" not in st.session_state: |
|
st.session_state.current_thread_id = st.session_state.session_id |
|
|
|
|
|
def create_new_session(): |
|
"""Create a new session with a new thread ID.""" |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
st.session_state.current_thread_id = st.session_state.session_id |
|
st.session_state.conversation_history = [] |
|
st.session_state.user_profile = {} |
|
|
|
|
|
def format_conversation_message(role: str, content: str, timestamp: str = None): |
|
"""Format a conversation message for display.""" |
|
if timestamp is None: |
|
timestamp = datetime.now().strftime("%H:%M:%S") |
|
|
|
if role == "human": |
|
return f""" |
|
<div class="user-message"> |
|
<strong>π€ You ({timestamp}):</strong><br> |
|
{content} |
|
</div> |
|
""" |
|
else: |
|
return f""" |
|
<div class="assistant-message"> |
|
<strong>π€ Agent ({timestamp}):</strong><br> |
|
{content} |
|
</div> |
|
""" |
|
|
|
|
|
def display_user_profile(profile: Dict): |
|
"""Display user profile information.""" |
|
if not profile: |
|
return |
|
|
|
with st.expander("π§ What I Remember About You", expanded=False): |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.markdown("**Your Interests:**") |
|
interests = profile.get("interests", []) |
|
if interests: |
|
for interest in interests: |
|
st.write(f"β’ {interest}") |
|
else: |
|
st.write("_No interests recorded yet_") |
|
|
|
st.markdown("**Expertise Level:**") |
|
expertise = profile.get("expertise_level", "beginner") |
|
st.write(f"β’ {expertise.title()}") |
|
|
|
with col2: |
|
st.markdown("**Your Preferences:**") |
|
preferences = profile.get("preferences", {}) |
|
if preferences: |
|
for key, value in preferences.items(): |
|
st.write(f"β’ {key}: {value}") |
|
else: |
|
st.write("_No preferences recorded yet_") |
|
|
|
st.markdown("**Recent Query Topics:**") |
|
query_history = profile.get("query_history", []) |
|
if query_history: |
|
for query in query_history[-3:]: |
|
st.write(f"β’ {query[:50]}...") |
|
else: |
|
st.write("_No query history yet_") |
|
|
|
|
|
def main(): |
|
|
|
st.markdown( |
|
""" |
|
<div class="main-header"> |
|
<h1>π€ LangGraph Data Analyst Agent</h1> |
|
<p>Intelligent Analysis with Memory & Recommendations</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
initialize_session() |
|
|
|
|
|
api_key = get_api_configuration() |
|
|
|
|
|
agent = get_agent(api_key) |
|
|
|
|
|
with st.spinner("π Loading dataset..."): |
|
df = load_bitext_dataset() |
|
|
|
if df is None: |
|
st.markdown( |
|
""" |
|
<div class="error-card"> |
|
<h3>β Dataset Loading Failed</h3> |
|
<p>Failed to load dataset. Please check your connection and try again.</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
return |
|
|
|
|
|
st.markdown( |
|
f""" |
|
<div class="success-card"> |
|
<h3>β
System Ready</h3> |
|
<p>Dataset loaded with <strong>{len(df):,}</strong> records. |
|
LangGraph agent initialized with memory.</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
with st.sidebar: |
|
st.markdown("## βοΈ Session Management") |
|
|
|
|
|
st.markdown("### π Session Control") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
if st.button("π New Session", use_container_width=True): |
|
create_new_session() |
|
st.rerun() |
|
|
|
with col2: |
|
if st.button("π Refresh", use_container_width=True): |
|
st.rerun() |
|
|
|
|
|
st.markdown( |
|
f""" |
|
<div class="session-info"> |
|
<strong>Current Session:</strong><br> |
|
<code>{st.session_state.current_thread_id[:8]}...</code><br> |
|
<strong>Messages:</strong> {len(st.session_state.conversation_history)} |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
st.markdown("### π Join Existing Session") |
|
custom_thread_id = st.text_input( |
|
"Enter Session ID:", |
|
placeholder="Enter full session ID to join...", |
|
help="Use this to resume a previous conversation", |
|
) |
|
|
|
if st.button("π Join Session") and custom_thread_id: |
|
st.session_state.current_thread_id = custom_thread_id |
|
|
|
history = agent.get_conversation_history(custom_thread_id) |
|
st.session_state.conversation_history = history |
|
|
|
profile = agent.get_user_profile(custom_thread_id) |
|
st.session_state.user_profile = profile |
|
st.success(f"Joined session: {custom_thread_id[:8]}...") |
|
st.rerun() |
|
|
|
st.markdown("---") |
|
|
|
|
|
st.markdown("### π Dataset Info") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.metric("π Records", f"{len(df):,}") |
|
with col2: |
|
st.metric("π Categories", len(df["category"].unique())) |
|
|
|
st.metric("π― Intents", len(df["intent"].unique())) |
|
|
|
|
|
st.markdown("### π‘ Try These Queries") |
|
example_queries = [ |
|
"What are the most common categories?", |
|
"Show me examples of billing issues", |
|
"Summarize the refund category", |
|
"What should I query next?", |
|
"What do you remember about me?", |
|
] |
|
|
|
for query in example_queries: |
|
if st.button(f"π¬ {query}", key=f"example_{hash(query)}"): |
|
st.session_state.pending_query = query |
|
st.rerun() |
|
|
|
|
|
|
|
if st.session_state.user_profile: |
|
display_user_profile(st.session_state.user_profile) |
|
|
|
|
|
with st.expander("π Dataset Information", expanded=False): |
|
st.markdown("### Dataset Details") |
|
|
|
metrics_col1, metrics_col2, metrics_col3, metrics_col4 = st.columns(4) |
|
with metrics_col1: |
|
st.metric("Total Records", f"{len(df):,}") |
|
with metrics_col2: |
|
st.metric("Columns", len(df.columns)) |
|
with metrics_col3: |
|
st.metric("Categories", len(df["category"].unique())) |
|
with metrics_col4: |
|
st.metric("Intents", len(df["intent"].unique())) |
|
|
|
st.markdown("### Sample Data") |
|
st.dataframe(df.head(), use_container_width=True) |
|
|
|
st.markdown("### Category Distribution") |
|
st.bar_chart(df["category"].value_counts()) |
|
|
|
|
|
st.markdown("## π¬ Chat with the Agent") |
|
|
|
|
|
has_pending_query = hasattr(st.session_state, "pending_query") |
|
if has_pending_query: |
|
user_question = st.session_state.pending_query |
|
delattr(st.session_state, "pending_query") |
|
else: |
|
user_question = st.text_input( |
|
"Ask your question:", |
|
placeholder="e.g., What are the most common customer issues?", |
|
key="user_input", |
|
help="Ask about statistics, examples, insights, or request recommendations", |
|
) |
|
|
|
|
|
col1, col2, col3 = st.columns([1, 2, 1]) |
|
with col2: |
|
submit_clicked = st.button("π Send Message", use_container_width=True) |
|
|
|
|
|
if (submit_clicked or has_pending_query) and user_question: |
|
|
|
timestamp = datetime.now().strftime("%H:%M:%S") |
|
st.session_state.conversation_history.append( |
|
{"role": "human", "content": user_question, "timestamp": timestamp} |
|
) |
|
|
|
|
|
thinking_placeholder = st.empty() |
|
thinking_placeholder.markdown( |
|
""" |
|
<div class="thinking-indicator"> |
|
<div class="info-card"> |
|
βοΈ <strong>Agent is thinking...</strong> |
|
Processing your query through the LangGraph workflow. |
|
</div> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
try: |
|
|
|
result = agent.invoke(user_question, st.session_state.current_thread_id) |
|
|
|
|
|
assistant_response = None |
|
for msg in reversed(result["messages"]): |
|
if ( |
|
hasattr(msg, "content") |
|
and msg.content |
|
and not isinstance(msg, type(user_question)) |
|
): |
|
|
|
if not hasattr(msg, "tool_calls") or not msg.tool_calls: |
|
if "human" not in str(type(msg)).lower(): |
|
content = msg.content |
|
|
|
|
|
if "<think>" in content and "</think>" in content: |
|
|
|
parts = content.split("</think>") |
|
if len(parts) > 1: |
|
content = parts[1].strip() |
|
|
|
assistant_response = content |
|
break |
|
|
|
if not assistant_response: |
|
assistant_response = "I processed your query but couldn't generate a response. Please try again." |
|
|
|
|
|
st.session_state.conversation_history.append( |
|
{ |
|
"role": "assistant", |
|
"content": assistant_response, |
|
"timestamp": datetime.now().strftime("%H:%M:%S"), |
|
} |
|
) |
|
|
|
|
|
if result.get("user_profile"): |
|
st.session_state.user_profile = result["user_profile"] |
|
|
|
except Exception as e: |
|
error_msg = f"Sorry, I encountered an error: {str(e)}" |
|
st.session_state.conversation_history.append( |
|
{ |
|
"role": "assistant", |
|
"content": error_msg, |
|
"timestamp": datetime.now().strftime("%H:%M:%S"), |
|
} |
|
) |
|
|
|
finally: |
|
thinking_placeholder.empty() |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
if st.session_state.conversation_history: |
|
st.markdown("## π Conversation") |
|
|
|
|
|
for i, message in enumerate(st.session_state.conversation_history): |
|
message_html = format_conversation_message( |
|
message["role"], message["content"], message.get("timestamp", "") |
|
) |
|
st.markdown(message_html, unsafe_allow_html=True) |
|
|
|
|
|
if i < len(st.session_state.conversation_history) - 1: |
|
st.markdown("---") |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
if st.button("ποΈ Clear Chat"): |
|
st.session_state.conversation_history = [] |
|
st.rerun() |
|
|
|
with col2: |
|
if st.button("πΎ Export Chat"): |
|
chat_data = { |
|
"session_id": st.session_state.current_thread_id, |
|
"timestamp": datetime.now().isoformat(), |
|
"conversation": st.session_state.conversation_history, |
|
"user_profile": st.session_state.user_profile, |
|
} |
|
st.download_button( |
|
label="π₯ Download JSON", |
|
data=json.dumps(chat_data, indent=2), |
|
file_name=f"chat_export_{st.session_state.current_thread_id[:8]}.json", |
|
mime="application/json", |
|
) |
|
|
|
with col3: |
|
if st.button("π€ Get Recommendations"): |
|
st.session_state.pending_query = "What should I query next?" |
|
st.rerun() |
|
|
|
|
|
with st.expander("π How to Use This Agent", expanded=False): |
|
st.markdown( |
|
""" |
|
### π― Query Types Supported: |
|
|
|
**Structured Queries (Quantitative):** |
|
- "How many records are in each category?" |
|
- "Show me 5 examples of billing issues" |
|
- "What are the most common intents?" |
|
|
|
**Unstructured Queries (Qualitative):** |
|
- "Summarize the refund category" |
|
- "What patterns do you see in payment issues?" |
|
- "Analyze customer sentiment in billing conversations" |
|
|
|
**Memory & Recommendations:** |
|
- "What do you remember about me?" |
|
- "What should I query next?" |
|
- "Advise me what to explore" |
|
|
|
### π§ Memory Features: |
|
- **Session Persistence:** Your conversations are saved across page reloads |
|
- **User Profile:** The agent learns about your interests and preferences |
|
- **Query History:** Past queries influence future recommendations |
|
- **Cross-Session:** Use session IDs to resume conversations later |
|
|
|
### π§ Advanced Features: |
|
- **Multi-Agent Architecture:** Separate agents for different query types |
|
- **Tool Usage:** Dynamic tool selection based on your needs |
|
- **Interactive Recommendations:** Collaborative query refinement |
|
""" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|