File size: 3,840 Bytes
fa4258c 43bfd5c 96ae8e4 b746dc4 a0428bd 43bfd5c 56226b9 a0428bd b746dc4 a0428bd bc974ff 43bfd5c fa4258c 43bfd5c fa4258c 43bfd5c bc974ff 43bfd5c fa4258c 43bfd5c fa4258c 43bfd5c fa4258c 96ae8e4 fa4258c 43bfd5c bc974ff fa4258c bc974ff b746dc4 96ae8e4 fa4258c bc974ff 96ae8e4 fa4258c 43bfd5c fa4258c 96ae8e4 fa4258c 96ae8e4 fa4258c 96ae8e4 43bfd5c 96ae8e4 fa4258c 96ae8e4 fa4258c 96ae8e4 fa4258c 96ae8e4 fa4258c 96ae8e4 fa4258c 96ae8e4 bc974ff 96ae8e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI
import json
# Load chatbot model
print("Loading DialoGPT model...")
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
# Load emotion detection model
print("Loading emotion detection model...")
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
# Store chat histories
chat_histories = {}
def chatbot_response(message, history=None, session_id="default"):
"""Generate a chatbot response and detect emotion from user message"""
# Initialize session if it doesn't exist
if session_id not in chat_histories:
chat_histories[session_id] = []
# Generate chatbot response
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Detect emotion
emotion_result = emotion_pipeline(message)
emotion = emotion_result[0]["label"]
score = float(emotion_result[0]["score"])
# Store in chat history
chat_histories[session_id].append((message, response))
return response, emotion, score, chat_histories[session_id]
def api_chatbot_response(message: str, session_id: str = "default"):
"""API endpoint version that returns a structured response"""
response, emotion, score, _ = chatbot_response(message, None, session_id)
return {
"bot_response": response,
"emotion": emotion,
"emotion_score": score,
"session_id": session_id
}
def get_chat_history(session_id: str = "default"):
"""Get chat history for a specific session"""
if session_id in chat_histories:
return chat_histories[session_id]
return []
def clear_history(session_id: str = "default"):
"""Clear chat history for a specific session"""
if session_id in chat_histories:
chat_histories[session_id] = []
return {"status": "success", "message": f"History cleared for session {session_id}"}
return {"status": "error", "message": f"Session {session_id} not found"}
# Create FastAPI app
app = FastAPI()
# Create Gradio app
with gr.Blocks() as gradio_interface:
gr.Markdown("# API Documentation")
gr.Markdown("Use these endpoints in your React Native app:")
with gr.Accordion("Chat Endpoint"):
gr.Markdown("""
**POST /chat**
- Parameters: `message`, `session_id` (optional)
- Returns: JSON with bot_response, emotion, emotion_score, session_id
""")
with gr.Accordion("History Endpoint"):
gr.Markdown("""
**GET /history**
- Parameters: `session_id` (optional)
- Returns: JSON array of message/response pairs
""")
with gr.Accordion("Clear History Endpoint"):
gr.Markdown("""
**POST /clear_history**
- Parameters: `session_id` (optional)
- Returns: JSON with status and message
""")
# Mount Gradio interface
gradio_app = gr.mount_gradio_app(app, gradio_interface, path="/")
# Add API endpoints
@app.post("/chat")
async def chat_endpoint(message: str, session_id: str = "default"):
return api_chatbot_response(message, session_id)
@app.get("/history")
async def history_endpoint(session_id: str = "default"):
return get_chat_history(session_id)
@app.post("/clear_history")
async def clear_endpoint(session_id: str = "default"):
return clear_history(session_id)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |