Spaces:
Runtime error
Runtime error
File size: 7,970 Bytes
b27232b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
"""
Cold Start Optimization for DittoTalkingHead
Reduces model loading time and I/O overhead
"""
import os
import shutil
import time
from pathlib import Path
from typing import Dict, Any, Optional
import pickle
import torch
class ColdStartOptimizer:
"""
Optimizes cold start time by using persistent storage and efficient loading
"""
def __init__(self, persistent_dir: str = "/tmp/persistent_model_cache"):
"""
Initialize cold start optimizer
Args:
persistent_dir: Directory for persistent storage (survives restarts)
"""
self.persistent_dir = Path(persistent_dir)
self.persistent_dir.mkdir(parents=True, exist_ok=True)
# Hugging Face Spaces persistent paths
self.hf_persistent_paths = [
"/data", # Primary persistent storage
"/tmp/persistent", # Fallback
]
# Model cache settings
self.model_cache = {}
self.load_times = {}
def get_persistent_path(self) -> Path:
"""
Get the best available persistent path
Returns:
Path to persistent storage
"""
# Check Hugging Face Spaces persistent directories
for path in self.hf_persistent_paths:
if os.path.exists(path) and os.access(path, os.W_OK):
return Path(path) / "model_cache"
# Fallback to configured directory
return self.persistent_dir
def setup_persistent_model_cache(self, source_dir: str) -> bool:
"""
Set up persistent model cache
Args:
source_dir: Source directory containing models
Returns:
True if successful
"""
persistent_path = self.get_persistent_path()
persistent_path.mkdir(parents=True, exist_ok=True)
source_path = Path(source_dir)
if not source_path.exists():
print(f"Source directory {source_dir} not found")
return False
# Copy models to persistent storage if not already there
model_files = list(source_path.glob("**/*.pth")) + \
list(source_path.glob("**/*.pkl")) + \
list(source_path.glob("**/*.onnx")) + \
list(source_path.glob("**/*.trt"))
copied = 0
for model_file in model_files:
relative_path = model_file.relative_to(source_path)
target_path = persistent_path / relative_path
if not target_path.exists():
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(model_file, target_path)
copied += 1
print(f"Copied {relative_path} to persistent storage")
print(f"Persistent cache setup complete. Copied {copied} new files.")
return True
def load_model_cached(
self,
model_path: str,
load_func: callable,
cache_key: Optional[str] = None
) -> Any:
"""
Load model with caching
Args:
model_path: Path to model file
load_func: Function to load the model
cache_key: Optional cache key (defaults to model_path)
Returns:
Loaded model
"""
cache_key = cache_key or model_path
# Check in-memory cache first
if cache_key in self.model_cache:
print(f"✅ Loaded {cache_key} from memory cache")
return self.model_cache[cache_key]
# Check persistent storage
persistent_path = self.get_persistent_path()
model_name = Path(model_path).name
persistent_model_path = persistent_path / model_name
start_time = time.time()
if persistent_model_path.exists():
# Load from persistent storage
print(f"Loading {model_name} from persistent storage...")
model = load_func(str(persistent_model_path))
else:
# Load from original path
print(f"Loading {model_name} from original location...")
model = load_func(model_path)
# Try to copy to persistent storage
try:
shutil.copy2(model_path, persistent_model_path)
print(f"Cached {model_name} to persistent storage")
except Exception as e:
print(f"Warning: Could not cache to persistent storage: {e}")
load_time = time.time() - start_time
self.load_times[cache_key] = load_time
# Cache in memory
self.model_cache[cache_key] = model
print(f"✅ Loaded {cache_key} in {load_time:.2f}s")
return model
def preload_models(self, model_configs: Dict[str, Dict[str, Any]]):
"""
Preload multiple models in parallel
Args:
model_configs: Dictionary of model configurations
{
'model_name': {
'path': 'path/to/model',
'load_func': callable,
'priority': int (0-10)
}
}
"""
# Sort by priority
sorted_models = sorted(
model_configs.items(),
key=lambda x: x[1].get('priority', 5),
reverse=True
)
for model_name, config in sorted_models:
try:
self.load_model_cached(
config['path'],
config['load_func'],
cache_key=model_name
)
except Exception as e:
print(f"Error preloading {model_name}: {e}")
def optimize_gradio_settings(self) -> Dict[str, Any]:
"""
Get optimized Gradio settings for faster response
Returns:
Gradio launch parameters
"""
return {
'max_threads': 40, # Increase parallel processing
'show_error': True,
'server_name': '0.0.0.0',
'server_port': 7860,
'share': False, # Disable share link for faster startup
}
def get_optimization_stats(self) -> Dict[str, Any]:
"""
Get cold start optimization statistics
Returns:
Optimization statistics
"""
persistent_path = self.get_persistent_path()
# Count cached files
cached_files = 0
total_size = 0
if persistent_path.exists():
for file in persistent_path.rglob("*"):
if file.is_file():
cached_files += 1
total_size += file.stat().st_size
return {
'persistent_path': str(persistent_path),
'cached_models': len(self.model_cache),
'cached_files': cached_files,
'total_cache_size_mb': total_size / (1024 * 1024),
'load_times': self.load_times,
'average_load_time': sum(self.load_times.values()) / len(self.load_times) if self.load_times else 0
}
def clear_memory_cache(self):
"""Clear in-memory model cache"""
self.model_cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Memory cache cleared")
def setup_streaming_response(self) -> Dict[str, Any]:
"""
Set up configuration for streaming responses
Returns:
Streaming configuration
"""
return {
'stream_output': True,
'buffer_size': 8192, # 8KB buffer
'chunk_size': 1024, # 1KB chunks
'enable_compression': True,
'compression_level': 6 # Balanced compression
} |