safwansajad commited on
Commit
b3c1245
·
verified ·
1 Parent(s): f1bc48e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -84
app.py CHANGED
@@ -3,26 +3,20 @@ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from fastapi import FastAPI
4
  import json
5
 
6
- # Load chatbot model
7
- print("Loading DialoGPT model...")
8
  chatbot_model = "microsoft/DialoGPT-medium"
9
  tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
10
  model = AutoModelForCausalLM.from_pretrained(chatbot_model)
11
-
12
- # Load emotion detection model
13
- print("Loading emotion detection model...")
14
  emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
15
 
16
  # Store chat histories
17
  chat_histories = {}
18
 
19
- def chatbot_response(message, history=None, session_id="default"):
20
- """Generate a chatbot response and detect emotion from user message"""
21
- # Initialize session if it doesn't exist
22
  if session_id not in chat_histories:
23
  chat_histories[session_id] = []
24
 
25
- # Generate chatbot response
26
  input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
27
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
28
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
@@ -32,80 +26,35 @@ def chatbot_response(message, history=None, session_id="default"):
32
  emotion = emotion_result[0]["label"]
33
  score = float(emotion_result[0]["score"])
34
 
35
- # Store in chat history
36
  chat_histories[session_id].append((message, response))
37
-
38
- return response, emotion, score, chat_histories[session_id]
39
-
40
- def api_chatbot_response(message: str, session_id: str = "default"):
41
- """API endpoint version that returns a structured response"""
42
- response, emotion, score, _ = chatbot_response(message, None, session_id)
43
-
44
- return {
45
- "bot_response": response,
46
- "emotion": emotion,
47
- "emotion_score": score,
48
- "session_id": session_id
49
- }
50
-
51
- def get_chat_history(session_id: str = "default"):
52
- """Get chat history for a specific session"""
53
- if session_id in chat_histories:
54
- return chat_histories[session_id]
55
- return []
56
-
57
- def clear_history(session_id: str = "default"):
58
- """Clear chat history for a specific session"""
59
- if session_id in chat_histories:
60
- chat_histories[session_id] = []
61
- return {"status": "success", "message": f"History cleared for session {session_id}"}
62
- return {"status": "error", "message": f"Session {session_id} not found"}
63
-
64
- # Create FastAPI app
65
- app = FastAPI()
66
-
67
- # Create Gradio app
68
- with gr.Blocks() as gradio_interface:
69
- gr.Markdown("# API Documentation")
70
- gr.Markdown("Use these endpoints in your React Native app:")
71
-
72
- with gr.Accordion("Chat Endpoint"):
73
- gr.Markdown("""
74
- **POST /chat**
75
- - Parameters: `message`, `session_id` (optional)
76
- - Returns: JSON with bot_response, emotion, emotion_score, session_id
77
- """)
78
-
79
- with gr.Accordion("History Endpoint"):
80
- gr.Markdown("""
81
- **GET /history**
82
- - Parameters: `session_id` (optional)
83
- - Returns: JSON array of message/response pairs
84
- """)
85
-
86
- with gr.Accordion("Clear History Endpoint"):
87
- gr.Markdown("""
88
- **POST /clear_history**
89
- - Parameters: `session_id` (optional)
90
- - Returns: JSON with status and message
91
- """)
92
-
93
- # Mount Gradio interface
94
- gradio_app = gr.mount_gradio_app(app, gradio_interface, path="/")
95
-
96
- # Add API endpoints
97
- @app.post("/chat")
98
- async def chat_endpoint(message: str, session_id: str = "default"):
99
- return api_chatbot_response(message, session_id)
100
-
101
- @app.get("/history")
102
- async def history_endpoint(session_id: str = "default"):
103
- return get_chat_history(session_id)
104
-
105
- @app.post("/clear_history")
106
- async def clear_endpoint(session_id: str = "default"):
107
- return clear_history(session_id)
108
-
109
  if __name__ == "__main__":
110
- import uvicorn
111
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from fastapi import FastAPI
4
  import json
5
 
6
+ # Load models
 
7
  chatbot_model = "microsoft/DialoGPT-medium"
8
  tokenizer = AutoTokenizer.from_pretrained(chatbot_model)
9
  model = AutoModelForCausalLM.from_pretrained(chatbot_model)
 
 
 
10
  emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
11
 
12
  # Store chat histories
13
  chat_histories = {}
14
 
15
+ def chatbot_response(message, session_id="default"):
 
 
16
  if session_id not in chat_histories:
17
  chat_histories[session_id] = []
18
 
19
+ # Generate response
20
  input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
21
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
22
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
 
26
  emotion = emotion_result[0]["label"]
27
  score = float(emotion_result[0]["score"])
28
 
29
+ # Store history
30
  chat_histories[session_id].append((message, response))
31
+ return response, emotion, score
32
+
33
+ # Gradio Interface (Primary for Spaces)
34
+ with gr.Blocks() as demo:
35
+ gr.Markdown("# 🤖 Mental Health Chatbot")
36
+ with gr.Row():
37
+ with gr.Column():
38
+ chatbot = gr.Chatbot()
39
+ msg = gr.Textbox(label="Your Message")
40
+ session_id = gr.Textbox(label="Session ID", value="default")
41
+ btn = gr.Button("Send")
42
+ clear_btn = gr.Button("Clear History")
43
+ with gr.Column():
44
+ emotion_out = gr.Textbox(label="Detected Emotion")
45
+ score_out = gr.Number(label="Confidence Score")
46
+
47
+ def respond(message, chat_history, session_id):
48
+ response, emotion, score = chatbot_response(message, session_id)
49
+ chat_history.append((message, response))
50
+ return "", chat_history, emotion, score
51
+
52
+ btn.click(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
53
+ msg.submit(respond, [msg, chatbot, session_id], [msg, chatbot, emotion_out, score_out])
54
+ clear_btn.click(lambda s_id: ([], "", 0.0) if s_id in chat_histories else ([], "", 0.0),
55
+ inputs=[session_id],
56
+ outputs=[chatbot, emotion_out, score_out])
57
+
58
+ # For Hugging Face Spaces, Gradio must be the main interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if __name__ == "__main__":
60
+ demo.launch()