safwansajad's picture
Update app.py
bc974ff verified
raw
history blame
3.56 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
# Initialize FastAPI
app = FastAPI()
# Add CORS middleware to allow requests from your React Native app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict this to your app's domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load chatbot model
print("Loading chatbot model...")
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
# Load emotion detection model
print("Loading emotion detection model...")
emotion_pipeline = pipeline(
"text-classification",
model="bhadresh-savani/distilbert-base-uncased-emotion",
return_all_scores=False
)
# Store conversation history
chat_history = {}
# Define API endpoint for chatbot
@app.post("/api/chat")
async def chat_endpoint(request: Request):
data = await request.json()
user_input = data.get("message", "")
session_id = data.get("session_id", "default")
# Process input and get response
response, emotion, score = generate_response(user_input, session_id)
return {
"response": response,
"emotion": emotion,
"score": float(score),
"session_id": session_id
}
def generate_response(user_input, session_id="default"):
# Initialize chat history for new sessions
if session_id not in chat_history:
chat_history[session_id] = []
# Format the input with chat history
bot_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
# Append user input to chat history
chat_history[session_id].append(bot_input_ids)
# Generate a response
with torch.no_grad():
chat_history_ids = model.generate(
bot_input_ids,
max_length=200,
pad_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
# Decode the response
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# Detect emotion
try:
emotion_result = emotion_pipeline(user_input)[0]
emotion = emotion_result["label"]
score = emotion_result["score"]
except Exception as e:
print(f"Error detecting emotion: {e}")
emotion = "unknown"
score = 0.0
return response, emotion, score
# Gradio Interface
def gradio_generate(user_input):
response, emotion, score = generate_response(user_input)
return response, emotion, f"{score:.4f}"
# Create Gradio interface
iface = gr.Interface(
fn=gradio_generate,
inputs=gr.Textbox(label="Enter your message", placeholder="Type your message here..."),
outputs=[
gr.Textbox(label="Chatbot Response"),
gr.Textbox(label="Emotion Detected"),
gr.Textbox(label="Emotion Score")
],
title="Mental Health Chatbot",
description="A simple mental health chatbot with emotion detection",
allow_flagging="never"
)
# Mount the Gradio app to FastAPI
app = gr.mount_gradio_app(app, iface, path="/")
# Run the app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)