Update app.py
Browse files
app.py
CHANGED
@@ -53,16 +53,18 @@ def get_node_name():
|
|
53 |
parser = argparse.ArgumentParser(description='Start a chat node with a specific name')
|
54 |
parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node')
|
55 |
parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on')
|
|
|
56 |
|
57 |
args = parser.parse_args()
|
58 |
node_name = args.node_name
|
59 |
port = args.port
|
|
|
60 |
|
61 |
# If no node name specified, generate a random one
|
62 |
if not node_name:
|
63 |
node_name = f"node-{uuid.uuid4().hex[:8]}"
|
64 |
|
65 |
-
return node_name, port
|
66 |
|
67 |
def get_room_history_file(room_id):
|
68 |
"""Get the filename for a room's history."""
|
@@ -300,9 +302,18 @@ async def broadcast_message(message, room_id):
|
|
300 |
|
301 |
async def start_websocket_server(host='0.0.0.0', port=8765):
|
302 |
"""Start the WebSocket server."""
|
303 |
-
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
def send_message(message, username, room_id):
|
308 |
"""Function to send a message from the Gradio interface."""
|
@@ -741,51 +752,73 @@ def create_gradio_interface():
|
|
741 |
async def main():
|
742 |
"""Main function to start the application."""
|
743 |
global NODE_NAME, main_event_loop
|
744 |
-
NODE_NAME, port = get_node_name()
|
745 |
|
746 |
# Store the main event loop for later use
|
747 |
main_event_loop = asyncio.get_running_loop()
|
748 |
|
749 |
# Start WebSocket server
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
# Apply middleware
|
773 |
-
app = gr.routes.App.create_app(interface)
|
774 |
-
app.add_middleware(NodeNameMiddleware)
|
775 |
-
|
776 |
-
# Launch with the modified app
|
777 |
-
gr.routes.mount_gradio_app(app, interface, path="/")
|
778 |
-
|
779 |
-
# Run the FastAPI app with uvicorn
|
780 |
-
import uvicorn
|
781 |
-
config = uvicorn.Config(app, host="0.0.0.0", port=port)
|
782 |
-
server = uvicorn.Server(config)
|
783 |
-
|
784 |
-
logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'")
|
785 |
-
logger.info("Starting message queue processor")
|
786 |
-
|
787 |
-
# Start the server
|
788 |
-
await server.serve()
|
789 |
|
790 |
if __name__ == "__main__":
|
791 |
asyncio.run(main())
|
|
|
53 |
parser = argparse.ArgumentParser(description='Start a chat node with a specific name')
|
54 |
parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node')
|
55 |
parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on')
|
56 |
+
parser.add_argument('--ws-port', type=int, default=8765, help='Port to run the WebSocket server on')
|
57 |
|
58 |
args = parser.parse_args()
|
59 |
node_name = args.node_name
|
60 |
port = args.port
|
61 |
+
ws_port = args.ws_port
|
62 |
|
63 |
# If no node name specified, generate a random one
|
64 |
if not node_name:
|
65 |
node_name = f"node-{uuid.uuid4().hex[:8]}"
|
66 |
|
67 |
+
return node_name, port, ws_port
|
68 |
|
69 |
def get_room_history_file(room_id):
|
70 |
"""Get the filename for a room's history."""
|
|
|
302 |
|
303 |
async def start_websocket_server(host='0.0.0.0', port=8765):
|
304 |
"""Start the WebSocket server."""
|
305 |
+
try:
|
306 |
+
server = await websockets.serve(websocket_handler, host, port)
|
307 |
+
logger.info(f"WebSocket server started on ws://{host}:{port}")
|
308 |
+
return server
|
309 |
+
except OSError as e:
|
310 |
+
if e.errno == 98: # Address already in use
|
311 |
+
logger.warning(f"Port {port} already in use, trying port {port+1}")
|
312 |
+
# Try a different port
|
313 |
+
return await start_websocket_server(host, port+1)
|
314 |
+
else:
|
315 |
+
# If it's a different error, re-raise it
|
316 |
+
raise
|
317 |
|
318 |
def send_message(message, username, room_id):
|
319 |
"""Function to send a message from the Gradio interface."""
|
|
|
752 |
async def main():
|
753 |
"""Main function to start the application."""
|
754 |
global NODE_NAME, main_event_loop
|
755 |
+
NODE_NAME, port, ws_port = get_node_name()
|
756 |
|
757 |
# Store the main event loop for later use
|
758 |
main_event_loop = asyncio.get_running_loop()
|
759 |
|
760 |
# Start WebSocket server
|
761 |
+
try:
|
762 |
+
server = await start_websocket_server(port=ws_port)
|
763 |
+
|
764 |
+
# Start message queue processor
|
765 |
+
asyncio.create_task(process_message_queue())
|
766 |
+
|
767 |
+
# Create and launch Gradio interface
|
768 |
+
interface = create_gradio_interface()
|
769 |
+
|
770 |
+
# Custom middleware to extract node name from URL query parameters
|
771 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
772 |
+
|
773 |
+
class NodeNameMiddleware(BaseHTTPMiddleware):
|
774 |
+
async def dispatch(self, request, call_next):
|
775 |
+
global NODE_NAME
|
776 |
+
query_params = dict(request.query_params)
|
777 |
+
if "node_name" in query_params:
|
778 |
+
NODE_NAME = query_params["node_name"]
|
779 |
+
logger.info(f"Node name set to {NODE_NAME} from URL parameter")
|
780 |
+
|
781 |
+
response = await call_next(request)
|
782 |
+
return response
|
783 |
+
|
784 |
+
# Apply middleware
|
785 |
+
app = gr.routes.App.create_app(interface)
|
786 |
+
app.add_middleware(NodeNameMiddleware)
|
787 |
+
|
788 |
+
# Launch with the modified app
|
789 |
+
gr.routes.mount_gradio_app(app, interface, path="/")
|
790 |
+
|
791 |
+
# Run the FastAPI app with uvicorn
|
792 |
+
import uvicorn
|
793 |
+
config = uvicorn.Config(app, host="0.0.0.0", port=port)
|
794 |
+
|
795 |
+
# Try to create the server with retries for port conflicts
|
796 |
+
server_started = False
|
797 |
+
max_retries = 5
|
798 |
+
current_port = port
|
799 |
+
|
800 |
+
for attempt in range(max_retries):
|
801 |
+
try:
|
802 |
+
config = uvicorn.Config(app, host="0.0.0.0", port=current_port)
|
803 |
+
server = uvicorn.Server(config)
|
804 |
+
logger.info(f"Starting Gradio interface on http://0.0.0.0:{current_port} with node name '{NODE_NAME}'")
|
805 |
+
logger.info("Starting message queue processor")
|
806 |
+
await server.serve()
|
807 |
+
server_started = True
|
808 |
+
break
|
809 |
+
except OSError as e:
|
810 |
+
if e.errno == 98: # Address already in use
|
811 |
+
current_port += 1
|
812 |
+
logger.warning(f"Port {current_port-1} already in use, trying port {current_port}")
|
813 |
+
else:
|
814 |
+
raise
|
815 |
+
|
816 |
+
if not server_started:
|
817 |
+
logger.error(f"Failed to start server after {max_retries} attempts")
|
818 |
|
819 |
+
except Exception as e:
|
820 |
+
logger.error(f"Error in main: {e}")
|
821 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
822 |
|
823 |
if __name__ == "__main__":
|
824 |
asyncio.run(main())
|