""" Cold Start Optimization for DittoTalkingHead Reduces model loading time and I/O overhead """ import os import shutil import time from pathlib import Path from typing import Dict, Any, Optional import pickle import torch class ColdStartOptimizer: """ Optimizes cold start time by using persistent storage and efficient loading """ def __init__(self, persistent_dir: str = "/tmp/persistent_model_cache"): """ Initialize cold start optimizer Args: persistent_dir: Directory for persistent storage (survives restarts) """ self.persistent_dir = Path(persistent_dir) self.persistent_dir.mkdir(parents=True, exist_ok=True) # Hugging Face Spaces persistent paths self.hf_persistent_paths = [ "/data", # Primary persistent storage "/tmp/persistent", # Fallback ] # Model cache settings self.model_cache = {} self.load_times = {} def get_persistent_path(self) -> Path: """ Get the best available persistent path Returns: Path to persistent storage """ # Check Hugging Face Spaces persistent directories for path in self.hf_persistent_paths: if os.path.exists(path) and os.access(path, os.W_OK): return Path(path) / "model_cache" # Fallback to configured directory return self.persistent_dir def setup_persistent_model_cache(self, source_dir: str) -> bool: """ Set up persistent model cache Args: source_dir: Source directory containing models Returns: True if successful """ persistent_path = self.get_persistent_path() persistent_path.mkdir(parents=True, exist_ok=True) source_path = Path(source_dir) if not source_path.exists(): print(f"Source directory {source_dir} not found") return False # Copy models to persistent storage if not already there model_files = list(source_path.glob("**/*.pth")) + \ list(source_path.glob("**/*.pkl")) + \ list(source_path.glob("**/*.onnx")) + \ list(source_path.glob("**/*.trt")) copied = 0 for model_file in model_files: relative_path = model_file.relative_to(source_path) target_path = persistent_path / relative_path if not target_path.exists(): target_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(model_file, target_path) copied += 1 print(f"Copied {relative_path} to persistent storage") print(f"Persistent cache setup complete. Copied {copied} new files.") return True def load_model_cached( self, model_path: str, load_func: callable, cache_key: Optional[str] = None ) -> Any: """ Load model with caching Args: model_path: Path to model file load_func: Function to load the model cache_key: Optional cache key (defaults to model_path) Returns: Loaded model """ cache_key = cache_key or model_path # Check in-memory cache first if cache_key in self.model_cache: print(f"✅ Loaded {cache_key} from memory cache") return self.model_cache[cache_key] # Check persistent storage persistent_path = self.get_persistent_path() model_name = Path(model_path).name persistent_model_path = persistent_path / model_name start_time = time.time() if persistent_model_path.exists(): # Load from persistent storage print(f"Loading {model_name} from persistent storage...") model = load_func(str(persistent_model_path)) else: # Load from original path print(f"Loading {model_name} from original location...") model = load_func(model_path) # Try to copy to persistent storage try: shutil.copy2(model_path, persistent_model_path) print(f"Cached {model_name} to persistent storage") except Exception as e: print(f"Warning: Could not cache to persistent storage: {e}") load_time = time.time() - start_time self.load_times[cache_key] = load_time # Cache in memory self.model_cache[cache_key] = model print(f"✅ Loaded {cache_key} in {load_time:.2f}s") return model def preload_models(self, model_configs: Dict[str, Dict[str, Any]]): """ Preload multiple models in parallel Args: model_configs: Dictionary of model configurations { 'model_name': { 'path': 'path/to/model', 'load_func': callable, 'priority': int (0-10) } } """ # Sort by priority sorted_models = sorted( model_configs.items(), key=lambda x: x[1].get('priority', 5), reverse=True ) for model_name, config in sorted_models: try: self.load_model_cached( config['path'], config['load_func'], cache_key=model_name ) except Exception as e: print(f"Error preloading {model_name}: {e}") def optimize_gradio_settings(self) -> Dict[str, Any]: """ Get optimized Gradio settings for faster response Returns: Gradio launch parameters """ return { 'max_threads': 40, # Increase parallel processing 'show_error': True, 'server_name': '0.0.0.0', 'server_port': 7860, 'share': False, # Disable share link for faster startup } def get_optimization_stats(self) -> Dict[str, Any]: """ Get cold start optimization statistics Returns: Optimization statistics """ persistent_path = self.get_persistent_path() # Count cached files cached_files = 0 total_size = 0 if persistent_path.exists(): for file in persistent_path.rglob("*"): if file.is_file(): cached_files += 1 total_size += file.stat().st_size return { 'persistent_path': str(persistent_path), 'cached_models': len(self.model_cache), 'cached_files': cached_files, 'total_cache_size_mb': total_size / (1024 * 1024), 'load_times': self.load_times, 'average_load_time': sum(self.load_times.values()) / len(self.load_times) if self.load_times else 0 } def clear_memory_cache(self): """Clear in-memory model cache""" self.model_cache.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() print("Memory cache cleared") def setup_streaming_response(self) -> Dict[str, Any]: """ Set up configuration for streaming responses Returns: Streaming configuration """ return { 'stream_output': True, 'buffer_size': 8192, # 8KB buffer 'chunk_size': 1024, # 1KB chunks 'enable_compression': True, 'compression_level': 6 # Balanced compression }