safwansajad commited on
Commit
96ae8e4
·
verified ·
1 Parent(s): a02f6d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -82
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
 
3
 
4
  # Load chatbot model
5
  print("Loading DialoGPT model...")
@@ -35,7 +37,7 @@ def chatbot_response(message, history=None, session_id="default"):
35
 
36
  return response, emotion, score, chat_histories[session_id]
37
 
38
- def api_chatbot_response(message, session_id="default"):
39
  """API endpoint version that returns a structured response"""
40
  response, emotion, score, _ = chatbot_response(message, None, session_id)
41
 
@@ -46,102 +48,64 @@ def api_chatbot_response(message, session_id="default"):
46
  "session_id": session_id
47
  }
48
 
49
- def get_chat_history(session_id="default"):
50
  """Get chat history for a specific session"""
51
  if session_id in chat_histories:
52
  return chat_histories[session_id]
53
  return []
54
 
55
- def clear_history(session_id="default"):
56
  """Clear chat history for a specific session"""
57
  if session_id in chat_histories:
58
  chat_histories[session_id] = []
59
- return f"History cleared for session {session_id}"
60
- return f"Session {session_id} not found"
61
 
62
- # Define UI interface
63
- with gr.Blocks(title="Mental Health Chatbot") as ui_interface:
64
- gr.Markdown("# 🧠 Mental Health Chatbot")
65
-
66
- with gr.Row():
67
- with gr.Column(scale=3):
68
- chatbot = gr.Chatbot(height=400, label="Conversation")
69
-
70
- with gr.Row():
71
- message = gr.Textbox(placeholder="Type your message here...", label="You", show_label=False)
72
- submit_btn = gr.Button("Send")
73
-
74
- with gr.Row():
75
- session_id = gr.Textbox(value="default", label="Session ID")
76
- clear_btn = gr.Button("Clear Chat")
77
-
78
- with gr.Column(scale=1):
79
- emotion_label = gr.Textbox(label="Emotion Detected")
80
- emotion_score = gr.Number(label="Confidence Score")
81
-
82
- # Set up event handlers
83
- def respond(message, chat_history, session_id):
84
- response, emotion, score, _ = chatbot_response(message, chat_history, session_id)
85
- chat_history.append((message, response))
86
- return "", chat_history, emotion, score
87
 
88
- submit_btn.click(
89
- respond,
90
- [message, chatbot, session_id],
91
- [message, chatbot, emotion_label, emotion_score]
92
- )
 
93
 
94
- message.submit(
95
- respond,
96
- [message, chatbot, session_id],
97
- [message, chatbot, emotion_label, emotion_score]
98
- )
 
99
 
100
- clear_btn.click(
101
- lambda s: ([], clear_history(s), "", 0),
102
- [session_id],
103
- [chatbot, emotion_label, emotion_score]
104
- )
 
105
 
106
- # Define API interface
107
- api_interface = gr.Interface(
108
- fn=api_chatbot_response,
109
- inputs=[
110
- gr.Textbox(label="Message"),
111
- gr.Textbox(label="Session ID", value="default")
112
- ],
113
- outputs=gr.JSON(label="Response"),
114
- title="Mental Health Chatbot API",
115
- description="Send a message to get chatbot response with emotion analysis",
116
- examples=[
117
- ["I'm feeling sad today", "user1"],
118
- ["I'm so excited about my new job!", "user2"],
119
- ["I'm worried about my exam tomorrow", "user3"]
120
- ]
121
- )
122
 
123
- history_api = gr.Interface(
124
- fn=get_chat_history,
125
- inputs=gr.Textbox(label="Session ID", value="default"),
126
- outputs=gr.JSON(label="Chat History"),
127
- title="Chat History API",
128
- description="Get chat history for a specific session"
129
- )
130
 
131
- clear_api = gr.Interface(
132
- fn=clear_history,
133
- inputs=gr.Textbox(label="Session ID", value="default"),
134
- outputs=gr.Textbox(label="Result"),
135
- title="Clear History API",
136
- description="Clear chat history for a specific session"
137
- )
138
 
139
- # Combine all interfaces
140
- demo = gr.TabbedInterface(
141
- [ui_interface, api_interface, history_api, clear_api],
142
- ["Chat UI", "Chat API", "History API", "Clear API"]
143
- )
144
 
145
- # Launch the app
146
  if __name__ == "__main__":
147
- demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ from fastapi import FastAPI
4
+ import json
5
 
6
  # Load chatbot model
7
  print("Loading DialoGPT model...")
 
37
 
38
  return response, emotion, score, chat_histories[session_id]
39
 
40
+ def api_chatbot_response(message: str, session_id: str = "default"):
41
  """API endpoint version that returns a structured response"""
42
  response, emotion, score, _ = chatbot_response(message, None, session_id)
43
 
 
48
  "session_id": session_id
49
  }
50
 
51
+ def get_chat_history(session_id: str = "default"):
52
  """Get chat history for a specific session"""
53
  if session_id in chat_histories:
54
  return chat_histories[session_id]
55
  return []
56
 
57
+ def clear_history(session_id: str = "default"):
58
  """Clear chat history for a specific session"""
59
  if session_id in chat_histories:
60
  chat_histories[session_id] = []
61
+ return {"status": "success", "message": f"History cleared for session {session_id}"}
62
+ return {"status": "error", "message": f"Session {session_id} not found"}
63
 
64
+ # Create FastAPI app
65
+ app = FastAPI()
66
+
67
+ # Create Gradio app
68
+ with gr.Blocks() as gradio_interface:
69
+ gr.Markdown("# API Documentation")
70
+ gr.Markdown("Use these endpoints in your React Native app:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ with gr.Accordion("Chat Endpoint"):
73
+ gr.Markdown("""
74
+ **POST /chat**
75
+ - Parameters: `message`, `session_id` (optional)
76
+ - Returns: JSON with bot_response, emotion, emotion_score, session_id
77
+ """)
78
 
79
+ with gr.Accordion("History Endpoint"):
80
+ gr.Markdown("""
81
+ **GET /history**
82
+ - Parameters: `session_id` (optional)
83
+ - Returns: JSON array of message/response pairs
84
+ """)
85
 
86
+ with gr.Accordion("Clear History Endpoint"):
87
+ gr.Markdown("""
88
+ **POST /clear_history**
89
+ - Parameters: `session_id` (optional)
90
+ - Returns: JSON with status and message
91
+ """)
92
 
93
+ # Mount Gradio interface
94
+ gradio_app = gr.mount_gradio_app(app, gradio_interface, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Add API endpoints
97
+ @app.post("/chat")
98
+ async def chat_endpoint(message: str, session_id: str = "default"):
99
+ return api_chatbot_response(message, session_id)
 
 
 
100
 
101
+ @app.get("/history")
102
+ async def history_endpoint(session_id: str = "default"):
103
+ return get_chat_history(session_id)
 
 
 
 
104
 
105
+ @app.post("/clear_history")
106
+ async def clear_endpoint(session_id: str = "default"):
107
+ return clear_history(session_id)
 
 
108
 
 
109
  if __name__ == "__main__":
110
+ import uvicorn
111
+ uvicorn.run(app, host="0.0.0.0", port=7860)