Samuel Thomas
use cpu
9c250a6
raw
history blame
135 kB
# Standard Library
import os
import re
import tempfile
import string
import glob
import shutil
import gc
import uuid
import signal
from datetime import datetime
from io import BytesIO
from contextlib import contextmanager
from langchain_huggingface import HuggingFacePipeline
from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set
import time
from collections import Counter
from pydantic import Field
import hashlib
import json
import numpy as np
import ast
from concurrent.futures import ThreadPoolExecutor, as_completed
from collections import Counter, defaultdict
# Third-Party Packages
import cv2
import requests
import wikipedia
import spacy
import yt_dlp
import librosa
from PIL import Image
from bs4 import BeautifulSoup
from duckduckgo_search import DDGS
from sentence_transformers import SentenceTransformer
from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline, AutoTokenizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import speech_recognition as sr
from pydub import AudioSegment
from pydub.silence import split_on_silence
import nltk
from nltk.corpus import words
# LangChain Ecosystem
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import WikipediaLoader
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.retrievers import BM25Retriever
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description
from langchain_core.documents import Document
# LangGraph
from langgraph.graph import START, END, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
# PyTorch
import torch
from functools import partial
from transformers import pipeline
# Additional Utilities
from datetime import datetime
from urllib.parse import urljoin, urlparse
import logging
nlp = spacy.load("en_core_web_sm")
# Ensure the word list is downloaded
nltk.download('words', quiet=True)
english_words = set(words.words())
logger = logging.getLogger(__name__)
# --- Model Configuration ---
def create_llm_pipeline():
#model_id = "meta-llama/Llama-2-13b-chat-hf"
#model_id = "meta-llama/Llama-3.3-70B-Instruct"
#model_id = "mistralai/Mistral-Small-24B-Base-2501"
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
#model_id = "Qwen/Qwen2-7B-Instruct"
# Load tokenizer explicitly with fast version
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=True, # Force fast tokenizer
add_prefix_space=True # Only if actually needed
)
return pipeline(
"text-generation",
model=model_id,
tokenizer = tokenizer,
device_map="cpu",
torch_dtype=torch.float16,
max_new_tokens=1024,
temperature=0.1
)
# Define file extension sets for each category
PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'}
CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'}
SPREADSHEET_EXTENSIONS = {
'.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm',
'.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet',
'.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2',
'.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx',
'.pmv', '.uos', '.txt'
}
def get_file_type(filename: str) -> str:
if not filename or '.' not in filename or filename == '':
return ''
ext = filename.lower().rsplit('.', 1)[-1]
dot_ext = f'.{ext}'
if dot_ext in PICTURE_EXTENSIONS:
return 'picture'
elif dot_ext in AUDIO_EXTENSIONS:
return 'audio'
elif dot_ext in CODE_EXTENSIONS:
return 'code'
elif dot_ext in SPREADSHEET_EXTENSIONS:
return 'spreadsheet'
else:
return 'unknown'
def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str:
"""
Writes bytes to a file in the system temporary directory using the provided file_name.
Returns the full path to the saved file.
The file will persist until manually deleted or the OS cleans the temp directory.
"""
temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces
os.makedirs(temp_dir, exist_ok=True)
file_path = os.path.join(temp_dir, file_name)
with open(file_path, 'wb') as f:
f.write(file_bytes)
print(f"File written to: {file_path}")
return file_path
def extract_final_answer(text: str) -> str:
"""
Returns the substring starting from the last occurrence of 'FINAL ANSWER:' (case-insensitive)
to the end of the string, with any trailing punctuation removed.
If not found, returns an empty string.
"""
marker = "FINAL ANSWER:"
idx = text.lower().rfind(marker.lower())
if idx == -1:
return ""
result = text[idx:].strip()
# Remove trailing punctuation
return result.rstrip(string.punctuation + " ")
class EnhancedDuckDuckGoSearchTool(BaseTool):
name: str = "enhanced_search"
description: str = (
"Performs a DuckDuckGo web search and retrieves actual content from the top web results. "
"Input should be a search query string. "
"Returns search results with extracted content from web pages, making it much more useful for answering questions. "
"Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. "
"Ideal for topics that require the latest news, recent developments, or information not covered in static sources."
)
max_results: int = 3
max_chars_per_page: int = 3000
session: Any = None # Now it's optional and defaults to None
# Use model_post_init for initialization logic in Pydantic v2+
def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
# Initialize HTTP session here
self.session = requests.Session()
self.session.headers.update({
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
'Accept-Language': 'en-US,en;q=0.5',
'Accept-Encoding': 'gzip, deflate',
'Connection': 'keep-alive',
'Upgrade-Insecure-Requests': '1',
})
def _search_duckduckgo(self, query: str) -> List[Dict]:
"""Perform DuckDuckGo search and return results."""
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=self.max_results))
return results
except Exception as e:
logger.error(f"DuckDuckGo search failed: {e}")
return []
def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]:
"""Extract clean text content from a web page."""
try:
# Skip certain file types
if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']):
return "Content type not supported for extraction"
response = self.session.get(url, timeout=timeout, allow_redirects=True)
response.raise_for_status()
# Check content type
content_type = response.headers.get('content-type', '').lower()
if 'text/html' not in content_type:
return "Non-HTML content detected"
soup = BeautifulSoup(response.content, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style", "nav", "header", "footer", "aside", "form"]):
script.decompose()
# Try to find main content areas
main_content = None
for selector in ['main', 'article', '.content', '#content', '.post', '.entry']:
main_content = soup.select_one(selector)
if main_content:
break
if not main_content:
main_content = soup.find('body') or soup
# Extract text
text = main_content.get_text(separator='\n', strip=True)
# Clean up the text
lines = [line.strip() for line in text.split('\n') if line.strip()]
text = '\n'.join(lines)
# Remove excessive whitespace
text = re.sub(r'\n{3,}', '\n\n', text)
text = re.sub(r' {2,}', ' ', text)
# Truncate if too long
if len(text) > self.max_chars_per_page:
text = text[:self.max_chars_per_page] + "\n[Content truncated...]"
return text
except requests.exceptions.Timeout:
return "Page loading timed out"
except requests.exceptions.RequestException as e:
return f"Failed to retrieve page: {str(e)}"
except Exception as e:
logger.error(f"Content extraction failed for {url}: {e}")
return "Failed to extract content from page"
def _format_search_result(self, result: Dict, content: str) -> str:
"""Format a single search result with its content."""
title = result.get('title', 'No title')
url = result.get('href', 'No URL')
snippet = result.get('body', 'No snippet')
formatted = f"""
🔍 **{title}**
URL: {url}
Snippet: {snippet}
📄 **Page Content:**
{content}
---
"""
return formatted
def run(self, query: str) -> str:
"""Execute the enhanced search."""
if not query or not query.strip():
return "Please provide a search query."
query = query.strip()
logger.info(f"Searching for: {query}")
# Perform DuckDuckGo search
search_results = self._search_duckduckgo(query)
if not search_results:
return f"No search results found for query: {query}"
# Process each result and extract content
enhanced_results = []
processed_count = 0
for i, result in enumerate(search_results[:self.max_results]):
url = result.get('href', '')
if not url:
continue
logger.info(f"Processing result {i+1}: {url}")
# Extract content from the page
content = self._extract_content_from_url(url)
if content and len(content.strip()) > 50: # Only include results with substantial content
formatted_result = self._format_search_result(result, content)
enhanced_results.append(formatted_result)
processed_count += 1
# Small delay to be respectful to servers
time.sleep(0.5)
if not enhanced_results:
return f"Search completed but no content could be extracted from the pages for query: {query}"
# Compile final response
response = f"""🔍 **Enhanced Search Results for: "{query}"**
Found {len(search_results)} results, successfully processed {processed_count} pages with content.
{''.join(enhanced_results)}
💡 **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query.
"""
# Ensure the response isn't too long
if len(response) > 8000:
response = response[:8000] + "\n[Response truncated to prevent memory issues]"
return response
def _run(self, query: str) -> str:
"""Required by BaseTool interface."""
return self.run(query)
# --- Agent State Definition ---
class AgentState(TypedDict):
messages: Annotated[List[AnyMessage], lambda x, y: x + y]
done: bool = False # Default value of False
question: str
task_id: str
input_file: Optional[bytes]
file_type: Optional[str]
context: List[Document] # Using LangChain's Document class
file_path: Optional[str]
youtube_url: Optional[str]
answer: Optional[str]
frame_answers: Optional[list]
def fetch_page_with_tables(page_title):
"""
Fetches Wikipedia page content and extracts all tables as readable text.
Returns a tuple: (main_text, [table_texts])
"""
# Fetch the page object
page = wikipedia.page(page_title)
main_text = page.content
# Get the HTML for table extraction
html = page.html()
soup = BeautifulSoup(html, 'html.parser')
tables = soup.find_all('table')
table_texts = []
for table in tables:
rows = table.find_all('tr')
table_lines = []
for row in rows:
cells = row.find_all(['th', 'td'])
cell_texts = [cell.get_text(strip=True) for cell in cells]
if cell_texts:
# Format as Markdown table row
table_lines.append(" | ".join(cell_texts))
if table_lines:
table_text = "\n".join(table_lines)
table_texts.append(table_text)
return main_text, table_texts
class WikipediaSearchToolWithFAISS(BaseTool):
name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval"
description: str = (
"Fetches content from multiple Wikipedia pages based on intelligent NLP query processing "
"of various search candidates, with strong prioritization of query entities. It then performs "
"entity-focused semantic search across all fetched content to find the most relevant information, "
"with improved retrieval for lists like discographies. Uses spaCy for named entity "
"recognition and query enhancement. Input should be a search query or topic. "
"Note: Uses the current live version of Wikipedia."
)
embedding_model_name: str = "all-MiniLM-L6-v2"
chunk_size: int = 4000
chunk_overlap: int = 250 # Maintained moderate overlap
top_k_results: int = 3
spacy_model: str = "en_core_web_sm"
# Increased multiplier to fetch more candidates per semantic query variant
semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic
def __init__(self, **kwargs):
super().__init__(**kwargs)
try:
self._nlp = spacy.load(self.spacy_model)
print(f"Loaded spaCy model: {self.spacy_model}")
self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name)
# Refined separators for better handling of Wikipedia lists and sections
self._text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
separators=[
"\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content)
"\n\n\n", "\n\n", # Multiple newlines (paragraph breaks)
"\n* ", "\n- ", "\n# ", # List items
"\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation
" ", "" # Word and character level
]
)
except OSError as e:
print(f"Error loading spaCy model '{self.spacy_model}': {e}")
print("Try running: python -m spacy download en_core_web_sm")
self._nlp = None
self._embedding_model = None
self._text_splitter = None
except Exception as e:
print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}")
self._nlp = None
self._embedding_model = None
self._text_splitter = None
def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]:
if not self._nlp:
return [], [], query
doc = self._nlp(query)
main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]]
keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2]
main_entities = list(dict.fromkeys(main_entities))
keywords = list(dict.fromkeys(keywords))
processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()]
processed_query = " ".join(processed_tokens)
return main_entities, keywords, processed_query
def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]:
candidates_set = set()
entity_prefix = main_entities[0] if main_entities else None
for me in main_entities:
candidates_set.add(me)
candidates_set.add(query)
if processed_query and processed_query != query:
candidates_set.add(processed_query)
if entity_prefix and keywords:
first_entity_lower = entity_prefix.lower()
for kw in keywords[:3]:
if kw not in first_entity_lower and len(kw) > 2:
candidates_set.add(f"{entity_prefix} {kw}")
keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2)
if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}")
if len(main_entities) > 1:
candidates_set.add(" ".join(main_entities[:2]))
if keywords:
keyword_combo = " ".join(keywords[:2])
if entity_prefix:
candidate_to_add = f"{entity_prefix} {keyword_combo}"
if not any(c.lower() == candidate_to_add.lower() for c in candidates_set):
candidates_set.add(candidate_to_add)
elif not main_entities:
candidates_set.add(keyword_combo)
ordered_candidates = []
for me in main_entities:
if me not in ordered_candidates: ordered_candidates.append(me)
for c in list(candidates_set):
if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c)
print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}")
return ordered_candidates
def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]:
candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text)
found_pages_data: List[Tuple[str, str]] = []
processed_page_titles: Set[str] = set()
for i, candidate_query in enumerate(candidates):
print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'")
page_object = None
final_page_title = None
is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False
try:
try:
page_to_load = candidate_query
suggest_mode = True # Default to auto_suggest=True
if is_candidate_entity_focused and main_entities_from_query:
try: # Attempt precise match first for entity-focused candidates
temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True)
suggest_mode = False # Flag that precise match worked
except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError):
print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.")
# Fallthrough to auto_suggest=True below if this fails
if suggest_mode: # If not attempted or failed with auto_suggest=False
temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True)
final_page_title = temp_page.title
if is_candidate_entity_focused and main_entities_from_query:
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query)
if not title_matches_main_entity:
print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') "
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.")
continue
if final_page_title in processed_page_titles:
print(f" ~ Already processed '{final_page_title}'")
continue
page_object = temp_page
print(f" ✓ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'")
except wikipedia.exceptions.PageError:
if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates
print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...")
search_results = wikipedia.search(candidate_query, results=1)
if not search_results:
print(f" - No Wikipedia search results for '{candidate_query}'.")
continue
search_result_title = search_results[0]
try:
temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical
final_page_title = temp_page.title
if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query)
if not title_matches_main_entity:
print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') "
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.")
continue
if final_page_title in processed_page_titles:
print(f" ~ Already processed '{final_page_title}'")
continue
page_object = temp_page
print(f" ✓ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'")
except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr:
print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}")
else:
print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.")
except wikipedia.exceptions.DisambiguationError as de:
print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}")
if de.options:
option_title = de.options[0]
try:
temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True)
final_page_title = temp_page.title
if is_candidate_entity_focused and main_entities_from_query: # Check against original intent
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query)
if not title_matches_main_entity:
print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') "
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.")
continue
if final_page_title in processed_page_titles:
print(f" ~ Already processed '{final_page_title}'")
continue
page_object = temp_page
print(f" ✓ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'")
except Exception as e_dis_opt:
print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}")
if page_object and final_page_title and (final_page_title not in processed_page_titles):
# Extract main text
main_text = page_object.content
# Extract tables using BeautifulSoup
try:
html = page_object.html()
soup = BeautifulSoup(html, 'html.parser')
tables = soup.find_all('table')
table_texts = []
for table in tables:
rows = table.find_all('tr')
table_lines = []
for row in rows:
cells = row.find_all(['th', 'td'])
cell_texts = [cell.get_text(strip=True) for cell in cells]
if cell_texts:
table_lines.append(" | ".join(cell_texts))
if table_lines:
table_text = "\n".join(table_lines)
table_texts.append(table_text)
except Exception as e:
print(f" !! Error extracting tables for '{final_page_title}': {e}")
table_texts = []
# Combine main text and all table texts as separate chunks
all_text_chunks = [main_text] + table_texts
for chunk in all_text_chunks:
found_pages_data.append((chunk, final_page_title))
processed_page_titles.add(final_page_title)
print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}")
except Exception as e:
print(f" !! Unexpected error processing candidate '{candidate_query}': {e}")
if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.")
else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.")
return found_pages_data
def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]:
core_query_parts = set()
core_query_parts.add(query)
if processed_query != query: core_query_parts.add(processed_query)
if keywords: core_query_parts.add(" ".join(keywords[:2]))
section_phrases_templates = []
lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords)
section_keywords_map = {
"discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"],
"biography": ["biography", "life story", "career details", "background history"],
"filmography": ["filmography", "list of films", "movie appearances", "acting roles"],
}
for section_term_key, specific_phrases_list in section_keywords_map.items():
# Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums")
# are mentioned or implied by the query terms.
if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()):
section_phrases_templates.extend(specific_phrases_list)
# Also check if phrases themselves are in query terms (e.g. query "list of albums by X")
for phrase in specific_phrases_list:
if phrase in query.lower(): # Check against original query for direct phrase matches
section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit
break
section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate
final_search_queries = set()
if main_entities:
entity_prefix = main_entities[0]
final_search_queries.add(entity_prefix)
for part in core_query_parts:
final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part)
for phrase_template in section_phrases_templates:
final_search_queries.add(f"{entity_prefix} {phrase_template}")
if "list of" in phrase_template or "history of" in phrase_template :
final_search_queries.add(f"{phrase_template} of {entity_prefix}")
else:
final_search_queries.update(core_query_parts)
final_search_queries.update(section_phrases_templates)
deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip()))
print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}")
all_results_docs: List[Document] = []
seen_content_hashes: Set[int] = set()
k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier
for search_query_variant in deduplicated_queries:
try:
results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch)
print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.")
for doc, score in results: # Assuming similarity_search_with_score returns (doc, score)
content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness
if content_hash not in seen_content_hashes:
seen_content_hashes.add(content_hash)
doc.metadata['retrieved_by_variant'] = search_query_variant
doc.metadata['retrieval_score'] = float(score) # Store score
all_results_docs.append(doc)
except Exception as e:
print(f" Error in semantic search for variant '{search_query_variant}': {e}")
# Sort all collected unique results by score (FAISS L2 distance is lower is better)
all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf')))
print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.")
return all_results_docs[:self.top_k_results]
def _run(self, query: str = None, search_query: str = None, **kwargs) -> str:
if not self._nlp or not self._embedding_model or not self._text_splitter:
print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.")
return "Error: Wikipedia tool components not initialized properly. Please check server logs."
if not query:
query = search_query or kwargs.get('q') or kwargs.get('search_term')
try:
print(f"\n--- Running {self.name} for query: '{query}' ---")
main_entities, keywords, processed_query = self._extract_entities_and_keywords(query)
print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'")
fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query)
if not fetched_pages_data:
return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. "
f"Main entities sought: {main_entities}")
all_page_titles = [title for _, title in fetched_pages_data]
print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}")
all_documents: List[Document] = []
for page_content, page_title in fetched_pages_data:
chunks = self._text_splitter.split_text(page_content)
if not chunks:
print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.")
continue
for i, chunk_text in enumerate(chunks):
all_documents.append(Document(page_content=chunk_text, metadata={
"source_page_title": page_title,
"original_query": query,
"chunk_index": i # Add chunk index for potential debugging or ordering
}))
print(f"Split content from '{page_title}' into {len(chunks)} chunks.")
if not all_documents:
return (f"Could not process content into searchable chunks from the fetched Wikipedia pages "
f"({', '.join(all_page_titles)}) for query '{query}'.")
print(f"\nTotal document chunks from all pages: {len(all_documents)}")
print("Creating FAISS index from content of all fetched pages...")
try:
vector_store = FAISS.from_documents(all_documents, self._embedding_model)
print("FAISS index created successfully.")
except Exception as e:
return f"Error creating FAISS vector store: {e}"
print(f"\nPerforming enhanced semantic search across all collected content...")
try:
relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query)
except Exception as e:
return f"Error during semantic search: {e}"
if not relevant_docs:
return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' "
f"for your query '{query}' using entity-focused semantic search with list retrieval.")
unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs]))
result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) "
f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n")
nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, "
f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n")
result_details = []
for i, doc in enumerate(relevant_docs):
source_info = doc.metadata.get('source_page_title', 'Unknown Source')
variant_info = doc.metadata.get('retrieved_by_variant', 'N/A')
score_info = doc.metadata.get('retrieval_score', 'N/A')
detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n"
f"(Retrieved by: '{variant_info}')\n{doc.page_content}")
result_details.append(detail)
final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details)
print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).")
return final_result.strip()
except Exception as e:
import traceback
print(f"Unexpected error in {self.name}: {traceback.format_exc()}")
return f"An unexpected error occurred: {str(e)}"
class EnhancedYoutubeScreenshotQA(BaseTool):
name: str = "enhanced_youtube_screenshot_qa"
description: str = (
"Downloads a YouTube video, intelligently extracts screenshots, "
"and answers questions using advanced visual QA with semantic analysis. "
"Use this tool for questions about the VIDEO or IMAGES in the video,"
"Input should be a dict with keys: 'youtube_url', 'question', and optional parameters. "
#"Optional parameters: 'frame_interval_seconds' (default: 10), 'max_frames' (default: 50), "
#"'use_scene_detection' (default: True), 'parallel_processing' (default: True). "
"Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'question': 'What animals are visible?'}"
)
# Define Pydantic fields for the attributes we need to set
device: Any = Field(default=None, exclude=True)
processor_vqa: Any = Field(default=None, exclude=True)
model_vqa: Any = Field(default=None, exclude=True)
class Config:
# Allow arbitrary types (needed for torch.device, model objects)
arbitrary_types_allowed = True
# Allow extra fields to be set
extra = "allow"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Initialize directories
cache_dir = '/tmp/youtube_qa_cache/'
video_dir = '/tmp/video/'
frames_dir = '/tmp/video_frames/'
# Initialize model and device
self._initialize_model()
# Create directories
for dir_path in [cache_dir, video_dir, frames_dir]:
os.makedirs(dir_path, exist_ok=True)
def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None):
"""Get configuration value with fallback to defaults"""
defaults = {
'frame_interval_seconds': 10,
'max_frames': 50,
'use_scene_detection': True,
'resize_frames': True,
'parallel_processing': True,
'cache_enabled': True,
'quality_threshold': 30.0,
'semantic_similarity_threshold': 0.8
}
if input_data and key in input_data:
return input_data[key]
return defaults.get(key, default_value)
def _initialize_model(self):
"""Initialize BLIP model for VQA with error handling"""
try:
#self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cpu")
print(f"Using device: {self.device}")
self.processor_vqa = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.model_vqa = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base"
).to(self.device)
print("BLIP VQA model loaded successfully")
except Exception as e:
print(f"Error initializing VQA model: {str(e)}")
raise
def _get_video_hash(self, url: str) -> str:
"""Generate hash for video URL for caching"""
return hashlib.md5(url.encode()).hexdigest()
def _get_cache_path(self, video_hash: str, cache_type: str) -> str:
"""Get cache file path"""
cache_dir = '/tmp/youtube_qa_cache/'
return os.path.join(cache_dir, f"{video_hash}_{cache_type}")
def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]:
"""Load data from cache"""
if not cache_enabled or not os.path.exists(cache_path):
return None
try:
with open(cache_path, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading cache: {str(e)}")
return None
def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True):
"""Save data to cache"""
if not cache_enabled:
return
try:
with open(cache_path, 'w') as f:
json.dump(data, f)
except Exception as e:
print(f"Error saving cache: {str(e)}")
def download_youtube_video(self, url: str, video_hash: str, cache_enabled: bool = True) -> Optional[str]:
"""Enhanced YouTube video download with caching"""
video_dir = '/tmp/video/'
output_filename = f'{video_hash}.mp4'
output_path = os.path.join(video_dir, output_filename)
# Check cache
if cache_enabled and os.path.exists(output_path):
print(f"Using cached video: {output_path}")
return output_path
# Clean directory
video_dir = '/tmp/video/'
self._clean_directory(video_dir)
try:
ydl_opts = {
'format': 'bestvideo[height<=720][ext=mp4]+bestaudio[ext=m4a]/best[height<=720][ext=mp4]/best',
'outtmpl': output_path,
'quiet': True,
'merge_output_format': 'mp4',
'postprocessors': [{
'key': 'FFmpegVideoConvertor',
'preferedformat': 'mp4',
}]
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
if os.path.exists(output_path):
print(f"Video downloaded successfully: {output_path}")
return output_path
else:
print("Download completed but file not found")
return None
except Exception as e:
print(f"Error downloading YouTube video: {str(e)}")
return None
def _clean_directory(self, directory: str):
"""Clean directory contents"""
if os.path.exists(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
def _assess_frame_quality(self, frame: np.ndarray) -> float:
"""Assess frame quality using Laplacian variance (blur detection)"""
try:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
return cv2.Laplacian(gray, cv2.CV_64F).var()
except Exception:
return 0.0
def _detect_scene_changes(self, video_path: str, threshold: float = 30.0) -> List[int]:
"""Detect scene changes in video"""
scene_frames = []
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return []
prev_frame = None
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if prev_frame is not None:
# Calculate histogram difference
hist1 = cv2.calcHist([prev_frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
hist2 = cv2.calcHist([frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CHISQR)
if diff > threshold:
scene_frames.append(frame_count)
prev_frame = frame.copy()
frame_count += 1
cap.release()
return scene_frames
except Exception as e:
print(f"Error in scene detection: {str(e)}")
return []
def smart_extract_frames(self, video_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> List[str]:
"""Intelligently extract frames with quality filtering and scene detection"""
cache_enabled = self._get_config('cache_enabled', True, input_data)
cache_path = self._get_cache_path(video_hash, "frames_info.json")
cached_info = self._load_from_cache(cache_path, cache_enabled)
if cached_info:
# Verify cached frames still exist
existing_frames = [f for f in cached_info['frame_paths'] if os.path.exists(f)]
if len(existing_frames) == len(cached_info['frame_paths']):
print(f"Using {len(existing_frames)} cached frames")
return existing_frames
# Clean frames directory
frames_dir = '/tmp/video_frames/'
self._clean_directory(frames_dir)
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error: Could not open video.")
return []
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval_seconds = self._get_config('frame_interval_seconds', 10, input_data)
frame_interval = max(1, int(fps * frame_interval_seconds))
print(f"Video info: {total_frames} frames, {fps:.2f} fps")
# Get scene change frames if enabled
scene_frames = set()
use_scene_detection = self._get_config('use_scene_detection', True, input_data)
if use_scene_detection:
scene_frames = set(self._detect_scene_changes(video_path))
print(f"Detected {len(scene_frames)} scene changes")
extracted_frames = []
frame_count = 0
saved_count = 0
max_frames = self._get_config('max_frames', 50, input_data)
while True:
ret, frame = cap.read()
if not ret or saved_count >= max_frames:
break
# Check if we should extract this frame
should_extract = (
frame_count % frame_interval == 0 or
frame_count in scene_frames
)
if should_extract:
# Assess frame quality
quality = self._assess_frame_quality(frame)
quality_threshold = self._get_config('quality_threshold', 30.0, input_data)
if quality >= quality_threshold:
# Resize frame if enabled
resize_frames = self._get_config('resize_frames', True, input_data)
if resize_frames:
height, width = frame.shape[:2]
if width > 800:
scale = 800 / width
new_width = 800
new_height = int(height * scale)
frame = cv2.resize(frame, (new_width, new_height))
frame_filename = os.path.join(
frames_dir,
f"frame_{frame_count:06d}_q{quality:.1f}.jpg"
)
if cv2.imwrite(frame_filename, frame):
extracted_frames.append(frame_filename)
saved_count += 1
print(f"Extracted frame {saved_count}/{max_frames} "
f"(quality: {quality:.1f})")
frame_count += 1
cap.release()
# Cache frame information
frame_info = {
'frame_paths': extracted_frames,
'extraction_time': time.time(),
'total_frames_processed': frame_count,
'frames_extracted': len(extracted_frames)
}
self._save_to_cache(cache_path, frame_info, cache_enabled)
print(f"Successfully extracted {len(extracted_frames)} high-quality frames")
return extracted_frames
except Exception as e:
print(f"Exception during frame extraction: {e}")
return []
def _answer_question_on_frame(self, frame_path: str, question: str) -> Tuple[str, float]:
"""Answer question on single frame with confidence scoring"""
try:
image = Image.open(frame_path).convert('RGB')
inputs = self.processor_vqa(image, question, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model_vqa.generate(**inputs, output_scores=True, return_dict_in_generate=True)
answer = self.processor_vqa.decode(outputs.sequences[0], skip_special_tokens=True)
# Calculate confidence (simplified - you might want to use actual model confidence)
confidence = 1.0 # Placeholder - BLIP doesn't directly provide confidence
return answer, confidence
except Exception as e:
print(f"Error processing frame {frame_path}: {str(e)}")
return "Error processing this frame", 0.0
def _process_frames_parallel(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> List[Tuple[str, str, float]]:
"""Process frames in parallel"""
results = []
parallel_processing = self._get_config('parallel_processing', True, input_data)
if parallel_processing:
with ThreadPoolExecutor(max_workers=min(4, len(frame_files))) as executor:
future_to_frame = {
executor.submit(self._answer_question_on_frame, frame_path, question): frame_path
for frame_path in frame_files
}
for future in as_completed(future_to_frame):
frame_path = future_to_frame[future]
try:
answer, confidence = future.result()
results.append((frame_path, answer, confidence))
print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})")
except Exception as e:
print(f"Error processing {frame_path}: {str(e)}")
results.append((frame_path, "Error", 0.0))
else:
for frame_path in frame_files:
answer, confidence = self._answer_question_on_frame(frame_path, question)
results.append((frame_path, answer, confidence))
print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})")
return results
def _cluster_similar_answers(self, answers: List[str], input_data: Dict[str, Any] = None) -> Dict[str, List[str]]:
"""Cluster semantically similar answers"""
if len(answers) <= 1:
return {answers[0]: answers} if answers else {}
try:
# First try with standard TF-IDF settings
vectorizer = TfidfVectorizer(
stop_words='english',
lowercase=True,
min_df=1, # Include words that appear in at least 1 document
max_df=1.0 # Include words that appear in up to 100% of documents
)
tfidf_matrix = vectorizer.fit_transform(answers)
# Check if we have any features after TF-IDF
if tfidf_matrix.shape[1] == 0:
raise ValueError("No features after TF-IDF processing")
# Calculate cosine similarity
similarity_matrix = cosine_similarity(tfidf_matrix)
# Cluster similar answers
clusters = defaultdict(list)
used = set()
semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data)
for i, answer in enumerate(answers):
if i in used:
continue
cluster_key = answer
clusters[cluster_key].append(answer)
used.add(i)
# Find similar answers
for j in range(i + 1, len(answers)):
if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold:
clusters[cluster_key].append(answers[j])
used.add(j)
return dict(clusters)
except (ValueError, Exception) as e:
print(f"Error in semantic clustering: {str(e)}")
# Fallback 1: Try without stop words filtering
try:
print("Attempting clustering without stop word filtering...")
vectorizer_no_stop = TfidfVectorizer(
lowercase=True,
min_df=1,
token_pattern=r'\b\w+\b' # Match any word
)
tfidf_matrix = vectorizer_no_stop.fit_transform(answers)
if tfidf_matrix.shape[1] > 0:
similarity_matrix = cosine_similarity(tfidf_matrix)
clusters = defaultdict(list)
used = set()
semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data)
for i, answer in enumerate(answers):
if i in used:
continue
cluster_key = answer
clusters[cluster_key].append(answer)
used.add(i)
for j in range(i + 1, len(answers)):
if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold:
clusters[cluster_key].append(answers[j])
used.add(j)
return dict(clusters)
except Exception as e2:
print(f"Fallback clustering also failed: {str(e2)}")
# Fallback 2: Simple string-based clustering
print("Using simple string-based clustering...")
return self._simple_string_cluster(answers)
def _simple_string_cluster(self, answers: List[str]) -> Dict[str, List[str]]:
"""Simple string-based clustering fallback"""
clusters = defaultdict(list)
# Normalize answers for comparison
normalized_answers = {}
for answer in answers:
normalized = answer.lower().strip()
normalized_answers[answer] = normalized
used = set()
for i, answer in enumerate(answers):
if answer in used:
continue
cluster_key = answer
clusters[cluster_key].append(answer)
used.add(answer)
# Find similar answers using simple string similarity
for j, other_answer in enumerate(answers[i+1:], i+1):
if other_answer in used:
continue
# Check for exact match after normalization
if normalized_answers[answer] == normalized_answers[other_answer]:
clusters[cluster_key].append(other_answer)
used.add(other_answer)
# Alternatively, check if one string contains the other
elif (normalized_answers[answer] in normalized_answers[other_answer] or
normalized_answers[other_answer] in normalized_answers[answer]):
clusters[cluster_key].append(other_answer)
used.add(other_answer)
return dict(clusters)
def _analyze_temporal_patterns(self, results: List[Tuple[str, str, float]]) -> Dict[str, Any]:
"""Analyze temporal patterns in answers"""
try:
# Sort by frame number
def get_frame_number(frame_path):
match = re.search(r'frame_(\d+)', os.path.basename(frame_path))
return int(match.group(1)) if match else 0
sorted_results = sorted(results, key=lambda x: get_frame_number(x[0]))
# Analyze answer changes over time
answers_timeline = [result[1] for result in sorted_results]
changes = []
for i in range(1, len(answers_timeline)):
if answers_timeline[i] != answers_timeline[i-1]:
changes.append({
'frame_index': i,
'from_answer': answers_timeline[i-1],
'to_answer': answers_timeline[i]
})
return {
'total_changes': len(changes),
'change_points': changes,
'stability_ratio': 1 - (len(changes) / max(1, len(answers_timeline) - 1)),
'answers_timeline': answers_timeline
}
except Exception as e:
print(f"Error in temporal analysis: {str(e)}")
return {'error': str(e)}
def analyze_video_question(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]:
"""Comprehensive video question analysis"""
if not frame_files:
return {
"final_answer": "No frames available for analysis.",
"confidence": 0.0,
"frame_count": 0,
"error": "No valid frames found"
}
# Process all frames
print(f"Processing {len(frame_files)} frames...")
results = self._process_frames_parallel(frame_files, question, input_data)
if not results:
return {
"final_answer": "Could not analyze any frames successfully.",
"confidence": 0.0,
"frame_count": 0,
"error": "Frame processing failed"
}
# Extract answers and confidences
answers = [result[1] for result in results if result[1] != "Error"]
confidences = [result[2] for result in results if result[1] != "Error"]
# Calculate statistical summary on numeric answers
numeric_answers = []
for answer in answers:
try:
# Try to convert answer to float
numeric_value = float(answer)
numeric_answers.append(numeric_value)
except (ValueError, TypeError):
# Skip non-numeric answers
pass
if numeric_answers:
stats = {
"minimum": float(np.min(numeric_answers)),
"maximum": float(np.max(numeric_answers)),
"range": float(np.max(numeric_answers) - np.min(numeric_answers)),
"mean": float(np.mean(numeric_answers)),
"median": float(np.median(numeric_answers)),
"count": len(numeric_answers),
"data_type": "answers"
}
elif confidences:
# Fallback to confidence statistics if no numeric answers
stats = {
"minimum": float(np.min(confidences)),
"maximum": float(np.max(confidences)),
"range": float(np.max(confidences) - np.min(confidences)),
"mean": float(np.mean(confidences)),
"median": float(np.median(confidences)),
"count": len(confidences),
"data_type": "confidences"
}
else:
stats = {
"minimum": 0.0,
"maximum": 0.0,
"range": 0.0,
"mean": 0.0,
"median": 0.0,
"count": 0,
"data_type": "none",
"note": "No numeric results available for statistical summary"
}
if not answers:
return {
"final_answer": "All frame processing failed.",
"confidence": 0.0,
"frame_count": len(frame_files),
"error": "No successful frame analysis"
}
# Cluster similar answers
answer_clusters = self._cluster_similar_answers(answers, input_data)
# Find most common cluster
largest_cluster = max(answer_clusters.items(), key=lambda x: len(x[1]))
most_common_answer = largest_cluster[0]
cluster_size = len(largest_cluster[1])
# Calculate weighted confidence
answer_counts = Counter(answers)
total_answers = len(answers)
frequency_confidence = answer_counts[most_common_answer] / total_answers
avg_confidence = np.mean(confidences) if confidences else 0.0
final_confidence = (frequency_confidence * 0.7) + (avg_confidence * 0.3)
# Temporal analysis
temporal_analysis = self._analyze_temporal_patterns(results)
return {
"final_answer": most_common_answer,
"confidence": final_confidence,
"frame_count": len(frame_files),
"successful_analyses": len(answers),
"answer_distribution": dict(answer_counts),
"semantic_clusters": {k: len(v) for k, v in answer_clusters.items()},
"temporal_analysis": temporal_analysis,
"average_model_confidence": avg_confidence,
"frequency_confidence": frequency_confidence,
"statistical_summary": stats
}
#def _run(self, query: Dict[str, Any]) -> str:
def _run(self, youtube_url, question, **kwargs) -> str:
"""Enhanced main execution method"""
#ipdb.set_trace()
#input_data = query
#youtube_url = input_data.get("youtube_url")
#question = input_data.get("question")
input_data = {
'youtube_url': youtube_url,
'question': question
}
if not youtube_url or not question:
return "Error: Input must include 'youtube_url' and 'question'."
try:
# Generate video hash for caching
video_hash = self._get_video_hash(youtube_url)
# Step 1: Download video
print(f"Downloading YouTube video from {youtube_url}...")
cache_enabled = self._get_config('cache_enabled', True, input_data)
video_path = self.download_youtube_video(youtube_url, video_hash, cache_enabled)
if not video_path or not os.path.exists(video_path):
return "Error: Failed to download the YouTube video."
# Step 2: Smart frame extraction
print(f"Extracting frames with smart selection...")
frame_files = self.smart_extract_frames(video_path, video_hash, input_data)
if not frame_files:
return "Error: Failed to extract frames from the video."
# Step 3: Comprehensive analysis
print(f"Analyzing {len(frame_files)} frames for question: '{question}'")
analysis_result = self.analyze_video_question(frame_files, question, input_data)
if analysis_result.get("error"):
return f"Error: {analysis_result['error']}"
# Format comprehensive result - Fixed the reference to stats
result = f"""
📊 **ANALYSIS SUMMARY**:
• Confidence Score: {analysis_result['confidence']:.2%}
• Frames Analyzed: {analysis_result['successful_analyses']}/{analysis_result['frame_count']}
• Answer Consistency: {analysis_result['temporal_analysis'].get('stability_ratio', 0):.2%}
📈 **ANSWER DISTRIBUTION**:
{chr(10).join([f"• {answer}: {count} frames" for answer, count in analysis_result['answer_distribution'].items()])}
🔍 **SEMANTIC CLUSTERS**:
{chr(10).join([f"• '{cluster}': {count} similar answers" for cluster, count in analysis_result['semantic_clusters'].items()])}
⏱️ **TEMPORAL ANALYSIS**:
• Answer Changes: {analysis_result['temporal_analysis'].get('total_changes', 0)}
• Stability: {analysis_result['temporal_analysis'].get('stability_ratio', 0):.2%}
📊 **STATISTICAL SUMMARY**:
• Minimum: {analysis_result['statistical_summary']['minimum']:.2f}
• Maximum: {analysis_result['statistical_summary']['maximum']:.2f}
• Mean: {analysis_result['statistical_summary']['mean']:.2f}
• Median: {analysis_result['statistical_summary']['median']:.2f}
• Range: {analysis_result['statistical_summary']['range']:.2f}
🎯 **CONFIDENCE BREAKDOWN**:
• Frequency-based: {analysis_result['frequency_confidence']:.2%}
• Model-based: {analysis_result['average_model_confidence']:.2%}
• Combined: {analysis_result['confidence']:.2%}
""".strip()
return result
except Exception as e:
return f"Error during video analysis: {str(e)}"
# Initialize the enhanced tool
def create_enhanced_youtube_qa_tool(**kwargs):
"""Factory function to create the enhanced tool with custom parameters"""
return EnhancedYoutubeScreenshotQA(**kwargs)
# Example of creating the tool instance:
# wikipedia_tool_faiss = WikipediaSearchToolWithFAISS()
# To use this new tool in your agent, you would replace the old
# `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list.
# For example:
# tools = [wikipedia_tool_faiss, search_tool]
# Create tool instances
#wikipedia_tool = WikipediaSearchTool()
# --- Define Call LLM function ---
# 3. Improved LLM call with memory management
class YouTubeTranscriptExtractor(BaseTool):
name: str = "youtube_transcript_extractor"
description: str = (
"Downloads a YouTube video and extracts the complete audio transcript using speech recognition with speaker identification. "
"Use this tool when you need the AUDIO or DIALOGUE or sound from a YouTube video with speaker tags,"
"Input should be a dict with keys: 'youtube_url' and optional parameters. "
"Optional parameters: 'language' (default: 'en-US'), 'chunk_length_ms' (default: 30000), "
"'silence_thresh' (default: -40), 'use_enhanced_model' (default: True), 'audio_quality' (default: 'best'), "
"'enable_speaker_id' (default: True), 'max_speakers' (default: 5), 'speaker_min_duration' (default: 2.0). "
"Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'language': 'en-US', 'enable_speaker_id': True}"
)
# Define Pydantic fields for the attributes we need to set
recognizer: Any = Field(default=None, exclude=True)
class Config:
# Allow arbitrary types
arbitrary_types_allowed = True
# Allow extra fields to be set
extra = "allow"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Initialize directories
cache_dir = '/tmp/youtube_transcript_cache/'
audio_dir = '/tmp/audio/'
chunks_dir = '/tmp/audio_chunks/'
# Initialize speech recognizer
self.recognizer = sr.Recognizer()
# Create directories
for dir_path in [cache_dir, audio_dir, chunks_dir]:
os.makedirs(dir_path, exist_ok=True)
def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None):
"""Get configuration value with fallback to defaults"""
defaults = {
'language': 'en-US',
'chunk_length_ms': 30000, # 30 seconds
'silence_thresh': -40, # dB
'use_enhanced_model': True,
'audio_quality': 'best',
'cache_enabled': True,
'parallel_processing': True,
'overlap_ms': 1000, # 1 second overlap between chunks
'min_silence_len': 500, # minimum silence length to split on
'energy_threshold': 4000, # recognizer energy threshold
'pause_threshold': 0.8, # recognizer pause threshold
'enable_speaker_id': True, # enable speaker identification
'max_speakers': 5, # maximum number of speakers to identify
'speaker_min_duration': 2.0, # minimum duration (seconds) for speaker segment
'speaker_confidence_threshold': 0.6, # confidence threshold for speaker assignment
'voice_activity_threshold': 0.01 # threshold for voice activity detection
}
if input_data and key in input_data:
return input_data[key]
return defaults.get(key, default_value)
def _get_video_hash(self, url: str) -> str:
"""Generate hash for video URL for caching"""
return hashlib.md5(url.encode()).hexdigest()
def _get_cache_path(self, video_hash: str, cache_type: str) -> str:
"""Get cache file path"""
cache_dir = '/tmp/youtube_transcript_cache/'
return os.path.join(cache_dir, f"{video_hash}_{cache_type}")
def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]:
"""Load data from cache"""
if not cache_enabled or not os.path.exists(cache_path):
return None
try:
with open(cache_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error loading cache: {str(e)}")
return None
def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True):
"""Save data to cache"""
if not cache_enabled:
return
try:
with open(cache_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving cache: {str(e)}")
def download_youtube_audio(self, url: str, video_hash: str, input_data: Dict[str, Any] = None) -> Optional[str]:
"""Download YouTube video as audio file"""
audio_dir = '/tmp/audio/'
audio_quality = self._get_config('audio_quality', 'best', input_data)
output_filename = f'{video_hash}.wav'
output_path = os.path.join(audio_dir, output_filename)
# Check cache
cache_enabled = self._get_config('cache_enabled', True, input_data)
if cache_enabled and os.path.exists(output_path):
print(f"Using cached audio: {output_path}")
return output_path
# Clean directory
self._clean_directory(audio_dir)
try:
# First download as mp4/webm
temp_video_path = os.path.join(audio_dir, f'{video_hash}_temp.%(ext)s')
ydl_opts = {
'format': 'bestaudio/best' if audio_quality == 'best' else 'worstaudio/worst',
'outtmpl': temp_video_path,
'quiet': True,
'extractaudio': True,
'audioformat': 'wav',
'audioquality': '192K' if audio_quality == 'best' else '64K',
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
# Find the downloaded file
temp_files = glob.glob(os.path.join(audio_dir, f'{video_hash}_temp.*'))
if not temp_files:
print("No temporary audio file found")
return None
temp_file = temp_files[0]
# Convert to WAV if not already
if not temp_file.endswith('.wav'):
try:
audio = AudioSegment.from_file(temp_file)
audio.export(output_path, format="wav")
os.remove(temp_file) # Clean up temp file
except Exception as e:
print(f"Error converting audio: {str(e)}")
# Try to rename if it's already the right format
if os.path.exists(temp_file):
os.rename(temp_file, output_path)
else:
os.rename(temp_file, output_path)
if os.path.exists(output_path):
print(f"Audio extracted successfully: {output_path}")
return output_path
else:
print("Audio extraction completed but file not found")
return None
except Exception as e:
print(f"Error downloading YouTube audio: {str(e)}")
return None
def _clean_directory(self, directory: str):
"""Clean directory contents"""
if os.path.exists(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
def _extract_voice_features(self, audio_path: str) -> Optional[np.ndarray]:
"""Extract voice features for speaker identification using librosa"""
try:
# Load audio with librosa
y, sr = librosa.load(audio_path, sr=None)
# Extract MFCC features (commonly used for speaker identification)
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
# Extract additional features
spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)
spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
zero_crossing_rate = librosa.feature.zero_crossing_rate(y)
# Combine features and take mean across time
features = np.concatenate([
np.mean(mfccs, axis=1),
np.mean(spectral_centroids),
np.mean(spectral_rolloff),
np.mean(zero_crossing_rate)
])
return features
except Exception as e:
print(f"Error extracting voice features from {audio_path}: {str(e)}")
return None
def _detect_voice_activity(self, audio_path: str, input_data: Dict[str, Any] = None) -> List[Tuple[float, float]]:
"""Detect voice activity in audio chunk"""
try:
y, sr = librosa.load(audio_path, sr=None)
# Simple voice activity detection based on energy
frame_length = int(0.025 * sr) # 25ms frames
hop_length = int(0.010 * sr) # 10ms hop
# Calculate short-time energy
energy = []
for i in range(0, len(y) - frame_length, hop_length):
frame = y[i:i + frame_length]
energy.append(np.sum(frame ** 2))
energy = np.array(energy)
threshold = self._get_config('voice_activity_threshold', 0.01, input_data)
# Find voice segments
voice_frames = energy > (np.max(energy) * threshold)
# Convert frame indices to time segments
voice_segments = []
in_voice = False
start_time = 0
for i, is_voice in enumerate(voice_frames):
time_sec = i * hop_length / sr
if is_voice and not in_voice:
start_time = time_sec
in_voice = True
elif not is_voice and in_voice:
voice_segments.append((start_time, time_sec))
in_voice = False
# Close last segment if needed
if in_voice:
voice_segments.append((start_time, len(y) / sr))
return voice_segments
except Exception as e:
print(f"Error in voice activity detection: {str(e)}")
return [(0, librosa.get_duration(filename=audio_path))]
def _split_audio_intelligent(self, audio_path: str, input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Split audio into chunks intelligently based on silence and voice activity"""
chunks_dir = '/tmp/audio_chunks/'
self._clean_directory(chunks_dir)
try:
# Load audio
audio = AudioSegment.from_wav(audio_path)
# Get configuration
chunk_length_ms = self._get_config('chunk_length_ms', 30000, input_data)
silence_thresh = self._get_config('silence_thresh', -40, input_data)
min_silence_len = self._get_config('min_silence_len', 500, input_data)
overlap_ms = self._get_config('overlap_ms', 1000, input_data)
# First try to split on silence
chunks = split_on_silence(
audio,
min_silence_len=min_silence_len,
silence_thresh=silence_thresh,
keep_silence=True
)
# If no silence-based splits or chunks too large, split by time
if not chunks or any(len(chunk) > chunk_length_ms * 2 for chunk in chunks):
print("Using time-based splitting...")
chunks = []
for i in range(0, len(audio), chunk_length_ms - overlap_ms):
chunk = audio[i:i + chunk_length_ms]
if len(chunk) > 1000: # Only add chunks longer than 1 second
chunks.append(chunk)
# Save chunks and create metadata
chunk_data = []
for i, chunk in enumerate(chunks):
if len(chunk) < 1000: # Skip very short chunks
continue
chunk_filename = os.path.join(chunks_dir, f"chunk_{i:04d}.wav")
chunk.export(chunk_filename, format="wav")
# Calculate timing information
start_time = sum(len(chunks[j]) for j in range(i)) / 1000.0 # in seconds
duration = len(chunk) / 1000.0 # in seconds
chunk_info = {
'filename': chunk_filename,
'index': i,
'start_time': start_time,
'duration': duration,
'end_time': start_time + duration
}
chunk_data.append(chunk_info)
print(f"Split audio into {len(chunk_data)} chunks")
return chunk_data
except Exception as e:
print(f"Error splitting audio: {str(e)}")
# Fallback: return original file
return [{
'filename': audio_path,
'index': 0,
'start_time': 0,
'duration': len(AudioSegment.from_wav(audio_path)) / 1000.0,
'end_time': len(AudioSegment.from_wav(audio_path)) / 1000.0
}]
def _transcribe_audio_chunk(self, chunk_info: Dict[str, Any], input_data: Dict[str, Any] = None) -> Dict[str, Any]:
"""Transcribe a single audio chunk"""
chunk_path = chunk_info['filename']
try:
language = self._get_config('language', 'en-US', input_data)
# Configure recognizer
self.recognizer.energy_threshold = self._get_config('energy_threshold', 4000, input_data)
self.recognizer.pause_threshold = self._get_config('pause_threshold', 0.8, input_data)
with sr.AudioFile(chunk_path) as source:
# Adjust for ambient noise
self.recognizer.adjust_for_ambient_noise(source, duration=0.5)
audio_data = self.recognizer.record(source)
# Try Google Speech Recognition first (most accurate)
try:
text = self.recognizer.recognize_google(audio_data, language=language)
result = {
'text': text,
'confidence': 1.0, # Google doesn't provide confidence
'method': 'google',
'chunk': os.path.basename(chunk_path),
'start_time': chunk_info['start_time'],
'end_time': chunk_info['end_time'],
'duration': chunk_info['duration'],
'index': chunk_info['index']
}
# Extract voice features if speaker ID is enabled
if self._get_config('enable_speaker_id', True, input_data):
features = self._extract_voice_features(chunk_path)
result['voice_features'] = features.tolist() if features is not None else None
return result
except sr.UnknownValueError:
# Try alternative recognition methods
try:
# Try with alternative language detection
text = self.recognizer.recognize_google(audio_data)
result = {
'text': text,
'confidence': 0.8, # Lower confidence for language mismatch
'method': 'google_auto',
'chunk': os.path.basename(chunk_path),
'start_time': chunk_info['start_time'],
'end_time': chunk_info['end_time'],
'duration': chunk_info['duration'],
'index': chunk_info['index']
}
if self._get_config('enable_speaker_id', True, input_data):
features = self._extract_voice_features(chunk_path)
result['voice_features'] = features.tolist() if features is not None else None
return result
except sr.UnknownValueError:
return {
'text': '[INAUDIBLE]',
'confidence': 0.0,
'method': 'failed',
'chunk': os.path.basename(chunk_path),
'start_time': chunk_info['start_time'],
'end_time': chunk_info['end_time'],
'duration': chunk_info['duration'],
'index': chunk_info['index'],
'voice_features': None
}
except sr.RequestError as e:
print(f"Google Speech Recognition error: {e}")
return {
'text': '[RECOGNITION_ERROR]',
'confidence': 0.0,
'method': 'error',
'chunk': os.path.basename(chunk_path),
'start_time': chunk_info['start_time'],
'end_time': chunk_info['end_time'],
'duration': chunk_info['duration'],
'index': chunk_info['index'],
'error': str(e),
'voice_features': None
}
except Exception as e:
print(f"Error transcribing chunk {chunk_path}: {str(e)}")
return {
'text': '[ERROR]',
'confidence': 0.0,
'method': 'error',
'chunk': os.path.basename(chunk_path),
'start_time': chunk_info.get('start_time', 0),
'end_time': chunk_info.get('end_time', 0),
'duration': chunk_info.get('duration', 0),
'index': chunk_info.get('index', 0),
'error': str(e),
'voice_features': None
}
def _identify_speakers(self, transcript_results: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Identify speakers using voice features clustering"""
enable_speaker_id = self._get_config('enable_speaker_id', True, input_data)
if not enable_speaker_id:
# Add default speaker tags
for result in transcript_results:
result['speaker_id'] = 'SPEAKER_1'
result['speaker_confidence'] = 1.0
return transcript_results
try:
# Filter results with valid voice features and text
valid_results = []
features_list = []
for result in transcript_results:
if (result.get('voice_features') is not None and
result['text'] not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']):
valid_results.append(result)
features_list.append(result['voice_features'])
if len(features_list) < 2:
# Not enough data for clustering
for result in transcript_results:
result['speaker_id'] = 'SPEAKER_1'
result['speaker_confidence'] = 1.0
return transcript_results
# Normalize features
features_array = np.array(features_list)
scaler = StandardScaler()
normalized_features = scaler.fit_transform(features_array)
# Determine optimal number of speakers
max_speakers = min(self._get_config('max_speakers', 5, input_data), len(features_list))
# Use elbow method to find optimal clusters (simplified)
best_k = 1
if len(features_list) > 1:
best_score = float('inf')
for k in range(1, min(max_speakers + 1, len(features_list) + 1)):
try:
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
labels = kmeans.fit_predict(normalized_features)
if k > 1:
score = kmeans.inertia_
if score < best_score:
best_score = score
best_k = k
except:
continue
# Don't use too many clusters for short audio
if len(features_list) < 10:
best_k = min(best_k, 2)
# Perform final clustering
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=10)
speaker_labels = kmeans.fit_predict(normalized_features)
# Calculate speaker assignment confidence
distances = kmeans.transform(normalized_features)
confidences = []
for i, label in enumerate(speaker_labels):
# Confidence based on distance to assigned cluster vs. nearest other cluster
dist_to_assigned = distances[i][label]
other_distances = np.delete(distances[i], label)
if len(other_distances) > 0:
dist_to_nearest_other = np.min(other_distances)
confidence = max(0.1, min(1.0, dist_to_nearest_other / (dist_to_assigned + 1e-6)))
else:
confidence = 1.0
confidences.append(confidence)
# Assign speaker IDs back to results
valid_idx = 0
speaker_duration = {} # Track duration per speaker
for result in transcript_results:
if (result.get('voice_features') is not None and
result['text'] not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']):
speaker_label = speaker_labels[valid_idx]
confidence = confidences[valid_idx]
# Filter by confidence threshold
conf_threshold = self._get_config('speaker_confidence_threshold', 0.6, input_data)
if confidence < conf_threshold:
speaker_id = 'SPEAKER_UNKNOWN'
else:
speaker_id = f'SPEAKER_{speaker_label + 1}'
result['speaker_id'] = speaker_id
result['speaker_confidence'] = confidence
# Track speaker duration
if speaker_id in speaker_duration:
speaker_duration[speaker_id] += result['duration']
else:
speaker_duration[speaker_id] = result['duration']
valid_idx += 1
else:
# Handle invalid results
result['speaker_id'] = 'SPEAKER_UNKNOWN'
result['speaker_confidence'] = 0.0
# Filter out speakers with insufficient duration
min_duration = self._get_config('speaker_min_duration', 2.0, input_data)
speakers_to_merge = [s for s, d in speaker_duration.items() if d < min_duration and s != 'SPEAKER_UNKNOWN']
# Merge low-duration speakers into SPEAKER_UNKNOWN
for result in transcript_results:
if result['speaker_id'] in speakers_to_merge:
result['speaker_id'] = 'SPEAKER_UNKNOWN'
result['speaker_confidence'] = 0.3
print(f"Identified {best_k} speakers based on voice characteristics")
return transcript_results
except Exception as e:
print(f"Error in speaker identification: {str(e)}")
# Fallback: assign all to single speaker
for result in transcript_results:
result['speaker_id'] = 'SPEAKER_1'
result['speaker_confidence'] = 1.0
return transcript_results
def _transcribe_chunks_parallel(self, chunk_data: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""Transcribe audio chunks in parallel"""
results = []
parallel_processing = self._get_config('parallel_processing', True, input_data)
if parallel_processing:
# Use fewer workers for speech recognition to avoid API limits
max_workers = min(3, len(chunk_data))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_chunk = {
executor.submit(self._transcribe_audio_chunk, chunk_info, input_data): chunk_info
for chunk_info in chunk_data
}
for future in as_completed(future_to_chunk):
chunk_info = future_to_chunk[future]
try:
result = future.result()
results.append(result)
print(f"Transcribed {result['chunk']}: {result['text'][:50]}..." if len(result['text']) > 50 else f"Transcribed {result['chunk']}: {result['text']}")
except Exception as e:
print(f"Error processing {chunk_info['filename']}: {str(e)}")
results.append({
'text': '[PROCESSING_ERROR]',
'confidence': 0.0,
'method': 'error',
'chunk': os.path.basename(chunk_info['filename']),
'start_time': chunk_info.get('start_time', 0),
'end_time': chunk_info.get('end_time', 0),
'duration': chunk_info.get('duration', 0),
'index': chunk_info.get('index', 0),
'error': str(e),
'voice_features': None
})
else:
for chunk_info in chunk_data:
result = self._transcribe_audio_chunk(chunk_info, input_data)
results.append(result)
print(f"Transcribed {result['chunk']}: {result['text'][:50]}..." if len(result['text']) > 50 else f"Transcribed {result['chunk']}: {result['text']}")
# Sort results by chunk index to maintain order
results.sort(key=lambda x: x['index'])
return results
def _post_process_transcript(self, transcript_results: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> Dict[str, Any]:
"""Post-process and analyze transcript results with speaker information"""
enable_speaker_id = self._get_config('enable_speaker_id', True, input_data)
# Identify speakers if enabled
if enable_speaker_id:
transcript_results = self._identify_speakers(transcript_results, input_data)
# Combine text with speaker tags
full_text_parts = []
speaker_tagged_text = []
successful_chunks = 0
total_confidence = 0.0
method_counts = {}
speaker_stats = {}
current_speaker = None
current_speaker_text = []
for result in transcript_results:
text = result['text']
speaker = result.get('speaker_id', 'SPEAKER_1')
start_time = result.get('start_time', 0)
if text not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']:
full_text_parts.append(text)
successful_chunks += 1
total_confidence += result['confidence']
# Handle speaker transitions
if enable_speaker_id:
if current_speaker != speaker:
# Save previous speaker's text
if current_speaker and current_speaker_text:
combined_text = ' '.join(current_speaker_text)
speaker_tagged_text.append(f"[{current_speaker}]: {combined_text}")
# Start new speaker
current_speaker = speaker
current_speaker_text = [text]
else:
# Continue with same speaker
current_speaker_text.append(text)
else:
speaker_tagged_text.append(text)
# Update speaker statistics
if speaker in speaker_stats:
speaker_stats[speaker]['duration'] += result.get('duration', 0)
speaker_stats[speaker]['word_count'] += len(text.split())
speaker_stats[speaker]['segments'] += 1
else:
speaker_stats[speaker] = {
'duration': result.get('duration', 0),
'word_count': len(text.split()),
'segments': 1,
'confidence': result.get('speaker_confidence', 1.0)
}
method = result['method']
method_counts[method] = method_counts.get(method, 0) + 1
# Add final speaker text
if enable_speaker_id and current_speaker and current_speaker_text:
combined_text = ' '.join(current_speaker_text)
speaker_tagged_text.append(f"[{current_speaker}]: {combined_text}")
# Combine texts
combined_text = ' '.join(full_text_parts)
speaker_formatted_text = combined_text
# Calculate statistics
word_count = len(combined_text.split()) if combined_text else 0
char_count = len(combined_text)
avg_confidence = total_confidence / max(1, successful_chunks)
success_rate = successful_chunks / len(transcript_results) if transcript_results else 0
# Estimate speaking duration (rough approximation: 150 words per minute)
estimated_duration_minutes = word_count / 150 if word_count > 0 else 0
return {
'full_transcript': combined_text,
'speaker_tagged_transcript': speaker_formatted_text,
'word_count': word_count,
'character_count': char_count,
'chunk_count': len(transcript_results),
'successful_chunks': successful_chunks,
'success_rate': success_rate,
'average_confidence': avg_confidence,
'method_distribution': method_counts,
'estimated_duration_minutes': estimated_duration_minutes,
'speaker_identification_enabled': enable_speaker_id,
'speaker_statistics': speaker_stats,
'total_speakers': len([s for s in speaker_stats.keys() if s != 'SPEAKER_UNKNOWN']),
'detailed_results': transcript_results
}
def extract_transcript(self, audio_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]:
"""Extract complete transcript from audio file"""
cache_enabled = self._get_config('cache_enabled', True, input_data)
enable_speaker_id = self._get_config('enable_speaker_id', True, input_data)
cache_suffix = "transcript_with_speakers.json" if enable_speaker_id else "transcript.json"
cache_path = self._get_cache_path(video_hash, cache_suffix)
# Check cache
cached_transcript = self._load_from_cache(cache_path, cache_enabled)
if cached_transcript:
print("Using cached transcript")
return cached_transcript
try:
# Step 1: Split audio into manageable chunks
print("Splitting audio into chunks...")
chunk_data = self._split_audio_intelligent(audio_path, input_data)
if not chunk_data:
return {
'error': 'Failed to split audio into chunks',
'full_transcript': '',
'speaker_tagged_transcript': '',
'success_rate': 0.0
}
# Step 2: Transcribe all chunks
print(f"Transcribing {len(chunk_data)} audio chunks...")
transcript_results = self._transcribe_chunks_parallel(chunk_data, input_data)
# Step 3: Post-process and combine results
print("Post-processing transcript and identifying speakers...")
final_result = self._post_process_transcript(transcript_results, input_data)
# Add timestamp
final_result['extraction_timestamp'] = time.time()
final_result['extraction_date'] = time.strftime('%Y-%m-%d %H:%M:%S')
# Cache results
self._save_to_cache(cache_path, final_result, cache_enabled)
return final_result
except Exception as e:
print(f"Error during transcript extraction: {str(e)}")
return {
'error': str(e),
'full_transcript': '',
'speaker_tagged_transcript': '',
'success_rate': 0.0
}
def _run(self, youtube_url: str, **kwargs) -> str:
"""Main execution method"""
input_data = {
'youtube_url': youtube_url,
**kwargs
}
if not youtube_url:
return "Error: youtube_url is required."
try:
# Generate video hash for caching
video_hash = self._get_video_hash(youtube_url)
# Step 1: Download audio
print(f"Downloading YouTube audio from {youtube_url}...")
audio_path = self.download_youtube_audio(youtube_url, video_hash, input_data)
if not audio_path or not os.path.exists(audio_path):
return "Error: Failed to download the YouTube audio."
# Step 2: Extract transcript
print("Extracting audio transcript...")
transcript_result = self.extract_transcript(audio_path, video_hash, input_data)
if transcript_result.get("error"):
return f"Error: {transcript_result['error']}"
# Choose the appropriate transcript
main_transcript = transcript_result.get('full_transcript')
#ipdb.set_trace()
print(f"Transcript extracted: {main_transcript[:50]}..." if len(main_transcript) > 50 else f"Transcript extracted: {main_transcript}")
return "TRANSCRIPT: " + main_transcript
except Exception as e:
return f"Error during transcript extraction: {str(e)}"
# Factory function to create the tool
def create_youtube_transcript_tool(**kwargs):
"""Factory function to create the transcript extraction tool with custom parameters"""
return YouTubeTranscriptExtractor(**kwargs)
# --- Model Configuration ---
def create_llm_pipeline():
#model_id = "meta-llama/Llama-2-13b-chat-hf"
#model_id = "meta-llama/Llama-3.3-70B-Instruct"
#model_id = "mistralai/Mistral-Small-24B-Base-2501"
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
#model_id = "Meta-Llama/Llama-2-7b-chat-hf"
#model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO"
#model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
#model_id = "mistralai/Mistral-7B-Instruct-v0.2"
#model_id = "Qwen/Qwen2-7B-Instruct"
#model_id = "GSAI-ML/LLaDA-8B-Instruct"
return pipeline(
"text-generation",
model=model_id,
device_map="cpu",
torch_dtype=torch.float16,
max_new_tokens=1024,
temperature=0.3,
top_k=50,
top_p=0.95
)
nlp = None # Set to None if not using spaCy, so the regex fallback is used in extract_entities
# --- Agent State Definition ---
class AgentState(TypedDict):
messages: Annotated[List[AnyMessage], lambda x, y: x + y]
done: bool = False # Default value of False
question: str
task_id: str
input_file: Optional[bytes]
file_type: Optional[str]
context: List[Document] # Using LangChain's Document class
file_path: Optional[str]
youtube_url: Optional[str]
answer: Optional[str]
frame_answers: Optional[list]
# --- Define Call LLM function ---
# 3. Improved LLM call with memory management
def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState:
"""Enhanced LLM call with better prompt engineering and hallucination prevention."""
print("Running call_llm with memory management...")
#ipdb.set_trace()
original_messages = messages_for_llm = state["messages"]
# Context management - be more aggressive about truncation
system_message_content = None
if messages_for_llm and isinstance(messages_for_llm[0], SystemMessage):
system_message_content = messages_for_llm[0]
regular_messages = messages_for_llm[1:]
else:
regular_messages = messages_for_llm
# Keep only the most recent messages (more aggressive)
max_regular_messages = 6 # Reduced from 9
if len(regular_messages) > max_regular_messages:
print(f"🔄 Truncating to {max_regular_messages} recent messages")
regular_messages = regular_messages[-max_regular_messages:]
# Reconstruct for LLM
messages_for_llm = []
if system_message_content:
messages_for_llm.append(system_message_content)
messages_for_llm.extend(regular_messages)
# Character limit check
total_chars = sum(len(str(msg.content)) for msg in messages_for_llm)
char_limit = 20000
if total_chars > char_limit:
print(f"📏 Context too long ({total_chars} chars) - further truncation")
while regular_messages and sum(len(str(m.content)) for m in regular_messages) > char_limit - (len(str(system_message_content.content)) if system_message_content else 0):
regular_messages.pop(0)
messages_for_llm = []
if system_message_content:
messages_for_llm.append(system_message_content)
messages_for_llm.extend(regular_messages)
new_state = state.copy()
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"🤖 Calling LLM with {len(messages_for_llm)} messages")
# Convert to simple string format that the model can understand
if len(messages_for_llm) == 2 and isinstance(messages_for_llm[0], SystemMessage) and isinstance(messages_for_llm[1], HumanMessage):
# Initial query - use simple format
system_content = messages_for_llm[0].content
human_content = messages_for_llm[1].content
formatted_input = f"{system_content}\n\nHuman: {human_content}\n\nAssistant:"
else:
# Ongoing conversation - build context
formatted_input = ""
# Add system message if present
if system_message_content:
formatted_input += f"{system_message_content.content}\n\n"
# Add conversation messages
for msg in regular_messages:
if isinstance(msg, HumanMessage):
formatted_input += f"Human: {msg.content}\n\n"
elif isinstance(msg, AIMessage):
formatted_input += f"Assistant: {msg.content}\n\n"
elif isinstance(msg, ToolMessage):
formatted_input += f"Tool Result: {msg.content}\n\n"
# Add explicit instruction for immediate final answer if we have recent tool results
if any(isinstance(msg, ToolMessage) for msg in regular_messages[-2:]):
formatted_input += "Based on the tool results above, provide your FINAL ANSWER now.\n\n"
formatted_input += "REMINDER ON ANSWER FORMAT: \n"
formatted_input += "- Numbers: no commas, no units unless specified\n"
formatted_input += "- Strings: no articles, no abbreviations, digits in plain text\n"
formatted_input += "- Lists: comma-separated following above rules\n"
formatted_input += "- Be extremely brief and concise"
formatted_input += "Assistant:"
print(f"Input preview: {formatted_input[:300]}...")
llm_response_object = llm_model.invoke(formatted_input)
# Process response and clean up hallucinated content
if isinstance(llm_response_object, BaseMessage):
raw_content = llm_response_object.content
elif hasattr(llm_response_object, 'content'):
raw_content = str(llm_response_object.content)
else:
raw_content = str(llm_response_object)
# Clean up the response to prevent hallucinated follow-up questions
cleaned_content = clean_llm_response(raw_content)
ai_message_response = AIMessage(content=cleaned_content)
print(f"🔍 LLM Response preview: {cleaned_content[:200]}...")
final_messages = original_messages + [ai_message_response]
new_state["messages"] = final_messages
new_state.pop("done", None)
except Exception as e:
print(f"❌ LLM call failed: {e}")
error_message = AIMessage(content=f"Error: LLM call failed - {str(e)}")
new_state["messages"] = original_messages + [error_message]
new_state["done"] = True
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
return new_state
def clean_llm_response(response_text: str) -> str:
"""
Clean LLM response to prevent hallucinated follow-up questions and conversations.
Specifically handles ReAct format: Thought: -> Action: -> Action Input:
"""
if not response_text:
return response_text
print(f"Initial response: {response_text[:200]}...")
# --- START MODIFICATION ---
# Isolate the text generated by the assistant in the last turn.
# This prevents parsing examples or instructions from the preamble.
assistant_marker = "Assistant:"
last_marker_idx = response_text.rfind(assistant_marker)
text_to_process = response_text # Default to full text if marker not found
if last_marker_idx != -1:
# If "Assistant:" is found, process only the text after the last occurrence.
text_to_process = response_text[last_marker_idx + len(assistant_marker):].strip()
print(f"ℹ️ Parsing content after last 'Assistant:': {text_to_process[:200]}...")
else:
# If "Assistant:" is not found, process the whole input.
# This might occur if the input is already just the assistant's direct response
# or if the prompt structure is different.
print(f"ℹ️ No 'Assistant:' marker found. Processing entire input as is.")
# --- END MODIFICATION ---
# Now, all subsequent operations use 'text_to_process'
# Try to find a complete ReAct pattern in the assistant's actual output
react_pattern = r'Thought:\s*(.*?)\s*Action:\s*([^\n\r]+)\s*Action Input:\s*(.*?)(?=\s*(?:Thought:|Action:|FINAL ANSWER:|$))'
# Apply search to 'text_to_process'
react_match = re.search(react_pattern, text_to_process, re.DOTALL | re.IGNORECASE)
if react_match:
thought_text = react_match.group(1).strip()
action_name = react_match.group(2).strip()
action_input = react_match.group(3).strip()
# Clean up the action input - remove any trailing content that looks like instructions
action_input_clean = re.sub(r'\s*(When you have|FINAL ANSWER|ANSWER FORMAT|IMPORTANT:).*$', '', action_input, flags=re.DOTALL | re.IGNORECASE)
action_input_clean = action_input_clean.strip()
react_sequence = f"Thought: {thought_text}\nAction: {action_name}\nAction Input: {action_input_clean}"
print(f"🔧 Found ReAct pattern - Action: {action_name}, Input: {action_input_clean[:100]}...")
# Check if there's a FINAL ANSWER after the action input (this would be hallucination)
# Check in the remaining part of 'text_to_process'
remaining_text_in_process = text_to_process[react_match.end():]
final_answer_after = re.search(r'FINAL ANSWER:', remaining_text_in_process, re.IGNORECASE)
if final_answer_after:
print(f"🚫 Removed hallucinated FINAL ANSWER after tool call")
return react_sequence
# If no ReAct pattern in 'text_to_process', check for standalone FINAL ANSWER
# This variable will hold the text being processed for FINAL ANSWER and then for fallback.
current_text_for_processing = text_to_process
final_answer_match = re.search(r"FINAL ANSWER:\s*(.+?)(?=\n|$)", current_text_for_processing, re.IGNORECASE)
if final_answer_match:
answer_content = final_answer_match.group(1).strip()
template_phrases = [
'[concise answer only]',
'[concise answer - number/word/list only]',
'[brief answer]',
'[your answer here]',
'concise answer only',
'brief answer',
'your answer here'
]
if any(phrase.lower() in answer_content.lower() for phrase in template_phrases):
print(f"🚫 Ignoring template FINAL ANSWER: {answer_content}")
# Remove the template FINAL ANSWER and continue cleaning on the remainder of 'current_text_for_processing'
current_text_for_processing = current_text_for_processing[:final_answer_match.start()].strip()
# Fall through to the general cleanup section below
else:
# Keep everything from the start of 'current_text_for_processing' up to and including the real FINAL ANSWER line only
cleaned_output = current_text_for_processing[:final_answer_match.end()]
# Check if there's additional content after FINAL ANSWER in 'current_text_for_processing'
remaining_after_final_answer = current_text_for_processing[final_answer_match.end():].strip()
if remaining_after_final_answer:
print(f"🚫 Removed content after FINAL ANSWER: {remaining_after_final_answer[:100]}...")
return cleaned_output.strip()
# If no ReAct, or FINAL ANSWER was a template or not found, apply fallback cleaning to 'current_text_for_processing'
lines = current_text_for_processing.split('\n')
cleaned_lines = []
for i, line in enumerate(lines):
# Stop if we see the model repeating system instructions
if re.search(r'\[SYSTEM\]|\[HUMAN\]|\[ASSISTANT\]|\[TOOL\]', line, re.IGNORECASE):
print(f"🚫 Stopped at repeated system format: {line}")
break
# Stop if we see the model generating format instructions
if re.search(r'CRITICAL INSTRUCTIONS|FORMAT for tool use|ANSWER FORMAT', line, re.IGNORECASE):
print(f"🚫 Stopped at repeated instructions: {line}")
break
# Stop if we see the model role-playing as a human asking questions
if re.search(r'(what are|what is|how many|can you tell me)', line, re.IGNORECASE) and not line.strip().startswith(('Thought:', 'Action:', 'Action Input:')):
# Make sure this isn't part of a legitimate thought process
if i > 0 and not any(keyword in lines[i-1] for keyword in ['Thought:', 'Action:', 'need to']):
print(f"🚫 Stopped at hallucinated question: {line}")
break
cleaned_lines.append(line)
cleaned = '\n'.join(cleaned_lines).strip()
print(f"Final cleaned response (fallback): {cleaned[:200]}...")
return cleaned
def parse_react_output(state: AgentState) -> AgentState:
"""
Enhanced parsing with better FINAL ANSWER detection and flow control.
"""
print("Running parse_react_output...")
#ipdb.set_trace()
messages = state.get("messages", [])
if not messages:
print("No messages in state.")
new_state = state.copy()
new_state["done"] = True
return new_state
# DEBUG
print(f"parse_react_output: Entry message count: {len(messages)}")
if messages and hasattr(messages[-1], 'tool_calls'):
print(f"parse_react_output: Number of tool calls in last AIMessage: {len(messages[-1].tool_calls)}")
last_message = messages[-1]
new_state = state.copy()
if not isinstance(last_message, AIMessage):
print("Last message is not an AIMessage instance.")
return new_state
content = last_message.content
if not isinstance(content, str):
content = str(content)
# Look for FINAL ANSWER first - this should take absolute priority
# Use a more precise regex to capture just the answer line
final_answer_match = re.search(r"FINAL ANSWER:\s*([^\n\r]+)", content, re.IGNORECASE)
if final_answer_match:
final_answer_text = final_answer_match.group(1).strip()
# Check if this is template text (not a real answer)
template_phrases = [
'[concise answer only]',
'[concise answer - number/word/list only]',
'[brief answer]',
'[your answer here]',
'concise answer only',
'brief answer',
'your answer here'
]
# If it's template text, don't treat it as a final answer
if any(phrase.lower() in final_answer_text.lower() for phrase in template_phrases):
print(f"🚫 Ignoring template FINAL ANSWER: '{final_answer_text}'")
# Continue processing as if no final answer was found
else:
print(f"🎯 FINAL ANSWER found: '{final_answer_text}' - ENDING")
# Store the answer in state for easy access
new_state["answer"] = final_answer_text
# Clean up the message content to just show the final answer
clean_content = f"FINAL ANSWER: {final_answer_text}"
updated_ai_message = AIMessage(content=clean_content, tool_calls=[])
new_state["messages"] = messages[:-1] + [updated_ai_message]
new_state["done"] = True
return new_state
# If no FINAL ANSWER, look for tool calls
action_match = re.search(r"Action:\s*([^\n]+)", content, re.IGNORECASE)
action_input_match = re.search(r"Action Input:\s*(.+)", content, re.IGNORECASE | re.DOTALL)
if action_match and action_input_match:
tool_name = action_match.group(1).strip()
tool_input_raw = action_input_match.group(1).strip()
if tool_name.lower() == "none":
print("Action is 'None' - treating as regular response")
updated_ai_message = AIMessage(content=content, tool_calls=[])
new_state["messages"] = messages[:-1] + [updated_ai_message]
new_state.pop("done", None)
return new_state
print(f"🔧 Tool call: {tool_name} with input: {tool_input_raw[:100]}...")
# Parse tool arguments
tool_args = {}
try:
trimmed_input = tool_input_raw.strip()
if (trimmed_input.startswith('{') and trimmed_input.endswith('}')) or \
(trimmed_input.startswith('[') and trimmed_input.endswith(']')):
tool_args = ast.literal_eval(trimmed_input)
if not isinstance(tool_args, dict):
tool_args = {"query": tool_input_raw}
else:
tool_args = {"query": tool_input_raw}
except (ValueError, SyntaxError):
tool_args = {"query": tool_input_raw}
tool_call_id = str(uuid.uuid4())
parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}]
updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls)
new_state["messages"] = messages[:-1] + [updated_ai_message]
new_state.pop("done", None)
return new_state
# No tool call or final answer - treat as regular response
print("No actionable content found - continuing conversation")
updated_ai_message = AIMessage(content=content, tool_calls=[])
new_state["messages"] = messages[:-1] + [updated_ai_message]
new_state.pop("done", None)
# DEBUG
print(f"parse_react_output: Exit message count: {len(new_state['messages'])}")
return new_state
# 4. Improved call_tool_with_memory_management to prevent duplicate processing
def call_tool_with_memory_management(state: AgentState) -> AgentState:
"""Process tool calls with memory management, avoiding duplicates."""
print("Running call_tool with memory management...")
# Clear CUDA cache before processing
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"🧹 Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
except ImportError:
pass
except Exception as e:
print(f"Error clearing CUDA cache: {e}")
# Check if we have parsed tool calls from the condition function
if 'parsed_tool_calls' in state and state.get('parsed_tool_calls'):
print("Executing parsed tool calls...")
return execute_parsed_tool_calls(state)
# Fallback to original OpenAI-style tool calls handling
messages = state.get("messages", [])
if not messages:
print("No messages found in state.")
return state
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
print("No tool calls found in last message")
return state
# Avoid processing the same tool calls multiple times
if hasattr(last_message, '_processed_tool_calls'):
print("Tool calls already processed, skipping...")
return state
# Copy the messages to avoid mutating the original list
new_messages = list(messages)
print(f"Processing {len(last_message.tool_calls)} tool calls from last message")
# Get file_path from state to pass to tools
file_path_to_pass = state.get('file_path')
for i, tool_call_item in enumerate(last_message.tool_calls):
# Handle both dict and object-style tool calls
if isinstance(tool_call_item, dict):
tool_name = tool_call_item.get("name", "")
raw_args = tool_call_item.get("args")
tool_call_id = tool_call_item.get("id", str(uuid.uuid4()))
elif hasattr(tool_call_item, "name") and hasattr(tool_call_item, "id"):
tool_name = getattr(tool_call_item, "name", "")
raw_args = getattr(tool_call_item, "args", None)
tool_call_id = getattr(tool_call_item, "id", str(uuid.uuid4()))
else:
print(f"Skipping malformed tool call item: {tool_call_item}")
continue
print(f"Processing tool call {i+1}: {tool_name}")
# Find the matching tool
selected_tool = None
for tool_instance in tools:
if tool_instance.name.lower() == tool_name.lower():
selected_tool = tool_instance
break
if not selected_tool:
tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}"
print(f"Tool not found: {tool_name}")
else:
try:
# Prepare the arguments for the tool.run() method
tool_run_input_dict = {}
if isinstance(raw_args, dict):
tool_run_input_dict = raw_args.copy()
elif raw_args is not None:
tool_run_input_dict["query"] = str(raw_args)
# Add file_path to the dictionary for the tool
tool_run_input_dict['file_path'] = file_path_to_pass
print(f"Executing {tool_name} with args: {tool_run_input_dict} ...")
tool_result = selected_tool.run(tool_run_input_dict)
#ipdb.set_trace()
# Aggressive truncation to prevent memory issues
if not isinstance(tool_result, str):
tool_result = str(tool_result)
max_length = 3000 if "wikipedia" in tool_name.lower() else 2000
if len(tool_result) > max_length:
original_length = len(tool_result)
tool_result = tool_result[:max_length] + f"... [Result truncated from {original_length} to {max_length} chars to prevent memory issues]"
print(f"📄 Truncated result to {max_length} characters")
print(f"Tool result length: {len(tool_result)} characters")
except Exception as e:
tool_result = f"Error executing tool '{tool_name}': {str(e)}"
print(f"Tool execution error: {e}")
# Create tool message - ONLY ONE PER TOOL CALL
tool_message = ToolMessage(
content=tool_result,
name=tool_name,
tool_call_id=tool_call_id
)
new_messages.append(tool_message)
print(f"Added tool message for {tool_name}")
# Mark the last message as processed to prevent re-processing
if hasattr(last_message, '__dict__'):
last_message._processed_tool_calls = True
# Update the state
new_state = state.copy()
new_state["messages"] = new_messages
# Clear CUDA cache after processing
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"🧹 Cleared CUDA cache post-processing. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
except ImportError:
pass
except Exception as e:
print(f"Error clearing CUDA cache post-processing: {e}")
return new_state
# 3. Enhanced execute_parsed_tool_calls to prevent duplicate observations
def execute_parsed_tool_calls(state: AgentState):
"""
Execute tool calls that were parsed from the Thought/Action/Action Input format.
This is called by call_tool when parsed_tool_calls are present in state.
"""
# Tool name mappings
tool_name_mappings = {
'wikipedia_semantic_search': 'wikipedia_tool',
'wikipedia': 'wikipedia_tool',
'search': 'enhanced_search',
'duckduckgo_search': 'enhanced_search',
'web_search': 'enhanced_search',
'enhanced_search': 'enhanced_search',
'youtube_screenshot_qa_tool': 'youtube_tool',
'youtube': 'youtube_tool',
'youtube_transcript_extractor': 'youtube_transcript_extractor',
'youtube_audio_tool': 'youtube_transcript_extractor'
}
# Create a lookup by tool names
tools_by_name = {}
for tool in tools:
tools_by_name[tool.name.lower()] = tool
# Copy messages to avoid mutation during iteration
new_messages = list(state["messages"])
# Process each tool call ONCE
for tool_call in state['parsed_tool_calls']:
action = tool_call['action']
action_input = tool_call['action_input']
normalized_action = tool_call['normalized_action']
print(f"🚀 Executing tool: {action} with input: {action_input}")
# Find the tool instance
tool_instance = None
if normalized_action in tools_by_name:
tool_instance = tools_by_name[normalized_action]
elif normalized_action in tool_name_mappings:
mapped_name = tool_name_mappings[normalized_action]
if mapped_name in tools_by_name:
tool_instance = tools_by_name[mapped_name]
if tool_instance:
try:
# Pass file_path if the tool expects it
if hasattr(tool_instance, 'run'):
if 'file_path' in tool_instance.run.__code__.co_varnames:
result = tool_instance.run(action_input, file_path=state.get('file_path'))
else:
result = tool_instance.run(action_input)
else:
result = str(tool_instance)
# Truncate long results
if len(result) > 6000:
result = result[:6000] + "... [Result truncated due to length]"
# Create a SINGLE observation message
from langchain_core.messages import ToolMessage
tool_message = ToolMessage(
content=f"Observation: {result}",
name=action,
tool_call_id=str(uuid.uuid4())
)
new_messages.append(tool_message)
print(f"✅ Tool '{action}' executed successfully")
except Exception as e:
print(f"❌ Error executing tool '{action}': {e}")
from langchain_core.messages import ToolMessage
error_message = ToolMessage(
content=f"Observation: Error executing '{action}': {str(e)}",
name=action,
tool_call_id=str(uuid.uuid4())
)
new_messages.append(error_message)
else:
print(f"❌ Tool '{action}' not found in available tools")
available_tool_names = list(tools_by_name.keys())
from langchain_core.messages import ToolMessage
error_message = ToolMessage(
content=f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}",
name=action,
tool_call_id=str(uuid.uuid4())
)
new_messages.append(error_message)
# Update state with new messages and clear parsed tool calls
new_state = state.copy()
new_state["messages"] = new_messages
new_state['parsed_tool_calls'] = [] # Clear to prevent re-execution
return new_state
# 1. Add loop detection to your AgentState
def should_continue(state: AgentState) -> str:
"""Enhanced continuation logic with better limits."""
print("Running should_continue...")
# Check done flag first
if state.get("done", False):
print("✅ Done flag is True - ending")
return "end"
messages = state["messages"]
# More aggressive message limit
#if len(messages) > 20: # Reduced from 15
# print(f"⚠️ Message limit reached ({len(messages)}/20) - forcing end")
# return "end"
# Check for repeated patterns (stuck in loop)
if len(messages) >= 6:
recent_contents = [str(msg.content)[:100] for msg in messages[-6:] if hasattr(msg, 'content')]
if len(set(recent_contents)) < 3: # Too much repetition
print("🔄 Detected repetitive pattern - ending")
return "end"
print(f"📊 Continuing... ({len(messages)} messages so far)")
return "continue"
def route_after_parse_react(state: AgentState) -> str:
"""Determines the next step after parsing LLM output, prioritizing end state."""
if state.get("done", False): # Check if parse_react_output decided we are done
return "end_processing"
# Original logic: check for tool calls in the last message
# Ensure messages list and last message exist before checking tool_calls
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
return "call_tool"
return "call_llm"
# --- Graph Construction ---
# --- Graph Construction ---
def create_memory_safe_workflow():
"""Create a workflow with memory management and loop prevention."""
# These models are initialized here but might be better managed if they need to be released/reinitialized
# like you attempt in run_agent. Consider passing them or managing their lifecycle carefully.
hf_pipe = create_llm_pipeline()
llm = HuggingFacePipeline(pipeline=hf_pipe)
# vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly
# processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used
# model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used
workflow = StateGraph(AgentState)
# Bind the llm_model to the call_llm_with_memory_management function
bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm)
# Add nodes with memory-safe versions
workflow.add_node("call_llm", bound_call_llm) # Use the bound version here
workflow.add_node("parse_react_output", parse_react_output)
workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly
# Set entry point
workflow.set_entry_point("call_llm")
# Add conditional edges
workflow.add_conditional_edges(
"call_llm",
should_continue,
{
"continue": "parse_react_output",
"end": END
}
)
workflow.add_conditional_edges(
"parse_react_output",
route_after_parse_react,
{
"call_tool": "call_tool",
"call_llm": "call_llm",
"end_processing": END
}
)
workflow.add_edge("call_tool", "call_llm")
return workflow.compile()
def count_english_words(text):
# Remove punctuation, lowercase, split into words
table = str.maketrans('', '', string.punctuation)
words_in_text = text.translate(table).lower().split()
return sum(1 for word in words_in_text if word in english_words)
def fix_backwards_text(text):
reversed_text = text[::-1]
original_count = count_english_words(text)
reversed_count = count_english_words(reversed_text)
if reversed_count > original_count:
return reversed_text
else:
return text
# --- Run the Agent ---
# Enhanced system prompt for better behavior
def run_agent(agent, state: AgentState):
"""Enhanced agent initialization with better prompt and hallucination prevention."""
global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL, tools
# Initialize tools
WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS()
SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=3000)
YOUTUBE_TOOL = EnhancedYoutubeScreenshotQA()
YOUTUBE_AUDIO_TOOL = YouTubeTranscriptExtractor()
tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL]
formatted_tools_description = render_text_description(tools)
current_date_str = datetime.now().strftime("%Y-%m-%d")
# Enhanced system prompt with stricter boundaries
system_content = f"""You are an AI assistant with access to these tools:
{formatted_tools_description}
CRITICAL INSTRUCTIONS:
1. Answer ONLY the question asked by the human
2. Do NOT generate additional questions or continue conversations
3. Use tools ONLY when you need specific information you don't know
4. After using a tool, provide your FINAL ANSWER immediately
5. STOP after giving your FINAL ANSWER - do not continue
FORMAT for tool use:
Thought: <brief reasoning>
Action: <exact_tool_name>
Action Input: <tool_input>
When you have the answer, immediately provide:
FINAL ANSWER: [concise answer only]
ANSWER FORMAT:
- Numbers: no commas, no units unless specified
- Strings: no articles, no abbreviations, digits in plain text
- Lists: comma-separated following above rules
- Be extremely brief and concise
- Do not provide additional context or explanations
- Do not provide parentheticals
IMPORTANT: You are responding to ONE question only. Do not ask follow-up questions or generate additional dialogue.
Current date: {current_date_str}
"""
query = fix_backwards_text(state['question'])
# Check for YouTube URLs
yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+"
if re.search(yt_pattern, query):
url_match = re.search(r"(https?://[^\s]+)", query)
if url_match:
state['youtube_url'] = url_match.group(0)
# Initialize messages
system_message = SystemMessage(content=system_content)
human_message = HumanMessage(content=query)
state['messages'] = [system_message, human_message]
state["done"] = False
# Run the agent
result = agent.invoke(state)
# Cleanup
if result.get("done"):
#torch.cuda.empty_cache()
#torch.cuda.ipc_collect()
gc.collect()
print("🧹 Released GPU memory after completion")
return result["messages"]