Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
8b76adf
1
Parent(s):
4036b4b
- dispatcher.py +282 -4
- main.py +10 -2
- start_system.sh +3 -0
- static/index.html +30 -3
- worker.py +103 -8
dispatcher.py
CHANGED
@@ -11,11 +11,157 @@ from enum import Enum
|
|
11 |
import uuid
|
12 |
import aiohttp
|
13 |
import logging
|
|
|
|
|
14 |
|
15 |
# Configure logging
|
16 |
logging.basicConfig(level=logging.INFO)
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class SessionStatus(Enum):
|
20 |
QUEUED = "queued"
|
21 |
ACTIVE = "active"
|
@@ -33,6 +179,9 @@ class UserSession:
|
|
33 |
last_activity: Optional[float] = None
|
34 |
max_session_time: Optional[float] = None
|
35 |
user_has_interacted: bool = False
|
|
|
|
|
|
|
36 |
|
37 |
@dataclass
|
38 |
class WorkerInfo:
|
@@ -67,6 +216,15 @@ class SessionManager:
|
|
67 |
last_ping=time.time()
|
68 |
)
|
69 |
logger.info(f"Registered worker {worker_id} on GPU {gpu_id} at {endpoint}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
async def get_available_worker(self) -> Optional[WorkerInfo]:
|
72 |
"""Get an available worker"""
|
@@ -80,6 +238,7 @@ class SessionManager:
|
|
80 |
self.sessions[session.session_id] = session
|
81 |
self.session_queue.append(session.session_id)
|
82 |
session.status = SessionStatus.QUEUED
|
|
|
83 |
logger.info(f"Added session {session.session_id} to queue. Queue size: {len(self.session_queue)}")
|
84 |
|
85 |
async def process_queue(self):
|
@@ -94,8 +253,15 @@ class SessionManager:
|
|
94 |
|
95 |
worker = await self.get_available_worker()
|
96 |
if not worker:
|
|
|
|
|
|
|
97 |
break # No available workers
|
98 |
|
|
|
|
|
|
|
|
|
99 |
# Assign session to worker
|
100 |
self.session_queue.pop(0)
|
101 |
session.status = SessionStatus.ACTIVE
|
@@ -112,6 +278,28 @@ class SessionManager:
|
|
112 |
|
113 |
logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
# Notify user that their session is starting
|
116 |
await self.notify_session_start(session)
|
117 |
|
@@ -199,12 +387,25 @@ class SessionManager:
|
|
199 |
|
200 |
session.status = status
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
# Free up the worker
|
203 |
if session.worker_id and session.worker_id in self.workers:
|
204 |
worker = self.workers[session.worker_id]
|
205 |
worker.is_available = True
|
206 |
worker.current_session = None
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
# Notify worker to clean up
|
209 |
try:
|
210 |
async with aiohttp.ClientSession() as client_session:
|
@@ -241,6 +442,11 @@ class SessionManager:
|
|
241 |
})
|
242 |
except Exception as e:
|
243 |
logger.error(f"Failed to send queue update to session {session_id}: {e}")
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
def _calculate_dynamic_wait_time(self, position_in_queue: int) -> float:
|
246 |
"""Calculate dynamic estimated wait time based on current session progress"""
|
@@ -308,6 +514,7 @@ class SessionManager:
|
|
308 |
session = self.sessions.get(session_id)
|
309 |
if session:
|
310 |
session.last_activity = time.time()
|
|
|
311 |
if not session.user_has_interacted:
|
312 |
session.user_has_interacted = True
|
313 |
logger.info(f"User started interacting in session {session_id}")
|
@@ -335,6 +542,9 @@ session_manager = SessionManager()
|
|
335 |
app = FastAPI()
|
336 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
337 |
|
|
|
|
|
|
|
338 |
@app.get("/")
|
339 |
async def get():
|
340 |
return HTMLResponse(open("static/index.html").read())
|
@@ -383,21 +593,39 @@ async def worker_result(result_data: dict):
|
|
383 |
|
384 |
@app.websocket("/ws")
|
385 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
386 |
await websocket.accept()
|
387 |
|
388 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
session_id = str(uuid.uuid4())
|
390 |
-
client_id = f"{int(time.time())}_{
|
391 |
|
392 |
session = UserSession(
|
393 |
session_id=session_id,
|
394 |
client_id=client_id,
|
395 |
websocket=websocket,
|
396 |
created_at=time.time(),
|
397 |
-
status=SessionStatus.QUEUED
|
|
|
398 |
)
|
399 |
|
400 |
-
logger.info(f"New WebSocket connection: {client_id}")
|
|
|
|
|
|
|
401 |
|
402 |
try:
|
403 |
# Add to queue
|
@@ -492,10 +720,60 @@ async def periodic_queue_update():
|
|
492 |
except Exception as e:
|
493 |
logger.error(f"Error in periodic queue update: {e}")
|
494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
@app.on_event("startup")
|
496 |
async def startup_event():
|
497 |
# Start background tasks
|
498 |
asyncio.create_task(periodic_queue_update())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
if __name__ == "__main__":
|
501 |
import uvicorn
|
|
|
11 |
import uuid
|
12 |
import aiohttp
|
13 |
import logging
|
14 |
+
from collections import defaultdict, deque
|
15 |
+
from datetime import datetime
|
16 |
|
17 |
# Configure logging
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
# Analytics and monitoring
|
22 |
+
class SystemAnalytics:
|
23 |
+
def __init__(self):
|
24 |
+
self.start_time = time.time()
|
25 |
+
self.total_connections = 0
|
26 |
+
self.active_connections = 0
|
27 |
+
self.total_interactions = 0
|
28 |
+
self.ip_addresses = defaultdict(int) # IP -> connection count
|
29 |
+
self.session_durations = deque(maxlen=100) # Last 100 session durations
|
30 |
+
self.waiting_times = deque(maxlen=100) # Last 100 waiting times
|
31 |
+
self.users_bypassed_queue = 0 # Users who got GPU immediately
|
32 |
+
self.users_waited_in_queue = 0 # Users who had to wait
|
33 |
+
self.gpu_utilization_samples = deque(maxlen=100) # GPU utilization over time
|
34 |
+
self.queue_size_samples = deque(maxlen=100) # Queue size over time
|
35 |
+
self.log_file = None
|
36 |
+
self._init_log_file()
|
37 |
+
|
38 |
+
def _init_log_file(self):
|
39 |
+
"""Initialize the system log file"""
|
40 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
41 |
+
log_filename = f"system_analytics_{timestamp}.log"
|
42 |
+
self.log_file = log_filename
|
43 |
+
self._write_log("="*80)
|
44 |
+
self._write_log("NEURAL OS MULTI-GPU SYSTEM ANALYTICS")
|
45 |
+
self._write_log("="*80)
|
46 |
+
self._write_log(f"System started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
47 |
+
self._write_log("")
|
48 |
+
|
49 |
+
def _write_log(self, message):
|
50 |
+
"""Write message to log file and console"""
|
51 |
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
52 |
+
log_message = f"[{timestamp}] {message}"
|
53 |
+
print(log_message)
|
54 |
+
with open(self.log_file, "a") as f:
|
55 |
+
f.write(log_message + "\n")
|
56 |
+
|
57 |
+
def log_new_connection(self, client_id: str, ip: str):
|
58 |
+
"""Log new connection"""
|
59 |
+
self.total_connections += 1
|
60 |
+
self.active_connections += 1
|
61 |
+
self.ip_addresses[ip] += 1
|
62 |
+
|
63 |
+
unique_ips = len(self.ip_addresses)
|
64 |
+
self._write_log(f"🔗 NEW CONNECTION: {client_id} from {ip}")
|
65 |
+
self._write_log(f" 📊 Total connections: {self.total_connections} | Active: {self.active_connections} | Unique IPs: {unique_ips}")
|
66 |
+
|
67 |
+
def log_connection_closed(self, client_id: str, duration: float, interactions: int, reason: str = "normal"):
|
68 |
+
"""Log connection closed"""
|
69 |
+
self.active_connections -= 1
|
70 |
+
self.total_interactions += interactions
|
71 |
+
self.session_durations.append(duration)
|
72 |
+
|
73 |
+
avg_duration = sum(self.session_durations) / len(self.session_durations) if self.session_durations else 0
|
74 |
+
|
75 |
+
self._write_log(f"🚪 CONNECTION CLOSED: {client_id}")
|
76 |
+
self._write_log(f" ⏱️ Duration: {duration:.1f}s | Interactions: {interactions} | Reason: {reason}")
|
77 |
+
self._write_log(f" 📊 Active connections: {self.active_connections} | Avg session duration: {avg_duration:.1f}s")
|
78 |
+
|
79 |
+
def log_queue_bypass(self, client_id: str):
|
80 |
+
"""Log when user bypasses queue (gets GPU immediately)"""
|
81 |
+
self.users_bypassed_queue += 1
|
82 |
+
bypass_rate = (self.users_bypassed_queue / self.total_connections) * 100 if self.total_connections > 0 else 0
|
83 |
+
self._write_log(f"⚡ QUEUE BYPASS: {client_id} got GPU immediately")
|
84 |
+
self._write_log(f" 📊 Bypass rate: {bypass_rate:.1f}% ({self.users_bypassed_queue}/{self.total_connections})")
|
85 |
+
|
86 |
+
def log_queue_wait(self, client_id: str, wait_time: float, queue_position: int):
|
87 |
+
"""Log when user had to wait in queue"""
|
88 |
+
self.users_waited_in_queue += 1
|
89 |
+
self.waiting_times.append(wait_time)
|
90 |
+
|
91 |
+
avg_wait = sum(self.waiting_times) / len(self.waiting_times) if self.waiting_times else 0
|
92 |
+
wait_rate = (self.users_waited_in_queue / self.total_connections) * 100 if self.total_connections > 0 else 0
|
93 |
+
|
94 |
+
self._write_log(f"⏳ QUEUE WAIT: {client_id} waited {wait_time:.1f}s (was #{queue_position})")
|
95 |
+
self._write_log(f" 📊 Wait rate: {wait_rate:.1f}% | Avg wait time: {avg_wait:.1f}s")
|
96 |
+
|
97 |
+
def log_gpu_status(self, total_gpus: int, active_gpus: int, available_gpus: int):
|
98 |
+
"""Log GPU utilization"""
|
99 |
+
utilization = (active_gpus / total_gpus) * 100 if total_gpus > 0 else 0
|
100 |
+
self.gpu_utilization_samples.append(utilization)
|
101 |
+
|
102 |
+
avg_utilization = sum(self.gpu_utilization_samples) / len(self.gpu_utilization_samples) if self.gpu_utilization_samples else 0
|
103 |
+
|
104 |
+
self._write_log(f"🖥️ GPU STATUS: {active_gpus}/{total_gpus} in use ({utilization:.1f}% utilization)")
|
105 |
+
self._write_log(f" 📊 Available: {available_gpus} | Avg utilization: {avg_utilization:.1f}%")
|
106 |
+
|
107 |
+
def log_worker_registered(self, worker_id: str, gpu_id: int, endpoint: str):
|
108 |
+
"""Log when a worker registers"""
|
109 |
+
self._write_log(f"⚙️ WORKER REGISTERED: {worker_id} (GPU {gpu_id}) at {endpoint}")
|
110 |
+
|
111 |
+
def log_worker_disconnected(self, worker_id: str, gpu_id: int):
|
112 |
+
"""Log when a worker disconnects"""
|
113 |
+
self._write_log(f"⚙️ WORKER DISCONNECTED: {worker_id} (GPU {gpu_id})")
|
114 |
+
|
115 |
+
def log_no_workers_available(self, queue_size: int):
|
116 |
+
"""Log critical situation when no workers are available"""
|
117 |
+
self._write_log(f"⚠️ CRITICAL: No GPU workers available! {queue_size} users waiting")
|
118 |
+
self._write_log(" Please check worker processes and GPU availability")
|
119 |
+
|
120 |
+
def log_queue_status(self, queue_size: int, estimated_wait: float):
|
121 |
+
"""Log queue status"""
|
122 |
+
self.queue_size_samples.append(queue_size)
|
123 |
+
|
124 |
+
avg_queue_size = sum(self.queue_size_samples) / len(self.queue_size_samples) if self.queue_size_samples else 0
|
125 |
+
|
126 |
+
if queue_size > 0:
|
127 |
+
self._write_log(f"📝 QUEUE STATUS: {queue_size} users waiting | Est. wait: {estimated_wait:.1f}s")
|
128 |
+
self._write_log(f" 📊 Avg queue size: {avg_queue_size:.1f}")
|
129 |
+
|
130 |
+
def log_periodic_summary(self):
|
131 |
+
"""Log periodic system summary"""
|
132 |
+
uptime = time.time() - self.start_time
|
133 |
+
uptime_hours = uptime / 3600
|
134 |
+
|
135 |
+
unique_ips = len(self.ip_addresses)
|
136 |
+
avg_duration = sum(self.session_durations) / len(self.session_durations) if self.session_durations else 0
|
137 |
+
avg_wait = sum(self.waiting_times) / len(self.waiting_times) if self.waiting_times else 0
|
138 |
+
avg_utilization = sum(self.gpu_utilization_samples) / len(self.gpu_utilization_samples) if self.gpu_utilization_samples else 0
|
139 |
+
avg_queue_size = sum(self.queue_size_samples) / len(self.queue_size_samples) if self.queue_size_samples else 0
|
140 |
+
|
141 |
+
bypass_rate = (self.users_bypassed_queue / self.total_connections) * 100 if self.total_connections > 0 else 0
|
142 |
+
|
143 |
+
self._write_log("")
|
144 |
+
self._write_log("="*60)
|
145 |
+
self._write_log("📊 SYSTEM SUMMARY")
|
146 |
+
self._write_log("="*60)
|
147 |
+
self._write_log(f"⏱️ Uptime: {uptime_hours:.1f} hours")
|
148 |
+
self._write_log(f"🔗 Connections: {self.total_connections} total | {self.active_connections} active | {unique_ips} unique IPs")
|
149 |
+
self._write_log(f"💬 Total interactions: {self.total_interactions}")
|
150 |
+
self._write_log(f"⚡ Queue bypass rate: {bypass_rate:.1f}% ({self.users_bypassed_queue}/{self.total_connections})")
|
151 |
+
self._write_log(f"⏳ Avg waiting time: {avg_wait:.1f}s")
|
152 |
+
self._write_log(f"📝 Avg queue size: {avg_queue_size:.1f}")
|
153 |
+
self._write_log(f"🖥️ Avg GPU utilization: {avg_utilization:.1f}%")
|
154 |
+
self._write_log(f"⏱️ Avg session duration: {avg_duration:.1f}s")
|
155 |
+
self._write_log("")
|
156 |
+
self._write_log("🌍 TOP IP ADDRESSES:")
|
157 |
+
for ip, count in sorted(self.ip_addresses.items(), key=lambda x: x[1], reverse=True)[:10]:
|
158 |
+
self._write_log(f" {ip}: {count} connections")
|
159 |
+
self._write_log("="*60)
|
160 |
+
self._write_log("")
|
161 |
+
|
162 |
+
# Initialize analytics
|
163 |
+
analytics = SystemAnalytics()
|
164 |
+
|
165 |
class SessionStatus(Enum):
|
166 |
QUEUED = "queued"
|
167 |
ACTIVE = "active"
|
|
|
179 |
last_activity: Optional[float] = None
|
180 |
max_session_time: Optional[float] = None
|
181 |
user_has_interacted: bool = False
|
182 |
+
ip_address: Optional[str] = None
|
183 |
+
interaction_count: int = 0
|
184 |
+
queue_start_time: Optional[float] = None
|
185 |
|
186 |
@dataclass
|
187 |
class WorkerInfo:
|
|
|
216 |
last_ping=time.time()
|
217 |
)
|
218 |
logger.info(f"Registered worker {worker_id} on GPU {gpu_id} at {endpoint}")
|
219 |
+
|
220 |
+
# Log worker registration
|
221 |
+
analytics.log_worker_registered(worker_id, gpu_id, endpoint)
|
222 |
+
|
223 |
+
# Log GPU status
|
224 |
+
total_gpus = len(self.workers)
|
225 |
+
active_gpus = len([w for w in self.workers.values() if not w.is_available])
|
226 |
+
available_gpus = total_gpus - active_gpus
|
227 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
228 |
|
229 |
async def get_available_worker(self) -> Optional[WorkerInfo]:
|
230 |
"""Get an available worker"""
|
|
|
238 |
self.sessions[session.session_id] = session
|
239 |
self.session_queue.append(session.session_id)
|
240 |
session.status = SessionStatus.QUEUED
|
241 |
+
session.queue_start_time = time.time()
|
242 |
logger.info(f"Added session {session.session_id} to queue. Queue size: {len(self.session_queue)}")
|
243 |
|
244 |
async def process_queue(self):
|
|
|
253 |
|
254 |
worker = await self.get_available_worker()
|
255 |
if not worker:
|
256 |
+
# Log critical situation if no workers are available
|
257 |
+
if len(self.workers) == 0:
|
258 |
+
analytics.log_no_workers_available(len(self.session_queue))
|
259 |
break # No available workers
|
260 |
|
261 |
+
# Calculate wait time
|
262 |
+
wait_time = time.time() - session.queue_start_time if session.queue_start_time else 0
|
263 |
+
queue_position = self.session_queue.index(session_id) + 1
|
264 |
+
|
265 |
# Assign session to worker
|
266 |
self.session_queue.pop(0)
|
267 |
session.status = SessionStatus.ACTIVE
|
|
|
278 |
|
279 |
logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
|
280 |
|
281 |
+
# Log analytics
|
282 |
+
if wait_time > 0:
|
283 |
+
analytics.log_queue_wait(session.client_id, wait_time, queue_position)
|
284 |
+
else:
|
285 |
+
analytics.log_queue_bypass(session.client_id)
|
286 |
+
|
287 |
+
# Log GPU status
|
288 |
+
total_gpus = len(self.workers)
|
289 |
+
active_gpus = len([w for w in self.workers.values() if not w.is_available])
|
290 |
+
available_gpus = total_gpus - active_gpus
|
291 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
292 |
+
|
293 |
+
# Initialize session on worker with client_id for logging
|
294 |
+
try:
|
295 |
+
async with aiohttp.ClientSession() as client_session:
|
296 |
+
await client_session.post(f"{worker.endpoint}/init_session", json={
|
297 |
+
"session_id": session_id,
|
298 |
+
"client_id": session.client_id
|
299 |
+
})
|
300 |
+
except Exception as e:
|
301 |
+
logger.error(f"Failed to initialize session on worker {worker.worker_id}: {e}")
|
302 |
+
|
303 |
# Notify user that their session is starting
|
304 |
await self.notify_session_start(session)
|
305 |
|
|
|
387 |
|
388 |
session.status = status
|
389 |
|
390 |
+
# Calculate session duration
|
391 |
+
duration = time.time() - session.created_at
|
392 |
+
|
393 |
+
# Log analytics
|
394 |
+
reason = "timeout" if status == SessionStatus.TIMEOUT else "normal"
|
395 |
+
analytics.log_connection_closed(session.client_id, duration, session.interaction_count, reason)
|
396 |
+
|
397 |
# Free up the worker
|
398 |
if session.worker_id and session.worker_id in self.workers:
|
399 |
worker = self.workers[session.worker_id]
|
400 |
worker.is_available = True
|
401 |
worker.current_session = None
|
402 |
|
403 |
+
# Log GPU status
|
404 |
+
total_gpus = len(self.workers)
|
405 |
+
active_gpus = len([w for w in self.workers.values() if not w.is_available])
|
406 |
+
available_gpus = total_gpus - active_gpus
|
407 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
408 |
+
|
409 |
# Notify worker to clean up
|
410 |
try:
|
411 |
async with aiohttp.ClientSession() as client_session:
|
|
|
442 |
})
|
443 |
except Exception as e:
|
444 |
logger.error(f"Failed to send queue update to session {session_id}: {e}")
|
445 |
+
|
446 |
+
# Log queue status if there's a queue
|
447 |
+
if self.session_queue:
|
448 |
+
estimated_wait = self._calculate_dynamic_wait_time(1)
|
449 |
+
analytics.log_queue_status(len(self.session_queue), estimated_wait)
|
450 |
|
451 |
def _calculate_dynamic_wait_time(self, position_in_queue: int) -> float:
|
452 |
"""Calculate dynamic estimated wait time based on current session progress"""
|
|
|
514 |
session = self.sessions.get(session_id)
|
515 |
if session:
|
516 |
session.last_activity = time.time()
|
517 |
+
session.interaction_count += 1
|
518 |
if not session.user_has_interacted:
|
519 |
session.user_has_interacted = True
|
520 |
logger.info(f"User started interacting in session {session_id}")
|
|
|
542 |
app = FastAPI()
|
543 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
544 |
|
545 |
+
# Global connection counter like in main.py
|
546 |
+
connection_counter = 0
|
547 |
+
|
548 |
@app.get("/")
|
549 |
async def get():
|
550 |
return HTMLResponse(open("static/index.html").read())
|
|
|
593 |
|
594 |
@app.websocket("/ws")
|
595 |
async def websocket_endpoint(websocket: WebSocket):
|
596 |
+
global connection_counter
|
597 |
await websocket.accept()
|
598 |
|
599 |
+
# Extract client IP address
|
600 |
+
client_ip = "unknown"
|
601 |
+
if websocket.client and hasattr(websocket.client, 'host'):
|
602 |
+
client_ip = websocket.client.host
|
603 |
+
elif hasattr(websocket, 'headers'):
|
604 |
+
# Try to get real IP from headers (for proxy setups)
|
605 |
+
client_ip = websocket.headers.get('x-forwarded-for',
|
606 |
+
websocket.headers.get('x-real-ip',
|
607 |
+
websocket.headers.get('cf-connecting-ip', 'unknown')))
|
608 |
+
if ',' in client_ip:
|
609 |
+
client_ip = client_ip.split(',')[0].strip()
|
610 |
+
|
611 |
+
# Create session with connection counter like in main.py
|
612 |
+
connection_counter += 1
|
613 |
session_id = str(uuid.uuid4())
|
614 |
+
client_id = f"{int(time.time())}_{connection_counter}"
|
615 |
|
616 |
session = UserSession(
|
617 |
session_id=session_id,
|
618 |
client_id=client_id,
|
619 |
websocket=websocket,
|
620 |
created_at=time.time(),
|
621 |
+
status=SessionStatus.QUEUED,
|
622 |
+
ip_address=client_ip
|
623 |
)
|
624 |
|
625 |
+
logger.info(f"New WebSocket connection: {client_id} from {client_ip}")
|
626 |
+
|
627 |
+
# Log new connection analytics
|
628 |
+
analytics.log_new_connection(client_id, client_ip)
|
629 |
|
630 |
try:
|
631 |
# Add to queue
|
|
|
720 |
except Exception as e:
|
721 |
logger.error(f"Error in periodic queue update: {e}")
|
722 |
|
723 |
+
# Background task to periodically log analytics summary
|
724 |
+
async def periodic_analytics_summary():
|
725 |
+
while True:
|
726 |
+
try:
|
727 |
+
await asyncio.sleep(300) # Log summary every 5 minutes
|
728 |
+
analytics.log_periodic_summary()
|
729 |
+
except Exception as e:
|
730 |
+
logger.error(f"Error in periodic analytics summary: {e}")
|
731 |
+
|
732 |
+
# Background task to check worker health
|
733 |
+
async def periodic_worker_health_check():
|
734 |
+
while True:
|
735 |
+
try:
|
736 |
+
await asyncio.sleep(60) # Check every minute
|
737 |
+
current_time = time.time()
|
738 |
+
disconnected_workers = []
|
739 |
+
|
740 |
+
for worker_id, worker in list(session_manager.workers.items()):
|
741 |
+
if current_time - worker.last_ping > 30: # 30 second timeout
|
742 |
+
disconnected_workers.append((worker_id, worker.gpu_id))
|
743 |
+
|
744 |
+
for worker_id, gpu_id in disconnected_workers:
|
745 |
+
analytics.log_worker_disconnected(worker_id, gpu_id)
|
746 |
+
del session_manager.workers[worker_id]
|
747 |
+
logger.warning(f"Removed disconnected worker {worker_id} (GPU {gpu_id})")
|
748 |
+
|
749 |
+
if disconnected_workers:
|
750 |
+
# Log updated GPU status
|
751 |
+
total_gpus = len(session_manager.workers)
|
752 |
+
active_gpus = len([w for w in session_manager.workers.values() if not w.is_available])
|
753 |
+
available_gpus = total_gpus - active_gpus
|
754 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
755 |
+
|
756 |
+
except Exception as e:
|
757 |
+
logger.error(f"Error in periodic worker health check: {e}")
|
758 |
+
|
759 |
@app.on_event("startup")
|
760 |
async def startup_event():
|
761 |
# Start background tasks
|
762 |
asyncio.create_task(periodic_queue_update())
|
763 |
+
asyncio.create_task(periodic_analytics_summary())
|
764 |
+
asyncio.create_task(periodic_worker_health_check())
|
765 |
+
|
766 |
+
# Log initial system status
|
767 |
+
analytics._write_log("🚀 System initialized and ready to accept connections")
|
768 |
+
analytics._write_log(" Waiting for GPU workers to register...")
|
769 |
+
|
770 |
+
@app.on_event("shutdown")
|
771 |
+
async def shutdown_event():
|
772 |
+
# Log final system summary
|
773 |
+
analytics._write_log("")
|
774 |
+
analytics._write_log("🛑 System shutting down...")
|
775 |
+
analytics.log_periodic_summary()
|
776 |
+
analytics._write_log("System shutdown complete.")
|
777 |
|
778 |
if __name__ == "__main__":
|
779 |
import uvicorn
|
main.py
CHANGED
@@ -526,7 +526,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
526 |
if not user_has_interacted:
|
527 |
user_has_interacted = True
|
528 |
print(f"[{time.perf_counter():.3f}] User has started interacting with canvas for client {client_id}")
|
529 |
-
|
|
|
|
|
530 |
|
531 |
# Update the set based on the received data
|
532 |
for key in keys_down_list:
|
@@ -649,7 +651,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
649 |
is_interesting = (current_input.get("is_left_click") or
|
650 |
current_input.get("is_right_click") or
|
651 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
652 |
-
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0)
|
|
|
|
|
653 |
|
654 |
# Process immediately if interesting
|
655 |
if is_interesting:
|
@@ -802,6 +806,8 @@ def log_interaction(client_id, data, generated_frame=None, is_end_of_session=Fal
|
|
802 |
"is_right_click": data.get("is_right_click"),
|
803 |
"keys_down": data.get("keys_down", []),
|
804 |
"keys_up": data.get("keys_up", []),
|
|
|
|
|
805 |
"is_auto_input": data.get("is_auto_input", False)
|
806 |
}
|
807 |
else:
|
@@ -809,6 +815,8 @@ def log_interaction(client_id, data, generated_frame=None, is_end_of_session=Fal
|
|
809 |
log_entry["inputs"] = None
|
810 |
|
811 |
# Save to a file (one file per session)
|
|
|
|
|
812 |
session_file = f"interaction_logs/session_{client_id}.jsonl"
|
813 |
with open(session_file, "a") as f:
|
814 |
f.write(json.dumps(log_entry) + "\n")
|
|
|
526 |
if not user_has_interacted:
|
527 |
user_has_interacted = True
|
528 |
print(f"[{time.perf_counter():.3f}] User has started interacting with canvas for client {client_id}")
|
529 |
+
wheel_delta_x = data.get("wheel_delta_x", 0)
|
530 |
+
wheel_delta_y = data.get("wheel_delta_y", 0)
|
531 |
+
print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}, wheel: ({wheel_delta_x},{wheel_delta_y}), time_since_activity: {time.perf_counter() - last_user_activity_time:.3f}')
|
532 |
|
533 |
# Update the set based on the received data
|
534 |
for key in keys_down_list:
|
|
|
651 |
is_interesting = (current_input.get("is_left_click") or
|
652 |
current_input.get("is_right_click") or
|
653 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
654 |
+
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or
|
655 |
+
current_input.get("wheel_delta_x", 0) != 0 or
|
656 |
+
current_input.get("wheel_delta_y", 0) != 0)
|
657 |
|
658 |
# Process immediately if interesting
|
659 |
if is_interesting:
|
|
|
806 |
"is_right_click": data.get("is_right_click"),
|
807 |
"keys_down": data.get("keys_down", []),
|
808 |
"keys_up": data.get("keys_up", []),
|
809 |
+
"wheel_delta_x": data.get("wheel_delta_x", 0),
|
810 |
+
"wheel_delta_y": data.get("wheel_delta_y", 0),
|
811 |
"is_auto_input": data.get("is_auto_input", False)
|
812 |
}
|
813 |
else:
|
|
|
815 |
log_entry["inputs"] = None
|
816 |
|
817 |
# Save to a file (one file per session)
|
818 |
+
if not os.path.exists("interaction_logs"):
|
819 |
+
os.makedirs("interaction_logs", exist_ok=True)
|
820 |
session_file = f"interaction_logs/session_{client_id}.jsonl"
|
821 |
with open(session_file, "a") as f:
|
822 |
f.write(json.dumps(log_entry) + "\n")
|
start_system.sh
CHANGED
@@ -61,6 +61,7 @@ echo "========================================"
|
|
61 |
echo "📊 Number of GPUs: $NUM_GPUS"
|
62 |
echo "🌐 Dispatcher port: $DISPATCHER_PORT"
|
63 |
echo "💻 Worker ports: $(seq -s', ' 8001 $((8000 + NUM_GPUS)))"
|
|
|
64 |
echo ""
|
65 |
|
66 |
# Check if required files exist
|
@@ -130,12 +131,14 @@ for ((i=0; i<NUM_GPUS; i++)); do
|
|
130 |
done
|
131 |
echo ""
|
132 |
echo "📋 Log files:"
|
|
|
133 |
echo " Dispatcher: dispatcher.log"
|
134 |
echo " Workers summary: workers.log"
|
135 |
for ((i=0; i<NUM_GPUS; i++)); do
|
136 |
echo " GPU $i worker: worker_gpu_$i.log"
|
137 |
done
|
138 |
echo ""
|
|
|
139 |
echo "Press Ctrl+C to stop the system"
|
140 |
echo "================================"
|
141 |
|
|
|
61 |
echo "📊 Number of GPUs: $NUM_GPUS"
|
62 |
echo "🌐 Dispatcher port: $DISPATCHER_PORT"
|
63 |
echo "💻 Worker ports: $(seq -s', ' 8001 $((8000 + NUM_GPUS)))"
|
64 |
+
echo "📈 Analytics logging: system_analytics_$(date +%Y%m%d_%H%M%S).log"
|
65 |
echo ""
|
66 |
|
67 |
# Check if required files exist
|
|
|
131 |
done
|
132 |
echo ""
|
133 |
echo "📋 Log files:"
|
134 |
+
echo " System analytics: system_analytics_*.log (real-time monitoring)"
|
135 |
echo " Dispatcher: dispatcher.log"
|
136 |
echo " Workers summary: workers.log"
|
137 |
for ((i=0; i<NUM_GPUS; i++)); do
|
138 |
echo " GPU $i worker: worker_gpu_$i.log"
|
139 |
done
|
140 |
echo ""
|
141 |
+
echo "💡 Monitor system in real-time: tail -f system_analytics_*.log"
|
142 |
echo "Press Ctrl+C to stop the system"
|
143 |
echo "================================"
|
144 |
|
static/index.html
CHANGED
@@ -414,6 +414,8 @@
|
|
414 |
"is_right_click": false,
|
415 |
"keys_down": [],
|
416 |
"keys_up": [],
|
|
|
|
|
417 |
"is_auto_input": true // Flag to identify auto-generated inputs
|
418 |
}));
|
419 |
lastAutoInputTime = currentTime;
|
@@ -531,7 +533,9 @@
|
|
531 |
"is_left_click": false,
|
532 |
"is_right_click": false,
|
533 |
"keys_down": [],
|
534 |
-
"keys_up": []
|
|
|
|
|
535 |
}));
|
536 |
updateLastUserInputTime(); // Update for auto-input mechanism
|
537 |
} catch (error) {
|
@@ -541,9 +545,9 @@
|
|
541 |
stopTimeoutCountdown();
|
542 |
}
|
543 |
|
544 |
-
function sendInputState(x, y, isLeftClick = false, isRightClick = false, keysDownArr = [], keysUpArr = []) {
|
545 |
const currentTime = Date.now();
|
546 |
-
if (isConnected && socket.readyState === WebSocket.OPEN && (isLeftClick || isRightClick || keysDownArr.length > 0 || keysUpArr.length > 0 || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
|
547 |
try {
|
548 |
socket.send(JSON.stringify({
|
549 |
"x": x,
|
@@ -552,6 +556,8 @@
|
|
552 |
"is_right_click": isRightClick,
|
553 |
"keys_down": keysDownArr,
|
554 |
"keys_up": keysUpArr,
|
|
|
|
|
555 |
}));
|
556 |
lastSentPosition = { x, y };
|
557 |
lastSentTime = currentTime;
|
@@ -638,6 +644,27 @@
|
|
638 |
sendInputState(x, y, false, true);
|
639 |
});
|
640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
// Track keyboard events
|
642 |
const TROUBLESOME = new Set([
|
643 |
"Tab", // focus change
|
|
|
414 |
"is_right_click": false,
|
415 |
"keys_down": [],
|
416 |
"keys_up": [],
|
417 |
+
"wheel_delta_x": 0,
|
418 |
+
"wheel_delta_y": 0,
|
419 |
"is_auto_input": true // Flag to identify auto-generated inputs
|
420 |
}));
|
421 |
lastAutoInputTime = currentTime;
|
|
|
533 |
"is_left_click": false,
|
534 |
"is_right_click": false,
|
535 |
"keys_down": [],
|
536 |
+
"keys_up": [],
|
537 |
+
"wheel_delta_x": 0,
|
538 |
+
"wheel_delta_y": 0
|
539 |
}));
|
540 |
updateLastUserInputTime(); // Update for auto-input mechanism
|
541 |
} catch (error) {
|
|
|
545 |
stopTimeoutCountdown();
|
546 |
}
|
547 |
|
548 |
+
function sendInputState(x, y, isLeftClick = false, isRightClick = false, keysDownArr = [], keysUpArr = [], wheelDeltaX = 0, wheelDeltaY = 0) {
|
549 |
const currentTime = Date.now();
|
550 |
+
if (isConnected && socket.readyState === WebSocket.OPEN && (isLeftClick || isRightClick || keysDownArr.length > 0 || keysUpArr.length > 0 || wheelDeltaX !== 0 || wheelDeltaY !== 0 || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
|
551 |
try {
|
552 |
socket.send(JSON.stringify({
|
553 |
"x": x,
|
|
|
556 |
"is_right_click": isRightClick,
|
557 |
"keys_down": keysDownArr,
|
558 |
"keys_up": keysUpArr,
|
559 |
+
"wheel_delta_x": wheelDeltaX,
|
560 |
+
"wheel_delta_y": wheelDeltaY,
|
561 |
}));
|
562 |
lastSentPosition = { x, y };
|
563 |
lastSentTime = currentTime;
|
|
|
644 |
sendInputState(x, y, false, true);
|
645 |
});
|
646 |
|
647 |
+
// Handle mouse wheel events
|
648 |
+
canvas.addEventListener("wheel", function (event) {
|
649 |
+
event.preventDefault(); // Prevent page scrolling
|
650 |
+
if (!isConnected || isProcessing) return;
|
651 |
+
|
652 |
+
let rect = canvas.getBoundingClientRect();
|
653 |
+
let x = event.clientX - rect.left;
|
654 |
+
let y = event.clientY - rect.top;
|
655 |
+
|
656 |
+
// Normalize wheel delta values (different browsers handle this differently)
|
657 |
+
let deltaX = event.deltaX;
|
658 |
+
let deltaY = event.deltaY;
|
659 |
+
|
660 |
+
// Clamp values to reasonable range
|
661 |
+
//deltaX = Math.max(-10, Math.min(10, deltaX));
|
662 |
+
//deltaY = Math.max(-10, Math.min(10, deltaY));
|
663 |
+
|
664 |
+
console.log(`Wheel event: deltaX=${deltaX}, deltaY=${deltaY} at (${x}, ${y})`);
|
665 |
+
sendInputState(x, y, false, false, [], [], deltaX, deltaY);
|
666 |
+
});
|
667 |
+
|
668 |
// Track keyboard events
|
669 |
const TROUBLESOME = new Set([
|
670 |
"Tab", // focus change
|
worker.py
CHANGED
@@ -293,9 +293,17 @@ class GPUWorker:
|
|
293 |
|
294 |
return sample_latent, sample_img, hidden_states, timing
|
295 |
|
296 |
-
def initialize_session(self, session_id: str):
|
297 |
"""Initialize a new session"""
|
298 |
self.current_session = session_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
self.session_data[session_id] = {
|
300 |
'previous_frame': self.padding_image,
|
301 |
'hidden_states': None,
|
@@ -306,9 +314,10 @@ class GPUWorker:
|
|
306 |
'sampling_steps': self.NUM_SAMPLING_STEPS
|
307 |
},
|
308 |
'input_queue': asyncio.Queue(),
|
309 |
-
'is_processing': False
|
|
|
310 |
}
|
311 |
-
logger.info(f"Initialized session {session_id}")
|
312 |
|
313 |
# Start processing task for this session
|
314 |
asyncio.create_task(self._process_session_queue(session_id))
|
@@ -316,8 +325,12 @@ class GPUWorker:
|
|
316 |
def end_session(self, session_id: str):
|
317 |
"""End a session and clean up"""
|
318 |
if session_id in self.session_data:
|
319 |
-
#
|
320 |
session = self.session_data[session_id]
|
|
|
|
|
|
|
|
|
321 |
while not session['input_queue'].empty():
|
322 |
try:
|
323 |
session['input_queue'].get_nowait()
|
@@ -391,7 +404,9 @@ class GPUWorker:
|
|
391 |
is_interesting = (current_input.get("is_left_click") or
|
392 |
current_input.get("is_right_click") or
|
393 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
394 |
-
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0)
|
|
|
|
|
395 |
|
396 |
# Process immediately if interesting
|
397 |
if is_interesting:
|
@@ -416,13 +431,17 @@ class GPUWorker:
|
|
416 |
async def process_input(self, session_id: str, data: dict) -> dict:
|
417 |
"""Process input for a session - adds to queue or handles control messages"""
|
418 |
if session_id not in self.session_data:
|
419 |
-
self.initialize_session(session_id)
|
420 |
|
421 |
session = self.session_data[session_id]
|
422 |
|
423 |
# Handle control messages immediately (don't queue these)
|
424 |
if data.get("type") == "reset":
|
425 |
logger.info(f"Received reset command for session {session_id}")
|
|
|
|
|
|
|
|
|
426 |
# Clear the queue
|
427 |
while not session['input_queue'].empty():
|
428 |
try:
|
@@ -484,6 +503,8 @@ class GPUWorker:
|
|
484 |
is_right_click = data.get("is_right_click", False)
|
485 |
keys_down_list = data.get("keys_down", [])
|
486 |
keys_up_list = data.get("keys_up", [])
|
|
|
|
|
487 |
|
488 |
# Update keys_down set
|
489 |
for key in keys_down_list:
|
@@ -518,8 +539,13 @@ class GPUWorker:
|
|
518 |
session['frame_num']
|
519 |
)
|
520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
# Process frame
|
522 |
-
logger.info(f"Processing frame {session['frame_num']} for session {session_id}")
|
523 |
sample_latent, sample_img, hidden_states, timing_info = await self.process_frame(
|
524 |
inputs,
|
525 |
use_rnn=session['client_settings']['use_rnn'],
|
@@ -539,6 +565,10 @@ class GPUWorker:
|
|
539 |
# Log timing
|
540 |
logger.info(f"Frame {session['frame_num']} processed in {timing_info['total']:.4f}s (FPS: {1.0/timing_info['total']:.2f})")
|
541 |
|
|
|
|
|
|
|
|
|
542 |
# Send result back to dispatcher
|
543 |
await self._send_result_to_dispatcher(session_id, {"image": img_str})
|
544 |
|
@@ -566,6 +596,55 @@ app = FastAPI()
|
|
566 |
# Global worker instance
|
567 |
worker: Optional[GPUWorker] = None
|
568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
@app.post("/process_input")
|
570 |
async def process_input_endpoint(request: dict):
|
571 |
"""Process input from dispatcher"""
|
@@ -581,13 +660,29 @@ async def process_input_endpoint(request: dict):
|
|
581 |
result = await worker.process_input(session_id, data)
|
582 |
return result
|
583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
@app.post("/end_session")
|
585 |
async def end_session_endpoint(request: dict):
|
586 |
-
"""End
|
587 |
if not worker:
|
588 |
raise HTTPException(status_code=500, detail="Worker not initialized")
|
589 |
|
590 |
session_id = request.get("session_id")
|
|
|
591 |
if not session_id:
|
592 |
raise HTTPException(status_code=400, detail="Missing session_id")
|
593 |
|
|
|
293 |
|
294 |
return sample_latent, sample_img, hidden_states, timing
|
295 |
|
296 |
+
def initialize_session(self, session_id: str, client_id: str = None):
|
297 |
"""Initialize a new session"""
|
298 |
self.current_session = session_id
|
299 |
+
# Use client_id from dispatcher if provided, otherwise create one
|
300 |
+
if client_id:
|
301 |
+
log_session_id = client_id
|
302 |
+
else:
|
303 |
+
# Fallback: create a time-prefixed session identifier for logging
|
304 |
+
session_start_time = int(time.time())
|
305 |
+
log_session_id = f"{session_start_time}_{session_id}"
|
306 |
+
|
307 |
self.session_data[session_id] = {
|
308 |
'previous_frame': self.padding_image,
|
309 |
'hidden_states': None,
|
|
|
314 |
'sampling_steps': self.NUM_SAMPLING_STEPS
|
315 |
},
|
316 |
'input_queue': asyncio.Queue(),
|
317 |
+
'is_processing': False,
|
318 |
+
'log_session_id': log_session_id # Store the time-prefixed ID for logging
|
319 |
}
|
320 |
+
logger.info(f"Initialized session {session_id} with log ID {log_session_id}")
|
321 |
|
322 |
# Start processing task for this session
|
323 |
asyncio.create_task(self._process_session_queue(session_id))
|
|
|
325 |
def end_session(self, session_id: str):
|
326 |
"""End a session and clean up"""
|
327 |
if session_id in self.session_data:
|
328 |
+
# Log session end using the stored log_session_id
|
329 |
session = self.session_data[session_id]
|
330 |
+
log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found
|
331 |
+
log_interaction(log_session_id, {}, is_end_of_session=True)
|
332 |
+
|
333 |
+
# Clear any remaining items in the queue
|
334 |
while not session['input_queue'].empty():
|
335 |
try:
|
336 |
session['input_queue'].get_nowait()
|
|
|
404 |
is_interesting = (current_input.get("is_left_click") or
|
405 |
current_input.get("is_right_click") or
|
406 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
407 |
+
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or
|
408 |
+
current_input.get("wheel_delta_x", 0) != 0 or
|
409 |
+
current_input.get("wheel_delta_y", 0) != 0)
|
410 |
|
411 |
# Process immediately if interesting
|
412 |
if is_interesting:
|
|
|
431 |
async def process_input(self, session_id: str, data: dict) -> dict:
|
432 |
"""Process input for a session - adds to queue or handles control messages"""
|
433 |
if session_id not in self.session_data:
|
434 |
+
self.initialize_session(session_id) # Fallback initialization without client_id
|
435 |
|
436 |
session = self.session_data[session_id]
|
437 |
|
438 |
# Handle control messages immediately (don't queue these)
|
439 |
if data.get("type") == "reset":
|
440 |
logger.info(f"Received reset command for session {session_id}")
|
441 |
+
# Log the reset action using the stored log_session_id
|
442 |
+
log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found
|
443 |
+
log_interaction(log_session_id, data, is_reset=True)
|
444 |
+
|
445 |
# Clear the queue
|
446 |
while not session['input_queue'].empty():
|
447 |
try:
|
|
|
503 |
is_right_click = data.get("is_right_click", False)
|
504 |
keys_down_list = data.get("keys_down", [])
|
505 |
keys_up_list = data.get("keys_up", [])
|
506 |
+
wheel_delta_x = data.get("wheel_delta_x", 0)
|
507 |
+
wheel_delta_y = data.get("wheel_delta_y", 0)
|
508 |
|
509 |
# Update keys_down set
|
510 |
for key in keys_down_list:
|
|
|
539 |
session['frame_num']
|
540 |
)
|
541 |
|
542 |
+
# Log the input data being processed
|
543 |
+
logger.info(f"Processing frame {session['frame_num']} for session {session_id}: "
|
544 |
+
f"pos=({x},{y}), clicks=(L:{is_left_click},R:{is_right_click}), "
|
545 |
+
f"keys_down={keys_down_list}, keys_up={keys_up_list}, "
|
546 |
+
f"wheel=({wheel_delta_x},{wheel_delta_y})")
|
547 |
+
|
548 |
# Process frame
|
|
|
549 |
sample_latent, sample_img, hidden_states, timing_info = await self.process_frame(
|
550 |
inputs,
|
551 |
use_rnn=session['client_settings']['use_rnn'],
|
|
|
565 |
# Log timing
|
566 |
logger.info(f"Frame {session['frame_num']} processed in {timing_info['total']:.4f}s (FPS: {1.0/timing_info['total']:.2f})")
|
567 |
|
568 |
+
# Log the interaction using the stored log_session_id
|
569 |
+
log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found
|
570 |
+
log_interaction(log_session_id, data, generated_frame=sample_img)
|
571 |
+
|
572 |
# Send result back to dispatcher
|
573 |
await self._send_result_to_dispatcher(session_id, {"image": img_str})
|
574 |
|
|
|
596 |
# Global worker instance
|
597 |
worker: Optional[GPUWorker] = None
|
598 |
|
599 |
+
def log_interaction(log_session_id, data, generated_frame=None, is_end_of_session=False, is_reset=False):
|
600 |
+
"""Log user interaction and optionally the generated frame."""
|
601 |
+
timestamp = time.time()
|
602 |
+
|
603 |
+
# Create directory structure if it doesn't exist
|
604 |
+
os.makedirs("interaction_logs", exist_ok=True)
|
605 |
+
|
606 |
+
# Structure the log entry
|
607 |
+
log_entry = {
|
608 |
+
"timestamp": timestamp,
|
609 |
+
"session_id": log_session_id, # Use the time-prefixed session ID
|
610 |
+
"is_eos": is_end_of_session,
|
611 |
+
"is_reset": is_reset
|
612 |
+
}
|
613 |
+
|
614 |
+
# Include type if present (for reset, etc.)
|
615 |
+
if data.get("type"):
|
616 |
+
log_entry["type"] = data.get("type")
|
617 |
+
|
618 |
+
# Only include input data if this isn't just a control message
|
619 |
+
if not is_end_of_session and not is_reset:
|
620 |
+
log_entry["inputs"] = {
|
621 |
+
"x": data.get("x"),
|
622 |
+
"y": data.get("y"),
|
623 |
+
"is_left_click": data.get("is_left_click"),
|
624 |
+
"is_right_click": data.get("is_right_click"),
|
625 |
+
"keys_down": data.get("keys_down", []),
|
626 |
+
"keys_up": data.get("keys_up", []),
|
627 |
+
"wheel_delta_x": data.get("wheel_delta_x", 0),
|
628 |
+
"wheel_delta_y": data.get("wheel_delta_y", 0),
|
629 |
+
"is_auto_input": data.get("is_auto_input", False)
|
630 |
+
}
|
631 |
+
else:
|
632 |
+
# For EOS/reset records, just include minimal info
|
633 |
+
log_entry["inputs"] = None
|
634 |
+
|
635 |
+
# Use the time-prefixed session ID for the filename (already includes timestamp)
|
636 |
+
session_file = f"interaction_logs/session_{log_session_id}.jsonl"
|
637 |
+
with open(session_file, "a") as f:
|
638 |
+
f.write(json.dumps(log_entry) + "\n")
|
639 |
+
|
640 |
+
# Optionally save the frame if provided
|
641 |
+
if generated_frame is not None and not is_end_of_session and not is_reset:
|
642 |
+
frame_dir = f"interaction_logs/frames_{log_session_id}"
|
643 |
+
os.makedirs(frame_dir, exist_ok=True)
|
644 |
+
frame_file = f"{frame_dir}/{timestamp:.6f}.png"
|
645 |
+
# Save the frame as PNG
|
646 |
+
Image.fromarray(generated_frame).save(frame_file)
|
647 |
+
|
648 |
@app.post("/process_input")
|
649 |
async def process_input_endpoint(request: dict):
|
650 |
"""Process input from dispatcher"""
|
|
|
660 |
result = await worker.process_input(session_id, data)
|
661 |
return result
|
662 |
|
663 |
+
@app.post("/init_session")
|
664 |
+
async def init_session_endpoint(request: dict):
|
665 |
+
"""Initialize session from dispatcher with client_id"""
|
666 |
+
if not worker:
|
667 |
+
raise HTTPException(status_code=500, detail="Worker not initialized")
|
668 |
+
|
669 |
+
session_id = request.get("session_id")
|
670 |
+
client_id = request.get("client_id")
|
671 |
+
|
672 |
+
if not session_id:
|
673 |
+
raise HTTPException(status_code=400, detail="Missing session_id")
|
674 |
+
|
675 |
+
worker.initialize_session(session_id, client_id)
|
676 |
+
return {"status": "session_initialized"}
|
677 |
+
|
678 |
@app.post("/end_session")
|
679 |
async def end_session_endpoint(request: dict):
|
680 |
+
"""End session from dispatcher"""
|
681 |
if not worker:
|
682 |
raise HTTPException(status_code=500, detail="Worker not initialized")
|
683 |
|
684 |
session_id = request.get("session_id")
|
685 |
+
|
686 |
if not session_id:
|
687 |
raise HTTPException(status_code=400, detail="Missing session_id")
|
688 |
|