import os import json from datetime import datetime import asyncio import aiohttp from typing import Dict, List, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel, HttpUrl import uvicorn from git_clone import clone_repository # ===== CONFIG ===== class Settings: # Server URLs and Ports CONTROLLER_HOST = "0.0.0.0" # Listen on all interfaces CONTROLLER_PORT = 8000 # This should be the actual IP or hostname where controller is accessible CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", "http://192.168.1.100:8000") # List of tensor server URLs - should be actual IP addresses or hostnames TENSOR_SERVER_URLS = os.getenv("TENSOR_SERVER_URLS", "").split(",") or [ "https://fred808-ilob.hf.space", # Example IP for tensor server 1 "https://fred808-tserv.hf.space", # Example IP for tensor server 2 "https://fred808-tserve2.hf.space" # Example IP for tensor server 3 ] # Aggregator settings - should be actual IP or hostname AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002") # Model settings MODEL_REPO = "https://huggingface.co/microsoft/Florence-2-large" # Server settings TENSOR_SERVER_TIMEOUT = 30 # seconds MAX_ERROR_THRESHOLD = 5 # maximum number of errors SERVER_TIMEOUT = 60 # seconds before marking as error MONITORING_INTERVAL = 15 # seconds between health checks # Dynamic distribution settings @classmethod def get_optimal_chunk_size(cls, total_params: int, num_servers: int) -> int: """Calculate optimal chunk size based on number of servers""" # Aim for 2-3 chunks per server for better parallelism target_chunks = num_servers * 2 return max(1, total_params // target_chunks) @classmethod def get_min_servers_required(cls) -> int: """Dynamically calculate minimum servers needed based on registered servers""" return max(2, len(cls.TENSOR_SERVER_URLS) // 3) # At least 1/3 of registered servers @classmethod def get_min_replica_count(cls, num_servers: int) -> int: """Calculate minimum replicas based on server count""" return max(2, num_servers // 4) # At least 25% of servers should have each chunk # Tokenizer settings MAX_SEQUENCE_LENGTH = 2048 VOCAB_SIZE = 50257 @classmethod def from_env(cls): """Load settings from environment variables""" cls.CONTROLLER_HOST = os.getenv("CONTROLLER_HOST", cls.CONTROLLER_HOST) cls.CONTROLLER_PORT = int(os.getenv("CONTROLLER_PORT", cls.CONTROLLER_PORT)) cls.CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", cls.CONTROLLER_BASE_URL) # Load tensor server URLs from environment tensor_urls = os.getenv("TENSOR_SERVER_URLS") if tensor_urls: cls.TENSOR_SERVER_URLS = tensor_urls.split(",") cls.AGGREGATOR_HOST = os.getenv("AGGREGATOR_HOST", cls.AGGREGATOR_HOST) cls.AGGREGATOR_PORT = int(os.getenv("AGGREGATOR_PORT", cls.AGGREGATOR_PORT)) cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", f"http://{cls.AGGREGATOR_HOST}:{cls.AGGREGATOR_PORT}") return cls # ===== State Models ===== class ServerMetrics(BaseModel): """Metrics for tensor server performance and load""" cpu_usage: float = 0.0 memory_usage: float = 0.0 gpu_usage: Optional[float] = None active_requests: int = 0 total_requests: int = 0 average_response_time: float = 0.0 last_error: Optional[str] = None error_count: int = 0 class TensorServer(BaseModel): """Represents a registered tensor server""" url: HttpUrl status: str = "initializing" # initializing, ready, busy, error, degraded last_heartbeat: datetime = datetime.now() model_chunks: List[int] = [] # List of chunk IDs assigned to this server metrics: ServerMetrics = ServerMetrics() version: str = "1.0.0" capabilities: Dict[str, bool] = { "gpu_available": False, "quantization_support": False, "tensor_parallelism": False } class ModelChunk(BaseModel): """Represents a chunk of the model to be sent to a tensor server""" chunk_id: int files: List[str] # files included in this chunk config: Dict # configuration for this chunk size_bytes: int = 0 server_assignments: List[str] = [] # URLs of servers holding this chunk status: str = "unassigned" # unassigned, assigned, loaded, error metrics: Dict[str, float] = { "load_time": 0.0, "memory_usage": 0.0, "average_inference_time": 0.0 } # ===== FastAPI App ===== app = FastAPI( title="Florence-2 Model Controller", description="Controls model distribution across tensor servers", version="1.0.0" ) # ===== Global State ===== class ControllerState: def __init__(self): self.model_files: Dict[str, str] = {} # Mapping of filename to file path self.model_config: Dict = {} # Model configuration self.tensor_servers: Dict[str, TensorServer] = {} self.model_chunks: Dict[int, ModelChunk] = {} self.is_model_loaded = False self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations state = ControllerState() # ===== Helper Functions ===== async def split_model_weights(): """Split model weights into chunks based on available servers""" try: import torch # Load the full model weights model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin')) weights = torch.load(model_file, map_location='cpu') # Calculate chunks based on number of servers total_params = sum(p.numel() for p in weights.values()) num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS) params_per_chunk = Settings.get_optimal_chunk_size(total_params, num_servers) print(f"[INFO] Total parameters: {total_params:,}") print(f"[INFO] Available servers: {num_servers}") print(f"[INFO] Parameters per chunk: {params_per_chunk:,}") current_chunk = [] current_size = 0 chunk_id = 0 for key, tensor in weights.items(): tensor_size = tensor.numel() if current_size + tensor_size > params_per_chunk and current_chunk: # Save current chunk chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") torch.save({k: weights[k] for k in current_chunk}, chunk_path) # Create chunk metadata state.model_chunks[chunk_id] = ModelChunk( chunk_id=chunk_id, files=[f"chunk_{chunk_id}.safetensors"], config={ "weight_keys": current_chunk, "input_size": weights[current_chunk[0]].size(1), "output_size": weights[current_chunk[-1]].size(0) } ) # Reset for next chunk current_chunk = [] current_size = 0 chunk_id += 1 current_chunk.append(key) current_size += tensor_size # Save last chunk if not empty if current_chunk: chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") torch.save({k: weights[k] for k in current_chunk}, chunk_path) state.model_chunks[chunk_id] = ModelChunk( chunk_id=chunk_id, files=[f"chunk_{chunk_id}.safetensors"], config={ "weight_keys": current_chunk, "input_size": weights[current_chunk[0]].size(1), "output_size": weights[current_chunk[-1]].size(0) } ) print(f"[INFO] Split model into {len(state.model_chunks)} chunks") return True except Exception as e: print(f"[ERROR] Failed to split model weights: {str(e)}") return False async def distribute_model_chunks(): """Distribute model chunks across available tensor servers""" try: available_servers = [ server for server in state.tensor_servers.values() if server.status in ["ready", "busy"] and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD ] min_required = Settings.get_min_servers_required() if len(available_servers) < min_required: raise Exception(f"Not enough healthy servers. Need {min_required}, got {len(available_servers)}") # Create or update weight chunks based on current server count if not state.model_chunks or len(state.model_chunks) > len(available_servers) * 3: if not await split_model_weights(): raise Exception("Failed to split model weights") # Prepare for parallel distribution tasks = [] min_replicas = Settings.get_min_replica_count(len(available_servers)) chunks_per_server = len(state.model_chunks) / len(available_servers) print(f"[INFO] Distributing chunks with min {min_replicas} replicas per chunk") print(f"[INFO] Target chunks per server: {chunks_per_server:.1f}") # Distribute chunks for chunk_id, chunk in state.model_chunks.items(): # Calculate optimal number of replicas based on chunk size and server capacity target_replicas = max(min_replicas, int(chunks_per_server * len(available_servers) / len(state.model_chunks))) current_assignments = set(chunk.server_assignments) current_healthy = [url for url in current_assignments if state.tensor_servers[url].status in ["ready", "busy"]] # Remove unhealthy assignments chunk.server_assignments = current_healthy # Add new assignments if needed while len(chunk.server_assignments) < target_replicas: # Find least loaded eligible server eligible_servers = [ server for server in available_servers if str(server.url) not in chunk.server_assignments and len(server.model_chunks) < (len(state.model_chunks) / len(available_servers) * 1.5) ] if not eligible_servers: break # Sort by load and error count eligible_servers.sort(key=lambda s: ( len(s.model_chunks), s.metrics.error_count, s.metrics.cpu_usage )) # Assign to best server best_server = eligible_servers[0] chunk.server_assignments.append(str(best_server.url)) best_server.model_chunks.append(chunk_id) print(f"[INFO] Assigned chunk {chunk_id} to server {best_server.url}") return True except Exception as e: print(f"[ERROR] Failed to distribute model chunks: {str(e)}") return False async def monitor_tensor_servers(): """Periodically check health and update metrics of all tensor servers""" while True: for server_url, server in state.tensor_servers.items(): try: # Check basic health is_healthy = await check_tensor_server_health(server_url) if not is_healthy: server.status = "error" server.metrics.error_count += 1 print(f"[WARN] Server {server_url} is unhealthy") continue # Get detailed metrics async with aiohttp.ClientSession() as session: async with session.get(f"{server_url}/metrics", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: if response.status == 200: metrics = await response.json() server.metrics = ServerMetrics(**metrics) # Update server status based on metrics if server.metrics.error_count > Settings.MAX_ERROR_THRESHOLD: server.status = "degraded" elif server.metrics.cpu_usage > 90 or server.metrics.memory_usage > 90: server.status = "busy" else: server.status = "ready" server.last_heartbeat = datetime.now() except Exception as e: print(f"[ERROR] Failed to monitor server {server_url}: {str(e)}") server.status = "error" server.metrics.last_error = str(e) server.metrics.error_count += 1 # Check for servers that haven't responded in a while current_time = datetime.now() for server_url, server in state.tensor_servers.items(): if (current_time - server.last_heartbeat).seconds > Settings.SERVER_TIMEOUT: print(f"[WARN] Server {server_url} hasn't responded in {Settings.SERVER_TIMEOUT} seconds") server.status = "error" await asyncio.sleep(Settings.MONITORING_INTERVAL) def get_next_model_version(base_dir: str, model_name: str) -> int: """Get the next available version number for the model""" existing_versions = [] model_base_dir = os.path.join(base_dir, model_name) if os.path.exists(model_base_dir): for d in os.listdir(model_base_dir): if d.startswith('v') and d[1:].isdigit(): existing_versions.append(int(d[1:])) return max(existing_versions + [0]) + 1 def check_existing_model(model_path: str) -> bool: """Check if a model exists and has required files""" if not os.path.exists(model_path): return False # Check for essential files required_files = ['config.json'] model_files = os.listdir(model_path) # Check for any weight files has_weights = any(f.endswith(('.bin', '.safetensors')) for f in model_files) return all(f in model_files for f in required_files) and has_weights async def download_model_files(): """Downloads the model files using git clone from Hugging Face repository""" try: print(f"[INFO] Processing model from {Settings.MODEL_REPO}...") # Create models directory models_dir = os.path.join(os.getcwd(), "models") os.makedirs(models_dir, exist_ok=True) print(f"[INFO] Models directory: {models_dir}") # Get the model name from the repository URL model_name = Settings.MODEL_REPO.split('/')[-1] # Create versioned model directory version = get_next_model_version(models_dir, model_name) model_base_dir = os.path.join(models_dir, model_name) model_version_dir = os.path.join(model_base_dir, f"v{version}") # Check if previous version exists and is valid if version > 1: prev_version_dir = os.path.join(model_base_dir, f"v{version-1}") if check_existing_model(prev_version_dir): print(f"[INFO] Using existing model from {prev_version_dir}") model_path = prev_version_dir state.is_model_loaded = True else: # Clone new version if previous is invalid or incomplete os.makedirs(model_version_dir, exist_ok=True) success = clone_repository(Settings.MODEL_REPO, model_version_dir) if not success: raise Exception("Failed to clone repository") model_path = model_version_dir print(f"[INFO] Successfully cloned model to {model_path}") else: # First time download os.makedirs(model_version_dir, exist_ok=True) success = clone_repository(Settings.MODEL_REPO, model_version_dir) if not success: raise Exception("Failed to clone repository") model_path = model_version_dir print(f"[INFO] Successfully cloned model to {model_path}") # Load and parse the config config_path = os.path.join(model_path, "config.json") if os.path.exists(config_path): with open(config_path, 'r') as f: state.model_config = json.load(f) print("[INFO] Loaded model configuration") print(f"[INFO] Model type: {state.model_config.get('model_type', 'unknown')}") print(f"[INFO] Architecture: {state.model_config.get('architectures', ['unknown'])[0]}") else: print("[WARN] No config.json found in model directory") # Scan for model files print("[INFO] Scanning for model files...") for root, _, files in os.walk(model_path): for file in files: if file.endswith(('.bin', '.json', '.safetensors')): file_path = os.path.join(root, file) state.model_files[file] = file_path print(f"[INFO] Found model file: {file}") if state.model_files: state.is_model_loaded = True print(f"[INFO] Model files found successfully! Total files: {len(state.model_files)}") print(f"[INFO] Model location: {model_path}") return True else: raise ValueError("No model files were found in the repository") except Exception as e: print(f"[ERROR] Failed to process model files: {e}") state.is_model_loaded = False raise async def check_tensor_server_health(url: HttpUrl) -> bool: """Checks if a tensor server is healthy""" try: async with aiohttp.ClientSession() as session: async with session.get(f"{url}/health", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: return response.status == 200 except: return False # ===== API Endpoints ===== async def execute_tensor_operation(operation_id: str, server_url: HttpUrl, operation: str, data: Dict): """Execute an operation on a tensor server and wait for results""" try: async with aiohttp.ClientSession() as session: # Start the operation async with session.post( f"{server_url}/{operation}", json=data, timeout=Settings.TENSOR_SERVER_TIMEOUT ) as response: if response.status != 200: error_msg = await response.text() raise HTTPException( status_code=response.status, detail=f"Operation failed on server {server_url}: {error_msg}" ) initial_response = await response.json() if initial_response.get("status") == "completed": # Operation completed immediately state.operation_results[operation_id] = initial_response return initial_response # Operation is async, poll for results while True: await asyncio.sleep(1) # Poll interval async with session.get( f"{server_url}/operation/{initial_response['operation_id']}", timeout=Settings.TENSOR_SERVER_TIMEOUT ) as status_response: if status_response.status != 200: raise HTTPException( status_code=status_response.status, detail=f"Failed to get operation status from {server_url}" ) status_data = await status_response.json() if status_data["status"] in ["completed", "failed"]: state.operation_results[operation_id] = status_data if status_data["status"] == "failed": raise HTTPException( status_code=500, detail=f"Operation failed on server {server_url}: {status_data.get('error')}" ) return status_data except asyncio.TimeoutError: raise HTTPException( status_code=504, detail=f"Operation timed out on server {server_url}" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Error executing operation on {server_url}: {str(e)}" ) @app.post("/execute/{operation}") async def execute_operation(operation: str, data: Dict): """Execute an operation across tensor servers and collect results""" operation_id = f"{operation}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{len(state.operation_results)}" # Get available servers with required chunks available_servers = [ server for server in state.tensor_servers.values() if server.status in ["ready", "busy"] and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD ] if not available_servers: raise HTTPException( status_code=503, detail="No available tensor servers" ) # Start operations on all relevant servers in parallel tasks = [] for server in available_servers: if operation in ["compute", "forward"]: # For compute operations, only use servers with required chunks required_chunks = data.get("required_chunks", []) if not all(chunk_id in server.model_chunks for chunk_id in required_chunks): continue task = asyncio.create_task( execute_tensor_operation( f"{operation_id}_{server.url}", server.url, operation, data ) ) tasks.append(task) state.pending_operations[f"{operation_id}_{server.url}"] = task if not tasks: raise HTTPException( status_code=400, detail="No servers available with required model chunks" ) try: # Wait for all operations to complete results = await asyncio.gather(*tasks) # Process and aggregate results aggregated_result = { "operation_id": operation_id, "status": "completed", "server_results": results, "timestamp": datetime.now().isoformat() } # Clean up for task_id in list(state.pending_operations.keys()): if task_id.startswith(operation_id): del state.pending_operations[task_id] return aggregated_result except Exception as e: # Cancel any remaining tasks for task in tasks: if not task.done(): task.cancel() # Clean up for task_id in list(state.pending_operations.keys()): if task_id.startswith(operation_id): del state.pending_operations[task_id] raise HTTPException( status_code=500, detail=f"Operation failed: {str(e)}" ) @app.get("/operation/{operation_id}") async def get_operation_status(operation_id: str): """Get the status of an operation""" # Check completed operations results = { k: v for k, v in state.operation_results.items() if k.startswith(operation_id) } if results: return { "operation_id": operation_id, "status": "completed", "results": results } # Check pending operations pending = { k: "running" for k in state.pending_operations.keys() if k.startswith(operation_id) } if pending: return { "operation_id": operation_id, "status": "running", "pending_servers": list(pending.keys()) } raise HTTPException( status_code=404, detail=f"Operation {operation_id} not found" ) @app.get("/") async def root(): """Health check endpoint""" return { "status": "running", "model_loaded": state.is_model_loaded, "registered_servers": len(state.tensor_servers), "downloaded_files": len(state.model_files), "config_loaded": bool(state.model_config) } @app.get("/health") async def health_check(): """Detailed health check""" return { "status": "healthy", "model_loaded": state.is_model_loaded, "registered_servers": len(state.tensor_servers), "downloaded_files": list(state.model_files.keys()), "config_loaded": bool(state.model_config), "model_type": state.model_config.get("model_type", "unknown") } @app.post("/register_tensor_server") async def register_tensor_server(server_url: HttpUrl): """Register a new tensor server""" if not await check_tensor_server_health(server_url): raise HTTPException(status_code=400, detail="Tensor server is not healthy") state.tensor_servers[str(server_url)] = TensorServer(url=server_url) print(f"[INFO] Registered new tensor server at {server_url}") return { "status": "registered", "registered_servers": len(state.tensor_servers), "server_id": str(server_url) } @app.delete("/unregister_tensor_server") async def unregister_tensor_server(server_url: HttpUrl): """Unregister a tensor server""" if str(server_url) in state.tensor_servers: # Remove server assignments from chunks for chunk in state.model_chunks.values(): if str(server_url) in chunk.server_assignments: chunk.server_assignments.remove(str(server_url)) del state.tensor_servers[str(server_url)] print(f"[INFO] Unregistered tensor server at {server_url}") # Trigger redistribution of chunks await distribute_model_chunks() return {"status": "unregistered"} raise HTTPException(status_code=404, detail="Server not found") @app.get("/server/{server_url}/chunks") async def get_server_chunks(server_url: HttpUrl): """Get the chunks assigned to a specific server""" if str(server_url) not in state.tensor_servers: raise HTTPException(status_code=404, detail="Server not found") server = state.tensor_servers[str(server_url)] assigned_chunks = [ state.model_chunks[chunk_id] for chunk_id in server.model_chunks ] return { "server_status": server.status, "assigned_chunks": assigned_chunks, "metrics": server.metrics.dict() } @app.post("/redistribute") async def redistribute_chunks(): """Manually trigger redistribution of model chunks""" success = await distribute_model_chunks() if not success: raise HTTPException(status_code=500, detail="Failed to redistribute chunks") return { "status": "redistributed", "chunk_assignments": { chunk_id: chunk.server_assignments for chunk_id, chunk in state.model_chunks.items() } } @app.get("/chunks/{chunk_id}/status") async def get_chunk_status(chunk_id: int): """Get the status and assignments of a specific chunk""" if chunk_id not in state.model_chunks: raise HTTPException(status_code=404, detail="Chunk not found") chunk = state.model_chunks[chunk_id] return { "chunk_id": chunk_id, "status": chunk.status, "server_assignments": chunk.server_assignments, "metrics": chunk.metrics } @app.post("/initialize") async def initialize_system(): """Download model files and prepare for distribution""" await download_model_files() # Verify downloaded files files_status = {} total_size = 0 for filename, filepath in state.model_files.items(): exists = os.path.exists(filepath) if exists: size = os.path.getsize(filepath) total_size += size files_status[filename] = {"exists": exists, "size_bytes": size} else: files_status[filename] = {"exists": exists, "size_bytes": 0} return { "status": "initialized", "model_loaded": state.is_model_loaded, "files_status": files_status, "total_size_bytes": total_size, "config_loaded": bool(state.model_config), "model_type": state.model_config.get("model_type", "unknown"), "architecture": state.model_config.get("architectures", ["unknown"])[0] } # ===== Main Execution ===== @app.on_event("startup") async def startup_event(): """Initialize the server and start background tasks""" print("[INFO] Initializing system...") await initialize_system() print("[INFO] Model initialization complete") # Start monitoring task asyncio.create_task(monitor_tensor_servers()) print("[INFO] Server monitoring started") if __name__ == "__main__": port = int(os.getenv("PORT", 8000)) print(f"[INFO] Starting controller server on port {port}") print(f"[INFO] API Documentation available at http://localhost:{port}/docs") uvicorn.run( "controller_server_new:app", host="0.0.0.0", port=port, reload=False )