safwansajad commited on
Commit
fa4258c
·
verified ·
1 Parent(s): 43bfd5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -67
app.py CHANGED
@@ -1,22 +1,5 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.middleware.cors import CORSMiddleware
3
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
- from pydantic import BaseModel
5
- import uuid
6
-
7
- # Initialize FastAPI app
8
- app = FastAPI(title="Mental Health Chatbot API",
9
- description="API for mental health chatbot with emotion detection",
10
- version="1.0.0")
11
-
12
- # Add CORS middleware
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"], # In production, restrict this to your app's domain
16
- allow_credentials=True,
17
- allow_methods=["*"],
18
- allow_headers=["*"],
19
- )
20
 
21
  # Load chatbot model
22
  print("Loading DialoGPT model...")
@@ -28,73 +11,137 @@ model = AutoModelForCausalLM.from_pretrained(chatbot_model)
28
  print("Loading emotion detection model...")
29
  emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
30
 
31
- # Store chat histories by session ID
32
  chat_histories = {}
33
 
34
- # Request models
35
- class ChatRequest(BaseModel):
36
- message: str
37
- session_id: str = None
38
-
39
- class ChatResponse(BaseModel):
40
- response: str
41
- emotion: str
42
- emotion_score: float
43
- session_id: str
44
-
45
- @app.get("/")
46
- async def root():
47
- return {"message": "Mental Health Chatbot API is running. Use /api/chat to interact with the chatbot."}
48
-
49
- @app.post("/api/chat", response_model=ChatResponse)
50
- async def chat(request: ChatRequest):
51
- # Create a new session ID if not provided
52
- session_id = request.session_id if request.session_id else str(uuid.uuid4())
53
-
54
- # Initialize chat history for new sessions
55
  if session_id not in chat_histories:
56
  chat_histories[session_id] = []
57
 
58
- user_input = request.message
59
-
60
  # Generate chatbot response
61
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
62
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
63
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
64
 
65
  # Detect emotion
66
- emotion_result = emotion_pipeline(user_input)[0]
67
- emotion = emotion_result["label"]
68
- score = emotion_result["score"]
 
 
 
69
 
70
- # Store chat history
71
- chat_histories[session_id].append({"role": "user", "content": user_input})
72
- chat_histories[session_id].append({"role": "bot", "content": response})
 
 
73
 
74
- # Return the response
75
  return {
76
- "response": response,
77
- "emotion": emotion,
78
- "emotion_score": float(score),
79
  "session_id": session_id
80
  }
81
 
82
- @app.get("/api/history/{session_id}")
83
- async def get_chat_history(session_id: str):
84
- if session_id not in chat_histories:
85
- return {"error": "Session not found"}
86
-
87
- return {"session_id": session_id, "history": chat_histories[session_id]}
88
 
89
- @app.delete("/api/history/{session_id}")
90
- async def clear_chat_history(session_id: str):
91
  if session_id in chat_histories:
92
- chat_histories.pop(session_id)
93
- return {"message": f"Chat history for session {session_id} cleared"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- return {"error": "Session not found"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Launch the API server
98
  if __name__ == "__main__":
99
- import uvicorn
100
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import gradio as gr
 
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Load chatbot model
5
  print("Loading DialoGPT model...")
 
11
  print("Loading emotion detection model...")
12
  emotion_pipeline = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
13
 
14
+ # Store chat histories
15
  chat_histories = {}
16
 
17
+ def chatbot_response(message, history=None, session_id="default"):
18
+ """Generate a chatbot response and detect emotion from user message"""
19
+ # Initialize session if it doesn't exist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if session_id not in chat_histories:
21
  chat_histories[session_id] = []
22
 
 
 
23
  # Generate chatbot response
24
+ input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
25
  output = model.generate(input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
26
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
27
 
28
  # Detect emotion
29
+ emotion_result = emotion_pipeline(message)
30
+ emotion = emotion_result[0]["label"]
31
+ score = float(emotion_result[0]["score"])
32
+
33
+ # Store in chat history
34
+ chat_histories[session_id].append((message, response))
35
 
36
+ return response, emotion, score, chat_histories[session_id]
37
+
38
+ def api_chatbot_response(message, session_id="default"):
39
+ """API endpoint version that returns a structured response"""
40
+ response, emotion, score, _ = chatbot_response(message, None, session_id)
41
 
 
42
  return {
43
+ "bot_response": response,
44
+ "emotion": emotion,
45
+ "emotion_score": score,
46
  "session_id": session_id
47
  }
48
 
49
+ def get_chat_history(session_id="default"):
50
+ """Get chat history for a specific session"""
51
+ if session_id in chat_histories:
52
+ return chat_histories[session_id]
53
+ return []
 
54
 
55
+ def clear_history(session_id="default"):
56
+ """Clear chat history for a specific session"""
57
  if session_id in chat_histories:
58
+ chat_histories[session_id] = []
59
+ return f"History cleared for session {session_id}"
60
+ return f"Session {session_id} not found"
61
+
62
+ # Define UI interface
63
+ with gr.Blocks(title="Mental Health Chatbot") as ui_interface:
64
+ gr.Markdown("# 🧠 Mental Health Chatbot")
65
+
66
+ with gr.Row():
67
+ with gr.Column(scale=3):
68
+ chatbot = gr.Chatbot(height=400, label="Conversation")
69
+
70
+ with gr.Row():
71
+ message = gr.Textbox(placeholder="Type your message here...", label="You", show_label=False)
72
+ submit_btn = gr.Button("Send")
73
+
74
+ with gr.Row():
75
+ session_id = gr.Textbox(value="default", label="Session ID")
76
+ clear_btn = gr.Button("Clear Chat")
77
+
78
+ with gr.Column(scale=1):
79
+ emotion_label = gr.Textbox(label="Emotion Detected")
80
+ emotion_score = gr.Number(label="Confidence Score")
81
+
82
+ # Set up event handlers
83
+ def respond(message, chat_history, session_id):
84
+ response, emotion, score, _ = chatbot_response(message, chat_history, session_id)
85
+ chat_history.append((message, response))
86
+ return "", chat_history, emotion, score
87
+
88
+ submit_btn.click(
89
+ respond,
90
+ [message, chatbot, session_id],
91
+ [message, chatbot, emotion_label, emotion_score]
92
+ )
93
 
94
+ message.submit(
95
+ respond,
96
+ [message, chatbot, session_id],
97
+ [message, chatbot, emotion_label, emotion_score]
98
+ )
99
+
100
+ clear_btn.click(
101
+ lambda s: ([], clear_history(s), "", 0),
102
+ [session_id],
103
+ [chatbot, emotion_label, emotion_score]
104
+ )
105
+
106
+ # Define API interface
107
+ api_interface = gr.Interface(
108
+ fn=api_chatbot_response,
109
+ inputs=[
110
+ gr.Textbox(label="Message"),
111
+ gr.Textbox(label="Session ID", value="default")
112
+ ],
113
+ outputs=gr.JSON(label="Response"),
114
+ title="Mental Health Chatbot API",
115
+ description="Send a message to get chatbot response with emotion analysis",
116
+ examples=[
117
+ ["I'm feeling sad today", "user1"],
118
+ ["I'm so excited about my new job!", "user2"],
119
+ ["I'm worried about my exam tomorrow", "user3"]
120
+ ]
121
+ )
122
+
123
+ history_api = gr.Interface(
124
+ fn=get_chat_history,
125
+ inputs=gr.Textbox(label="Session ID", value="default"),
126
+ outputs=gr.JSON(label="Chat History"),
127
+ title="Chat History API",
128
+ description="Get chat history for a specific session"
129
+ )
130
+
131
+ clear_api = gr.Interface(
132
+ fn=clear_history,
133
+ inputs=gr.Textbox(label="Session ID", value="default"),
134
+ outputs=gr.Textbox(label="Result"),
135
+ title="Clear History API",
136
+ description="Clear chat history for a specific session"
137
+ )
138
+
139
+ # Combine all interfaces
140
+ demo = gr.TabbedInterface(
141
+ [ui_interface, api_interface, history_api, clear_api],
142
+ ["Chat UI", "Chat API", "History API", "Clear API"]
143
+ )
144
 
145
+ # Launch the app
146
  if __name__ == "__main__":
147
+ demo.launch()