awacke1 commited on
Commit
b1fe0dd
·
verified ·
1 Parent(s): 239a534

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -0
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import asyncio
3
+ import websockets
4
+ import json
5
+ import uuid
6
+ import argparse
7
+ import urllib.parse
8
+ from datetime import datetime
9
+ import logging
10
+ import sys
11
+
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
+ handlers=[logging.StreamHandler(sys.stdout)]
17
+ )
18
+ logger = logging.getLogger("chat-node")
19
+
20
+ # Dictionary to store active connections
21
+ active_connections = {}
22
+ # Dictionary to store message history for each chat room
23
+ chat_history = {}
24
+
25
+ # Get node name from URL or command line
26
+ def get_node_name():
27
+ parser = argparse.ArgumentParser(description='Start a chat node with a specific name')
28
+ parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node')
29
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on')
30
+
31
+ args = parser.parse_args()
32
+ node_name = args.node_name
33
+ port = args.port
34
+
35
+ # If no node name specified, generate a random one
36
+ if not node_name:
37
+ node_name = f"node-{uuid.uuid4().hex[:8]}"
38
+
39
+ return node_name, port
40
+
41
+ async def websocket_handler(websocket, path):
42
+ """Handle WebSocket connections."""
43
+ try:
44
+ # Extract room_id from path if present
45
+ path_parts = path.strip('/').split('/')
46
+ room_id = path_parts[0] if path_parts else "default"
47
+
48
+ # Register the new client
49
+ client_id = str(uuid.uuid4())
50
+ if room_id not in active_connections:
51
+ active_connections[room_id] = {}
52
+ chat_history[room_id] = []
53
+
54
+ active_connections[room_id][client_id] = websocket
55
+
56
+ # Send welcome message and chat history
57
+ welcome_msg = {
58
+ "type": "system",
59
+ "content": f"Welcome to room '{room_id}'! Connected from node '{NODE_NAME}'",
60
+ "timestamp": datetime.now().isoformat(),
61
+ "sender": "system",
62
+ "room_id": room_id
63
+ }
64
+ await websocket.send(json.dumps(welcome_msg))
65
+
66
+ # Send chat history
67
+ for msg in chat_history[room_id]:
68
+ await websocket.send(json.dumps(msg))
69
+
70
+ # Broadcast join notification
71
+ join_msg = {
72
+ "type": "system",
73
+ "content": f"User joined the room",
74
+ "timestamp": datetime.now().isoformat(),
75
+ "sender": "system",
76
+ "room_id": room_id
77
+ }
78
+ await broadcast_message(join_msg, room_id)
79
+
80
+ logger.info(f"New client {client_id} connected to room {room_id}")
81
+
82
+ # Handle messages from this client
83
+ async for message in websocket:
84
+ try:
85
+ data = json.loads(message)
86
+
87
+ # Add metadata to the message
88
+ data["timestamp"] = datetime.now().isoformat()
89
+ data["sender_node"] = NODE_NAME
90
+ data["room_id"] = room_id
91
+
92
+ # Store in history
93
+ chat_history[room_id].append(data)
94
+ if len(chat_history[room_id]) > 100: # Limit history to 100 messages
95
+ chat_history[room_id] = chat_history[room_id][-100:]
96
+
97
+ # Broadcast to all clients in the room
98
+ await broadcast_message(data, room_id)
99
+
100
+ except json.JSONDecodeError:
101
+ error_msg = {
102
+ "type": "error",
103
+ "content": "Invalid JSON format",
104
+ "timestamp": datetime.now().isoformat(),
105
+ "sender": "system",
106
+ "room_id": room_id
107
+ }
108
+ await websocket.send(json.dumps(error_msg))
109
+
110
+ except websockets.exceptions.ConnectionClosed:
111
+ logger.info(f"Client {client_id} disconnected from room {room_id}")
112
+ finally:
113
+ # Remove the client when disconnected
114
+ if room_id in active_connections and client_id in active_connections[room_id]:
115
+ del active_connections[room_id][client_id]
116
+
117
+ # Broadcast leave notification
118
+ leave_msg = {
119
+ "type": "system",
120
+ "content": f"User left the room",
121
+ "timestamp": datetime.now().isoformat(),
122
+ "sender": "system",
123
+ "room_id": room_id
124
+ }
125
+ await broadcast_message(leave_msg, room_id)
126
+
127
+ # Clean up empty rooms
128
+ if not active_connections[room_id]:
129
+ del active_connections[room_id]
130
+ # Optionally, you might want to keep the chat history
131
+
132
+ async def broadcast_message(message, room_id):
133
+ """Broadcast a message to all clients in a room."""
134
+ if room_id in active_connections:
135
+ disconnected_clients = []
136
+
137
+ for client_id, websocket in active_connections[room_id].items():
138
+ try:
139
+ await websocket.send(json.dumps(message))
140
+ except websockets.exceptions.ConnectionClosed:
141
+ disconnected_clients.append(client_id)
142
+
143
+ # Clean up disconnected clients
144
+ for client_id in disconnected_clients:
145
+ del active_connections[room_id][client_id]
146
+
147
+ async def start_websocket_server(host='0.0.0.0', port=8765):
148
+ """Start the WebSocket server."""
149
+ server = await websockets.serve(websocket_handler, host, port)
150
+ logger.info(f"WebSocket server started on ws://{host}:{port}")
151
+ return server
152
+
153
+ def send_message(message, username, room_id):
154
+ """Function to send a message from the Gradio interface."""
155
+ if not message.strip():
156
+ return None
157
+
158
+ loop = asyncio.get_event_loop()
159
+ msg_data = {
160
+ "type": "chat",
161
+ "content": message,
162
+ "username": username,
163
+ "room_id": room_id
164
+ }
165
+
166
+ loop.create_task(broadcast_message(msg_data, room_id))
167
+
168
+ # Format the message for display
169
+ formatted_msg = f"{username}: {message}"
170
+ return formatted_msg
171
+
172
+ def join_room(room_id, chat_history_output):
173
+ """Join a specific chat room."""
174
+ if not room_id.strip():
175
+ return "Please enter a valid room ID", chat_history_output
176
+
177
+ # Sanitize the room ID
178
+ room_id = urllib.parse.quote(room_id.strip())
179
+
180
+ # Create the room if it doesn't exist
181
+ if room_id not in chat_history:
182
+ chat_history[room_id] = []
183
+
184
+ # Format existing messages
185
+ formatted_history = []
186
+ for msg in chat_history[room_id]:
187
+ if msg.get("type") == "chat":
188
+ formatted_history.append(f"{msg.get('username', 'Anonymous')}: {msg.get('content', '')}")
189
+ elif msg.get("type") == "system":
190
+ formatted_history.append(f"System: {msg.get('content', '')}")
191
+
192
+ return f"Joined room: {room_id}", formatted_history
193
+
194
+ def create_gradio_interface():
195
+ """Create and return the Gradio interface."""
196
+ with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface:
197
+ gr.Markdown(f"# Chat Node: {NODE_NAME}")
198
+ gr.Markdown("Join a room by entering a room ID below or create a new one.")
199
+
200
+ with gr.Row():
201
+ room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID")
202
+ join_button = gr.Button("Join Room")
203
+
204
+ chat_history_output = gr.Textbox(label="Chat History", lines=15, max_lines=15)
205
+
206
+ with gr.Row():
207
+ username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User")
208
+ message_input = gr.Textbox(label="Message", placeholder="Type your message here")
209
+ send_button = gr.Button("Send")
210
+
211
+ # Current room display
212
+ current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet")
213
+
214
+ # Event handlers
215
+ join_button.click(
216
+ join_room,
217
+ inputs=[room_id_input, chat_history_output],
218
+ outputs=[current_room_display, chat_history_output]
219
+ )
220
+
221
+ def send_and_clear(message, username, room_id):
222
+ if not room_id.startswith("Joined room:"):
223
+ return "Please join a room first", message
224
+
225
+ actual_room_id = room_id.replace("Joined room: ", "").strip()
226
+ formatted_msg = send_message(message, username, actual_room_id)
227
+
228
+ if formatted_msg:
229
+ return "", formatted_msg
230
+ return message, None
231
+
232
+ send_button.click(
233
+ send_and_clear,
234
+ inputs=[message_input, username_input, current_room_display],
235
+ outputs=[message_input, chat_history_output]
236
+ )
237
+
238
+ # Enter key to send message
239
+ message_input.submit(
240
+ send_and_clear,
241
+ inputs=[message_input, username_input, current_room_display],
242
+ outputs=[message_input, chat_history_output]
243
+ )
244
+
245
+ return interface
246
+
247
+ async def main():
248
+ """Main function to start the application."""
249
+ global NODE_NAME
250
+ NODE_NAME, port = get_node_name()
251
+
252
+ # Start WebSocket server
253
+ server = await start_websocket_server()
254
+
255
+ # Create and launch Gradio interface
256
+ interface = create_gradio_interface()
257
+
258
+ # Custom middleware to extract node name from URL query parameters
259
+ from starlette.middleware.base import BaseHTTPMiddleware
260
+
261
+ class NodeNameMiddleware(BaseHTTPMiddleware):
262
+ async def dispatch(self, request, call_next):
263
+ global NODE_NAME
264
+ query_params = dict(request.query_params)
265
+ if "node_name" in query_params:
266
+ NODE_NAME = query_params["node_name"]
267
+ logger.info(f"Node name set to {NODE_NAME} from URL parameter")
268
+
269
+ response = await call_next(request)
270
+ return response
271
+
272
+ # Apply middleware
273
+ app = gr.routes.App.create_app(interface)
274
+ app.add_middleware(NodeNameMiddleware)
275
+
276
+ # Launch with the modified app
277
+ gr.routes.mount_gradio_app(app, interface, path="/")
278
+
279
+ # Run the FastAPI app with uvicorn
280
+ import uvicorn
281
+ config = uvicorn.Config(app, host="0.0.0.0", port=port)
282
+ server = uvicorn.Server(config)
283
+
284
+ logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'")
285
+
286
+ await server.serve()
287
+
288
+ # Keep the WebSocket server running
289
+ await asyncio.Future()
290
+
291
+ if __name__ == "__main__":
292
+ asyncio.run(main())