safwansajad commited on
Commit
7d4edac
·
verified ·
1 Parent(s): 57aa867

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -20
app.py CHANGED
@@ -11,10 +11,11 @@ emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-eng
11
  chat_histories = {}
12
 
13
  def chatbot_response(message, session_id="default"):
 
14
  if session_id not in chat_histories:
15
  chat_histories[session_id] = []
16
 
17
- # Generate response
18
  input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
19
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
20
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
@@ -28,16 +29,30 @@ def chatbot_response(message, session_id="default"):
28
  chat_histories[session_id].append((message, response))
29
  return response, emotion, score
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  # ------------------ Web Interface ------------------
32
- with gr.Blocks() as demo:
33
  gr.Markdown("# 🤖 Mental Health Chatbot")
 
34
  with gr.Row():
35
  with gr.Column():
36
- chatbot = gr.Chatbot()
37
- msg = gr.Textbox(label="Your Message")
38
- session_id = gr.Textbox(label="Session ID", value="default")
39
- btn = gr.Button("Send")
40
- clear_btn = gr.Button("Clear History")
 
 
41
  with gr.Column():
42
  emotion_out = gr.Textbox(label="Detected Emotion")
43
  score_out = gr.Number(label="Confidence Score")
@@ -47,19 +62,41 @@ with gr.Blocks() as demo:
47
  chat_history.append((message, response))
48
  return "", chat_history, emotion, score
49
 
50
- btn.click(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
51
- msg.submit(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
52
- clear_btn.click(lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0),
53
- inputs=[session_id],
54
- outputs=[chatbot, emotion_out, score_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # ------------------ API Endpoint for /api/predict ------------------
57
- predict_api = gr.Interface(
58
- fn=chatbot_response,
59
- inputs=[gr.Textbox(label="Message"), gr.Textbox(label="Session ID")],
60
- outputs=[gr.Textbox(label="Response"), gr.Textbox(label="Emotion"), gr.Number(label="Score")]
 
 
 
 
 
61
  )
62
 
63
- # ------------------ Launch for Gradio Spaces ------------------
64
- demo.launch()
65
- predict_api.launch(inline=False)
 
11
  chat_histories = {}
12
 
13
  def chatbot_response(message, session_id="default"):
14
+ """Core function that handles both chat and emotion analysis"""
15
  if session_id not in chat_histories:
16
  chat_histories[session_id] = []
17
 
18
+ # Generate chatbot response
19
  input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
20
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
21
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
 
29
  chat_histories[session_id].append((message, response))
30
  return response, emotion, score
31
 
32
+ # ------------------ API Interface ------------------
33
+ def api_predict(message: str, session_id: str = "default"):
34
+ """Endpoint for /predict that returns JSON"""
35
+ response, emotion, score = chatbot_response(message, session_id)
36
+ return {
37
+ "response": response,
38
+ "emotion": emotion,
39
+ "score": score,
40
+ "session_id": session_id
41
+ }
42
+
43
  # ------------------ Web Interface ------------------
44
+ with gr.Blocks(title="Mental Health Chatbot") as web_interface:
45
  gr.Markdown("# 🤖 Mental Health Chatbot")
46
+
47
  with gr.Row():
48
  with gr.Column():
49
+ chatbot = gr.Chatbot(height=400)
50
+ msg = gr.Textbox(placeholder="Type your message...", label="You")
51
+ with gr.Row():
52
+ session_id = gr.Textbox(label="Session ID", value="default")
53
+ submit_btn = gr.Button("Send", variant="primary")
54
+ clear_btn = gr.Button("Clear")
55
+
56
  with gr.Column():
57
  emotion_out = gr.Textbox(label="Detected Emotion")
58
  score_out = gr.Number(label="Confidence Score")
 
62
  chat_history.append((message, response))
63
  return "", chat_history, emotion, score
64
 
65
+ submit_btn.click(
66
+ respond,
67
+ [msg, chatbot, session_id],
68
+ [msg, chatbot, emotion_out, score_out]
69
+ )
70
+ msg.submit(
71
+ respond,
72
+ [msg, chatbot, session_id],
73
+ [msg, chatbot, emotion_out, score_out]
74
+ )
75
+ clear_btn.click(
76
+ lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0),
77
+ [session_id],
78
+ [chatbot, emotion_out, score_out]
79
+ )
80
+
81
+ # ------------------ Mount Interfaces ------------------
82
+ app = gr.mount_gradio_app(
83
+ gr.routes.App(),
84
+ web_interface,
85
+ path="/"
86
+ )
87
 
88
+ app = gr.mount_gradio_app(
89
+ app,
90
+ gr.Interface(
91
+ fn=api_predict,
92
+ inputs=[gr.Textbox(), gr.Textbox()],
93
+ outputs=gr.JSON(),
94
+ title="API Predict",
95
+ description="Use this endpoint for programmatic access"
96
+ ),
97
+ path="/predict"
98
  )
99
 
100
+ # ------------------ Launch ------------------
101
+ if __name__ == "__main__":
102
+ app.launch(show_api=False) # We manually mounted our API