Spaces:
Sleeping
Sleeping
File size: 3,558 Bytes
81a2ae6 bc974ff b746dc4 a0428bd bc974ff 56226b9 a0428bd b746dc4 a0428bd bc974ff b746dc4 bc974ff 6b8ac59 81a2ae6 bc974ff a0428bd f3aaa59 6b8ac59 f3aaa59 bc974ff 81a2ae6 bc974ff f3aaa59 bc974ff b915000 bc974ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
# Initialize FastAPI
app = FastAPI()
# Add CORS middleware to allow requests from your React Native app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict this to your app's domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load chatbot model
print("Loading chatbot model...")
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
# Load emotion detection model
print("Loading emotion detection model...")
emotion_pipeline = pipeline(
"text-classification",
model="bhadresh-savani/distilbert-base-uncased-emotion",
return_all_scores=False
)
# Store conversation history
chat_history = {}
# Define API endpoint for chatbot
@app.post("/api/chat")
async def chat_endpoint(request: Request):
data = await request.json()
user_input = data.get("message", "")
session_id = data.get("session_id", "default")
# Process input and get response
response, emotion, score = generate_response(user_input, session_id)
return {
"response": response,
"emotion": emotion,
"score": float(score),
"session_id": session_id
}
def generate_response(user_input, session_id="default"):
# Initialize chat history for new sessions
if session_id not in chat_history:
chat_history[session_id] = []
# Format the input with chat history
bot_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
# Append user input to chat history
chat_history[session_id].append(bot_input_ids)
# Generate a response
with torch.no_grad():
chat_history_ids = model.generate(
bot_input_ids,
max_length=200,
pad_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
# Decode the response
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# Detect emotion
try:
emotion_result = emotion_pipeline(user_input)[0]
emotion = emotion_result["label"]
score = emotion_result["score"]
except Exception as e:
print(f"Error detecting emotion: {e}")
emotion = "unknown"
score = 0.0
return response, emotion, score
# Gradio Interface
def gradio_generate(user_input):
response, emotion, score = generate_response(user_input)
return response, emotion, f"{score:.4f}"
# Create Gradio interface
iface = gr.Interface(
fn=gradio_generate,
inputs=gr.Textbox(label="Enter your message", placeholder="Type your message here..."),
outputs=[
gr.Textbox(label="Chatbot Response"),
gr.Textbox(label="Emotion Detected"),
gr.Textbox(label="Emotion Score")
],
title="Mental Health Chatbot",
description="A simple mental health chatbot with emotion detection",
allow_flagging="never"
)
# Mount the Gradio app to FastAPI
app = gr.mount_gradio_app(app, iface, path="/")
# Run the app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |