Spaces:
Runtime error
Runtime error
# Import modules | |
from typing import TypedDict, Dict | |
from langgraph.graph import StateGraph, END | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables.graph import MermaidDrawMethod | |
# from IPython.display import Image, display | |
import gradio as gr | |
import os | |
from langchain_groq import ChatGroq | |
# Define the State data structure | |
class State(TypedDict): | |
query: str | |
category: str | |
sentiment: str | |
response: str | |
# Function to get the language model | |
def get_llm(api_key=None): | |
if api_key is None: | |
api_key = os.getenv('GROQ_API_KEY') | |
llm = ChatGroq( | |
temperature=0, | |
groq_api_key=api_key, | |
model_name="llama-3.3-70b-versatile" | |
) | |
return llm | |
# Define the processing functions | |
def categorize(state: State, llm) -> State: | |
prompt = ChatPromptTemplate.from_template( | |
"Categorize the following customer query into one of these categories: " | |
"Technical, Billing, General. Query: {query}" | |
) | |
chain = prompt | llm | |
category = chain.invoke({"query": state["query"]}).content.strip() | |
state["category"] = category | |
return state | |
def analyze_sentiment(state: State, llm) -> State: | |
prompt = ChatPromptTemplate.from_template( | |
"Analyze the sentiment of the following customer query. " | |
"Respond with either 'Positive', 'Neutral', or 'Negative'. Query: {query}" | |
) | |
chain = prompt | llm | |
sentiment = chain.invoke({"query": state["query"]}).content.strip() | |
state["sentiment"] = sentiment | |
return state | |
def handle_technical(state: State, llm) -> State: | |
prompt = ChatPromptTemplate.from_template( | |
"Provide a technical support response to the following query: {query}" | |
) | |
chain = prompt | llm | |
response = chain.invoke({"query": state["query"]}).content.strip() | |
state["response"] = response | |
return state | |
def handle_billing(state: State, llm) -> State: | |
prompt = ChatPromptTemplate.from_template( | |
"Provide a billing-related support response to the following query: {query}" | |
) | |
chain = prompt | llm | |
response = chain.invoke({"query": state["query"]}).content.strip() | |
state["response"] = response | |
return state | |
def handle_general(state: State, llm) -> State: | |
prompt = ChatPromptTemplate.from_template( | |
"Provide a general support response to the following query: {query}" | |
) | |
chain = prompt | llm | |
response = chain.invoke({"query": state["query"]}).content.strip() | |
state["response"] = response | |
return state | |
def escalate(state: State) -> State: | |
state["response"] = "This query has been escalated to a human agent due to its negative sentiment." | |
return state | |
def route_query(state: State) -> str: | |
if state["sentiment"].lower() == "negative": | |
return "escalate" | |
elif state["category"].lower() == "technical": | |
return "handle_technical" | |
elif state["category"].lower() == "billing": | |
return "handle_billing" | |
else: | |
return "handle_general" | |
# Function to compile the workflow | |
def get_workflow(llm): | |
workflow = StateGraph(State) | |
workflow.add_node("categorize", lambda state: categorize(state, llm)) | |
workflow.add_node("analyze_sentiment", lambda state: analyze_sentiment(state, llm)) | |
workflow.add_node("handle_technical", lambda state: handle_technical(state, llm)) | |
workflow.add_node("handle_billing", lambda state: handle_billing(state, llm)) | |
workflow.add_node("handle_general", lambda state: handle_general(state, llm)) | |
workflow.add_node("escalate", escalate) | |
workflow.add_edge("categorize", "analyze_sentiment") | |
workflow.add_conditional_edges("analyze_sentiment", | |
route_query, { | |
"handle_technical": "handle_technical", | |
"handle_billing": "handle_billing", | |
"handle_general": "handle_general", | |
"escalate": "escalate", | |
}) | |
workflow.add_edge("handle_technical", END) | |
workflow.add_edge("handle_billing", END) | |
workflow.add_edge("handle_general", END) | |
workflow.add_edge("escalate", END) | |
workflow.set_entry_point("categorize") | |
return workflow.compile() | |
# Gradio interface function | |
def run_customer_support(query: str, api_key: str) -> Dict[str, str]: | |
llm = get_llm(api_key) | |
app = get_workflow(llm) | |
result = app.invoke({"query": query}) | |
return { | |
# "Query": query, | |
# "Category": result.get("category", "").strip(), | |
# "Sentiment": result.get("sentiment", "").strip(), | |
"Response": result.get("response", "").strip() | |
} | |
# Create the Gradio interface | |
gr_interface = gr.Interface( | |
fn=run_customer_support, | |
inputs=[ | |
gr.Textbox(lines=2, label="Customer Query", placeholder="Enter your customer support query here..."), | |
gr.Textbox(label="GROQ API Key", placeholder="Enter your GROQ API key"), | |
], | |
outputs=gr.JSON(label="Response"), | |
title="Customer Support Chatbot", | |
description="Enter your query to receive assistance.", | |
) | |
# Launch the Gradio interface | |
gr_interface.launch() |