import os import re import json import glob from typing import List, Optional, Dict, Tuple import uuid import asyncio import time from concurrent.futures import ThreadPoolExecutor from functools import lru_cache import aiofiles from mllm_tools.utils import _prepare_text_inputs from src.utils.utils import extract_xml from task_generator import ( get_prompt_scene_plan, get_prompt_scene_vision_storyboard, get_prompt_scene_technical_implementation, get_prompt_scene_animation_narration, get_prompt_context_learning_scene_plan, get_prompt_context_learning_vision_storyboard, get_prompt_context_learning_technical_implementation, get_prompt_context_learning_animation_narration, get_prompt_context_learning_code ) from src.rag.rag_integration import RAGIntegration class EnhancedVideoPlanner: """Enhanced video planner with improved parallelization and performance.""" def __init__(self, planner_model, helper_model=None, output_dir="output", print_response=False, use_context_learning=False, context_learning_path="data/context_learning", use_rag=False, session_id=None, chroma_db_path="data/rag/chroma_db", manim_docs_path="data/rag/manim_docs", embedding_model="text-embedding-ada-002", use_langfuse=True, max_scene_concurrency=5, max_step_concurrency=3, enable_caching=True): self.planner_model = planner_model self.helper_model = helper_model if helper_model is not None else planner_model self.output_dir = output_dir self.print_response = print_response self.use_context_learning = use_context_learning self.context_learning_path = context_learning_path self.use_rag = use_rag self.session_id = session_id self.enable_caching = enable_caching # Enhanced concurrency control self.max_scene_concurrency = max_scene_concurrency self.max_step_concurrency = max_step_concurrency self.scene_semaphore = asyncio.Semaphore(max_scene_concurrency) self.step_semaphore = asyncio.Semaphore(max_step_concurrency) # Thread pool for I/O operations self.thread_pool = ThreadPoolExecutor(max_workers=4) # Cache for prompts and examples self._context_cache = {} self._prompt_cache = {} # Initialize context examples with caching self._initialize_context_examples() # Initialize RAG with enhanced settings self.rag_integration = None self.relevant_plugins = [] if use_rag: self.rag_integration = RAGIntegration( helper_model=helper_model, output_dir=output_dir, chroma_db_path=chroma_db_path, manim_docs_path=manim_docs_path, embedding_model=embedding_model, use_langfuse=use_langfuse, session_id=session_id ) def _initialize_context_examples(self): """Initialize and cache context examples for faster access.""" example_types = [ 'scene_plan', 'scene_vision_storyboard', 'technical_implementation', 'scene_animation_narration', 'code' ] if self.use_context_learning: for example_type in example_types: self._context_cache[example_type] = self._load_context_examples(example_type) else: for example_type in example_types: self._context_cache[example_type] = None @lru_cache(maxsize=128) def _get_cached_prompt(self, prompt_type: str, *args) -> str: """Get cached prompt to avoid regeneration.""" prompt_generators = { 'scene_plan': get_prompt_scene_plan, 'scene_vision_storyboard': get_prompt_scene_vision_storyboard, 'scene_technical_implementation': get_prompt_scene_technical_implementation, 'scene_animation_narration': get_prompt_scene_animation_narration } generator = prompt_generators.get(prompt_type) if generator: return generator(*args) return "" async def _async_file_write(self, file_path: str, content: str): """Asynchronous file writing for better performance.""" async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: await f.write(content) async def _async_file_read(self, file_path: str) -> str: """Asynchronous file reading.""" try: async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: return await f.read() except FileNotFoundError: return None async def _ensure_directories(self, *paths): """Asynchronously ensure directories exist.""" loop = asyncio.get_event_loop() for path in paths: await loop.run_in_executor(self.thread_pool, lambda p: os.makedirs(p, exist_ok=True), path) def _load_context_examples(self, example_type: str) -> str: """Load context learning examples with improved performance.""" if example_type in self._context_cache: return self._context_cache[example_type] examples = [] file_patterns = { 'scene_plan': '*_scene_plan.txt', 'scene_vision_storyboard': '*_scene_vision_storyboard.txt', 'technical_implementation': '*_technical_implementation.txt', 'scene_animation_narration': '*_scene_animation_narration.txt', 'code': '*.py' } pattern = file_patterns.get(example_type) if not pattern: return None # Use glob for faster file discovery search_pattern = os.path.join(self.context_learning_path, "**", pattern) for example_file in glob.glob(search_pattern, recursive=True): try: with open(example_file, 'r', encoding='utf-8') as f: content = f.read() examples.append(f"# Example from {os.path.basename(example_file)}\n{content}\n") except Exception as e: print(f"Warning: Could not load example {example_file}: {e}") if examples: formatted_examples = self._format_examples(example_type, examples) self._context_cache[example_type] = formatted_examples return formatted_examples return None def _format_examples(self, example_type: str, examples: List[str]) -> str: """Format examples using the appropriate template.""" templates = { 'scene_plan': get_prompt_context_learning_scene_plan, 'scene_vision_storyboard': get_prompt_context_learning_vision_storyboard, 'technical_implementation': get_prompt_context_learning_technical_implementation, 'scene_animation_narration': get_prompt_context_learning_animation_narration, 'code': get_prompt_context_learning_code } template = templates.get(example_type) if template: return template(examples="\n".join(examples)) return None async def generate_scene_outline(self, topic: str, description: str, session_id: str) -> str: """Enhanced scene outline generation with async I/O.""" start_time = time.time() # Detect relevant plugins upfront if RAG is enabled if self.use_rag and self.rag_integration: plugin_detection_task = asyncio.create_task( self._detect_plugins_async(topic, description) ) # Prepare prompt with cached examples prompt = self._get_cached_prompt('scene_plan', topic, description) if self.use_context_learning and self._context_cache.get('scene_plan'): prompt += f"\n\nHere are some example scene plans for reference:\n{self._context_cache['scene_plan']}" # Wait for plugin detection if enabled if self.use_rag and self.rag_integration: self.relevant_plugins = await plugin_detection_task print(f"✅ Detected relevant plugins: {self.relevant_plugins}") # Generate plan using planner model response_text = self.planner_model( _prepare_text_inputs(prompt), metadata={ "generation_name": "scene_outline", "tags": [topic, "scene-outline"], "session_id": session_id } ) # Extract scene outline with improved error handling scene_outline = self._extract_scene_outline_robust(response_text) # Async file operations file_prefix = re.sub(r'[^a-z0-9_]+', '_', topic.lower()) output_dir = os.path.join(self.output_dir, file_prefix) await self._ensure_directories(output_dir) file_path = os.path.join(output_dir, f"{file_prefix}_scene_outline.txt") await self._async_file_write(file_path, scene_outline) elapsed_time = time.time() - start_time print(f"Scene outline generated in {elapsed_time:.2f}s - saved to {file_prefix}_scene_outline.txt") return scene_outline async def _detect_plugins_async(self, topic: str, description: str) -> List[str]: """Asynchronously detect relevant plugins.""" loop = asyncio.get_event_loop() return await loop.run_in_executor( self.thread_pool, lambda: self.rag_integration.detect_relevant_plugins(topic, description) or [] ) async def _generate_scene_step_parallel(self, step_name: str, prompt_func, scene_trace_id: str, topic: str, scene_number: int, session_id: str, output_path: str, *args) -> Tuple[str, str]: """Generate a single scene step with async operations.""" async with self.step_semaphore: # Control step-level concurrency # Check cache first if enabled if self.enable_caching: cached_content = await self._async_file_read(output_path) if cached_content: print(f"Using cached {step_name} for scene {scene_number}") return cached_content, output_path print(f"🚀 Generating {step_name} for scene {scene_number}") start_time = time.time() # Generate prompt prompt = prompt_func(*args) # Add context examples if available example_type = step_name.replace('_plan', '').replace('scene_', '') if self._context_cache.get(example_type): prompt += f"\n\nHere are some example {step_name}s:\n{self._context_cache[example_type]}" # Add RAG context if enabled if self.use_rag and self.rag_integration: rag_queries = await self._generate_rag_queries_async( step_name, args, scene_trace_id, topic, scene_number, session_id ) if rag_queries: retrieved_docs = self.rag_integration.get_relevant_docs( rag_queries=rag_queries, scene_trace_id=scene_trace_id, topic=topic, scene_number=scene_number ) prompt += f"\n\n{retrieved_docs}" # Generate content response = self.planner_model( _prepare_text_inputs(prompt), metadata={ "generation_name": step_name, "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id } ) # Extract content using step-specific patterns extraction_patterns = { 'scene_vision_storyboard': r'(.*?)', 'scene_technical_implementation': r'(.*?)', 'scene_animation_narration': r'(.*?)' } pattern = extraction_patterns.get(step_name) if pattern: match = re.search(pattern, response, re.DOTALL) content = match.group(1) if match else response else: content = response # Async file save await self._async_file_write(output_path, content) elapsed_time = time.time() - start_time print(f"{step_name} for scene {scene_number} completed in {elapsed_time:.2f}s") return content, output_path async def _generate_rag_queries_async(self, step_name: str, args: tuple, scene_trace_id: str, topic: str, scene_number: int, session_id: str) -> List[Dict]: """Generate RAG queries asynchronously based on step type.""" query_generators = { 'scene_vision_storyboard': self.rag_integration._generate_rag_queries_storyboard, 'scene_technical_implementation': self.rag_integration._generate_rag_queries_technical, 'scene_animation_narration': self.rag_integration._generate_rag_queries_narration } generator = query_generators.get(step_name) if not generator: return [] # Map args to appropriate parameters based on step if step_name == 'scene_vision_storyboard': scene_plan = args[3] if len(args) > 3 else "" return generator( scene_plan=scene_plan, scene_trace_id=scene_trace_id, topic=topic, scene_number=scene_number, session_id=session_id, relevant_plugins=self.relevant_plugins ) elif step_name == 'scene_technical_implementation': storyboard = args[4] if len(args) > 4 else "" return generator( storyboard=storyboard, scene_trace_id=scene_trace_id, topic=topic, scene_number=scene_number, session_id=session_id, relevant_plugins=self.relevant_plugins ) elif step_name == 'scene_animation_narration': storyboard = args[4] if len(args) > 4 else "" return generator( storyboard=storyboard, scene_trace_id=scene_trace_id, topic=topic, scene_number=scene_number, session_id=session_id, relevant_plugins=self.relevant_plugins ) return [] async def _generate_scene_implementation_single_enhanced(self, topic: str, description: str, scene_outline_i: str, scene_number: int, file_prefix: str, session_id: str, scene_trace_id: str) -> str: """Enhanced single scene implementation with parallel steps.""" start_time = time.time() print(f"Starting scene {scene_number} implementation (parallel processing)") # Setup directories scene_dir = os.path.join(self.output_dir, file_prefix, f"scene{scene_number}") subplan_dir = os.path.join(scene_dir, "subplans") await self._ensure_directories(scene_dir, subplan_dir) # Save scene trace ID trace_id_file = os.path.join(subplan_dir, "scene_trace_id.txt") await self._async_file_write(trace_id_file, scene_trace_id) # Define all steps with their configurations steps_config = [ { 'name': 'scene_vision_storyboard', 'prompt_func': get_prompt_scene_vision_storyboard, 'args': (scene_number, topic, description, scene_outline_i, self.relevant_plugins), 'output_path': os.path.join(subplan_dir, f"{file_prefix}_scene{scene_number}_vision_storyboard_plan.txt") } ] # Execute Step 1: Vision Storyboard (sequential dependency) vision_storyboard_content, _ = await self._generate_scene_step_parallel( steps_config[0]['name'], steps_config[0]['prompt_func'], scene_trace_id, topic, scene_number, session_id, steps_config[0]['output_path'], *steps_config[0]['args'] ) # Prepare Step 2 and 3 for parallel execution (both depend on Step 1) remaining_steps = [ { 'name': 'scene_technical_implementation', 'prompt_func': get_prompt_scene_technical_implementation, 'args': (scene_number, topic, description, scene_outline_i, vision_storyboard_content, self.relevant_plugins), 'output_path': os.path.join(subplan_dir, f"{file_prefix}_scene{scene_number}_technical_implementation_plan.txt") }, { 'name': 'scene_animation_narration', 'prompt_func': get_prompt_scene_animation_narration, 'args': (scene_number, topic, description, scene_outline_i, vision_storyboard_content, None, self.relevant_plugins), 'output_path': os.path.join(subplan_dir, f"{file_prefix}_scene{scene_number}_animation_narration_plan.txt") } ] # Execute Steps 2 and 3 in parallel parallel_tasks = [] for step_config in remaining_steps: task = asyncio.create_task( self._generate_scene_step_parallel( step_config['name'], step_config['prompt_func'], scene_trace_id, topic, scene_number, session_id, step_config['output_path'], *step_config['args'] ) ) parallel_tasks.append(task) # Wait for parallel tasks to complete parallel_results = await asyncio.gather(*parallel_tasks) technical_implementation_content = parallel_results[0][0] animation_narration_content = parallel_results[1][0] # Update animation narration args with technical implementation and regenerate if needed if technical_implementation_content: updated_animation_args = ( scene_number, topic, description, scene_outline_i, vision_storyboard_content, technical_implementation_content, self.relevant_plugins ) animation_narration_content, _ = await self._generate_scene_step_parallel( 'scene_animation_narration', get_prompt_scene_animation_narration, scene_trace_id, topic, scene_number, session_id, remaining_steps[1]['output_path'], *updated_animation_args ) # Combine all implementation plans implementation_plan = ( f"{vision_storyboard_content}\n\n" f"{technical_implementation_content}\n\n" f"{animation_narration_content}\n\n" ) # Ensure scene directory exists (just to be extra safe) scene_dir = os.path.join(self.output_dir, file_prefix, f"scene{scene_number}") await self._ensure_directories(scene_dir) # Save combined implementation plan combined_plan_path = os.path.join(scene_dir, f"{file_prefix}_scene{scene_number}_implementation_plan.txt") combined_content = f"# Scene {scene_number} Implementation Plan\n\n{implementation_plan}" try: await self._async_file_write(combined_plan_path, combined_content) print(f"✅ Saved implementation plan for scene {scene_number} to: {combined_plan_path}") except Exception as e: print(f"❌ Error saving implementation plan for scene {scene_number}: {e}") raise elapsed_time = time.time() - start_time print(f"Scene {scene_number} implementation completed in {elapsed_time:.2f}s") return implementation_plan async def generate_scene_implementation_concurrently_enhanced(self, topic: str, description: str, plan: str, session_id: str) -> List[str]: """Enhanced concurrent scene implementation with better performance.""" start_time = time.time() # Extract scene information scene_outline = extract_xml(plan) scene_number = len(re.findall(r'[^<]', scene_outline)) file_prefix = re.sub(r'[^a-z0-9_]+', '_', topic.lower()) print(f"Starting implementation generation for {scene_number} scenes with max concurrency: {self.max_scene_concurrency}") async def generate_single_scene_implementation(i): async with self.scene_semaphore: # Control scene-level concurrency scene_regex = r'(.*?)'.format(i) scene_match = re.search( scene_regex, scene_outline, re.DOTALL ) if not scene_match: print(f"❌ Error: Could not find scene {i} in scene outline. Regex pattern: {scene_regex}") raise ValueError(f"Scene {i} not found in scene outline") scene_outline_i = scene_match.group(1) scene_trace_id = str(uuid.uuid4()) return await self._generate_scene_implementation_single_enhanced( topic, description, scene_outline_i, i, file_prefix, session_id, scene_trace_id ) # Create tasks for all scenes tasks = [generate_single_scene_implementation(i + 1) for i in range(scene_number)] # Execute with progress tracking print(f"Executing {len(tasks)} scene implementation tasks...") try: all_scene_implementation_plans = await asyncio.gather(*tasks, return_exceptions=True) # Handle any exceptions successful_plans = [] error_count = 0 for i, result in enumerate(all_scene_implementation_plans): if isinstance(result, Exception): print(f"❌ Error in scene {i+1}: {result}") error_message = f"# Scene {i+1} - Error: {result}" successful_plans.append(error_message) # Write error to file to maintain file structure even on failure scene_dir = os.path.join(self.output_dir, file_prefix, f"scene{i+1}") os.makedirs(scene_dir, exist_ok=True) error_file_path = os.path.join(scene_dir, f"{file_prefix}_scene{i+1}_implementation_plan.txt") try: with open(error_file_path, 'w') as f: f.write(error_message) except Exception as e: print(f"❌ Failed to write error file for scene {i+1}: {e}") error_count += 1 else: successful_plans.append(result) print(f"✅ Successfully generated implementation plan for scene {i+1}") total_time = time.time() - start_time print(f"All scene implementations completed in {total_time:.2f}s") print(f" Average time per scene: {total_time/len(tasks):.2f}s") print(f" Success rate: {len(tasks) - error_count}/{len(tasks)} scenes ({(len(tasks) - error_count) / len(tasks) * 100:.1f}%)") if error_count > 0: print(f"⚠️ Warning: {error_count} scenes had errors during implementation plan generation") except Exception as e: print(f"❌ Fatal error during scene implementation tasks: {e}") raise return successful_plans async def __aenter__(self): """Async context manager entry.""" return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit - cleanup resources.""" self.thread_pool.shutdown(wait=True) # Legacy method compatibility async def generate_scene_implementation_concurrently(self, topic: str, description: str, plan: str, session_id: str, scene_semaphore=None) -> List[str]: """Legacy compatibility method - redirects to enhanced version.""" if scene_semaphore: self.scene_semaphore = scene_semaphore return await self.generate_scene_implementation_concurrently_enhanced( topic, description, plan, session_id ) def _extract_scene_outline_robust(self, response_text: str) -> str: """ Robust extraction of scene outline that handles various XML format issues. This method addresses common problems: 1. XML wrapped in markdown code blocks 2. Missing closing tags 3. Malformed XML structure 4. Extra text before/after XML """ import re # First try: Look for XML wrapped in markdown code blocks markdown_xml_pattern = r'```xml\s*\n(.*?)\s*\n```' markdown_match = re.search(markdown_xml_pattern, response_text, re.DOTALL) if markdown_match: xml_content = markdown_match.group(1) return self._validate_and_fix_xml(xml_content) # Second try: Look for direct XML tags direct_xml_pattern = r'(.*?)' direct_match = re.search(direct_xml_pattern, response_text, re.DOTALL) if direct_match: xml_content = direct_match.group(1) return self._validate_and_fix_xml(xml_content) # Third try: Look for incomplete XML and attempt to fix incomplete_pattern = r'(.*?)(?:|$)' incomplete_match = re.search(incomplete_pattern, response_text, re.DOTALL) if incomplete_match: xml_content = incomplete_match.group(1) # Add missing closing tag if needed full_xml = f"{xml_content}" return self._validate_and_fix_xml(full_xml) # If no XML structure found, return the entire response but warn print("⚠️ Warning: No valid XML structure found in LLM response. Using full response.") print("Response preview:", response_text[:200] + "..." if len(response_text) > 200 else response_text) return response_text def _validate_and_fix_xml(self, xml_content: str) -> str: """ Validate and fix common XML issues in scene outlines. """ import re # Check for unclosed scene tags scene_pattern = r'' scene_matches = re.findall(scene_pattern, xml_content) fixed_content = xml_content for scene_num in scene_matches: # Check if this scene has a proper closing tag open_tag = f"" close_tag = f"" # Find the position of this scene's opening tag open_pos = fixed_content.find(open_tag) if open_pos == -1: continue # Find the next scene's opening tag (if any) next_scene_pattern = f"" next_scene_pos = fixed_content.find(next_scene_pattern, open_pos) # Check if there's a closing tag before the next scene close_pos = fixed_content.find(close_tag, open_pos) if close_pos == -1 or (next_scene_pos != -1 and close_pos > next_scene_pos): # Missing or misplaced closing tag if next_scene_pos != -1: # Insert closing tag before next scene insert_pos = next_scene_pos while insert_pos > 0 and fixed_content[insert_pos - 1] in ' \n\t': insert_pos -= 1 fixed_content = (fixed_content[:insert_pos] + f"\n {close_tag}\n\n " + fixed_content[insert_pos:]) else: # Insert closing tag at the end end_outline_pos = fixed_content.find("") if end_outline_pos != -1: fixed_content = (fixed_content[:end_outline_pos] + f"\n {close_tag}\n" + fixed_content[end_outline_pos:]) else: fixed_content += f"\n {close_tag}" print(f"🔧 Fixed missing closing tag for SCENE_{scene_num}") # Ensure proper SCENE_OUTLINE structure if not fixed_content.strip().startswith(""): fixed_content = f"\n{fixed_content}" if not fixed_content.strip().endswith(""): fixed_content = f"{fixed_content}\n" return fixed_content # Update class alias for backward compatibility VideoPlanner = EnhancedVideoPlanner