File size: 3,558 Bytes
81a2ae6
bc974ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b746dc4
a0428bd
bc974ff
56226b9
 
a0428bd
b746dc4
a0428bd
bc974ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b746dc4
bc974ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b8ac59
81a2ae6
bc974ff
 
 
 
 
 
 
 
a0428bd
f3aaa59
6b8ac59
f3aaa59
bc974ff
 
 
 
 
81a2ae6
bc974ff
 
f3aaa59
 
 
 
 
bc974ff
 
 
b915000
 
bc974ff
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json

# Initialize FastAPI
app = FastAPI()

# Add CORS middleware to allow requests from your React Native app
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, restrict this to your app's domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load chatbot model
print("Loading chatbot model...")
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)

# Load emotion detection model
print("Loading emotion detection model...")
emotion_pipeline = pipeline(
    "text-classification", 
    model="bhadresh-savani/distilbert-base-uncased-emotion",
    return_all_scores=False
)

# Store conversation history
chat_history = {}

# Define API endpoint for chatbot
@app.post("/api/chat")
async def chat_endpoint(request: Request):
    data = await request.json()
    user_input = data.get("message", "")
    session_id = data.get("session_id", "default")
    
    # Process input and get response
    response, emotion, score = generate_response(user_input, session_id)
    
    return {
        "response": response,
        "emotion": emotion,
        "score": float(score),
        "session_id": session_id
    }

def generate_response(user_input, session_id="default"):
    # Initialize chat history for new sessions
    if session_id not in chat_history:
        chat_history[session_id] = []
    
    # Format the input with chat history
    bot_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
    
    # Append user input to chat history
    chat_history[session_id].append(bot_input_ids)
    
    # Generate a response
    with torch.no_grad():
        chat_history_ids = model.generate(
            bot_input_ids,
            max_length=200,
            pad_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=3,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7
        )
    
    # Decode the response
    response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    
    # Detect emotion
    try:
        emotion_result = emotion_pipeline(user_input)[0]
        emotion = emotion_result["label"]
        score = emotion_result["score"]
    except Exception as e:
        print(f"Error detecting emotion: {e}")
        emotion = "unknown"
        score = 0.0
    
    return response, emotion, score

# Gradio Interface
def gradio_generate(user_input):
    response, emotion, score = generate_response(user_input)
    return response, emotion, f"{score:.4f}"

# Create Gradio interface
iface = gr.Interface(
    fn=gradio_generate,
    inputs=gr.Textbox(label="Enter your message", placeholder="Type your message here..."),
    outputs=[
        gr.Textbox(label="Chatbot Response"),
        gr.Textbox(label="Emotion Detected"),
        gr.Textbox(label="Emotion Score")
    ],
    title="Mental Health Chatbot",
    description="A simple mental health chatbot with emotion detection",
    allow_flagging="never"
)

# Mount the Gradio app to FastAPI
app = gr.mount_gradio_app(app, iface, path="/")

# Run the app
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)