""" Inference Cache System for DittoTalkingHead Caches video generation results for faster repeated processing """ import hashlib import json import os import pickle import time from pathlib import Path from typing import Optional, Dict, Any, Tuple, Union from functools import lru_cache import shutil from datetime import datetime, timedelta class InferenceCache: """ Cache system for video generation results Supports both memory and file-based caching """ def __init__( self, cache_dir: str = "/tmp/inference_cache", memory_cache_size: int = 100, file_cache_size_gb: float = 10.0, ttl_hours: int = 24 ): """ Initialize inference cache Args: cache_dir: Directory for file-based cache memory_cache_size: Maximum number of items in memory cache file_cache_size_gb: Maximum size of file cache in GB ttl_hours: Time to live for cache entries in hours """ self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.memory_cache_size = memory_cache_size self.file_cache_size_bytes = int(file_cache_size_gb * 1024 * 1024 * 1024) self.ttl_seconds = ttl_hours * 3600 # Metadata file for managing cache self.metadata_file = self.cache_dir / "cache_metadata.json" self.metadata = self._load_metadata() # In-memory cache self._memory_cache = {} self._access_times = {} # Clean up expired entries on initialization self._cleanup_expired() def _load_metadata(self) -> Dict[str, Any]: """Load cache metadata""" if self.metadata_file.exists(): try: with open(self.metadata_file, 'r') as f: return json.load(f) except: return {} return {} def _save_metadata(self): """Save cache metadata""" with open(self.metadata_file, 'w') as f: json.dump(self.metadata, f, indent=2) def generate_cache_key( self, audio_path: str, image_path: str, **kwargs ) -> str: """ Generate unique cache key based on input parameters Args: audio_path: Path to audio file image_path: Path to image file **kwargs: Additional parameters affecting output Returns: SHA-256 hash as cache key """ # Read file contents for hashing with open(audio_path, 'rb') as f: audio_hash = hashlib.sha256(f.read()).hexdigest() with open(image_path, 'rb') as f: image_hash = hashlib.sha256(f.read()).hexdigest() # Include relevant parameters in key key_data = { 'audio': audio_hash, 'image': image_hash, 'resolution': kwargs.get('resolution', '320x320'), 'steps': kwargs.get('steps', 25), 'seed': kwargs.get('seed', None) } # Generate final key key_str = json.dumps(key_data, sort_keys=True) return hashlib.sha256(key_str.encode()).hexdigest() def get_from_memory(self, cache_key: str) -> Optional[str]: """ Get video path from memory cache Args: cache_key: Cache key Returns: Video file path if found, None otherwise """ if cache_key in self._memory_cache: self._access_times[cache_key] = time.time() return self._memory_cache[cache_key] return None def get_from_file(self, cache_key: str) -> Optional[str]: """ Get video path from file cache Args: cache_key: Cache key Returns: Video file path if found, None otherwise """ if cache_key not in self.metadata: return None entry = self.metadata[cache_key] # Check expiration if time.time() > entry['expires_at']: self._remove_cache_entry(cache_key) return None # Check if file exists video_path = self.cache_dir / entry['filename'] if not video_path.exists(): self._remove_cache_entry(cache_key) return None # Update access time self.metadata[cache_key]['last_access'] = time.time() self._save_metadata() # Add to memory cache self._add_to_memory_cache(cache_key, str(video_path)) return str(video_path) def get(self, cache_key: str) -> Optional[str]: """ Get video from cache (memory first, then file) Args: cache_key: Cache key Returns: Video file path if found, None otherwise """ # Try memory cache first result = self.get_from_memory(cache_key) if result: return result # Try file cache return self.get_from_file(cache_key) def put( self, cache_key: str, video_path: str, **metadata ) -> bool: """ Store video in cache Args: cache_key: Cache key video_path: Path to generated video **metadata: Additional metadata to store Returns: True if stored successfully """ try: # Copy video to cache directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") cache_filename = f"{cache_key[:8]}_{timestamp}.mp4" cache_video_path = self.cache_dir / cache_filename shutil.copy2(video_path, cache_video_path) # Store metadata self.metadata[cache_key] = { 'filename': cache_filename, 'created_at': time.time(), 'expires_at': time.time() + self.ttl_seconds, 'last_access': time.time(), 'size_bytes': os.path.getsize(cache_video_path), 'metadata': metadata } # Check cache size and clean if needed self._check_cache_size() # Save metadata self._save_metadata() # Add to memory cache self._add_to_memory_cache(cache_key, str(cache_video_path)) return True except Exception as e: print(f"Error storing cache: {e}") return False def _add_to_memory_cache(self, cache_key: str, video_path: str): """Add item to memory cache with LRU eviction""" # Check if we need to evict if len(self._memory_cache) >= self.memory_cache_size: # Find least recently used lru_key = min(self._access_times, key=self._access_times.get) del self._memory_cache[lru_key] del self._access_times[lru_key] self._memory_cache[cache_key] = video_path self._access_times[cache_key] = time.time() def _check_cache_size(self): """Check and maintain cache size limit""" total_size = sum( entry['size_bytes'] for entry in self.metadata.values() ) if total_size > self.file_cache_size_bytes: # Remove oldest entries until under limit sorted_entries = sorted( self.metadata.items(), key=lambda x: x[1]['last_access'] ) while total_size > self.file_cache_size_bytes and sorted_entries: key_to_remove, entry = sorted_entries.pop(0) total_size -= entry['size_bytes'] self._remove_cache_entry(key_to_remove) def _cleanup_expired(self): """Remove expired cache entries""" current_time = time.time() expired_keys = [ key for key, entry in self.metadata.items() if current_time > entry['expires_at'] ] for key in expired_keys: self._remove_cache_entry(key) if expired_keys: print(f"Cleaned up {len(expired_keys)} expired cache entries") def _remove_cache_entry(self, cache_key: str): """Remove a cache entry""" if cache_key in self.metadata: # Remove file video_file = self.cache_dir / self.metadata[cache_key]['filename'] if video_file.exists(): video_file.unlink() # Remove from metadata del self.metadata[cache_key] # Remove from memory cache if cache_key in self._memory_cache: del self._memory_cache[cache_key] del self._access_times[cache_key] def clear_cache(self): """Clear all cache entries""" # Remove all video files for file in self.cache_dir.glob("*.mp4"): file.unlink() # Clear metadata self.metadata = {} self._save_metadata() # Clear memory cache self._memory_cache.clear() self._access_times.clear() print("Inference cache cleared") def get_cache_stats(self) -> Dict[str, Any]: """Get cache statistics""" total_size = sum( entry['size_bytes'] for entry in self.metadata.values() ) memory_hits = len(self._memory_cache) file_entries = len(self.metadata) return { 'memory_cache_entries': memory_hits, 'file_cache_entries': file_entries, 'total_cache_size_mb': total_size / (1024 * 1024), 'cache_size_limit_gb': self.file_cache_size_bytes / (1024 * 1024 * 1024), 'ttl_hours': self.ttl_seconds / 3600, 'cache_directory': str(self.cache_dir) } class CachedInference: """ Wrapper for cached inference execution """ def __init__(self, cache: InferenceCache): """ Initialize cached inference Args: cache: InferenceCache instance """ self.cache = cache def process_with_cache( self, inference_func: callable, audio_path: str, image_path: str, output_path: str, **kwargs ) -> Tuple[str, bool, float]: """ Process with caching Args: inference_func: Function to generate video audio_path: Path to audio file image_path: Path to image file output_path: Desired output path **kwargs: Additional parameters Returns: Tuple of (output_path, cache_hit, process_time) """ start_time = time.time() # Generate cache key cache_key = self.cache.generate_cache_key( audio_path, image_path, **kwargs ) # Check cache cached_video = self.cache.get(cache_key) if cached_video: # Cache hit - copy to output path shutil.copy2(cached_video, output_path) process_time = time.time() - start_time print(f"✅ Cache hit! Retrieved in {process_time:.2f}s") return output_path, True, process_time # Cache miss - generate video print("Cache miss - generating video...") inference_func(audio_path, image_path, output_path, **kwargs) # Store in cache if os.path.exists(output_path): self.cache.put(cache_key, output_path, **kwargs) process_time = time.time() - start_time return output_path, False, process_time