diff --git "a/tools.py" "b/tools.py" --- "a/tools.py" +++ "b/tools.py" @@ -15,6 +15,13 @@ from langchain_huggingface import HuggingFacePipeline from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set import time from collections import Counter +from pydantic import Field +import hashlib +import json +import numpy as np +import ast +from concurrent.futures import ThreadPoolExecutor, as_completed +from collections import Counter, defaultdict # Third-Party Packages import cv2 @@ -28,6 +35,15 @@ from bs4 import BeautifulSoup from duckduckgo_search import DDGS from sentence_transformers import SentenceTransformer from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline, AutoTokenizer +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.cluster import KMeans +from sklearn.preprocessing import StandardScaler +import speech_recognition as sr +from pydub import AudioSegment +from pydub.silence import split_on_silence +import nltk +from nltk.corpus import words # LangChain Ecosystem from langchain.docstore.document import Document @@ -59,9 +75,12 @@ from datetime import datetime from urllib.parse import urljoin, urlparse import logging - nlp = spacy.load("en_core_web_sm") +# Ensure the word list is downloaded +nltk.download('words', quiet=True) +english_words = set(words.words()) + logger = logging.getLogger(__name__) # --- Model Configuration --- @@ -331,7 +350,6 @@ class AgentState(TypedDict): answer: Optional[str] frame_answers: Optional[list] - def fetch_page_with_tables(page_title): """ Fetches Wikipedia page content and extracts all tables as readable text. @@ -431,7 +449,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): candidates_set.add(query) if processed_query and processed_query != query: candidates_set.add(processed_query) - + if entity_prefix and keywords: first_entity_lower = entity_prefix.lower() for kw in keywords[:3]: @@ -439,10 +457,10 @@ class WikipediaSearchToolWithFAISS(BaseTool): candidates_set.add(f"{entity_prefix} {kw}") keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2) if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}") - + if len(main_entities) > 1: candidates_set.add(" ".join(main_entities[:2])) - + if keywords: keyword_combo = " ".join(keywords[:2]) if entity_prefix: @@ -451,13 +469,13 @@ class WikipediaSearchToolWithFAISS(BaseTool): candidates_set.add(candidate_to_add) elif not main_entities: candidates_set.add(keyword_combo) - + ordered_candidates = [] for me in main_entities: if me not in ordered_candidates: ordered_candidates.append(me) for c in list(candidates_set): if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c) - + print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}") return ordered_candidates @@ -483,28 +501,28 @@ class WikipediaSearchToolWithFAISS(BaseTool): except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError): print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.") # Fallthrough to auto_suggest=True below if this fails - + if suggest_mode: # If not attempted or failed with auto_suggest=False temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True) - + final_page_title = temp_page.title - + if is_candidate_entity_focused and main_entities_from_query: title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) if not title_matches_main_entity: print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') " f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") - continue + continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") continue page_object = temp_page print(f" โœ“ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'") - + except wikipedia.exceptions.PageError: if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...") - search_results = wikipedia.search(candidate_query, results=1) + search_results = wikipedia.search(candidate_query, results=1) if not search_results: print(f" - No Wikipedia search results for '{candidate_query}'.") continue @@ -520,7 +538,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") - continue + continue page_object = temp_page print(f" โœ“ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'") except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr: @@ -547,7 +565,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): print(f" โœ“ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'") except Exception as e_dis_opt: print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}") - + if page_object and final_page_title and (final_page_title not in processed_page_titles): # Extract main text main_text = page_object.content @@ -582,7 +600,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}") except Exception as e: print(f" !! Unexpected error processing candidate '{candidate_query}': {e}") - + if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.") else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.") return found_pages_data @@ -595,7 +613,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): section_phrases_templates = [] lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords) - + section_keywords_map = { "discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"], "biography": ["biography", "life story", "career details", "background history"], @@ -610,33 +628,33 @@ class WikipediaSearchToolWithFAISS(BaseTool): for phrase in specific_phrases_list: if phrase in query.lower(): # Check against original query for direct phrase matches section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit - break + break section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate final_search_queries = set() if main_entities: entity_prefix = main_entities[0] - final_search_queries.add(entity_prefix) + final_search_queries.add(entity_prefix) for part in core_query_parts: final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part) for phrase_template in section_phrases_templates: final_search_queries.add(f"{entity_prefix} {phrase_template}") if "list of" in phrase_template or "history of" in phrase_template : final_search_queries.add(f"{phrase_template} of {entity_prefix}") - else: + else: final_search_queries.update(core_query_parts) final_search_queries.update(section_phrases_templates) deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip())) print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}") - + all_results_docs: List[Document] = [] seen_content_hashes: Set[int] = set() k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier - + for search_query_variant in deduplicated_queries: try: - results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) + results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.") for doc, score in results: # Assuming similarity_search_with_score returns (doc, score) content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness @@ -647,23 +665,26 @@ class WikipediaSearchToolWithFAISS(BaseTool): all_results_docs.append(doc) except Exception as e: print(f" Error in semantic search for variant '{search_query_variant}': {e}") - + # Sort all collected unique results by score (FAISS L2 distance is lower is better) all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf'))) print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.") - + return all_results_docs[:self.top_k_results] - def _run(self, query: str) -> str: + def _run(self, query: str = None, search_query: str = None, **kwargs) -> str: if not self._nlp or not self._embedding_model or not self._text_splitter: print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.") return "Error: Wikipedia tool components not initialized properly. Please check server logs." + if not query: + query = search_query or kwargs.get('q') or kwargs.get('search_term') + try: print(f"\n--- Running {self.name} for query: '{query}' ---") main_entities, keywords, processed_query = self._extract_entities_and_keywords(query) print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'") - + fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query) if not fetched_pages_data: @@ -681,7 +702,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): continue for i, chunk_text in enumerate(chunks): all_documents.append(Document(page_content=chunk_text, metadata={ - "source_page_title": page_title, + "source_page_title": page_title, "original_query": query, "chunk_index": i # Add chunk index for potential debugging or ordering })) @@ -690,7 +711,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): if not all_documents: return (f"Could not process content into searchable chunks from the fetched Wikipedia pages " f"({', '.join(all_page_titles)}) for query '{query}'.") - + print(f"\nTotal document chunks from all pages: {len(all_documents)}") print("Creating FAISS index from content of all fetched pages...") @@ -709,7 +730,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): if not relevant_docs: return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' " f"for your query '{query}' using entity-focused semantic search with list retrieval.") - + unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs])) result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) " f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n") @@ -723,7 +744,7 @@ class WikipediaSearchToolWithFAISS(BaseTool): detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n" f"(Retrieved by: '{variant_info}')\n{doc.page_content}") result_details.append(detail) - + final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details) print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).") return final_result.strip() @@ -734,632 +755,2005 @@ class WikipediaSearchToolWithFAISS(BaseTool): return f"An unexpected error occurred: {str(e)}" -# Example of creating the tool instance: -# wikipedia_tool_faiss = WikipediaSearchToolWithFAISS() +class EnhancedYoutubeScreenshotQA(BaseTool): + name: str = "enhanced_youtube_screenshot_qa" + description: str = ( + "Downloads a YouTube video, intelligently extracts screenshots, " + "and answers questions using advanced visual QA with semantic analysis. " + "Use this tool for questions about the VIDEO or IMAGES in the video," + "Input should be a dict with keys: 'youtube_url', 'question', and optional parameters. " + #"Optional parameters: 'frame_interval_seconds' (default: 10), 'max_frames' (default: 50), " + #"'use_scene_detection' (default: True), 'parallel_processing' (default: True). " + "Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'question': 'What animals are visible?'}" + ) -# To use this new tool in your agent, you would replace the old -# `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list. -# For example: -# tools = [wikipedia_tool_faiss, search_tool] -# Create tool instances -#wikipedia_tool = WikipediaSearchTool() + # Define Pydantic fields for the attributes we need to set + device: Any = Field(default=None, exclude=True) + processor_vqa: Any = Field(default=None, exclude=True) + model_vqa: Any = Field(default=None, exclude=True) -# --- Define Call LLM function --- + class Config: + # Allow arbitrary types (needed for torch.device, model objects) + arbitrary_types_allowed = True + # Allow extra fields to be set + extra = "allow" -# 3. Improved LLM call with memory management + def __init__(self, **kwargs): + super().__init__(**kwargs) -def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: # Added llm_model parameter - """Call LLM with memory management, context truncation, and process response.""" - print("Running call_llm with memory management...") + # Initialize directories + cache_dir = '/tmp/youtube_qa_cache/' + video_dir = '/tmp/video/' + frames_dir = '/tmp/video_frames/' - # It's crucial to work with a copy of messages for modification within this step - # The final state["messages"] should reflect the full history + new response. - original_messages = list(state["messages"]) - messages_for_llm_processing = list(state["messages"]) # Use this for truncation logic + # Initialize model and device + self._initialize_model() + + # Create directories + for dir_path in [cache_dir, video_dir, frames_dir]: + os.makedirs(dir_path, exist_ok=True) + + def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None): + """Get configuration value with fallback to defaults""" + defaults = { + 'frame_interval_seconds': 10, + 'max_frames': 50, + 'use_scene_detection': True, + 'resize_frames': True, + 'parallel_processing': True, + 'cache_enabled': True, + 'quality_threshold': 30.0, + 'semantic_similarity_threshold': 0.8 + } - #ipdb.set_trace() + if input_data and key in input_data: + return input_data[key] + return defaults.get(key, default_value) - # --- Context Truncation Logic --- - system_message_content = None - # Check if the first message is a system message and preserve it - if messages_for_llm_processing and isinstance(messages_for_llm_processing[0], SystemMessage): - system_message_content = messages_for_llm_processing[0] - # Process only non-system messages for truncation count - regular_messages = messages_for_llm_processing[1:] - else: - regular_messages = messages_for_llm_processing + def _initialize_model(self): + """Initialize BLIP model for VQA with error handling""" + try: + #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cpu") + print(f"Using device: {self.device}") - # Truncate context if too many messages (e.g., keep system + X most recent) - # Max 10 messages total (e.g. 1 system + 9 others) - max_regular_messages = 9 - if len(regular_messages) > max_regular_messages: - print(f"๐Ÿ”„ Truncating message count: {len(messages_for_llm_processing)} -> ~{max_regular_messages + (1 if system_message_content else 0)} messages") - regular_messages = regular_messages[- (max_regular_messages -1):] # Keep X-1 most recent, to add user input later + self.processor_vqa = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") + self.model_vqa = BlipForQuestionAnswering.from_pretrained( + "Salesforce/blip-vqa-base" + ).to(self.device) - # Reconstruct messages for LLM call - messages_for_llm = [] - if system_message_content: - messages_for_llm.append(system_message_content) - messages_for_llm.extend(regular_messages) + print("BLIP VQA model loaded successfully") + except Exception as e: + print(f"Error initializing VQA model: {str(e)}") + raise + + def _get_video_hash(self, url: str) -> str: + """Generate hash for video URL for caching""" + return hashlib.md5(url.encode()).hexdigest() + + def _get_cache_path(self, video_hash: str, cache_type: str) -> str: + """Get cache file path""" + cache_dir = '/tmp/youtube_qa_cache/' + return os.path.join(cache_dir, f"{video_hash}_{cache_type}") + + def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: + """Load data from cache""" + if not cache_enabled or not os.path.exists(cache_path): + return None + try: + with open(cache_path, 'r') as f: + return json.load(f) + except Exception as e: + print(f"Error loading cache: {str(e)}") + return None - # Further truncate based on character count (rough proxy for tokens) - total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) - # Example character limit, adjust based on your model (e.g. 8k chars for ~4k tokens) - char_limit = 8000 - if total_chars > char_limit: - print(f"๐Ÿ“ Context too long ({total_chars} chars > {char_limit}), further truncation needed") - # More aggressive truncation of regular messages - chars_to_remove = total_chars - char_limit - temp_regular_messages = list(regular_messages) # copy - while sum(len(str(m.content)) for m in temp_regular_messages) > char_limit and temp_regular_messages: - if system_message_content and sum(len(str(m.content)) for m in temp_regular_messages) + len(str(system_message_content.content)) <= char_limit : - break # if removing one more makes it too small with system message - print(f"Removing message: {temp_regular_messages[0].type} - {temp_regular_messages[0].content[:50]}...") - temp_regular_messages.pop(0) - - regular_messages = temp_regular_messages - messages_for_llm = [] # Rebuild - if system_message_content: - messages_for_llm.append(system_message_content) - messages_for_llm.extend(regular_messages) - print(f"Context truncated to {sum(len(str(m.content)) for m in messages_for_llm)} chars.") + def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): + """Save data to cache""" + if not cache_enabled: + return + try: + with open(cache_path, 'w') as f: + json.dump(data, f) + except Exception as e: + print(f"Error saving cache: {str(e)}") - new_state = state.copy() # Start with a copy of the input state + def download_youtube_video(self, url: str, video_hash: str, cache_enabled: bool = True) -> Optional[str]: + """Enhanced YouTube video download with caching""" + video_dir = '/tmp/video/' + output_filename = f'{video_hash}.mp4' + output_path = os.path.join(video_dir, output_filename) - try: - #if torch.cuda.is_available(): - # torch.cuda.empty_cache() - # print(f"๐Ÿงน Pre-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") + # Check cache + if cache_enabled and os.path.exists(output_path): + print(f"Using cached video: {output_path}") + return output_path - print(f"Invoking LLM with {len(messages_for_llm)} messages.") - # This is where you call your actual LLM - formatted_input = "\n".join([f"[{msg.type.upper()}] {msg.content}" for msg in messages_for_llm]) - print(f"\n\nFormatted input for LLM:\n\n{formatted_input}") + # Clean directory + video_dir = '/tmp/video/' + self._clean_directory(video_dir) - llm_response_object = llm_model.invoke(formatted_input) + try: + ydl_opts = { + 'format': 'bestvideo[height<=720][ext=mp4]+bestaudio[ext=m4a]/best[height<=720][ext=mp4]/best', + 'outtmpl': output_path, + 'quiet': True, + 'merge_output_format': 'mp4', + 'postprocessors': [{ + 'key': 'FFmpegVideoConvertor', + 'preferedformat': 'mp4', + }] + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.download([url]) + + if os.path.exists(output_path): + print(f"Video downloaded successfully: {output_path}") + return output_path + else: + print("Download completed but file not found") + return None - #ipdb.set_trace() - - # The response_object is typically a BaseMessage subclass (e.g., AIMessage) - # or a string for simpler LLMs. Adapt as needed. - if isinstance(llm_response_object, BaseMessage): - ai_message_response = llm_response_object # It's already a message object - if not ai_message_response.content: # Ensure content is not empty - ai_message_response.content = "" - elif hasattr(llm_response_object, 'content'): # Some models might return a custom object with a content attribute - ai_message_response = AIMessage(content=str(llm_response_object.content) if llm_response_object.content is not None else "") - else: # Assuming it's a string for basic LLMs - ai_message_response = AIMessage(content=str(llm_response_object) if llm_response_object is not None else "") + except Exception as e: + print(f"Error downloading YouTube video: {str(e)}") + return None + + def _clean_directory(self, directory: str): + """Clean directory contents""" + if os.path.exists(directory): + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print(f'Failed to delete {file_path}. Reason: {e}') + + def _assess_frame_quality(self, frame: np.ndarray) -> float: + """Assess frame quality using Laplacian variance (blur detection)""" + try: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + return cv2.Laplacian(gray, cv2.CV_64F).var() + except Exception: + return 0.0 - print(f"LLM Response: {ai_message_response.content[:300]}...") # Print a snippet + def _detect_scene_changes(self, video_path: str, threshold: float = 30.0) -> List[int]: + """Detect scene changes in video""" + scene_frames = [] + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + return [] - # Append the LLM's response to the original full list of messages - final_messages = original_messages + [ai_message_response] - new_state["messages"] = final_messages - new_state.pop("done", None) # LLM responded, so not 'done' by default + prev_frame = None + frame_count = 0 - except Exception as e: - print(f"LLM call failed: {e}") - error_message_content = f"LLM call failed with error: {str(e)}. Input consisted of {len(messages_for_llm)} messages." + while True: + ret, frame = cap.read() + if not ret: + break - if "out of memory" in str(e).lower(): - print("๐Ÿšจ CUDA OOM detected during LLM call! Implementing emergency cleanup...") - error_message_content = f"LLM failed due to Out of Memory: {str(e)}." - try: - #if torch.cuda.is_available(): - # torch.cuda.empty_cache() - gc.collect() - except Exception as cleanup_e: - print(f"Emergency OOM cleanup failed: {cleanup_e}") - - # Append an error message to the original message history - error_ai_message = AIMessage(content=error_message_content) - final_messages_on_error = original_messages + [error_ai_message] - new_state["messages"] = final_messages_on_error - new_state["done"] = True # Mark as done to prevent loops on LLM failure - finally: - try: - pass - #if torch.cuda.is_available(): - # torch.cuda.empty_cache() - # print(f"๐Ÿงน Post-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") - except Exception: - pass # Avoid error in cleanup hiding the main error + if prev_frame is not None: + # Calculate histogram difference + hist1 = cv2.calcHist([prev_frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) + hist2 = cv2.calcHist([frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) + diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CHISQR) - return new_state -import re -import uuid + if diff > threshold: + scene_frames.append(frame_count) -def parse_react_output(state: AgentState) -> AgentState: - print("Running parse_react_output (Action prioritized)...") - messages = state["messages"] - last_message = messages[-1] - new_state = state.copy() + prev_frame = frame.copy() + frame_count += 1 - # Only process AI messages (not system/user) - if not isinstance(last_message, AIMessage): - return new_state + cap.release() + return scene_frames - content = last_message.content + except Exception as e: + print(f"Error in scene detection: {str(e)}") + return [] - # Remove any system prompt/instructions (if present in content) - # Assume that the actual AI output is after the last occurrence of "You are a general AI assistant" or similar system prompt marker - sys_prompt_pattern = r"(You are a general AI assistant.*?)(?=\n\n|$)" - content_wo_sys_prompt = re.sub(sys_prompt_pattern, '', content, flags=re.DOTALL | re.IGNORECASE).strip() + def smart_extract_frames(self, video_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> List[str]: + """Intelligently extract frames with quality filtering and scene detection""" + cache_enabled = self._get_config('cache_enabled', True, input_data) + cache_path = self._get_cache_path(video_hash, "frames_info.json") + cached_info = self._load_from_cache(cache_path, cache_enabled) - # Find the last occurrence of FINAL ANSWER or Action Input - final_answer_match = list(re.finditer(r"FINAL ANSWER:", content_wo_sys_prompt, re.IGNORECASE)) - action_input_match = list(re.finditer(r"Action Input:", content_wo_sys_prompt, re.IGNORECASE)) + if cached_info: + # Verify cached frames still exist + existing_frames = [f for f in cached_info['frame_paths'] if os.path.exists(f)] + if len(existing_frames) == len(cached_info['frame_paths']): + print(f"Using {len(existing_frames)} cached frames") + return existing_frames - # Helper: get the last match position and which it was - last_marker = None - last_pos = -1 - if final_answer_match: - last_fa = final_answer_match[-1] - last_marker = 'FINAL ANSWER' - last_pos = last_fa.start() - if action_input_match: - last_ai = action_input_match[-1] - if last_ai.start() > last_pos: - last_marker = 'Action Input' - last_pos = last_ai.start() - - # If neither marker found, mark as done - if not last_marker: - print("No FINAL ANSWER or Action Input found in last AI output.") - new_state["done"] = True - return new_state + # Clean frames directory + frames_dir = '/tmp/video_frames/' + self._clean_directory(frames_dir) - # Get the substring from the last marker to the end - last_section = content_wo_sys_prompt[last_pos:].strip() + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("Error: Could not open video.") + return [] + + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_interval_seconds = self._get_config('frame_interval_seconds', 10, input_data) + frame_interval = max(1, int(fps * frame_interval_seconds)) + + print(f"Video info: {total_frames} frames, {fps:.2f} fps") + + # Get scene change frames if enabled + scene_frames = set() + use_scene_detection = self._get_config('use_scene_detection', True, input_data) + if use_scene_detection: + scene_frames = set(self._detect_scene_changes(video_path)) + print(f"Detected {len(scene_frames)} scene changes") + + extracted_frames = [] + frame_count = 0 + saved_count = 0 + max_frames = self._get_config('max_frames', 50, input_data) + + while True: + ret, frame = cap.read() + if not ret or saved_count >= max_frames: + break - # 2. If FINAL ANSWER is in the last part, end the process - if last_marker == 'FINAL ANSWER': - # Extract the answer after FINAL ANSWER: - answer = re.search(r"FINAL ANSWER:\s*(.+)", last_section, re.IGNORECASE) - final_answer_text = answer.group(1).strip() if answer else "" - updated_ai_message = AIMessage(content=f"FINAL ANSWER: {final_answer_text}", tool_calls=[]) - new_state["messages"] = messages[:-1] + [updated_ai_message] - new_state["done"] = True - print(f"FINAL ANSWER found at end: '{final_answer_text}'") - return new_state + # Check if we should extract this frame + should_extract = ( + frame_count % frame_interval == 0 or + frame_count in scene_frames + ) + + if should_extract: + # Assess frame quality + quality = self._assess_frame_quality(frame) + quality_threshold = self._get_config('quality_threshold', 30.0, input_data) + + if quality >= quality_threshold: + # Resize frame if enabled + resize_frames = self._get_config('resize_frames', True, input_data) + if resize_frames: + height, width = frame.shape[:2] + if width > 800: + scale = 800 / width + new_width = 800 + new_height = int(height * scale) + frame = cv2.resize(frame, (new_width, new_height)) + + frame_filename = os.path.join( + frames_dir, + f"frame_{frame_count:06d}_q{quality:.1f}.jpg" + ) + + if cv2.imwrite(frame_filename, frame): + extracted_frames.append(frame_filename) + saved_count += 1 + print(f"Extracted frame {saved_count}/{max_frames} " + f"(quality: {quality:.1f})") + + frame_count += 1 + + cap.release() + + # Cache frame information + frame_info = { + 'frame_paths': extracted_frames, + 'extraction_time': time.time(), + 'total_frames_processed': frame_count, + 'frames_extracted': len(extracted_frames) + } + self._save_to_cache(cache_path, frame_info, cache_enabled) + + print(f"Successfully extracted {len(extracted_frames)} high-quality frames") + return extracted_frames - # 3. If Action Input is in the last part, launch tool - if last_marker == 'Action Input': - # Try to extract the Action and Action Input for the last occurrence - action_match = list(re.finditer(r"Action:\s*([^\n]+)", last_section)) - action_input_match = list(re.finditer(r"Action Input:\s*([^\n]+)", last_section)) - if action_match and action_input_match: - tool_name = action_match[-1].group(1).strip() - tool_input_raw = action_input_match[-1].group(1).strip() - print(f"ReAct: Found Action: {tool_name}, Input: '{tool_input_raw}'") - # Format tool_args as in your original code (simplified here) - tool_args = {"query": tool_input_raw} - tool_call_id = str(uuid.uuid4()) - parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] - updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) - new_state["messages"] = messages[:-1] + [updated_ai_message] - new_state.pop("done", None) - print(f"AIMessage updated with tool_calls: {parsed_tool_calls}") - return new_state - else: - print("Action Input found at end, but could not parse Action or Action Input.") - new_state["done"] = True - return new_state + except Exception as e: + print(f"Exception during frame extraction: {e}") + return [] - # Fallback: mark as done - print("No actionable marker found at end of last AI output. Marking as done.") - new_state["done"] = True - return new_state + def _answer_question_on_frame(self, frame_path: str, question: str) -> Tuple[str, float]: + """Answer question on single frame with confidence scoring""" + try: + image = Image.open(frame_path).convert('RGB') + inputs = self.processor_vqa(image, question, return_tensors="pt").to(self.device) + with torch.no_grad(): + outputs = self.model_vqa.generate(**inputs, output_scores=True, return_dict_in_generate=True) + answer = self.processor_vqa.decode(outputs.sequences[0], skip_special_tokens=True) -def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'): - """Download a YouTube video using yt-dlp""" - # Ensure the output directory exists - os.makedirs(output_dir, exist_ok=True) + # Calculate confidence (simplified - you might want to use actual model confidence) + confidence = 1.0 # Placeholder - BLIP doesn't directly provide confidence + + return answer, confidence - # Delete all files in the output directory - files = glob.glob(os.path.join(output_dir, '*')) - for f in files: - try: - os.remove(f) except Exception as e: - print(f"Error deleting {f}: {str(e)}") + print(f"Error processing frame {frame_path}: {str(e)}") + return "Error processing this frame", 0.0 + + def _process_frames_parallel(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> List[Tuple[str, str, float]]: + """Process frames in parallel""" + results = [] + parallel_processing = self._get_config('parallel_processing', True, input_data) + + if parallel_processing: + with ThreadPoolExecutor(max_workers=min(4, len(frame_files))) as executor: + future_to_frame = { + executor.submit(self._answer_question_on_frame, frame_path, question): frame_path + for frame_path in frame_files + } + + for future in as_completed(future_to_frame): + frame_path = future_to_frame[future] + try: + answer, confidence = future.result() + results.append((frame_path, answer, confidence)) + print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") + except Exception as e: + print(f"Error processing {frame_path}: {str(e)}") + results.append((frame_path, "Error", 0.0)) + else: + for frame_path in frame_files: + answer, confidence = self._answer_question_on_frame(frame_path, question) + results.append((frame_path, answer, confidence)) + print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") - # Set output path for yt-dlp - output_path = os.path.join(output_dir, output_filename) + return results - try: - ydl_opts = { - 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', - 'outtmpl': output_path, - 'quiet': True, - 'merge_output_format': 'mp4', # Ensures merged output is mp4 - 'postprocessors': [{ - 'key': 'FFmpegVideoConvertor', - 'preferedformat': 'mp4', # Recode if needed - }] - } - with yt_dlp.YoutubeDL(ydl_opts) as ydl: - ydl.download([url]) - return output_path - except Exception as e: - print(f"Error downloading YouTube video: {str(e)}") - return None - -def extract_frames(video_path, output_dir, frame_interval_seconds=10): - """Extract frames from a video file at specified intervals""" - # Clean output directory before extracting new frames - if os.path.exists(output_dir): - for filename in os.listdir(output_dir): - file_path = os.path.join(output_dir, filename) + def _cluster_similar_answers(self, answers: List[str], input_data: Dict[str, Any] = None) -> Dict[str, List[str]]: + """Cluster semantically similar answers""" + if len(answers) <= 1: + return {answers[0]: answers} if answers else {} + + try: + # First try with standard TF-IDF settings + vectorizer = TfidfVectorizer( + stop_words='english', + lowercase=True, + min_df=1, # Include words that appear in at least 1 document + max_df=1.0 # Include words that appear in up to 100% of documents + ) + tfidf_matrix = vectorizer.fit_transform(answers) + + # Check if we have any features after TF-IDF + if tfidf_matrix.shape[1] == 0: + raise ValueError("No features after TF-IDF processing") + + # Calculate cosine similarity + similarity_matrix = cosine_similarity(tfidf_matrix) + + # Cluster similar answers + clusters = defaultdict(list) + used = set() + semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) + + for i, answer in enumerate(answers): + if i in used: + continue + + cluster_key = answer + clusters[cluster_key].append(answer) + used.add(i) + + # Find similar answers + for j in range(i + 1, len(answers)): + if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: + clusters[cluster_key].append(answers[j]) + used.add(j) + + return dict(clusters) + + except (ValueError, Exception) as e: + print(f"Error in semantic clustering: {str(e)}") + + # Fallback 1: Try without stop words filtering try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print(f'Failed to delete {file_path}. Reason: {e}') - else: - os.makedirs(output_dir, exist_ok=True) + print("Attempting clustering without stop word filtering...") + vectorizer_no_stop = TfidfVectorizer( + lowercase=True, + min_df=1, + token_pattern=r'\b\w+\b' # Match any word + ) + tfidf_matrix = vectorizer_no_stop.fit_transform(answers) + + if tfidf_matrix.shape[1] > 0: + similarity_matrix = cosine_similarity(tfidf_matrix) + + clusters = defaultdict(list) + used = set() + semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) + + for i, answer in enumerate(answers): + if i in used: + continue - try: - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - print("Error: Could not open video.") - return False - fps = cap.get(cv2.CAP_PROP_FPS) - frame_interval = int(fps * frame_interval_seconds) - count = 0 - saved = 0 - while True: - ret, frame = cap.read() - if not ret: - break - if count % frame_interval == 0: - frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") - cv2.imwrite(frame_filename, frame) - saved += 1 - count += 1 - cap.release() - print(f"Extracted {saved} frames.") - return saved > 0 - except Exception as e: - print(f"Exception during frame extraction: {e}") - return False + cluster_key = answer + clusters[cluster_key].append(answer) + used.add(i) -def answer_question_on_frame(image_path, question): - """Answer a question about a single video frame using BLIP""" - try: - vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly - processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used - model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used - device = "cpu" - - image = Image.open(image_path).convert('RGB') - inputs = processor_vqa(image, question, return_tensors="pt").to(device) - out = model_vqa.generate(**inputs) - answer = processor_vqa.decode(out[0], skip_special_tokens=True) - return answer - except Exception as e: - print(f"Error processing frame {image_path}: {str(e)}") - return "Error processing this frame" + for j in range(i + 1, len(answers)): + if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: + clusters[cluster_key].append(answers[j]) + used.add(j) -def answer_video_question(frames_dir, question): - """Answer a question about a video by analyzing extracted frames""" - valid_exts = ('.jpg', '.jpeg', '.png') + return dict(clusters) - # Check if directory exists - if not os.path.exists(frames_dir): - return { - "most_common_answer": "No frames found to analyze.", - "all_answers": [], - "answer_counts": Counter() - } + except Exception as e2: + print(f"Fallback clustering also failed: {str(e2)}") - frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) - if f.lower().endswith(valid_exts)] + # Fallback 2: Simple string-based clustering + print("Using simple string-based clustering...") + return self._simple_string_cluster(answers) - # Sort frames properly by number - def get_frame_number(filename): - match = re.search(r'(\d+)', os.path.basename(filename)) - return int(match.group(1)) if match else 0 + def _simple_string_cluster(self, answers: List[str]) -> Dict[str, List[str]]: + """Simple string-based clustering fallback""" + clusters = defaultdict(list) - frame_files = sorted(frame_files, key=get_frame_number) + # Normalize answers for comparison + normalized_answers = {} + for answer in answers: + normalized = answer.lower().strip() + normalized_answers[answer] = normalized - if not frame_files: - return { - "most_common_answer": "No valid image frames found.", - "all_answers": [], - "answer_counts": Counter() - } + used = set() + + for i, answer in enumerate(answers): + if answer in used: + continue + + cluster_key = answer + clusters[cluster_key].append(answer) + used.add(answer) + + # Find similar answers using simple string similarity + for j, other_answer in enumerate(answers[i+1:], i+1): + if other_answer in used: + continue - answers = [] - for frame_path in frame_files: + # Check for exact match after normalization + if normalized_answers[answer] == normalized_answers[other_answer]: + clusters[cluster_key].append(other_answer) + used.add(other_answer) + # Alternatively, check if one string contains the other + elif (normalized_answers[answer] in normalized_answers[other_answer] or + normalized_answers[other_answer] in normalized_answers[answer]): + clusters[cluster_key].append(other_answer) + used.add(other_answer) + + return dict(clusters) + + def _analyze_temporal_patterns(self, results: List[Tuple[str, str, float]]) -> Dict[str, Any]: + """Analyze temporal patterns in answers""" try: - ans = answer_question_on_frame(frame_path, question) - answers.append(ans) - print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") + # Sort by frame number + def get_frame_number(frame_path): + match = re.search(r'frame_(\d+)', os.path.basename(frame_path)) + return int(match.group(1)) if match else 0 + + sorted_results = sorted(results, key=lambda x: get_frame_number(x[0])) + + # Analyze answer changes over time + answers_timeline = [result[1] for result in sorted_results] + changes = [] + + for i in range(1, len(answers_timeline)): + if answers_timeline[i] != answers_timeline[i-1]: + changes.append({ + 'frame_index': i, + 'from_answer': answers_timeline[i-1], + 'to_answer': answers_timeline[i] + }) + + return { + 'total_changes': len(changes), + 'change_points': changes, + 'stability_ratio': 1 - (len(changes) / max(1, len(answers_timeline) - 1)), + 'answers_timeline': answers_timeline + } + except Exception as e: - print(f"Error processing frame {frame_path}: {str(e)}") + print(f"Error in temporal analysis: {str(e)}") + return {'error': str(e)} + + def analyze_video_question(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]: + """Comprehensive video question analysis""" + if not frame_files: + return { + "final_answer": "No frames available for analysis.", + "confidence": 0.0, + "frame_count": 0, + "error": "No valid frames found" + } + + # Process all frames + print(f"Processing {len(frame_files)} frames...") + results = self._process_frames_parallel(frame_files, question, input_data) + + if not results: + return { + "final_answer": "Could not analyze any frames successfully.", + "confidence": 0.0, + "frame_count": 0, + "error": "Frame processing failed" + } + + # Extract answers and confidences + answers = [result[1] for result in results if result[1] != "Error"] + confidences = [result[2] for result in results if result[1] != "Error"] + + # Calculate statistical summary on numeric answers + numeric_answers = [] + for answer in answers: + try: + # Try to convert answer to float + numeric_value = float(answer) + numeric_answers.append(numeric_value) + except (ValueError, TypeError): + # Skip non-numeric answers + pass + + if numeric_answers: + stats = { + "minimum": float(np.min(numeric_answers)), + "maximum": float(np.max(numeric_answers)), + "range": float(np.max(numeric_answers) - np.min(numeric_answers)), + "mean": float(np.mean(numeric_answers)), + "median": float(np.median(numeric_answers)), + "count": len(numeric_answers), + "data_type": "answers" + } + elif confidences: + # Fallback to confidence statistics if no numeric answers + stats = { + "minimum": float(np.min(confidences)), + "maximum": float(np.max(confidences)), + "range": float(np.max(confidences) - np.min(confidences)), + "mean": float(np.mean(confidences)), + "median": float(np.median(confidences)), + "count": len(confidences), + "data_type": "confidences" + } + else: + stats = { + "minimum": 0.0, + "maximum": 0.0, + "range": 0.0, + "mean": 0.0, + "median": 0.0, + "count": 0, + "data_type": "none", + "note": "No numeric results available for statistical summary" + } + + + if not answers: + return { + "final_answer": "All frame processing failed.", + "confidence": 0.0, + "frame_count": len(frame_files), + "error": "No successful frame analysis" + } + + # Cluster similar answers + answer_clusters = self._cluster_similar_answers(answers, input_data) + + # Find most common cluster + largest_cluster = max(answer_clusters.items(), key=lambda x: len(x[1])) + most_common_answer = largest_cluster[0] + cluster_size = len(largest_cluster[1]) + + # Calculate weighted confidence + answer_counts = Counter(answers) + total_answers = len(answers) + frequency_confidence = answer_counts[most_common_answer] / total_answers + avg_confidence = np.mean(confidences) if confidences else 0.0 + + final_confidence = (frequency_confidence * 0.7) + (avg_confidence * 0.3) + + # Temporal analysis + temporal_analysis = self._analyze_temporal_patterns(results) - if not answers: return { - "most_common_answer": "Could not analyze any frames successfully.", - "all_answers": [], - "answer_counts": Counter() + "final_answer": most_common_answer, + "confidence": final_confidence, + "frame_count": len(frame_files), + "successful_analyses": len(answers), + "answer_distribution": dict(answer_counts), + "semantic_clusters": {k: len(v) for k, v in answer_clusters.items()}, + "temporal_analysis": temporal_analysis, + "average_model_confidence": avg_confidence, + "frequency_confidence": frequency_confidence, + "statistical_summary": stats } - counted = Counter(answers) - most_common_answer, freq = counted.most_common(1)[0] - return { - "most_common_answer": most_common_answer, - "all_answers": answers, - "answer_counts": counted - } - - -class YoutubeScreenshotQA(BaseTool): - name: str = "youtube_screenshot_qa" - description: str = ( - "Downloads a YouTube video, extracts screenshots at intervals, " - "and answers a question about the video based on the screenshots. " - "Input should be a dict with keys: 'youtube_url' and 'question'." - "Example input: {'youtube_url': 'https://www.youtube.com/watch?v=L1vXCYZAYYM', 'question': 'What is the highest number of bird species on camera simultaneously?'}" - ) - frame_interval_seconds: int = 10 # Can be parameterized if needed + #def _run(self, query: Dict[str, Any]) -> str: + def _run(self, youtube_url, question, **kwargs) -> str: + """Enhanced main execution method""" + #ipdb.set_trace() - def _run(self, input_data: Dict[str, Any]) -> str: - youtube_url = input_data.get("youtube_url") - question = input_data.get("question") + #input_data = query + #youtube_url = input_data.get("youtube_url") + #question = input_data.get("question") + input_data = { + 'youtube_url': youtube_url, + 'question': question + } if not youtube_url or not question: return "Error: Input must include 'youtube_url' and 'question'." - # Step 1: Download the video - video_dir = '/tmp/video/' - video_filename = 'downloaded_video.mp4' - print(f"Downloading YouTube video from {youtube_url}...") - video_path = download_youtube_video(youtube_url, output_dir=video_dir, output_filename=video_filename) - if not video_path or not os.path.exists(video_path): - return "Error: Failed to download the YouTube video." + try: + # Generate video hash for caching + video_hash = self._get_video_hash(youtube_url) + + # Step 1: Download video + print(f"Downloading YouTube video from {youtube_url}...") + cache_enabled = self._get_config('cache_enabled', True, input_data) + video_path = self.download_youtube_video(youtube_url, video_hash, cache_enabled) + if not video_path or not os.path.exists(video_path): + return "Error: Failed to download the YouTube video." + + # Step 2: Smart frame extraction + print(f"Extracting frames with smart selection...") + frame_files = self.smart_extract_frames(video_path, video_hash, input_data) + if not frame_files: + return "Error: Failed to extract frames from the video." + + # Step 3: Comprehensive analysis + print(f"Analyzing {len(frame_files)} frames for question: '{question}'") + analysis_result = self.analyze_video_question(frame_files, question, input_data) + + if analysis_result.get("error"): + return f"Error: {analysis_result['error']}" + + # Format comprehensive result - Fixed the reference to stats + result = f""" +๐Ÿ“Š **ANALYSIS SUMMARY**: +โ€ข Confidence Score: {analysis_result['confidence']:.2%} +โ€ข Frames Analyzed: {analysis_result['successful_analyses']}/{analysis_result['frame_count']} +โ€ข Answer Consistency: {analysis_result['temporal_analysis'].get('stability_ratio', 0):.2%} + +๐Ÿ“ˆ **ANSWER DISTRIBUTION**: +{chr(10).join([f"โ€ข {answer}: {count} frames" for answer, count in analysis_result['answer_distribution'].items()])} + +๐Ÿ” **SEMANTIC CLUSTERS**: +{chr(10).join([f"โ€ข '{cluster}': {count} similar answers" for cluster, count in analysis_result['semantic_clusters'].items()])} + +โฑ๏ธ **TEMPORAL ANALYSIS**: +โ€ข Answer Changes: {analysis_result['temporal_analysis'].get('total_changes', 0)} +โ€ข Stability: {analysis_result['temporal_analysis'].get('stability_ratio', 0):.2%} + +๐Ÿ“Š **STATISTICAL SUMMARY**: +โ€ข Minimum: {analysis_result['statistical_summary']['minimum']:.2f} +โ€ข Maximum: {analysis_result['statistical_summary']['maximum']:.2f} +โ€ข Mean: {analysis_result['statistical_summary']['mean']:.2f} +โ€ข Median: {analysis_result['statistical_summary']['median']:.2f} +โ€ข Range: {analysis_result['statistical_summary']['range']:.2f} + +๐ŸŽฏ **CONFIDENCE BREAKDOWN**: +โ€ข Frequency-based: {analysis_result['frequency_confidence']:.2%} +โ€ข Model-based: {analysis_result['average_model_confidence']:.2%} +โ€ข Combined: {analysis_result['confidence']:.2%} + """.strip() + + return result - # Step 2: Extract frames - frames_dir = '/tmp/video_frames/' - print(f"Extracting frames from {video_path} every {self.frame_interval_seconds} seconds...") - success = extract_frames(video_path, frames_dir, frame_interval_seconds=self.frame_interval_seconds) - if not success: - return "Error: Failed to extract frames from the video." - - # Step 3: Analyze frames and answer question - print(f"Answering question about the video frames...") - answer_result = answer_video_question(frames_dir, question) - if not answer_result or not answer_result.get("most_common_answer"): - return "Error: Could not analyze video frames to answer the question." - - # Format the result - most_common = answer_result["most_common_answer"] - all_answers = answer_result["all_answers"] - counts = answer_result["answer_counts"] - - result = ( - f"Most common answer: {most_common}\n" - f"All answers: {all_answers}\n" - f"Answer counts: {dict(counts)}" - ) - return result + except Exception as e: + return f"Error during video analysis: {str(e)}" -def tools_condition_with_logging(state: AgentState): - """ - Custom tools condition function that checks if the last message contains tool calls - in the Thought/Action/Action Input format and logs the transition decision. - - Args: - state (AgentState): The current state containing messages - - Returns: - str: "tools" if tool calls are present, "__end__" otherwise - """ - - import re - # Ensure we have messages in the state - if not state.get("messages") or len(state["messages"]) == 0: - print("โŒ No messages found in state, ending conversation") - return "__end__" - - # Get the last message - last_message = state["messages"][-1] - - # Get message content - content = "" - if hasattr(last_message, 'content'): - content = str(last_message.content) - elif isinstance(last_message, dict) and 'content' in last_message: - content = str(last_message['content']) - else: - print("โŒ No content found in last message, ending conversation") - return "__end__" - - print(f"๐Ÿ” Analyzing message content: {content[:200]}...") - - # Check for Thought/Action/Action Input format - has_tool_calls = False - - # Pattern to match the format: - # Thought: - # Action: - # Action Input: - thought_action_pattern = re.compile( - r'Thought:\s*(.*?)\n\s*Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', - re.DOTALL | re.IGNORECASE - ) - - # Also check for just Action/Action Input without Thought - action_only_pattern = re.compile( - r'Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', - re.DOTALL | re.IGNORECASE +# Initialize the enhanced tool +def create_enhanced_youtube_qa_tool(**kwargs): + """Factory function to create the enhanced tool with custom parameters""" + return EnhancedYoutubeScreenshotQA(**kwargs) +# Example of creating the tool instance: +# wikipedia_tool_faiss = WikipediaSearchToolWithFAISS() + +# To use this new tool in your agent, you would replace the old +# `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list. +# For example: +# tools = [wikipedia_tool_faiss, search_tool] +# Create tool instances +#wikipedia_tool = WikipediaSearchTool() + +# --- Define Call LLM function --- + +# 3. Improved LLM call with memory management + + +class YouTubeTranscriptExtractor(BaseTool): + name: str = "youtube_transcript_extractor" + description: str = ( + "Downloads a YouTube video and extracts the complete audio transcript using speech recognition with speaker identification. " + "Use this tool when you need the AUDIO or DIALOGUE or sound from a YouTube video with speaker tags," + "Input should be a dict with keys: 'youtube_url' and optional parameters. " + "Optional parameters: 'language' (default: 'en-US'), 'chunk_length_ms' (default: 30000), " + "'silence_thresh' (default: -40), 'use_enhanced_model' (default: True), 'audio_quality' (default: 'best'), " + "'enable_speaker_id' (default: True), 'max_speakers' (default: 5), 'speaker_min_duration' (default: 2.0). " + "Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'language': 'en-US', 'enable_speaker_id': True}" ) - - # Look for the complete format first - match = thought_action_pattern.search(content) - if not match: - # Try the action-only format - match = action_only_pattern.search(content) - if match: - thought = "No thought provided" - action = match.group(1).strip() - action_input = match.group(2).strip() - else: - action = None - action_input = None - thought = None - else: - thought = match.group(1).strip() - action = match.group(2).strip() - action_input = match.group(3).strip() - - if match and action: - has_tool_calls = True - print(f"๐Ÿ”ง Found tool call format:") - print(f" Thought: {thought}") - print(f" Action: {action}") - print(f" Action Input: {action_input}") - - # Map common tool names to your actual tools - tool_mappings = { - 'wikipedia_semantic_search': 'wikipedia_tool', - 'wikipedia': 'wikipedia_tool', - 'search': 'search_tool', - 'duckduckgo_search': 'search_tool', - 'web_search': 'search_tool', - 'youtube_screenshot_qa_tool': 'youtube_tool', - 'youtube': 'youtube_tool', - } - - # Normalize the action name - normalized_action = action.lower().strip() - - # Store the parsed tool call information in the state for the tools node to use - if 'parsed_tool_calls' not in state: - state['parsed_tool_calls'] = [] - - tool_call_info = { - 'thought': thought, - 'action': action, - 'action_input': action_input, - 'normalized_action': normalized_action, - 'tool_mapping': tool_mappings.get(normalized_action, normalized_action) + + # Define Pydantic fields for the attributes we need to set + recognizer: Any = Field(default=None, exclude=True) + + class Config: + # Allow arbitrary types + arbitrary_types_allowed = True + # Allow extra fields to be set + extra = "allow" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Initialize directories + cache_dir = '/tmp/youtube_transcript_cache/' + audio_dir = '/tmp/audio/' + chunks_dir = '/tmp/audio_chunks/' + + # Initialize speech recognizer + self.recognizer = sr.Recognizer() + + # Create directories + for dir_path in [cache_dir, audio_dir, chunks_dir]: + os.makedirs(dir_path, exist_ok=True) + + def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None): + """Get configuration value with fallback to defaults""" + defaults = { + 'language': 'en-US', + 'chunk_length_ms': 30000, # 30 seconds + 'silence_thresh': -40, # dB + 'use_enhanced_model': True, + 'audio_quality': 'best', + 'cache_enabled': True, + 'parallel_processing': True, + 'overlap_ms': 1000, # 1 second overlap between chunks + 'min_silence_len': 500, # minimum silence length to split on + 'energy_threshold': 4000, # recognizer energy threshold + 'pause_threshold': 0.8, # recognizer pause threshold + 'enable_speaker_id': True, # enable speaker identification + 'max_speakers': 5, # maximum number of speakers to identify + 'speaker_min_duration': 2.0, # minimum duration (seconds) for speaker segment + 'speaker_confidence_threshold': 0.6, # confidence threshold for speaker assignment + 'voice_activity_threshold': 0.01 # threshold for voice activity detection } - - state['parsed_tool_calls'].append(tool_call_info) - print(f"๐Ÿš€ Added tool call to state: {tool_call_info}") - - # Don't execute tools here - let call_tool handle execution - # Just store the parsed information for call_tool to use - - # Also check for standalone tool mentions (fallback) - if not has_tool_calls: - # Check for tool names mentioned in content - tool_keywords = [ - 'wikipedia_semantic_search', 'wikipedia', 'search', 'duckduckgo', - 'youtube_screenshot_qa_tool', 'youtube', 'web search' - ] - - content_lower = content.lower() - for keyword in tool_keywords: - if keyword in content_lower: - print(f"๐Ÿ”ง Found tool keyword '{keyword}' in content (fallback detection)") - has_tool_calls = True - break - - if has_tool_calls: - print("๐Ÿ”ง Tool calls detected, transitioning to tools...") - return "tools" - else: - print("โœ… No tool calls found, ending conversation") - return "__end__" + if input_data and key in input_data: + return input_data[key] + return defaults.get(key, default_value) + + def _get_video_hash(self, url: str) -> str: + """Generate hash for video URL for caching""" + return hashlib.md5(url.encode()).hexdigest() + + def _get_cache_path(self, video_hash: str, cache_type: str) -> str: + """Get cache file path""" + cache_dir = '/tmp/youtube_transcript_cache/' + return os.path.join(cache_dir, f"{video_hash}_{cache_type}") + + def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: + """Load data from cache""" + if not cache_enabled or not os.path.exists(cache_path): + return None + try: + with open(cache_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"Error loading cache: {str(e)}") + return None + + def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): + """Save data to cache""" + if not cache_enabled: + return + try: + with open(cache_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"Error saving cache: {str(e)}") + + def download_youtube_audio(self, url: str, video_hash: str, input_data: Dict[str, Any] = None) -> Optional[str]: + """Download YouTube video as audio file""" + audio_dir = '/tmp/audio/' + audio_quality = self._get_config('audio_quality', 'best', input_data) + output_filename = f'{video_hash}.wav' + output_path = os.path.join(audio_dir, output_filename) + + # Check cache + cache_enabled = self._get_config('cache_enabled', True, input_data) + if cache_enabled and os.path.exists(output_path): + print(f"Using cached audio: {output_path}") + return output_path + + # Clean directory + self._clean_directory(audio_dir) + + try: + # First download as mp4/webm + temp_video_path = os.path.join(audio_dir, f'{video_hash}_temp.%(ext)s') + + ydl_opts = { + 'format': 'bestaudio/best' if audio_quality == 'best' else 'worstaudio/worst', + 'outtmpl': temp_video_path, + 'quiet': True, + 'extractaudio': True, + 'audioformat': 'wav', + 'audioquality': '192K' if audio_quality == 'best' else '64K', + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.download([url]) + + # Find the downloaded file + temp_files = glob.glob(os.path.join(audio_dir, f'{video_hash}_temp.*')) + if not temp_files: + print("No temporary audio file found") + return None + + temp_file = temp_files[0] + + # Convert to WAV if not already + if not temp_file.endswith('.wav'): + try: + audio = AudioSegment.from_file(temp_file) + audio.export(output_path, format="wav") + os.remove(temp_file) # Clean up temp file + except Exception as e: + print(f"Error converting audio: {str(e)}") + # Try to rename if it's already the right format + if os.path.exists(temp_file): + os.rename(temp_file, output_path) + else: + os.rename(temp_file, output_path) + + if os.path.exists(output_path): + print(f"Audio extracted successfully: {output_path}") + return output_path + else: + print("Audio extraction completed but file not found") + return None + + except Exception as e: + print(f"Error downloading YouTube audio: {str(e)}") + return None + + def _clean_directory(self, directory: str): + """Clean directory contents""" + if os.path.exists(directory): + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print(f'Failed to delete {file_path}. Reason: {e}') + + def _extract_voice_features(self, audio_path: str) -> Optional[np.ndarray]: + """Extract voice features for speaker identification using librosa""" + try: + # Load audio with librosa + y, sr = librosa.load(audio_path, sr=None) + + # Extract MFCC features (commonly used for speaker identification) + mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) + + # Extract additional features + spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr) + spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr) + zero_crossing_rate = librosa.feature.zero_crossing_rate(y) + + # Combine features and take mean across time + features = np.concatenate([ + np.mean(mfccs, axis=1), + np.mean(spectral_centroids), + np.mean(spectral_rolloff), + np.mean(zero_crossing_rate) + ]) + + return features + + except Exception as e: + print(f"Error extracting voice features from {audio_path}: {str(e)}") + return None + + def _detect_voice_activity(self, audio_path: str, input_data: Dict[str, Any] = None) -> List[Tuple[float, float]]: + """Detect voice activity in audio chunk""" + try: + y, sr = librosa.load(audio_path, sr=None) + + # Simple voice activity detection based on energy + frame_length = int(0.025 * sr) # 25ms frames + hop_length = int(0.010 * sr) # 10ms hop + + # Calculate short-time energy + energy = [] + for i in range(0, len(y) - frame_length, hop_length): + frame = y[i:i + frame_length] + energy.append(np.sum(frame ** 2)) + + energy = np.array(energy) + threshold = self._get_config('voice_activity_threshold', 0.01, input_data) + + # Find voice segments + voice_frames = energy > (np.max(energy) * threshold) + + # Convert frame indices to time segments + voice_segments = [] + in_voice = False + start_time = 0 + + for i, is_voice in enumerate(voice_frames): + time_sec = i * hop_length / sr + if is_voice and not in_voice: + start_time = time_sec + in_voice = True + elif not is_voice and in_voice: + voice_segments.append((start_time, time_sec)) + in_voice = False + + # Close last segment if needed + if in_voice: + voice_segments.append((start_time, len(y) / sr)) + + return voice_segments + + except Exception as e: + print(f"Error in voice activity detection: {str(e)}") + return [(0, librosa.get_duration(filename=audio_path))] + + def _split_audio_intelligent(self, audio_path: str, input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """Split audio into chunks intelligently based on silence and voice activity""" + chunks_dir = '/tmp/audio_chunks/' + self._clean_directory(chunks_dir) + + try: + # Load audio + audio = AudioSegment.from_wav(audio_path) + + # Get configuration + chunk_length_ms = self._get_config('chunk_length_ms', 30000, input_data) + silence_thresh = self._get_config('silence_thresh', -40, input_data) + min_silence_len = self._get_config('min_silence_len', 500, input_data) + overlap_ms = self._get_config('overlap_ms', 1000, input_data) + + # First try to split on silence + chunks = split_on_silence( + audio, + min_silence_len=min_silence_len, + silence_thresh=silence_thresh, + keep_silence=True + ) + + # If no silence-based splits or chunks too large, split by time + if not chunks or any(len(chunk) > chunk_length_ms * 2 for chunk in chunks): + print("Using time-based splitting...") + chunks = [] + for i in range(0, len(audio), chunk_length_ms - overlap_ms): + chunk = audio[i:i + chunk_length_ms] + if len(chunk) > 1000: # Only add chunks longer than 1 second + chunks.append(chunk) + + # Save chunks and create metadata + chunk_data = [] + for i, chunk in enumerate(chunks): + if len(chunk) < 1000: # Skip very short chunks + continue + + chunk_filename = os.path.join(chunks_dir, f"chunk_{i:04d}.wav") + chunk.export(chunk_filename, format="wav") + + # Calculate timing information + start_time = sum(len(chunks[j]) for j in range(i)) / 1000.0 # in seconds + duration = len(chunk) / 1000.0 # in seconds + + chunk_info = { + 'filename': chunk_filename, + 'index': i, + 'start_time': start_time, + 'duration': duration, + 'end_time': start_time + duration + } + + chunk_data.append(chunk_info) + + print(f"Split audio into {len(chunk_data)} chunks") + return chunk_data + + except Exception as e: + print(f"Error splitting audio: {str(e)}") + # Fallback: return original file + return [{ + 'filename': audio_path, + 'index': 0, + 'start_time': 0, + 'duration': len(AudioSegment.from_wav(audio_path)) / 1000.0, + 'end_time': len(AudioSegment.from_wav(audio_path)) / 1000.0 + }] + + def _transcribe_audio_chunk(self, chunk_info: Dict[str, Any], input_data: Dict[str, Any] = None) -> Dict[str, Any]: + """Transcribe a single audio chunk""" + chunk_path = chunk_info['filename'] + try: + language = self._get_config('language', 'en-US', input_data) + + # Configure recognizer + self.recognizer.energy_threshold = self._get_config('energy_threshold', 4000, input_data) + self.recognizer.pause_threshold = self._get_config('pause_threshold', 0.8, input_data) + + with sr.AudioFile(chunk_path) as source: + # Adjust for ambient noise + self.recognizer.adjust_for_ambient_noise(source, duration=0.5) + audio_data = self.recognizer.record(source) + + # Try Google Speech Recognition first (most accurate) + try: + text = self.recognizer.recognize_google(audio_data, language=language) + result = { + 'text': text, + 'confidence': 1.0, # Google doesn't provide confidence + 'method': 'google', + 'chunk': os.path.basename(chunk_path), + 'start_time': chunk_info['start_time'], + 'end_time': chunk_info['end_time'], + 'duration': chunk_info['duration'], + 'index': chunk_info['index'] + } + + # Extract voice features if speaker ID is enabled + if self._get_config('enable_speaker_id', True, input_data): + features = self._extract_voice_features(chunk_path) + result['voice_features'] = features.tolist() if features is not None else None + + return result + + except sr.UnknownValueError: + # Try alternative recognition methods + try: + # Try with alternative language detection + text = self.recognizer.recognize_google(audio_data) + result = { + 'text': text, + 'confidence': 0.8, # Lower confidence for language mismatch + 'method': 'google_auto', + 'chunk': os.path.basename(chunk_path), + 'start_time': chunk_info['start_time'], + 'end_time': chunk_info['end_time'], + 'duration': chunk_info['duration'], + 'index': chunk_info['index'] + } + + if self._get_config('enable_speaker_id', True, input_data): + features = self._extract_voice_features(chunk_path) + result['voice_features'] = features.tolist() if features is not None else None + + return result + + except sr.UnknownValueError: + return { + 'text': '[INAUDIBLE]', + 'confidence': 0.0, + 'method': 'failed', + 'chunk': os.path.basename(chunk_path), + 'start_time': chunk_info['start_time'], + 'end_time': chunk_info['end_time'], + 'duration': chunk_info['duration'], + 'index': chunk_info['index'], + 'voice_features': None + } + except sr.RequestError as e: + print(f"Google Speech Recognition error: {e}") + return { + 'text': '[RECOGNITION_ERROR]', + 'confidence': 0.0, + 'method': 'error', + 'chunk': os.path.basename(chunk_path), + 'start_time': chunk_info['start_time'], + 'end_time': chunk_info['end_time'], + 'duration': chunk_info['duration'], + 'index': chunk_info['index'], + 'error': str(e), + 'voice_features': None + } + + except Exception as e: + print(f"Error transcribing chunk {chunk_path}: {str(e)}") + return { + 'text': '[ERROR]', + 'confidence': 0.0, + 'method': 'error', + 'chunk': os.path.basename(chunk_path), + 'start_time': chunk_info.get('start_time', 0), + 'end_time': chunk_info.get('end_time', 0), + 'duration': chunk_info.get('duration', 0), + 'index': chunk_info.get('index', 0), + 'error': str(e), + 'voice_features': None + } + + def _identify_speakers(self, transcript_results: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """Identify speakers using voice features clustering""" + enable_speaker_id = self._get_config('enable_speaker_id', True, input_data) + if not enable_speaker_id: + # Add default speaker tags + for result in transcript_results: + result['speaker_id'] = 'SPEAKER_1' + result['speaker_confidence'] = 1.0 + return transcript_results + + try: + # Filter results with valid voice features and text + valid_results = [] + features_list = [] + + for result in transcript_results: + if (result.get('voice_features') is not None and + result['text'] not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']): + valid_results.append(result) + features_list.append(result['voice_features']) + + if len(features_list) < 2: + # Not enough data for clustering + for result in transcript_results: + result['speaker_id'] = 'SPEAKER_1' + result['speaker_confidence'] = 1.0 + return transcript_results + + # Normalize features + features_array = np.array(features_list) + scaler = StandardScaler() + normalized_features = scaler.fit_transform(features_array) + + # Determine optimal number of speakers + max_speakers = min(self._get_config('max_speakers', 5, input_data), len(features_list)) + + # Use elbow method to find optimal clusters (simplified) + best_k = 1 + if len(features_list) > 1: + best_score = float('inf') + for k in range(1, min(max_speakers + 1, len(features_list) + 1)): + try: + kmeans = KMeans(n_clusters=k, random_state=42, n_init=10) + labels = kmeans.fit_predict(normalized_features) + if k > 1: + score = kmeans.inertia_ + if score < best_score: + best_score = score + best_k = k + except: + continue + + # Don't use too many clusters for short audio + if len(features_list) < 10: + best_k = min(best_k, 2) + + # Perform final clustering + kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=10) + speaker_labels = kmeans.fit_predict(normalized_features) + + # Calculate speaker assignment confidence + distances = kmeans.transform(normalized_features) + confidences = [] + for i, label in enumerate(speaker_labels): + # Confidence based on distance to assigned cluster vs. nearest other cluster + dist_to_assigned = distances[i][label] + other_distances = np.delete(distances[i], label) + if len(other_distances) > 0: + dist_to_nearest_other = np.min(other_distances) + confidence = max(0.1, min(1.0, dist_to_nearest_other / (dist_to_assigned + 1e-6))) + else: + confidence = 1.0 + confidences.append(confidence) + + # Assign speaker IDs back to results + valid_idx = 0 + speaker_duration = {} # Track duration per speaker + + for result in transcript_results: + if (result.get('voice_features') is not None and + result['text'] not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']): + + speaker_label = speaker_labels[valid_idx] + confidence = confidences[valid_idx] + + # Filter by confidence threshold + conf_threshold = self._get_config('speaker_confidence_threshold', 0.6, input_data) + if confidence < conf_threshold: + speaker_id = 'SPEAKER_UNKNOWN' + else: + speaker_id = f'SPEAKER_{speaker_label + 1}' + + result['speaker_id'] = speaker_id + result['speaker_confidence'] = confidence + + # Track speaker duration + if speaker_id in speaker_duration: + speaker_duration[speaker_id] += result['duration'] + else: + speaker_duration[speaker_id] = result['duration'] + + valid_idx += 1 + else: + # Handle invalid results + result['speaker_id'] = 'SPEAKER_UNKNOWN' + result['speaker_confidence'] = 0.0 + + # Filter out speakers with insufficient duration + min_duration = self._get_config('speaker_min_duration', 2.0, input_data) + speakers_to_merge = [s for s, d in speaker_duration.items() if d < min_duration and s != 'SPEAKER_UNKNOWN'] + + # Merge low-duration speakers into SPEAKER_UNKNOWN + for result in transcript_results: + if result['speaker_id'] in speakers_to_merge: + result['speaker_id'] = 'SPEAKER_UNKNOWN' + result['speaker_confidence'] = 0.3 + + print(f"Identified {best_k} speakers based on voice characteristics") + return transcript_results + + except Exception as e: + print(f"Error in speaker identification: {str(e)}") + # Fallback: assign all to single speaker + for result in transcript_results: + result['speaker_id'] = 'SPEAKER_1' + result['speaker_confidence'] = 1.0 + return transcript_results + + def _transcribe_chunks_parallel(self, chunk_data: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """Transcribe audio chunks in parallel""" + results = [] + parallel_processing = self._get_config('parallel_processing', True, input_data) + + if parallel_processing: + # Use fewer workers for speech recognition to avoid API limits + max_workers = min(3, len(chunk_data)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_chunk = { + executor.submit(self._transcribe_audio_chunk, chunk_info, input_data): chunk_info + for chunk_info in chunk_data + } + + for future in as_completed(future_to_chunk): + chunk_info = future_to_chunk[future] + try: + result = future.result() + results.append(result) + print(f"Transcribed {result['chunk']}: {result['text'][:50]}..." if len(result['text']) > 50 else f"Transcribed {result['chunk']}: {result['text']}") + except Exception as e: + print(f"Error processing {chunk_info['filename']}: {str(e)}") + results.append({ + 'text': '[PROCESSING_ERROR]', + 'confidence': 0.0, + 'method': 'error', + 'chunk': os.path.basename(chunk_info['filename']), + 'start_time': chunk_info.get('start_time', 0), + 'end_time': chunk_info.get('end_time', 0), + 'duration': chunk_info.get('duration', 0), + 'index': chunk_info.get('index', 0), + 'error': str(e), + 'voice_features': None + }) + else: + for chunk_info in chunk_data: + result = self._transcribe_audio_chunk(chunk_info, input_data) + results.append(result) + print(f"Transcribed {result['chunk']}: {result['text'][:50]}..." if len(result['text']) > 50 else f"Transcribed {result['chunk']}: {result['text']}") + + # Sort results by chunk index to maintain order + results.sort(key=lambda x: x['index']) + return results + + def _post_process_transcript(self, transcript_results: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> Dict[str, Any]: + """Post-process and analyze transcript results with speaker information""" + enable_speaker_id = self._get_config('enable_speaker_id', True, input_data) + + # Identify speakers if enabled + if enable_speaker_id: + transcript_results = self._identify_speakers(transcript_results, input_data) + + # Combine text with speaker tags + full_text_parts = [] + speaker_tagged_text = [] + successful_chunks = 0 + total_confidence = 0.0 + method_counts = {} + speaker_stats = {} + + current_speaker = None + current_speaker_text = [] + + for result in transcript_results: + text = result['text'] + speaker = result.get('speaker_id', 'SPEAKER_1') + start_time = result.get('start_time', 0) + + if text not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']: + full_text_parts.append(text) + successful_chunks += 1 + total_confidence += result['confidence'] + + # Handle speaker transitions + if enable_speaker_id: + if current_speaker != speaker: + # Save previous speaker's text + if current_speaker and current_speaker_text: + combined_text = ' '.join(current_speaker_text) + speaker_tagged_text.append(f"[{current_speaker}]: {combined_text}") + + # Start new speaker + current_speaker = speaker + current_speaker_text = [text] + else: + # Continue with same speaker + current_speaker_text.append(text) + else: + speaker_tagged_text.append(text) + + # Update speaker statistics + if speaker in speaker_stats: + speaker_stats[speaker]['duration'] += result.get('duration', 0) + speaker_stats[speaker]['word_count'] += len(text.split()) + speaker_stats[speaker]['segments'] += 1 + else: + speaker_stats[speaker] = { + 'duration': result.get('duration', 0), + 'word_count': len(text.split()), + 'segments': 1, + 'confidence': result.get('speaker_confidence', 1.0) + } + + method = result['method'] + method_counts[method] = method_counts.get(method, 0) + 1 + + # Add final speaker text + if enable_speaker_id and current_speaker and current_speaker_text: + combined_text = ' '.join(current_speaker_text) + speaker_tagged_text.append(f"[{current_speaker}]: {combined_text}") + + # Combine texts + combined_text = ' '.join(full_text_parts) + speaker_formatted_text = combined_text + + # Calculate statistics + word_count = len(combined_text.split()) if combined_text else 0 + char_count = len(combined_text) + avg_confidence = total_confidence / max(1, successful_chunks) + success_rate = successful_chunks / len(transcript_results) if transcript_results else 0 + + # Estimate speaking duration (rough approximation: 150 words per minute) + estimated_duration_minutes = word_count / 150 if word_count > 0 else 0 + + return { + 'full_transcript': combined_text, + 'speaker_tagged_transcript': speaker_formatted_text, + 'word_count': word_count, + 'character_count': char_count, + 'chunk_count': len(transcript_results), + 'successful_chunks': successful_chunks, + 'success_rate': success_rate, + 'average_confidence': avg_confidence, + 'method_distribution': method_counts, + 'estimated_duration_minutes': estimated_duration_minutes, + 'speaker_identification_enabled': enable_speaker_id, + 'speaker_statistics': speaker_stats, + 'total_speakers': len([s for s in speaker_stats.keys() if s != 'SPEAKER_UNKNOWN']), + 'detailed_results': transcript_results + } + + def extract_transcript(self, audio_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]: + """Extract complete transcript from audio file""" + cache_enabled = self._get_config('cache_enabled', True, input_data) + enable_speaker_id = self._get_config('enable_speaker_id', True, input_data) + cache_suffix = "transcript_with_speakers.json" if enable_speaker_id else "transcript.json" + cache_path = self._get_cache_path(video_hash, cache_suffix) + + # Check cache + cached_transcript = self._load_from_cache(cache_path, cache_enabled) + if cached_transcript: + print("Using cached transcript") + return cached_transcript + + try: + # Step 1: Split audio into manageable chunks + print("Splitting audio into chunks...") + chunk_data = self._split_audio_intelligent(audio_path, input_data) + + if not chunk_data: + return { + 'error': 'Failed to split audio into chunks', + 'full_transcript': '', + 'speaker_tagged_transcript': '', + 'success_rate': 0.0 + } + + # Step 2: Transcribe all chunks + print(f"Transcribing {len(chunk_data)} audio chunks...") + transcript_results = self._transcribe_chunks_parallel(chunk_data, input_data) + + # Step 3: Post-process and combine results + print("Post-processing transcript and identifying speakers...") + final_result = self._post_process_transcript(transcript_results, input_data) + + # Add timestamp + final_result['extraction_timestamp'] = time.time() + final_result['extraction_date'] = time.strftime('%Y-%m-%d %H:%M:%S') + + # Cache results + self._save_to_cache(cache_path, final_result, cache_enabled) + + return final_result + + except Exception as e: + print(f"Error during transcript extraction: {str(e)}") + return { + 'error': str(e), + 'full_transcript': '', + 'speaker_tagged_transcript': '', + 'success_rate': 0.0 + } + + def _run(self, youtube_url: str, **kwargs) -> str: + """Main execution method""" + input_data = { + 'youtube_url': youtube_url, + **kwargs + } + + if not youtube_url: + return "Error: youtube_url is required." + + try: + # Generate video hash for caching + video_hash = self._get_video_hash(youtube_url) + + # Step 1: Download audio + print(f"Downloading YouTube audio from {youtube_url}...") + audio_path = self.download_youtube_audio(youtube_url, video_hash, input_data) + if not audio_path or not os.path.exists(audio_path): + return "Error: Failed to download the YouTube audio." + + # Step 2: Extract transcript + print("Extracting audio transcript...") + transcript_result = self.extract_transcript(audio_path, video_hash, input_data) + + if transcript_result.get("error"): + return f"Error: {transcript_result['error']}" + + # Choose the appropriate transcript + main_transcript = transcript_result.get('full_transcript') + + #ipdb.set_trace() + print(f"Transcript extracted: {main_transcript[:50]}..." if len(main_transcript) > 50 else f"Transcript extracted: {main_transcript}") + + return "TRANSCRIPT: " + main_transcript + + except Exception as e: + return f"Error during transcript extraction: {str(e)}" + + +# Factory function to create the tool +def create_youtube_transcript_tool(**kwargs): + """Factory function to create the transcript extraction tool with custom parameters""" + return YouTubeTranscriptExtractor(**kwargs) + + + +# --- Model Configuration --- +def create_llm_pipeline(): + #model_id = "meta-llama/Llama-2-13b-chat-hf" + #model_id = "meta-llama/Llama-3.3-70B-Instruct" + #model_id = "mistralai/Mistral-Small-24B-Base-2501" + model_id = "mistralai/Mistral-7B-Instruct-v0.3" + #model_id = "Meta-Llama/Llama-2-7b-chat-hf" + #model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" + #model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF" + #model_id = "mistralai/Mistral-7B-Instruct-v0.2" + #model_id = "Qwen/Qwen2-7B-Instruct" + #model_id = "GSAI-ML/LLaDA-8B-Instruct" + return pipeline( + "text-generation", + model=model_id, + device_map="auto", + torch_dtype=torch.float16, + max_new_tokens=1024, + temperature=0.3, + top_k=50, + top_p=0.95 + ) + + +nlp = None # Set to None if not using spaCy, so the regex fallback is used in extract_entities + +# --- Agent State Definition --- +class AgentState(TypedDict): + messages: Annotated[List[AnyMessage], lambda x, y: x + y] + done: bool = False # Default value of False + question: str + task_id: str + input_file: Optional[bytes] + file_type: Optional[str] + context: List[Document] # Using LangChain's Document class + file_path: Optional[str] + youtube_url: Optional[str] + answer: Optional[str] + frame_answers: Optional[list] + + + +# --- Define Call LLM function --- + +# 3. Improved LLM call with memory management + + +def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: + """Enhanced LLM call with better prompt engineering and hallucination prevention.""" + print("Running call_llm with memory management...") + + #ipdb.set_trace() + + original_messages = messages_for_llm = state["messages"] + + # Context management - be more aggressive about truncation + system_message_content = None + if messages_for_llm and isinstance(messages_for_llm[0], SystemMessage): + system_message_content = messages_for_llm[0] + regular_messages = messages_for_llm[1:] + else: + regular_messages = messages_for_llm + + # Keep only the most recent messages (more aggressive) + max_regular_messages = 6 # Reduced from 9 + if len(regular_messages) > max_regular_messages: + print(f"๐Ÿ”„ Truncating to {max_regular_messages} recent messages") + regular_messages = regular_messages[-max_regular_messages:] + + # Reconstruct for LLM + messages_for_llm = [] + if system_message_content: + messages_for_llm.append(system_message_content) + messages_for_llm.extend(regular_messages) + + # Character limit check + total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) + char_limit = 20000 + + if total_chars > char_limit: + print(f"๐Ÿ“ Context too long ({total_chars} chars) - further truncation") + while regular_messages and sum(len(str(m.content)) for m in regular_messages) > char_limit - (len(str(system_message_content.content)) if system_message_content else 0): + regular_messages.pop(0) + + messages_for_llm = [] + if system_message_content: + messages_for_llm.append(system_message_content) + messages_for_llm.extend(regular_messages) + + new_state = state.copy() + + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + print(f"๐Ÿค– Calling LLM with {len(messages_for_llm)} messages") + + # Convert to simple string format that the model can understand + if len(messages_for_llm) == 2 and isinstance(messages_for_llm[0], SystemMessage) and isinstance(messages_for_llm[1], HumanMessage): + # Initial query - use simple format + system_content = messages_for_llm[0].content + human_content = messages_for_llm[1].content + + formatted_input = f"{system_content}\n\nHuman: {human_content}\n\nAssistant:" + else: + # Ongoing conversation - build context + formatted_input = "" + + # Add system message if present + if system_message_content: + formatted_input += f"{system_message_content.content}\n\n" + + # Add conversation messages + for msg in regular_messages: + if isinstance(msg, HumanMessage): + formatted_input += f"Human: {msg.content}\n\n" + elif isinstance(msg, AIMessage): + formatted_input += f"Assistant: {msg.content}\n\n" + elif isinstance(msg, ToolMessage): + formatted_input += f"Tool Result: {msg.content}\n\n" + + # Add explicit instruction for immediate final answer if we have recent tool results + if any(isinstance(msg, ToolMessage) for msg in regular_messages[-2:]): + formatted_input += "Based on the tool results above, provide your FINAL ANSWER now.\n\n" + formatted_input += "REMINDER ON ANSWER FORMAT: \n" + formatted_input += "- Numbers: no commas, no units unless specified\n" + formatted_input += "- Strings: no articles, no abbreviations, digits in plain text\n" + formatted_input += "- Lists: comma-separated following above rules\n" + formatted_input += "- Be extremely brief and concise" + + formatted_input += "Assistant:" + + print(f"Input preview: {formatted_input[:300]}...") + + llm_response_object = llm_model.invoke(formatted_input) + + # Process response and clean up hallucinated content + if isinstance(llm_response_object, BaseMessage): + raw_content = llm_response_object.content + elif hasattr(llm_response_object, 'content'): + raw_content = str(llm_response_object.content) + else: + raw_content = str(llm_response_object) + + # Clean up the response to prevent hallucinated follow-up questions + cleaned_content = clean_llm_response(raw_content) + ai_message_response = AIMessage(content=cleaned_content) + + print(f"๐Ÿ” LLM Response preview: {cleaned_content[:200]}...") + + final_messages = original_messages + [ai_message_response] + new_state["messages"] = final_messages + new_state.pop("done", None) + + except Exception as e: + print(f"โŒ LLM call failed: {e}") + error_message = AIMessage(content=f"Error: LLM call failed - {str(e)}") + new_state["messages"] = original_messages + [error_message] + new_state["done"] = True + + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return new_state + + +def clean_llm_response(response_text: str) -> str: + """ + Clean LLM response to prevent hallucinated follow-up questions and conversations. + Specifically handles ReAct format: Thought: -> Action: -> Action Input: + """ + if not response_text: + return response_text + + print(f"Initial response: {response_text[:200]}...") + + # --- START MODIFICATION --- + # Isolate the text generated by the assistant in the last turn. + # This prevents parsing examples or instructions from the preamble. + assistant_marker = "Assistant:" + last_marker_idx = response_text.rfind(assistant_marker) + + text_to_process = response_text # Default to full text if marker not found + if last_marker_idx != -1: + # If "Assistant:" is found, process only the text after the last occurrence. + text_to_process = response_text[last_marker_idx + len(assistant_marker):].strip() + print(f"โ„น๏ธ Parsing content after last 'Assistant:': {text_to_process[:200]}...") + else: + # If "Assistant:" is not found, process the whole input. + # This might occur if the input is already just the assistant's direct response + # or if the prompt structure is different. + print(f"โ„น๏ธ No 'Assistant:' marker found. Processing entire input as is.") + # --- END MODIFICATION --- + + # Now, all subsequent operations use 'text_to_process' + + # Try to find a complete ReAct pattern in the assistant's actual output + react_pattern = r'Thought:\s*(.*?)\s*Action:\s*([^\n\r]+)\s*Action Input:\s*(.*?)(?=\s*(?:Thought:|Action:|FINAL ANSWER:|$))' + # Apply search to 'text_to_process' + react_match = re.search(react_pattern, text_to_process, re.DOTALL | re.IGNORECASE) + + if react_match: + thought_text = react_match.group(1).strip() + action_name = react_match.group(2).strip() + action_input = react_match.group(3).strip() + + # Clean up the action input - remove any trailing content that looks like instructions + action_input_clean = re.sub(r'\s*(When you have|FINAL ANSWER|ANSWER FORMAT|IMPORTANT:).*$', '', action_input, flags=re.DOTALL | re.IGNORECASE) + action_input_clean = action_input_clean.strip() + + react_sequence = f"Thought: {thought_text}\nAction: {action_name}\nAction Input: {action_input_clean}" + + print(f"๐Ÿ”ง Found ReAct pattern - Action: {action_name}, Input: {action_input_clean[:100]}...") + + # Check if there's a FINAL ANSWER after the action input (this would be hallucination) + # Check in the remaining part of 'text_to_process' + remaining_text_in_process = text_to_process[react_match.end():] + final_answer_after = re.search(r'FINAL ANSWER:', remaining_text_in_process, re.IGNORECASE) + if final_answer_after: + print(f"๐Ÿšซ Removed hallucinated FINAL ANSWER after tool call") + + return react_sequence + + # If no ReAct pattern in 'text_to_process', check for standalone FINAL ANSWER + # This variable will hold the text being processed for FINAL ANSWER and then for fallback. + current_text_for_processing = text_to_process + final_answer_match = re.search(r"FINAL ANSWER:\s*(.+?)(?=\n|$)", current_text_for_processing, re.IGNORECASE) + if final_answer_match: + answer_content = final_answer_match.group(1).strip() + + template_phrases = [ + '[concise answer only]', + '[concise answer - number/word/list only]', + '[brief answer]', + '[your answer here]', + 'concise answer only', + 'brief answer', + 'your answer here' + ] + + if any(phrase.lower() in answer_content.lower() for phrase in template_phrases): + print(f"๐Ÿšซ Ignoring template FINAL ANSWER: {answer_content}") + # Remove the template FINAL ANSWER and continue cleaning on the remainder of 'current_text_for_processing' + current_text_for_processing = current_text_for_processing[:final_answer_match.start()].strip() + # Fall through to the general cleanup section below + else: + # Keep everything from the start of 'current_text_for_processing' up to and including the real FINAL ANSWER line only + cleaned_output = current_text_for_processing[:final_answer_match.end()] + + # Check if there's additional content after FINAL ANSWER in 'current_text_for_processing' + remaining_after_final_answer = current_text_for_processing[final_answer_match.end():].strip() + if remaining_after_final_answer: + print(f"๐Ÿšซ Removed content after FINAL ANSWER: {remaining_after_final_answer[:100]}...") + + return cleaned_output.strip() + + # If no ReAct, or FINAL ANSWER was a template or not found, apply fallback cleaning to 'current_text_for_processing' + lines = current_text_for_processing.split('\n') + cleaned_lines = [] + + for i, line in enumerate(lines): + # Stop if we see the model repeating system instructions + if re.search(r'\[SYSTEM\]|\[HUMAN\]|\[ASSISTANT\]|\[TOOL\]', line, re.IGNORECASE): + print(f"๐Ÿšซ Stopped at repeated system format: {line}") + break + + # Stop if we see the model generating format instructions + if re.search(r'CRITICAL INSTRUCTIONS|FORMAT for tool use|ANSWER FORMAT', line, re.IGNORECASE): + print(f"๐Ÿšซ Stopped at repeated instructions: {line}") + break + + # Stop if we see the model role-playing as a human asking questions + if re.search(r'(what are|what is|how many|can you tell me)', line, re.IGNORECASE) and not line.strip().startswith(('Thought:', 'Action:', 'Action Input:')): + # Make sure this isn't part of a legitimate thought process + if i > 0 and not any(keyword in lines[i-1] for keyword in ['Thought:', 'Action:', 'need to']): + print(f"๐Ÿšซ Stopped at hallucinated question: {line}") + break + + cleaned_lines.append(line) + + cleaned = '\n'.join(cleaned_lines).strip() + print(f"Final cleaned response (fallback): {cleaned[:200]}...") + + return cleaned + +def parse_react_output(state: AgentState) -> AgentState: + """ + Enhanced parsing with better FINAL ANSWER detection and flow control. + """ + print("Running parse_react_output...") + + #ipdb.set_trace() + + messages = state.get("messages", []) + if not messages: + print("No messages in state.") + new_state = state.copy() + new_state["done"] = True + return new_state + + + # DEBUG + print(f"parse_react_output: Entry message count: {len(messages)}") + if messages and hasattr(messages[-1], 'tool_calls'): + print(f"parse_react_output: Number of tool calls in last AIMessage: {len(messages[-1].tool_calls)}") + + last_message = messages[-1] + new_state = state.copy() + + if not isinstance(last_message, AIMessage): + print("Last message is not an AIMessage instance.") + return new_state + + content = last_message.content + if not isinstance(content, str): + content = str(content) + + # Look for FINAL ANSWER first - this should take absolute priority + # Use a more precise regex to capture just the answer line + final_answer_match = re.search(r"FINAL ANSWER:\s*([^\n\r]+)", content, re.IGNORECASE) + if final_answer_match: + final_answer_text = final_answer_match.group(1).strip() + + # Check if this is template text (not a real answer) + template_phrases = [ + '[concise answer only]', + '[concise answer - number/word/list only]', + '[brief answer]', + '[your answer here]', + 'concise answer only', + 'brief answer', + 'your answer here' + ] + + # If it's template text, don't treat it as a final answer + if any(phrase.lower() in final_answer_text.lower() for phrase in template_phrases): + print(f"๐Ÿšซ Ignoring template FINAL ANSWER: '{final_answer_text}'") + # Continue processing as if no final answer was found + else: + print(f"๐ŸŽฏ FINAL ANSWER found: '{final_answer_text}' - ENDING") + + # Store the answer in state for easy access + new_state["answer"] = final_answer_text + + # Clean up the message content to just show the final answer + clean_content = f"FINAL ANSWER: {final_answer_text}" + updated_ai_message = AIMessage(content=clean_content, tool_calls=[]) + new_state["messages"] = messages[:-1] + [updated_ai_message] + new_state["done"] = True + return new_state + + # If no FINAL ANSWER, look for tool calls + action_match = re.search(r"Action:\s*([^\n]+)", content, re.IGNORECASE) + action_input_match = re.search(r"Action Input:\s*(.+)", content, re.IGNORECASE | re.DOTALL) + + if action_match and action_input_match: + tool_name = action_match.group(1).strip() + tool_input_raw = action_input_match.group(1).strip() + + if tool_name.lower() == "none": + print("Action is 'None' - treating as regular response") + updated_ai_message = AIMessage(content=content, tool_calls=[]) + new_state["messages"] = messages[:-1] + [updated_ai_message] + new_state.pop("done", None) + return new_state + + print(f"๐Ÿ”ง Tool call: {tool_name} with input: {tool_input_raw[:100]}...") + + # Parse tool arguments + tool_args = {} + try: + trimmed_input = tool_input_raw.strip() + if (trimmed_input.startswith('{') and trimmed_input.endswith('}')) or \ + (trimmed_input.startswith('[') and trimmed_input.endswith(']')): + tool_args = ast.literal_eval(trimmed_input) + if not isinstance(tool_args, dict): + tool_args = {"query": tool_input_raw} + else: + tool_args = {"query": tool_input_raw} + except (ValueError, SyntaxError): + tool_args = {"query": tool_input_raw} + + tool_call_id = str(uuid.uuid4()) + parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] + + updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) + new_state["messages"] = messages[:-1] + [updated_ai_message] + new_state.pop("done", None) + return new_state + + # No tool call or final answer - treat as regular response + print("No actionable content found - continuing conversation") + updated_ai_message = AIMessage(content=content, tool_calls=[]) + new_state["messages"] = messages[:-1] + [updated_ai_message] + new_state.pop("done", None) + + + # DEBUG + print(f"parse_react_output: Exit message count: {len(new_state['messages'])}") + + + return new_state + +# 4. Improved call_tool_with_memory_management to prevent duplicate processing +def call_tool_with_memory_management(state: AgentState) -> AgentState: + """Process tool calls with memory management, avoiding duplicates.""" + print("Running call_tool with memory management...") + + # Clear CUDA cache before processing + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print(f"๐Ÿงน Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") + except ImportError: + pass + except Exception as e: + print(f"Error clearing CUDA cache: {e}") + + # Check if we have parsed tool calls from the condition function + if 'parsed_tool_calls' in state and state.get('parsed_tool_calls'): + print("Executing parsed tool calls...") + return execute_parsed_tool_calls(state) -# 2. Improved call_tool with memory management -def call_tool_with_memory_management(state: AgentState) -> AgentState: - """Process tool calls with memory management.""" - print("Running call_tool with memory management...") - - # Clear CUDA cache before processing - try: - import torch - #if torch.cuda.is_available(): - # torch.cuda.empty_cache() - # print(f"๐Ÿงน Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") - except: - pass - - # Check if we have parsed tool calls from the condition function - if 'parsed_tool_calls' in state and state['parsed_tool_calls']: - return execute_parsed_tool_calls(state) - # Fallback to original OpenAI-style tool calls handling - messages = state["messages"] + messages = state.get("messages", []) + if not messages: + print("No messages found in state.") + return state + last_message = messages[-1] - + if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: print("No tool calls found in last message") return state - + + # Avoid processing the same tool calls multiple times + if hasattr(last_message, '_processed_tool_calls'): + print("Tool calls already processed, skipping...") + return state + # Copy the messages to avoid mutating the original list new_messages = list(messages) - - print(f"Processing {len(last_message.tool_calls)} tool calls") - - for i, tool_call in enumerate(last_message.tool_calls): - print(f"Processing tool call {i+1}: {tool_call['name'] if isinstance(tool_call, dict) else tool_call.name}") - + print(f"Processing {len(last_message.tool_calls)} tool calls from last message") + + # Get file_path from state to pass to tools + file_path_to_pass = state.get('file_path') + + for i, tool_call_item in enumerate(last_message.tool_calls): # Handle both dict and object-style tool calls - if isinstance(tool_call, dict): - tool_name = tool_call.get("name", "") - args = tool_call.get("args", {}) - tool_call_id = tool_call.get("id", str(uuid.uuid4())) + if isinstance(tool_call_item, dict): + tool_name = tool_call_item.get("name", "") + raw_args = tool_call_item.get("args") + tool_call_id = tool_call_item.get("id", str(uuid.uuid4())) + elif hasattr(tool_call_item, "name") and hasattr(tool_call_item, "id"): + tool_name = getattr(tool_call_item, "name", "") + raw_args = getattr(tool_call_item, "args", None) + tool_call_id = getattr(tool_call_item, "id", str(uuid.uuid4())) else: - tool_name = getattr(tool_call, "name", "") - args = getattr(tool_call, "args", {}) - tool_call_id = getattr(tool_call, "id", str(uuid.uuid4())) - + print(f"Skipping malformed tool call item: {tool_call_item}") + continue + + print(f"Processing tool call {i+1}: {tool_name}") + # Find the matching tool selected_tool = None - for tool in tools: - if tool.name.lower() == tool_name.lower(): - selected_tool = tool + for tool_instance in tools: + if tool_instance.name.lower() == tool_name.lower(): + selected_tool = tool_instance break - + if not selected_tool: tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}" print(f"Tool not found: {tool_name}") else: try: - # Extract query - if isinstance(args, dict) and "query" in args: - query = args["query"] - else: - query = str(args) if args else "" - - print(f"Executing {tool_name} with query: {query[:100]}...") - tool_result = selected_tool.run(query) - + # Prepare the arguments for the tool.run() method + tool_run_input_dict = {} + + if isinstance(raw_args, dict): + tool_run_input_dict = raw_args.copy() + elif raw_args is not None: + tool_run_input_dict["query"] = str(raw_args) + + # Add file_path to the dictionary for the tool + tool_run_input_dict['file_path'] = file_path_to_pass + + print(f"Executing {tool_name} with args: {tool_run_input_dict} ...") + tool_result = selected_tool.run(tool_run_input_dict) + + + #ipdb.set_trace() + + # Aggressive truncation to prevent memory issues + if not isinstance(tool_result, str): + tool_result = str(tool_result) + max_length = 3000 if "wikipedia" in tool_name.lower() else 2000 if len(tool_result) > max_length: - tool_result = tool_result[:max_length] + f"... [Result truncated from {len(tool_result)} to {max_length} chars to prevent memory issues]" + original_length = len(tool_result) + tool_result = tool_result[:max_length] + f"... [Result truncated from {original_length} to {max_length} chars to prevent memory issues]" print(f"๐Ÿ“„ Truncated result to {max_length} characters") - + print(f"Tool result length: {len(tool_result)} characters") - + except Exception as e: tool_result = f"Error executing tool '{tool_name}': {str(e)}" print(f"Tool execution error: {e}") - - # Create tool message + + # Create tool message - ONLY ONE PER TOOL CALL tool_message = ToolMessage( content=tool_result, name=tool_name, @@ -1367,155 +2761,156 @@ def call_tool_with_memory_management(state: AgentState) -> AgentState: ) new_messages.append(tool_message) print(f"Added tool message for {tool_name}") - + + # Mark the last message as processed to prevent re-processing + if hasattr(last_message, '__dict__'): + last_message._processed_tool_calls = True + # Update the state new_state = state.copy() new_state["messages"] = new_messages - + # Clear CUDA cache after processing try: import torch - #if torch.cuda.is_available(): - # torch.cuda.empty_cache() - except: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print(f"๐Ÿงน Cleared CUDA cache post-processing. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") + except ImportError: pass - - return new_state + except Exception as e: + print(f"Error clearing CUDA cache post-processing: {e}") + return new_state +# 3. Enhanced execute_parsed_tool_calls to prevent duplicate observations def execute_parsed_tool_calls(state: AgentState): """ Execute tool calls that were parsed from the Thought/Action/Action Input format. This is called by call_tool when parsed_tool_calls are present in state. - - Args: - state (AgentState): The current state containing parsed tool calls - - Returns: - AgentState: Updated state with tool results """ - - # Use the same tools list that's available globally - # Map tool names to the actual tool instances + + # Tool name mappings tool_name_mappings = { 'wikipedia_semantic_search': 'wikipedia_tool', - 'wikipedia': 'wikipedia_tool', - 'search': 'enhanced_search', # Updated mapping - 'duckduckgo_search': 'enhanced_search', # Updated mapping - 'web_search': 'enhanced_search', # Updated mapping - 'enhanced_search': 'enhanced_search', # Direct mapping + 'wikipedia': 'wikipedia_tool', + 'search': 'enhanced_search', + 'duckduckgo_search': 'enhanced_search', + 'web_search': 'enhanced_search', + 'enhanced_search': 'enhanced_search', 'youtube_screenshot_qa_tool': 'youtube_tool', 'youtube': 'youtube_tool', + 'youtube_transcript_extractor': 'youtube_transcript_extractor', + 'youtube_audio_tool': 'youtube_transcript_extractor' } - - - # Create a lookup by tool names for your existing tools list + + # Create a lookup by tool names tools_by_name = {} for tool in tools: tools_by_name[tool.name.lower()] = tool - # Also map by class name for flexibility - class_name = tool.__class__.__name__.lower() - if 'wikipedia' in class_name: - tools_by_name['wikipedia_tool'] = tool - elif 'search' in class_name or 'duck' in class_name: - tools_by_name['search_tool'] = tool - elif 'youtube' in class_name: - tools_by_name['youtube_tool'] = tool - + # Copy messages to avoid mutation during iteration new_messages = list(state["messages"]) - + + # Process each tool call ONCE for tool_call in state['parsed_tool_calls']: action = tool_call['action'] action_input = tool_call['action_input'] - thought = tool_call['thought'] normalized_action = tool_call['normalized_action'] - + print(f"๐Ÿš€ Executing tool: {action} with input: {action_input}") - + # Find the tool instance tool_instance = None - - # Try direct name match first if normalized_action in tools_by_name: tool_instance = tools_by_name[normalized_action] - # Try mapped name elif normalized_action in tool_name_mappings: mapped_name = tool_name_mappings[normalized_action] if mapped_name in tools_by_name: tool_instance = tools_by_name[mapped_name] - + if tool_instance: try: - result = tool_instance.run(action_input) + # Pass file_path if the tool expects it + if hasattr(tool_instance, 'run'): + if 'file_path' in tool_instance.run.__code__.co_varnames: + result = tool_instance.run(action_input, file_path=state.get('file_path')) + else: + result = tool_instance.run(action_input) + else: + result = str(tool_instance) + + # Truncate long results if len(result) > 6000: result = result[:6000] + "... [Result truncated due to length]" - - # Create observation message in the format your agent expects - from langchain_core.messages import AIMessage - observation = f"Observation: {result}" - observation_message = AIMessage(content=observation) - new_messages.append(observation_message) - + + # Create a SINGLE observation message + from langchain_core.messages import ToolMessage + tool_message = ToolMessage( + content=f"Observation: {result}", + name=action, + tool_call_id=str(uuid.uuid4()) + ) + new_messages.append(tool_message) print(f"โœ… Tool '{action}' executed successfully") - + except Exception as e: print(f"โŒ Error executing tool '{action}': {e}") - from langchain_core.messages import AIMessage - error_msg = f"Observation: Error executing '{action}': {str(e)}" - error_message = AIMessage(content=error_msg) + from langchain_core.messages import ToolMessage + error_message = ToolMessage( + content=f"Observation: Error executing '{action}': {str(e)}", + name=action, + tool_call_id=str(uuid.uuid4()) + ) new_messages.append(error_message) else: print(f"โŒ Tool '{action}' not found in available tools") available_tool_names = list(tools_by_name.keys()) - from langchain_core.messages import AIMessage - error_msg = f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}" - error_message = AIMessage(content=error_msg) + from langchain_core.messages import ToolMessage + error_message = ToolMessage( + content=f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}", + name=action, + tool_call_id=str(uuid.uuid4()) + ) new_messages.append(error_message) - + # Update state with new messages and clear parsed tool calls - state["messages"] = new_messages - state['parsed_tool_calls'] = [] - - return state + new_state = state.copy() + new_state["messages"] = new_messages + new_state['parsed_tool_calls'] = [] # Clear to prevent re-execution + + return new_state + # 1. Add loop detection to your AgentState + def should_continue(state: AgentState) -> str: - """Determine if the agent should continue or end.""" - print("Running should_continue....") - messages = state["messages"] - - #ipdb.set_trace() + """Enhanced continuation logic with better limits.""" + print("Running should_continue...") - # Check if we're done + # Check done flag first if state.get("done", False): + print("โœ… Done flag is True - ending") return "end" - - # Prevent infinite loops - limit tool calls - tool_call_count = sum(1 for msg in messages if hasattr(msg, 'tool_calls') and msg.tool_calls) - if tool_call_count >= 3: # Max 3 tool calls per conversation - print(f"โš ๏ธ Stopping: Too many tool calls ({tool_call_count})") - return "end" - - # Check for repeated tool calls with same query - recent_tool_calls = [] - for msg in messages[-6:]: # Check last 6 messages - if hasattr(msg, 'tool_calls') and msg.tool_calls: - for tool_call in msg.tool_calls: - if isinstance(tool_call, dict): - recent_tool_calls.append((tool_call.get('name'), str(tool_call.get('args', {})))) - - if len(recent_tool_calls) >= 2 and recent_tool_calls[-1] == recent_tool_calls[-2]: - print("โš ๏ธ Stopping: Repeated tool call detected") - return "end" - - # Check message count to prevent runaway conversations - if len(messages) > 15: - print(f"โš ๏ธ Stopping: Too many messages ({len(messages)})") - return "end" - + + messages = state["messages"] + + # More aggressive message limit + #if len(messages) > 20: # Reduced from 15 + # print(f"โš ๏ธ Message limit reached ({len(messages)}/20) - forcing end") + # return "end" + + # Check for repeated patterns (stuck in loop) + if len(messages) >= 6: + recent_contents = [str(msg.content)[:100] for msg in messages[-6:] if hasattr(msg, 'content')] + if len(set(recent_contents)) < 3: # Too much repetition + print("๐Ÿ”„ Detected repetitive pattern - ending") + return "end" + + print(f"๐Ÿ“Š Continuing... ({len(messages)} messages so far)") return "continue" + def route_after_parse_react(state: AgentState) -> str: """Determines the next step after parsing LLM output, prioritizing end state.""" if state.get("done", False): # Check if parse_react_output decided we are done @@ -1530,16 +2925,6 @@ def route_after_parse_react(state: AgentState) -> str: return "call_tool" return "call_llm" -#wikipedia_tool = WikipediaSearchToolWithFAISS() -#search_tool = DuckDuckGoSearchRun() -#youtube_screenshot_qa_tool = YoutubeScreenshotQA() - -# Combine all tools -#tools = [wikipedia_tool, search_tool, youtube_screenshot_qa_tool] - -# Update your tools list to use the global instances -# - # --- Graph Construction --- # --- Graph Construction --- def create_memory_safe_workflow(): @@ -1589,100 +2974,97 @@ def create_memory_safe_workflow(): return workflow.compile() + +def count_english_words(text): + # Remove punctuation, lowercase, split into words + table = str.maketrans('', '', string.punctuation) + words_in_text = text.translate(table).lower().split() + return sum(1 for word in words_in_text if word in english_words) + +def fix_backwards_text(text): + reversed_text = text[::-1] + original_count = count_english_words(text) + reversed_count = count_english_words(reversed_text) + if reversed_count > original_count: + return reversed_text + else: + return text + + # --- Run the Agent --- -def run_agent(myagent, state: AgentState): - """ - Initialize agent with proper system message and formatted query. - """ - #global llm, hf_pipe, model_vqa, processor_vqa - global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, tools +# Enhanced system prompt for better behavior - #ipdb.set_trace() +def run_agent(agent, state: AgentState): + """Enhanced agent initialization with better prompt and hallucination prevention.""" + global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL, tools - # At the module level, create instances once + # Initialize tools WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS() SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=3000) - YOUTUBE_TOOL = YoutubeScreenshotQA() - - tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL] + YOUTUBE_TOOL = EnhancedYoutubeScreenshotQA() + YOUTUBE_AUDIO_TOOL = YouTubeTranscriptExtractor() + tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL] - # Create a fresh system message each time formatted_tools_description = render_text_description(tools) current_date_str = datetime.now().strftime("%Y-%m-%d") - - system_content = f"""You are a general AI assistant. with access to these tools: - {formatted_tools_description} + # Enhanced system prompt with stricter boundaries + system_content = f"""You are an AI assistant with access to these tools: - If you need the most current information as of 2025, use enhanced_search - If you need to do in-depth research, use wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval - If you can answer the question confidently, do so directly. - If the question seems like gibberish (not English), try flipping the entire question and re-read the question. - If you need more information, use a tool. - (Think through the problem step by step) +{formatted_tools_description} - When using a tool, follow this format: - Thought: - Action: - Action Input: +CRITICAL INSTRUCTIONS: +1. Answer ONLY the question asked by the human +2. Do NOT generate additional questions or continue conversations +3. Use tools ONLY when you need specific information you don't know +4. After using a tool, provide your FINAL ANSWER immediately +5. STOP after giving your FINAL ANSWER - do not continue - Only use the tools listed above for the Action: step. Do not invent new tool names or actions. If you need to reason, do so in the Thought: step. After using a tool, process its output in your Thought: step, not as an Action. +FORMAT for tool use: +Thought: +Action: +Action Input: - Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. - YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. - If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. - If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. - If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string - Do not provide disclaimers. - Do not provide supporting details. +When you have the answer, immediately provide: +FINAL ANSWER: [concise answer only] - """ +ANSWER FORMAT: +- Numbers: no commas, no units unless specified +- Strings: no articles, no abbreviations, digits in plain text +- Lists: comma-separated following above rules +- Be extremely brief and concise +- Do not provide additional context or explanations +- Do not provide parentheticals - # Get user question from AgentState - query = state['question'] +IMPORTANT: You are responding to ONE question only. Do not ask follow-up questions or generate additional dialogue. - # Pattern for YouTube - yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" - has_youtube = re.search(yt_pattern, query) is not None +Current date: {current_date_str} +""" - if has_youtube: - # Store the extracted YouTube URL in the state + query = fix_backwards_text(state['question']) + + # Check for YouTube URLs + yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" + if re.search(yt_pattern, query): url_match = re.search(r"(https?://[^\s]+)", query) if url_match: state['youtube_url'] = url_match.group(0) - - # Format the user query to guide the model better - formatted_query = f"""{query}""" - - # Initialize agent state with proper message types + + # Initialize messages system_message = SystemMessage(content=system_content) - human_message = HumanMessage(content=formatted_query) - - # Initialize state with properly typed messages and done=False - # state = {"messages": [system_message, human_message], "done": False} + human_message = HumanMessage(content=query) + state['messages'] = [system_message, human_message] state["done"] = False - # Use the new method to run the graph - result = myagent.invoke(state) + # Run the agent + result = agent.invoke(state) - # Check if FINAL ANSWER was given (i.e., workflow ended) + # Cleanup if result.get("done"): - #del llm - #del hf_pipe - #del model_vqa - #del processor_vqa - #torch.cuda.empty_cache() - #torch.cuda.ipc_collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() gc.collect() - print("Released GPU memory after FINAL ANSWER.") - # Re-initialize for the next run - #hf_pipe = create_llm_pipeline() - #llm = HuggingFacePipeline(pipeline=hf_pipe) - #print("Re-initilized llm...") - - # Extract and return just the messages for cleaner output - return result["messages"] - - + print("๐Ÿงน Released GPU memory after completion") + return result["messages"]