da03 commited on
Commit
d4eaeb4
·
1 Parent(s): 4bedca7
Files changed (3) hide show
  1. dispatcher.py +61 -3
  2. start_system.sh +3 -2
  3. static/index.html +11 -0
dispatcher.py CHANGED
@@ -864,10 +864,59 @@ class SessionManager:
864
  except Exception as e:
865
  logger.error(f"Error in system state validation: {e}")
866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
  async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict):
868
  """Forward input to worker asynchronously"""
869
  try:
870
- async with aiohttp.ClientSession() as client_session:
871
  async with client_session.post(
872
  f"{worker.endpoint}/process_input",
873
  json={
@@ -876,10 +925,15 @@ class SessionManager:
876
  }
877
  ) as response:
878
  if response.status != 200:
879
- logger.error(f"Worker returned status {response.status}")
880
- # Optionally handle worker errors here
 
 
 
881
  except Exception as e:
882
  logger.error(f"Error forwarding to worker {worker.worker_id}: {e}")
 
 
883
 
884
  # Global session manager
885
  session_manager = SessionManager()
@@ -1110,6 +1164,10 @@ async def periodic_worker_health_check():
1110
 
1111
  for worker_id, worker_address in disconnected_workers:
1112
  analytics.log_worker_disconnected(worker_id, worker_address)
 
 
 
 
1113
  del session_manager.workers[worker_id]
1114
  logger.warning(f"Removed disconnected worker {worker_id} ({worker_address})")
1115
 
 
864
  except Exception as e:
865
  logger.error(f"Error in system state validation: {e}")
866
 
867
+ async def _handle_worker_failure(self, failed_worker_id: str):
868
+ """Handle sessions when a worker fails - end sessions and put users back in queue"""
869
+ logger.warning(f"Handling failure of worker {failed_worker_id}")
870
+
871
+ # Find all sessions assigned to this worker
872
+ failed_sessions = []
873
+ for session_id, worker_id in list(self.active_sessions.items()):
874
+ if worker_id == failed_worker_id:
875
+ failed_sessions.append(session_id)
876
+
877
+ logger.warning(f"Found {len(failed_sessions)} sessions on failed worker {failed_worker_id}")
878
+
879
+ for session_id in failed_sessions:
880
+ session = self.sessions.get(session_id)
881
+ if session:
882
+ logger.info(f"Recovering session {session_id} from failed worker")
883
+
884
+ # Notify user about the worker failure
885
+ try:
886
+ await session.websocket.send_json({
887
+ "type": "worker_failure",
888
+ "message": "GPU worker failed. Reconnecting you to a healthy worker..."
889
+ })
890
+ except Exception as e:
891
+ logger.error(f"Failed to notify session {session_id} about worker failure: {e}")
892
+
893
+ # Remove from active sessions
894
+ if session_id in self.active_sessions:
895
+ del self.active_sessions[session_id]
896
+
897
+ # Reset session state and put back in queue
898
+ session.status = SessionStatus.QUEUED
899
+ session.worker_id = None
900
+ session.queue_start_time = time.time()
901
+ session.max_session_time = None # Reset time limits
902
+ session.session_limit_start_time = None
903
+ session.session_warning_sent = False
904
+ session.idle_warning_sent = False
905
+
906
+ # Add back to front of queue (they were already active)
907
+ if session_id not in self.session_queue:
908
+ self.session_queue.insert(0, session_id)
909
+ logger.info(f"Added session {session_id} to front of queue for recovery")
910
+
911
+ # Process queue to reassign recovered sessions to healthy workers
912
+ if failed_sessions:
913
+ logger.info(f"Processing queue to reassign {len(failed_sessions)} recovered sessions")
914
+ await self.process_queue()
915
+
916
  async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict):
917
  """Forward input to worker asynchronously"""
918
  try:
919
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as client_session:
920
  async with client_session.post(
921
  f"{worker.endpoint}/process_input",
922
  json={
 
925
  }
926
  ) as response:
927
  if response.status != 200:
928
+ logger.error(f"Worker {worker.worker_id} returned status {response.status}")
929
+ except asyncio.TimeoutError:
930
+ logger.error(f"Worker {worker.worker_id} timeout - may be unresponsive")
931
+ # Mark worker as potentially dead for faster detection
932
+ worker.last_ping = 0 # This will cause it to be removed on next health check
933
  except Exception as e:
934
  logger.error(f"Error forwarding to worker {worker.worker_id}: {e}")
935
+ # Mark worker as potentially dead for faster detection
936
+ worker.last_ping = 0
937
 
938
  # Global session manager
939
  session_manager = SessionManager()
 
1164
 
1165
  for worker_id, worker_address in disconnected_workers:
1166
  analytics.log_worker_disconnected(worker_id, worker_address)
1167
+
1168
+ # Handle any active sessions on this dead worker
1169
+ await session_manager._handle_worker_failure(worker_id)
1170
+
1171
  del session_manager.workers[worker_id]
1172
  logger.warning(f"Removed disconnected worker {worker_id} ({worker_address})")
1173
 
start_system.sh CHANGED
@@ -168,8 +168,9 @@ while true; do
168
  CURRENT_WORKERS=$(ps aux | grep -c "python.*worker.py.*--worker-address" || echo "0")
169
  if [ "$CURRENT_WORKERS" -lt "$NUM_GPUS" ]; then
170
  echo "⚠️ Some workers died unexpectedly. Expected $NUM_GPUS, found $CURRENT_WORKERS"
171
- cleanup
172
- exit 1
 
173
  fi
174
 
175
  sleep 5
 
168
  CURRENT_WORKERS=$(ps aux | grep -c "python.*worker.py.*--worker-address" || echo "0")
169
  if [ "$CURRENT_WORKERS" -lt "$NUM_GPUS" ]; then
170
  echo "⚠️ Some workers died unexpectedly. Expected $NUM_GPUS, found $CURRENT_WORKERS"
171
+ echo "🔄 System will continue operating with reduced capacity"
172
+ echo "💡 Check worker logs for error details"
173
+ # Don't exit - keep system running with remaining workers
174
  fi
175
 
176
  sleep 5
static/index.html CHANGED
@@ -333,6 +333,17 @@
333
  console.log(`Queue limit applied, ${data.time_remaining} seconds remaining`);
334
  setTimeoutMessage(`⏰ Other users waiting. Time remaining: <span id="timeoutCountdown">${Math.ceil(data.time_remaining)}</span> seconds.`);
335
  startTimeoutCountdown(Math.ceil(data.time_remaining), true); // true = hide stay connected button
 
 
 
 
 
 
 
 
 
 
 
336
  }
337
  };
338
  }
 
333
  console.log(`Queue limit applied, ${data.time_remaining} seconds remaining`);
334
  setTimeoutMessage(`⏰ Other users waiting. Time remaining: <span id="timeoutCountdown">${Math.ceil(data.time_remaining)}</span> seconds.`);
335
  startTimeoutCountdown(Math.ceil(data.time_remaining), true); // true = hide stay connected button
336
+ } else if (data.type === "worker_failure") {
337
+ console.log("Worker failure detected, reconnecting...");
338
+ showConnectionStatus("🔄 GPU worker failed. Reconnecting to healthy worker...");
339
+ // Clear the canvas to show we're reconnecting
340
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
341
+ ctx.fillStyle = '#f0f0f0';
342
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
343
+ ctx.fillStyle = '#666';
344
+ ctx.font = '20px Arial';
345
+ ctx.textAlign = 'center';
346
+ ctx.fillText('🔄 Reconnecting to healthy GPU...', canvas.width/2, canvas.height/2);
347
  }
348
  };
349
  }