Spaces:
Runtime error
Runtime error
""" | |
Avatar Cache System for DittoTalkingHead | |
Implements image pre-upload and embedding caching | |
""" | |
import os | |
import pickle | |
import hashlib | |
import time | |
from typing import Optional, Dict, Any, Tuple | |
from datetime import datetime, timedelta | |
import json | |
from pathlib import Path | |
class AvatarCache: | |
""" | |
Avatar embedding cache system | |
Stores pre-computed image embeddings for faster video generation | |
""" | |
def __init__(self, cache_dir: str = "/tmp/avatar_cache", ttl_days: int = 14): | |
""" | |
Initialize avatar cache | |
Args: | |
cache_dir: Directory to store cache files | |
ttl_days: Time to live for cache entries in days | |
""" | |
self.cache_dir = Path(cache_dir) | |
self.cache_dir.mkdir(parents=True, exist_ok=True) | |
self.ttl_seconds = ttl_days * 24 * 60 * 60 | |
self.metadata_file = self.cache_dir / "metadata.json" | |
# Load existing metadata | |
self.metadata = self._load_metadata() | |
# Clean expired entries on initialization | |
self._cleanup_expired() | |
def _load_metadata(self) -> Dict[str, Any]: | |
"""Load cache metadata""" | |
if self.metadata_file.exists(): | |
try: | |
with open(self.metadata_file, 'r') as f: | |
return json.load(f) | |
except: | |
return {} | |
return {} | |
def _save_metadata(self): | |
"""Save cache metadata""" | |
with open(self.metadata_file, 'w') as f: | |
json.dump(self.metadata, f, indent=2) | |
def _cleanup_expired(self): | |
"""Remove expired cache entries""" | |
current_time = time.time() | |
expired_tokens = [] | |
for token, info in self.metadata.items(): | |
if current_time > info['expires_at']: | |
expired_tokens.append(token) | |
cache_file = self.cache_dir / f"{token}.pkl" | |
if cache_file.exists(): | |
cache_file.unlink() | |
for token in expired_tokens: | |
del self.metadata[token] | |
if expired_tokens: | |
self._save_metadata() | |
print(f"Cleaned up {len(expired_tokens)} expired cache entries") | |
def generate_token(self, img_bytes: bytes) -> str: | |
""" | |
Generate unique token for image | |
Args: | |
img_bytes: Image data as bytes | |
Returns: | |
SHA-1 hash token | |
""" | |
return hashlib.sha1(img_bytes).hexdigest() | |
def store_embedding( | |
self, | |
img_bytes: bytes, | |
embedding: Any, | |
additional_info: Optional[Dict[str, Any]] = None | |
) -> Tuple[str, datetime]: | |
""" | |
Store image embedding in cache | |
Args: | |
img_bytes: Image data as bytes | |
embedding: Pre-computed embedding (latent vector) | |
additional_info: Additional metadata to store | |
Returns: | |
Tuple of (token, expiration_date) | |
""" | |
token = self.generate_token(img_bytes) | |
cache_file = self.cache_dir / f"{token}.pkl" | |
# Calculate expiration | |
expires_at = time.time() + self.ttl_seconds | |
expiration_date = datetime.fromtimestamp(expires_at) | |
# Save embedding | |
cache_data = { | |
'embedding': embedding, | |
'created_at': time.time(), | |
'expires_at': expires_at, | |
'additional_info': additional_info or {} | |
} | |
with open(cache_file, 'wb') as f: | |
pickle.dump(cache_data, f) | |
# Update metadata | |
self.metadata[token] = { | |
'expires_at': expires_at, | |
'created_at': time.time(), | |
'file_size': os.path.getsize(cache_file) | |
} | |
self._save_metadata() | |
return token, expiration_date | |
def load_embedding(self, token: str) -> Optional[Any]: | |
""" | |
Load embedding from cache | |
Args: | |
token: Avatar token | |
Returns: | |
Embedding if found and valid, None otherwise | |
""" | |
# Check if token exists and not expired | |
if token not in self.metadata: | |
return None | |
if time.time() > self.metadata[token]['expires_at']: | |
# Token expired | |
self._cleanup_expired() | |
return None | |
# Load from file | |
cache_file = self.cache_dir / f"{token}.pkl" | |
if not cache_file.exists(): | |
# File missing, clean up metadata | |
del self.metadata[token] | |
self._save_metadata() | |
return None | |
try: | |
with open(cache_file, 'rb') as f: | |
cache_data = pickle.load(f) | |
return cache_data['embedding'] | |
except Exception as e: | |
print(f"Error loading cache for token {token}: {e}") | |
return None | |
def get_cache_info(self) -> Dict[str, Any]: | |
""" | |
Get cache statistics | |
Returns: | |
Cache information | |
""" | |
total_size = 0 | |
active_entries = 0 | |
for token, info in self.metadata.items(): | |
if time.time() <= info['expires_at']: | |
active_entries += 1 | |
total_size += info.get('file_size', 0) | |
return { | |
'cache_dir': str(self.cache_dir), | |
'active_entries': active_entries, | |
'total_entries': len(self.metadata), | |
'total_size_mb': total_size / (1024 * 1024), | |
'ttl_days': self.ttl_seconds / (24 * 60 * 60) | |
} | |
def clear_cache(self): | |
"""Clear all cache entries""" | |
for file in self.cache_dir.glob("*.pkl"): | |
file.unlink() | |
self.metadata = {} | |
self._save_metadata() | |
print("Avatar cache cleared") | |
class AvatarTokenManager: | |
""" | |
Manages avatar tokens and their lifecycle | |
""" | |
def __init__(self, cache: AvatarCache): | |
""" | |
Initialize token manager | |
Args: | |
cache: Avatar cache instance | |
""" | |
self.cache = cache | |
def prepare_avatar( | |
self, | |
image_data: bytes, | |
appearance_encoder_func: callable, | |
**encoder_kwargs | |
) -> Dict[str, Any]: | |
""" | |
Prepare avatar by pre-computing embedding | |
Args: | |
image_data: Image data as bytes | |
appearance_encoder_func: Function to encode appearance | |
**encoder_kwargs: Additional arguments for encoder | |
Returns: | |
Response with avatar token and expiration | |
""" | |
# Check if already cached | |
token = self.cache.generate_token(image_data) | |
existing_embedding = self.cache.load_embedding(token) | |
if existing_embedding is not None: | |
# Already cached, return existing token | |
metadata = self.cache.metadata.get(token, {}) | |
expires_at = datetime.fromtimestamp(metadata.get('expires_at', 0)) | |
return { | |
'avatar_token': token, | |
'expires': expires_at.isoformat(), | |
'cached': True | |
} | |
# Compute new embedding | |
try: | |
embedding = appearance_encoder_func(image_data, **encoder_kwargs) | |
# Store in cache | |
token, expiration = self.cache.store_embedding( | |
image_data, | |
embedding, | |
additional_info={'encoder_kwargs': encoder_kwargs} | |
) | |
return { | |
'avatar_token': token, | |
'expires': expiration.isoformat(), | |
'cached': False | |
} | |
except Exception as e: | |
raise RuntimeError(f"Failed to prepare avatar: {str(e)}") | |
def validate_token(self, token: str) -> bool: | |
""" | |
Validate if token is valid and not expired | |
Args: | |
token: Avatar token to validate | |
Returns: | |
True if valid, False otherwise | |
""" | |
return self.cache.load_embedding(token) is not None | |
def get_token_info(self, token: str) -> Optional[Dict[str, Any]]: | |
""" | |
Get information about a token | |
Args: | |
token: Avatar token | |
Returns: | |
Token information if found, None otherwise | |
""" | |
if token not in self.cache.metadata: | |
return None | |
info = self.cache.metadata[token] | |
current_time = time.time() | |
return { | |
'token': token, | |
'valid': current_time <= info['expires_at'], | |
'created_at': datetime.fromtimestamp(info['created_at']).isoformat(), | |
'expires_at': datetime.fromtimestamp(info['expires_at']).isoformat(), | |
'file_size_kb': info.get('file_size', 0) / 1024 | |
} |