Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
# Load models | |
chatbot_model = "microsoft/DialoGPT-medium" | |
tokenizer = AutoTokenizer.from_pretrained(chatbot_model) | |
model = AutoModelForCausalLM.from_pretrained(chatbot_model) | |
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base") | |
# Store chat histories | |
chat_histories = {} | |
def chatbot_response(message, session_id="default"): | |
"""Core function that handles both chat and emotion analysis""" | |
if session_id not in chat_histories: | |
chat_histories[session_id] = [] | |
# Generate chatbot response | |
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt") | |
output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id) | |
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Detect emotion | |
emotion_result = emotion_pipeline(message) | |
emotion = emotion_result[0]["label"] | |
score = float(emotion_result[0]["score"]) | |
# Store history | |
chat_histories[session_id].append((message, response)) | |
return response, emotion, score | |
# ------------------ API Interface ------------------ | |
def api_predict(message: str, session_id: str = "default"): | |
"""Endpoint for /predict that returns JSON""" | |
response, emotion, score = chatbot_response(message, session_id) | |
return { | |
"response": response, | |
"emotion": emotion, | |
"score": score, | |
"session_id": session_id | |
} | |
# ------------------ Web Interface ------------------ | |
with gr.Blocks(title="Mental Health Chatbot") as web_interface: | |
gr.Markdown("# 🤖 Mental Health Chatbot") | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot(height=400) | |
msg = gr.Textbox(placeholder="Type your message...", label="You") | |
with gr.Row(): | |
session_id = gr.Textbox(label="Session ID", value="default") | |
submit_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(): | |
emotion_out = gr.Textbox(label="Detected Emotion") | |
score_out = gr.Number(label="Confidence Score") | |
def respond(message, chat_history, session_id): | |
response, emotion, score = chatbot_response(message, session_id) | |
chat_history.append((message, response)) | |
return "", chat_history, emotion, score | |
submit_btn.click( | |
respond, | |
[msg, chatbot, session_id], | |
[msg, chatbot, emotion_out, score_out] | |
) | |
msg.submit( | |
respond, | |
[msg, chatbot, session_id], | |
[msg, chatbot, emotion_out, score_out] | |
) | |
clear_btn.click( | |
lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0), | |
[session_id], | |
[chatbot, emotion_out, score_out] | |
) | |
# ------------------ Mount Interfaces ------------------ | |
app = gr.mount_gradio_app( | |
gr.routes.App(), | |
web_interface, | |
path="/" | |
) | |
app = gr.mount_gradio_app( | |
app, | |
gr.Interface( | |
fn=api_predict, | |
inputs=[gr.Textbox(), gr.Textbox()], | |
outputs=gr.JSON(), | |
title="API Predict", | |
description="Use this endpoint for programmatic access" | |
), | |
path="/predict" | |
) | |
# ------------------ Launch ------------------ | |
if __name__ == "__main__": | |
app.launch(show_api=False) # We manually mounted our API |