""" Parallel Inference Integration for DittoTalkingHead Integrates parallel processing into the inference pipeline """ import asyncio import time from typing import Dict, Any, Tuple, Optional import numpy as np import torch from pathlib import Path from .parallel_processing import ParallelProcessor, PipelineProcessor class ParallelInference: """ Parallel inference wrapper for DittoTalkingHead """ def __init__(self, sdk, parallel_processor: Optional[ParallelProcessor] = None): """ Initialize parallel inference Args: sdk: StreamSDK instance parallel_processor: ParallelProcessor instance """ self.sdk = sdk self.parallel_processor = parallel_processor or ParallelProcessor(num_threads=4) # Setup pipeline stages self.pipeline_stages = { 'load': self._load_files, 'preprocess': self._preprocess, 'inference': self._inference, 'postprocess': self._postprocess } def _load_files(self, paths: Dict[str, str]) -> Dict[str, Any]: """Load audio and image files""" audio_path = paths['audio'] image_path = paths['image'] # Parallel loading audio_data, image_data = self.parallel_processor.preprocess_parallel_sync( audio_path, image_path ) return { 'audio_data': audio_data, 'image_data': image_data, 'paths': paths } def _preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: """Preprocess loaded data""" # Extract audio features audio = data['audio_data']['audio'] sr = data['audio_data']['sample_rate'] # Prepare for SDK import librosa import math # Calculate number of frames num_frames = math.ceil(len(audio) / 16000 * 25) # Prepare image image = data['image_data']['image'] return { 'audio': audio, 'image': image, 'num_frames': num_frames, 'paths': data['paths'] } def _inference(self, data: Dict[str, Any]) -> Dict[str, Any]: """Run inference""" # This would integrate with the actual SDK inference # For now, placeholder return { 'result': 'inference_result', 'paths': data['paths'] } def _postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: """Postprocess results""" return data async def process_parallel_async( self, audio_path: str, image_path: str, output_path: str, **kwargs ) -> Tuple[str, float]: """ Process with full parallelization (async) Args: audio_path: Path to audio file image_path: Path to image file output_path: Output video path **kwargs: Additional parameters Returns: Tuple of (output_path, process_time) """ start_time = time.time() # Parallel preprocessing audio_data, image_data = await self.parallel_processor.preprocess_parallel_async( audio_path, image_path, kwargs.get('target_size', 320) ) # Run inference (simplified for integration) # In real implementation, this would call SDK methods process_time = time.time() - start_time return output_path, process_time def process_parallel_sync( self, audio_path: str, image_path: str, output_path: str, **kwargs ) -> Tuple[str, float]: """ Process with parallelization (sync) Args: audio_path: Path to audio file image_path: Path to image file output_path: Output video path **kwargs: Additional parameters Returns: Tuple of (output_path, process_time) """ start_time = time.time() try: # Parallel preprocessing print("🔄 Starting parallel preprocessing...") preprocess_start = time.time() audio_data, image_data = self.parallel_processor.preprocess_parallel_sync( audio_path, image_path, kwargs.get('target_size', 320) ) preprocess_time = time.time() - preprocess_start print(f"✅ Parallel preprocessing completed in {preprocess_time:.2f}s") # Run actual SDK inference # This integrates with the existing SDK from inference import run, seed_everything seed_everything(kwargs.get('seed', 1024)) inference_start = time.time() run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {})) inference_time = time.time() - inference_start print(f"✅ Inference completed in {inference_time:.2f}s") total_time = time.time() - start_time # Performance breakdown print(f""" 🎯 Performance Breakdown: - Preprocessing (parallel): {preprocess_time:.2f}s - Inference: {inference_time:.2f}s - Total: {total_time:.2f}s """) return output_path, total_time except Exception as e: print(f"❌ Error in parallel processing: {e}") raise def get_performance_stats(self) -> Dict[str, Any]: """Get performance statistics""" return { 'num_threads': self.parallel_processor.num_threads, 'num_processes': self.parallel_processor.num_processes, 'cuda_streams_enabled': self.parallel_processor.use_cuda_streams } class OptimizedInferenceWrapper: """ Wrapper that combines all optimizations """ def __init__( self, sdk, use_parallel: bool = True, use_cache: bool = True, use_gpu_opt: bool = True ): """ Initialize optimized inference wrapper Args: sdk: StreamSDK instance use_parallel: Enable parallel processing use_cache: Enable caching use_gpu_opt: Enable GPU optimizations """ self.sdk = sdk self.use_parallel = use_parallel self.use_cache = use_cache self.use_gpu_opt = use_gpu_opt # Initialize components if use_parallel: self.parallel_processor = ParallelProcessor(num_threads=4) self.parallel_inference = ParallelInference(sdk, self.parallel_processor) else: self.parallel_processor = None self.parallel_inference = None def process( self, audio_path: str, image_path: str, output_path: str, **kwargs ) -> Tuple[str, float, Dict[str, Any]]: """ Process with all optimizations Returns: Tuple of (output_path, process_time, stats) """ stats = { 'parallel_enabled': self.use_parallel, 'cache_enabled': self.use_cache, 'gpu_opt_enabled': self.use_gpu_opt } if self.use_parallel and self.parallel_inference: output_path, process_time = self.parallel_inference.process_parallel_sync( audio_path, image_path, output_path, **kwargs ) stats['preprocessing'] = 'parallel' else: # Fallback to sequential from inference import run, seed_everything start_time = time.time() seed_everything(kwargs.get('seed', 1024)) run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {})) process_time = time.time() - start_time stats['preprocessing'] = 'sequential' stats['process_time'] = process_time return output_path, process_time, stats def shutdown(self): """Cleanup resources""" if self.parallel_processor: self.parallel_processor.shutdown()