safwansajad's picture
Update app.py
fa4258c verified
raw
history blame
5.02 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# Load chatbot model
print("Loading DialoGPT model...")
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
# Load emotion detection model
print("Loading emotion detection model...")
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
# Store chat histories
chat_histories = {}
def chatbot_response(message, history=None, session_id="default"):
"""Generate a chatbot response and detect emotion from user message"""
# Initialize session if it doesn't exist
if session_id not in chat_histories:
chat_histories[session_id] = []
# Generate chatbot response
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Detect emotion
emotion_result = emotion_pipeline(message)
emotion = emotion_result[0]["label"]
score = float(emotion_result[0]["score"])
# Store in chat history
chat_histories[session_id].append((message, response))
return response, emotion, score, chat_histories[session_id]
def api_chatbot_response(message, session_id="default"):
"""API endpoint version that returns a structured response"""
response, emotion, score, _ = chatbot_response(message, None, session_id)
return {
"bot_response": response,
"emotion": emotion,
"emotion_score": score,
"session_id": session_id
}
def get_chat_history(session_id="default"):
"""Get chat history for a specific session"""
if session_id in chat_histories:
return chat_histories[session_id]
return []
def clear_history(session_id="default"):
"""Clear chat history for a specific session"""
if session_id in chat_histories:
chat_histories[session_id] = []
return f"History cleared for session {session_id}"
return f"Session {session_id} not found"
# Define UI interface
with gr.Blocks(title="Mental Health Chatbot") as ui_interface:
gr.Markdown("# 🧠 Mental Health Chatbot")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=400, label="Conversation")
with gr.Row():
message = gr.Textbox(placeholder="Type your message here...", label="You", show_label=False)
submit_btn = gr.Button("Send")
with gr.Row():
session_id = gr.Textbox(value="default", label="Session ID")
clear_btn = gr.Button("Clear Chat")
with gr.Column(scale=1):
emotion_label = gr.Textbox(label="Emotion Detected")
emotion_score = gr.Number(label="Confidence Score")
# Set up event handlers
def respond(message, chat_history, session_id):
response, emotion, score, _ = chatbot_response(message, chat_history, session_id)
chat_history.append((message, response))
return "", chat_history, emotion, score
submit_btn.click(
respond,
[message, chatbot, session_id],
[message, chatbot, emotion_label, emotion_score]
)
message.submit(
respond,
[message, chatbot, session_id],
[message, chatbot, emotion_label, emotion_score]
)
clear_btn.click(
lambda s: ([], clear_history(s), "", 0),
[session_id],
[chatbot, emotion_label, emotion_score]
)
# Define API interface
api_interface = gr.Interface(
fn=api_chatbot_response,
inputs=[
gr.Textbox(label="Message"),
gr.Textbox(label="Session ID", value="default")
],
outputs=gr.JSON(label="Response"),
title="Mental Health Chatbot API",
description="Send a message to get chatbot response with emotion analysis",
examples=[
["I'm feeling sad today", "user1"],
["I'm so excited about my new job!", "user2"],
["I'm worried about my exam tomorrow", "user3"]
]
)
history_api = gr.Interface(
fn=get_chat_history,
inputs=gr.Textbox(label="Session ID", value="default"),
outputs=gr.JSON(label="Chat History"),
title="Chat History API",
description="Get chat history for a specific session"
)
clear_api = gr.Interface(
fn=clear_history,
inputs=gr.Textbox(label="Session ID", value="default"),
outputs=gr.Textbox(label="Result"),
title="Clear History API",
description="Clear chat history for a specific session"
)
# Combine all interfaces
demo = gr.TabbedInterface(
[ui_interface, api_interface, history_api, clear_api],
["Chat UI", "Chat API", "History API", "Clear API"]
)
# Launch the app
if __name__ == "__main__":
demo.launch()