safwansajad commited on
Commit
57aa867
·
verified ·
1 Parent(s): 743eff1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
- import json
4
 
5
  # Load models
6
  chatbot_model = "microsoft/DialoGPT-medium"
@@ -14,22 +13,22 @@ chat_histories = {}
14
  def chatbot_response(message, session_id="default"):
15
  if session_id not in chat_histories:
16
  chat_histories[session_id] = []
17
-
18
  # Generate response
19
  input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
20
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
21
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
22
-
23
  # Detect emotion
24
  emotion_result = emotion_pipeline(message)
25
  emotion = emotion_result[0]["label"]
26
  score = float(emotion_result[0]["score"])
27
-
28
  # Store history
29
  chat_histories[session_id].append((message, response))
30
  return response, emotion, score
31
 
32
- # Gradio Interface (Primary for Spaces)
33
  with gr.Blocks() as demo:
34
  gr.Markdown("# 🤖 Mental Health Chatbot")
35
  with gr.Row():
@@ -50,10 +49,17 @@ with gr.Blocks() as demo:
50
 
51
  btn.click(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
52
  msg.submit(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
53
- clear_btn.click(lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0),
54
- inputs=[session_id],
55
- outputs=[chatbot, emotion_out, score_out])
 
 
 
 
 
 
 
56
 
57
- # For Hugging Face Spaces, Gradio must be the main interface
58
- if __name__ == "__main__":
59
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
3
 
4
  # Load models
5
  chatbot_model = "microsoft/DialoGPT-medium"
 
13
  def chatbot_response(message, session_id="default"):
14
  if session_id not in chat_histories:
15
  chat_histories[session_id] = []
16
+
17
  # Generate response
18
  input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
19
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
20
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
21
+
22
  # Detect emotion
23
  emotion_result = emotion_pipeline(message)
24
  emotion = emotion_result[0]["label"]
25
  score = float(emotion_result[0]["score"])
26
+
27
  # Store history
28
  chat_histories[session_id].append((message, response))
29
  return response, emotion, score
30
 
31
+ # ------------------ Web Interface ------------------
32
  with gr.Blocks() as demo:
33
  gr.Markdown("# 🤖 Mental Health Chatbot")
34
  with gr.Row():
 
49
 
50
  btn.click(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
51
  msg.submit(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
52
+ clear_btn.click(lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0),
53
+ inputs=[session_id],
54
+ outputs=[chatbot, emotion_out, score_out])
55
+
56
+ # ------------------ API Endpoint for /api/predict ------------------
57
+ predict_api = gr.Interface(
58
+ fn=chatbot_response,
59
+ inputs=[gr.Textbox(label="Message"), gr.Textbox(label="Session ID")],
60
+ outputs=[gr.Textbox(label="Response"), gr.Textbox(label="Emotion"), gr.Number(label="Score")]
61
+ )
62
 
63
+ # ------------------ Launch for Gradio Spaces ------------------
64
+ demo.launch()
65
+ predict_api.launch(inline=False)