Spaces:
Runtime error
Runtime error
""" | |
Cold Start Optimization for DittoTalkingHead | |
Reduces model loading time and I/O overhead | |
""" | |
import os | |
import shutil | |
import time | |
from pathlib import Path | |
from typing import Dict, Any, Optional | |
import pickle | |
import torch | |
class ColdStartOptimizer: | |
""" | |
Optimizes cold start time by using persistent storage and efficient loading | |
""" | |
def __init__(self, persistent_dir: str = "/tmp/persistent_model_cache"): | |
""" | |
Initialize cold start optimizer | |
Args: | |
persistent_dir: Directory for persistent storage (survives restarts) | |
""" | |
self.persistent_dir = Path(persistent_dir) | |
self.persistent_dir.mkdir(parents=True, exist_ok=True) | |
# Hugging Face Spaces persistent paths | |
self.hf_persistent_paths = [ | |
"/data", # Primary persistent storage | |
"/tmp/persistent", # Fallback | |
] | |
# Model cache settings | |
self.model_cache = {} | |
self.load_times = {} | |
def get_persistent_path(self) -> Path: | |
""" | |
Get the best available persistent path | |
Returns: | |
Path to persistent storage | |
""" | |
# Check Hugging Face Spaces persistent directories | |
for path in self.hf_persistent_paths: | |
if os.path.exists(path) and os.access(path, os.W_OK): | |
return Path(path) / "model_cache" | |
# Fallback to configured directory | |
return self.persistent_dir | |
def setup_persistent_model_cache(self, source_dir: str) -> bool: | |
""" | |
Set up persistent model cache | |
Args: | |
source_dir: Source directory containing models | |
Returns: | |
True if successful | |
""" | |
persistent_path = self.get_persistent_path() | |
persistent_path.mkdir(parents=True, exist_ok=True) | |
source_path = Path(source_dir) | |
if not source_path.exists(): | |
print(f"Source directory {source_dir} not found") | |
return False | |
# Copy models to persistent storage if not already there | |
model_files = list(source_path.glob("**/*.pth")) + \ | |
list(source_path.glob("**/*.pkl")) + \ | |
list(source_path.glob("**/*.onnx")) + \ | |
list(source_path.glob("**/*.trt")) | |
copied = 0 | |
for model_file in model_files: | |
relative_path = model_file.relative_to(source_path) | |
target_path = persistent_path / relative_path | |
if not target_path.exists(): | |
target_path.parent.mkdir(parents=True, exist_ok=True) | |
shutil.copy2(model_file, target_path) | |
copied += 1 | |
print(f"Copied {relative_path} to persistent storage") | |
print(f"Persistent cache setup complete. Copied {copied} new files.") | |
return True | |
def load_model_cached( | |
self, | |
model_path: str, | |
load_func: callable, | |
cache_key: Optional[str] = None | |
) -> Any: | |
""" | |
Load model with caching | |
Args: | |
model_path: Path to model file | |
load_func: Function to load the model | |
cache_key: Optional cache key (defaults to model_path) | |
Returns: | |
Loaded model | |
""" | |
cache_key = cache_key or model_path | |
# Check in-memory cache first | |
if cache_key in self.model_cache: | |
print(f"✅ Loaded {cache_key} from memory cache") | |
return self.model_cache[cache_key] | |
# Check persistent storage | |
persistent_path = self.get_persistent_path() | |
model_name = Path(model_path).name | |
persistent_model_path = persistent_path / model_name | |
start_time = time.time() | |
if persistent_model_path.exists(): | |
# Load from persistent storage | |
print(f"Loading {model_name} from persistent storage...") | |
model = load_func(str(persistent_model_path)) | |
else: | |
# Load from original path | |
print(f"Loading {model_name} from original location...") | |
model = load_func(model_path) | |
# Try to copy to persistent storage | |
try: | |
shutil.copy2(model_path, persistent_model_path) | |
print(f"Cached {model_name} to persistent storage") | |
except Exception as e: | |
print(f"Warning: Could not cache to persistent storage: {e}") | |
load_time = time.time() - start_time | |
self.load_times[cache_key] = load_time | |
# Cache in memory | |
self.model_cache[cache_key] = model | |
print(f"✅ Loaded {cache_key} in {load_time:.2f}s") | |
return model | |
def preload_models(self, model_configs: Dict[str, Dict[str, Any]]): | |
""" | |
Preload multiple models in parallel | |
Args: | |
model_configs: Dictionary of model configurations | |
{ | |
'model_name': { | |
'path': 'path/to/model', | |
'load_func': callable, | |
'priority': int (0-10) | |
} | |
} | |
""" | |
# Sort by priority | |
sorted_models = sorted( | |
model_configs.items(), | |
key=lambda x: x[1].get('priority', 5), | |
reverse=True | |
) | |
for model_name, config in sorted_models: | |
try: | |
self.load_model_cached( | |
config['path'], | |
config['load_func'], | |
cache_key=model_name | |
) | |
except Exception as e: | |
print(f"Error preloading {model_name}: {e}") | |
def optimize_gradio_settings(self) -> Dict[str, Any]: | |
""" | |
Get optimized Gradio settings for faster response | |
Returns: | |
Gradio launch parameters | |
""" | |
return { | |
'max_threads': 40, # Increase parallel processing | |
'show_error': True, | |
'server_name': '0.0.0.0', | |
'server_port': 7860, | |
'share': False, # Disable share link for faster startup | |
} | |
def get_optimization_stats(self) -> Dict[str, Any]: | |
""" | |
Get cold start optimization statistics | |
Returns: | |
Optimization statistics | |
""" | |
persistent_path = self.get_persistent_path() | |
# Count cached files | |
cached_files = 0 | |
total_size = 0 | |
if persistent_path.exists(): | |
for file in persistent_path.rglob("*"): | |
if file.is_file(): | |
cached_files += 1 | |
total_size += file.stat().st_size | |
return { | |
'persistent_path': str(persistent_path), | |
'cached_models': len(self.model_cache), | |
'cached_files': cached_files, | |
'total_cache_size_mb': total_size / (1024 * 1024), | |
'load_times': self.load_times, | |
'average_load_time': sum(self.load_times.values()) / len(self.load_times) if self.load_times else 0 | |
} | |
def clear_memory_cache(self): | |
"""Clear in-memory model cache""" | |
self.model_cache.clear() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print("Memory cache cleared") | |
def setup_streaming_response(self) -> Dict[str, Any]: | |
""" | |
Set up configuration for streaming responses | |
Returns: | |
Streaming configuration | |
""" | |
return { | |
'stream_output': True, | |
'buffer_size': 8192, # 8KB buffer | |
'chunk_size': 1024, # 1KB chunks | |
'enable_compression': True, | |
'compression_level': 6 # Balanced compression | |
} |