safwansajad's picture
Update app.py
6b8ac59 verified
raw
history blame
2.23 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI, Request
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
# Create FastAPI app
app = FastAPI()
# Add CORS middleware to allow requests from your React Native app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your actual domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load chatbot model
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
# Load emotion detection model
emotion_pipeline = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion")
def generate_response(user_input):
# Generate chatbot response
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Detect emotion
emotion_result = emotion_pipeline(user_input)
emotion = emotion_result[0]["label"]
return response, emotion
# Create API endpoint
@app.post("/analyze")
async def analyze_text(request: Request):
data = await request.json()
user_input = data.get("text", "")
if not user_input:
return {"error": "No text provided"}
response, emotion = generate_response(user_input)
# Return structured response
return {
"response": response,
"emotion": emotion,
"score": emotion_pipeline(user_input)[0]["score"]
}
# Create Gradio interface (optional, can keep for web testing)
iface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(label="Enter your message"),
outputs=[gr.Textbox(label="Chatbot Response"), gr.Textbox(label="Emotion Detected")],
live=True
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, iface, path="/")
# Only needed if running directly
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)