import gradio as gr import asyncio import websockets import json import uuid import argparse import urllib.parse from datetime import datetime import logging import sys # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger("chat-node") # Dictionary to store active connections active_connections = {} # Dictionary to store message history for each chat room chat_history = {} # Get node name from URL or command line def get_node_name(): parser = argparse.ArgumentParser(description='Start a chat node with a specific name') parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node') parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on') args = parser.parse_args() node_name = args.node_name port = args.port # If no node name specified, generate a random one if not node_name: node_name = f"node-{uuid.uuid4().hex[:8]}" return node_name, port async def websocket_handler(websocket, path): """Handle WebSocket connections.""" try: # Extract room_id from path if present path_parts = path.strip('/').split('/') room_id = path_parts[0] if path_parts else "default" # Register the new client client_id = str(uuid.uuid4()) if room_id not in active_connections: active_connections[room_id] = {} chat_history[room_id] = [] active_connections[room_id][client_id] = websocket # Send welcome message and chat history welcome_msg = { "type": "system", "content": f"Welcome to room '{room_id}'! Connected from node '{NODE_NAME}'", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(welcome_msg)) # Send chat history for msg in chat_history[room_id]: await websocket.send(json.dumps(msg)) # Broadcast join notification join_msg = { "type": "system", "content": f"User joined the room", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await broadcast_message(join_msg, room_id) logger.info(f"New client {client_id} connected to room {room_id}") # Handle messages from this client async for message in websocket: try: data = json.loads(message) # Add metadata to the message data["timestamp"] = datetime.now().isoformat() data["sender_node"] = NODE_NAME data["room_id"] = room_id # Store in history chat_history[room_id].append(data) if len(chat_history[room_id]) > 100: # Limit history to 100 messages chat_history[room_id] = chat_history[room_id][-100:] # Broadcast to all clients in the room await broadcast_message(data, room_id) except json.JSONDecodeError: error_msg = { "type": "error", "content": "Invalid JSON format", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(error_msg)) except websockets.exceptions.ConnectionClosed: logger.info(f"Client {client_id} disconnected from room {room_id}") finally: # Remove the client when disconnected if room_id in active_connections and client_id in active_connections[room_id]: del active_connections[room_id][client_id] # Broadcast leave notification leave_msg = { "type": "system", "content": f"User left the room", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await broadcast_message(leave_msg, room_id) # Clean up empty rooms if not active_connections[room_id]: del active_connections[room_id] # Optionally, you might want to keep the chat history async def broadcast_message(message, room_id): """Broadcast a message to all clients in a room.""" if room_id in active_connections: disconnected_clients = [] for client_id, websocket in active_connections[room_id].items(): try: await websocket.send(json.dumps(message)) except websockets.exceptions.ConnectionClosed: disconnected_clients.append(client_id) # Clean up disconnected clients for client_id in disconnected_clients: del active_connections[room_id][client_id] async def start_websocket_server(host='0.0.0.0', port=8765): """Start the WebSocket server.""" server = await websockets.serve(websocket_handler, host, port) logger.info(f"WebSocket server started on ws://{host}:{port}") return server def send_message(message, username, room_id): """Function to send a message from the Gradio interface.""" if not message.strip(): return None loop = asyncio.get_event_loop() msg_data = { "type": "chat", "content": message, "username": username, "room_id": room_id } loop.create_task(broadcast_message(msg_data, room_id)) # Format the message for display formatted_msg = f"{username}: {message}" return formatted_msg def join_room(room_id, chat_history_output): """Join a specific chat room.""" if not room_id.strip(): return "Please enter a valid room ID", chat_history_output # Sanitize the room ID room_id = urllib.parse.quote(room_id.strip()) # Create the room if it doesn't exist if room_id not in chat_history: chat_history[room_id] = [] # Format existing messages formatted_history = [] for msg in chat_history[room_id]: if msg.get("type") == "chat": formatted_history.append(f"{msg.get('username', 'Anonymous')}: {msg.get('content', '')}") elif msg.get("type") == "system": formatted_history.append(f"System: {msg.get('content', '')}") return f"Joined room: {room_id}", formatted_history def create_gradio_interface(): """Create and return the Gradio interface.""" with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface: gr.Markdown(f"# Chat Node: {NODE_NAME}") gr.Markdown("Join a room by entering a room ID below or create a new one.") with gr.Row(): room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID") join_button = gr.Button("Join Room") chat_history_output = gr.Textbox(label="Chat History", lines=15, max_lines=15) with gr.Row(): username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User") message_input = gr.Textbox(label="Message", placeholder="Type your message here") send_button = gr.Button("Send") # Current room display current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet") # Event handlers join_button.click( join_room, inputs=[room_id_input, chat_history_output], outputs=[current_room_display, chat_history_output] ) def send_and_clear(message, username, room_id): if not room_id.startswith("Joined room:"): return "Please join a room first", message actual_room_id = room_id.replace("Joined room: ", "").strip() formatted_msg = send_message(message, username, actual_room_id) if formatted_msg: return "", formatted_msg return message, None send_button.click( send_and_clear, inputs=[message_input, username_input, current_room_display], outputs=[message_input, chat_history_output] ) # Enter key to send message message_input.submit( send_and_clear, inputs=[message_input, username_input, current_room_display], outputs=[message_input, chat_history_output] ) return interface async def main(): """Main function to start the application.""" global NODE_NAME NODE_NAME, port = get_node_name() # Start WebSocket server server = await start_websocket_server() # Create and launch Gradio interface interface = create_gradio_interface() # Custom middleware to extract node name from URL query parameters from starlette.middleware.base import BaseHTTPMiddleware class NodeNameMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): global NODE_NAME query_params = dict(request.query_params) if "node_name" in query_params: NODE_NAME = query_params["node_name"] logger.info(f"Node name set to {NODE_NAME} from URL parameter") response = await call_next(request) return response # Apply middleware app = gr.routes.App.create_app(interface) app.add_middleware(NodeNameMiddleware) # Launch with the modified app gr.routes.mount_gradio_app(app, interface, path="/") # Run the FastAPI app with uvicorn import uvicorn config = uvicorn.Config(app, host="0.0.0.0", port=port) server = uvicorn.Server(config) logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'") await server.serve() # Keep the WebSocket server running await asyncio.Future() if __name__ == "__main__": asyncio.run(main())