|
import gradio as gr |
|
import asyncio |
|
import websockets |
|
import json |
|
import uuid |
|
import argparse |
|
import urllib.parse |
|
from datetime import datetime |
|
import logging |
|
import sys |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger("chat-node") |
|
|
|
|
|
active_connections = {} |
|
|
|
chat_history = {} |
|
|
|
|
|
def get_node_name(): |
|
parser = argparse.ArgumentParser(description='Start a chat node with a specific name') |
|
parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node') |
|
parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on') |
|
|
|
args = parser.parse_args() |
|
node_name = args.node_name |
|
port = args.port |
|
|
|
|
|
if not node_name: |
|
node_name = f"node-{uuid.uuid4().hex[:8]}" |
|
|
|
return node_name, port |
|
|
|
async def websocket_handler(websocket, path): |
|
"""Handle WebSocket connections.""" |
|
try: |
|
|
|
path_parts = path.strip('/').split('/') |
|
room_id = path_parts[0] if path_parts else "default" |
|
|
|
|
|
client_id = str(uuid.uuid4()) |
|
if room_id not in active_connections: |
|
active_connections[room_id] = {} |
|
chat_history[room_id] = [] |
|
|
|
active_connections[room_id][client_id] = websocket |
|
|
|
|
|
welcome_msg = { |
|
"type": "system", |
|
"content": f"Welcome to room '{room_id}'! Connected from node '{NODE_NAME}'", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await websocket.send(json.dumps(welcome_msg)) |
|
|
|
|
|
for msg in chat_history[room_id]: |
|
await websocket.send(json.dumps(msg)) |
|
|
|
|
|
join_msg = { |
|
"type": "system", |
|
"content": f"User joined the room", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await broadcast_message(join_msg, room_id) |
|
|
|
logger.info(f"New client {client_id} connected to room {room_id}") |
|
|
|
|
|
async for message in websocket: |
|
try: |
|
data = json.loads(message) |
|
|
|
|
|
data["timestamp"] = datetime.now().isoformat() |
|
data["sender_node"] = NODE_NAME |
|
data["room_id"] = room_id |
|
|
|
|
|
chat_history[room_id].append(data) |
|
if len(chat_history[room_id]) > 100: |
|
chat_history[room_id] = chat_history[room_id][-100:] |
|
|
|
|
|
await broadcast_message(data, room_id) |
|
|
|
except json.JSONDecodeError: |
|
error_msg = { |
|
"type": "error", |
|
"content": "Invalid JSON format", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await websocket.send(json.dumps(error_msg)) |
|
|
|
except websockets.exceptions.ConnectionClosed: |
|
logger.info(f"Client {client_id} disconnected from room {room_id}") |
|
finally: |
|
|
|
if room_id in active_connections and client_id in active_connections[room_id]: |
|
del active_connections[room_id][client_id] |
|
|
|
|
|
leave_msg = { |
|
"type": "system", |
|
"content": f"User left the room", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await broadcast_message(leave_msg, room_id) |
|
|
|
|
|
if not active_connections[room_id]: |
|
del active_connections[room_id] |
|
|
|
|
|
async def broadcast_message(message, room_id): |
|
"""Broadcast a message to all clients in a room.""" |
|
if room_id in active_connections: |
|
disconnected_clients = [] |
|
|
|
for client_id, websocket in active_connections[room_id].items(): |
|
try: |
|
await websocket.send(json.dumps(message)) |
|
except websockets.exceptions.ConnectionClosed: |
|
disconnected_clients.append(client_id) |
|
|
|
|
|
for client_id in disconnected_clients: |
|
del active_connections[room_id][client_id] |
|
|
|
async def start_websocket_server(host='0.0.0.0', port=8765): |
|
"""Start the WebSocket server.""" |
|
server = await websockets.serve(websocket_handler, host, port) |
|
logger.info(f"WebSocket server started on ws://{host}:{port}") |
|
return server |
|
|
|
|
|
main_event_loop = None |
|
message_queue = [] |
|
|
|
def send_message(message, username, room_id): |
|
"""Function to send a message from the Gradio interface.""" |
|
if not message.strip(): |
|
return None |
|
|
|
global message_queue |
|
|
|
msg_data = { |
|
"type": "chat", |
|
"content": message, |
|
"username": username, |
|
"room_id": room_id |
|
} |
|
|
|
|
|
message_queue.append(msg_data) |
|
|
|
|
|
formatted_msg = f"{username}: {message}" |
|
return formatted_msg |
|
|
|
def join_room(room_id, chat_history_output): |
|
"""Join a specific chat room.""" |
|
if not room_id.strip(): |
|
return "Please enter a valid room ID", chat_history_output |
|
|
|
|
|
room_id = urllib.parse.quote(room_id.strip()) |
|
|
|
|
|
if room_id not in chat_history: |
|
chat_history[room_id] = [] |
|
|
|
|
|
formatted_history = [] |
|
for msg in chat_history[room_id]: |
|
if msg.get("type") == "chat": |
|
formatted_history.append(f"{msg.get('username', 'Anonymous')}: {msg.get('content', '')}") |
|
elif msg.get("type") == "system": |
|
formatted_history.append(f"System: {msg.get('content', '')}") |
|
|
|
return f"Joined room: {room_id}", formatted_history |
|
|
|
def create_gradio_interface(): |
|
"""Create and return the Gradio interface.""" |
|
with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface: |
|
gr.Markdown(f"# Chat Node: {NODE_NAME}") |
|
gr.Markdown("Join a room by entering a room ID below or create a new one.") |
|
|
|
with gr.Row(): |
|
room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID") |
|
join_button = gr.Button("Join Room") |
|
|
|
chat_history_output = gr.Textbox(label="Chat History", lines=15, max_lines=15) |
|
|
|
with gr.Row(): |
|
username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User") |
|
message_input = gr.Textbox(label="Message", placeholder="Type your message here") |
|
send_button = gr.Button("Send") |
|
|
|
|
|
current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet") |
|
|
|
|
|
join_button.click( |
|
join_room, |
|
inputs=[room_id_input, chat_history_output], |
|
outputs=[current_room_display, chat_history_output] |
|
) |
|
|
|
def send_and_clear(message, username, room_id): |
|
if not room_id.startswith("Joined room:"): |
|
return "Please join a room first", message |
|
|
|
actual_room_id = room_id.replace("Joined room: ", "").strip() |
|
formatted_msg = send_message(message, username, actual_room_id) |
|
|
|
if formatted_msg: |
|
return "", formatted_msg |
|
return message, None |
|
|
|
send_button.click( |
|
send_and_clear, |
|
inputs=[message_input, username_input, current_room_display], |
|
outputs=[message_input, chat_history_output] |
|
) |
|
|
|
|
|
message_input.submit( |
|
send_and_clear, |
|
inputs=[message_input, username_input, current_room_display], |
|
outputs=[message_input, chat_history_output] |
|
) |
|
|
|
return interface |
|
|
|
async def process_message_queue(): |
|
"""Process messages in the queue and broadcast them.""" |
|
global message_queue |
|
|
|
while True: |
|
|
|
if message_queue: |
|
|
|
msg_data = message_queue.pop(0) |
|
|
|
await broadcast_message(msg_data, msg_data["room_id"]) |
|
|
|
|
|
await asyncio.sleep(0.1) |
|
|
|
async def main(): |
|
"""Main function to start the application.""" |
|
global NODE_NAME, main_event_loop |
|
NODE_NAME, port = get_node_name() |
|
|
|
|
|
main_event_loop = asyncio.get_running_loop() |
|
|
|
|
|
server = await start_websocket_server() |
|
|
|
|
|
asyncio.create_task(process_message_queue()) |
|
|
|
|
|
interface = create_gradio_interface() |
|
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
|
class NodeNameMiddleware(BaseHTTPMiddleware): |
|
async def dispatch(self, request, call_next): |
|
global NODE_NAME |
|
query_params = dict(request.query_params) |
|
if "node_name" in query_params: |
|
NODE_NAME = query_params["node_name"] |
|
logger.info(f"Node name set to {NODE_NAME} from URL parameter") |
|
|
|
response = await call_next(request) |
|
return response |
|
|
|
|
|
app = gr.routes.App.create_app(interface) |
|
app.add_middleware(NodeNameMiddleware) |
|
|
|
|
|
gr.routes.mount_gradio_app(app, interface, path="/") |
|
|
|
|
|
import uvicorn |
|
config = uvicorn.Config(app, host="0.0.0.0", port=port) |
|
server = uvicorn.Server(config) |
|
|
|
logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'") |
|
|
|
|
|
logger.info("Starting message queue processor") |
|
|
|
|
|
await server.serve() |
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |