""" Avatar Cache System for DittoTalkingHead Implements image pre-upload and embedding caching """ import os import pickle import hashlib import time from typing import Optional, Dict, Any, Tuple from datetime import datetime, timedelta import json from pathlib import Path class AvatarCache: """ Avatar embedding cache system Stores pre-computed image embeddings for faster video generation """ def __init__(self, cache_dir: str = "/tmp/avatar_cache", ttl_days: int = 14): """ Initialize avatar cache Args: cache_dir: Directory to store cache files ttl_days: Time to live for cache entries in days """ self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.ttl_seconds = ttl_days * 24 * 60 * 60 self.metadata_file = self.cache_dir / "metadata.json" # Load existing metadata self.metadata = self._load_metadata() # Clean expired entries on initialization self._cleanup_expired() def _load_metadata(self) -> Dict[str, Any]: """Load cache metadata""" if self.metadata_file.exists(): try: with open(self.metadata_file, 'r') as f: return json.load(f) except: return {} return {} def _save_metadata(self): """Save cache metadata""" with open(self.metadata_file, 'w') as f: json.dump(self.metadata, f, indent=2) def _cleanup_expired(self): """Remove expired cache entries""" current_time = time.time() expired_tokens = [] for token, info in self.metadata.items(): if current_time > info['expires_at']: expired_tokens.append(token) cache_file = self.cache_dir / f"{token}.pkl" if cache_file.exists(): cache_file.unlink() for token in expired_tokens: del self.metadata[token] if expired_tokens: self._save_metadata() print(f"Cleaned up {len(expired_tokens)} expired cache entries") def generate_token(self, img_bytes: bytes) -> str: """ Generate unique token for image Args: img_bytes: Image data as bytes Returns: SHA-1 hash token """ return hashlib.sha1(img_bytes).hexdigest() def store_embedding( self, img_bytes: bytes, embedding: Any, additional_info: Optional[Dict[str, Any]] = None ) -> Tuple[str, datetime]: """ Store image embedding in cache Args: img_bytes: Image data as bytes embedding: Pre-computed embedding (latent vector) additional_info: Additional metadata to store Returns: Tuple of (token, expiration_date) """ token = self.generate_token(img_bytes) cache_file = self.cache_dir / f"{token}.pkl" # Calculate expiration expires_at = time.time() + self.ttl_seconds expiration_date = datetime.fromtimestamp(expires_at) # Save embedding cache_data = { 'embedding': embedding, 'created_at': time.time(), 'expires_at': expires_at, 'additional_info': additional_info or {} } with open(cache_file, 'wb') as f: pickle.dump(cache_data, f) # Update metadata self.metadata[token] = { 'expires_at': expires_at, 'created_at': time.time(), 'file_size': os.path.getsize(cache_file) } self._save_metadata() return token, expiration_date def load_embedding(self, token: str) -> Optional[Any]: """ Load embedding from cache Args: token: Avatar token Returns: Embedding if found and valid, None otherwise """ # Check if token exists and not expired if token not in self.metadata: return None if time.time() > self.metadata[token]['expires_at']: # Token expired self._cleanup_expired() return None # Load from file cache_file = self.cache_dir / f"{token}.pkl" if not cache_file.exists(): # File missing, clean up metadata del self.metadata[token] self._save_metadata() return None try: with open(cache_file, 'rb') as f: cache_data = pickle.load(f) return cache_data['embedding'] except Exception as e: print(f"Error loading cache for token {token}: {e}") return None def get_cache_info(self) -> Dict[str, Any]: """ Get cache statistics Returns: Cache information """ total_size = 0 active_entries = 0 for token, info in self.metadata.items(): if time.time() <= info['expires_at']: active_entries += 1 total_size += info.get('file_size', 0) return { 'cache_dir': str(self.cache_dir), 'active_entries': active_entries, 'total_entries': len(self.metadata), 'total_size_mb': total_size / (1024 * 1024), 'ttl_days': self.ttl_seconds / (24 * 60 * 60) } def clear_cache(self): """Clear all cache entries""" for file in self.cache_dir.glob("*.pkl"): file.unlink() self.metadata = {} self._save_metadata() print("Avatar cache cleared") class AvatarTokenManager: """ Manages avatar tokens and their lifecycle """ def __init__(self, cache: AvatarCache): """ Initialize token manager Args: cache: Avatar cache instance """ self.cache = cache def prepare_avatar( self, image_data: bytes, appearance_encoder_func: callable, **encoder_kwargs ) -> Dict[str, Any]: """ Prepare avatar by pre-computing embedding Args: image_data: Image data as bytes appearance_encoder_func: Function to encode appearance **encoder_kwargs: Additional arguments for encoder Returns: Response with avatar token and expiration """ # Check if already cached token = self.cache.generate_token(image_data) existing_embedding = self.cache.load_embedding(token) if existing_embedding is not None: # Already cached, return existing token metadata = self.cache.metadata.get(token, {}) expires_at = datetime.fromtimestamp(metadata.get('expires_at', 0)) return { 'avatar_token': token, 'expires': expires_at.isoformat(), 'cached': True } # Compute new embedding try: embedding = appearance_encoder_func(image_data, **encoder_kwargs) # Store in cache token, expiration = self.cache.store_embedding( image_data, embedding, additional_info={'encoder_kwargs': encoder_kwargs} ) return { 'avatar_token': token, 'expires': expiration.isoformat(), 'cached': False } except Exception as e: raise RuntimeError(f"Failed to prepare avatar: {str(e)}") def validate_token(self, token: str) -> bool: """ Validate if token is valid and not expired Args: token: Avatar token to validate Returns: True if valid, False otherwise """ return self.cache.load_embedding(token) is not None def get_token_info(self, token: str) -> Optional[Dict[str, Any]]: """ Get information about a token Args: token: Avatar token Returns: Token information if found, None otherwise """ if token not in self.cache.metadata: return None info = self.cache.metadata[token] current_time = time.time() return { 'token': token, 'valid': current_time <= info['expires_at'], 'created_at': datetime.fromtimestamp(info['created_at']).isoformat(), 'expires_at': datetime.fromtimestamp(info['expires_at']).isoformat(), 'file_size_kb': info.get('file_size', 0) / 1024 }