Spaces:
Sleeping
Sleeping
# Standard Library | |
import os | |
import re | |
import tempfile | |
import string | |
import glob | |
import shutil | |
import gc | |
import uuid | |
import signal | |
from datetime import datetime | |
from io import BytesIO | |
from contextlib import contextmanager | |
from langchain_huggingface import HuggingFacePipeline | |
from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set | |
import time | |
from collections import Counter | |
from pydantic import Field | |
import hashlib | |
import json | |
import numpy as np | |
import ast | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from collections import Counter, defaultdict | |
# Third-Party Packages | |
import cv2 | |
import requests | |
import wikipedia | |
import spacy | |
import yt_dlp | |
import librosa | |
from PIL import Image | |
from bs4 import BeautifulSoup | |
from duckduckgo_search import DDGS | |
from sentence_transformers import SentenceTransformer | |
from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline, AutoTokenizer | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sklearn.cluster import KMeans | |
from sklearn.preprocessing import StandardScaler | |
import speech_recognition as sr | |
from pydub import AudioSegment | |
from pydub.silence import split_on_silence | |
import nltk | |
from nltk.corpus import words | |
# LangChain Ecosystem | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import WikipediaLoader | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_community.retrievers import BM25Retriever | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage | |
from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description | |
from langchain_core.documents import Document | |
# LangGraph | |
from langgraph.graph import START, END, StateGraph | |
from langgraph.prebuilt import ToolNode, tools_condition | |
# PyTorch | |
import torch | |
from functools import partial | |
from transformers import pipeline | |
# Additional Utilities | |
from datetime import datetime | |
from urllib.parse import urljoin, urlparse | |
import logging | |
nlp = spacy.load("en_core_web_sm") | |
# Ensure the word list is downloaded | |
nltk.download('words', quiet=True) | |
english_words = set(words.words()) | |
logger = logging.getLogger(__name__) | |
# --- Model Configuration --- | |
def create_llm_pipeline(): | |
#model_id = "meta-llama/Llama-2-13b-chat-hf" | |
#model_id = "meta-llama/Llama-3.3-70B-Instruct" | |
#model_id = "mistralai/Mistral-Small-24B-Base-2501" | |
model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
#model_id = "Qwen/Qwen2-7B-Instruct" | |
# Load tokenizer explicitly with fast version | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
use_fast=True, # Force fast tokenizer | |
add_prefix_space=True # Only if actually needed | |
) | |
return pipeline( | |
"text-generation", | |
model=model_id, | |
tokenizer = tokenizer, | |
device_map="cpu", | |
torch_dtype=torch.float16, | |
max_new_tokens=1024, | |
temperature=0.1 | |
) | |
# Define file extension sets for each category | |
PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
SPREADSHEET_EXTENSIONS = { | |
'.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
'.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
'.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
'.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
'.pmv', '.uos', '.txt' | |
} | |
def get_file_type(filename: str) -> str: | |
if not filename or '.' not in filename or filename == '': | |
return '' | |
ext = filename.lower().rsplit('.', 1)[-1] | |
dot_ext = f'.{ext}' | |
if dot_ext in PICTURE_EXTENSIONS: | |
return 'picture' | |
elif dot_ext in AUDIO_EXTENSIONS: | |
return 'audio' | |
elif dot_ext in CODE_EXTENSIONS: | |
return 'code' | |
elif dot_ext in SPREADSHEET_EXTENSIONS: | |
return 'spreadsheet' | |
else: | |
return 'unknown' | |
def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
""" | |
Writes bytes to a file in the system temporary directory using the provided file_name. | |
Returns the full path to the saved file. | |
The file will persist until manually deleted or the OS cleans the temp directory. | |
""" | |
temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces | |
os.makedirs(temp_dir, exist_ok=True) | |
file_path = os.path.join(temp_dir, file_name) | |
with open(file_path, 'wb') as f: | |
f.write(file_bytes) | |
print(f"File written to: {file_path}") | |
return file_path | |
def extract_final_answer(text: str) -> str: | |
""" | |
Returns the substring starting from the last occurrence of 'FINAL ANSWER:' (case-insensitive) | |
to the end of the string, with any trailing punctuation removed. | |
If not found, returns an empty string. | |
""" | |
marker = "FINAL ANSWER:" | |
idx = text.lower().rfind(marker.lower()) | |
if idx == -1: | |
return "" | |
result = text[idx:].strip() | |
# Remove trailing punctuation | |
return result.rstrip(string.punctuation + " ") | |
class EnhancedDuckDuckGoSearchTool(BaseTool): | |
name: str = "enhanced_search" | |
description: str = ( | |
"Performs a DuckDuckGo web search and retrieves actual content from the top web results. " | |
"Input should be a search query string. " | |
"Returns search results with extracted content from web pages, making it much more useful for answering questions. " | |
"Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. " | |
"Ideal for topics that require the latest news, recent developments, or information not covered in static sources." | |
) | |
max_results: int = 3 | |
max_chars_per_page: int = 3000 | |
session: Any = None # Now it's optional and defaults to None | |
# Use model_post_init for initialization logic in Pydantic v2+ | |
def model_post_init(self, __context: Any) -> None: | |
super().model_post_init(__context) | |
# Initialize HTTP session here | |
self.session = requests.Session() | |
self.session.headers.update({ | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', | |
'Accept-Language': 'en-US,en;q=0.5', | |
'Accept-Encoding': 'gzip, deflate', | |
'Connection': 'keep-alive', | |
'Upgrade-Insecure-Requests': '1', | |
}) | |
def _search_duckduckgo(self, query: str) -> List[Dict]: | |
"""Perform DuckDuckGo search and return results.""" | |
try: | |
with DDGS() as ddgs: | |
results = list(ddgs.text(query, max_results=self.max_results)) | |
return results | |
except Exception as e: | |
logger.error(f"DuckDuckGo search failed: {e}") | |
return [] | |
def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]: | |
"""Extract clean text content from a web page.""" | |
try: | |
# Skip certain file types | |
if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']): | |
return "Content type not supported for extraction" | |
response = self.session.get(url, timeout=timeout, allow_redirects=True) | |
response.raise_for_status() | |
# Check content type | |
content_type = response.headers.get('content-type', '').lower() | |
if 'text/html' not in content_type: | |
return "Non-HTML content detected" | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Remove script and style elements | |
for script in soup(["script", "style", "nav", "header", "footer", "aside", "form"]): | |
script.decompose() | |
# Try to find main content areas | |
main_content = None | |
for selector in ['main', 'article', '.content', '#content', '.post', '.entry']: | |
main_content = soup.select_one(selector) | |
if main_content: | |
break | |
if not main_content: | |
main_content = soup.find('body') or soup | |
# Extract text | |
text = main_content.get_text(separator='\n', strip=True) | |
# Clean up the text | |
lines = [line.strip() for line in text.split('\n') if line.strip()] | |
text = '\n'.join(lines) | |
# Remove excessive whitespace | |
text = re.sub(r'\n{3,}', '\n\n', text) | |
text = re.sub(r' {2,}', ' ', text) | |
# Truncate if too long | |
if len(text) > self.max_chars_per_page: | |
text = text[:self.max_chars_per_page] + "\n[Content truncated...]" | |
return text | |
except requests.exceptions.Timeout: | |
return "Page loading timed out" | |
except requests.exceptions.RequestException as e: | |
return f"Failed to retrieve page: {str(e)}" | |
except Exception as e: | |
logger.error(f"Content extraction failed for {url}: {e}") | |
return "Failed to extract content from page" | |
def _format_search_result(self, result: Dict, content: str) -> str: | |
"""Format a single search result with its content.""" | |
title = result.get('title', 'No title') | |
url = result.get('href', 'No URL') | |
snippet = result.get('body', 'No snippet') | |
formatted = f""" | |
🔍 **{title}** | |
URL: {url} | |
Snippet: {snippet} | |
📄 **Page Content:** | |
{content} | |
--- | |
""" | |
return formatted | |
def run(self, query: str) -> str: | |
"""Execute the enhanced search.""" | |
if not query or not query.strip(): | |
return "Please provide a search query." | |
query = query.strip() | |
logger.info(f"Searching for: {query}") | |
# Perform DuckDuckGo search | |
search_results = self._search_duckduckgo(query) | |
if not search_results: | |
return f"No search results found for query: {query}" | |
# Process each result and extract content | |
enhanced_results = [] | |
processed_count = 0 | |
for i, result in enumerate(search_results[:self.max_results]): | |
url = result.get('href', '') | |
if not url: | |
continue | |
logger.info(f"Processing result {i+1}: {url}") | |
# Extract content from the page | |
content = self._extract_content_from_url(url) | |
if content and len(content.strip()) > 50: # Only include results with substantial content | |
formatted_result = self._format_search_result(result, content) | |
enhanced_results.append(formatted_result) | |
processed_count += 1 | |
# Small delay to be respectful to servers | |
time.sleep(0.5) | |
if not enhanced_results: | |
return f"Search completed but no content could be extracted from the pages for query: {query}" | |
# Compile final response | |
response = f"""🔍 **Enhanced Search Results for: "{query}"** | |
Found {len(search_results)} results, successfully processed {processed_count} pages with content. | |
{''.join(enhanced_results)} | |
💡 **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query. | |
""" | |
# Ensure the response isn't too long | |
if len(response) > 8000: | |
response = response[:8000] + "\n[Response truncated to prevent memory issues]" | |
return response | |
def _run(self, query: str) -> str: | |
"""Required by BaseTool interface.""" | |
return self.run(query) | |
# --- Agent State Definition --- | |
class AgentState(TypedDict): | |
messages: Annotated[List[AnyMessage], lambda x, y: x + y] | |
done: bool = False # Default value of False | |
question: str | |
task_id: str | |
input_file: Optional[bytes] | |
file_type: Optional[str] | |
context: List[Document] # Using LangChain's Document class | |
file_path: Optional[str] | |
youtube_url: Optional[str] | |
answer: Optional[str] | |
frame_answers: Optional[list] | |
def fetch_page_with_tables(page_title): | |
""" | |
Fetches Wikipedia page content and extracts all tables as readable text. | |
Returns a tuple: (main_text, [table_texts]) | |
""" | |
# Fetch the page object | |
page = wikipedia.page(page_title) | |
main_text = page.content | |
# Get the HTML for table extraction | |
html = page.html() | |
soup = BeautifulSoup(html, 'html.parser') | |
tables = soup.find_all('table') | |
table_texts = [] | |
for table in tables: | |
rows = table.find_all('tr') | |
table_lines = [] | |
for row in rows: | |
cells = row.find_all(['th', 'td']) | |
cell_texts = [cell.get_text(strip=True) for cell in cells] | |
if cell_texts: | |
# Format as Markdown table row | |
table_lines.append(" | ".join(cell_texts)) | |
if table_lines: | |
table_text = "\n".join(table_lines) | |
table_texts.append(table_text) | |
return main_text, table_texts | |
class WikipediaSearchToolWithFAISS(BaseTool): | |
name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval" | |
description: str = ( | |
"Fetches content from multiple Wikipedia pages based on intelligent NLP query processing " | |
"of various search candidates, with strong prioritization of query entities. It then performs " | |
"entity-focused semantic search across all fetched content to find the most relevant information, " | |
"with improved retrieval for lists like discographies. Uses spaCy for named entity " | |
"recognition and query enhancement. Input should be a search query or topic. " | |
"Note: Uses the current live version of Wikipedia." | |
) | |
embedding_model_name: str = "all-MiniLM-L6-v2" | |
chunk_size: int = 4000 | |
chunk_overlap: int = 250 # Maintained moderate overlap | |
top_k_results: int = 3 | |
spacy_model: str = "en_core_web_sm" | |
# Increased multiplier to fetch more candidates per semantic query variant | |
semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
try: | |
self._nlp = spacy.load(self.spacy_model) | |
print(f"Loaded spaCy model: {self.spacy_model}") | |
self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name) | |
# Refined separators for better handling of Wikipedia lists and sections | |
self._text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.chunk_size, | |
chunk_overlap=self.chunk_overlap, | |
separators=[ | |
"\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content) | |
"\n\n\n", "\n\n", # Multiple newlines (paragraph breaks) | |
"\n* ", "\n- ", "\n# ", # List items | |
"\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation | |
" ", "" # Word and character level | |
] | |
) | |
except OSError as e: | |
print(f"Error loading spaCy model '{self.spacy_model}': {e}") | |
print("Try running: python -m spacy download en_core_web_sm") | |
self._nlp = None | |
self._embedding_model = None | |
self._text_splitter = None | |
except Exception as e: | |
print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}") | |
self._nlp = None | |
self._embedding_model = None | |
self._text_splitter = None | |
def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]: | |
if not self._nlp: | |
return [], [], query | |
doc = self._nlp(query) | |
main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]] | |
keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2] | |
main_entities = list(dict.fromkeys(main_entities)) | |
keywords = list(dict.fromkeys(keywords)) | |
processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()] | |
processed_query = " ".join(processed_tokens) | |
return main_entities, keywords, processed_query | |
def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]: | |
candidates_set = set() | |
entity_prefix = main_entities[0] if main_entities else None | |
for me in main_entities: | |
candidates_set.add(me) | |
candidates_set.add(query) | |
if processed_query and processed_query != query: | |
candidates_set.add(processed_query) | |
if entity_prefix and keywords: | |
first_entity_lower = entity_prefix.lower() | |
for kw in keywords[:3]: | |
if kw not in first_entity_lower and len(kw) > 2: | |
candidates_set.add(f"{entity_prefix} {kw}") | |
keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2) | |
if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}") | |
if len(main_entities) > 1: | |
candidates_set.add(" ".join(main_entities[:2])) | |
if keywords: | |
keyword_combo = " ".join(keywords[:2]) | |
if entity_prefix: | |
candidate_to_add = f"{entity_prefix} {keyword_combo}" | |
if not any(c.lower() == candidate_to_add.lower() for c in candidates_set): | |
candidates_set.add(candidate_to_add) | |
elif not main_entities: | |
candidates_set.add(keyword_combo) | |
ordered_candidates = [] | |
for me in main_entities: | |
if me not in ordered_candidates: ordered_candidates.append(me) | |
for c in list(candidates_set): | |
if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c) | |
print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}") | |
return ordered_candidates | |
def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]: | |
candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text) | |
found_pages_data: List[Tuple[str, str]] = [] | |
processed_page_titles: Set[str] = set() | |
for i, candidate_query in enumerate(candidates): | |
print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'") | |
page_object = None | |
final_page_title = None | |
is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False | |
try: | |
try: | |
page_to_load = candidate_query | |
suggest_mode = True # Default to auto_suggest=True | |
if is_candidate_entity_focused and main_entities_from_query: | |
try: # Attempt precise match first for entity-focused candidates | |
temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True) | |
suggest_mode = False # Flag that precise match worked | |
except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError): | |
print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.") | |
# Fallthrough to auto_suggest=True below if this fails | |
if suggest_mode: # If not attempted or failed with auto_suggest=False | |
temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True) | |
final_page_title = temp_page.title | |
if is_candidate_entity_focused and main_entities_from_query: | |
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
if not title_matches_main_entity: | |
print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') " | |
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
continue | |
if final_page_title in processed_page_titles: | |
print(f" ~ Already processed '{final_page_title}'") | |
continue | |
page_object = temp_page | |
print(f" ✓ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'") | |
except wikipedia.exceptions.PageError: | |
if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates | |
print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...") | |
search_results = wikipedia.search(candidate_query, results=1) | |
if not search_results: | |
print(f" - No Wikipedia search results for '{candidate_query}'.") | |
continue | |
search_result_title = search_results[0] | |
try: | |
temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical | |
final_page_title = temp_page.title | |
if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent | |
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
if not title_matches_main_entity: | |
print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') " | |
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
continue | |
if final_page_title in processed_page_titles: | |
print(f" ~ Already processed '{final_page_title}'") | |
continue | |
page_object = temp_page | |
print(f" ✓ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'") | |
except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr: | |
print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}") | |
else: | |
print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.") | |
except wikipedia.exceptions.DisambiguationError as de: | |
print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}") | |
if de.options: | |
option_title = de.options[0] | |
try: | |
temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True) | |
final_page_title = temp_page.title | |
if is_candidate_entity_focused and main_entities_from_query: # Check against original intent | |
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
if not title_matches_main_entity: | |
print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') " | |
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
continue | |
if final_page_title in processed_page_titles: | |
print(f" ~ Already processed '{final_page_title}'") | |
continue | |
page_object = temp_page | |
print(f" ✓ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'") | |
except Exception as e_dis_opt: | |
print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}") | |
if page_object and final_page_title and (final_page_title not in processed_page_titles): | |
# Extract main text | |
main_text = page_object.content | |
# Extract tables using BeautifulSoup | |
try: | |
html = page_object.html() | |
soup = BeautifulSoup(html, 'html.parser') | |
tables = soup.find_all('table') | |
table_texts = [] | |
for table in tables: | |
rows = table.find_all('tr') | |
table_lines = [] | |
for row in rows: | |
cells = row.find_all(['th', 'td']) | |
cell_texts = [cell.get_text(strip=True) for cell in cells] | |
if cell_texts: | |
table_lines.append(" | ".join(cell_texts)) | |
if table_lines: | |
table_text = "\n".join(table_lines) | |
table_texts.append(table_text) | |
except Exception as e: | |
print(f" !! Error extracting tables for '{final_page_title}': {e}") | |
table_texts = [] | |
# Combine main text and all table texts as separate chunks | |
all_text_chunks = [main_text] + table_texts | |
for chunk in all_text_chunks: | |
found_pages_data.append((chunk, final_page_title)) | |
processed_page_titles.add(final_page_title) | |
print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}") | |
except Exception as e: | |
print(f" !! Unexpected error processing candidate '{candidate_query}': {e}") | |
if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.") | |
else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.") | |
return found_pages_data | |
def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]: | |
core_query_parts = set() | |
core_query_parts.add(query) | |
if processed_query != query: core_query_parts.add(processed_query) | |
if keywords: core_query_parts.add(" ".join(keywords[:2])) | |
section_phrases_templates = [] | |
lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords) | |
section_keywords_map = { | |
"discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"], | |
"biography": ["biography", "life story", "career details", "background history"], | |
"filmography": ["filmography", "list of films", "movie appearances", "acting roles"], | |
} | |
for section_term_key, specific_phrases_list in section_keywords_map.items(): | |
# Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums") | |
# are mentioned or implied by the query terms. | |
if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()): | |
section_phrases_templates.extend(specific_phrases_list) | |
# Also check if phrases themselves are in query terms (e.g. query "list of albums by X") | |
for phrase in specific_phrases_list: | |
if phrase in query.lower(): # Check against original query for direct phrase matches | |
section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit | |
break | |
section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate | |
final_search_queries = set() | |
if main_entities: | |
entity_prefix = main_entities[0] | |
final_search_queries.add(entity_prefix) | |
for part in core_query_parts: | |
final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part) | |
for phrase_template in section_phrases_templates: | |
final_search_queries.add(f"{entity_prefix} {phrase_template}") | |
if "list of" in phrase_template or "history of" in phrase_template : | |
final_search_queries.add(f"{phrase_template} of {entity_prefix}") | |
else: | |
final_search_queries.update(core_query_parts) | |
final_search_queries.update(section_phrases_templates) | |
deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip())) | |
print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}") | |
all_results_docs: List[Document] = [] | |
seen_content_hashes: Set[int] = set() | |
k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier | |
for search_query_variant in deduplicated_queries: | |
try: | |
results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) | |
print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.") | |
for doc, score in results: # Assuming similarity_search_with_score returns (doc, score) | |
content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness | |
if content_hash not in seen_content_hashes: | |
seen_content_hashes.add(content_hash) | |
doc.metadata['retrieved_by_variant'] = search_query_variant | |
doc.metadata['retrieval_score'] = float(score) # Store score | |
all_results_docs.append(doc) | |
except Exception as e: | |
print(f" Error in semantic search for variant '{search_query_variant}': {e}") | |
# Sort all collected unique results by score (FAISS L2 distance is lower is better) | |
all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf'))) | |
print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.") | |
return all_results_docs[:self.top_k_results] | |
def _run(self, query: str = None, search_query: str = None, **kwargs) -> str: | |
if not self._nlp or not self._embedding_model or not self._text_splitter: | |
print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.") | |
return "Error: Wikipedia tool components not initialized properly. Please check server logs." | |
if not query: | |
query = search_query or kwargs.get('q') or kwargs.get('search_term') | |
try: | |
print(f"\n--- Running {self.name} for query: '{query}' ---") | |
main_entities, keywords, processed_query = self._extract_entities_and_keywords(query) | |
print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'") | |
fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query) | |
if not fetched_pages_data: | |
return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. " | |
f"Main entities sought: {main_entities}") | |
all_page_titles = [title for _, title in fetched_pages_data] | |
print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}") | |
all_documents: List[Document] = [] | |
for page_content, page_title in fetched_pages_data: | |
chunks = self._text_splitter.split_text(page_content) | |
if not chunks: | |
print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.") | |
continue | |
for i, chunk_text in enumerate(chunks): | |
all_documents.append(Document(page_content=chunk_text, metadata={ | |
"source_page_title": page_title, | |
"original_query": query, | |
"chunk_index": i # Add chunk index for potential debugging or ordering | |
})) | |
print(f"Split content from '{page_title}' into {len(chunks)} chunks.") | |
if not all_documents: | |
return (f"Could not process content into searchable chunks from the fetched Wikipedia pages " | |
f"({', '.join(all_page_titles)}) for query '{query}'.") | |
print(f"\nTotal document chunks from all pages: {len(all_documents)}") | |
print("Creating FAISS index from content of all fetched pages...") | |
try: | |
vector_store = FAISS.from_documents(all_documents, self._embedding_model) | |
print("FAISS index created successfully.") | |
except Exception as e: | |
return f"Error creating FAISS vector store: {e}" | |
print(f"\nPerforming enhanced semantic search across all collected content...") | |
try: | |
relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query) | |
except Exception as e: | |
return f"Error during semantic search: {e}" | |
if not relevant_docs: | |
return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' " | |
f"for your query '{query}' using entity-focused semantic search with list retrieval.") | |
unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs])) | |
result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) " | |
f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n") | |
nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, " | |
f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n") | |
result_details = [] | |
for i, doc in enumerate(relevant_docs): | |
source_info = doc.metadata.get('source_page_title', 'Unknown Source') | |
variant_info = doc.metadata.get('retrieved_by_variant', 'N/A') | |
score_info = doc.metadata.get('retrieval_score', 'N/A') | |
detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n" | |
f"(Retrieved by: '{variant_info}')\n{doc.page_content}") | |
result_details.append(detail) | |
final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details) | |
print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).") | |
return final_result.strip() | |
except Exception as e: | |
import traceback | |
print(f"Unexpected error in {self.name}: {traceback.format_exc()}") | |
return f"An unexpected error occurred: {str(e)}" | |
class EnhancedYoutubeScreenshotQA(BaseTool): | |
name: str = "enhanced_youtube_screenshot_qa" | |
description: str = ( | |
"Downloads a YouTube video, intelligently extracts screenshots, " | |
"and answers questions using advanced visual QA with semantic analysis. " | |
"Use this tool for questions about the VIDEO or IMAGES in the video," | |
"Input should be a dict with keys: 'youtube_url', 'question', and optional parameters. " | |
#"Optional parameters: 'frame_interval_seconds' (default: 10), 'max_frames' (default: 50), " | |
#"'use_scene_detection' (default: True), 'parallel_processing' (default: True). " | |
"Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'question': 'What animals are visible?'}" | |
) | |
# Define Pydantic fields for the attributes we need to set | |
device: Any = Field(default=None, exclude=True) | |
processor_vqa: Any = Field(default=None, exclude=True) | |
model_vqa: Any = Field(default=None, exclude=True) | |
class Config: | |
# Allow arbitrary types (needed for torch.device, model objects) | |
arbitrary_types_allowed = True | |
# Allow extra fields to be set | |
extra = "allow" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# Initialize directories | |
cache_dir = '/tmp/youtube_qa_cache/' | |
video_dir = '/tmp/video/' | |
frames_dir = '/tmp/video_frames/' | |
# Initialize model and device | |
self._initialize_model() | |
# Create directories | |
for dir_path in [cache_dir, video_dir, frames_dir]: | |
os.makedirs(dir_path, exist_ok=True) | |
def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None): | |
"""Get configuration value with fallback to defaults""" | |
defaults = { | |
'frame_interval_seconds': 10, | |
'max_frames': 50, | |
'use_scene_detection': True, | |
'resize_frames': True, | |
'parallel_processing': True, | |
'cache_enabled': True, | |
'quality_threshold': 30.0, | |
'semantic_similarity_threshold': 0.8 | |
} | |
if input_data and key in input_data: | |
return input_data[key] | |
return defaults.get(key, default_value) | |
def _initialize_model(self): | |
"""Initialize BLIP model for VQA with error handling""" | |
try: | |
#self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.device = torch.device("cpu") | |
print(f"Using device: {self.device}") | |
self.processor_vqa = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
self.model_vqa = BlipForQuestionAnswering.from_pretrained( | |
"Salesforce/blip-vqa-base" | |
).to(self.device) | |
print("BLIP VQA model loaded successfully") | |
except Exception as e: | |
print(f"Error initializing VQA model: {str(e)}") | |
raise | |
def _get_video_hash(self, url: str) -> str: | |
"""Generate hash for video URL for caching""" | |
return hashlib.md5(url.encode()).hexdigest() | |
def _get_cache_path(self, video_hash: str, cache_type: str) -> str: | |
"""Get cache file path""" | |
cache_dir = '/tmp/youtube_qa_cache/' | |
return os.path.join(cache_dir, f"{video_hash}_{cache_type}") | |
def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: | |
"""Load data from cache""" | |
if not cache_enabled or not os.path.exists(cache_path): | |
return None | |
try: | |
with open(cache_path, 'r') as f: | |
return json.load(f) | |
except Exception as e: | |
print(f"Error loading cache: {str(e)}") | |
return None | |
def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): | |
"""Save data to cache""" | |
if not cache_enabled: | |
return | |
try: | |
with open(cache_path, 'w') as f: | |
json.dump(data, f) | |
except Exception as e: | |
print(f"Error saving cache: {str(e)}") | |
def download_youtube_video(self, url: str, video_hash: str, cache_enabled: bool = True) -> Optional[str]: | |
"""Enhanced YouTube video download with caching""" | |
video_dir = '/tmp/video/' | |
output_filename = f'{video_hash}.mp4' | |
output_path = os.path.join(video_dir, output_filename) | |
# Check cache | |
if cache_enabled and os.path.exists(output_path): | |
print(f"Using cached video: {output_path}") | |
return output_path | |
# Clean directory | |
video_dir = '/tmp/video/' | |
self._clean_directory(video_dir) | |
try: | |
ydl_opts = { | |
'format': 'bestvideo[height<=720][ext=mp4]+bestaudio[ext=m4a]/best[height<=720][ext=mp4]/best', | |
'outtmpl': output_path, | |
'quiet': True, | |
'merge_output_format': 'mp4', | |
'postprocessors': [{ | |
'key': 'FFmpegVideoConvertor', | |
'preferedformat': 'mp4', | |
}] | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
if os.path.exists(output_path): | |
print(f"Video downloaded successfully: {output_path}") | |
return output_path | |
else: | |
print("Download completed but file not found") | |
return None | |
except Exception as e: | |
print(f"Error downloading YouTube video: {str(e)}") | |
return None | |
def _clean_directory(self, directory: str): | |
"""Clean directory contents""" | |
if os.path.exists(directory): | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
def _assess_frame_quality(self, frame: np.ndarray) -> float: | |
"""Assess frame quality using Laplacian variance (blur detection)""" | |
try: | |
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
return cv2.Laplacian(gray, cv2.CV_64F).var() | |
except Exception: | |
return 0.0 | |
def _detect_scene_changes(self, video_path: str, threshold: float = 30.0) -> List[int]: | |
"""Detect scene changes in video""" | |
scene_frames = [] | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return [] | |
prev_frame = None | |
frame_count = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if prev_frame is not None: | |
# Calculate histogram difference | |
hist1 = cv2.calcHist([prev_frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) | |
hist2 = cv2.calcHist([frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) | |
diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CHISQR) | |
if diff > threshold: | |
scene_frames.append(frame_count) | |
prev_frame = frame.copy() | |
frame_count += 1 | |
cap.release() | |
return scene_frames | |
except Exception as e: | |
print(f"Error in scene detection: {str(e)}") | |
return [] | |
def smart_extract_frames(self, video_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> List[str]: | |
"""Intelligently extract frames with quality filtering and scene detection""" | |
cache_enabled = self._get_config('cache_enabled', True, input_data) | |
cache_path = self._get_cache_path(video_hash, "frames_info.json") | |
cached_info = self._load_from_cache(cache_path, cache_enabled) | |
if cached_info: | |
# Verify cached frames still exist | |
existing_frames = [f for f in cached_info['frame_paths'] if os.path.exists(f)] | |
if len(existing_frames) == len(cached_info['frame_paths']): | |
print(f"Using {len(existing_frames)} cached frames") | |
return existing_frames | |
# Clean frames directory | |
frames_dir = '/tmp/video_frames/' | |
self._clean_directory(frames_dir) | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return [] | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_interval_seconds = self._get_config('frame_interval_seconds', 10, input_data) | |
frame_interval = max(1, int(fps * frame_interval_seconds)) | |
print(f"Video info: {total_frames} frames, {fps:.2f} fps") | |
# Get scene change frames if enabled | |
scene_frames = set() | |
use_scene_detection = self._get_config('use_scene_detection', True, input_data) | |
if use_scene_detection: | |
scene_frames = set(self._detect_scene_changes(video_path)) | |
print(f"Detected {len(scene_frames)} scene changes") | |
extracted_frames = [] | |
frame_count = 0 | |
saved_count = 0 | |
max_frames = self._get_config('max_frames', 50, input_data) | |
while True: | |
ret, frame = cap.read() | |
if not ret or saved_count >= max_frames: | |
break | |
# Check if we should extract this frame | |
should_extract = ( | |
frame_count % frame_interval == 0 or | |
frame_count in scene_frames | |
) | |
if should_extract: | |
# Assess frame quality | |
quality = self._assess_frame_quality(frame) | |
quality_threshold = self._get_config('quality_threshold', 30.0, input_data) | |
if quality >= quality_threshold: | |
# Resize frame if enabled | |
resize_frames = self._get_config('resize_frames', True, input_data) | |
if resize_frames: | |
height, width = frame.shape[:2] | |
if width > 800: | |
scale = 800 / width | |
new_width = 800 | |
new_height = int(height * scale) | |
frame = cv2.resize(frame, (new_width, new_height)) | |
frame_filename = os.path.join( | |
frames_dir, | |
f"frame_{frame_count:06d}_q{quality:.1f}.jpg" | |
) | |
if cv2.imwrite(frame_filename, frame): | |
extracted_frames.append(frame_filename) | |
saved_count += 1 | |
print(f"Extracted frame {saved_count}/{max_frames} " | |
f"(quality: {quality:.1f})") | |
frame_count += 1 | |
cap.release() | |
# Cache frame information | |
frame_info = { | |
'frame_paths': extracted_frames, | |
'extraction_time': time.time(), | |
'total_frames_processed': frame_count, | |
'frames_extracted': len(extracted_frames) | |
} | |
self._save_to_cache(cache_path, frame_info, cache_enabled) | |
print(f"Successfully extracted {len(extracted_frames)} high-quality frames") | |
return extracted_frames | |
except Exception as e: | |
print(f"Exception during frame extraction: {e}") | |
return [] | |
def _answer_question_on_frame(self, frame_path: str, question: str) -> Tuple[str, float]: | |
"""Answer question on single frame with confidence scoring""" | |
try: | |
image = Image.open(frame_path).convert('RGB') | |
inputs = self.processor_vqa(image, question, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.model_vqa.generate(**inputs, output_scores=True, return_dict_in_generate=True) | |
answer = self.processor_vqa.decode(outputs.sequences[0], skip_special_tokens=True) | |
# Calculate confidence (simplified - you might want to use actual model confidence) | |
confidence = 1.0 # Placeholder - BLIP doesn't directly provide confidence | |
return answer, confidence | |
except Exception as e: | |
print(f"Error processing frame {frame_path}: {str(e)}") | |
return "Error processing this frame", 0.0 | |
def _process_frames_parallel(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> List[Tuple[str, str, float]]: | |
"""Process frames in parallel""" | |
results = [] | |
parallel_processing = self._get_config('parallel_processing', True, input_data) | |
if parallel_processing: | |
with ThreadPoolExecutor(max_workers=min(4, len(frame_files))) as executor: | |
future_to_frame = { | |
executor.submit(self._answer_question_on_frame, frame_path, question): frame_path | |
for frame_path in frame_files | |
} | |
for future in as_completed(future_to_frame): | |
frame_path = future_to_frame[future] | |
try: | |
answer, confidence = future.result() | |
results.append((frame_path, answer, confidence)) | |
print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") | |
except Exception as e: | |
print(f"Error processing {frame_path}: {str(e)}") | |
results.append((frame_path, "Error", 0.0)) | |
else: | |
for frame_path in frame_files: | |
answer, confidence = self._answer_question_on_frame(frame_path, question) | |
results.append((frame_path, answer, confidence)) | |
print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") | |
return results | |
def _cluster_similar_answers(self, answers: List[str], input_data: Dict[str, Any] = None) -> Dict[str, List[str]]: | |
"""Cluster semantically similar answers""" | |
if len(answers) <= 1: | |
return {answers[0]: answers} if answers else {} | |
try: | |
# First try with standard TF-IDF settings | |
vectorizer = TfidfVectorizer( | |
stop_words='english', | |
lowercase=True, | |
min_df=1, # Include words that appear in at least 1 document | |
max_df=1.0 # Include words that appear in up to 100% of documents | |
) | |
tfidf_matrix = vectorizer.fit_transform(answers) | |
# Check if we have any features after TF-IDF | |
if tfidf_matrix.shape[1] == 0: | |
raise ValueError("No features after TF-IDF processing") | |
# Calculate cosine similarity | |
similarity_matrix = cosine_similarity(tfidf_matrix) | |
# Cluster similar answers | |
clusters = defaultdict(list) | |
used = set() | |
semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) | |
for i, answer in enumerate(answers): | |
if i in used: | |
continue | |
cluster_key = answer | |
clusters[cluster_key].append(answer) | |
used.add(i) | |
# Find similar answers | |
for j in range(i + 1, len(answers)): | |
if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: | |
clusters[cluster_key].append(answers[j]) | |
used.add(j) | |
return dict(clusters) | |
except (ValueError, Exception) as e: | |
print(f"Error in semantic clustering: {str(e)}") | |
# Fallback 1: Try without stop words filtering | |
try: | |
print("Attempting clustering without stop word filtering...") | |
vectorizer_no_stop = TfidfVectorizer( | |
lowercase=True, | |
min_df=1, | |
token_pattern=r'\b\w+\b' # Match any word | |
) | |
tfidf_matrix = vectorizer_no_stop.fit_transform(answers) | |
if tfidf_matrix.shape[1] > 0: | |
similarity_matrix = cosine_similarity(tfidf_matrix) | |
clusters = defaultdict(list) | |
used = set() | |
semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) | |
for i, answer in enumerate(answers): | |
if i in used: | |
continue | |
cluster_key = answer | |
clusters[cluster_key].append(answer) | |
used.add(i) | |
for j in range(i + 1, len(answers)): | |
if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: | |
clusters[cluster_key].append(answers[j]) | |
used.add(j) | |
return dict(clusters) | |
except Exception as e2: | |
print(f"Fallback clustering also failed: {str(e2)}") | |
# Fallback 2: Simple string-based clustering | |
print("Using simple string-based clustering...") | |
return self._simple_string_cluster(answers) | |
def _simple_string_cluster(self, answers: List[str]) -> Dict[str, List[str]]: | |
"""Simple string-based clustering fallback""" | |
clusters = defaultdict(list) | |
# Normalize answers for comparison | |
normalized_answers = {} | |
for answer in answers: | |
normalized = answer.lower().strip() | |
normalized_answers[answer] = normalized | |
used = set() | |
for i, answer in enumerate(answers): | |
if answer in used: | |
continue | |
cluster_key = answer | |
clusters[cluster_key].append(answer) | |
used.add(answer) | |
# Find similar answers using simple string similarity | |
for j, other_answer in enumerate(answers[i+1:], i+1): | |
if other_answer in used: | |
continue | |
# Check for exact match after normalization | |
if normalized_answers[answer] == normalized_answers[other_answer]: | |
clusters[cluster_key].append(other_answer) | |
used.add(other_answer) | |
# Alternatively, check if one string contains the other | |
elif (normalized_answers[answer] in normalized_answers[other_answer] or | |
normalized_answers[other_answer] in normalized_answers[answer]): | |
clusters[cluster_key].append(other_answer) | |
used.add(other_answer) | |
return dict(clusters) | |
def _analyze_temporal_patterns(self, results: List[Tuple[str, str, float]]) -> Dict[str, Any]: | |
"""Analyze temporal patterns in answers""" | |
try: | |
# Sort by frame number | |
def get_frame_number(frame_path): | |
match = re.search(r'frame_(\d+)', os.path.basename(frame_path)) | |
return int(match.group(1)) if match else 0 | |
sorted_results = sorted(results, key=lambda x: get_frame_number(x[0])) | |
# Analyze answer changes over time | |
answers_timeline = [result[1] for result in sorted_results] | |
changes = [] | |
for i in range(1, len(answers_timeline)): | |
if answers_timeline[i] != answers_timeline[i-1]: | |
changes.append({ | |
'frame_index': i, | |
'from_answer': answers_timeline[i-1], | |
'to_answer': answers_timeline[i] | |
}) | |
return { | |
'total_changes': len(changes), | |
'change_points': changes, | |
'stability_ratio': 1 - (len(changes) / max(1, len(answers_timeline) - 1)), | |
'answers_timeline': answers_timeline | |
} | |
except Exception as e: | |
print(f"Error in temporal analysis: {str(e)}") | |
return {'error': str(e)} | |
def analyze_video_question(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]: | |
"""Comprehensive video question analysis""" | |
if not frame_files: | |
return { | |
"final_answer": "No frames available for analysis.", | |
"confidence": 0.0, | |
"frame_count": 0, | |
"error": "No valid frames found" | |
} | |
# Process all frames | |
print(f"Processing {len(frame_files)} frames...") | |
results = self._process_frames_parallel(frame_files, question, input_data) | |
if not results: | |
return { | |
"final_answer": "Could not analyze any frames successfully.", | |
"confidence": 0.0, | |
"frame_count": 0, | |
"error": "Frame processing failed" | |
} | |
# Extract answers and confidences | |
answers = [result[1] for result in results if result[1] != "Error"] | |
confidences = [result[2] for result in results if result[1] != "Error"] | |
# Calculate statistical summary on numeric answers | |
numeric_answers = [] | |
for answer in answers: | |
try: | |
# Try to convert answer to float | |
numeric_value = float(answer) | |
numeric_answers.append(numeric_value) | |
except (ValueError, TypeError): | |
# Skip non-numeric answers | |
pass | |
if numeric_answers: | |
stats = { | |
"minimum": float(np.min(numeric_answers)), | |
"maximum": float(np.max(numeric_answers)), | |
"range": float(np.max(numeric_answers) - np.min(numeric_answers)), | |
"mean": float(np.mean(numeric_answers)), | |
"median": float(np.median(numeric_answers)), | |
"count": len(numeric_answers), | |
"data_type": "answers" | |
} | |
elif confidences: | |
# Fallback to confidence statistics if no numeric answers | |
stats = { | |
"minimum": float(np.min(confidences)), | |
"maximum": float(np.max(confidences)), | |
"range": float(np.max(confidences) - np.min(confidences)), | |
"mean": float(np.mean(confidences)), | |
"median": float(np.median(confidences)), | |
"count": len(confidences), | |
"data_type": "confidences" | |
} | |
else: | |
stats = { | |
"minimum": 0.0, | |
"maximum": 0.0, | |
"range": 0.0, | |
"mean": 0.0, | |
"median": 0.0, | |
"count": 0, | |
"data_type": "none", | |
"note": "No numeric results available for statistical summary" | |
} | |
if not answers: | |
return { | |
"final_answer": "All frame processing failed.", | |
"confidence": 0.0, | |
"frame_count": len(frame_files), | |
"error": "No successful frame analysis" | |
} | |
# Cluster similar answers | |
answer_clusters = self._cluster_similar_answers(answers, input_data) | |
# Find most common cluster | |
largest_cluster = max(answer_clusters.items(), key=lambda x: len(x[1])) | |
most_common_answer = largest_cluster[0] | |
cluster_size = len(largest_cluster[1]) | |
# Calculate weighted confidence | |
answer_counts = Counter(answers) | |
total_answers = len(answers) | |
frequency_confidence = answer_counts[most_common_answer] / total_answers | |
avg_confidence = np.mean(confidences) if confidences else 0.0 | |
final_confidence = (frequency_confidence * 0.7) + (avg_confidence * 0.3) | |
# Temporal analysis | |
temporal_analysis = self._analyze_temporal_patterns(results) | |
return { | |
"final_answer": most_common_answer, | |
"confidence": final_confidence, | |
"frame_count": len(frame_files), | |
"successful_analyses": len(answers), | |
"answer_distribution": dict(answer_counts), | |
"semantic_clusters": {k: len(v) for k, v in answer_clusters.items()}, | |
"temporal_analysis": temporal_analysis, | |
"average_model_confidence": avg_confidence, | |
"frequency_confidence": frequency_confidence, | |
"statistical_summary": stats | |
} | |
#def _run(self, query: Dict[str, Any]) -> str: | |
def _run(self, youtube_url, question, **kwargs) -> str: | |
"""Enhanced main execution method""" | |
#ipdb.set_trace() | |
#input_data = query | |
#youtube_url = input_data.get("youtube_url") | |
#question = input_data.get("question") | |
input_data = { | |
'youtube_url': youtube_url, | |
'question': question | |
} | |
if not youtube_url or not question: | |
return "Error: Input must include 'youtube_url' and 'question'." | |
try: | |
# Generate video hash for caching | |
video_hash = self._get_video_hash(youtube_url) | |
# Step 1: Download video | |
print(f"Downloading YouTube video from {youtube_url}...") | |
cache_enabled = self._get_config('cache_enabled', True, input_data) | |
video_path = self.download_youtube_video(youtube_url, video_hash, cache_enabled) | |
if not video_path or not os.path.exists(video_path): | |
return "Error: Failed to download the YouTube video." | |
# Step 2: Smart frame extraction | |
print(f"Extracting frames with smart selection...") | |
frame_files = self.smart_extract_frames(video_path, video_hash, input_data) | |
if not frame_files: | |
return "Error: Failed to extract frames from the video." | |
# Step 3: Comprehensive analysis | |
print(f"Analyzing {len(frame_files)} frames for question: '{question}'") | |
analysis_result = self.analyze_video_question(frame_files, question, input_data) | |
if analysis_result.get("error"): | |
return f"Error: {analysis_result['error']}" | |
# Format comprehensive result - Fixed the reference to stats | |
result = f""" | |
📊 **ANALYSIS SUMMARY**: | |
• Confidence Score: {analysis_result['confidence']:.2%} | |
• Frames Analyzed: {analysis_result['successful_analyses']}/{analysis_result['frame_count']} | |
• Answer Consistency: {analysis_result['temporal_analysis'].get('stability_ratio', 0):.2%} | |
📈 **ANSWER DISTRIBUTION**: | |
{chr(10).join([f"• {answer}: {count} frames" for answer, count in analysis_result['answer_distribution'].items()])} | |
🔍 **SEMANTIC CLUSTERS**: | |
{chr(10).join([f"• '{cluster}': {count} similar answers" for cluster, count in analysis_result['semantic_clusters'].items()])} | |
⏱️ **TEMPORAL ANALYSIS**: | |
• Answer Changes: {analysis_result['temporal_analysis'].get('total_changes', 0)} | |
• Stability: {analysis_result['temporal_analysis'].get('stability_ratio', 0):.2%} | |
📊 **STATISTICAL SUMMARY**: | |
• Minimum: {analysis_result['statistical_summary']['minimum']:.2f} | |
• Maximum: {analysis_result['statistical_summary']['maximum']:.2f} | |
• Mean: {analysis_result['statistical_summary']['mean']:.2f} | |
• Median: {analysis_result['statistical_summary']['median']:.2f} | |
• Range: {analysis_result['statistical_summary']['range']:.2f} | |
🎯 **CONFIDENCE BREAKDOWN**: | |
• Frequency-based: {analysis_result['frequency_confidence']:.2%} | |
• Model-based: {analysis_result['average_model_confidence']:.2%} | |
• Combined: {analysis_result['confidence']:.2%} | |
""".strip() | |
return result | |
except Exception as e: | |
return f"Error during video analysis: {str(e)}" | |
# Initialize the enhanced tool | |
def create_enhanced_youtube_qa_tool(**kwargs): | |
"""Factory function to create the enhanced tool with custom parameters""" | |
return EnhancedYoutubeScreenshotQA(**kwargs) | |
# Example of creating the tool instance: | |
# wikipedia_tool_faiss = WikipediaSearchToolWithFAISS() | |
# To use this new tool in your agent, you would replace the old | |
# `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list. | |
# For example: | |
# tools = [wikipedia_tool_faiss, search_tool] | |
# Create tool instances | |
#wikipedia_tool = WikipediaSearchTool() | |
# --- Define Call LLM function --- | |
# 3. Improved LLM call with memory management | |
class YouTubeTranscriptExtractor(BaseTool): | |
name: str = "youtube_transcript_extractor" | |
description: str = ( | |
"Downloads a YouTube video and extracts the complete audio transcript using speech recognition with speaker identification. " | |
"Use this tool when you need the AUDIO or DIALOGUE or sound from a YouTube video with speaker tags," | |
"Input should be a dict with keys: 'youtube_url' and optional parameters. " | |
"Optional parameters: 'language' (default: 'en-US'), 'chunk_length_ms' (default: 30000), " | |
"'silence_thresh' (default: -40), 'use_enhanced_model' (default: True), 'audio_quality' (default: 'best'), " | |
"'enable_speaker_id' (default: True), 'max_speakers' (default: 5), 'speaker_min_duration' (default: 2.0). " | |
"Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'language': 'en-US', 'enable_speaker_id': True}" | |
) | |
# Define Pydantic fields for the attributes we need to set | |
recognizer: Any = Field(default=None, exclude=True) | |
class Config: | |
# Allow arbitrary types | |
arbitrary_types_allowed = True | |
# Allow extra fields to be set | |
extra = "allow" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# Initialize directories | |
cache_dir = '/tmp/youtube_transcript_cache/' | |
audio_dir = '/tmp/audio/' | |
chunks_dir = '/tmp/audio_chunks/' | |
# Initialize speech recognizer | |
self.recognizer = sr.Recognizer() | |
# Create directories | |
for dir_path in [cache_dir, audio_dir, chunks_dir]: | |
os.makedirs(dir_path, exist_ok=True) | |
def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None): | |
"""Get configuration value with fallback to defaults""" | |
defaults = { | |
'language': 'en-US', | |
'chunk_length_ms': 30000, # 30 seconds | |
'silence_thresh': -40, # dB | |
'use_enhanced_model': True, | |
'audio_quality': 'best', | |
'cache_enabled': True, | |
'parallel_processing': True, | |
'overlap_ms': 1000, # 1 second overlap between chunks | |
'min_silence_len': 500, # minimum silence length to split on | |
'energy_threshold': 4000, # recognizer energy threshold | |
'pause_threshold': 0.8, # recognizer pause threshold | |
'enable_speaker_id': True, # enable speaker identification | |
'max_speakers': 5, # maximum number of speakers to identify | |
'speaker_min_duration': 2.0, # minimum duration (seconds) for speaker segment | |
'speaker_confidence_threshold': 0.6, # confidence threshold for speaker assignment | |
'voice_activity_threshold': 0.01 # threshold for voice activity detection | |
} | |
if input_data and key in input_data: | |
return input_data[key] | |
return defaults.get(key, default_value) | |
def _get_video_hash(self, url: str) -> str: | |
"""Generate hash for video URL for caching""" | |
return hashlib.md5(url.encode()).hexdigest() | |
def _get_cache_path(self, video_hash: str, cache_type: str) -> str: | |
"""Get cache file path""" | |
cache_dir = '/tmp/youtube_transcript_cache/' | |
return os.path.join(cache_dir, f"{video_hash}_{cache_type}") | |
def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: | |
"""Load data from cache""" | |
if not cache_enabled or not os.path.exists(cache_path): | |
return None | |
try: | |
with open(cache_path, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
except Exception as e: | |
print(f"Error loading cache: {str(e)}") | |
return None | |
def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): | |
"""Save data to cache""" | |
if not cache_enabled: | |
return | |
try: | |
with open(cache_path, 'w', encoding='utf-8') as f: | |
json.dump(data, f, ensure_ascii=False, indent=2) | |
except Exception as e: | |
print(f"Error saving cache: {str(e)}") | |
def download_youtube_audio(self, url: str, video_hash: str, input_data: Dict[str, Any] = None) -> Optional[str]: | |
"""Download YouTube video as audio file""" | |
audio_dir = '/tmp/audio/' | |
audio_quality = self._get_config('audio_quality', 'best', input_data) | |
output_filename = f'{video_hash}.wav' | |
output_path = os.path.join(audio_dir, output_filename) | |
# Check cache | |
cache_enabled = self._get_config('cache_enabled', True, input_data) | |
if cache_enabled and os.path.exists(output_path): | |
print(f"Using cached audio: {output_path}") | |
return output_path | |
# Clean directory | |
self._clean_directory(audio_dir) | |
try: | |
# First download as mp4/webm | |
temp_video_path = os.path.join(audio_dir, f'{video_hash}_temp.%(ext)s') | |
ydl_opts = { | |
'format': 'bestaudio/best' if audio_quality == 'best' else 'worstaudio/worst', | |
'outtmpl': temp_video_path, | |
'quiet': True, | |
'extractaudio': True, | |
'audioformat': 'wav', | |
'audioquality': '192K' if audio_quality == 'best' else '64K', | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
# Find the downloaded file | |
temp_files = glob.glob(os.path.join(audio_dir, f'{video_hash}_temp.*')) | |
if not temp_files: | |
print("No temporary audio file found") | |
return None | |
temp_file = temp_files[0] | |
# Convert to WAV if not already | |
if not temp_file.endswith('.wav'): | |
try: | |
audio = AudioSegment.from_file(temp_file) | |
audio.export(output_path, format="wav") | |
os.remove(temp_file) # Clean up temp file | |
except Exception as e: | |
print(f"Error converting audio: {str(e)}") | |
# Try to rename if it's already the right format | |
if os.path.exists(temp_file): | |
os.rename(temp_file, output_path) | |
else: | |
os.rename(temp_file, output_path) | |
if os.path.exists(output_path): | |
print(f"Audio extracted successfully: {output_path}") | |
return output_path | |
else: | |
print("Audio extraction completed but file not found") | |
return None | |
except Exception as e: | |
print(f"Error downloading YouTube audio: {str(e)}") | |
return None | |
def _clean_directory(self, directory: str): | |
"""Clean directory contents""" | |
if os.path.exists(directory): | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
def _extract_voice_features(self, audio_path: str) -> Optional[np.ndarray]: | |
"""Extract voice features for speaker identification using librosa""" | |
try: | |
# Load audio with librosa | |
y, sr = librosa.load(audio_path, sr=None) | |
# Extract MFCC features (commonly used for speaker identification) | |
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) | |
# Extract additional features | |
spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr) | |
spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr) | |
zero_crossing_rate = librosa.feature.zero_crossing_rate(y) | |
# Combine features and take mean across time | |
features = np.concatenate([ | |
np.mean(mfccs, axis=1), | |
np.mean(spectral_centroids), | |
np.mean(spectral_rolloff), | |
np.mean(zero_crossing_rate) | |
]) | |
return features | |
except Exception as e: | |
print(f"Error extracting voice features from {audio_path}: {str(e)}") | |
return None | |
def _detect_voice_activity(self, audio_path: str, input_data: Dict[str, Any] = None) -> List[Tuple[float, float]]: | |
"""Detect voice activity in audio chunk""" | |
try: | |
y, sr = librosa.load(audio_path, sr=None) | |
# Simple voice activity detection based on energy | |
frame_length = int(0.025 * sr) # 25ms frames | |
hop_length = int(0.010 * sr) # 10ms hop | |
# Calculate short-time energy | |
energy = [] | |
for i in range(0, len(y) - frame_length, hop_length): | |
frame = y[i:i + frame_length] | |
energy.append(np.sum(frame ** 2)) | |
energy = np.array(energy) | |
threshold = self._get_config('voice_activity_threshold', 0.01, input_data) | |
# Find voice segments | |
voice_frames = energy > (np.max(energy) * threshold) | |
# Convert frame indices to time segments | |
voice_segments = [] | |
in_voice = False | |
start_time = 0 | |
for i, is_voice in enumerate(voice_frames): | |
time_sec = i * hop_length / sr | |
if is_voice and not in_voice: | |
start_time = time_sec | |
in_voice = True | |
elif not is_voice and in_voice: | |
voice_segments.append((start_time, time_sec)) | |
in_voice = False | |
# Close last segment if needed | |
if in_voice: | |
voice_segments.append((start_time, len(y) / sr)) | |
return voice_segments | |
except Exception as e: | |
print(f"Error in voice activity detection: {str(e)}") | |
return [(0, librosa.get_duration(filename=audio_path))] | |
def _split_audio_intelligent(self, audio_path: str, input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]: | |
"""Split audio into chunks intelligently based on silence and voice activity""" | |
chunks_dir = '/tmp/audio_chunks/' | |
self._clean_directory(chunks_dir) | |
try: | |
# Load audio | |
audio = AudioSegment.from_wav(audio_path) | |
# Get configuration | |
chunk_length_ms = self._get_config('chunk_length_ms', 30000, input_data) | |
silence_thresh = self._get_config('silence_thresh', -40, input_data) | |
min_silence_len = self._get_config('min_silence_len', 500, input_data) | |
overlap_ms = self._get_config('overlap_ms', 1000, input_data) | |
# First try to split on silence | |
chunks = split_on_silence( | |
audio, | |
min_silence_len=min_silence_len, | |
silence_thresh=silence_thresh, | |
keep_silence=True | |
) | |
# If no silence-based splits or chunks too large, split by time | |
if not chunks or any(len(chunk) > chunk_length_ms * 2 for chunk in chunks): | |
print("Using time-based splitting...") | |
chunks = [] | |
for i in range(0, len(audio), chunk_length_ms - overlap_ms): | |
chunk = audio[i:i + chunk_length_ms] | |
if len(chunk) > 1000: # Only add chunks longer than 1 second | |
chunks.append(chunk) | |
# Save chunks and create metadata | |
chunk_data = [] | |
for i, chunk in enumerate(chunks): | |
if len(chunk) < 1000: # Skip very short chunks | |
continue | |
chunk_filename = os.path.join(chunks_dir, f"chunk_{i:04d}.wav") | |
chunk.export(chunk_filename, format="wav") | |
# Calculate timing information | |
start_time = sum(len(chunks[j]) for j in range(i)) / 1000.0 # in seconds | |
duration = len(chunk) / 1000.0 # in seconds | |
chunk_info = { | |
'filename': chunk_filename, | |
'index': i, | |
'start_time': start_time, | |
'duration': duration, | |
'end_time': start_time + duration | |
} | |
chunk_data.append(chunk_info) | |
print(f"Split audio into {len(chunk_data)} chunks") | |
return chunk_data | |
except Exception as e: | |
print(f"Error splitting audio: {str(e)}") | |
# Fallback: return original file | |
return [{ | |
'filename': audio_path, | |
'index': 0, | |
'start_time': 0, | |
'duration': len(AudioSegment.from_wav(audio_path)) / 1000.0, | |
'end_time': len(AudioSegment.from_wav(audio_path)) / 1000.0 | |
}] | |
def _transcribe_audio_chunk(self, chunk_info: Dict[str, Any], input_data: Dict[str, Any] = None) -> Dict[str, Any]: | |
"""Transcribe a single audio chunk""" | |
chunk_path = chunk_info['filename'] | |
try: | |
language = self._get_config('language', 'en-US', input_data) | |
# Configure recognizer | |
self.recognizer.energy_threshold = self._get_config('energy_threshold', 4000, input_data) | |
self.recognizer.pause_threshold = self._get_config('pause_threshold', 0.8, input_data) | |
with sr.AudioFile(chunk_path) as source: | |
# Adjust for ambient noise | |
self.recognizer.adjust_for_ambient_noise(source, duration=0.5) | |
audio_data = self.recognizer.record(source) | |
# Try Google Speech Recognition first (most accurate) | |
try: | |
text = self.recognizer.recognize_google(audio_data, language=language) | |
result = { | |
'text': text, | |
'confidence': 1.0, # Google doesn't provide confidence | |
'method': 'google', | |
'chunk': os.path.basename(chunk_path), | |
'start_time': chunk_info['start_time'], | |
'end_time': chunk_info['end_time'], | |
'duration': chunk_info['duration'], | |
'index': chunk_info['index'] | |
} | |
# Extract voice features if speaker ID is enabled | |
if self._get_config('enable_speaker_id', True, input_data): | |
features = self._extract_voice_features(chunk_path) | |
result['voice_features'] = features.tolist() if features is not None else None | |
return result | |
except sr.UnknownValueError: | |
# Try alternative recognition methods | |
try: | |
# Try with alternative language detection | |
text = self.recognizer.recognize_google(audio_data) | |
result = { | |
'text': text, | |
'confidence': 0.8, # Lower confidence for language mismatch | |
'method': 'google_auto', | |
'chunk': os.path.basename(chunk_path), | |
'start_time': chunk_info['start_time'], | |
'end_time': chunk_info['end_time'], | |
'duration': chunk_info['duration'], | |
'index': chunk_info['index'] | |
} | |
if self._get_config('enable_speaker_id', True, input_data): | |
features = self._extract_voice_features(chunk_path) | |
result['voice_features'] = features.tolist() if features is not None else None | |
return result | |
except sr.UnknownValueError: | |
return { | |
'text': '[INAUDIBLE]', | |
'confidence': 0.0, | |
'method': 'failed', | |
'chunk': os.path.basename(chunk_path), | |
'start_time': chunk_info['start_time'], | |
'end_time': chunk_info['end_time'], | |
'duration': chunk_info['duration'], | |
'index': chunk_info['index'], | |
'voice_features': None | |
} | |
except sr.RequestError as e: | |
print(f"Google Speech Recognition error: {e}") | |
return { | |
'text': '[RECOGNITION_ERROR]', | |
'confidence': 0.0, | |
'method': 'error', | |
'chunk': os.path.basename(chunk_path), | |
'start_time': chunk_info['start_time'], | |
'end_time': chunk_info['end_time'], | |
'duration': chunk_info['duration'], | |
'index': chunk_info['index'], | |
'error': str(e), | |
'voice_features': None | |
} | |
except Exception as e: | |
print(f"Error transcribing chunk {chunk_path}: {str(e)}") | |
return { | |
'text': '[ERROR]', | |
'confidence': 0.0, | |
'method': 'error', | |
'chunk': os.path.basename(chunk_path), | |
'start_time': chunk_info.get('start_time', 0), | |
'end_time': chunk_info.get('end_time', 0), | |
'duration': chunk_info.get('duration', 0), | |
'index': chunk_info.get('index', 0), | |
'error': str(e), | |
'voice_features': None | |
} | |
def _identify_speakers(self, transcript_results: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]: | |
"""Identify speakers using voice features clustering""" | |
enable_speaker_id = self._get_config('enable_speaker_id', True, input_data) | |
if not enable_speaker_id: | |
# Add default speaker tags | |
for result in transcript_results: | |
result['speaker_id'] = 'SPEAKER_1' | |
result['speaker_confidence'] = 1.0 | |
return transcript_results | |
try: | |
# Filter results with valid voice features and text | |
valid_results = [] | |
features_list = [] | |
for result in transcript_results: | |
if (result.get('voice_features') is not None and | |
result['text'] not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']): | |
valid_results.append(result) | |
features_list.append(result['voice_features']) | |
if len(features_list) < 2: | |
# Not enough data for clustering | |
for result in transcript_results: | |
result['speaker_id'] = 'SPEAKER_1' | |
result['speaker_confidence'] = 1.0 | |
return transcript_results | |
# Normalize features | |
features_array = np.array(features_list) | |
scaler = StandardScaler() | |
normalized_features = scaler.fit_transform(features_array) | |
# Determine optimal number of speakers | |
max_speakers = min(self._get_config('max_speakers', 5, input_data), len(features_list)) | |
# Use elbow method to find optimal clusters (simplified) | |
best_k = 1 | |
if len(features_list) > 1: | |
best_score = float('inf') | |
for k in range(1, min(max_speakers + 1, len(features_list) + 1)): | |
try: | |
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10) | |
labels = kmeans.fit_predict(normalized_features) | |
if k > 1: | |
score = kmeans.inertia_ | |
if score < best_score: | |
best_score = score | |
best_k = k | |
except: | |
continue | |
# Don't use too many clusters for short audio | |
if len(features_list) < 10: | |
best_k = min(best_k, 2) | |
# Perform final clustering | |
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=10) | |
speaker_labels = kmeans.fit_predict(normalized_features) | |
# Calculate speaker assignment confidence | |
distances = kmeans.transform(normalized_features) | |
confidences = [] | |
for i, label in enumerate(speaker_labels): | |
# Confidence based on distance to assigned cluster vs. nearest other cluster | |
dist_to_assigned = distances[i][label] | |
other_distances = np.delete(distances[i], label) | |
if len(other_distances) > 0: | |
dist_to_nearest_other = np.min(other_distances) | |
confidence = max(0.1, min(1.0, dist_to_nearest_other / (dist_to_assigned + 1e-6))) | |
else: | |
confidence = 1.0 | |
confidences.append(confidence) | |
# Assign speaker IDs back to results | |
valid_idx = 0 | |
speaker_duration = {} # Track duration per speaker | |
for result in transcript_results: | |
if (result.get('voice_features') is not None and | |
result['text'] not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']): | |
speaker_label = speaker_labels[valid_idx] | |
confidence = confidences[valid_idx] | |
# Filter by confidence threshold | |
conf_threshold = self._get_config('speaker_confidence_threshold', 0.6, input_data) | |
if confidence < conf_threshold: | |
speaker_id = 'SPEAKER_UNKNOWN' | |
else: | |
speaker_id = f'SPEAKER_{speaker_label + 1}' | |
result['speaker_id'] = speaker_id | |
result['speaker_confidence'] = confidence | |
# Track speaker duration | |
if speaker_id in speaker_duration: | |
speaker_duration[speaker_id] += result['duration'] | |
else: | |
speaker_duration[speaker_id] = result['duration'] | |
valid_idx += 1 | |
else: | |
# Handle invalid results | |
result['speaker_id'] = 'SPEAKER_UNKNOWN' | |
result['speaker_confidence'] = 0.0 | |
# Filter out speakers with insufficient duration | |
min_duration = self._get_config('speaker_min_duration', 2.0, input_data) | |
speakers_to_merge = [s for s, d in speaker_duration.items() if d < min_duration and s != 'SPEAKER_UNKNOWN'] | |
# Merge low-duration speakers into SPEAKER_UNKNOWN | |
for result in transcript_results: | |
if result['speaker_id'] in speakers_to_merge: | |
result['speaker_id'] = 'SPEAKER_UNKNOWN' | |
result['speaker_confidence'] = 0.3 | |
print(f"Identified {best_k} speakers based on voice characteristics") | |
return transcript_results | |
except Exception as e: | |
print(f"Error in speaker identification: {str(e)}") | |
# Fallback: assign all to single speaker | |
for result in transcript_results: | |
result['speaker_id'] = 'SPEAKER_1' | |
result['speaker_confidence'] = 1.0 | |
return transcript_results | |
def _transcribe_chunks_parallel(self, chunk_data: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> List[Dict[str, Any]]: | |
"""Transcribe audio chunks in parallel""" | |
results = [] | |
parallel_processing = self._get_config('parallel_processing', True, input_data) | |
if parallel_processing: | |
# Use fewer workers for speech recognition to avoid API limits | |
max_workers = min(3, len(chunk_data)) | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
future_to_chunk = { | |
executor.submit(self._transcribe_audio_chunk, chunk_info, input_data): chunk_info | |
for chunk_info in chunk_data | |
} | |
for future in as_completed(future_to_chunk): | |
chunk_info = future_to_chunk[future] | |
try: | |
result = future.result() | |
results.append(result) | |
print(f"Transcribed {result['chunk']}: {result['text'][:50]}..." if len(result['text']) > 50 else f"Transcribed {result['chunk']}: {result['text']}") | |
except Exception as e: | |
print(f"Error processing {chunk_info['filename']}: {str(e)}") | |
results.append({ | |
'text': '[PROCESSING_ERROR]', | |
'confidence': 0.0, | |
'method': 'error', | |
'chunk': os.path.basename(chunk_info['filename']), | |
'start_time': chunk_info.get('start_time', 0), | |
'end_time': chunk_info.get('end_time', 0), | |
'duration': chunk_info.get('duration', 0), | |
'index': chunk_info.get('index', 0), | |
'error': str(e), | |
'voice_features': None | |
}) | |
else: | |
for chunk_info in chunk_data: | |
result = self._transcribe_audio_chunk(chunk_info, input_data) | |
results.append(result) | |
print(f"Transcribed {result['chunk']}: {result['text'][:50]}..." if len(result['text']) > 50 else f"Transcribed {result['chunk']}: {result['text']}") | |
# Sort results by chunk index to maintain order | |
results.sort(key=lambda x: x['index']) | |
return results | |
def _post_process_transcript(self, transcript_results: List[Dict[str, Any]], input_data: Dict[str, Any] = None) -> Dict[str, Any]: | |
"""Post-process and analyze transcript results with speaker information""" | |
enable_speaker_id = self._get_config('enable_speaker_id', True, input_data) | |
# Identify speakers if enabled | |
if enable_speaker_id: | |
transcript_results = self._identify_speakers(transcript_results, input_data) | |
# Combine text with speaker tags | |
full_text_parts = [] | |
speaker_tagged_text = [] | |
successful_chunks = 0 | |
total_confidence = 0.0 | |
method_counts = {} | |
speaker_stats = {} | |
current_speaker = None | |
current_speaker_text = [] | |
for result in transcript_results: | |
text = result['text'] | |
speaker = result.get('speaker_id', 'SPEAKER_1') | |
start_time = result.get('start_time', 0) | |
if text not in ['[INAUDIBLE]', '[RECOGNITION_ERROR]', '[ERROR]', '[PROCESSING_ERROR]']: | |
full_text_parts.append(text) | |
successful_chunks += 1 | |
total_confidence += result['confidence'] | |
# Handle speaker transitions | |
if enable_speaker_id: | |
if current_speaker != speaker: | |
# Save previous speaker's text | |
if current_speaker and current_speaker_text: | |
combined_text = ' '.join(current_speaker_text) | |
speaker_tagged_text.append(f"[{current_speaker}]: {combined_text}") | |
# Start new speaker | |
current_speaker = speaker | |
current_speaker_text = [text] | |
else: | |
# Continue with same speaker | |
current_speaker_text.append(text) | |
else: | |
speaker_tagged_text.append(text) | |
# Update speaker statistics | |
if speaker in speaker_stats: | |
speaker_stats[speaker]['duration'] += result.get('duration', 0) | |
speaker_stats[speaker]['word_count'] += len(text.split()) | |
speaker_stats[speaker]['segments'] += 1 | |
else: | |
speaker_stats[speaker] = { | |
'duration': result.get('duration', 0), | |
'word_count': len(text.split()), | |
'segments': 1, | |
'confidence': result.get('speaker_confidence', 1.0) | |
} | |
method = result['method'] | |
method_counts[method] = method_counts.get(method, 0) + 1 | |
# Add final speaker text | |
if enable_speaker_id and current_speaker and current_speaker_text: | |
combined_text = ' '.join(current_speaker_text) | |
speaker_tagged_text.append(f"[{current_speaker}]: {combined_text}") | |
# Combine texts | |
combined_text = ' '.join(full_text_parts) | |
speaker_formatted_text = combined_text | |
# Calculate statistics | |
word_count = len(combined_text.split()) if combined_text else 0 | |
char_count = len(combined_text) | |
avg_confidence = total_confidence / max(1, successful_chunks) | |
success_rate = successful_chunks / len(transcript_results) if transcript_results else 0 | |
# Estimate speaking duration (rough approximation: 150 words per minute) | |
estimated_duration_minutes = word_count / 150 if word_count > 0 else 0 | |
return { | |
'full_transcript': combined_text, | |
'speaker_tagged_transcript': speaker_formatted_text, | |
'word_count': word_count, | |
'character_count': char_count, | |
'chunk_count': len(transcript_results), | |
'successful_chunks': successful_chunks, | |
'success_rate': success_rate, | |
'average_confidence': avg_confidence, | |
'method_distribution': method_counts, | |
'estimated_duration_minutes': estimated_duration_minutes, | |
'speaker_identification_enabled': enable_speaker_id, | |
'speaker_statistics': speaker_stats, | |
'total_speakers': len([s for s in speaker_stats.keys() if s != 'SPEAKER_UNKNOWN']), | |
'detailed_results': transcript_results | |
} | |
def extract_transcript(self, audio_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]: | |
"""Extract complete transcript from audio file""" | |
cache_enabled = self._get_config('cache_enabled', True, input_data) | |
enable_speaker_id = self._get_config('enable_speaker_id', True, input_data) | |
cache_suffix = "transcript_with_speakers.json" if enable_speaker_id else "transcript.json" | |
cache_path = self._get_cache_path(video_hash, cache_suffix) | |
# Check cache | |
cached_transcript = self._load_from_cache(cache_path, cache_enabled) | |
if cached_transcript: | |
print("Using cached transcript") | |
return cached_transcript | |
try: | |
# Step 1: Split audio into manageable chunks | |
print("Splitting audio into chunks...") | |
chunk_data = self._split_audio_intelligent(audio_path, input_data) | |
if not chunk_data: | |
return { | |
'error': 'Failed to split audio into chunks', | |
'full_transcript': '', | |
'speaker_tagged_transcript': '', | |
'success_rate': 0.0 | |
} | |
# Step 2: Transcribe all chunks | |
print(f"Transcribing {len(chunk_data)} audio chunks...") | |
transcript_results = self._transcribe_chunks_parallel(chunk_data, input_data) | |
# Step 3: Post-process and combine results | |
print("Post-processing transcript and identifying speakers...") | |
final_result = self._post_process_transcript(transcript_results, input_data) | |
# Add timestamp | |
final_result['extraction_timestamp'] = time.time() | |
final_result['extraction_date'] = time.strftime('%Y-%m-%d %H:%M:%S') | |
# Cache results | |
self._save_to_cache(cache_path, final_result, cache_enabled) | |
return final_result | |
except Exception as e: | |
print(f"Error during transcript extraction: {str(e)}") | |
return { | |
'error': str(e), | |
'full_transcript': '', | |
'speaker_tagged_transcript': '', | |
'success_rate': 0.0 | |
} | |
def _run(self, youtube_url: str, **kwargs) -> str: | |
"""Main execution method""" | |
input_data = { | |
'youtube_url': youtube_url, | |
**kwargs | |
} | |
if not youtube_url: | |
return "Error: youtube_url is required." | |
try: | |
# Generate video hash for caching | |
video_hash = self._get_video_hash(youtube_url) | |
# Step 1: Download audio | |
print(f"Downloading YouTube audio from {youtube_url}...") | |
audio_path = self.download_youtube_audio(youtube_url, video_hash, input_data) | |
if not audio_path or not os.path.exists(audio_path): | |
return "Error: Failed to download the YouTube audio." | |
# Step 2: Extract transcript | |
print("Extracting audio transcript...") | |
transcript_result = self.extract_transcript(audio_path, video_hash, input_data) | |
if transcript_result.get("error"): | |
return f"Error: {transcript_result['error']}" | |
# Choose the appropriate transcript | |
main_transcript = transcript_result.get('full_transcript') | |
#ipdb.set_trace() | |
print(f"Transcript extracted: {main_transcript[:50]}..." if len(main_transcript) > 50 else f"Transcript extracted: {main_transcript}") | |
return "TRANSCRIPT: " + main_transcript | |
except Exception as e: | |
return f"Error during transcript extraction: {str(e)}" | |
# Factory function to create the tool | |
def create_youtube_transcript_tool(**kwargs): | |
"""Factory function to create the transcript extraction tool with custom parameters""" | |
return YouTubeTranscriptExtractor(**kwargs) | |
# --- Model Configuration --- | |
def create_llm_pipeline(): | |
#model_id = "meta-llama/Llama-2-13b-chat-hf" | |
#model_id = "meta-llama/Llama-3.3-70B-Instruct" | |
#model_id = "mistralai/Mistral-Small-24B-Base-2501" | |
model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
#model_id = "Meta-Llama/Llama-2-7b-chat-hf" | |
#model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" | |
#model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF" | |
#model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
#model_id = "Qwen/Qwen2-7B-Instruct" | |
#model_id = "GSAI-ML/LLaDA-8B-Instruct" | |
return pipeline( | |
"text-generation", | |
model=model_id, | |
device_map="cpu", | |
torch_dtype=torch.float16, | |
max_new_tokens=1024, | |
temperature=0.3, | |
top_k=50, | |
top_p=0.95 | |
) | |
nlp = None # Set to None if not using spaCy, so the regex fallback is used in extract_entities | |
# --- Agent State Definition --- | |
class AgentState(TypedDict): | |
messages: Annotated[List[AnyMessage], lambda x, y: x + y] | |
done: bool = False # Default value of False | |
question: str | |
task_id: str | |
input_file: Optional[bytes] | |
file_type: Optional[str] | |
context: List[Document] # Using LangChain's Document class | |
file_path: Optional[str] | |
youtube_url: Optional[str] | |
answer: Optional[str] | |
frame_answers: Optional[list] | |
# --- Define Call LLM function --- | |
# 3. Improved LLM call with memory management | |
def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: | |
"""Enhanced LLM call with better prompt engineering and hallucination prevention.""" | |
print("Running call_llm with memory management...") | |
#ipdb.set_trace() | |
original_messages = messages_for_llm = state["messages"] | |
# Context management - be more aggressive about truncation | |
system_message_content = None | |
if messages_for_llm and isinstance(messages_for_llm[0], SystemMessage): | |
system_message_content = messages_for_llm[0] | |
regular_messages = messages_for_llm[1:] | |
else: | |
regular_messages = messages_for_llm | |
# Keep only the most recent messages (more aggressive) | |
max_regular_messages = 6 # Reduced from 9 | |
if len(regular_messages) > max_regular_messages: | |
print(f"🔄 Truncating to {max_regular_messages} recent messages") | |
regular_messages = regular_messages[-max_regular_messages:] | |
# Reconstruct for LLM | |
messages_for_llm = [] | |
if system_message_content: | |
messages_for_llm.append(system_message_content) | |
messages_for_llm.extend(regular_messages) | |
# Character limit check | |
total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) | |
char_limit = 20000 | |
if total_chars > char_limit: | |
print(f"📏 Context too long ({total_chars} chars) - further truncation") | |
while regular_messages and sum(len(str(m.content)) for m in regular_messages) > char_limit - (len(str(system_message_content.content)) if system_message_content else 0): | |
regular_messages.pop(0) | |
messages_for_llm = [] | |
if system_message_content: | |
messages_for_llm.append(system_message_content) | |
messages_for_llm.extend(regular_messages) | |
new_state = state.copy() | |
try: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"🤖 Calling LLM with {len(messages_for_llm)} messages") | |
# Convert to simple string format that the model can understand | |
if len(messages_for_llm) == 2 and isinstance(messages_for_llm[0], SystemMessage) and isinstance(messages_for_llm[1], HumanMessage): | |
# Initial query - use simple format | |
system_content = messages_for_llm[0].content | |
human_content = messages_for_llm[1].content | |
formatted_input = f"{system_content}\n\nHuman: {human_content}\n\nAssistant:" | |
else: | |
# Ongoing conversation - build context | |
formatted_input = "" | |
# Add system message if present | |
if system_message_content: | |
formatted_input += f"{system_message_content.content}\n\n" | |
# Add conversation messages | |
for msg in regular_messages: | |
if isinstance(msg, HumanMessage): | |
formatted_input += f"Human: {msg.content}\n\n" | |
elif isinstance(msg, AIMessage): | |
formatted_input += f"Assistant: {msg.content}\n\n" | |
elif isinstance(msg, ToolMessage): | |
formatted_input += f"Tool Result: {msg.content}\n\n" | |
# Add explicit instruction for immediate final answer if we have recent tool results | |
if any(isinstance(msg, ToolMessage) for msg in regular_messages[-2:]): | |
formatted_input += "Based on the tool results above, provide your FINAL ANSWER now.\n\n" | |
formatted_input += "REMINDER ON ANSWER FORMAT: \n" | |
formatted_input += "- Numbers: no commas, no units unless specified\n" | |
formatted_input += "- Strings: no articles, no abbreviations, digits in plain text\n" | |
formatted_input += "- Lists: comma-separated following above rules\n" | |
formatted_input += "- Be extremely brief and concise" | |
formatted_input += "Assistant:" | |
print(f"Input preview: {formatted_input[:300]}...") | |
llm_response_object = llm_model.invoke(formatted_input) | |
# Process response and clean up hallucinated content | |
if isinstance(llm_response_object, BaseMessage): | |
raw_content = llm_response_object.content | |
elif hasattr(llm_response_object, 'content'): | |
raw_content = str(llm_response_object.content) | |
else: | |
raw_content = str(llm_response_object) | |
# Clean up the response to prevent hallucinated follow-up questions | |
cleaned_content = clean_llm_response(raw_content) | |
ai_message_response = AIMessage(content=cleaned_content) | |
print(f"🔍 LLM Response preview: {cleaned_content[:200]}...") | |
final_messages = original_messages + [ai_message_response] | |
new_state["messages"] = final_messages | |
new_state.pop("done", None) | |
except Exception as e: | |
print(f"❌ LLM call failed: {e}") | |
error_message = AIMessage(content=f"Error: LLM call failed - {str(e)}") | |
new_state["messages"] = original_messages + [error_message] | |
new_state["done"] = True | |
finally: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return new_state | |
def clean_llm_response(response_text: str) -> str: | |
""" | |
Clean LLM response to prevent hallucinated follow-up questions and conversations. | |
Specifically handles ReAct format: Thought: -> Action: -> Action Input: | |
""" | |
if not response_text: | |
return response_text | |
print(f"Initial response: {response_text[:200]}...") | |
# --- START MODIFICATION --- | |
# Isolate the text generated by the assistant in the last turn. | |
# This prevents parsing examples or instructions from the preamble. | |
assistant_marker = "Assistant:" | |
last_marker_idx = response_text.rfind(assistant_marker) | |
text_to_process = response_text # Default to full text if marker not found | |
if last_marker_idx != -1: | |
# If "Assistant:" is found, process only the text after the last occurrence. | |
text_to_process = response_text[last_marker_idx + len(assistant_marker):].strip() | |
print(f"ℹ️ Parsing content after last 'Assistant:': {text_to_process[:200]}...") | |
else: | |
# If "Assistant:" is not found, process the whole input. | |
# This might occur if the input is already just the assistant's direct response | |
# or if the prompt structure is different. | |
print(f"ℹ️ No 'Assistant:' marker found. Processing entire input as is.") | |
# --- END MODIFICATION --- | |
# Now, all subsequent operations use 'text_to_process' | |
# Try to find a complete ReAct pattern in the assistant's actual output | |
react_pattern = r'Thought:\s*(.*?)\s*Action:\s*([^\n\r]+)\s*Action Input:\s*(.*?)(?=\s*(?:Thought:|Action:|FINAL ANSWER:|$))' | |
# Apply search to 'text_to_process' | |
react_match = re.search(react_pattern, text_to_process, re.DOTALL | re.IGNORECASE) | |
if react_match: | |
thought_text = react_match.group(1).strip() | |
action_name = react_match.group(2).strip() | |
action_input = react_match.group(3).strip() | |
# Clean up the action input - remove any trailing content that looks like instructions | |
action_input_clean = re.sub(r'\s*(When you have|FINAL ANSWER|ANSWER FORMAT|IMPORTANT:).*$', '', action_input, flags=re.DOTALL | re.IGNORECASE) | |
action_input_clean = action_input_clean.strip() | |
react_sequence = f"Thought: {thought_text}\nAction: {action_name}\nAction Input: {action_input_clean}" | |
print(f"🔧 Found ReAct pattern - Action: {action_name}, Input: {action_input_clean[:100]}...") | |
# Check if there's a FINAL ANSWER after the action input (this would be hallucination) | |
# Check in the remaining part of 'text_to_process' | |
remaining_text_in_process = text_to_process[react_match.end():] | |
final_answer_after = re.search(r'FINAL ANSWER:', remaining_text_in_process, re.IGNORECASE) | |
if final_answer_after: | |
print(f"🚫 Removed hallucinated FINAL ANSWER after tool call") | |
return react_sequence | |
# If no ReAct pattern in 'text_to_process', check for standalone FINAL ANSWER | |
# This variable will hold the text being processed for FINAL ANSWER and then for fallback. | |
current_text_for_processing = text_to_process | |
final_answer_match = re.search(r"FINAL ANSWER:\s*(.+?)(?=\n|$)", current_text_for_processing, re.IGNORECASE) | |
if final_answer_match: | |
answer_content = final_answer_match.group(1).strip() | |
template_phrases = [ | |
'[concise answer only]', | |
'[concise answer - number/word/list only]', | |
'[brief answer]', | |
'[your answer here]', | |
'concise answer only', | |
'brief answer', | |
'your answer here' | |
] | |
if any(phrase.lower() in answer_content.lower() for phrase in template_phrases): | |
print(f"🚫 Ignoring template FINAL ANSWER: {answer_content}") | |
# Remove the template FINAL ANSWER and continue cleaning on the remainder of 'current_text_for_processing' | |
current_text_for_processing = current_text_for_processing[:final_answer_match.start()].strip() | |
# Fall through to the general cleanup section below | |
else: | |
# Keep everything from the start of 'current_text_for_processing' up to and including the real FINAL ANSWER line only | |
cleaned_output = current_text_for_processing[:final_answer_match.end()] | |
# Check if there's additional content after FINAL ANSWER in 'current_text_for_processing' | |
remaining_after_final_answer = current_text_for_processing[final_answer_match.end():].strip() | |
if remaining_after_final_answer: | |
print(f"🚫 Removed content after FINAL ANSWER: {remaining_after_final_answer[:100]}...") | |
return cleaned_output.strip() | |
# If no ReAct, or FINAL ANSWER was a template or not found, apply fallback cleaning to 'current_text_for_processing' | |
lines = current_text_for_processing.split('\n') | |
cleaned_lines = [] | |
for i, line in enumerate(lines): | |
# Stop if we see the model repeating system instructions | |
if re.search(r'\[SYSTEM\]|\[HUMAN\]|\[ASSISTANT\]|\[TOOL\]', line, re.IGNORECASE): | |
print(f"🚫 Stopped at repeated system format: {line}") | |
break | |
# Stop if we see the model generating format instructions | |
if re.search(r'CRITICAL INSTRUCTIONS|FORMAT for tool use|ANSWER FORMAT', line, re.IGNORECASE): | |
print(f"🚫 Stopped at repeated instructions: {line}") | |
break | |
# Stop if we see the model role-playing as a human asking questions | |
if re.search(r'(what are|what is|how many|can you tell me)', line, re.IGNORECASE) and not line.strip().startswith(('Thought:', 'Action:', 'Action Input:')): | |
# Make sure this isn't part of a legitimate thought process | |
if i > 0 and not any(keyword in lines[i-1] for keyword in ['Thought:', 'Action:', 'need to']): | |
print(f"🚫 Stopped at hallucinated question: {line}") | |
break | |
cleaned_lines.append(line) | |
cleaned = '\n'.join(cleaned_lines).strip() | |
print(f"Final cleaned response (fallback): {cleaned[:200]}...") | |
return cleaned | |
def parse_react_output(state: AgentState) -> AgentState: | |
""" | |
Enhanced parsing with better FINAL ANSWER detection and flow control. | |
""" | |
print("Running parse_react_output...") | |
#ipdb.set_trace() | |
messages = state.get("messages", []) | |
if not messages: | |
print("No messages in state.") | |
new_state = state.copy() | |
new_state["done"] = True | |
return new_state | |
# DEBUG | |
print(f"parse_react_output: Entry message count: {len(messages)}") | |
if messages and hasattr(messages[-1], 'tool_calls'): | |
print(f"parse_react_output: Number of tool calls in last AIMessage: {len(messages[-1].tool_calls)}") | |
last_message = messages[-1] | |
new_state = state.copy() | |
if not isinstance(last_message, AIMessage): | |
print("Last message is not an AIMessage instance.") | |
return new_state | |
content = last_message.content | |
if not isinstance(content, str): | |
content = str(content) | |
# Look for FINAL ANSWER first - this should take absolute priority | |
# Use a more precise regex to capture just the answer line | |
final_answer_match = re.search(r"FINAL ANSWER:\s*([^\n\r]+)", content, re.IGNORECASE) | |
if final_answer_match: | |
final_answer_text = final_answer_match.group(1).strip() | |
# Check if this is template text (not a real answer) | |
template_phrases = [ | |
'[concise answer only]', | |
'[concise answer - number/word/list only]', | |
'[brief answer]', | |
'[your answer here]', | |
'concise answer only', | |
'brief answer', | |
'your answer here' | |
] | |
# If it's template text, don't treat it as a final answer | |
if any(phrase.lower() in final_answer_text.lower() for phrase in template_phrases): | |
print(f"🚫 Ignoring template FINAL ANSWER: '{final_answer_text}'") | |
# Continue processing as if no final answer was found | |
else: | |
print(f"🎯 FINAL ANSWER found: '{final_answer_text}' - ENDING") | |
# Store the answer in state for easy access | |
new_state["answer"] = final_answer_text | |
# Clean up the message content to just show the final answer | |
clean_content = f"FINAL ANSWER: {final_answer_text}" | |
updated_ai_message = AIMessage(content=clean_content, tool_calls=[]) | |
new_state["messages"] = messages[:-1] + [updated_ai_message] | |
new_state["done"] = True | |
return new_state | |
# If no FINAL ANSWER, look for tool calls | |
action_match = re.search(r"Action:\s*([^\n]+)", content, re.IGNORECASE) | |
action_input_match = re.search(r"Action Input:\s*(.+)", content, re.IGNORECASE | re.DOTALL) | |
if action_match and action_input_match: | |
tool_name = action_match.group(1).strip() | |
tool_input_raw = action_input_match.group(1).strip() | |
if tool_name.lower() == "none": | |
print("Action is 'None' - treating as regular response") | |
updated_ai_message = AIMessage(content=content, tool_calls=[]) | |
new_state["messages"] = messages[:-1] + [updated_ai_message] | |
new_state.pop("done", None) | |
return new_state | |
print(f"🔧 Tool call: {tool_name} with input: {tool_input_raw[:100]}...") | |
# Parse tool arguments | |
tool_args = {} | |
try: | |
trimmed_input = tool_input_raw.strip() | |
if (trimmed_input.startswith('{') and trimmed_input.endswith('}')) or \ | |
(trimmed_input.startswith('[') and trimmed_input.endswith(']')): | |
tool_args = ast.literal_eval(trimmed_input) | |
if not isinstance(tool_args, dict): | |
tool_args = {"query": tool_input_raw} | |
else: | |
tool_args = {"query": tool_input_raw} | |
except (ValueError, SyntaxError): | |
tool_args = {"query": tool_input_raw} | |
tool_call_id = str(uuid.uuid4()) | |
parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] | |
updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) | |
new_state["messages"] = messages[:-1] + [updated_ai_message] | |
new_state.pop("done", None) | |
return new_state | |
# No tool call or final answer - treat as regular response | |
print("No actionable content found - continuing conversation") | |
updated_ai_message = AIMessage(content=content, tool_calls=[]) | |
new_state["messages"] = messages[:-1] + [updated_ai_message] | |
new_state.pop("done", None) | |
# DEBUG | |
print(f"parse_react_output: Exit message count: {len(new_state['messages'])}") | |
return new_state | |
# 4. Improved call_tool_with_memory_management to prevent duplicate processing | |
def call_tool_with_memory_management(state: AgentState) -> AgentState: | |
"""Process tool calls with memory management, avoiding duplicates.""" | |
print("Running call_tool with memory management...") | |
# Clear CUDA cache before processing | |
try: | |
import torch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"🧹 Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
except ImportError: | |
pass | |
except Exception as e: | |
print(f"Error clearing CUDA cache: {e}") | |
# Check if we have parsed tool calls from the condition function | |
if 'parsed_tool_calls' in state and state.get('parsed_tool_calls'): | |
print("Executing parsed tool calls...") | |
return execute_parsed_tool_calls(state) | |
# Fallback to original OpenAI-style tool calls handling | |
messages = state.get("messages", []) | |
if not messages: | |
print("No messages found in state.") | |
return state | |
last_message = messages[-1] | |
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: | |
print("No tool calls found in last message") | |
return state | |
# Avoid processing the same tool calls multiple times | |
if hasattr(last_message, '_processed_tool_calls'): | |
print("Tool calls already processed, skipping...") | |
return state | |
# Copy the messages to avoid mutating the original list | |
new_messages = list(messages) | |
print(f"Processing {len(last_message.tool_calls)} tool calls from last message") | |
# Get file_path from state to pass to tools | |
file_path_to_pass = state.get('file_path') | |
for i, tool_call_item in enumerate(last_message.tool_calls): | |
# Handle both dict and object-style tool calls | |
if isinstance(tool_call_item, dict): | |
tool_name = tool_call_item.get("name", "") | |
raw_args = tool_call_item.get("args") | |
tool_call_id = tool_call_item.get("id", str(uuid.uuid4())) | |
elif hasattr(tool_call_item, "name") and hasattr(tool_call_item, "id"): | |
tool_name = getattr(tool_call_item, "name", "") | |
raw_args = getattr(tool_call_item, "args", None) | |
tool_call_id = getattr(tool_call_item, "id", str(uuid.uuid4())) | |
else: | |
print(f"Skipping malformed tool call item: {tool_call_item}") | |
continue | |
print(f"Processing tool call {i+1}: {tool_name}") | |
# Find the matching tool | |
selected_tool = None | |
for tool_instance in tools: | |
if tool_instance.name.lower() == tool_name.lower(): | |
selected_tool = tool_instance | |
break | |
if not selected_tool: | |
tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}" | |
print(f"Tool not found: {tool_name}") | |
else: | |
try: | |
# Prepare the arguments for the tool.run() method | |
tool_run_input_dict = {} | |
if isinstance(raw_args, dict): | |
tool_run_input_dict = raw_args.copy() | |
elif raw_args is not None: | |
tool_run_input_dict["query"] = str(raw_args) | |
# Add file_path to the dictionary for the tool | |
tool_run_input_dict['file_path'] = file_path_to_pass | |
print(f"Executing {tool_name} with args: {tool_run_input_dict} ...") | |
tool_result = selected_tool.run(tool_run_input_dict) | |
#ipdb.set_trace() | |
# Aggressive truncation to prevent memory issues | |
if not isinstance(tool_result, str): | |
tool_result = str(tool_result) | |
max_length = 3000 if "wikipedia" in tool_name.lower() else 2000 | |
if len(tool_result) > max_length: | |
original_length = len(tool_result) | |
tool_result = tool_result[:max_length] + f"... [Result truncated from {original_length} to {max_length} chars to prevent memory issues]" | |
print(f"📄 Truncated result to {max_length} characters") | |
print(f"Tool result length: {len(tool_result)} characters") | |
except Exception as e: | |
tool_result = f"Error executing tool '{tool_name}': {str(e)}" | |
print(f"Tool execution error: {e}") | |
# Create tool message - ONLY ONE PER TOOL CALL | |
tool_message = ToolMessage( | |
content=tool_result, | |
name=tool_name, | |
tool_call_id=tool_call_id | |
) | |
new_messages.append(tool_message) | |
print(f"Added tool message for {tool_name}") | |
# Mark the last message as processed to prevent re-processing | |
if hasattr(last_message, '__dict__'): | |
last_message._processed_tool_calls = True | |
# Update the state | |
new_state = state.copy() | |
new_state["messages"] = new_messages | |
# Clear CUDA cache after processing | |
try: | |
import torch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"🧹 Cleared CUDA cache post-processing. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
except ImportError: | |
pass | |
except Exception as e: | |
print(f"Error clearing CUDA cache post-processing: {e}") | |
return new_state | |
# 3. Enhanced execute_parsed_tool_calls to prevent duplicate observations | |
def execute_parsed_tool_calls(state: AgentState): | |
""" | |
Execute tool calls that were parsed from the Thought/Action/Action Input format. | |
This is called by call_tool when parsed_tool_calls are present in state. | |
""" | |
# Tool name mappings | |
tool_name_mappings = { | |
'wikipedia_semantic_search': 'wikipedia_tool', | |
'wikipedia': 'wikipedia_tool', | |
'search': 'enhanced_search', | |
'duckduckgo_search': 'enhanced_search', | |
'web_search': 'enhanced_search', | |
'enhanced_search': 'enhanced_search', | |
'youtube_screenshot_qa_tool': 'youtube_tool', | |
'youtube': 'youtube_tool', | |
'youtube_transcript_extractor': 'youtube_transcript_extractor', | |
'youtube_audio_tool': 'youtube_transcript_extractor' | |
} | |
# Create a lookup by tool names | |
tools_by_name = {} | |
for tool in tools: | |
tools_by_name[tool.name.lower()] = tool | |
# Copy messages to avoid mutation during iteration | |
new_messages = list(state["messages"]) | |
# Process each tool call ONCE | |
for tool_call in state['parsed_tool_calls']: | |
action = tool_call['action'] | |
action_input = tool_call['action_input'] | |
normalized_action = tool_call['normalized_action'] | |
print(f"🚀 Executing tool: {action} with input: {action_input}") | |
# Find the tool instance | |
tool_instance = None | |
if normalized_action in tools_by_name: | |
tool_instance = tools_by_name[normalized_action] | |
elif normalized_action in tool_name_mappings: | |
mapped_name = tool_name_mappings[normalized_action] | |
if mapped_name in tools_by_name: | |
tool_instance = tools_by_name[mapped_name] | |
if tool_instance: | |
try: | |
# Pass file_path if the tool expects it | |
if hasattr(tool_instance, 'run'): | |
if 'file_path' in tool_instance.run.__code__.co_varnames: | |
result = tool_instance.run(action_input, file_path=state.get('file_path')) | |
else: | |
result = tool_instance.run(action_input) | |
else: | |
result = str(tool_instance) | |
# Truncate long results | |
if len(result) > 6000: | |
result = result[:6000] + "... [Result truncated due to length]" | |
# Create a SINGLE observation message | |
from langchain_core.messages import ToolMessage | |
tool_message = ToolMessage( | |
content=f"Observation: {result}", | |
name=action, | |
tool_call_id=str(uuid.uuid4()) | |
) | |
new_messages.append(tool_message) | |
print(f"✅ Tool '{action}' executed successfully") | |
except Exception as e: | |
print(f"❌ Error executing tool '{action}': {e}") | |
from langchain_core.messages import ToolMessage | |
error_message = ToolMessage( | |
content=f"Observation: Error executing '{action}': {str(e)}", | |
name=action, | |
tool_call_id=str(uuid.uuid4()) | |
) | |
new_messages.append(error_message) | |
else: | |
print(f"❌ Tool '{action}' not found in available tools") | |
available_tool_names = list(tools_by_name.keys()) | |
from langchain_core.messages import ToolMessage | |
error_message = ToolMessage( | |
content=f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}", | |
name=action, | |
tool_call_id=str(uuid.uuid4()) | |
) | |
new_messages.append(error_message) | |
# Update state with new messages and clear parsed tool calls | |
new_state = state.copy() | |
new_state["messages"] = new_messages | |
new_state['parsed_tool_calls'] = [] # Clear to prevent re-execution | |
return new_state | |
# 1. Add loop detection to your AgentState | |
def should_continue(state: AgentState) -> str: | |
"""Enhanced continuation logic with better limits.""" | |
print("Running should_continue...") | |
# Check done flag first | |
if state.get("done", False): | |
print("✅ Done flag is True - ending") | |
return "end" | |
messages = state["messages"] | |
# More aggressive message limit | |
#if len(messages) > 20: # Reduced from 15 | |
# print(f"⚠️ Message limit reached ({len(messages)}/20) - forcing end") | |
# return "end" | |
# Check for repeated patterns (stuck in loop) | |
if len(messages) >= 6: | |
recent_contents = [str(msg.content)[:100] for msg in messages[-6:] if hasattr(msg, 'content')] | |
if len(set(recent_contents)) < 3: # Too much repetition | |
print("🔄 Detected repetitive pattern - ending") | |
return "end" | |
print(f"📊 Continuing... ({len(messages)} messages so far)") | |
return "continue" | |
def route_after_parse_react(state: AgentState) -> str: | |
"""Determines the next step after parsing LLM output, prioritizing end state.""" | |
if state.get("done", False): # Check if parse_react_output decided we are done | |
return "end_processing" | |
# Original logic: check for tool calls in the last message | |
# Ensure messages list and last message exist before checking tool_calls | |
messages = state.get("messages", []) | |
if messages: | |
last_message = messages[-1] | |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
return "call_tool" | |
return "call_llm" | |
# --- Graph Construction --- | |
# --- Graph Construction --- | |
def create_memory_safe_workflow(): | |
"""Create a workflow with memory management and loop prevention.""" | |
# These models are initialized here but might be better managed if they need to be released/reinitialized | |
# like you attempt in run_agent. Consider passing them or managing their lifecycle carefully. | |
hf_pipe = create_llm_pipeline() | |
llm = HuggingFacePipeline(pipeline=hf_pipe) | |
# vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly | |
# processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used | |
# model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used | |
workflow = StateGraph(AgentState) | |
# Bind the llm_model to the call_llm_with_memory_management function | |
bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm) | |
# Add nodes with memory-safe versions | |
workflow.add_node("call_llm", bound_call_llm) # Use the bound version here | |
workflow.add_node("parse_react_output", parse_react_output) | |
workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly | |
# Set entry point | |
workflow.set_entry_point("call_llm") | |
# Add conditional edges | |
workflow.add_conditional_edges( | |
"call_llm", | |
should_continue, | |
{ | |
"continue": "parse_react_output", | |
"end": END | |
} | |
) | |
workflow.add_conditional_edges( | |
"parse_react_output", | |
route_after_parse_react, | |
{ | |
"call_tool": "call_tool", | |
"call_llm": "call_llm", | |
"end_processing": END | |
} | |
) | |
workflow.add_edge("call_tool", "call_llm") | |
return workflow.compile() | |
def count_english_words(text): | |
# Remove punctuation, lowercase, split into words | |
table = str.maketrans('', '', string.punctuation) | |
words_in_text = text.translate(table).lower().split() | |
return sum(1 for word in words_in_text if word in english_words) | |
def fix_backwards_text(text): | |
reversed_text = text[::-1] | |
original_count = count_english_words(text) | |
reversed_count = count_english_words(reversed_text) | |
if reversed_count > original_count: | |
return reversed_text | |
else: | |
return text | |
# --- Run the Agent --- | |
# Enhanced system prompt for better behavior | |
def run_agent(agent, state: AgentState): | |
"""Enhanced agent initialization with better prompt and hallucination prevention.""" | |
global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL, tools | |
# Initialize tools | |
WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS() | |
SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=3000) | |
YOUTUBE_TOOL = EnhancedYoutubeScreenshotQA() | |
YOUTUBE_AUDIO_TOOL = YouTubeTranscriptExtractor() | |
tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL] | |
formatted_tools_description = render_text_description(tools) | |
current_date_str = datetime.now().strftime("%Y-%m-%d") | |
# Enhanced system prompt with stricter boundaries | |
system_content = f"""You are an AI assistant with access to these tools: | |
{formatted_tools_description} | |
CRITICAL INSTRUCTIONS: | |
1. Answer ONLY the question asked by the human | |
2. Do NOT generate additional questions or continue conversations | |
3. Use tools ONLY when you need specific information you don't know | |
4. After using a tool, provide your FINAL ANSWER immediately | |
5. STOP after giving your FINAL ANSWER - do not continue | |
FORMAT for tool use: | |
Thought: <brief reasoning> | |
Action: <exact_tool_name> | |
Action Input: <tool_input> | |
When you have the answer, immediately provide: | |
FINAL ANSWER: [concise answer only] | |
ANSWER FORMAT: | |
- Numbers: no commas, no units unless specified | |
- Strings: no articles, no abbreviations, digits in plain text | |
- Lists: comma-separated following above rules | |
- Be extremely brief and concise | |
- Do not provide additional context or explanations | |
- Do not provide parentheticals | |
IMPORTANT: You are responding to ONE question only. Do not ask follow-up questions or generate additional dialogue. | |
Current date: {current_date_str} | |
""" | |
query = fix_backwards_text(state['question']) | |
# Check for YouTube URLs | |
yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
if re.search(yt_pattern, query): | |
url_match = re.search(r"(https?://[^\s]+)", query) | |
if url_match: | |
state['youtube_url'] = url_match.group(0) | |
# Initialize messages | |
system_message = SystemMessage(content=system_content) | |
human_message = HumanMessage(content=query) | |
state['messages'] = [system_message, human_message] | |
state["done"] = False | |
# Run the agent | |
result = agent.invoke(state) | |
# Cleanup | |
if result.get("done"): | |
#torch.cuda.empty_cache() | |
#torch.cuda.ipc_collect() | |
gc.collect() | |
print("🧹 Released GPU memory after completion") | |
return result["messages"] | |