File size: 3,474 Bytes
fa4258c
43bfd5c
b746dc4
b3c1245
56226b9
 
a0428bd
43bfd5c
 
fa4258c
43bfd5c
 
b3c1245
7d4edac
43bfd5c
 
57aa867
7d4edac
fa4258c
43bfd5c
 
57aa867
43bfd5c
fa4258c
 
 
57aa867
b3c1245
fa4258c
b3c1245
 
7d4edac
 
 
 
 
 
 
 
 
 
 
57aa867
7d4edac
b3c1245
7d4edac
b3c1245
 
7d4edac
 
 
 
 
 
 
b3c1245
 
 
 
 
 
 
 
 
7d4edac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57aa867
7d4edac
 
 
 
 
 
 
 
 
 
57aa867
b3c1245
7d4edac
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

# Load models
chatbot_model = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
model = AutoModelForCausalLM.from_pretrained(chatbot_model)
emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")

# Store chat histories
chat_histories = {}

def chatbot_response(message, session_id="default"):
    """Core function that handles both chat and emotion analysis"""
    if session_id not in chat_histories:
        chat_histories[session_id] = []

    # Generate chatbot response
    input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
    output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

    # Detect emotion
    emotion_result = emotion_pipeline(message)
    emotion = emotion_result[0]["label"]
    score = float(emotion_result[0]["score"])

    # Store history
    chat_histories[session_id].append((message, response))
    return response, emotion, score

# ------------------ API Interface ------------------
def api_predict(message: str, session_id: str = "default"):
    """Endpoint for /predict that returns JSON"""
    response, emotion, score = chatbot_response(message, session_id)
    return {
        "response": response,
        "emotion": emotion,
        "score": score,
        "session_id": session_id
    }

# ------------------ Web Interface ------------------
with gr.Blocks(title="Mental Health Chatbot") as web_interface:
    gr.Markdown("# 🤖 Mental Health Chatbot")
    
    with gr.Row():
        with gr.Column():
            chatbot = gr.Chatbot(height=400)
            msg = gr.Textbox(placeholder="Type your message...", label="You")
            with gr.Row():
                session_id = gr.Textbox(label="Session ID", value="default")
                submit_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear")
        
        with gr.Column():
            emotion_out = gr.Textbox(label="Detected Emotion")
            score_out = gr.Number(label="Confidence Score")

    def respond(message, chat_history, session_id):
        response, emotion, score = chatbot_response(message, session_id)
        chat_history.append((message, response))
        return "", chat_history, emotion, score

    submit_btn.click(
        respond,
        [msg, chatbot, session_id],
        [msg, chatbot, emotion_out, score_out]
    )
    msg.submit(
        respond,
        [msg, chatbot, session_id],
        [msg, chatbot, emotion_out, score_out]
    )
    clear_btn.click(
        lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0),
        [session_id],
        [chatbot, emotion_out, score_out]
    )

# ------------------ Mount Interfaces ------------------
app = gr.mount_gradio_app(
    gr.routes.App(),
    web_interface,
    path="/"
)

app = gr.mount_gradio_app(
    app,
    gr.Interface(
        fn=api_predict,
        inputs=[gr.Textbox(), gr.Textbox()],
        outputs=gr.JSON(),
        title="API Predict",
        description="Use this endpoint for programmatic access"
    ),
    path="/predict"
)

# ------------------ Launch ------------------
if __name__ == "__main__":
    app.launch(show_api=False)  # We manually mounted our API