Spaces:
Sleeping
Sleeping
File size: 3,474 Bytes
fa4258c 43bfd5c b746dc4 b3c1245 56226b9 a0428bd 43bfd5c fa4258c 43bfd5c b3c1245 7d4edac 43bfd5c 57aa867 7d4edac fa4258c 43bfd5c 57aa867 43bfd5c fa4258c 57aa867 b3c1245 fa4258c b3c1245 7d4edac 57aa867 7d4edac b3c1245 7d4edac b3c1245 7d4edac b3c1245 7d4edac 57aa867 7d4edac 57aa867 b3c1245 7d4edac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# Load models
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
# Store chat histories
chat_histories = {}
def chatbot_response(message, session_id="default"):
"""Core function that handles both chat and emotion analysis"""
if session_id not in chat_histories:
chat_histories[session_id] = []
# Generate chatbot response
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Detect emotion
emotion_result = emotion_pipeline(message)
emotion = emotion_result[0]["label"]
score = float(emotion_result[0]["score"])
# Store history
chat_histories[session_id].append((message, response))
return response, emotion, score
# ------------------ API Interface ------------------
def api_predict(message: str, session_id: str = "default"):
"""Endpoint for /predict that returns JSON"""
response, emotion, score = chatbot_response(message, session_id)
return {
"response": response,
"emotion": emotion,
"score": score,
"session_id": session_id
}
# ------------------ Web Interface ------------------
with gr.Blocks(title="Mental Health Chatbot") as web_interface:
gr.Markdown("# 🤖 Mental Health Chatbot")
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(placeholder="Type your message...", label="You")
with gr.Row():
session_id = gr.Textbox(label="Session ID", value="default")
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column():
emotion_out = gr.Textbox(label="Detected Emotion")
score_out = gr.Number(label="Confidence Score")
def respond(message, chat_history, session_id):
response, emotion, score = chatbot_response(message, session_id)
chat_history.append((message, response))
return "", chat_history, emotion, score
submit_btn.click(
respond,
[msg, chatbot, session_id],
[msg, chatbot, emotion_out, score_out]
)
msg.submit(
respond,
[msg, chatbot, session_id],
[msg, chatbot, emotion_out, score_out]
)
clear_btn.click(
lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0),
[session_id],
[chatbot, emotion_out, score_out]
)
# ------------------ Mount Interfaces ------------------
app = gr.mount_gradio_app(
gr.routes.App(),
web_interface,
path="/"
)
app = gr.mount_gradio_app(
app,
gr.Interface(
fn=api_predict,
inputs=[gr.Textbox(), gr.Textbox()],
outputs=gr.JSON(),
title="API Predict",
description="Use this endpoint for programmatic access"
),
path="/predict"
)
# ------------------ Launch ------------------
if __name__ == "__main__":
app.launch(show_api=False) # We manually mounted our API |