Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from typing import List, Dict, Any, Optional | |
import asyncio | |
import json | |
import time | |
import os | |
from dataclasses import dataclass, asdict | |
from enum import Enum | |
import uuid | |
import aiohttp | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class SessionStatus(Enum): | |
QUEUED = "queued" | |
ACTIVE = "active" | |
COMPLETED = "completed" | |
TIMEOUT = "timeout" | |
class UserSession: | |
session_id: str | |
client_id: str | |
websocket: WebSocket | |
created_at: float | |
status: SessionStatus | |
worker_id: Optional[str] = None | |
last_activity: Optional[float] = None | |
max_session_time: Optional[float] = None | |
user_has_interacted: bool = False | |
class WorkerInfo: | |
worker_id: str | |
gpu_id: int | |
endpoint: str | |
is_available: bool | |
current_session: Optional[str] = None | |
last_ping: float = 0 | |
class SessionManager: | |
def __init__(self): | |
self.sessions: Dict[str, UserSession] = {} | |
self.workers: Dict[str, WorkerInfo] = {} | |
self.session_queue: List[str] = [] | |
self.active_sessions: Dict[str, str] = {} # session_id -> worker_id | |
# Configuration | |
self.IDLE_TIMEOUT = 20.0 # When no queue | |
self.QUEUE_WARNING_TIME = 10.0 | |
self.MAX_SESSION_TIME_WITH_QUEUE = 60.0 # When there's a queue | |
self.QUEUE_SESSION_WARNING_TIME = 45.0 # 15 seconds before timeout | |
self.GRACE_PERIOD = 10.0 | |
async def register_worker(self, worker_id: str, gpu_id: int, endpoint: str): | |
"""Register a new worker""" | |
self.workers[worker_id] = WorkerInfo( | |
worker_id=worker_id, | |
gpu_id=gpu_id, | |
endpoint=endpoint, | |
is_available=True, | |
last_ping=time.time() | |
) | |
logger.info(f"Registered worker {worker_id} on GPU {gpu_id} at {endpoint}") | |
async def get_available_worker(self) -> Optional[WorkerInfo]: | |
"""Get an available worker""" | |
for worker in self.workers.values(): | |
if worker.is_available and time.time() - worker.last_ping < 30: # Worker ping timeout | |
return worker | |
return None | |
async def add_session_to_queue(self, session: UserSession): | |
"""Add a session to the queue""" | |
self.sessions[session.session_id] = session | |
self.session_queue.append(session.session_id) | |
session.status = SessionStatus.QUEUED | |
logger.info(f"Added session {session.session_id} to queue. Queue size: {len(self.session_queue)}") | |
async def process_queue(self): | |
"""Process the session queue""" | |
while self.session_queue: | |
session_id = self.session_queue[0] | |
session = self.sessions.get(session_id) | |
if not session or session.status != SessionStatus.QUEUED: | |
self.session_queue.pop(0) | |
continue | |
worker = await self.get_available_worker() | |
if not worker: | |
break # No available workers | |
# Assign session to worker | |
self.session_queue.pop(0) | |
session.status = SessionStatus.ACTIVE | |
session.worker_id = worker.worker_id | |
session.last_activity = time.time() | |
# Set session time limit based on queue status | |
if len(self.session_queue) > 0: | |
session.max_session_time = self.MAX_SESSION_TIME_WITH_QUEUE | |
worker.is_available = False | |
worker.current_session = session_id | |
self.active_sessions[session_id] = worker.worker_id | |
logger.info(f"Assigned session {session_id} to worker {worker.worker_id}") | |
# Notify user that their session is starting | |
await self.notify_session_start(session) | |
# Start session monitoring | |
asyncio.create_task(self.monitor_active_session(session_id)) | |
async def notify_session_start(self, session: UserSession): | |
"""Notify user that their session is starting""" | |
try: | |
await session.websocket.send_json({ | |
"type": "session_start", | |
"worker_id": session.worker_id, | |
"max_session_time": session.max_session_time | |
}) | |
except Exception as e: | |
logger.error(f"Failed to notify session start for {session.session_id}: {e}") | |
async def monitor_active_session(self, session_id: str): | |
"""Monitor an active session for timeouts""" | |
session = self.sessions.get(session_id) | |
if not session: | |
return | |
try: | |
while session.status == SessionStatus.ACTIVE: | |
current_time = time.time() | |
# Check if session has exceeded time limit | |
if session.max_session_time: | |
elapsed = current_time - session.last_activity if session.last_activity else 0 | |
remaining = session.max_session_time - elapsed | |
# Send warning at 15 seconds before timeout | |
if remaining <= 15 and remaining > 10: | |
await session.websocket.send_json({ | |
"type": "session_warning", | |
"time_remaining": remaining, | |
"queue_size": len(self.session_queue) | |
}) | |
# Grace period handling | |
elif remaining <= 10 and remaining > 0: | |
# Check if queue is empty - if so, extend session | |
if len(self.session_queue) == 0: | |
session.max_session_time = None # Remove time limit | |
await session.websocket.send_json({ | |
"type": "time_limit_removed", | |
"reason": "queue_empty" | |
}) | |
else: | |
await session.websocket.send_json({ | |
"type": "grace_period", | |
"time_remaining": remaining, | |
"queue_size": len(self.session_queue) | |
}) | |
# Timeout | |
elif remaining <= 0: | |
await self.end_session(session_id, SessionStatus.TIMEOUT) | |
return | |
# Check idle timeout when no queue | |
elif not session.max_session_time and session.last_activity: | |
idle_time = current_time - session.last_activity | |
if idle_time >= self.IDLE_TIMEOUT: | |
await self.end_session(session_id, SessionStatus.TIMEOUT) | |
return | |
elif idle_time >= self.QUEUE_WARNING_TIME: | |
await session.websocket.send_json({ | |
"type": "idle_warning", | |
"time_remaining": self.IDLE_TIMEOUT - idle_time | |
}) | |
await asyncio.sleep(1) # Check every second | |
except Exception as e: | |
logger.error(f"Error monitoring session {session_id}: {e}") | |
await self.end_session(session_id, SessionStatus.COMPLETED) | |
async def end_session(self, session_id: str, status: SessionStatus): | |
"""End a session and free up the worker""" | |
session = self.sessions.get(session_id) | |
if not session: | |
return | |
session.status = status | |
# Free up the worker | |
if session.worker_id and session.worker_id in self.workers: | |
worker = self.workers[session.worker_id] | |
worker.is_available = True | |
worker.current_session = None | |
# Notify worker to clean up | |
try: | |
async with aiohttp.ClientSession() as client_session: | |
await client_session.post(f"{worker.endpoint}/end_session", | |
json={"session_id": session_id}) | |
except Exception as e: | |
logger.error(f"Failed to notify worker {worker.worker_id} of session end: {e}") | |
# Remove from active sessions | |
if session_id in self.active_sessions: | |
del self.active_sessions[session_id] | |
logger.info(f"Ended session {session_id} with status {status}") | |
# Process next in queue | |
asyncio.create_task(self.process_queue()) | |
async def update_queue_info(self): | |
"""Send queue information to waiting users""" | |
for i, session_id in enumerate(self.session_queue): | |
session = self.sessions.get(session_id) | |
if session and session.status == SessionStatus.QUEUED: | |
try: | |
# Calculate estimated wait time | |
active_sessions_count = len(self.active_sessions) | |
avg_session_time = self.MAX_SESSION_TIME_WITH_QUEUE if active_sessions_count > 0 else 30.0 | |
estimated_wait = (i + 1) * avg_session_time / max(len(self.workers), 1) | |
await session.websocket.send_json({ | |
"type": "queue_update", | |
"position": i + 1, | |
"total_waiting": len(self.session_queue), | |
"estimated_wait_minutes": estimated_wait / 60, | |
"active_sessions": active_sessions_count | |
}) | |
except Exception as e: | |
logger.error(f"Failed to send queue update to session {session_id}: {e}") | |
async def handle_user_activity(self, session_id: str): | |
"""Update user activity timestamp""" | |
session = self.sessions.get(session_id) | |
if session: | |
session.last_activity = time.time() | |
if not session.user_has_interacted: | |
session.user_has_interacted = True | |
logger.info(f"User started interacting in session {session_id}") | |
async def _forward_to_worker(self, worker: WorkerInfo, session_id: str, data: dict): | |
"""Forward input to worker asynchronously""" | |
try: | |
async with aiohttp.ClientSession() as client_session: | |
async with client_session.post( | |
f"{worker.endpoint}/process_input", | |
json={ | |
"session_id": session_id, | |
"data": data | |
} | |
) as response: | |
if response.status != 200: | |
logger.error(f"Worker returned status {response.status}") | |
# Optionally handle worker errors here | |
except Exception as e: | |
logger.error(f"Error forwarding to worker {worker.worker_id}: {e}") | |
# Global session manager | |
session_manager = SessionManager() | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
async def get(): | |
return HTMLResponse(open("static/index.html").read()) | |
async def register_worker(worker_info: dict): | |
"""Endpoint for workers to register themselves""" | |
await session_manager.register_worker( | |
worker_info["worker_id"], | |
worker_info["gpu_id"], | |
worker_info["endpoint"] | |
) | |
return {"status": "registered"} | |
async def worker_ping(worker_info: dict): | |
"""Endpoint for workers to ping their availability""" | |
worker_id = worker_info["worker_id"] | |
if worker_id in session_manager.workers: | |
session_manager.workers[worker_id].last_ping = time.time() | |
session_manager.workers[worker_id].is_available = worker_info.get("is_available", True) | |
return {"status": "ok"} | |
async def worker_result(result_data: dict): | |
"""Endpoint for workers to send back processing results""" | |
session_id = result_data.get("session_id") | |
worker_id = result_data.get("worker_id") | |
result = result_data.get("result") | |
if not session_id or not result: | |
raise HTTPException(status_code=400, detail="Missing session_id or result") | |
# Find the session and send result to the WebSocket | |
session = session_manager.sessions.get(session_id) | |
if session and session.status == SessionStatus.ACTIVE: | |
try: | |
await session.websocket.send_json(result) | |
logger.info(f"Sent result to session {session_id}") | |
except Exception as e: | |
logger.error(f"Failed to send result to session {session_id}: {e}") | |
else: | |
logger.warning(f"Could not find active session {session_id} for result") | |
return {"status": "ok"} | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
# Create session | |
session_id = str(uuid.uuid4()) | |
client_id = f"{int(time.time())}_{session_id[:8]}" | |
session = UserSession( | |
session_id=session_id, | |
client_id=client_id, | |
websocket=websocket, | |
created_at=time.time(), | |
status=SessionStatus.QUEUED | |
) | |
logger.info(f"New WebSocket connection: {client_id}") | |
try: | |
# Add to queue | |
await session_manager.add_session_to_queue(session) | |
# Try to process queue immediately | |
await session_manager.process_queue() | |
# Send initial queue status | |
if session.status == SessionStatus.QUEUED: | |
await session_manager.update_queue_info() | |
# Main message loop | |
while True: | |
try: | |
data = await websocket.receive_json() | |
# Update activity | |
await session_manager.handle_user_activity(session_id) | |
# Handle different message types | |
if data.get("type") == "heartbeat": | |
await websocket.send_json({"type": "heartbeat_response"}) | |
continue | |
# If session is active, forward to worker | |
if session.status == SessionStatus.ACTIVE and session.worker_id: | |
worker = session_manager.workers.get(session.worker_id) | |
if worker: | |
try: | |
# Forward message to worker (don't wait for response for regular inputs) | |
# The worker will send results back asynchronously via /worker_result | |
asyncio.create_task(session_manager._forward_to_worker(worker, session_id, data)) | |
except Exception as e: | |
logger.error(f"Error forwarding to worker: {e}") | |
# Handle control messages (these need synchronous responses) | |
elif data.get("type") in ["reset", "update_sampling_steps", "update_use_rnn", "get_settings"]: | |
if session.status == SessionStatus.ACTIVE and session.worker_id: | |
worker = session_manager.workers.get(session.worker_id) | |
if worker: | |
try: | |
async with aiohttp.ClientSession() as client_session: | |
async with client_session.post( | |
f"{worker.endpoint}/process_input", | |
json={ | |
"session_id": session_id, | |
"data": data | |
} | |
) as response: | |
if response.status == 200: | |
result = await response.json() | |
await websocket.send_json(result) | |
else: | |
logger.error(f"Worker returned status {response.status}") | |
except Exception as e: | |
logger.error(f"Error forwarding control message: {e}") | |
else: | |
# Send appropriate response for queued users | |
await websocket.send_json({ | |
"type": "error", | |
"message": "Session not active yet. Please wait in queue." | |
}) | |
except asyncio.TimeoutError: | |
logger.info("WebSocket connection timed out") | |
break | |
except WebSocketDisconnect: | |
logger.info(f"WebSocket disconnected: {client_id}") | |
break | |
except Exception as e: | |
logger.error(f"Error in WebSocket connection {client_id}: {e}") | |
import traceback | |
traceback.print_exc() | |
finally: | |
# Clean up session | |
if session_id in session_manager.sessions: | |
await session_manager.end_session(session_id, SessionStatus.COMPLETED) | |
del session_manager.sessions[session_id] | |
logger.info(f"WebSocket connection closed: {client_id}") | |
# Background task to periodically update queue info | |
async def periodic_queue_update(): | |
while True: | |
try: | |
await session_manager.update_queue_info() | |
await asyncio.sleep(5) # Update every 5 seconds | |
except Exception as e: | |
logger.error(f"Error in periodic queue update: {e}") | |
async def startup_event(): | |
# Start background tasks | |
asyncio.create_task(periodic_queue_update()) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |