awacke1 commited on
Commit
da40c9a
·
verified ·
1 Parent(s): eb41de1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -43
app.py CHANGED
@@ -53,16 +53,18 @@ def get_node_name():
53
  parser = argparse.ArgumentParser(description='Start a chat node with a specific name')
54
  parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node')
55
  parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on')
 
56
 
57
  args = parser.parse_args()
58
  node_name = args.node_name
59
  port = args.port
 
60
 
61
  # If no node name specified, generate a random one
62
  if not node_name:
63
  node_name = f"node-{uuid.uuid4().hex[:8]}"
64
 
65
- return node_name, port
66
 
67
  def get_room_history_file(room_id):
68
  """Get the filename for a room's history."""
@@ -300,9 +302,18 @@ async def broadcast_message(message, room_id):
300
 
301
  async def start_websocket_server(host='0.0.0.0', port=8765):
302
  """Start the WebSocket server."""
303
- server = await websockets.serve(websocket_handler, host, port)
304
- logger.info(f"WebSocket server started on ws://{host}:{port}")
305
- return server
 
 
 
 
 
 
 
 
 
306
 
307
  def send_message(message, username, room_id):
308
  """Function to send a message from the Gradio interface."""
@@ -741,51 +752,73 @@ def create_gradio_interface():
741
  async def main():
742
  """Main function to start the application."""
743
  global NODE_NAME, main_event_loop
744
- NODE_NAME, port = get_node_name()
745
 
746
  # Store the main event loop for later use
747
  main_event_loop = asyncio.get_running_loop()
748
 
749
  # Start WebSocket server
750
- server = await start_websocket_server()
751
-
752
- # Start message queue processor
753
- asyncio.create_task(process_message_queue())
754
-
755
- # Create and launch Gradio interface
756
- interface = create_gradio_interface()
757
-
758
- # Custom middleware to extract node name from URL query parameters
759
- from starlette.middleware.base import BaseHTTPMiddleware
760
-
761
- class NodeNameMiddleware(BaseHTTPMiddleware):
762
- async def dispatch(self, request, call_next):
763
- global NODE_NAME
764
- query_params = dict(request.query_params)
765
- if "node_name" in query_params:
766
- NODE_NAME = query_params["node_name"]
767
- logger.info(f"Node name set to {NODE_NAME} from URL parameter")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
- response = await call_next(request)
770
- return response
771
-
772
- # Apply middleware
773
- app = gr.routes.App.create_app(interface)
774
- app.add_middleware(NodeNameMiddleware)
775
-
776
- # Launch with the modified app
777
- gr.routes.mount_gradio_app(app, interface, path="/")
778
-
779
- # Run the FastAPI app with uvicorn
780
- import uvicorn
781
- config = uvicorn.Config(app, host="0.0.0.0", port=port)
782
- server = uvicorn.Server(config)
783
-
784
- logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'")
785
- logger.info("Starting message queue processor")
786
-
787
- # Start the server
788
- await server.serve()
789
 
790
  if __name__ == "__main__":
791
  asyncio.run(main())
 
53
  parser = argparse.ArgumentParser(description='Start a chat node with a specific name')
54
  parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node')
55
  parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on')
56
+ parser.add_argument('--ws-port', type=int, default=8765, help='Port to run the WebSocket server on')
57
 
58
  args = parser.parse_args()
59
  node_name = args.node_name
60
  port = args.port
61
+ ws_port = args.ws_port
62
 
63
  # If no node name specified, generate a random one
64
  if not node_name:
65
  node_name = f"node-{uuid.uuid4().hex[:8]}"
66
 
67
+ return node_name, port, ws_port
68
 
69
  def get_room_history_file(room_id):
70
  """Get the filename for a room's history."""
 
302
 
303
  async def start_websocket_server(host='0.0.0.0', port=8765):
304
  """Start the WebSocket server."""
305
+ try:
306
+ server = await websockets.serve(websocket_handler, host, port)
307
+ logger.info(f"WebSocket server started on ws://{host}:{port}")
308
+ return server
309
+ except OSError as e:
310
+ if e.errno == 98: # Address already in use
311
+ logger.warning(f"Port {port} already in use, trying port {port+1}")
312
+ # Try a different port
313
+ return await start_websocket_server(host, port+1)
314
+ else:
315
+ # If it's a different error, re-raise it
316
+ raise
317
 
318
  def send_message(message, username, room_id):
319
  """Function to send a message from the Gradio interface."""
 
752
  async def main():
753
  """Main function to start the application."""
754
  global NODE_NAME, main_event_loop
755
+ NODE_NAME, port, ws_port = get_node_name()
756
 
757
  # Store the main event loop for later use
758
  main_event_loop = asyncio.get_running_loop()
759
 
760
  # Start WebSocket server
761
+ try:
762
+ server = await start_websocket_server(port=ws_port)
763
+
764
+ # Start message queue processor
765
+ asyncio.create_task(process_message_queue())
766
+
767
+ # Create and launch Gradio interface
768
+ interface = create_gradio_interface()
769
+
770
+ # Custom middleware to extract node name from URL query parameters
771
+ from starlette.middleware.base import BaseHTTPMiddleware
772
+
773
+ class NodeNameMiddleware(BaseHTTPMiddleware):
774
+ async def dispatch(self, request, call_next):
775
+ global NODE_NAME
776
+ query_params = dict(request.query_params)
777
+ if "node_name" in query_params:
778
+ NODE_NAME = query_params["node_name"]
779
+ logger.info(f"Node name set to {NODE_NAME} from URL parameter")
780
+
781
+ response = await call_next(request)
782
+ return response
783
+
784
+ # Apply middleware
785
+ app = gr.routes.App.create_app(interface)
786
+ app.add_middleware(NodeNameMiddleware)
787
+
788
+ # Launch with the modified app
789
+ gr.routes.mount_gradio_app(app, interface, path="/")
790
+
791
+ # Run the FastAPI app with uvicorn
792
+ import uvicorn
793
+ config = uvicorn.Config(app, host="0.0.0.0", port=port)
794
+
795
+ # Try to create the server with retries for port conflicts
796
+ server_started = False
797
+ max_retries = 5
798
+ current_port = port
799
+
800
+ for attempt in range(max_retries):
801
+ try:
802
+ config = uvicorn.Config(app, host="0.0.0.0", port=current_port)
803
+ server = uvicorn.Server(config)
804
+ logger.info(f"Starting Gradio interface on http://0.0.0.0:{current_port} with node name '{NODE_NAME}'")
805
+ logger.info("Starting message queue processor")
806
+ await server.serve()
807
+ server_started = True
808
+ break
809
+ except OSError as e:
810
+ if e.errno == 98: # Address already in use
811
+ current_port += 1
812
+ logger.warning(f"Port {current_port-1} already in use, trying port {current_port}")
813
+ else:
814
+ raise
815
+
816
+ if not server_started:
817
+ logger.error(f"Failed to start server after {max_retries} attempts")
818
 
819
+ except Exception as e:
820
+ logger.error(f"Error in main: {e}")
821
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
 
823
  if __name__ == "__main__":
824
  asyncio.run(main())