safwansajad's picture
Update app.py
96ae8e4 verified
raw
history blame
3.84 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI
import json
# Load chatbot model
print("Loading DialoGPT model...")
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
# Load emotion detection model
print("Loading emotion detection model...")
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
# Store chat histories
chat_histories = {}
def chatbot_response(message, history=None, session_id="default"):
"""Generate a chatbot response and detect emotion from user message"""
# Initialize session if it doesn't exist
if session_id not in chat_histories:
chat_histories[session_id] = []
# Generate chatbot response
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Detect emotion
emotion_result = emotion_pipeline(message)
emotion = emotion_result[0]["label"]
score = float(emotion_result[0]["score"])
# Store in chat history
chat_histories[session_id].append((message, response))
return response, emotion, score, chat_histories[session_id]
def api_chatbot_response(message: str, session_id: str = "default"):
"""API endpoint version that returns a structured response"""
response, emotion, score, _ = chatbot_response(message, None, session_id)
return {
"bot_response": response,
"emotion": emotion,
"emotion_score": score,
"session_id": session_id
}
def get_chat_history(session_id: str = "default"):
"""Get chat history for a specific session"""
if session_id in chat_histories:
return chat_histories[session_id]
return []
def clear_history(session_id: str = "default"):
"""Clear chat history for a specific session"""
if session_id in chat_histories:
chat_histories[session_id] = []
return {"status": "success", "message": f"History cleared for session {session_id}"}
return {"status": "error", "message": f"Session {session_id} not found"}
# Create FastAPI app
app = FastAPI()
# Create Gradio app
with gr.Blocks() as gradio_interface:
gr.Markdown("# API Documentation")
gr.Markdown("Use these endpoints in your React Native app:")
with gr.Accordion("Chat Endpoint"):
gr.Markdown("""
**POST /chat**
- Parameters: `message`, `session_id` (optional)
- Returns: JSON with bot_response, emotion, emotion_score, session_id
""")
with gr.Accordion("History Endpoint"):
gr.Markdown("""
**GET /history**
- Parameters: `session_id` (optional)
- Returns: JSON array of message/response pairs
""")
with gr.Accordion("Clear History Endpoint"):
gr.Markdown("""
**POST /clear_history**
- Parameters: `session_id` (optional)
- Returns: JSON with status and message
""")
# Mount Gradio interface
gradio_app = gr.mount_gradio_app(app, gradio_interface, path="/")
# Add API endpoints
@app.post("/chat")
async def chat_endpoint(message: str, session_id: str = "default"):
return api_chatbot_response(message, session_id)
@app.get("/history")
async def history_endpoint(session_id: str = "default"):
return get_chat_history(session_id)
@app.post("/clear_history")
async def clear_endpoint(session_id: str = "default"):
return clear_history(session_id)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)