File size: 3,840 Bytes
fa4258c
43bfd5c
96ae8e4
 
b746dc4
a0428bd
43bfd5c
56226b9
 
a0428bd
b746dc4
a0428bd
bc974ff
43bfd5c
 
fa4258c
43bfd5c
 
fa4258c
 
 
43bfd5c
 
bc974ff
43bfd5c
fa4258c
43bfd5c
 
 
 
fa4258c
 
 
 
 
 
43bfd5c
fa4258c
 
96ae8e4
fa4258c
 
43bfd5c
bc974ff
fa4258c
 
 
bc974ff
 
b746dc4
96ae8e4
fa4258c
 
 
 
bc974ff
96ae8e4
fa4258c
43bfd5c
fa4258c
96ae8e4
 
fa4258c
96ae8e4
 
 
 
 
 
 
fa4258c
96ae8e4
 
 
 
 
 
43bfd5c
96ae8e4
 
 
 
 
 
fa4258c
96ae8e4
 
 
 
 
 
fa4258c
96ae8e4
 
fa4258c
96ae8e4
 
 
 
fa4258c
96ae8e4
 
 
fa4258c
96ae8e4
 
 
bc974ff
 
96ae8e4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI
import json

# Load chatbot model
print("Loading DialoGPT model...")
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)

# Load emotion detection model
print("Loading emotion detection model...")
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")

# Store chat histories
chat_histories = {}

def chatbot_response(message, history=None, session_id="default"):
    """Generate a chatbot response and detect emotion from user message"""
    # Initialize session if it doesn't exist
    if session_id not in chat_histories:
        chat_histories[session_id] = []
    
    # Generate chatbot response
    input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
    output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
    
    # Detect emotion
    emotion_result = emotion_pipeline(message)
    emotion = emotion_result[0]["label"]
    score = float(emotion_result[0]["score"])
    
    # Store in chat history
    chat_histories[session_id].append((message, response))
    
    return response, emotion, score, chat_histories[session_id]

def api_chatbot_response(message: str, session_id: str = "default"):
    """API endpoint version that returns a structured response"""
    response, emotion, score, _ = chatbot_response(message, None, session_id)
    
    return {
        "bot_response": response, 
        "emotion": emotion, 
        "emotion_score": score,
        "session_id": session_id
    }

def get_chat_history(session_id: str = "default"):
    """Get chat history for a specific session"""
    if session_id in chat_histories:
        return chat_histories[session_id]
    return []

def clear_history(session_id: str = "default"):
    """Clear chat history for a specific session"""
    if session_id in chat_histories:
        chat_histories[session_id] = []
        return {"status": "success", "message": f"History cleared for session {session_id}"}
    return {"status": "error", "message": f"Session {session_id} not found"}

# Create FastAPI app
app = FastAPI()

# Create Gradio app
with gr.Blocks() as gradio_interface:
    gr.Markdown("# API Documentation")
    gr.Markdown("Use these endpoints in your React Native app:")
    
    with gr.Accordion("Chat Endpoint"):
        gr.Markdown("""
        **POST /chat**
        - Parameters: `message`, `session_id` (optional)
        - Returns: JSON with bot_response, emotion, emotion_score, session_id
        """)
    
    with gr.Accordion("History Endpoint"):
        gr.Markdown("""
        **GET /history**
        - Parameters: `session_id` (optional)
        - Returns: JSON array of message/response pairs
        """)
    
    with gr.Accordion("Clear History Endpoint"):
        gr.Markdown("""
        **POST /clear_history**
        - Parameters: `session_id` (optional)
        - Returns: JSON with status and message
        """)

# Mount Gradio interface
gradio_app = gr.mount_gradio_app(app, gradio_interface, path="/")

# Add API endpoints
@app.post("/chat")
async def chat_endpoint(message: str, session_id: str = "default"):
    return api_chatbot_response(message, session_id)

@app.get("/history")
async def history_endpoint(session_id: str = "default"):
    return get_chat_history(session_id)

@app.post("/clear_history")
async def clear_endpoint(session_id: str = "default"):
    return clear_history(session_id)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)