|
import gradio as gr |
|
import asyncio |
|
import websockets |
|
import json |
|
import uuid |
|
import argparse |
|
import urllib.parse |
|
from datetime import datetime |
|
import logging |
|
import sys |
|
import os |
|
import time |
|
from pathlib import Path |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger("chat-node") |
|
|
|
|
|
active_connections = {} |
|
|
|
chat_history = {} |
|
|
|
file_modification_times = {} |
|
|
|
sector_users = {} |
|
|
|
main_event_loop = None |
|
message_queue = [] |
|
|
|
|
|
GRID_WIDTH = 10 |
|
GRID_HEIGHT = 10 |
|
|
|
|
|
HISTORY_DIR = "chat_history" |
|
|
|
|
|
os.makedirs(HISTORY_DIR, exist_ok=True) |
|
|
|
|
|
README_PATH = os.path.join(HISTORY_DIR, "README.md") |
|
if not os.path.exists(README_PATH): |
|
with open(README_PATH, "w") as f: |
|
f.write("# Chat History\n\nThis directory contains persistent chat history files.\n") |
|
|
|
|
|
def get_node_name(): |
|
parser = argparse.ArgumentParser(description='Start a chat node with a specific name') |
|
parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node') |
|
parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on') |
|
|
|
args = parser.parse_args() |
|
node_name = args.node_name |
|
port = args.port |
|
|
|
|
|
if not node_name: |
|
node_name = f"node-{uuid.uuid4().hex[:8]}" |
|
|
|
return node_name, port |
|
|
|
def get_room_history_file(room_id): |
|
"""Get the filename for a room's history.""" |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
return os.path.join(HISTORY_DIR, f"{room_id}_{timestamp}.jsonl") |
|
|
|
def get_all_room_history_files(room_id): |
|
"""Get all history files for a specific room.""" |
|
files = [] |
|
for file in os.listdir(HISTORY_DIR): |
|
if file.startswith(f"{room_id}_") and file.endswith(".jsonl"): |
|
files.append(os.path.join(HISTORY_DIR, file)) |
|
|
|
files.sort(key=lambda x: os.path.getmtime(x), reverse=True) |
|
return files |
|
|
|
def get_all_history_files(): |
|
"""Get a list of all unique room IDs from history files, sorted by modification time (newest first).""" |
|
try: |
|
rooms = {} |
|
|
|
for file in os.listdir(HISTORY_DIR): |
|
if file.endswith(".jsonl"): |
|
file_path = os.path.join(HISTORY_DIR, file) |
|
mod_time = os.path.getmtime(file_path) |
|
|
|
|
|
parts = file.split('_', 1) |
|
if len(parts) > 0: |
|
room_id = parts[0] |
|
|
|
|
|
if room_id not in rooms or mod_time > rooms[room_id][1]: |
|
rooms[room_id] = (file_path, mod_time) |
|
|
|
|
|
history_files = [(room_id, file_path, mod_time) for room_id, (file_path, mod_time) in rooms.items()] |
|
history_files.sort(key=lambda x: x[2], reverse=True) |
|
|
|
return history_files |
|
except Exception as e: |
|
logger.error(f"Error in get_all_history_files: {e}") |
|
return [] |
|
|
|
def load_room_history(room_id): |
|
"""Load chat history for a room from all persistent storage files.""" |
|
if room_id not in chat_history: |
|
chat_history[room_id] = [] |
|
|
|
|
|
history_files = get_all_room_history_files(room_id) |
|
|
|
|
|
for file in history_files: |
|
if file not in file_modification_times: |
|
file_modification_times[file] = os.path.getmtime(file) |
|
|
|
|
|
messages = [] |
|
for history_file in history_files: |
|
try: |
|
with open(history_file, 'r') as f: |
|
for line in f: |
|
line = line.strip() |
|
if line: |
|
try: |
|
data = json.loads(line) |
|
messages.append(data) |
|
except json.JSONDecodeError: |
|
logger.error(f"Error parsing JSON line in {history_file}") |
|
except Exception as e: |
|
logger.error(f"Error loading history from {history_file}: {e}") |
|
|
|
|
|
messages.sort(key=lambda x: x.get("timestamp", ""), reverse=False) |
|
chat_history[room_id] = messages |
|
|
|
logger.info(f"Loaded {len(messages)} messages from {len(history_files)} files for room {room_id}") |
|
|
|
|
|
if room_id not in sector_users: |
|
sector_users[room_id] = set() |
|
|
|
return chat_history[room_id] |
|
|
|
def save_message_to_history(room_id, message): |
|
"""Save a single message to the newest history file for a room.""" |
|
|
|
history_files = get_all_room_history_files(room_id) |
|
|
|
if not history_files: |
|
|
|
history_file = get_room_history_file(room_id) |
|
else: |
|
|
|
newest_file = history_files[0] |
|
if os.path.getsize(newest_file) > 1024 * 1024: |
|
history_file = get_room_history_file(room_id) |
|
else: |
|
history_file = newest_file |
|
|
|
try: |
|
|
|
with open(history_file, 'a') as f: |
|
f.write(json.dumps(message) + '\n') |
|
|
|
|
|
file_modification_times[history_file] = os.path.getmtime(history_file) |
|
|
|
logger.debug(f"Saved message to {history_file}") |
|
except Exception as e: |
|
logger.error(f"Error saving message to {history_file}: {e}") |
|
|
|
def check_for_new_messages(): |
|
"""Check for new messages in all history files.""" |
|
updated_rooms = set() |
|
|
|
|
|
for file in os.listdir(HISTORY_DIR): |
|
if file.endswith(".jsonl"): |
|
file_path = os.path.join(HISTORY_DIR, file) |
|
current_mtime = os.path.getmtime(file_path) |
|
|
|
|
|
if file_path not in file_modification_times or current_mtime > file_modification_times[file_path]: |
|
|
|
parts = file.split('_', 1) |
|
if len(parts) > 0: |
|
room_id = parts[0] |
|
updated_rooms.add(room_id) |
|
|
|
|
|
file_modification_times[file_path] = current_mtime |
|
|
|
|
|
for room_id in updated_rooms: |
|
if room_id in chat_history: |
|
|
|
old_history_len = len(chat_history[room_id]) |
|
|
|
chat_history[room_id] = [] |
|
load_room_history(room_id) |
|
new_history_len = len(chat_history[room_id]) |
|
|
|
if new_history_len > old_history_len: |
|
logger.info(f"Found {new_history_len - old_history_len} new messages for room {room_id}") |
|
|
|
return updated_rooms |
|
|
|
def get_sector_coordinates(room_id): |
|
"""Convert a room ID to grid coordinates, or assign new ones.""" |
|
try: |
|
|
|
if ',' in room_id: |
|
x, y = map(int, room_id.split(',')) |
|
return max(0, min(x, GRID_WIDTH-1)), max(0, min(y, GRID_HEIGHT-1)) |
|
except: |
|
pass |
|
|
|
|
|
hash_val = hash(room_id) |
|
x = abs(hash_val) % GRID_WIDTH |
|
y = abs(hash_val >> 8) % GRID_HEIGHT |
|
|
|
return x, y |
|
|
|
def generate_sector_map(): |
|
"""Generate an ASCII representation of the sector map.""" |
|
|
|
grid = [[' ' for _ in range(GRID_WIDTH)] for _ in range(GRID_HEIGHT)] |
|
|
|
|
|
for room_id, users in sector_users.items(): |
|
if users: |
|
x, y = get_sector_coordinates(room_id) |
|
user_count = len(users) |
|
grid[y][x] = str(min(user_count, 9)) if user_count < 10 else '+' |
|
|
|
|
|
header = ' ' + ''.join([str(i % 10) for i in range(GRID_WIDTH)]) |
|
map_str = header + '\n' |
|
|
|
for y in range(GRID_HEIGHT): |
|
row = f"{y % 10}|" |
|
for x in range(GRID_WIDTH): |
|
row += grid[y][x] |
|
row += '|' |
|
map_str += row + '\n' |
|
|
|
footer = ' ' + ''.join([str(i % 10) for i in range(GRID_WIDTH)]) |
|
map_str += footer |
|
|
|
return f"```\n{map_str}\n```\n\nLegend: Number indicates users in sector. '+' means 10+ users." |
|
|
|
def list_available_rooms(): |
|
"""List all available chat rooms with their last activity time and user count.""" |
|
try: |
|
history_files = get_all_history_files() |
|
|
|
if not history_files: |
|
return "No chat rooms available yet. Create one by joining a room!" |
|
|
|
room_list = "### Available Chat Rooms (Sectors)\n\n" |
|
room_list += "| Room ID | Sector | Users | Last Activity |\n" |
|
room_list += "|---------|--------|-------|---------------|\n" |
|
|
|
for room_id, file_path, mod_time in history_files: |
|
x, y = get_sector_coordinates(room_id) |
|
user_count = len(sector_users.get(room_id, set())) |
|
last_activity = datetime.fromtimestamp(mod_time).strftime("%Y-%m-%d %H:%M:%S") |
|
room_list += f"| {room_id} | ({x},{y}) | {user_count} | {last_activity} |\n" |
|
|
|
room_list += "\n\n### Sector Map\n\n" + generate_sector_map() |
|
|
|
return room_list |
|
except Exception as e: |
|
logger.error(f"Error in list_available_rooms: {e}") |
|
return f"Error listing rooms: {str(e)}" |
|
|
|
async def broadcast_message(message, room_id): |
|
"""Broadcast a message to all clients in a room.""" |
|
if room_id in active_connections: |
|
disconnected_clients = [] |
|
|
|
for client_id, websocket in active_connections[room_id].items(): |
|
try: |
|
await websocket.send(json.dumps(message)) |
|
except websockets.exceptions.ConnectionClosed: |
|
disconnected_clients.append(client_id) |
|
|
|
|
|
for client_id in disconnected_clients: |
|
del active_connections[room_id][client_id] |
|
|
|
async def start_websocket_server(host='0.0.0.0', port=8765): |
|
"""Start the WebSocket server.""" |
|
server = await websockets.serve(websocket_handler, host, port) |
|
logger.info(f"WebSocket server started on ws://{host}:{port}") |
|
return server |
|
|
|
def send_message(message, username, room_id): |
|
"""Function to send a message from the Gradio interface.""" |
|
if not message.strip(): |
|
return None |
|
|
|
global message_queue |
|
|
|
msg_data = { |
|
"type": "chat", |
|
"content": message, |
|
"username": username, |
|
"room_id": room_id |
|
} |
|
|
|
|
|
message_queue.append(msg_data) |
|
|
|
|
|
formatted_msg = f"{username}: {message}" |
|
return formatted_msg |
|
|
|
def send_clear_command(): |
|
"""Send a command to clear all chat history.""" |
|
global message_queue |
|
|
|
msg_data = { |
|
"type": "command", |
|
"command": "clear_history", |
|
"username": "System" |
|
} |
|
|
|
|
|
message_queue.append(msg_data) |
|
|
|
return "🧹 Clearing all chat history..." |
|
|
|
async def clear_all_history(): |
|
"""Clear all chat history for all rooms.""" |
|
global chat_history, sector_users |
|
|
|
|
|
chat_history = {} |
|
sector_users = {} |
|
|
|
|
|
for file in os.listdir(HISTORY_DIR): |
|
if file.endswith(".jsonl"): |
|
try: |
|
os.remove(os.path.join(HISTORY_DIR, file)) |
|
except Exception as e: |
|
logger.error(f"Error removing file {file}: {e}") |
|
|
|
|
|
clear_msg = { |
|
"type": "system", |
|
"content": "🧹 All chat history has been cleared by a user", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system" |
|
} |
|
|
|
for room_id in list(active_connections.keys()): |
|
clear_msg["room_id"] = room_id |
|
await broadcast_message(clear_msg, room_id) |
|
|
|
logger.info("All chat history cleared") |
|
return "All chat history cleared" |
|
|
|
def join_room(room_id, chat_history_output): |
|
"""Join a specific chat room.""" |
|
if not room_id.strip(): |
|
return "Please enter a valid room ID", chat_history_output |
|
|
|
|
|
room_id = urllib.parse.quote(room_id.strip()) |
|
|
|
|
|
history = load_room_history(room_id) |
|
|
|
|
|
x, y = get_sector_coordinates(room_id) |
|
|
|
|
|
formatted_history = [f"You are now in Sector ({x},{y}) - Room ID: {room_id}"] |
|
formatted_history.append(f"Sector Map:\n{generate_sector_map()}") |
|
|
|
for msg in history: |
|
if msg.get("type") == "chat": |
|
sender_node = f" [{msg.get('sender_node', 'unknown')}]" if "sender_node" in msg else "" |
|
time_str = "" |
|
if "timestamp" in msg: |
|
try: |
|
dt = datetime.fromisoformat(msg["timestamp"]) |
|
time_str = f"[{dt.strftime('%H:%M:%S')}] " |
|
except: |
|
pass |
|
formatted_history.append(f"{time_str}{msg.get('username', 'Anonymous')}{sender_node}: {msg.get('content', '')}") |
|
elif msg.get("type") == "system": |
|
formatted_history.append(f"System: {msg.get('content', '')}") |
|
|
|
return f"Joined room: {room_id}", formatted_history |
|
|
|
async def websocket_handler(websocket, path): |
|
"""Handle WebSocket connections.""" |
|
client_id = str(uuid.uuid4()) |
|
room_id = "default" |
|
|
|
try: |
|
|
|
path_parts = path.strip('/').split('/') |
|
room_id = path_parts[0] if path_parts else "default" |
|
|
|
|
|
if room_id not in active_connections: |
|
active_connections[room_id] = {} |
|
|
|
active_connections[room_id][client_id] = websocket |
|
|
|
|
|
if room_id not in sector_users: |
|
sector_users[room_id] = set() |
|
sector_users[room_id].add(client_id) |
|
|
|
|
|
x, y = get_sector_coordinates(room_id) |
|
|
|
|
|
room_history = load_room_history(room_id) |
|
|
|
|
|
welcome_msg = { |
|
"type": "system", |
|
"content": f"Welcome to room '{room_id}' (Sector {x},{y})! Connected from node '{NODE_NAME}'", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await websocket.send(json.dumps(welcome_msg)) |
|
|
|
|
|
map_msg = { |
|
"type": "system", |
|
"content": f"Sector Map:\n{generate_sector_map()}", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await websocket.send(json.dumps(map_msg)) |
|
|
|
|
|
for msg in room_history: |
|
await websocket.send(json.dumps(msg)) |
|
|
|
|
|
join_msg = { |
|
"type": "system", |
|
"content": f"User joined the room (Sector {x},{y}) - {len(sector_users[room_id])} users now present", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await broadcast_message(join_msg, room_id) |
|
save_message_to_history(room_id, join_msg) |
|
|
|
logger.info(f"New client {client_id} connected to room {room_id} (Sector {x},{y})") |
|
|
|
|
|
async for message in websocket: |
|
try: |
|
data = json.loads(message) |
|
|
|
|
|
if data.get("type") == "command" and data.get("command") == "clear_history": |
|
result = await clear_all_history() |
|
continue |
|
|
|
|
|
if data.get("type") == "command" and data.get("command") == "show_map": |
|
map_msg = { |
|
"type": "system", |
|
"content": f"Sector Map:\n{generate_sector_map()}", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await websocket.send(json.dumps(map_msg)) |
|
continue |
|
|
|
|
|
data["timestamp"] = datetime.now().isoformat() |
|
data["sender_node"] = NODE_NAME |
|
data["room_id"] = room_id |
|
|
|
|
|
chat_history[room_id].append(data) |
|
if len(chat_history[room_id]) > 500: |
|
chat_history[room_id] = chat_history[room_id][-500:] |
|
|
|
|
|
save_message_to_history(room_id, data) |
|
|
|
|
|
await broadcast_message(data, room_id) |
|
|
|
except json.JSONDecodeError: |
|
error_msg = { |
|
"type": "error", |
|
"content": "Invalid JSON format", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await websocket.send(json.dumps(error_msg)) |
|
|
|
except websockets.exceptions.ConnectionClosed: |
|
logger.info(f"Client {client_id} disconnected from room {room_id}") |
|
finally: |
|
|
|
if room_id in active_connections and client_id in active_connections[room_id]: |
|
del active_connections[room_id][client_id] |
|
|
|
|
|
if room_id in sector_users and client_id in sector_users[room_id]: |
|
sector_users[room_id].remove(client_id) |
|
|
|
|
|
x, y = get_sector_coordinates(room_id) |
|
|
|
|
|
leave_msg = { |
|
"type": "system", |
|
"content": f"User left the room (Sector {x},{y}) - {len(sector_users.get(room_id, set()))} users remaining", |
|
"timestamp": datetime.now().isoformat(), |
|
"sender": "system", |
|
"room_id": room_id |
|
} |
|
await broadcast_message(leave_msg, room_id) |
|
save_message_to_history(room_id, leave_msg) |
|
|
|
|
|
if not active_connections[room_id]: |
|
del active_connections[room_id] |
|
|
|
async def process_message_queue(): |
|
"""Process messages in the queue and broadcast them.""" |
|
global message_queue |
|
|
|
while True: |
|
|
|
if message_queue: |
|
|
|
msg_data = message_queue.pop(0) |
|
|
|
|
|
if msg_data.get("type") == "command" and msg_data.get("command") == "clear_history": |
|
await clear_all_history() |
|
elif "room_id" in msg_data: |
|
|
|
room_id = msg_data["room_id"] |
|
|
|
|
|
if "timestamp" not in msg_data: |
|
msg_data["timestamp"] = datetime.now().isoformat() |
|
|
|
|
|
if "sender_node" not in msg_data: |
|
msg_data["sender_node"] = NODE_NAME |
|
|
|
|
|
if room_id not in chat_history: |
|
chat_history[room_id] = [] |
|
chat_history[room_id].append(msg_data) |
|
|
|
|
|
save_message_to_history(room_id, msg_data) |
|
|
|
|
|
await broadcast_message(msg_data, room_id) |
|
|
|
|
|
updated_rooms = check_for_new_messages() |
|
|
|
|
|
for room_id in updated_rooms: |
|
if room_id in active_connections: |
|
|
|
|
|
if room_id in chat_history and chat_history[room_id]: |
|
|
|
latest_messages = chat_history[room_id][-5:] |
|
for msg in latest_messages: |
|
|
|
if msg.get("sender_node") != NODE_NAME: |
|
await broadcast_message(msg, room_id) |
|
|
|
|
|
await asyncio.sleep(1.0) |
|
|
|
def create_gradio_interface(): |
|
"""Create and return the Gradio interface.""" |
|
with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface: |
|
gr.Markdown(f"# Chat Node: {NODE_NAME}") |
|
gr.Markdown("Join a room by entering a room ID below or create a new one.") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
room_list = gr.Markdown(value="Loading available rooms...") |
|
refresh_button = gr.Button("🔄 Refresh Room List") |
|
with gr.Column(scale=1): |
|
clear_button = gr.Button("🧹 Clear All Chat History", variant="stop") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID or use x,y coordinates") |
|
join_button = gr.Button("Join Room") |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
x_coord = gr.Number(label="X", value=0, minimum=0, maximum=GRID_WIDTH-1, step=1) |
|
y_coord = gr.Number(label="Y", value=0, minimum=0, maximum=GRID_HEIGHT-1, step=1) |
|
grid_join_button = gr.Button("Join by Coordinates") |
|
|
|
|
|
chat_history_output = gr.Textbox(label="Chat History", lines=20, max_lines=20) |
|
|
|
|
|
with gr.Row(): |
|
username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User") |
|
with gr.Column(scale=3): |
|
message_input = gr.Textbox( |
|
label="Message", |
|
placeholder="Type your message here. Press Shift+Enter for new line, Enter to send.", |
|
lines=3 |
|
) |
|
with gr.Column(scale=1): |
|
send_button = gr.Button("Send") |
|
map_button = gr.Button("🗺️ Show Map") |
|
|
|
|
|
current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet") |
|
|
|
|
|
refresh_button.click( |
|
list_available_rooms, |
|
inputs=[], |
|
outputs=[room_list] |
|
) |
|
|
|
clear_button.click( |
|
send_clear_command, |
|
inputs=[], |
|
outputs=[room_list] |
|
) |
|
|
|
def join_by_coordinates(x, y): |
|
"""Join a room using grid coordinates.""" |
|
room_id = f"{int(x)},{int(y)}" |
|
return room_id |
|
|
|
|
|
grid_join_button.click( |
|
join_by_coordinates, |
|
inputs=[x_coord, y_coord], |
|
outputs=[room_id_input] |
|
).then( |
|
join_room, |
|
inputs=[room_id_input, chat_history_output], |
|
outputs=[current_room_display, chat_history_output] |
|
) |
|
|
|
join_button.click( |
|
join_room, |
|
inputs=[room_id_input, chat_history_output], |
|
outputs=[current_room_display, chat_history_output] |
|
) |
|
|
|
def send_and_clear(message, username, room_id): |
|
if not room_id.startswith("Joined room:"): |
|
return "Please join a room first", message |
|
|
|
actual_room_id = room_id.replace("Joined room: ", "").strip() |
|
|
|
|
|
message_lines = message.strip().split("\n") |
|
formatted_msg = "" |
|
|
|
for line in message_lines: |
|
if line.strip(): |
|
sent_msg = send_message(line.strip(), username, actual_room_id) |
|
if sent_msg: |
|
formatted_msg += sent_msg + "\n" |
|
|
|
if formatted_msg: |
|
return "", formatted_msg |
|
return message, None |
|
|
|
send_button.click( |
|
send_and_clear, |
|
inputs=[message_input, username_input, current_room_display], |
|
outputs=[message_input, chat_history_output] |
|
) |
|
|
|
def show_sector_map(room_id): |
|
if not room_id.startswith("Joined room:"): |
|
return "Please join a room first to view the map" |
|
|
|
return generate_sector_map() |
|
|
|
map_button.click( |
|
show_sector_map, |
|
inputs=[current_room_display], |
|
outputs=[chat_history_output] |
|
) |
|
|
|
|
|
def on_message_submit(message, username, room_id): |
|
|
|
return send_and_clear(message, username, room_id) |
|
|
|
message_input.submit( |
|
on_message_submit, |
|
inputs=[message_input, username_input, current_room_display], |
|
outputs=[message_input, chat_history_output] |
|
) |
|
|
|
|
|
interface.load( |
|
list_available_rooms, |
|
inputs=[], |
|
outputs=[room_list] |
|
) |
|
|
|
return interface |
|
|
|
async def main(): |
|
"""Main function to start the application.""" |
|
global NODE_NAME, main_event_loop |
|
NODE_NAME, port = get_node_name() |
|
|
|
|
|
main_event_loop = asyncio.get_running_loop() |
|
|
|
|
|
server = await start_websocket_server() |
|
|
|
|
|
asyncio.create_task(process_message_queue()) |
|
|
|
|
|
interface = create_gradio_interface() |
|
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
|
class NodeNameMiddleware(BaseHTTPMiddleware): |
|
async def dispatch(self, request, call_next): |
|
global NODE_NAME |
|
query_params = dict(request.query_params) |
|
if "node_name" in query_params: |
|
NODE_NAME = query_params["node_name"] |
|
logger.info(f"Node name set to {NODE_NAME} from URL parameter") |
|
|
|
response = await call_next(request) |
|
return response |
|
|
|
|
|
app = gr.routes.App.create_app(interface) |
|
app.add_middleware(NodeNameMiddleware) |
|
|
|
|
|
gr.routes.mount_gradio_app(app, interface, path="/") |
|
|
|
|
|
import uvicorn |
|
config = uvicorn.Config(app, host="0.0.0.0", port=port) |
|
server = uvicorn.Server(config) |
|
|
|
logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'") |
|
logger.info("Starting message queue processor") |
|
|
|
|
|
await server.serve() |
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |