Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
# Load models | |
chatbot_model = "microsoft/DialoGPT-medium" | |
tokenizer = AutoTokenizer.from_pretrained(chatbot_model) | |
model = AutoModelForCausalLM.from_pretrained(chatbot_model) | |
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base") | |
# Store chat histories | |
chat_histories = {} | |
def chatbot_response(message, session_id="default"): | |
if session_id not in chat_histories: | |
chat_histories[session_id] = [] | |
# Generate response | |
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt") | |
output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id) | |
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Detect emotion | |
emotion_result = emotion_pipeline(message) | |
emotion = emotion_result[0]["label"] | |
score = float(emotion_result[0]["score"]) | |
# Store history | |
chat_histories[session_id].append((message, response)) | |
return response, emotion, score | |
# ------------------ Web Interface ------------------ | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🤖 Mental Health Chatbot") | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Your Message") | |
session_id = gr.Textbox(label="Session ID", value="default") | |
btn = gr.Button("Send") | |
clear_btn = gr.Button("Clear History") | |
with gr.Column(): | |
emotion_out = gr.Textbox(label="Detected Emotion") | |
score_out = gr.Number(label="Confidence Score") | |
def respond(message, chat_history, session_id): | |
response, emotion, score = chatbot_response(message, session_id) | |
chat_history.append((message, response)) | |
return "", chat_history, emotion, score | |
btn.click(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out]) | |
msg.submit(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out]) | |
clear_btn.click(lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0), | |
inputs=[session_id], | |
outputs=[chatbot, emotion_out, score_out]) | |
# ------------------ API Endpoint ------------------ | |
api_interface = gr.Interface( | |
fn=chatbot_response, # Exposing the chatbot function | |
inputs=[gr.Textbox(label="Message"), gr.Textbox(label="Session ID", value="default")], | |
outputs=[gr.Textbox(label="Chatbot Response"), gr.Textbox(label="Detected Emotion"), gr.Number(label="Confidence Score")] | |
) | |
# Launch Gradio interface and API | |
demo.launch(share=True, server_name="0.0.0.0", server_port=7860) | |
api_interface.launch(share=True, server_name="0.0.0.0", server_port=7861) |