da03 commited on
Commit
92199b3
·
1 Parent(s): 869b633
Files changed (3) hide show
  1. dispatcher.py +153 -83
  2. start_workers.py +14 -8
  3. worker.py +46 -19
dispatcher.py CHANGED
@@ -227,13 +227,13 @@ class SystemAnalytics:
227
  "avg_utilization_percent": avg_utilization
228
  })
229
 
230
- def log_worker_registered(self, worker_id: str, gpu_id: int, endpoint: str):
231
  """Log when a worker registers"""
232
- self._write_log(f"⚙️ WORKER REGISTERED: {worker_id} (GPU {gpu_id}) at {endpoint}")
233
 
234
- def log_worker_disconnected(self, worker_id: str, gpu_id: int):
235
  """Log when a worker disconnects"""
236
- self._write_log(f"⚙️ WORKER DISCONNECTED: {worker_id} (GPU {gpu_id})")
237
 
238
  def log_no_workers_available(self, queue_size: int):
239
  """Log critical situation when no workers are available"""
@@ -340,7 +340,7 @@ class UserSession:
340
  @dataclass
341
  class WorkerInfo:
342
  worker_id: str
343
- gpu_id: int
344
  endpoint: str
345
  is_available: bool
346
  current_session: Optional[str] = None
@@ -352,6 +352,7 @@ class SessionManager:
352
  self.workers: Dict[str, WorkerInfo] = {}
353
  self.session_queue: List[str] = []
354
  self.active_sessions: Dict[str, str] = {} # session_id -> worker_id
 
355
 
356
  # Configuration
357
  self.IDLE_TIMEOUT = 20.0 # When no queue
@@ -360,19 +361,26 @@ class SessionManager:
360
  self.QUEUE_SESSION_WARNING_TIME = 45.0 # 15 seconds before timeout
361
  self.GRACE_PERIOD = 10.0
362
 
363
- async def register_worker(self, worker_id: str, gpu_id: int, endpoint: str):
364
  """Register a new worker"""
 
 
 
 
 
 
365
  self.workers[worker_id] = WorkerInfo(
366
  worker_id=worker_id,
367
- gpu_id=gpu_id,
368
  endpoint=endpoint,
369
  is_available=True,
370
  last_ping=time.time()
371
  )
372
- logger.info(f"Registered worker {worker_id} on GPU {gpu_id} at {endpoint}")
 
373
 
374
  # Log worker registration
375
- analytics.log_worker_registered(worker_id, gpu_id, endpoint)
376
 
377
  # Log GPU status
378
  total_gpus = len(self.workers)
@@ -459,81 +467,89 @@ class SessionManager:
459
 
460
  async def process_queue(self):
461
  """Process the session queue"""
462
- # Track if we had any existing active sessions before processing
463
- had_active_sessions = len(self.active_sessions) > 0
464
-
465
- while self.session_queue:
466
- session_id = self.session_queue[0]
467
- session = self.sessions.get(session_id)
 
 
468
 
469
- if not session or session.status != SessionStatus.QUEUED:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  self.session_queue.pop(0)
471
- continue
 
 
472
 
473
- worker = await self.get_available_worker()
474
- if not worker:
475
- # Log critical situation if no workers are available
476
- if len(self.workers) == 0:
477
- analytics.log_no_workers_available(len(self.session_queue))
478
- break # No available workers
479
 
480
- # Calculate wait time
481
- wait_time = time.time() - session.queue_start_time if session.queue_start_time else 0
482
- queue_position = self.session_queue.index(session_id) + 1
483
-
484
- # Assign session to worker
485
- self.session_queue.pop(0)
486
- session.status = SessionStatus.ACTIVE
487
- session.worker_id = worker.worker_id
488
- session.last_activity = time.time()
489
-
490
- # Set session time limit based on queue status AFTER processing
491
- if len(self.session_queue) > 0:
492
- session.max_session_time = self.MAX_SESSION_TIME_WITH_QUEUE
493
- session.session_limit_start_time = time.time() # Track when limit started
494
-
495
- worker.is_available = False
496
- worker.current_session = session_id
497
- self.active_sessions[session_id] = worker.worker_id
498
-
499
- logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
500
-
501
- # Log analytics
502
- if wait_time > 0:
503
- analytics.log_queue_wait(session.client_id, wait_time, queue_position)
504
- else:
505
- analytics.log_queue_bypass(session.client_id)
506
-
507
- # Log GPU status
508
- total_gpus = len(self.workers)
509
- active_gpus = len([w for w in self.workers.values() if not w.is_available])
510
- available_gpus = total_gpus - active_gpus
511
- analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
512
-
513
- # Initialize session on worker with client_id for logging
514
- try:
515
- async with aiohttp.ClientSession() as client_session:
516
- await client_session.post(f"{worker.endpoint}/init_session", json={
517
- "session_id": session_id,
518
- "client_id": session.client_id
519
- })
520
- except Exception as e:
521
- logger.error(f"Failed to initialize session on worker {worker.worker_id}: {e}")
522
 
523
- # Notify user that their session is starting
524
- await self.notify_session_start(session)
 
 
525
 
526
- # Start session monitoring
527
- asyncio.create_task(self.monitor_active_session(session_id))
528
-
529
- # After processing queue, if there are still users waiting AND we had existing active sessions,
530
- # apply time limits to those existing sessions
531
- if len(self.session_queue) > 0 and had_active_sessions:
532
- await self.apply_queue_limits_to_existing_sessions()
533
-
534
- # If queue became empty and there are active sessions with time limits, remove them
535
- elif len(self.session_queue) == 0:
536
- await self.remove_time_limits_if_queue_empty()
537
 
538
  async def notify_session_start(self, session: UserSession):
539
  """Notify user that their session is starting"""
@@ -646,6 +662,10 @@ class SessionManager:
646
  # Free up the worker
647
  if session.worker_id and session.worker_id in self.workers:
648
  worker = self.workers[session.worker_id]
 
 
 
 
649
  worker.is_available = True
650
  worker.current_session = None
651
 
@@ -669,6 +689,9 @@ class SessionManager:
669
 
670
  logger.info(f"Ended session {session_id} with status {status}")
671
 
 
 
 
672
  # Process next in queue
673
  asyncio.create_task(self.process_queue())
674
 
@@ -773,6 +796,43 @@ class SessionManager:
773
  session.user_has_interacted = True
774
  logger.info(f"User started interacting in session {session_id}")
775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict):
777
  """Forward input to worker asynchronously"""
778
  try:
@@ -808,7 +868,7 @@ async def register_worker(worker_info: dict):
808
  """Endpoint for workers to register themselves"""
809
  await session_manager.register_worker(
810
  worker_info["worker_id"],
811
- worker_info["gpu_id"],
812
  worker_info["endpoint"]
813
  )
814
  return {"status": "registered"}
@@ -977,6 +1037,15 @@ async def periodic_queue_update():
977
  except Exception as e:
978
  logger.error(f"Error in periodic queue update: {e}")
979
 
 
 
 
 
 
 
 
 
 
980
  # Background task to periodically log analytics summary
981
  async def periodic_analytics_summary():
982
  while True:
@@ -996,12 +1065,12 @@ async def periodic_worker_health_check():
996
 
997
  for worker_id, worker in list(session_manager.workers.items()):
998
  if current_time - worker.last_ping > 30: # 30 second timeout
999
- disconnected_workers.append((worker_id, worker.gpu_id))
1000
 
1001
- for worker_id, gpu_id in disconnected_workers:
1002
- analytics.log_worker_disconnected(worker_id, gpu_id)
1003
  del session_manager.workers[worker_id]
1004
- logger.warning(f"Removed disconnected worker {worker_id} (GPU {gpu_id})")
1005
 
1006
  if disconnected_workers:
1007
  # Log updated GPU status
@@ -1017,6 +1086,7 @@ async def periodic_worker_health_check():
1017
  async def startup_event():
1018
  # Start background tasks
1019
  asyncio.create_task(periodic_queue_update())
 
1020
  asyncio.create_task(periodic_analytics_summary())
1021
  asyncio.create_task(periodic_worker_health_check())
1022
 
 
227
  "avg_utilization_percent": avg_utilization
228
  })
229
 
230
+ def log_worker_registered(self, worker_id: str, worker_address: str, endpoint: str):
231
  """Log when a worker registers"""
232
+ self._write_log(f"⚙️ WORKER REGISTERED: {worker_id} ({worker_address}) at {endpoint}")
233
 
234
+ def log_worker_disconnected(self, worker_id: str, worker_address: str):
235
  """Log when a worker disconnects"""
236
+ self._write_log(f"⚙️ WORKER DISCONNECTED: {worker_id} ({worker_address})")
237
 
238
  def log_no_workers_available(self, queue_size: int):
239
  """Log critical situation when no workers are available"""
 
340
  @dataclass
341
  class WorkerInfo:
342
  worker_id: str
343
+ worker_address: str # e.g., "localhost:8001", "192.168.1.100:8002"
344
  endpoint: str
345
  is_available: bool
346
  current_session: Optional[str] = None
 
352
  self.workers: Dict[str, WorkerInfo] = {}
353
  self.session_queue: List[str] = []
354
  self.active_sessions: Dict[str, str] = {} # session_id -> worker_id
355
+ self._queue_lock = asyncio.Lock() # Prevent race conditions in queue processing
356
 
357
  # Configuration
358
  self.IDLE_TIMEOUT = 20.0 # When no queue
 
361
  self.QUEUE_SESSION_WARNING_TIME = 45.0 # 15 seconds before timeout
362
  self.GRACE_PERIOD = 10.0
363
 
364
+ async def register_worker(self, worker_id: str, worker_address: str, endpoint: str):
365
  """Register a new worker"""
366
+ # Check for duplicate registrations
367
+ if worker_id in self.workers:
368
+ logger.warning(f"Worker {worker_id} already registered! Overwriting previous registration.")
369
+ logger.warning(f"Previous: {self.workers[worker_id].worker_address}, {self.workers[worker_id].endpoint}")
370
+ logger.warning(f"New: {worker_address}, {endpoint}")
371
+
372
  self.workers[worker_id] = WorkerInfo(
373
  worker_id=worker_id,
374
+ worker_address=worker_address,
375
  endpoint=endpoint,
376
  is_available=True,
377
  last_ping=time.time()
378
  )
379
+ logger.info(f"Registered worker {worker_id} ({worker_address}) at {endpoint}")
380
+ logger.info(f"Total workers now: {len(self.workers)} - {[w.worker_id for w in self.workers.values()]}")
381
 
382
  # Log worker registration
383
+ analytics.log_worker_registered(worker_id, worker_address, endpoint)
384
 
385
  # Log GPU status
386
  total_gpus = len(self.workers)
 
467
 
468
  async def process_queue(self):
469
  """Process the session queue"""
470
+ async with self._queue_lock: # Prevent race conditions
471
+ # Track if we had any existing active sessions before processing
472
+ had_active_sessions = len(self.active_sessions) > 0
473
+
474
+ # Add detailed logging for debugging
475
+ logger.info(f"Processing queue: {len(self.session_queue)} waiting, {len(self.active_sessions)} active")
476
+ logger.info(f"Available workers: {[f'{w.worker_id}({w.worker_address})' for w in self.workers.values() if w.is_available]}")
477
+ logger.info(f"Busy workers: {[f'{w.worker_id}({w.worker_address})' for w in self.workers.values() if not w.is_available]}")
478
 
479
+ while self.session_queue:
480
+ session_id = self.session_queue[0]
481
+ session = self.sessions.get(session_id)
482
+
483
+ if not session or session.status != SessionStatus.QUEUED:
484
+ self.session_queue.pop(0)
485
+ continue
486
+
487
+ worker = await self.get_available_worker()
488
+ if not worker:
489
+ # Log critical situation if no workers are available
490
+ if len(self.workers) == 0:
491
+ analytics.log_no_workers_available(len(self.session_queue))
492
+ logger.info(f"No available workers for session {session_id}. Queue processing stopped.")
493
+ break # No available workers
494
+
495
+ # Calculate wait time
496
+ wait_time = time.time() - session.queue_start_time if session.queue_start_time else 0
497
+ queue_position = self.session_queue.index(session_id) + 1
498
+
499
+ # Assign session to worker
500
  self.session_queue.pop(0)
501
+ session.status = SessionStatus.ACTIVE
502
+ session.worker_id = worker.worker_id
503
+ session.last_activity = time.time()
504
 
505
+ # Set session time limit based on queue status AFTER processing
506
+ if len(self.session_queue) > 0:
507
+ session.max_session_time = self.MAX_SESSION_TIME_WITH_QUEUE
508
+ session.session_limit_start_time = time.time() # Track when limit started
 
 
509
 
510
+ worker.is_available = False
511
+ worker.current_session = session_id
512
+ self.active_sessions[session_id] = worker.worker_id
513
+
514
+ logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
515
+ logger.info(f"Active sessions now: {len(self.active_sessions)}, Available workers: {len([w for w in self.workers.values() if w.is_available])}")
516
+
517
+ # Log analytics
518
+ if wait_time > 0:
519
+ analytics.log_queue_wait(session.client_id, wait_time, queue_position)
520
+ else:
521
+ analytics.log_queue_bypass(session.client_id)
522
+
523
+ # Log GPU status
524
+ total_gpus = len(self.workers)
525
+ active_gpus = len([w for w in self.workers.values() if not w.is_available])
526
+ available_gpus = total_gpus - active_gpus
527
+ analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
528
+
529
+ # Initialize session on worker with client_id for logging
530
+ try:
531
+ async with aiohttp.ClientSession() as client_session:
532
+ await client_session.post(f"{worker.endpoint}/init_session", json={
533
+ "session_id": session_id,
534
+ "client_id": session.client_id
535
+ })
536
+ except Exception as e:
537
+ logger.error(f"Failed to initialize session on worker {worker.worker_id}: {e}")
538
+
539
+ # Notify user that their session is starting
540
+ await self.notify_session_start(session)
541
+
542
+ # Start session monitoring
543
+ asyncio.create_task(self.monitor_active_session(session_id))
 
 
 
 
 
 
 
 
544
 
545
+ # After processing queue, if there are still users waiting AND we had existing active sessions,
546
+ # apply time limits to those existing sessions
547
+ if len(self.session_queue) > 0 and had_active_sessions:
548
+ await self.apply_queue_limits_to_existing_sessions()
549
 
550
+ # If queue became empty and there are active sessions with time limits, remove them
551
+ elif len(self.session_queue) == 0:
552
+ await self.remove_time_limits_if_queue_empty()
 
 
 
 
 
 
 
 
553
 
554
  async def notify_session_start(self, session: UserSession):
555
  """Notify user that their session is starting"""
 
662
  # Free up the worker
663
  if session.worker_id and session.worker_id in self.workers:
664
  worker = self.workers[session.worker_id]
665
+ if not worker.is_available: # Only log if worker was actually busy
666
+ logger.info(f"Freeing worker {worker.worker_id} from session {session_id}")
667
+ else:
668
+ logger.warning(f"Worker {worker.worker_id} was already available when freeing from session {session_id}")
669
  worker.is_available = True
670
  worker.current_session = None
671
 
 
689
 
690
  logger.info(f"Ended session {session_id} with status {status}")
691
 
692
+ # Validate system state consistency
693
+ await self._validate_system_state()
694
+
695
  # Process next in queue
696
  asyncio.create_task(self.process_queue())
697
 
 
796
  session.user_has_interacted = True
797
  logger.info(f"User started interacting in session {session_id}")
798
 
799
+ async def _validate_system_state(self):
800
+ """Validate system state consistency for debugging"""
801
+ try:
802
+ # Count active sessions
803
+ active_sessions_count = len(self.active_sessions)
804
+ busy_workers_count = len([w for w in self.workers.values() if not w.is_available])
805
+
806
+ # Check for inconsistencies
807
+ if active_sessions_count != busy_workers_count:
808
+ logger.error(f"INCONSISTENCY: Active sessions ({active_sessions_count}) != Busy workers ({busy_workers_count})")
809
+ logger.error(f"Active sessions: {list(self.active_sessions.keys())}")
810
+ logger.error(f"Busy workers: {[w.worker_id for w in self.workers.values() if not w.is_available]}")
811
+
812
+ # Log detailed state
813
+ for session_id, worker_id in self.active_sessions.items():
814
+ session = self.sessions.get(session_id)
815
+ worker = self.workers.get(worker_id)
816
+ logger.error(f"Session {session_id}: status={session.status if session else 'MISSING'}, worker={worker_id}")
817
+ if worker:
818
+ logger.error(f"Worker {worker_id}: available={worker.is_available}, current_session={worker.current_session}")
819
+
820
+ # Check for orphaned workers
821
+ for worker in self.workers.values():
822
+ if not worker.is_available and worker.current_session not in self.active_sessions:
823
+ logger.error(f"ORPHANED WORKER: {worker.worker_id} is busy but session {worker.current_session} not in active_sessions")
824
+
825
+ # Check for sessions without workers
826
+ for session_id in self.active_sessions:
827
+ session = self.sessions.get(session_id)
828
+ if session and session.status == SessionStatus.ACTIVE:
829
+ worker = self.workers.get(session.worker_id)
830
+ if not worker or worker.is_available:
831
+ logger.error(f"ACTIVE SESSION WITHOUT WORKER: {session_id} has no busy worker assigned")
832
+
833
+ except Exception as e:
834
+ logger.error(f"Error in system state validation: {e}")
835
+
836
  async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict):
837
  """Forward input to worker asynchronously"""
838
  try:
 
868
  """Endpoint for workers to register themselves"""
869
  await session_manager.register_worker(
870
  worker_info["worker_id"],
871
+ worker_info["worker_address"],
872
  worker_info["endpoint"]
873
  )
874
  return {"status": "registered"}
 
1037
  except Exception as e:
1038
  logger.error(f"Error in periodic queue update: {e}")
1039
 
1040
+ # Background task to periodically validate system state
1041
+ async def periodic_system_validation():
1042
+ while True:
1043
+ try:
1044
+ await asyncio.sleep(10) # Validate every 10 seconds
1045
+ await session_manager._validate_system_state()
1046
+ except Exception as e:
1047
+ logger.error(f"Error in periodic system validation: {e}")
1048
+
1049
  # Background task to periodically log analytics summary
1050
  async def periodic_analytics_summary():
1051
  while True:
 
1065
 
1066
  for worker_id, worker in list(session_manager.workers.items()):
1067
  if current_time - worker.last_ping > 30: # 30 second timeout
1068
+ disconnected_workers.append((worker_id, worker.worker_address))
1069
 
1070
+ for worker_id, worker_address in disconnected_workers:
1071
+ analytics.log_worker_disconnected(worker_id, worker_address)
1072
  del session_manager.workers[worker_id]
1073
+ logger.warning(f"Removed disconnected worker {worker_id} ({worker_address})")
1074
 
1075
  if disconnected_workers:
1076
  # Log updated GPU status
 
1086
  async def startup_event():
1087
  # Start background tasks
1088
  asyncio.create_task(periodic_queue_update())
1089
+ asyncio.create_task(periodic_system_validation())
1090
  asyncio.create_task(periodic_analytics_summary())
1091
  asyncio.create_task(periodic_worker_health_check())
1092
 
start_workers.py CHANGED
@@ -27,10 +27,11 @@ class WorkerManager:
27
  port = 8001 + gpu_id
28
  print(f"Starting worker for GPU {gpu_id} on port {port}...")
29
 
30
- # Start worker process
 
31
  cmd = [
32
  sys.executable, "worker.py",
33
- "--gpu-id", str(gpu_id),
34
  "--dispatcher-url", self.dispatcher_url
35
  ]
36
 
@@ -39,30 +40,35 @@ class WorkerManager:
39
  with open(log_file, 'w') as f:
40
  f.write(f"Starting worker for GPU {gpu_id}\n")
41
 
 
 
 
 
 
42
  process = subprocess.Popen(
43
  cmd,
44
  stdout=open(log_file, 'a'),
45
  stderr=subprocess.STDOUT,
46
  universal_newlines=True,
47
- bufsize=1
 
48
  )
49
 
50
  self.processes.append(process)
51
- print(f"✓ Started worker {gpu_id} (PID: {process.pid}) - Log: {log_file}")
52
 
53
  # Small delay between starts
54
  time.sleep(1)
55
 
56
  except Exception as e:
57
- print(f"✗ Failed to start worker for GPU {gpu_id}: {e}")
58
  self.cleanup()
59
  return False
60
 
61
  print(f"\n✓ All {self.num_gpus} workers started successfully!")
62
- print("Workers are running on ports:", [8001 + i for i in range(self.num_gpus)])
63
- print("Worker log files:")
64
  for i in range(self.num_gpus):
65
- print(f" GPU {i}: worker_gpu_{i}.log")
66
  return True
67
 
68
  def monitor_workers(self):
 
27
  port = 8001 + gpu_id
28
  print(f"Starting worker for GPU {gpu_id} on port {port}...")
29
 
30
+ # Start worker process with GPU isolation
31
+ worker_address = f"localhost:{port}"
32
  cmd = [
33
  sys.executable, "worker.py",
34
+ "--worker-address", worker_address,
35
  "--dispatcher-url", self.dispatcher_url
36
  ]
37
 
 
40
  with open(log_file, 'w') as f:
41
  f.write(f"Starting worker for GPU {gpu_id}\n")
42
 
43
+ # Set environment variables for GPU isolation
44
+ env = os.environ.copy()
45
+ env['CUDA_VISIBLE_DEVICES'] = str(gpu_id) # Only show this GPU to the worker
46
+ env['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' # Consistent GPU ordering
47
+
48
  process = subprocess.Popen(
49
  cmd,
50
  stdout=open(log_file, 'a'),
51
  stderr=subprocess.STDOUT,
52
  universal_newlines=True,
53
+ bufsize=1,
54
+ env=env # Pass the modified environment
55
  )
56
 
57
  self.processes.append(process)
58
+ print(f"✓ Started worker {worker_address} (PID: {process.pid}) - Log: {log_file}")
59
 
60
  # Small delay between starts
61
  time.sleep(1)
62
 
63
  except Exception as e:
64
+ print(f"✗ Failed to start worker {worker_address}: {e}")
65
  self.cleanup()
66
  return False
67
 
68
  print(f"\n✓ All {self.num_gpus} workers started successfully!")
69
+ print("Worker addresses:")
 
70
  for i in range(self.num_gpus):
71
+ print(f" localhost:{8001 + i} - log: worker_gpu_{i}.log")
72
  return True
73
 
74
  def monitor_workers(self):
worker.py CHANGED
@@ -16,6 +16,7 @@ import concurrent.futures
16
  import aiohttp
17
  import argparse
18
  import uuid
 
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
@@ -26,11 +27,19 @@ torch.backends.cuda.matmul.allow_tf32 = True
26
  torch.backends.cudnn.allow_tf32 = True
27
 
28
  class GPUWorker:
29
- def __init__(self, gpu_id: int, dispatcher_url: str = "http://localhost:8000"):
30
- self.gpu_id = gpu_id
 
 
 
 
 
 
 
31
  self.dispatcher_url = dispatcher_url
32
- self.worker_id = f"worker_{gpu_id}_{uuid.uuid4().hex[:8]}"
33
- self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
 
34
  self.current_session: Optional[str] = None
35
  self.session_data: Dict[str, Any] = {}
36
 
@@ -55,11 +64,18 @@ class GPUWorker:
55
  # Load keyboard mappings
56
  self._load_keyboard_mappings()
57
 
58
- logger.info(f"GPU Worker {self.worker_id} initialized on GPU {gpu_id}")
59
 
60
  def _initialize_model(self):
61
- """Initialize the model on the specified GPU"""
62
- logger.info(f"Initializing model on GPU {self.gpu_id}")
 
 
 
 
 
 
 
63
 
64
  # Load latent stats
65
  with open('latent_stats.json', 'r') as f:
@@ -93,7 +109,7 @@ class GPUWorker:
93
  self.padding_image = torch.zeros(*self.LATENT_DIMS).unsqueeze(0).to(self.device)
94
  self.padding_image = (self.padding_image - self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1)
95
 
96
- logger.info(f"Model initialized successfully on GPU {self.gpu_id}")
97
 
98
  def _load_keyboard_mappings(self):
99
  """Load keyboard mappings from main.py"""
@@ -142,10 +158,10 @@ class GPUWorker:
142
  async with aiohttp.ClientSession() as session:
143
  await session.post(f"{self.dispatcher_url}/register_worker", json={
144
  "worker_id": self.worker_id,
145
- "gpu_id": self.gpu_id,
146
- "endpoint": f"http://localhost:{8001 + self.gpu_id}"
147
  })
148
- logger.info(f"Successfully registered worker {self.worker_id} with dispatcher")
149
  except Exception as e:
150
  logger.error(f"Failed to register with dispatcher: {e}")
151
 
@@ -695,14 +711,15 @@ async def health_check():
695
  return {
696
  "status": "healthy",
697
  "worker_id": worker.worker_id if worker else None,
698
- "gpu_id": worker.gpu_id if worker else None,
 
699
  "current_session": worker.current_session if worker else None
700
  }
701
 
702
- async def startup_worker(gpu_id: int, dispatcher_url: str):
703
  """Initialize the worker"""
704
  global worker
705
- worker = GPUWorker(gpu_id, dispatcher_url)
706
 
707
  # Register with dispatcher
708
  await worker.register_with_dispatcher()
@@ -715,16 +732,26 @@ if __name__ == "__main__":
715
 
716
  # Parse command line arguments
717
  parser = argparse.ArgumentParser(description="GPU Worker for Neural OS")
718
- parser.add_argument("--gpu-id", type=int, required=True, help="GPU ID to use")
719
  parser.add_argument("--dispatcher-url", type=str, default="http://localhost:8000", help="Dispatcher URL")
720
  args = parser.parse_args()
721
 
722
- # Calculate port based on GPU ID
723
- port = 8001 + args.gpu_id
 
 
 
 
 
 
 
 
 
 
724
 
725
  @app.on_event("startup")
726
  async def startup_event():
727
- await startup_worker(args.gpu_id, args.dispatcher_url)
728
 
729
- logger.info(f"Starting worker on GPU {args.gpu_id}, port {port}")
730
  uvicorn.run(app, host="0.0.0.0", port=port)
 
16
  import aiohttp
17
  import argparse
18
  import uuid
19
+ import sys
20
 
21
  # Configure logging
22
  logging.basicConfig(level=logging.INFO)
 
27
  torch.backends.cudnn.allow_tf32 = True
28
 
29
  class GPUWorker:
30
+ def __init__(self, worker_address: str, dispatcher_url: str = "http://localhost:8000"):
31
+ self.worker_address = worker_address # e.g., "localhost:8001", "192.168.1.100:8002"
32
+ # Parse port from worker address
33
+ if ':' in worker_address:
34
+ self.host, port_str = worker_address.split(':')
35
+ self.port = int(port_str)
36
+ else:
37
+ raise ValueError(f"Invalid worker address format: {worker_address}. Expected format: 'host:port'")
38
+
39
  self.dispatcher_url = dispatcher_url
40
+ self.worker_id = f"worker_{worker_address.replace(':', '_')}_{uuid.uuid4().hex[:8]}"
41
+ # Always use GPU 0 since CUDA_VISIBLE_DEVICES limits visibility to one GPU
42
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
43
  self.current_session: Optional[str] = None
44
  self.session_data: Dict[str, Any] = {}
45
 
 
64
  # Load keyboard mappings
65
  self._load_keyboard_mappings()
66
 
67
+ logger.info(f"GPU Worker {self.worker_id} initialized for {self.worker_address} on port {self.port}")
68
 
69
  def _initialize_model(self):
70
+ """Initialize the model on the GPU"""
71
+ logger.info(f"Initializing model for worker {self.worker_address}")
72
+
73
+ # Log CUDA environment info
74
+ logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'not set')}")
75
+ logger.info(f"Available CUDA devices: {torch.cuda.device_count()}")
76
+ if torch.cuda.is_available():
77
+ logger.info(f"Current CUDA device: {torch.cuda.current_device()}")
78
+ logger.info(f"Device name: {torch.cuda.get_device_name(0)}") # Always GPU 0
79
 
80
  # Load latent stats
81
  with open('latent_stats.json', 'r') as f:
 
109
  self.padding_image = torch.zeros(*self.LATENT_DIMS).unsqueeze(0).to(self.device)
110
  self.padding_image = (self.padding_image - self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1)
111
 
112
+ logger.info(f"Model initialized successfully for worker {self.worker_address}")
113
 
114
  def _load_keyboard_mappings(self):
115
  """Load keyboard mappings from main.py"""
 
158
  async with aiohttp.ClientSession() as session:
159
  await session.post(f"{self.dispatcher_url}/register_worker", json={
160
  "worker_id": self.worker_id,
161
+ "worker_address": self.worker_address,
162
+ "endpoint": f"http://{self.worker_address}"
163
  })
164
+ logger.info(f"Successfully registered worker {self.worker_id} ({self.worker_address}) with dispatcher")
165
  except Exception as e:
166
  logger.error(f"Failed to register with dispatcher: {e}")
167
 
 
711
  return {
712
  "status": "healthy",
713
  "worker_id": worker.worker_id if worker else None,
714
+ "worker_address": worker.worker_address if worker else None,
715
+ "port": worker.port if worker else None,
716
  "current_session": worker.current_session if worker else None
717
  }
718
 
719
+ async def startup_worker(worker_address: str, dispatcher_url: str):
720
  """Initialize the worker"""
721
  global worker
722
+ worker = GPUWorker(worker_address, dispatcher_url)
723
 
724
  # Register with dispatcher
725
  await worker.register_with_dispatcher()
 
732
 
733
  # Parse command line arguments
734
  parser = argparse.ArgumentParser(description="GPU Worker for Neural OS")
735
+ parser.add_argument("--worker-address", type=str, required=True, help="Worker address (e.g., 'localhost:8001', '192.168.1.100:8002')")
736
  parser.add_argument("--dispatcher-url", type=str, default="http://localhost:8000", help="Dispatcher URL")
737
  args = parser.parse_args()
738
 
739
+ # Parse port from worker address for validation
740
+ if ':' not in args.worker_address:
741
+ print(f"Error: Invalid worker address format: {args.worker_address}")
742
+ print("Expected format: 'host:port' (e.g., 'localhost:8001')")
743
+ sys.exit(1)
744
+
745
+ try:
746
+ host, port_str = args.worker_address.split(':')
747
+ port = int(port_str)
748
+ except ValueError:
749
+ print(f"Error: Invalid port in worker address: {args.worker_address}")
750
+ sys.exit(1)
751
 
752
  @app.on_event("startup")
753
  async def startup_event():
754
+ await startup_worker(args.worker_address, args.dispatcher_url)
755
 
756
+ logger.info(f"Starting worker {args.worker_address}")
757
  uvicorn.run(app, host="0.0.0.0", port=port)