|
import json |
|
import os |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional, TypedDict |
|
|
|
import pandas as pd |
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage |
|
from langchain_core.tools import tool |
|
from langchain_openai import ChatOpenAI |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langgraph.graph import END, START, StateGraph |
|
from langgraph.prebuilt import ToolNode |
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
class QueryType(str, Enum): |
|
STRUCTURED = "structured" |
|
UNSTRUCTURED = "unstructured" |
|
OUT_OF_SCOPE = "out_of_scope" |
|
RECOMMEND_QUERY = "recommend_query" |
|
|
|
|
|
class AnalysisType(str, Enum): |
|
QUANTITATIVE = "quantitative" |
|
QUALITATIVE = "qualitative" |
|
OUT_OF_SCOPE = "out_of_scope" |
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: List[Any] |
|
query_type: Optional[str] |
|
analysis_result: Optional[Dict[str, Any]] |
|
user_profile: Optional[Dict[str, Any]] |
|
session_context: Optional[Dict[str, Any]] |
|
recommendations: Optional[List[str]] |
|
|
|
|
|
|
|
class UserProfile(BaseModel): |
|
interests: List[str] = Field(default_factory=list) |
|
query_history: List[str] = Field(default_factory=list) |
|
preferences: Dict[str, Any] = Field(default_factory=dict) |
|
expertise_level: str = "beginner" |
|
|
|
|
|
|
|
class DatasetManager: |
|
_instance = None |
|
_df = None |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super(DatasetManager, cls).__new__(cls) |
|
return cls._instance |
|
|
|
def get_dataset(self) -> pd.DataFrame: |
|
if self._df is None: |
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset( |
|
"bitext/Bitext-customer-support-llm-chatbot-training-dataset" |
|
) |
|
self._df = pd.DataFrame(dataset["train"]) |
|
return self._df |
|
|
|
|
|
|
|
@tool |
|
def get_category_distribution() -> Dict[str, int]: |
|
"""Get the distribution of categories in the dataset.""" |
|
df = DatasetManager().get_dataset() |
|
return df["category"].value_counts().to_dict() |
|
|
|
|
|
@tool |
|
def get_intent_distribution() -> Dict[str, int]: |
|
"""Get the distribution of intents in the dataset.""" |
|
df = DatasetManager().get_dataset() |
|
return df["intent"].value_counts().to_dict() |
|
|
|
|
|
@tool |
|
def get_dataset_stats() -> Dict[str, Any]: |
|
"""Get basic statistics about the dataset.""" |
|
df = DatasetManager().get_dataset() |
|
return { |
|
"total_records": len(df), |
|
"unique_categories": len(df["category"].unique()), |
|
"unique_intents": len(df["intent"].unique()), |
|
"columns": df.columns.tolist(), |
|
} |
|
|
|
|
|
@tool |
|
def get_examples_by_category(category: str, n: int = 5) -> List[Dict[str, Any]]: |
|
"""Get examples from a specific category.""" |
|
df = DatasetManager().get_dataset() |
|
filtered_df = df[df["category"].str.lower() == category.lower()] |
|
if filtered_df.empty: |
|
return [] |
|
return filtered_df.head(n).to_dict("records") |
|
|
|
|
|
@tool |
|
def get_examples_by_intent(intent: str, n: int = 5) -> List[Dict[str, Any]]: |
|
"""Get examples from a specific intent.""" |
|
df = DatasetManager().get_dataset() |
|
filtered_df = df[df["intent"].str.lower() == intent.lower()] |
|
if filtered_df.empty: |
|
return [] |
|
return filtered_df.head(n).to_dict("records") |
|
|
|
|
|
@tool |
|
def search_conversations(query: str, n: int = 5) -> List[Dict[str, Any]]: |
|
"""Search for conversations containing specific keywords.""" |
|
df = DatasetManager().get_dataset() |
|
mask = df["customer"].str.contains(query, case=False, na=False) | df[ |
|
"agent" |
|
].str.contains(query, case=False, na=False) |
|
filtered_df = df[mask] |
|
return filtered_df.head(n).to_dict("records") |
|
|
|
|
|
|
|
@tool |
|
def get_category_summary(category: str) -> Dict[str, Any]: |
|
"""Get a summary of conversations in a specific category.""" |
|
df = DatasetManager().get_dataset() |
|
filtered_df = df[df["category"].str.lower() == category.lower()] |
|
if filtered_df.empty: |
|
return {"error": f"No data found for category: {category}"} |
|
|
|
return { |
|
"category": category, |
|
"count": len(filtered_df), |
|
"unique_intents": filtered_df["intent"].nunique(), |
|
"intents": filtered_df["intent"].value_counts().to_dict(), |
|
"sample_conversations": filtered_df.head(3).to_dict("records"), |
|
} |
|
|
|
|
|
@tool |
|
def get_intent_summary(intent: str) -> Dict[str, Any]: |
|
"""Get a summary of conversations for a specific intent.""" |
|
df = DatasetManager().get_dataset() |
|
filtered_df = df[df["intent"].str.lower() == intent.lower()] |
|
if filtered_df.empty: |
|
return {"error": f"No data found for intent: {intent}"} |
|
|
|
return { |
|
"intent": intent, |
|
"count": len(filtered_df), |
|
"categories": filtered_df["category"].value_counts().to_dict(), |
|
"sample_conversations": filtered_df.head(3).to_dict("records"), |
|
} |
|
|
|
|
|
|
|
@tool |
|
def update_user_profile( |
|
interests: List[str], preferences: Dict[str, Any], expertise_level: str = "beginner" |
|
) -> Dict[str, Any]: |
|
"""Update the user's profile with new information.""" |
|
return { |
|
"interests": interests, |
|
"preferences": preferences, |
|
"expertise_level": expertise_level, |
|
"updated": True, |
|
} |
|
|
|
|
|
|
|
structured_tools = [ |
|
get_category_distribution, |
|
get_intent_distribution, |
|
get_dataset_stats, |
|
get_examples_by_category, |
|
get_examples_by_intent, |
|
search_conversations, |
|
] |
|
|
|
unstructured_tools = [ |
|
get_category_summary, |
|
get_intent_summary, |
|
search_conversations, |
|
get_examples_by_category, |
|
get_examples_by_intent, |
|
] |
|
|
|
memory_tools = [update_user_profile] |
|
|
|
|
|
class DataAnalystAgent: |
|
def __init__(self, api_key: str, model_name: str = None): |
|
|
|
is_nebius = os.environ.get("NEBIUS_API_KEY") == api_key |
|
|
|
if is_nebius: |
|
|
|
self.llm = ChatOpenAI( |
|
api_key=api_key, |
|
model=model_name or "Qwen/Qwen3-30B-A3B", |
|
base_url="https://api.studio.nebius.com/v1", |
|
temperature=0, |
|
) |
|
else: |
|
|
|
self.llm = ChatOpenAI( |
|
api_key=api_key, model=model_name or "gpt-4o", temperature=0 |
|
) |
|
|
|
self.memory = MemorySaver() |
|
self.graph = self._build_graph() |
|
|
|
def _build_graph(self) -> StateGraph: |
|
"""Build the LangGraph workflow.""" |
|
builder = StateGraph(AgentState) |
|
|
|
|
|
builder.add_node("classifier", self._classify_query) |
|
builder.add_node("structured_agent", self._structured_agent) |
|
builder.add_node("unstructured_agent", self._unstructured_agent) |
|
builder.add_node("structured_tools", ToolNode(structured_tools)) |
|
builder.add_node("unstructured_tools", ToolNode(unstructured_tools)) |
|
builder.add_node("summarizer", self._update_summary) |
|
builder.add_node("recommender", self._recommend_queries) |
|
builder.add_node("out_of_scope", self._handle_out_of_scope) |
|
|
|
|
|
builder.add_edge(START, "classifier") |
|
|
|
|
|
builder.add_conditional_edges( |
|
"classifier", |
|
self._route_query, |
|
{ |
|
"structured": "structured_agent", |
|
"unstructured": "unstructured_agent", |
|
"out_of_scope": "out_of_scope", |
|
"recommend_query": "recommender", |
|
}, |
|
) |
|
|
|
|
|
builder.add_conditional_edges( |
|
"structured_agent", |
|
self._should_use_tools, |
|
{"tools": "structured_tools", "end": "summarizer"}, |
|
) |
|
|
|
|
|
builder.add_conditional_edges( |
|
"unstructured_agent", |
|
self._should_use_tools, |
|
{"tools": "unstructured_tools", "end": "summarizer"}, |
|
) |
|
|
|
|
|
builder.add_edge("structured_tools", "structured_agent") |
|
builder.add_edge("unstructured_tools", "unstructured_agent") |
|
|
|
|
|
builder.add_edge("summarizer", END) |
|
builder.add_edge("out_of_scope", END) |
|
builder.add_edge("recommender", END) |
|
|
|
return builder.compile(checkpointer=self.memory) |
|
|
|
def _classify_query(self, state: AgentState) -> AgentState: |
|
"""Classify the user query into different types.""" |
|
last_message = state["messages"][-1] |
|
user_query = last_message.content.lower() |
|
|
|
|
|
|
|
if any( |
|
word in user_query |
|
for word in [ |
|
"what should i", |
|
"what to query", |
|
"recommend", |
|
"suggest", |
|
"advise", |
|
"what next", |
|
"what can i ask", |
|
] |
|
): |
|
query_type = "recommend_query" |
|
|
|
|
|
elif any( |
|
word in user_query |
|
for word in [ |
|
"weather", |
|
"news", |
|
"sports", |
|
"politics", |
|
"cooking", |
|
"travel", |
|
"music", |
|
"movies", |
|
"games", |
|
"programming", |
|
"code", |
|
] |
|
) and not any( |
|
word in user_query |
|
for word in ["category", "intent", "customer", "support", "data", "records"] |
|
): |
|
query_type = "out_of_scope" |
|
|
|
|
|
elif any( |
|
word in user_query |
|
for word in [ |
|
"summarize", |
|
"summary", |
|
"patterns", |
|
"insights", |
|
"analysis", |
|
"analyze", |
|
"themes", |
|
"trends", |
|
"what patterns", |
|
"understand", |
|
] |
|
): |
|
query_type = "unstructured" |
|
|
|
|
|
else: |
|
query_type = "structured" |
|
|
|
|
|
if query_type == "out_of_scope": |
|
simple_prompt = f""" |
|
Is this question about customer support data analysis? |
|
Question: "{last_message.content}" |
|
|
|
Answer only "yes" or "no". |
|
""" |
|
|
|
try: |
|
response = self.llm.invoke([HumanMessage(content=simple_prompt)]) |
|
if "yes" in response.content.lower(): |
|
query_type = "structured" |
|
except Exception: |
|
pass |
|
|
|
state["query_type"] = query_type |
|
return state |
|
|
|
def _route_query(self, state: AgentState) -> str: |
|
"""Route to appropriate agent based on classification.""" |
|
return state["query_type"] |
|
|
|
def _structured_agent(self, state: AgentState) -> AgentState: |
|
"""Handle structured/quantitative queries.""" |
|
|
|
system_prompt = """ |
|
You are a data analyst that MUST use tools to answer questions about |
|
customer support data. You have access to these tools: |
|
|
|
- get_category_distribution: Get category counts |
|
- get_intent_distribution: Get intent counts |
|
- get_dataset_stats: Get basic dataset statistics |
|
- get_examples_by_category: Get examples from a category |
|
- get_examples_by_intent: Get examples from an intent |
|
- search_conversations: Search for conversations with keywords |
|
|
|
IMPORTANT: Always use the appropriate tool to get real data. |
|
Do NOT make up or guess answers. Use tools to get actual numbers. |
|
|
|
For questions about: |
|
- "How many categories" or "category distribution" → use get_category_distribution |
|
- "How many intents" or "intent distribution" → use get_intent_distribution |
|
- "Total records" or "dataset size" → use get_dataset_stats |
|
- "Examples of X" → use get_examples_by_category or get_examples_by_intent |
|
- "Search for X" → use search_conversations |
|
""" |
|
|
|
llm_with_tools = self.llm.bind_tools(structured_tools) |
|
messages = [SystemMessage(content=system_prompt)] + state["messages"] |
|
response = llm_with_tools.invoke(messages) |
|
|
|
state["messages"].append(response) |
|
return state |
|
|
|
def _unstructured_agent(self, state: AgentState) -> AgentState: |
|
"""Handle unstructured/qualitative queries.""" |
|
|
|
system_prompt = """ |
|
You are a data analyst that MUST use tools to provide insights about |
|
customer support data. You have access to these tools: |
|
|
|
- get_category_summary: Get detailed summary of a category |
|
- get_intent_summary: Get detailed summary of an intent |
|
- search_conversations: Search conversations for patterns |
|
- get_examples_by_category: Get examples to analyze patterns |
|
- get_examples_by_intent: Get examples to analyze patterns |
|
|
|
IMPORTANT: Always use the appropriate tool to get real data. |
|
Do NOT make up or guess insights. Use tools to get actual data first. |
|
|
|
For questions about: |
|
- "Summarize X category" → use get_category_summary |
|
- "Analyze X intent" → use get_intent_summary |
|
- "Patterns in X" → use get_examples_by_category or search_conversations |
|
""" |
|
|
|
llm_with_tools = self.llm.bind_tools(unstructured_tools) |
|
messages = [SystemMessage(content=system_prompt)] + state["messages"] |
|
response = llm_with_tools.invoke(messages) |
|
|
|
state["messages"].append(response) |
|
return state |
|
|
|
def _should_use_tools(self, state: AgentState) -> str: |
|
"""Determine if the agent should use tools or end.""" |
|
last_message = state["messages"][-1] |
|
|
|
|
|
if hasattr(last_message, "tool_calls") and last_message.tool_calls: |
|
return "tools" |
|
|
|
|
|
|
|
messages = state["messages"] |
|
human_messages = [msg for msg in messages if isinstance(msg, HumanMessage)] |
|
|
|
if len(human_messages) >= 1: |
|
last_human_msg = human_messages[-1].content.lower() |
|
|
|
|
|
needs_tools = any( |
|
word in last_human_msg |
|
for word in [ |
|
"how many", |
|
"show me", |
|
"examples", |
|
"distribution", |
|
"categories", |
|
"intents", |
|
"records", |
|
"statistics", |
|
"stats", |
|
"count", |
|
"total", |
|
"billing", |
|
"refund", |
|
"payment", |
|
"technical", |
|
"support", |
|
] |
|
) |
|
|
|
|
|
ai_messages = [msg for msg in messages if not isinstance(msg, HumanMessage)] |
|
if needs_tools and len(ai_messages) <= 1: |
|
return "tools" |
|
|
|
return "end" |
|
|
|
def _update_summary(self, state: AgentState) -> AgentState: |
|
"""Update user profile/summary based on the interaction.""" |
|
user_profile = state.get("user_profile", {}) |
|
last_human_message = None |
|
|
|
|
|
for msg in reversed(state["messages"]): |
|
if isinstance(msg, HumanMessage): |
|
last_human_message = msg |
|
break |
|
|
|
if last_human_message: |
|
|
|
system_prompt = """ |
|
Based on the user's question, extract information about their |
|
interests and update their profile. Consider: |
|
- What categories/intents they're interested in |
|
- Their level of technical detail preference |
|
- Types of analysis they prefer |
|
|
|
Return a JSON with: |
|
{ |
|
"interests": ["list of topics they seem interested in"], |
|
"preferences": {"any preferences about analysis style"}, |
|
"expertise_level": "beginner/intermediate/advanced" |
|
} |
|
|
|
If no clear information can be extracted, return empty lists/dicts. |
|
""" |
|
|
|
messages = [ |
|
SystemMessage(content=system_prompt), |
|
HumanMessage(content=f"User question: {last_human_message.content}"), |
|
] |
|
|
|
try: |
|
response = self.llm.invoke(messages) |
|
profile_update = json.loads(response.content) |
|
|
|
|
|
if not user_profile: |
|
user_profile = { |
|
"interests": [], |
|
"preferences": {}, |
|
"expertise_level": "beginner", |
|
"query_history": [], |
|
} |
|
|
|
|
|
new_interests = profile_update.get("interests", []) |
|
existing_interests = user_profile.get("interests", []) |
|
user_profile["interests"] = list( |
|
set(existing_interests + new_interests) |
|
) |
|
|
|
|
|
user_profile["preferences"].update( |
|
profile_update.get("preferences", {}) |
|
) |
|
|
|
|
|
if profile_update.get("expertise_level"): |
|
user_profile["expertise_level"] = profile_update["expertise_level"] |
|
|
|
|
|
if "query_history" not in user_profile: |
|
user_profile["query_history"] = [] |
|
user_profile["query_history"].append(last_human_message.content) |
|
|
|
|
|
user_profile["query_history"] = user_profile["query_history"][-10:] |
|
|
|
except (json.JSONDecodeError, Exception): |
|
|
|
if not user_profile: |
|
user_profile = {"query_history": []} |
|
if "query_history" not in user_profile: |
|
user_profile["query_history"] = [] |
|
user_profile["query_history"].append(last_human_message.content) |
|
user_profile["query_history"] = user_profile["query_history"][-10:] |
|
|
|
state["user_profile"] = user_profile |
|
return state |
|
|
|
def _recommend_queries(self, state: AgentState) -> AgentState: |
|
"""Recommend next queries based on conversation history and user profile.""" |
|
user_profile = state.get("user_profile", {}) |
|
query_history = user_profile.get("query_history", []) |
|
interests = user_profile.get("interests", []) |
|
|
|
|
|
df = DatasetManager().get_dataset() |
|
categories = df["category"].unique().tolist() |
|
intents = df["intent"].unique()[:20].tolist() |
|
|
|
system_prompt = f""" |
|
You are a query recommendation assistant. Based on the user's conversation |
|
history and interests, suggest relevant follow-up questions they could ask |
|
about the customer support dataset. |
|
|
|
User's query history: {query_history} |
|
User's interests: {interests} |
|
|
|
Available categories: {categories} |
|
Sample intents: {intents} |
|
|
|
Suggest 3-5 relevant questions the user might want to ask next. Consider: |
|
- Natural follow-ups to their previous questions |
|
- Related categories or intents they haven't explored |
|
- Different types of analysis (if they've only done quantitative, |
|
suggest qualitative and vice versa) |
|
|
|
Be conversational and explain why each suggestion might be interesting. |
|
Start with "Based on your previous queries, you might want to..." |
|
""" |
|
|
|
messages = [SystemMessage(content=system_prompt)] |
|
|
|
|
|
if state["messages"]: |
|
messages.extend(state["messages"]) |
|
else: |
|
messages.append(HumanMessage(content="What should I query next?")) |
|
|
|
response = self.llm.invoke(messages) |
|
state["messages"].append(response) |
|
|
|
return state |
|
|
|
def _handle_out_of_scope(self, state: AgentState) -> AgentState: |
|
"""Handle queries that are out of scope.""" |
|
response = AIMessage( |
|
content="I'm sorry, but I can only answer questions about the customer " |
|
"support dataset. Please ask questions about categories, intents, " |
|
"conversation examples, or data statistics." |
|
) |
|
state["messages"].append(response) |
|
return state |
|
|
|
def invoke(self, message: str, thread_id: str) -> Dict[str, Any]: |
|
"""Invoke the agent with a message and thread ID.""" |
|
config = {"configurable": {"thread_id": thread_id}} |
|
|
|
|
|
input_state = {"messages": [HumanMessage(content=message)]} |
|
|
|
|
|
result = self.graph.invoke(input_state, config) |
|
|
|
return result |
|
|
|
def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]: |
|
"""Get conversation history for a thread.""" |
|
config = {"configurable": {"thread_id": thread_id}} |
|
|
|
try: |
|
|
|
state = self.graph.get_state(config) |
|
if state and state.values.get("messages"): |
|
return [ |
|
{ |
|
"role": ( |
|
"human" if isinstance(msg, HumanMessage) else "assistant" |
|
), |
|
"content": msg.content, |
|
} |
|
for msg in state.values["messages"] |
|
] |
|
except Exception: |
|
pass |
|
|
|
return [] |
|
|
|
def get_user_profile(self, thread_id: str) -> Dict[str, Any]: |
|
"""Get user profile for a thread.""" |
|
config = {"configurable": {"thread_id": thread_id}} |
|
|
|
try: |
|
state = self.graph.get_state(config) |
|
if state and state.values.get("user_profile"): |
|
return state.values["user_profile"] |
|
except Exception: |
|
pass |
|
|
|
return {} |
|
|