da03 commited on
Commit
ef88fd2
·
1 Parent(s): 88a7d94
Files changed (4) hide show
  1. dispatcher.py +441 -0
  2. start_workers.py +152 -0
  3. ttt.py +8 -0
  4. worker.py +635 -0
dispatcher.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from typing import List, Dict, Any, Optional
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import os
9
+ from dataclasses import dataclass, asdict
10
+ from enum import Enum
11
+ import uuid
12
+ import aiohttp
13
+ import logging
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class SessionStatus(Enum):
20
+ QUEUED = "queued"
21
+ ACTIVE = "active"
22
+ COMPLETED = "completed"
23
+ TIMEOUT = "timeout"
24
+
25
+ @dataclass
26
+ class UserSession:
27
+ session_id: str
28
+ client_id: str
29
+ websocket: WebSocket
30
+ created_at: float
31
+ status: SessionStatus
32
+ worker_id: Optional[str] = None
33
+ last_activity: Optional[float] = None
34
+ max_session_time: Optional[float] = None
35
+ user_has_interacted: bool = False
36
+
37
+ @dataclass
38
+ class WorkerInfo:
39
+ worker_id: str
40
+ gpu_id: int
41
+ endpoint: str
42
+ is_available: bool
43
+ current_session: Optional[str] = None
44
+ last_ping: float = 0
45
+
46
+ class SessionManager:
47
+ def __init__(self):
48
+ self.sessions: Dict[str, UserSession] = {}
49
+ self.workers: Dict[str, WorkerInfo] = {}
50
+ self.session_queue: List[str] = []
51
+ self.active_sessions: Dict[str, str] = {} # session_id -> worker_id
52
+
53
+ # Configuration
54
+ self.IDLE_TIMEOUT = 20.0 # When no queue
55
+ self.QUEUE_WARNING_TIME = 10.0
56
+ self.MAX_SESSION_TIME_WITH_QUEUE = 60.0 # When there's a queue
57
+ self.QUEUE_SESSION_WARNING_TIME = 45.0 # 15 seconds before timeout
58
+ self.GRACE_PERIOD = 10.0
59
+
60
+ async def register_worker(self, worker_id: str, gpu_id: int, endpoint: str):
61
+ """Register a new worker"""
62
+ self.workers[worker_id] = WorkerInfo(
63
+ worker_id=worker_id,
64
+ gpu_id=gpu_id,
65
+ endpoint=endpoint,
66
+ is_available=True,
67
+ last_ping=time.time()
68
+ )
69
+ logger.info(f"Registered worker {worker_id} on GPU {gpu_id} at {endpoint}")
70
+
71
+ async def get_available_worker(self) -> Optional[WorkerInfo]:
72
+ """Get an available worker"""
73
+ for worker in self.workers.values():
74
+ if worker.is_available and time.time() - worker.last_ping < 30: # Worker ping timeout
75
+ return worker
76
+ return None
77
+
78
+ async def add_session_to_queue(self, session: UserSession):
79
+ """Add a session to the queue"""
80
+ self.sessions[session.session_id] = session
81
+ self.session_queue.append(session.session_id)
82
+ session.status = SessionStatus.QUEUED
83
+ logger.info(f"Added session {session.session_id} to queue. Queue size: {len(self.session_queue)}")
84
+
85
+ async def process_queue(self):
86
+ """Process the session queue"""
87
+ while self.session_queue:
88
+ session_id = self.session_queue[0]
89
+ session = self.sessions.get(session_id)
90
+
91
+ if not session or session.status != SessionStatus.QUEUED:
92
+ self.session_queue.pop(0)
93
+ continue
94
+
95
+ worker = await self.get_available_worker()
96
+ if not worker:
97
+ break # No available workers
98
+
99
+ # Assign session to worker
100
+ self.session_queue.pop(0)
101
+ session.status = SessionStatus.ACTIVE
102
+ session.worker_id = worker.worker_id
103
+ session.last_activity = time.time()
104
+
105
+ # Set session time limit based on queue status
106
+ if len(self.session_queue) > 0:
107
+ session.max_session_time = self.MAX_SESSION_TIME_WITH_QUEUE
108
+
109
+ worker.is_available = False
110
+ worker.current_session = session_id
111
+ self.active_sessions[session_id] = worker.worker_id
112
+
113
+ logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
114
+
115
+ # Notify user that their session is starting
116
+ await self.notify_session_start(session)
117
+
118
+ # Start session monitoring
119
+ asyncio.create_task(self.monitor_active_session(session_id))
120
+
121
+ async def notify_session_start(self, session: UserSession):
122
+ """Notify user that their session is starting"""
123
+ try:
124
+ await session.websocket.send_json({
125
+ "type": "session_start",
126
+ "worker_id": session.worker_id,
127
+ "max_session_time": session.max_session_time
128
+ })
129
+ except Exception as e:
130
+ logger.error(f"Failed to notify session start for {session.session_id}: {e}")
131
+
132
+ async def monitor_active_session(self, session_id: str):
133
+ """Monitor an active session for timeouts"""
134
+ session = self.sessions.get(session_id)
135
+ if not session:
136
+ return
137
+
138
+ try:
139
+ while session.status == SessionStatus.ACTIVE:
140
+ current_time = time.time()
141
+
142
+ # Check if session has exceeded time limit
143
+ if session.max_session_time:
144
+ elapsed = current_time - session.last_activity if session.last_activity else 0
145
+ remaining = session.max_session_time - elapsed
146
+
147
+ # Send warning at 15 seconds before timeout
148
+ if remaining <= 15 and remaining > 10:
149
+ await session.websocket.send_json({
150
+ "type": "session_warning",
151
+ "time_remaining": remaining,
152
+ "queue_size": len(self.session_queue)
153
+ })
154
+
155
+ # Grace period handling
156
+ elif remaining <= 10 and remaining > 0:
157
+ # Check if queue is empty - if so, extend session
158
+ if len(self.session_queue) == 0:
159
+ session.max_session_time = None # Remove time limit
160
+ await session.websocket.send_json({
161
+ "type": "time_limit_removed",
162
+ "reason": "queue_empty"
163
+ })
164
+ else:
165
+ await session.websocket.send_json({
166
+ "type": "grace_period",
167
+ "time_remaining": remaining,
168
+ "queue_size": len(self.session_queue)
169
+ })
170
+
171
+ # Timeout
172
+ elif remaining <= 0:
173
+ await self.end_session(session_id, SessionStatus.TIMEOUT)
174
+ return
175
+
176
+ # Check idle timeout when no queue
177
+ elif not session.max_session_time and session.last_activity:
178
+ idle_time = current_time - session.last_activity
179
+ if idle_time >= self.IDLE_TIMEOUT:
180
+ await self.end_session(session_id, SessionStatus.TIMEOUT)
181
+ return
182
+ elif idle_time >= self.QUEUE_WARNING_TIME:
183
+ await session.websocket.send_json({
184
+ "type": "idle_warning",
185
+ "time_remaining": self.IDLE_TIMEOUT - idle_time
186
+ })
187
+
188
+ await asyncio.sleep(1) # Check every second
189
+
190
+ except Exception as e:
191
+ logger.error(f"Error monitoring session {session_id}: {e}")
192
+ await self.end_session(session_id, SessionStatus.COMPLETED)
193
+
194
+ async def end_session(self, session_id: str, status: SessionStatus):
195
+ """End a session and free up the worker"""
196
+ session = self.sessions.get(session_id)
197
+ if not session:
198
+ return
199
+
200
+ session.status = status
201
+
202
+ # Free up the worker
203
+ if session.worker_id and session.worker_id in self.workers:
204
+ worker = self.workers[session.worker_id]
205
+ worker.is_available = True
206
+ worker.current_session = None
207
+
208
+ # Notify worker to clean up
209
+ try:
210
+ async with aiohttp.ClientSession() as client_session:
211
+ await client_session.post(f"{worker.endpoint}/end_session",
212
+ json={"session_id": session_id})
213
+ except Exception as e:
214
+ logger.error(f"Failed to notify worker {worker.worker_id} of session end: {e}")
215
+
216
+ # Remove from active sessions
217
+ if session_id in self.active_sessions:
218
+ del self.active_sessions[session_id]
219
+
220
+ logger.info(f"Ended session {session_id} with status {status}")
221
+
222
+ # Process next in queue
223
+ asyncio.create_task(self.process_queue())
224
+
225
+ async def update_queue_info(self):
226
+ """Send queue information to waiting users"""
227
+ for i, session_id in enumerate(self.session_queue):
228
+ session = self.sessions.get(session_id)
229
+ if session and session.status == SessionStatus.QUEUED:
230
+ try:
231
+ # Calculate estimated wait time
232
+ active_sessions_count = len(self.active_sessions)
233
+ avg_session_time = self.MAX_SESSION_TIME_WITH_QUEUE if active_sessions_count > 0 else 30.0
234
+ estimated_wait = (i + 1) * avg_session_time / max(len(self.workers), 1)
235
+
236
+ await session.websocket.send_json({
237
+ "type": "queue_update",
238
+ "position": i + 1,
239
+ "total_waiting": len(self.session_queue),
240
+ "estimated_wait_minutes": estimated_wait / 60,
241
+ "active_sessions": active_sessions_count
242
+ })
243
+ except Exception as e:
244
+ logger.error(f"Failed to send queue update to session {session_id}: {e}")
245
+
246
+ async def handle_user_activity(self, session_id: str):
247
+ """Update user activity timestamp"""
248
+ session = self.sessions.get(session_id)
249
+ if session:
250
+ session.last_activity = time.time()
251
+ if not session.user_has_interacted:
252
+ session.user_has_interacted = True
253
+ logger.info(f"User started interacting in session {session_id}")
254
+
255
+ async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict):
256
+ """Forward input to worker asynchronously"""
257
+ try:
258
+ async with aiohttp.ClientSession() as client_session:
259
+ async with client_session.post(
260
+ f"{worker.endpoint}/process_input",
261
+ json={
262
+ "session_id": session_id,
263
+ "data": data
264
+ }
265
+ ) as response:
266
+ if response.status != 200:
267
+ logger.error(f"Worker returned status {response.status}")
268
+ # Optionally handle worker errors here
269
+ except Exception as e:
270
+ logger.error(f"Error forwarding to worker {worker.worker_id}: {e}")
271
+
272
+ # Global session manager
273
+ session_manager = SessionManager()
274
+
275
+ app = FastAPI()
276
+ app.mount("/static", StaticFiles(directory="static"), name="static")
277
+
278
+ @app.get("/")
279
+ async def get():
280
+ return HTMLResponse(open("static/index.html").read())
281
+
282
+ @app.post("/register_worker")
283
+ async def register_worker(worker_info: dict):
284
+ """Endpoint for workers to register themselves"""
285
+ await session_manager.register_worker(
286
+ worker_info["worker_id"],
287
+ worker_info["gpu_id"],
288
+ worker_info["endpoint"]
289
+ )
290
+ return {"status": "registered"}
291
+
292
+ @app.post("/worker_ping")
293
+ async def worker_ping(worker_info: dict):
294
+ """Endpoint for workers to ping their availability"""
295
+ worker_id = worker_info["worker_id"]
296
+ if worker_id in session_manager.workers:
297
+ session_manager.workers[worker_id].last_ping = time.time()
298
+ session_manager.workers[worker_id].is_available = worker_info.get("is_available", True)
299
+ return {"status": "ok"}
300
+
301
+ @app.post("/worker_result")
302
+ async def worker_result(result_data: dict):
303
+ """Endpoint for workers to send back processing results"""
304
+ session_id = result_data.get("session_id")
305
+ worker_id = result_data.get("worker_id")
306
+ result = result_data.get("result")
307
+
308
+ if not session_id or not result:
309
+ raise HTTPException(status_code=400, detail="Missing session_id or result")
310
+
311
+ # Find the session and send result to the WebSocket
312
+ session = session_manager.sessions.get(session_id)
313
+ if session and session.status == SessionStatus.ACTIVE:
314
+ try:
315
+ await session.websocket.send_json(result)
316
+ logger.info(f"Sent result to session {session_id}")
317
+ except Exception as e:
318
+ logger.error(f"Failed to send result to session {session_id}: {e}")
319
+ else:
320
+ logger.warning(f"Could not find active session {session_id} for result")
321
+
322
+ return {"status": "ok"}
323
+
324
+ @app.websocket("/ws")
325
+ async def websocket_endpoint(websocket: WebSocket):
326
+ await websocket.accept()
327
+
328
+ # Create session
329
+ session_id = str(uuid.uuid4())
330
+ client_id = f"{int(time.time())}_{session_id[:8]}"
331
+
332
+ session = UserSession(
333
+ session_id=session_id,
334
+ client_id=client_id,
335
+ websocket=websocket,
336
+ created_at=time.time(),
337
+ status=SessionStatus.QUEUED
338
+ )
339
+
340
+ logger.info(f"New WebSocket connection: {client_id}")
341
+
342
+ try:
343
+ # Add to queue
344
+ await session_manager.add_session_to_queue(session)
345
+
346
+ # Try to process queue immediately
347
+ await session_manager.process_queue()
348
+
349
+ # Send initial queue status
350
+ if session.status == SessionStatus.QUEUED:
351
+ await session_manager.update_queue_info()
352
+
353
+ # Main message loop
354
+ while True:
355
+ try:
356
+ data = await websocket.receive_json()
357
+
358
+ # Update activity
359
+ await session_manager.handle_user_activity(session_id)
360
+
361
+ # Handle different message types
362
+ if data.get("type") == "heartbeat":
363
+ await websocket.send_json({"type": "heartbeat_response"})
364
+ continue
365
+
366
+ # If session is active, forward to worker
367
+ if session.status == SessionStatus.ACTIVE and session.worker_id:
368
+ worker = session_manager.workers.get(session.worker_id)
369
+ if worker:
370
+ try:
371
+ # Forward message to worker (don't wait for response for regular inputs)
372
+ # The worker will send results back asynchronously via /worker_result
373
+ asyncio.create_task(session_manager._forward_to_worker(worker, session_id, data))
374
+ except Exception as e:
375
+ logger.error(f"Error forwarding to worker: {e}")
376
+
377
+ # Handle control messages (these need synchronous responses)
378
+ elif data.get("type") in ["reset", "update_sampling_steps", "update_use_rnn", "get_settings"]:
379
+ if session.status == SessionStatus.ACTIVE and session.worker_id:
380
+ worker = session_manager.workers.get(session.worker_id)
381
+ if worker:
382
+ try:
383
+ async with aiohttp.ClientSession() as client_session:
384
+ async with client_session.post(
385
+ f"{worker.endpoint}/process_input",
386
+ json={
387
+ "session_id": session_id,
388
+ "data": data
389
+ }
390
+ ) as response:
391
+ if response.status == 200:
392
+ result = await response.json()
393
+ await websocket.send_json(result)
394
+ else:
395
+ logger.error(f"Worker returned status {response.status}")
396
+ except Exception as e:
397
+ logger.error(f"Error forwarding control message: {e}")
398
+ else:
399
+ # Send appropriate response for queued users
400
+ await websocket.send_json({
401
+ "type": "error",
402
+ "message": "Session not active yet. Please wait in queue."
403
+ })
404
+
405
+ except asyncio.TimeoutError:
406
+ logger.info("WebSocket connection timed out")
407
+ break
408
+ except WebSocketDisconnect:
409
+ logger.info(f"WebSocket disconnected: {client_id}")
410
+ break
411
+
412
+ except Exception as e:
413
+ logger.error(f"Error in WebSocket connection {client_id}: {e}")
414
+ import traceback
415
+ traceback.print_exc()
416
+
417
+ finally:
418
+ # Clean up session
419
+ if session_id in session_manager.sessions:
420
+ await session_manager.end_session(session_id, SessionStatus.COMPLETED)
421
+ del session_manager.sessions[session_id]
422
+
423
+ logger.info(f"WebSocket connection closed: {client_id}")
424
+
425
+ # Background task to periodically update queue info
426
+ async def periodic_queue_update():
427
+ while True:
428
+ try:
429
+ await session_manager.update_queue_info()
430
+ await asyncio.sleep(5) # Update every 5 seconds
431
+ except Exception as e:
432
+ logger.error(f"Error in periodic queue update: {e}")
433
+
434
+ @app.on_event("startup")
435
+ async def startup_event():
436
+ # Start background tasks
437
+ asyncio.create_task(periodic_queue_update())
438
+
439
+ if __name__ == "__main__":
440
+ import uvicorn
441
+ uvicorn.run(app, host="0.0.0.0", port=8000)
start_workers.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to start multiple GPU workers for the neural OS demo.
4
+ Usage: python start_workers.py --num-gpus 4
5
+ """
6
+
7
+ import argparse
8
+ import subprocess
9
+ import time
10
+ import sys
11
+ import signal
12
+ import os
13
+ from typing import List
14
+
15
+ class WorkerManager:
16
+ def __init__(self, num_gpus: int, dispatcher_url: str = "http://localhost:8000"):
17
+ self.num_gpus = num_gpus
18
+ self.dispatcher_url = dispatcher_url
19
+ self.processes: List[subprocess.Popen] = []
20
+
21
+ def start_workers(self):
22
+ """Start all worker processes"""
23
+ print(f"Starting {self.num_gpus} GPU workers...")
24
+
25
+ for gpu_id in range(self.num_gpus):
26
+ try:
27
+ port = 8001 + gpu_id
28
+ print(f"Starting worker for GPU {gpu_id} on port {port}...")
29
+
30
+ # Start worker process
31
+ cmd = [
32
+ sys.executable, "worker.py",
33
+ "--gpu-id", str(gpu_id),
34
+ "--dispatcher-url", self.dispatcher_url
35
+ ]
36
+
37
+ process = subprocess.Popen(
38
+ cmd,
39
+ stdout=subprocess.PIPE,
40
+ stderr=subprocess.STDOUT,
41
+ universal_newlines=True,
42
+ bufsize=1
43
+ )
44
+
45
+ self.processes.append(process)
46
+ print(f"✓ Started worker {gpu_id} (PID: {process.pid})")
47
+
48
+ # Small delay between starts
49
+ time.sleep(1)
50
+
51
+ except Exception as e:
52
+ print(f"✗ Failed to start worker for GPU {gpu_id}: {e}")
53
+ self.cleanup()
54
+ return False
55
+
56
+ print(f"\n✓ All {self.num_gpus} workers started successfully!")
57
+ print("Workers are running on ports:", [8001 + i for i in range(self.num_gpus)])
58
+ return True
59
+
60
+ def monitor_workers(self):
61
+ """Monitor worker processes and print their output"""
62
+ print("\nMonitoring workers (Ctrl+C to stop)...")
63
+ print("-" * 50)
64
+
65
+ try:
66
+ while True:
67
+ # Check if any process has died
68
+ for i, process in enumerate(self.processes):
69
+ if process.poll() is not None:
70
+ print(f"⚠️ Worker {i} (PID: {process.pid}) has died!")
71
+ # Optionally restart it
72
+
73
+ # Print output from processes
74
+ for i, process in enumerate(self.processes):
75
+ if process.stdout and process.stdout.readable():
76
+ try:
77
+ line = process.stdout.readline()
78
+ if line:
79
+ print(f"[GPU {i}] {line.strip()}")
80
+ except:
81
+ pass
82
+
83
+ time.sleep(0.1)
84
+
85
+ except KeyboardInterrupt:
86
+ print("\n\nReceived interrupt signal, shutting down workers...")
87
+ self.cleanup()
88
+
89
+ def cleanup(self):
90
+ """Clean up all worker processes"""
91
+ print("Stopping all workers...")
92
+
93
+ for i, process in enumerate(self.processes):
94
+ if process.poll() is None: # Process is still running
95
+ print(f"Stopping worker {i} (PID: {process.pid})...")
96
+ try:
97
+ process.terminate()
98
+ # Wait for graceful shutdown
99
+ process.wait(timeout=5)
100
+ print(f"✓ Worker {i} stopped gracefully")
101
+ except subprocess.TimeoutExpired:
102
+ print(f"⚠️ Force killing worker {i}...")
103
+ process.kill()
104
+ process.wait()
105
+ except Exception as e:
106
+ print(f"Error stopping worker {i}: {e}")
107
+
108
+ print("✓ All workers stopped")
109
+
110
+ def main():
111
+ parser = argparse.ArgumentParser(description="Start multiple GPU workers")
112
+ parser.add_argument("--num-gpus", type=int, required=True,
113
+ help="Number of GPU workers to start")
114
+ parser.add_argument("--dispatcher-url", type=str, default="http://localhost:8000",
115
+ help="URL of the dispatcher service")
116
+ parser.add_argument("--no-monitor", action="store_true",
117
+ help="Start workers but don't monitor them")
118
+
119
+ args = parser.parse_args()
120
+
121
+ if args.num_gpus < 1:
122
+ print("Error: Number of GPUs must be at least 1")
123
+ sys.exit(1)
124
+
125
+ # Check if worker.py exists
126
+ if not os.path.exists("worker.py"):
127
+ print("Error: worker.py not found in current directory")
128
+ sys.exit(1)
129
+
130
+ manager = WorkerManager(args.num_gpus, args.dispatcher_url)
131
+
132
+ # Set up signal handlers for clean shutdown
133
+ def signal_handler(sig, frame):
134
+ print(f"\nReceived signal {sig}, shutting down...")
135
+ manager.cleanup()
136
+ sys.exit(0)
137
+
138
+ signal.signal(signal.SIGINT, signal_handler)
139
+ signal.signal(signal.SIGTERM, signal_handler)
140
+
141
+ # Start workers
142
+ if not manager.start_workers():
143
+ sys.exit(1)
144
+
145
+ if not args.no_monitor:
146
+ manager.monitor_workers()
147
+ else:
148
+ print("Workers started. Use 'ps aux | grep worker.py' to check status.")
149
+ print("To stop workers, use: pkill -f 'python.*worker.py'")
150
+
151
+ if __name__ == "__main__":
152
+ main()
ttt.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import pyautogui
2
+ import time
3
+
4
+ time.sleep(3) # Gives you 3 seconds to switch to another window (e.g., a text editor)
5
+
6
+ pyautogui.press(' ') # Sends a space
7
+ pyautogui.press('space') # Sends another space
8
+
worker.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from typing import List, Tuple, Dict, Any, Optional
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+ import base64
6
+ import io
7
+ import json
8
+ import asyncio
9
+ import time
10
+ import torch
11
+ import os
12
+ import logging
13
+ from utils import initialize_model, sample_frame
14
+ from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
15
+ import concurrent.futures
16
+ import aiohttp
17
+ import argparse
18
+ import uuid
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # GPU settings
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ torch.backends.cudnn.allow_tf32 = True
27
+
28
+ class GPUWorker:
29
+ def __init__(self, gpu_id: int, dispatcher_url: str = "http://localhost:8000"):
30
+ self.gpu_id = gpu_id
31
+ self.dispatcher_url = dispatcher_url
32
+ self.worker_id = f"worker_{gpu_id}_{uuid.uuid4().hex[:8]}"
33
+ self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
34
+ self.current_session: Optional[str] = None
35
+ self.session_data: Dict[str, Any] = {}
36
+
37
+ # Model configuration from main.py
38
+ self.DEBUG_MODE = False
39
+ self.DEBUG_MODE_2 = False
40
+ self.NUM_MAX_FRAMES = 1
41
+ self.TIMESTEPS = 1000
42
+ self.SCREEN_WIDTH = 512
43
+ self.SCREEN_HEIGHT = 384
44
+ self.NUM_SAMPLING_STEPS = 32
45
+ self.USE_RNN = False
46
+
47
+ self.MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-06k"
48
+
49
+ # Initialize model
50
+ self._initialize_model()
51
+
52
+ # Thread executor for heavy computation
53
+ self.thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
54
+
55
+ # Load keyboard mappings
56
+ self._load_keyboard_mappings()
57
+
58
+ logger.info(f"GPU Worker {self.worker_id} initialized on GPU {gpu_id}")
59
+
60
+ def _initialize_model(self):
61
+ """Initialize the model on the specified GPU"""
62
+ logger.info(f"Initializing model on GPU {self.gpu_id}")
63
+
64
+ # Load latent stats
65
+ with open('latent_stats.json', 'r') as f:
66
+ latent_stats = json.load(f)
67
+ self.DATA_NORMALIZATION = {
68
+ 'mean': torch.tensor(latent_stats['mean']).to(self.device),
69
+ 'std': torch.tensor(latent_stats['std']).to(self.device)
70
+ }
71
+ self.LATENT_DIMS = (16, self.SCREEN_HEIGHT // 8, self.SCREEN_WIDTH // 8)
72
+
73
+ # Initialize model based on model name
74
+ if 'origunet' in self.MODEL_NAME:
75
+ if 'x0' in self.MODEL_NAME:
76
+ if 'ddpm32' in self.MODEL_NAME:
77
+ self.TIMESTEPS = 32
78
+ self.model = initialize_model("config_final_model_origunet_nospatial_x0_ddpm32.yaml", self.MODEL_NAME)
79
+ else:
80
+ self.model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", self.MODEL_NAME)
81
+ else:
82
+ if 'ddpm32' in self.MODEL_NAME:
83
+ self.TIMESTEPS = 32
84
+ self.model = initialize_model("config_final_model_origunet_nospatial_ddpm32.yaml", self.MODEL_NAME)
85
+ else:
86
+ self.model = initialize_model("config_final_model_origunet_nospatial.yaml", self.MODEL_NAME)
87
+ else:
88
+ self.model = initialize_model("config_final_model.yaml", self.MODEL_NAME)
89
+
90
+ self.model = self.model.to(self.device)
91
+
92
+ # Create padding image
93
+ self.padding_image = torch.zeros(*self.LATENT_DIMS).unsqueeze(0).to(self.device)
94
+ self.padding_image = (self.padding_image - self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1)
95
+
96
+ logger.info(f"Model initialized successfully on GPU {self.gpu_id}")
97
+
98
+ def _load_keyboard_mappings(self):
99
+ """Load keyboard mappings from main.py"""
100
+ self.KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
101
+ ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
102
+ '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
103
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
104
+ 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
105
+ 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
106
+ 'browserback', 'browserfavorites', 'browserforward', 'browserhome',
107
+ 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
108
+ 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
109
+ 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
110
+ 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
111
+ 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
112
+ 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
113
+ 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
114
+ 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
115
+ 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
116
+ 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
117
+ 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
118
+ 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
119
+ 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
120
+ 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
121
+ 'command', 'option', 'optionleft', 'optionright']
122
+
123
+ self.KEYMAPPING = {
124
+ 'arrowup': 'up',
125
+ 'arrowdown': 'down',
126
+ 'arrowleft': 'left',
127
+ 'arrowright': 'right',
128
+ 'meta': 'command',
129
+ 'contextmenu': 'apps',
130
+ 'control': 'ctrl',
131
+ }
132
+
133
+ self.INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
134
+ 'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
135
+ self.VALID_KEYS = [key for key in self.KEYS if key not in self.INVALID_KEYS]
136
+ self.itos = self.VALID_KEYS
137
+ self.stoi = {key: i for i, key in enumerate(self.itos)}
138
+
139
+ async def register_with_dispatcher(self):
140
+ """Register this worker with the dispatcher"""
141
+ try:
142
+ async with aiohttp.ClientSession() as session:
143
+ await session.post(f"{self.dispatcher_url}/register_worker", json={
144
+ "worker_id": self.worker_id,
145
+ "gpu_id": self.gpu_id,
146
+ "endpoint": f"http://localhost:{8001 + self.gpu_id}"
147
+ })
148
+ logger.info(f"Successfully registered worker {self.worker_id} with dispatcher")
149
+ except Exception as e:
150
+ logger.error(f"Failed to register with dispatcher: {e}")
151
+
152
+ async def ping_dispatcher(self):
153
+ """Periodically ping the dispatcher to maintain connection"""
154
+ while True:
155
+ try:
156
+ async with aiohttp.ClientSession() as session:
157
+ await session.post(f"{self.dispatcher_url}/worker_ping", json={
158
+ "worker_id": self.worker_id,
159
+ "is_available": self.current_session is None
160
+ })
161
+ await asyncio.sleep(10) # Ping every 10 seconds
162
+ except Exception as e:
163
+ logger.error(f"Failed to ping dispatcher: {e}")
164
+ await asyncio.sleep(5) # Retry after 5 seconds on error
165
+
166
+ def prepare_model_inputs(
167
+ self,
168
+ previous_frame: torch.Tensor,
169
+ hidden_states: Any,
170
+ x: int,
171
+ y: int,
172
+ right_click: bool,
173
+ left_click: bool,
174
+ keys_down: List[str],
175
+ time_step: int
176
+ ) -> Dict[str, torch.Tensor]:
177
+ """Prepare inputs for the model (from main.py)"""
178
+ # Clamp coordinates to valid ranges
179
+ x = min(max(0, x), self.SCREEN_WIDTH - 1) if x is not None else 0
180
+ y = min(max(0, y), self.SCREEN_HEIGHT - 1) if y is not None else 0
181
+
182
+ if self.DEBUG_MODE:
183
+ logger.info('DEBUG MODE, SETTING TIME STEP TO 0')
184
+ time_step = 0
185
+ if self.DEBUG_MODE_2:
186
+ if time_step > self.NUM_MAX_FRAMES-1:
187
+ logger.info('DEBUG MODE_2, SETTING TIME STEP TO 0')
188
+ time_step = 0
189
+
190
+ inputs = {
191
+ 'image_features': previous_frame.to(self.device),
192
+ 'is_padding': torch.BoolTensor([time_step == 0]).to(self.device),
193
+ 'x': torch.LongTensor([x]).unsqueeze(0).to(self.device),
194
+ 'y': torch.LongTensor([y]).unsqueeze(0).to(self.device),
195
+ 'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(self.device),
196
+ 'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(self.device),
197
+ 'key_events': torch.zeros(len(self.itos), dtype=torch.long).to(self.device)
198
+ }
199
+
200
+ for key in keys_down:
201
+ key = key.lower()
202
+ if key in self.KEYMAPPING:
203
+ key = self.KEYMAPPING[key]
204
+ if key in self.stoi:
205
+ inputs['key_events'][self.stoi[key]] = 1
206
+ else:
207
+ logger.warning(f'Key {key} not found in stoi')
208
+
209
+ if hidden_states is not None:
210
+ inputs['hidden_states'] = hidden_states
211
+
212
+ if self.DEBUG_MODE:
213
+ logger.info('DEBUG MODE, REMOVING INPUTS')
214
+ if 'hidden_states' in inputs:
215
+ del inputs['hidden_states']
216
+
217
+ if self.DEBUG_MODE_2:
218
+ if time_step > self.NUM_MAX_FRAMES-1:
219
+ logger.info('DEBUG MODE_2, REMOVING HIDDEN STATES')
220
+ if 'hidden_states' in inputs:
221
+ del inputs['hidden_states']
222
+
223
+ logger.info(f'Time step: {time_step}')
224
+ return inputs
225
+
226
+ @torch.no_grad()
227
+ async def process_frame(
228
+ self,
229
+ inputs: Dict[str, torch.Tensor],
230
+ use_rnn: bool = False,
231
+ num_sampling_steps: int = 32
232
+ ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
233
+ """Process a single frame through the model"""
234
+ # Run the heavy computation in a separate thread
235
+ loop = asyncio.get_running_loop()
236
+ return await loop.run_in_executor(
237
+ self.thread_executor,
238
+ lambda: self._process_frame_sync(inputs, use_rnn, num_sampling_steps)
239
+ )
240
+
241
+ def _process_frame_sync(self, inputs, use_rnn, num_sampling_steps):
242
+ """Synchronous version of process_frame that runs in a thread"""
243
+ timing = {}
244
+
245
+ # Temporal encoding
246
+ start = time.perf_counter()
247
+ output_from_rnn, hidden_states = self.model.temporal_encoder.forward_step(inputs)
248
+ timing['temporal_encoder'] = time.perf_counter() - start
249
+
250
+ # UNet sampling
251
+ start = time.perf_counter()
252
+ logger.info(f"model.clip_denoised: {self.model.clip_denoised}")
253
+ self.model.clip_denoised = False
254
+ logger.info(f"USE_RNN: {use_rnn}, NUM_SAMPLING_STEPS: {num_sampling_steps}")
255
+
256
+ if use_rnn:
257
+ sample_latent = output_from_rnn[:, :16]
258
+ else:
259
+ if num_sampling_steps >= self.TIMESTEPS:
260
+ sample_latent = self.model.p_sample_loop(
261
+ cond={'c_concat': output_from_rnn},
262
+ shape=[1, *self.LATENT_DIMS],
263
+ return_intermediates=False,
264
+ verbose=True
265
+ )
266
+ else:
267
+ if num_sampling_steps == 1:
268
+ x = torch.randn([1, *self.LATENT_DIMS], device=self.device)
269
+ t = torch.full((1,), self.TIMESTEPS-1, device=self.device, dtype=torch.long)
270
+ sample_latent = self.model.apply_model(x, t, {'c_concat': output_from_rnn})
271
+ else:
272
+ sampler = DDIMSampler(self.model)
273
+ sample_latent, _ = sampler.sample(
274
+ S=num_sampling_steps,
275
+ conditioning={'c_concat': output_from_rnn},
276
+ batch_size=1,
277
+ shape=self.LATENT_DIMS,
278
+ verbose=False
279
+ )
280
+ timing['unet'] = time.perf_counter() - start
281
+
282
+ # Decoding
283
+ start = time.perf_counter()
284
+ sample = sample_latent * self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1) + self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)
285
+ sample = self.model.decode_first_stage(sample)
286
+ sample = sample.squeeze(0).clamp(-1, 1)
287
+ timing['decode'] = time.perf_counter() - start
288
+
289
+ # Convert to image
290
+ sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
291
+
292
+ timing['total'] = sum(timing.values())
293
+
294
+ return sample_latent, sample_img, hidden_states, timing
295
+
296
+ def initialize_session(self, session_id: str):
297
+ """Initialize a new session"""
298
+ self.current_session = session_id
299
+ self.session_data[session_id] = {
300
+ 'previous_frame': self.padding_image,
301
+ 'hidden_states': None,
302
+ 'keys_down': set(),
303
+ 'frame_num': -1,
304
+ 'client_settings': {
305
+ 'use_rnn': self.USE_RNN,
306
+ 'sampling_steps': self.NUM_SAMPLING_STEPS
307
+ },
308
+ 'input_queue': asyncio.Queue(),
309
+ 'is_processing': False
310
+ }
311
+ logger.info(f"Initialized session {session_id}")
312
+
313
+ # Start processing task for this session
314
+ asyncio.create_task(self._process_session_queue(session_id))
315
+
316
+ def end_session(self, session_id: str):
317
+ """End a session and clean up"""
318
+ if session_id in self.session_data:
319
+ # Clear any remaining items in the queue
320
+ session = self.session_data[session_id]
321
+ while not session['input_queue'].empty():
322
+ try:
323
+ session['input_queue'].get_nowait()
324
+ session['input_queue'].task_done()
325
+ except asyncio.QueueEmpty:
326
+ break
327
+ del self.session_data[session_id]
328
+ if self.current_session == session_id:
329
+ self.current_session = None
330
+ logger.info(f"Ended session {session_id}")
331
+
332
+ async def _process_session_queue(self, session_id: str):
333
+ """Process the input queue for a specific session with interesting input filtering"""
334
+ while session_id in self.session_data:
335
+ try:
336
+ session = self.session_data[session_id]
337
+ input_queue = session['input_queue']
338
+
339
+ # Wait for input to be available
340
+ if input_queue.empty():
341
+ await asyncio.sleep(0.01) # Small delay to prevent busy waiting
342
+ continue
343
+
344
+ # If already processing, skip
345
+ if session['is_processing']:
346
+ await asyncio.sleep(0.01)
347
+ continue
348
+
349
+ # Set processing flag
350
+ session['is_processing'] = True
351
+
352
+ try:
353
+ # Process queue with interesting input filtering
354
+ await self._process_next_input(session_id)
355
+ finally:
356
+ session['is_processing'] = False
357
+
358
+ except Exception as e:
359
+ logger.error(f"Error in session queue processing for {session_id}: {e}")
360
+ import traceback
361
+ traceback.print_exc()
362
+ await asyncio.sleep(1) # Prevent tight error loop
363
+
364
+ logger.info(f"Session queue processor ended for {session_id}")
365
+
366
+ async def _process_next_input(self, session_id: str):
367
+ """Process next input with interesting input filtering (from main.py logic)"""
368
+ session = self.session_data[session_id]
369
+ input_queue = session['input_queue']
370
+
371
+ if input_queue.empty():
372
+ return
373
+
374
+ queue_size = input_queue.qsize()
375
+ logger.info(f"Processing next input for session {session_id}. Queue size: {queue_size}")
376
+
377
+ try:
378
+ # Initialize variables to track progress
379
+ skipped = 0
380
+ latest_input = None
381
+
382
+ # Process the queue one item at a time
383
+ while not input_queue.empty():
384
+ current_input = await input_queue.get()
385
+ input_queue.task_done()
386
+
387
+ # Always update the latest input
388
+ latest_input = current_input
389
+
390
+ # Check if this is an interesting event
391
+ is_interesting = (current_input.get("is_left_click") or
392
+ current_input.get("is_right_click") or
393
+ (current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
394
+ (current_input.get("keys_up") and len(current_input.get("keys_up")) > 0))
395
+
396
+ # Process immediately if interesting
397
+ if is_interesting:
398
+ logger.info(f"Found interesting input for session {session_id} (skipped {skipped} events)")
399
+ await self._process_single_input(session_id, current_input)
400
+ return
401
+
402
+ # Otherwise, continue to the next item
403
+ skipped += 1
404
+
405
+ # If this is the last item and no interesting inputs were found
406
+ if input_queue.empty():
407
+ logger.info(f"No interesting inputs for session {session_id}, processing latest movement (skipped {skipped-1} events)")
408
+ await self._process_single_input(session_id, latest_input)
409
+ return
410
+
411
+ except Exception as e:
412
+ logger.error(f"Error in _process_next_input for session {session_id}: {e}")
413
+ import traceback
414
+ traceback.print_exc()
415
+
416
+ async def process_input(self, session_id: str, data: dict) -> dict:
417
+ """Process input for a session - adds to queue or handles control messages"""
418
+ if session_id not in self.session_data:
419
+ self.initialize_session(session_id)
420
+
421
+ session = self.session_data[session_id]
422
+
423
+ # Handle control messages immediately (don't queue these)
424
+ if data.get("type") == "reset":
425
+ logger.info(f"Received reset command for session {session_id}")
426
+ # Clear the queue
427
+ while not session['input_queue'].empty():
428
+ try:
429
+ session['input_queue'].get_nowait()
430
+ session['input_queue'].task_done()
431
+ except asyncio.QueueEmpty:
432
+ break
433
+ session['previous_frame'] = self.padding_image
434
+ session['hidden_states'] = None
435
+ session['keys_down'] = set()
436
+ session['frame_num'] = -1
437
+ return {"type": "reset_confirmed"}
438
+
439
+ elif data.get("type") == "update_sampling_steps":
440
+ steps = data.get("steps", 32)
441
+ if steps < 1:
442
+ return {"type": "error", "message": "Invalid sampling steps value"}
443
+ session['client_settings']['sampling_steps'] = steps
444
+ logger.info(f"Updated sampling steps to {steps} for session {session_id}")
445
+ return {"type": "steps_updated", "steps": steps}
446
+
447
+ elif data.get("type") == "update_use_rnn":
448
+ use_rnn = data.get("use_rnn", False)
449
+ session['client_settings']['use_rnn'] = use_rnn
450
+ logger.info(f"Updated USE_RNN to {use_rnn} for session {session_id}")
451
+ return {"type": "rnn_updated", "use_rnn": use_rnn}
452
+
453
+ elif data.get("type") == "get_settings":
454
+ return {
455
+ "type": "settings",
456
+ "sampling_steps": session['client_settings']['sampling_steps'],
457
+ "use_rnn": session['client_settings']['use_rnn']
458
+ }
459
+
460
+ elif data.get("type") == "heartbeat":
461
+ return {"type": "heartbeat_response"}
462
+
463
+ # For regular input data, add to queue and return immediately
464
+ # The actual processing will happen asynchronously in the queue processor
465
+ await session['input_queue'].put(data)
466
+ queue_size = session['input_queue'].qsize()
467
+ logger.info(f"Added input to queue for session {session_id}. Queue size: {queue_size}")
468
+
469
+ # Return a placeholder response - the real response will be sent via WebSocket
470
+ return {"type": "queued", "queue_size": queue_size}
471
+
472
+ async def _process_single_input(self, session_id: str, data: dict):
473
+ """Process a single input for a session (the actual processing logic)"""
474
+ session = self.session_data[session_id]
475
+
476
+ # Process regular input
477
+ try:
478
+ session['frame_num'] += 1
479
+
480
+ # Extract input data
481
+ x = max(0, min(data.get("x", 0), self.SCREEN_WIDTH - 1))
482
+ y = max(0, min(data.get("y", 0), self.SCREEN_HEIGHT - 1))
483
+ is_left_click = data.get("is_left_click", False)
484
+ is_right_click = data.get("is_right_click", False)
485
+ keys_down_list = data.get("keys_down", [])
486
+ keys_up_list = data.get("keys_up", [])
487
+
488
+ # Update keys_down set
489
+ for key in keys_down_list:
490
+ key = key.lower()
491
+ if key in self.KEYMAPPING:
492
+ key = self.KEYMAPPING[key]
493
+ session['keys_down'].add(key)
494
+
495
+ for key in keys_up_list:
496
+ key = key.lower()
497
+ if key in self.KEYMAPPING:
498
+ key = self.KEYMAPPING[key]
499
+ session['keys_down'].discard(key)
500
+
501
+ # Handle debug modes
502
+ if self.DEBUG_MODE:
503
+ logger.info("DEBUG MODE, REMOVING HIDDEN STATES")
504
+ session['previous_frame'] = self.padding_image
505
+
506
+ if self.DEBUG_MODE_2:
507
+ if session['frame_num'] > self.NUM_MAX_FRAMES-1:
508
+ logger.info("DEBUG MODE_2, REMOVING HIDDEN STATES")
509
+ session['previous_frame'] = self.padding_image
510
+ session['frame_num'] = 0
511
+
512
+ # Prepare model inputs
513
+ inputs = self.prepare_model_inputs(
514
+ session['previous_frame'],
515
+ session['hidden_states'],
516
+ x, y, is_right_click, is_left_click,
517
+ list(session['keys_down']),
518
+ session['frame_num']
519
+ )
520
+
521
+ # Process frame
522
+ logger.info(f"Processing frame {session['frame_num']} for session {session_id}")
523
+ sample_latent, sample_img, hidden_states, timing_info = await self.process_frame(
524
+ inputs,
525
+ use_rnn=session['client_settings']['use_rnn'],
526
+ num_sampling_steps=session['client_settings']['sampling_steps']
527
+ )
528
+
529
+ # Update session state
530
+ session['previous_frame'] = sample_latent
531
+ session['hidden_states'] = hidden_states
532
+
533
+ # Convert image to base64
534
+ img = Image.fromarray(sample_img)
535
+ buffered = io.BytesIO()
536
+ img.save(buffered, format="PNG")
537
+ img_str = base64.b64encode(buffered.getvalue()).decode()
538
+
539
+ # Log timing
540
+ logger.info(f"Frame {session['frame_num']} processed in {timing_info['total']:.4f}s (FPS: {1.0/timing_info['total']:.2f})")
541
+
542
+ # Send result back to dispatcher
543
+ await self._send_result_to_dispatcher(session_id, {"image": img_str})
544
+
545
+ except Exception as e:
546
+ logger.error(f"Error processing input for session {session_id}: {e}")
547
+ import traceback
548
+ traceback.print_exc()
549
+ await self._send_result_to_dispatcher(session_id, {"type": "error", "message": str(e)})
550
+
551
+ async def _send_result_to_dispatcher(self, session_id: str, result: dict):
552
+ """Send processing result back to dispatcher"""
553
+ try:
554
+ async with aiohttp.ClientSession() as client_session:
555
+ await client_session.post(f"{self.dispatcher_url}/worker_result", json={
556
+ "session_id": session_id,
557
+ "worker_id": self.worker_id,
558
+ "result": result
559
+ })
560
+ except Exception as e:
561
+ logger.error(f"Failed to send result to dispatcher: {e}")
562
+
563
+ # FastAPI app for the worker
564
+ app = FastAPI()
565
+
566
+ # Global worker instance
567
+ worker: Optional[GPUWorker] = None
568
+
569
+ @app.post("/process_input")
570
+ async def process_input_endpoint(request: dict):
571
+ """Process input from dispatcher"""
572
+ if not worker:
573
+ raise HTTPException(status_code=500, detail="Worker not initialized")
574
+
575
+ session_id = request.get("session_id")
576
+ data = request.get("data")
577
+
578
+ if not session_id or not data:
579
+ raise HTTPException(status_code=400, detail="Missing session_id or data")
580
+
581
+ result = await worker.process_input(session_id, data)
582
+ return result
583
+
584
+ @app.post("/end_session")
585
+ async def end_session_endpoint(request: dict):
586
+ """End a session"""
587
+ if not worker:
588
+ raise HTTPException(status_code=500, detail="Worker not initialized")
589
+
590
+ session_id = request.get("session_id")
591
+ if not session_id:
592
+ raise HTTPException(status_code=400, detail="Missing session_id")
593
+
594
+ worker.end_session(session_id)
595
+ return {"status": "session_ended"}
596
+
597
+ @app.get("/health")
598
+ async def health_check():
599
+ """Health check endpoint"""
600
+ return {
601
+ "status": "healthy",
602
+ "worker_id": worker.worker_id if worker else None,
603
+ "gpu_id": worker.gpu_id if worker else None,
604
+ "current_session": worker.current_session if worker else None
605
+ }
606
+
607
+ async def startup_worker(gpu_id: int, dispatcher_url: str):
608
+ """Initialize the worker"""
609
+ global worker
610
+ worker = GPUWorker(gpu_id, dispatcher_url)
611
+
612
+ # Register with dispatcher
613
+ await worker.register_with_dispatcher()
614
+
615
+ # Start ping task
616
+ asyncio.create_task(worker.ping_dispatcher())
617
+
618
+ if __name__ == "__main__":
619
+ import uvicorn
620
+
621
+ # Parse command line arguments
622
+ parser = argparse.ArgumentParser(description="GPU Worker for Neural OS")
623
+ parser.add_argument("--gpu-id", type=int, required=True, help="GPU ID to use")
624
+ parser.add_argument("--dispatcher-url", type=str, default="http://localhost:8000", help="Dispatcher URL")
625
+ args = parser.parse_args()
626
+
627
+ # Calculate port based on GPU ID
628
+ port = 8001 + args.gpu_id
629
+
630
+ @app.on_event("startup")
631
+ async def startup_event():
632
+ await startup_worker(args.gpu_id, args.dispatcher_url)
633
+
634
+ logger.info(f"Starting worker on GPU {args.gpu_id}, port {port}")
635
+ uvicorn.run(app, host="0.0.0.0", port=port)