Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
92199b3
1
Parent(s):
869b633
- dispatcher.py +153 -83
- start_workers.py +14 -8
- worker.py +46 -19
dispatcher.py
CHANGED
@@ -227,13 +227,13 @@ class SystemAnalytics:
|
|
227 |
"avg_utilization_percent": avg_utilization
|
228 |
})
|
229 |
|
230 |
-
def log_worker_registered(self, worker_id: str,
|
231 |
"""Log when a worker registers"""
|
232 |
-
self._write_log(f"⚙️ WORKER REGISTERED: {worker_id} (
|
233 |
|
234 |
-
def log_worker_disconnected(self, worker_id: str,
|
235 |
"""Log when a worker disconnects"""
|
236 |
-
self._write_log(f"⚙️ WORKER DISCONNECTED: {worker_id} (
|
237 |
|
238 |
def log_no_workers_available(self, queue_size: int):
|
239 |
"""Log critical situation when no workers are available"""
|
@@ -340,7 +340,7 @@ class UserSession:
|
|
340 |
@dataclass
|
341 |
class WorkerInfo:
|
342 |
worker_id: str
|
343 |
-
|
344 |
endpoint: str
|
345 |
is_available: bool
|
346 |
current_session: Optional[str] = None
|
@@ -352,6 +352,7 @@ class SessionManager:
|
|
352 |
self.workers: Dict[str, WorkerInfo] = {}
|
353 |
self.session_queue: List[str] = []
|
354 |
self.active_sessions: Dict[str, str] = {} # session_id -> worker_id
|
|
|
355 |
|
356 |
# Configuration
|
357 |
self.IDLE_TIMEOUT = 20.0 # When no queue
|
@@ -360,19 +361,26 @@ class SessionManager:
|
|
360 |
self.QUEUE_SESSION_WARNING_TIME = 45.0 # 15 seconds before timeout
|
361 |
self.GRACE_PERIOD = 10.0
|
362 |
|
363 |
-
async def register_worker(self, worker_id: str,
|
364 |
"""Register a new worker"""
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
self.workers[worker_id] = WorkerInfo(
|
366 |
worker_id=worker_id,
|
367 |
-
|
368 |
endpoint=endpoint,
|
369 |
is_available=True,
|
370 |
last_ping=time.time()
|
371 |
)
|
372 |
-
logger.info(f"Registered worker {worker_id}
|
|
|
373 |
|
374 |
# Log worker registration
|
375 |
-
analytics.log_worker_registered(worker_id,
|
376 |
|
377 |
# Log GPU status
|
378 |
total_gpus = len(self.workers)
|
@@ -459,81 +467,89 @@ class SessionManager:
|
|
459 |
|
460 |
async def process_queue(self):
|
461 |
"""Process the session queue"""
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
|
|
|
|
468 |
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
self.session_queue.pop(0)
|
471 |
-
|
|
|
|
|
472 |
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
analytics.log_no_workers_available(len(self.session_queue))
|
478 |
-
break # No available workers
|
479 |
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
try:
|
515 |
-
async with aiohttp.ClientSession() as client_session:
|
516 |
-
await client_session.post(f"{worker.endpoint}/init_session", json={
|
517 |
-
"session_id": session_id,
|
518 |
-
"client_id": session.client_id
|
519 |
-
})
|
520 |
-
except Exception as e:
|
521 |
-
logger.error(f"Failed to initialize session on worker {worker.worker_id}: {e}")
|
522 |
|
523 |
-
#
|
524 |
-
|
|
|
|
|
525 |
|
526 |
-
#
|
527 |
-
|
528 |
-
|
529 |
-
# After processing queue, if there are still users waiting AND we had existing active sessions,
|
530 |
-
# apply time limits to those existing sessions
|
531 |
-
if len(self.session_queue) > 0 and had_active_sessions:
|
532 |
-
await self.apply_queue_limits_to_existing_sessions()
|
533 |
-
|
534 |
-
# If queue became empty and there are active sessions with time limits, remove them
|
535 |
-
elif len(self.session_queue) == 0:
|
536 |
-
await self.remove_time_limits_if_queue_empty()
|
537 |
|
538 |
async def notify_session_start(self, session: UserSession):
|
539 |
"""Notify user that their session is starting"""
|
@@ -646,6 +662,10 @@ class SessionManager:
|
|
646 |
# Free up the worker
|
647 |
if session.worker_id and session.worker_id in self.workers:
|
648 |
worker = self.workers[session.worker_id]
|
|
|
|
|
|
|
|
|
649 |
worker.is_available = True
|
650 |
worker.current_session = None
|
651 |
|
@@ -669,6 +689,9 @@ class SessionManager:
|
|
669 |
|
670 |
logger.info(f"Ended session {session_id} with status {status}")
|
671 |
|
|
|
|
|
|
|
672 |
# Process next in queue
|
673 |
asyncio.create_task(self.process_queue())
|
674 |
|
@@ -773,6 +796,43 @@ class SessionManager:
|
|
773 |
session.user_has_interacted = True
|
774 |
logger.info(f"User started interacting in session {session_id}")
|
775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict):
|
777 |
"""Forward input to worker asynchronously"""
|
778 |
try:
|
@@ -808,7 +868,7 @@ async def register_worker(worker_info: dict):
|
|
808 |
"""Endpoint for workers to register themselves"""
|
809 |
await session_manager.register_worker(
|
810 |
worker_info["worker_id"],
|
811 |
-
worker_info["
|
812 |
worker_info["endpoint"]
|
813 |
)
|
814 |
return {"status": "registered"}
|
@@ -977,6 +1037,15 @@ async def periodic_queue_update():
|
|
977 |
except Exception as e:
|
978 |
logger.error(f"Error in periodic queue update: {e}")
|
979 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
980 |
# Background task to periodically log analytics summary
|
981 |
async def periodic_analytics_summary():
|
982 |
while True:
|
@@ -996,12 +1065,12 @@ async def periodic_worker_health_check():
|
|
996 |
|
997 |
for worker_id, worker in list(session_manager.workers.items()):
|
998 |
if current_time - worker.last_ping > 30: # 30 second timeout
|
999 |
-
disconnected_workers.append((worker_id, worker.
|
1000 |
|
1001 |
-
for worker_id,
|
1002 |
-
analytics.log_worker_disconnected(worker_id,
|
1003 |
del session_manager.workers[worker_id]
|
1004 |
-
logger.warning(f"Removed disconnected worker {worker_id} (
|
1005 |
|
1006 |
if disconnected_workers:
|
1007 |
# Log updated GPU status
|
@@ -1017,6 +1086,7 @@ async def periodic_worker_health_check():
|
|
1017 |
async def startup_event():
|
1018 |
# Start background tasks
|
1019 |
asyncio.create_task(periodic_queue_update())
|
|
|
1020 |
asyncio.create_task(periodic_analytics_summary())
|
1021 |
asyncio.create_task(periodic_worker_health_check())
|
1022 |
|
|
|
227 |
"avg_utilization_percent": avg_utilization
|
228 |
})
|
229 |
|
230 |
+
def log_worker_registered(self, worker_id: str, worker_address: str, endpoint: str):
|
231 |
"""Log when a worker registers"""
|
232 |
+
self._write_log(f"⚙️ WORKER REGISTERED: {worker_id} ({worker_address}) at {endpoint}")
|
233 |
|
234 |
+
def log_worker_disconnected(self, worker_id: str, worker_address: str):
|
235 |
"""Log when a worker disconnects"""
|
236 |
+
self._write_log(f"⚙️ WORKER DISCONNECTED: {worker_id} ({worker_address})")
|
237 |
|
238 |
def log_no_workers_available(self, queue_size: int):
|
239 |
"""Log critical situation when no workers are available"""
|
|
|
340 |
@dataclass
|
341 |
class WorkerInfo:
|
342 |
worker_id: str
|
343 |
+
worker_address: str # e.g., "localhost:8001", "192.168.1.100:8002"
|
344 |
endpoint: str
|
345 |
is_available: bool
|
346 |
current_session: Optional[str] = None
|
|
|
352 |
self.workers: Dict[str, WorkerInfo] = {}
|
353 |
self.session_queue: List[str] = []
|
354 |
self.active_sessions: Dict[str, str] = {} # session_id -> worker_id
|
355 |
+
self._queue_lock = asyncio.Lock() # Prevent race conditions in queue processing
|
356 |
|
357 |
# Configuration
|
358 |
self.IDLE_TIMEOUT = 20.0 # When no queue
|
|
|
361 |
self.QUEUE_SESSION_WARNING_TIME = 45.0 # 15 seconds before timeout
|
362 |
self.GRACE_PERIOD = 10.0
|
363 |
|
364 |
+
async def register_worker(self, worker_id: str, worker_address: str, endpoint: str):
|
365 |
"""Register a new worker"""
|
366 |
+
# Check for duplicate registrations
|
367 |
+
if worker_id in self.workers:
|
368 |
+
logger.warning(f"Worker {worker_id} already registered! Overwriting previous registration.")
|
369 |
+
logger.warning(f"Previous: {self.workers[worker_id].worker_address}, {self.workers[worker_id].endpoint}")
|
370 |
+
logger.warning(f"New: {worker_address}, {endpoint}")
|
371 |
+
|
372 |
self.workers[worker_id] = WorkerInfo(
|
373 |
worker_id=worker_id,
|
374 |
+
worker_address=worker_address,
|
375 |
endpoint=endpoint,
|
376 |
is_available=True,
|
377 |
last_ping=time.time()
|
378 |
)
|
379 |
+
logger.info(f"Registered worker {worker_id} ({worker_address}) at {endpoint}")
|
380 |
+
logger.info(f"Total workers now: {len(self.workers)} - {[w.worker_id for w in self.workers.values()]}")
|
381 |
|
382 |
# Log worker registration
|
383 |
+
analytics.log_worker_registered(worker_id, worker_address, endpoint)
|
384 |
|
385 |
# Log GPU status
|
386 |
total_gpus = len(self.workers)
|
|
|
467 |
|
468 |
async def process_queue(self):
|
469 |
"""Process the session queue"""
|
470 |
+
async with self._queue_lock: # Prevent race conditions
|
471 |
+
# Track if we had any existing active sessions before processing
|
472 |
+
had_active_sessions = len(self.active_sessions) > 0
|
473 |
+
|
474 |
+
# Add detailed logging for debugging
|
475 |
+
logger.info(f"Processing queue: {len(self.session_queue)} waiting, {len(self.active_sessions)} active")
|
476 |
+
logger.info(f"Available workers: {[f'{w.worker_id}({w.worker_address})' for w in self.workers.values() if w.is_available]}")
|
477 |
+
logger.info(f"Busy workers: {[f'{w.worker_id}({w.worker_address})' for w in self.workers.values() if not w.is_available]}")
|
478 |
|
479 |
+
while self.session_queue:
|
480 |
+
session_id = self.session_queue[0]
|
481 |
+
session = self.sessions.get(session_id)
|
482 |
+
|
483 |
+
if not session or session.status != SessionStatus.QUEUED:
|
484 |
+
self.session_queue.pop(0)
|
485 |
+
continue
|
486 |
+
|
487 |
+
worker = await self.get_available_worker()
|
488 |
+
if not worker:
|
489 |
+
# Log critical situation if no workers are available
|
490 |
+
if len(self.workers) == 0:
|
491 |
+
analytics.log_no_workers_available(len(self.session_queue))
|
492 |
+
logger.info(f"No available workers for session {session_id}. Queue processing stopped.")
|
493 |
+
break # No available workers
|
494 |
+
|
495 |
+
# Calculate wait time
|
496 |
+
wait_time = time.time() - session.queue_start_time if session.queue_start_time else 0
|
497 |
+
queue_position = self.session_queue.index(session_id) + 1
|
498 |
+
|
499 |
+
# Assign session to worker
|
500 |
self.session_queue.pop(0)
|
501 |
+
session.status = SessionStatus.ACTIVE
|
502 |
+
session.worker_id = worker.worker_id
|
503 |
+
session.last_activity = time.time()
|
504 |
|
505 |
+
# Set session time limit based on queue status AFTER processing
|
506 |
+
if len(self.session_queue) > 0:
|
507 |
+
session.max_session_time = self.MAX_SESSION_TIME_WITH_QUEUE
|
508 |
+
session.session_limit_start_time = time.time() # Track when limit started
|
|
|
|
|
509 |
|
510 |
+
worker.is_available = False
|
511 |
+
worker.current_session = session_id
|
512 |
+
self.active_sessions[session_id] = worker.worker_id
|
513 |
+
|
514 |
+
logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
|
515 |
+
logger.info(f"Active sessions now: {len(self.active_sessions)}, Available workers: {len([w for w in self.workers.values() if w.is_available])}")
|
516 |
+
|
517 |
+
# Log analytics
|
518 |
+
if wait_time > 0:
|
519 |
+
analytics.log_queue_wait(session.client_id, wait_time, queue_position)
|
520 |
+
else:
|
521 |
+
analytics.log_queue_bypass(session.client_id)
|
522 |
+
|
523 |
+
# Log GPU status
|
524 |
+
total_gpus = len(self.workers)
|
525 |
+
active_gpus = len([w for w in self.workers.values() if not w.is_available])
|
526 |
+
available_gpus = total_gpus - active_gpus
|
527 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
528 |
+
|
529 |
+
# Initialize session on worker with client_id for logging
|
530 |
+
try:
|
531 |
+
async with aiohttp.ClientSession() as client_session:
|
532 |
+
await client_session.post(f"{worker.endpoint}/init_session", json={
|
533 |
+
"session_id": session_id,
|
534 |
+
"client_id": session.client_id
|
535 |
+
})
|
536 |
+
except Exception as e:
|
537 |
+
logger.error(f"Failed to initialize session on worker {worker.worker_id}: {e}")
|
538 |
+
|
539 |
+
# Notify user that their session is starting
|
540 |
+
await self.notify_session_start(session)
|
541 |
+
|
542 |
+
# Start session monitoring
|
543 |
+
asyncio.create_task(self.monitor_active_session(session_id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
|
545 |
+
# After processing queue, if there are still users waiting AND we had existing active sessions,
|
546 |
+
# apply time limits to those existing sessions
|
547 |
+
if len(self.session_queue) > 0 and had_active_sessions:
|
548 |
+
await self.apply_queue_limits_to_existing_sessions()
|
549 |
|
550 |
+
# If queue became empty and there are active sessions with time limits, remove them
|
551 |
+
elif len(self.session_queue) == 0:
|
552 |
+
await self.remove_time_limits_if_queue_empty()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
|
554 |
async def notify_session_start(self, session: UserSession):
|
555 |
"""Notify user that their session is starting"""
|
|
|
662 |
# Free up the worker
|
663 |
if session.worker_id and session.worker_id in self.workers:
|
664 |
worker = self.workers[session.worker_id]
|
665 |
+
if not worker.is_available: # Only log if worker was actually busy
|
666 |
+
logger.info(f"Freeing worker {worker.worker_id} from session {session_id}")
|
667 |
+
else:
|
668 |
+
logger.warning(f"Worker {worker.worker_id} was already available when freeing from session {session_id}")
|
669 |
worker.is_available = True
|
670 |
worker.current_session = None
|
671 |
|
|
|
689 |
|
690 |
logger.info(f"Ended session {session_id} with status {status}")
|
691 |
|
692 |
+
# Validate system state consistency
|
693 |
+
await self._validate_system_state()
|
694 |
+
|
695 |
# Process next in queue
|
696 |
asyncio.create_task(self.process_queue())
|
697 |
|
|
|
796 |
session.user_has_interacted = True
|
797 |
logger.info(f"User started interacting in session {session_id}")
|
798 |
|
799 |
+
async def _validate_system_state(self):
|
800 |
+
"""Validate system state consistency for debugging"""
|
801 |
+
try:
|
802 |
+
# Count active sessions
|
803 |
+
active_sessions_count = len(self.active_sessions)
|
804 |
+
busy_workers_count = len([w for w in self.workers.values() if not w.is_available])
|
805 |
+
|
806 |
+
# Check for inconsistencies
|
807 |
+
if active_sessions_count != busy_workers_count:
|
808 |
+
logger.error(f"INCONSISTENCY: Active sessions ({active_sessions_count}) != Busy workers ({busy_workers_count})")
|
809 |
+
logger.error(f"Active sessions: {list(self.active_sessions.keys())}")
|
810 |
+
logger.error(f"Busy workers: {[w.worker_id for w in self.workers.values() if not w.is_available]}")
|
811 |
+
|
812 |
+
# Log detailed state
|
813 |
+
for session_id, worker_id in self.active_sessions.items():
|
814 |
+
session = self.sessions.get(session_id)
|
815 |
+
worker = self.workers.get(worker_id)
|
816 |
+
logger.error(f"Session {session_id}: status={session.status if session else 'MISSING'}, worker={worker_id}")
|
817 |
+
if worker:
|
818 |
+
logger.error(f"Worker {worker_id}: available={worker.is_available}, current_session={worker.current_session}")
|
819 |
+
|
820 |
+
# Check for orphaned workers
|
821 |
+
for worker in self.workers.values():
|
822 |
+
if not worker.is_available and worker.current_session not in self.active_sessions:
|
823 |
+
logger.error(f"ORPHANED WORKER: {worker.worker_id} is busy but session {worker.current_session} not in active_sessions")
|
824 |
+
|
825 |
+
# Check for sessions without workers
|
826 |
+
for session_id in self.active_sessions:
|
827 |
+
session = self.sessions.get(session_id)
|
828 |
+
if session and session.status == SessionStatus.ACTIVE:
|
829 |
+
worker = self.workers.get(session.worker_id)
|
830 |
+
if not worker or worker.is_available:
|
831 |
+
logger.error(f"ACTIVE SESSION WITHOUT WORKER: {session_id} has no busy worker assigned")
|
832 |
+
|
833 |
+
except Exception as e:
|
834 |
+
logger.error(f"Error in system state validation: {e}")
|
835 |
+
|
836 |
async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict):
|
837 |
"""Forward input to worker asynchronously"""
|
838 |
try:
|
|
|
868 |
"""Endpoint for workers to register themselves"""
|
869 |
await session_manager.register_worker(
|
870 |
worker_info["worker_id"],
|
871 |
+
worker_info["worker_address"],
|
872 |
worker_info["endpoint"]
|
873 |
)
|
874 |
return {"status": "registered"}
|
|
|
1037 |
except Exception as e:
|
1038 |
logger.error(f"Error in periodic queue update: {e}")
|
1039 |
|
1040 |
+
# Background task to periodically validate system state
|
1041 |
+
async def periodic_system_validation():
|
1042 |
+
while True:
|
1043 |
+
try:
|
1044 |
+
await asyncio.sleep(10) # Validate every 10 seconds
|
1045 |
+
await session_manager._validate_system_state()
|
1046 |
+
except Exception as e:
|
1047 |
+
logger.error(f"Error in periodic system validation: {e}")
|
1048 |
+
|
1049 |
# Background task to periodically log analytics summary
|
1050 |
async def periodic_analytics_summary():
|
1051 |
while True:
|
|
|
1065 |
|
1066 |
for worker_id, worker in list(session_manager.workers.items()):
|
1067 |
if current_time - worker.last_ping > 30: # 30 second timeout
|
1068 |
+
disconnected_workers.append((worker_id, worker.worker_address))
|
1069 |
|
1070 |
+
for worker_id, worker_address in disconnected_workers:
|
1071 |
+
analytics.log_worker_disconnected(worker_id, worker_address)
|
1072 |
del session_manager.workers[worker_id]
|
1073 |
+
logger.warning(f"Removed disconnected worker {worker_id} ({worker_address})")
|
1074 |
|
1075 |
if disconnected_workers:
|
1076 |
# Log updated GPU status
|
|
|
1086 |
async def startup_event():
|
1087 |
# Start background tasks
|
1088 |
asyncio.create_task(periodic_queue_update())
|
1089 |
+
asyncio.create_task(periodic_system_validation())
|
1090 |
asyncio.create_task(periodic_analytics_summary())
|
1091 |
asyncio.create_task(periodic_worker_health_check())
|
1092 |
|
start_workers.py
CHANGED
@@ -27,10 +27,11 @@ class WorkerManager:
|
|
27 |
port = 8001 + gpu_id
|
28 |
print(f"Starting worker for GPU {gpu_id} on port {port}...")
|
29 |
|
30 |
-
# Start worker process
|
|
|
31 |
cmd = [
|
32 |
sys.executable, "worker.py",
|
33 |
-
"--
|
34 |
"--dispatcher-url", self.dispatcher_url
|
35 |
]
|
36 |
|
@@ -39,30 +40,35 @@ class WorkerManager:
|
|
39 |
with open(log_file, 'w') as f:
|
40 |
f.write(f"Starting worker for GPU {gpu_id}\n")
|
41 |
|
|
|
|
|
|
|
|
|
|
|
42 |
process = subprocess.Popen(
|
43 |
cmd,
|
44 |
stdout=open(log_file, 'a'),
|
45 |
stderr=subprocess.STDOUT,
|
46 |
universal_newlines=True,
|
47 |
-
bufsize=1
|
|
|
48 |
)
|
49 |
|
50 |
self.processes.append(process)
|
51 |
-
print(f"✓ Started worker {
|
52 |
|
53 |
# Small delay between starts
|
54 |
time.sleep(1)
|
55 |
|
56 |
except Exception as e:
|
57 |
-
print(f"✗ Failed to start worker
|
58 |
self.cleanup()
|
59 |
return False
|
60 |
|
61 |
print(f"\n✓ All {self.num_gpus} workers started successfully!")
|
62 |
-
print("
|
63 |
-
print("Worker log files:")
|
64 |
for i in range(self.num_gpus):
|
65 |
-
print(f"
|
66 |
return True
|
67 |
|
68 |
def monitor_workers(self):
|
|
|
27 |
port = 8001 + gpu_id
|
28 |
print(f"Starting worker for GPU {gpu_id} on port {port}...")
|
29 |
|
30 |
+
# Start worker process with GPU isolation
|
31 |
+
worker_address = f"localhost:{port}"
|
32 |
cmd = [
|
33 |
sys.executable, "worker.py",
|
34 |
+
"--worker-address", worker_address,
|
35 |
"--dispatcher-url", self.dispatcher_url
|
36 |
]
|
37 |
|
|
|
40 |
with open(log_file, 'w') as f:
|
41 |
f.write(f"Starting worker for GPU {gpu_id}\n")
|
42 |
|
43 |
+
# Set environment variables for GPU isolation
|
44 |
+
env = os.environ.copy()
|
45 |
+
env['CUDA_VISIBLE_DEVICES'] = str(gpu_id) # Only show this GPU to the worker
|
46 |
+
env['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' # Consistent GPU ordering
|
47 |
+
|
48 |
process = subprocess.Popen(
|
49 |
cmd,
|
50 |
stdout=open(log_file, 'a'),
|
51 |
stderr=subprocess.STDOUT,
|
52 |
universal_newlines=True,
|
53 |
+
bufsize=1,
|
54 |
+
env=env # Pass the modified environment
|
55 |
)
|
56 |
|
57 |
self.processes.append(process)
|
58 |
+
print(f"✓ Started worker {worker_address} (PID: {process.pid}) - Log: {log_file}")
|
59 |
|
60 |
# Small delay between starts
|
61 |
time.sleep(1)
|
62 |
|
63 |
except Exception as e:
|
64 |
+
print(f"✗ Failed to start worker {worker_address}: {e}")
|
65 |
self.cleanup()
|
66 |
return False
|
67 |
|
68 |
print(f"\n✓ All {self.num_gpus} workers started successfully!")
|
69 |
+
print("Worker addresses:")
|
|
|
70 |
for i in range(self.num_gpus):
|
71 |
+
print(f" localhost:{8001 + i} - log: worker_gpu_{i}.log")
|
72 |
return True
|
73 |
|
74 |
def monitor_workers(self):
|
worker.py
CHANGED
@@ -16,6 +16,7 @@ import concurrent.futures
|
|
16 |
import aiohttp
|
17 |
import argparse
|
18 |
import uuid
|
|
|
19 |
|
20 |
# Configure logging
|
21 |
logging.basicConfig(level=logging.INFO)
|
@@ -26,11 +27,19 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|
26 |
torch.backends.cudnn.allow_tf32 = True
|
27 |
|
28 |
class GPUWorker:
|
29 |
-
def __init__(self,
|
30 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
self.dispatcher_url = dispatcher_url
|
32 |
-
self.worker_id = f"worker_{
|
33 |
-
|
|
|
34 |
self.current_session: Optional[str] = None
|
35 |
self.session_data: Dict[str, Any] = {}
|
36 |
|
@@ -55,11 +64,18 @@ class GPUWorker:
|
|
55 |
# Load keyboard mappings
|
56 |
self._load_keyboard_mappings()
|
57 |
|
58 |
-
logger.info(f"GPU Worker {self.worker_id} initialized on
|
59 |
|
60 |
def _initialize_model(self):
|
61 |
-
"""Initialize the model on the
|
62 |
-
logger.info(f"Initializing model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
# Load latent stats
|
65 |
with open('latent_stats.json', 'r') as f:
|
@@ -93,7 +109,7 @@ class GPUWorker:
|
|
93 |
self.padding_image = torch.zeros(*self.LATENT_DIMS).unsqueeze(0).to(self.device)
|
94 |
self.padding_image = (self.padding_image - self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1)
|
95 |
|
96 |
-
logger.info(f"Model initialized successfully
|
97 |
|
98 |
def _load_keyboard_mappings(self):
|
99 |
"""Load keyboard mappings from main.py"""
|
@@ -142,10 +158,10 @@ class GPUWorker:
|
|
142 |
async with aiohttp.ClientSession() as session:
|
143 |
await session.post(f"{self.dispatcher_url}/register_worker", json={
|
144 |
"worker_id": self.worker_id,
|
145 |
-
"
|
146 |
-
"endpoint": f"http://
|
147 |
})
|
148 |
-
logger.info(f"Successfully registered worker {self.worker_id} with dispatcher")
|
149 |
except Exception as e:
|
150 |
logger.error(f"Failed to register with dispatcher: {e}")
|
151 |
|
@@ -695,14 +711,15 @@ async def health_check():
|
|
695 |
return {
|
696 |
"status": "healthy",
|
697 |
"worker_id": worker.worker_id if worker else None,
|
698 |
-
"
|
|
|
699 |
"current_session": worker.current_session if worker else None
|
700 |
}
|
701 |
|
702 |
-
async def startup_worker(
|
703 |
"""Initialize the worker"""
|
704 |
global worker
|
705 |
-
worker = GPUWorker(
|
706 |
|
707 |
# Register with dispatcher
|
708 |
await worker.register_with_dispatcher()
|
@@ -715,16 +732,26 @@ if __name__ == "__main__":
|
|
715 |
|
716 |
# Parse command line arguments
|
717 |
parser = argparse.ArgumentParser(description="GPU Worker for Neural OS")
|
718 |
-
parser.add_argument("--
|
719 |
parser.add_argument("--dispatcher-url", type=str, default="http://localhost:8000", help="Dispatcher URL")
|
720 |
args = parser.parse_args()
|
721 |
|
722 |
-
#
|
723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
724 |
|
725 |
@app.on_event("startup")
|
726 |
async def startup_event():
|
727 |
-
await startup_worker(args.
|
728 |
|
729 |
-
logger.info(f"Starting worker
|
730 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
16 |
import aiohttp
|
17 |
import argparse
|
18 |
import uuid
|
19 |
+
import sys
|
20 |
|
21 |
# Configure logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
|
|
27 |
torch.backends.cudnn.allow_tf32 = True
|
28 |
|
29 |
class GPUWorker:
|
30 |
+
def __init__(self, worker_address: str, dispatcher_url: str = "http://localhost:8000"):
|
31 |
+
self.worker_address = worker_address # e.g., "localhost:8001", "192.168.1.100:8002"
|
32 |
+
# Parse port from worker address
|
33 |
+
if ':' in worker_address:
|
34 |
+
self.host, port_str = worker_address.split(':')
|
35 |
+
self.port = int(port_str)
|
36 |
+
else:
|
37 |
+
raise ValueError(f"Invalid worker address format: {worker_address}. Expected format: 'host:port'")
|
38 |
+
|
39 |
self.dispatcher_url = dispatcher_url
|
40 |
+
self.worker_id = f"worker_{worker_address.replace(':', '_')}_{uuid.uuid4().hex[:8]}"
|
41 |
+
# Always use GPU 0 since CUDA_VISIBLE_DEVICES limits visibility to one GPU
|
42 |
+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
43 |
self.current_session: Optional[str] = None
|
44 |
self.session_data: Dict[str, Any] = {}
|
45 |
|
|
|
64 |
# Load keyboard mappings
|
65 |
self._load_keyboard_mappings()
|
66 |
|
67 |
+
logger.info(f"GPU Worker {self.worker_id} initialized for {self.worker_address} on port {self.port}")
|
68 |
|
69 |
def _initialize_model(self):
|
70 |
+
"""Initialize the model on the GPU"""
|
71 |
+
logger.info(f"Initializing model for worker {self.worker_address}")
|
72 |
+
|
73 |
+
# Log CUDA environment info
|
74 |
+
logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'not set')}")
|
75 |
+
logger.info(f"Available CUDA devices: {torch.cuda.device_count()}")
|
76 |
+
if torch.cuda.is_available():
|
77 |
+
logger.info(f"Current CUDA device: {torch.cuda.current_device()}")
|
78 |
+
logger.info(f"Device name: {torch.cuda.get_device_name(0)}") # Always GPU 0
|
79 |
|
80 |
# Load latent stats
|
81 |
with open('latent_stats.json', 'r') as f:
|
|
|
109 |
self.padding_image = torch.zeros(*self.LATENT_DIMS).unsqueeze(0).to(self.device)
|
110 |
self.padding_image = (self.padding_image - self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1)
|
111 |
|
112 |
+
logger.info(f"Model initialized successfully for worker {self.worker_address}")
|
113 |
|
114 |
def _load_keyboard_mappings(self):
|
115 |
"""Load keyboard mappings from main.py"""
|
|
|
158 |
async with aiohttp.ClientSession() as session:
|
159 |
await session.post(f"{self.dispatcher_url}/register_worker", json={
|
160 |
"worker_id": self.worker_id,
|
161 |
+
"worker_address": self.worker_address,
|
162 |
+
"endpoint": f"http://{self.worker_address}"
|
163 |
})
|
164 |
+
logger.info(f"Successfully registered worker {self.worker_id} ({self.worker_address}) with dispatcher")
|
165 |
except Exception as e:
|
166 |
logger.error(f"Failed to register with dispatcher: {e}")
|
167 |
|
|
|
711 |
return {
|
712 |
"status": "healthy",
|
713 |
"worker_id": worker.worker_id if worker else None,
|
714 |
+
"worker_address": worker.worker_address if worker else None,
|
715 |
+
"port": worker.port if worker else None,
|
716 |
"current_session": worker.current_session if worker else None
|
717 |
}
|
718 |
|
719 |
+
async def startup_worker(worker_address: str, dispatcher_url: str):
|
720 |
"""Initialize the worker"""
|
721 |
global worker
|
722 |
+
worker = GPUWorker(worker_address, dispatcher_url)
|
723 |
|
724 |
# Register with dispatcher
|
725 |
await worker.register_with_dispatcher()
|
|
|
732 |
|
733 |
# Parse command line arguments
|
734 |
parser = argparse.ArgumentParser(description="GPU Worker for Neural OS")
|
735 |
+
parser.add_argument("--worker-address", type=str, required=True, help="Worker address (e.g., 'localhost:8001', '192.168.1.100:8002')")
|
736 |
parser.add_argument("--dispatcher-url", type=str, default="http://localhost:8000", help="Dispatcher URL")
|
737 |
args = parser.parse_args()
|
738 |
|
739 |
+
# Parse port from worker address for validation
|
740 |
+
if ':' not in args.worker_address:
|
741 |
+
print(f"Error: Invalid worker address format: {args.worker_address}")
|
742 |
+
print("Expected format: 'host:port' (e.g., 'localhost:8001')")
|
743 |
+
sys.exit(1)
|
744 |
+
|
745 |
+
try:
|
746 |
+
host, port_str = args.worker_address.split(':')
|
747 |
+
port = int(port_str)
|
748 |
+
except ValueError:
|
749 |
+
print(f"Error: Invalid port in worker address: {args.worker_address}")
|
750 |
+
sys.exit(1)
|
751 |
|
752 |
@app.on_event("startup")
|
753 |
async def startup_event():
|
754 |
+
await startup_worker(args.worker_address, args.dispatcher_url)
|
755 |
|
756 |
+
logger.info(f"Starting worker {args.worker_address}")
|
757 |
uvicorn.run(app, host="0.0.0.0", port=port)
|