Spaces:
Sleeping
Sleeping
# Standard Library | |
import os | |
import re | |
import tempfile | |
import string | |
import glob | |
import shutil | |
import gc | |
import uuid | |
import signal | |
from datetime import datetime | |
from io import BytesIO | |
from contextlib import contextmanager | |
from langchain_huggingface import HuggingFacePipeline | |
from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set | |
import time | |
from collections import Counter | |
# Third-Party Packages | |
import cv2 | |
import requests | |
import wikipedia | |
import spacy | |
import yt_dlp | |
import librosa | |
from PIL import Image | |
from bs4 import BeautifulSoup | |
from duckduckgo_search import DDGS | |
from sentence_transformers import SentenceTransformer | |
from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
# LangChain Ecosystem | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import WikipediaLoader | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_community.retrievers import BM25Retriever | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage | |
from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description | |
from langchain_core.documents import Document | |
# LangGraph | |
from langgraph.graph import START, END, StateGraph | |
from langgraph.prebuilt import ToolNode, tools_condition | |
# PyTorch | |
import torch | |
from functools import partial | |
from transformers import pipeline | |
# Additional Utilities | |
from datetime import datetime | |
from urllib.parse import urljoin, urlparse | |
import logging | |
nlp = spacy.load("en_core_web_sm") | |
logger = logging.getLogger(__name__) | |
# --- Model Configuration --- | |
def create_llm_pipeline(): | |
#model_id = "meta-llama/Llama-2-13b-chat-hf" | |
#model_id = "meta-llama/Llama-3.3-70B-Instruct" | |
#model_id = "mistralai/Mistral-Small-24B-Base-2501" | |
model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
#model_id = "Qwen/Qwen2-7B-Instruct" | |
return pipeline( | |
"text-generation", | |
model=model_id, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
max_new_tokens=1024, | |
temperature=0.1 | |
) | |
# Define file extension sets for each category | |
PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
SPREADSHEET_EXTENSIONS = { | |
'.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
'.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
'.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
'.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
'.pmv', '.uos', '.txt' | |
} | |
def get_file_type(filename: str) -> str: | |
if not filename or '.' not in filename or filename == '': | |
return '' | |
ext = filename.lower().rsplit('.', 1)[-1] | |
dot_ext = f'.{ext}' | |
if dot_ext in PICTURE_EXTENSIONS: | |
return 'picture' | |
elif dot_ext in AUDIO_EXTENSIONS: | |
return 'audio' | |
elif dot_ext in CODE_EXTENSIONS: | |
return 'code' | |
elif dot_ext in SPREADSHEET_EXTENSIONS: | |
return 'spreadsheet' | |
else: | |
return 'unknown' | |
def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
""" | |
Writes bytes to a file in the system temporary directory using the provided file_name. | |
Returns the full path to the saved file. | |
The file will persist until manually deleted or the OS cleans the temp directory. | |
""" | |
temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces | |
os.makedirs(temp_dir, exist_ok=True) | |
file_path = os.path.join(temp_dir, file_name) | |
with open(file_path, 'wb') as f: | |
f.write(file_bytes) | |
print(f"File written to: {file_path}") | |
return file_path | |
def extract_final_answer(text: str) -> str: | |
""" | |
Returns the substring starting from the last occurrence of 'FINAL ANSWER:' (case-insensitive) | |
to the end of the string, with any trailing punctuation removed. | |
If not found, returns an empty string. | |
""" | |
marker = "FINAL ANSWER:" | |
idx = text.lower().rfind(marker.lower()) | |
if idx == -1: | |
return "" | |
result = text[idx:].strip() | |
# Remove trailing punctuation | |
return result.rstrip(string.punctuation + " ") | |
class EnhancedDuckDuckGoSearchTool(BaseTool): | |
name: str = "enhanced_search" | |
description: str = ( | |
"Performs a DuckDuckGo web search and retrieves actual content from the top web results. " | |
"Input should be a search query string. " | |
"Returns search results with extracted content from web pages, making it much more useful for answering questions. " | |
"Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. " | |
"Ideal for topics that require the latest news, recent developments, or information not covered in static sources." | |
) | |
max_results: int = 3 | |
max_chars_per_page: int = 3000 | |
session: Any = None # Now it's optional and defaults to None | |
# Use model_post_init for initialization logic in Pydantic v2+ | |
def model_post_init(self, __context: Any) -> None: | |
super().model_post_init(__context) | |
# Initialize HTTP session here | |
self.session = requests.Session() | |
self.session.headers.update({ | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', | |
'Accept-Language': 'en-US,en;q=0.5', | |
'Accept-Encoding': 'gzip, deflate', | |
'Connection': 'keep-alive', | |
'Upgrade-Insecure-Requests': '1', | |
}) | |
def _search_duckduckgo(self, query: str) -> List[Dict]: | |
"""Perform DuckDuckGo search and return results.""" | |
try: | |
with DDGS() as ddgs: | |
results = list(ddgs.text(query, max_results=self.max_results)) | |
return results | |
except Exception as e: | |
logger.error(f"DuckDuckGo search failed: {e}") | |
return [] | |
def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]: | |
"""Extract clean text content from a web page.""" | |
try: | |
# Skip certain file types | |
if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']): | |
return "Content type not supported for extraction" | |
response = self.session.get(url, timeout=timeout, allow_redirects=True) | |
response.raise_for_status() | |
# Check content type | |
content_type = response.headers.get('content-type', '').lower() | |
if 'text/html' not in content_type: | |
return "Non-HTML content detected" | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Remove script and style elements | |
for script in soup(["script", "style", "nav", "header", "footer", "aside", "form"]): | |
script.decompose() | |
# Try to find main content areas | |
main_content = None | |
for selector in ['main', 'article', '.content', '#content', '.post', '.entry']: | |
main_content = soup.select_one(selector) | |
if main_content: | |
break | |
if not main_content: | |
main_content = soup.find('body') or soup | |
# Extract text | |
text = main_content.get_text(separator='\n', strip=True) | |
# Clean up the text | |
lines = [line.strip() for line in text.split('\n') if line.strip()] | |
text = '\n'.join(lines) | |
# Remove excessive whitespace | |
text = re.sub(r'\n{3,}', '\n\n', text) | |
text = re.sub(r' {2,}', ' ', text) | |
# Truncate if too long | |
if len(text) > self.max_chars_per_page: | |
text = text[:self.max_chars_per_page] + "\n[Content truncated...]" | |
return text | |
except requests.exceptions.Timeout: | |
return "Page loading timed out" | |
except requests.exceptions.RequestException as e: | |
return f"Failed to retrieve page: {str(e)}" | |
except Exception as e: | |
logger.error(f"Content extraction failed for {url}: {e}") | |
return "Failed to extract content from page" | |
def _format_search_result(self, result: Dict, content: str) -> str: | |
"""Format a single search result with its content.""" | |
title = result.get('title', 'No title') | |
url = result.get('href', 'No URL') | |
snippet = result.get('body', 'No snippet') | |
formatted = f""" | |
🔍 **{title}** | |
URL: {url} | |
Snippet: {snippet} | |
📄 **Page Content:** | |
{content} | |
--- | |
""" | |
return formatted | |
def run(self, query: str) -> str: | |
"""Execute the enhanced search.""" | |
if not query or not query.strip(): | |
return "Please provide a search query." | |
query = query.strip() | |
logger.info(f"Searching for: {query}") | |
# Perform DuckDuckGo search | |
search_results = self._search_duckduckgo(query) | |
if not search_results: | |
return f"No search results found for query: {query}" | |
# Process each result and extract content | |
enhanced_results = [] | |
processed_count = 0 | |
for i, result in enumerate(search_results[:self.max_results]): | |
url = result.get('href', '') | |
if not url: | |
continue | |
logger.info(f"Processing result {i+1}: {url}") | |
# Extract content from the page | |
content = self._extract_content_from_url(url) | |
if content and len(content.strip()) > 50: # Only include results with substantial content | |
formatted_result = self._format_search_result(result, content) | |
enhanced_results.append(formatted_result) | |
processed_count += 1 | |
# Small delay to be respectful to servers | |
time.sleep(0.5) | |
if not enhanced_results: | |
return f"Search completed but no content could be extracted from the pages for query: {query}" | |
# Compile final response | |
response = f"""🔍 **Enhanced Search Results for: "{query}"** | |
Found {len(search_results)} results, successfully processed {processed_count} pages with content. | |
{''.join(enhanced_results)} | |
💡 **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query. | |
""" | |
# Ensure the response isn't too long | |
if len(response) > 8000: | |
response = response[:8000] + "\n[Response truncated to prevent memory issues]" | |
return response | |
def _run(self, query: str) -> str: | |
"""Required by BaseTool interface.""" | |
return self.run(query) | |
# --- Agent State Definition --- | |
class AgentState(TypedDict): | |
messages: Annotated[List[AnyMessage], lambda x, y: x + y] | |
done: bool = False # Default value of False | |
question: str | |
task_id: str | |
input_file: Optional[bytes] | |
file_type: Optional[str] | |
context: List[Document] # Using LangChain's Document class | |
file_path: Optional[str] | |
youtube_url: Optional[str] | |
answer: Optional[str] | |
frame_answers: Optional[list] | |
def fetch_page_with_tables(page_title): | |
""" | |
Fetches Wikipedia page content and extracts all tables as readable text. | |
Returns a tuple: (main_text, [table_texts]) | |
""" | |
# Fetch the page object | |
page = wikipedia.page(page_title) | |
main_text = page.content | |
# Get the HTML for table extraction | |
html = page.html() | |
soup = BeautifulSoup(html, 'html.parser') | |
tables = soup.find_all('table') | |
table_texts = [] | |
for table in tables: | |
rows = table.find_all('tr') | |
table_lines = [] | |
for row in rows: | |
cells = row.find_all(['th', 'td']) | |
cell_texts = [cell.get_text(strip=True) for cell in cells] | |
if cell_texts: | |
# Format as Markdown table row | |
table_lines.append(" | ".join(cell_texts)) | |
if table_lines: | |
table_text = "\n".join(table_lines) | |
table_texts.append(table_text) | |
return main_text, table_texts | |
class WikipediaSearchToolWithFAISS(BaseTool): | |
name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval" | |
description: str = ( | |
"Fetches content from multiple Wikipedia pages based on intelligent NLP query processing " | |
"of various search candidates, with strong prioritization of query entities. It then performs " | |
"entity-focused semantic search across all fetched content to find the most relevant information, " | |
"with improved retrieval for lists like discographies. Uses spaCy for named entity " | |
"recognition and query enhancement. Input should be a search query or topic. " | |
"Note: Uses the current live version of Wikipedia." | |
) | |
embedding_model_name: str = "all-MiniLM-L6-v2" | |
chunk_size: int = 4000 | |
chunk_overlap: int = 250 # Maintained moderate overlap | |
top_k_results: int = 3 | |
spacy_model: str = "en_core_web_sm" | |
# Increased multiplier to fetch more candidates per semantic query variant | |
semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
try: | |
self._nlp = spacy.load(self.spacy_model) | |
print(f"Loaded spaCy model: {self.spacy_model}") | |
self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name) | |
# Refined separators for better handling of Wikipedia lists and sections | |
self._text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.chunk_size, | |
chunk_overlap=self.chunk_overlap, | |
separators=[ | |
"\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content) | |
"\n\n\n", "\n\n", # Multiple newlines (paragraph breaks) | |
"\n* ", "\n- ", "\n# ", # List items | |
"\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation | |
" ", "" # Word and character level | |
] | |
) | |
except OSError as e: | |
print(f"Error loading spaCy model '{self.spacy_model}': {e}") | |
print("Try running: python -m spacy download en_core_web_sm") | |
self._nlp = None | |
self._embedding_model = None | |
self._text_splitter = None | |
except Exception as e: | |
print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}") | |
self._nlp = None | |
self._embedding_model = None | |
self._text_splitter = None | |
def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]: | |
if not self._nlp: | |
return [], [], query | |
doc = self._nlp(query) | |
main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]] | |
keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2] | |
main_entities = list(dict.fromkeys(main_entities)) | |
keywords = list(dict.fromkeys(keywords)) | |
processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()] | |
processed_query = " ".join(processed_tokens) | |
return main_entities, keywords, processed_query | |
def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]: | |
candidates_set = set() | |
entity_prefix = main_entities[0] if main_entities else None | |
for me in main_entities: | |
candidates_set.add(me) | |
candidates_set.add(query) | |
if processed_query and processed_query != query: | |
candidates_set.add(processed_query) | |
if entity_prefix and keywords: | |
first_entity_lower = entity_prefix.lower() | |
for kw in keywords[:3]: | |
if kw not in first_entity_lower and len(kw) > 2: | |
candidates_set.add(f"{entity_prefix} {kw}") | |
keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2) | |
if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}") | |
if len(main_entities) > 1: | |
candidates_set.add(" ".join(main_entities[:2])) | |
if keywords: | |
keyword_combo = " ".join(keywords[:2]) | |
if entity_prefix: | |
candidate_to_add = f"{entity_prefix} {keyword_combo}" | |
if not any(c.lower() == candidate_to_add.lower() for c in candidates_set): | |
candidates_set.add(candidate_to_add) | |
elif not main_entities: | |
candidates_set.add(keyword_combo) | |
ordered_candidates = [] | |
for me in main_entities: | |
if me not in ordered_candidates: ordered_candidates.append(me) | |
for c in list(candidates_set): | |
if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c) | |
print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}") | |
return ordered_candidates | |
def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]: | |
candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text) | |
found_pages_data: List[Tuple[str, str]] = [] | |
processed_page_titles: Set[str] = set() | |
for i, candidate_query in enumerate(candidates): | |
print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'") | |
page_object = None | |
final_page_title = None | |
is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False | |
try: | |
try: | |
page_to_load = candidate_query | |
suggest_mode = True # Default to auto_suggest=True | |
if is_candidate_entity_focused and main_entities_from_query: | |
try: # Attempt precise match first for entity-focused candidates | |
temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True) | |
suggest_mode = False # Flag that precise match worked | |
except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError): | |
print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.") | |
# Fallthrough to auto_suggest=True below if this fails | |
if suggest_mode: # If not attempted or failed with auto_suggest=False | |
temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True) | |
final_page_title = temp_page.title | |
if is_candidate_entity_focused and main_entities_from_query: | |
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
if not title_matches_main_entity: | |
print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') " | |
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
continue | |
if final_page_title in processed_page_titles: | |
print(f" ~ Already processed '{final_page_title}'") | |
continue | |
page_object = temp_page | |
print(f" ✓ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'") | |
except wikipedia.exceptions.PageError: | |
if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates | |
print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...") | |
search_results = wikipedia.search(candidate_query, results=1) | |
if not search_results: | |
print(f" - No Wikipedia search results for '{candidate_query}'.") | |
continue | |
search_result_title = search_results[0] | |
try: | |
temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical | |
final_page_title = temp_page.title | |
if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent | |
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
if not title_matches_main_entity: | |
print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') " | |
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
continue | |
if final_page_title in processed_page_titles: | |
print(f" ~ Already processed '{final_page_title}'") | |
continue | |
page_object = temp_page | |
print(f" ✓ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'") | |
except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr: | |
print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}") | |
else: | |
print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.") | |
except wikipedia.exceptions.DisambiguationError as de: | |
print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}") | |
if de.options: | |
option_title = de.options[0] | |
try: | |
temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True) | |
final_page_title = temp_page.title | |
if is_candidate_entity_focused and main_entities_from_query: # Check against original intent | |
title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
if not title_matches_main_entity: | |
print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') " | |
f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
continue | |
if final_page_title in processed_page_titles: | |
print(f" ~ Already processed '{final_page_title}'") | |
continue | |
page_object = temp_page | |
print(f" ✓ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'") | |
except Exception as e_dis_opt: | |
print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}") | |
if page_object and final_page_title and (final_page_title not in processed_page_titles): | |
# Extract main text | |
main_text = page_object.content | |
# Extract tables using BeautifulSoup | |
try: | |
html = page_object.html() | |
soup = BeautifulSoup(html, 'html.parser') | |
tables = soup.find_all('table') | |
table_texts = [] | |
for table in tables: | |
rows = table.find_all('tr') | |
table_lines = [] | |
for row in rows: | |
cells = row.find_all(['th', 'td']) | |
cell_texts = [cell.get_text(strip=True) for cell in cells] | |
if cell_texts: | |
table_lines.append(" | ".join(cell_texts)) | |
if table_lines: | |
table_text = "\n".join(table_lines) | |
table_texts.append(table_text) | |
except Exception as e: | |
print(f" !! Error extracting tables for '{final_page_title}': {e}") | |
table_texts = [] | |
# Combine main text and all table texts as separate chunks | |
all_text_chunks = [main_text] + table_texts | |
for chunk in all_text_chunks: | |
found_pages_data.append((chunk, final_page_title)) | |
processed_page_titles.add(final_page_title) | |
print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}") | |
except Exception as e: | |
print(f" !! Unexpected error processing candidate '{candidate_query}': {e}") | |
if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.") | |
else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.") | |
return found_pages_data | |
def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]: | |
core_query_parts = set() | |
core_query_parts.add(query) | |
if processed_query != query: core_query_parts.add(processed_query) | |
if keywords: core_query_parts.add(" ".join(keywords[:2])) | |
section_phrases_templates = [] | |
lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords) | |
section_keywords_map = { | |
"discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"], | |
"biography": ["biography", "life story", "career details", "background history"], | |
"filmography": ["filmography", "list of films", "movie appearances", "acting roles"], | |
} | |
for section_term_key, specific_phrases_list in section_keywords_map.items(): | |
# Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums") | |
# are mentioned or implied by the query terms. | |
if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()): | |
section_phrases_templates.extend(specific_phrases_list) | |
# Also check if phrases themselves are in query terms (e.g. query "list of albums by X") | |
for phrase in specific_phrases_list: | |
if phrase in query.lower(): # Check against original query for direct phrase matches | |
section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit | |
break | |
section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate | |
final_search_queries = set() | |
if main_entities: | |
entity_prefix = main_entities[0] | |
final_search_queries.add(entity_prefix) | |
for part in core_query_parts: | |
final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part) | |
for phrase_template in section_phrases_templates: | |
final_search_queries.add(f"{entity_prefix} {phrase_template}") | |
if "list of" in phrase_template or "history of" in phrase_template : | |
final_search_queries.add(f"{phrase_template} of {entity_prefix}") | |
else: | |
final_search_queries.update(core_query_parts) | |
final_search_queries.update(section_phrases_templates) | |
deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip())) | |
print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}") | |
all_results_docs: List[Document] = [] | |
seen_content_hashes: Set[int] = set() | |
k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier | |
for search_query_variant in deduplicated_queries: | |
try: | |
results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) | |
print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.") | |
for doc, score in results: # Assuming similarity_search_with_score returns (doc, score) | |
content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness | |
if content_hash not in seen_content_hashes: | |
seen_content_hashes.add(content_hash) | |
doc.metadata['retrieved_by_variant'] = search_query_variant | |
doc.metadata['retrieval_score'] = float(score) # Store score | |
all_results_docs.append(doc) | |
except Exception as e: | |
print(f" Error in semantic search for variant '{search_query_variant}': {e}") | |
# Sort all collected unique results by score (FAISS L2 distance is lower is better) | |
all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf'))) | |
print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.") | |
return all_results_docs[:self.top_k_results] | |
def _run(self, query: str) -> str: | |
if not self._nlp or not self._embedding_model or not self._text_splitter: | |
print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.") | |
return "Error: Wikipedia tool components not initialized properly. Please check server logs." | |
try: | |
print(f"\n--- Running {self.name} for query: '{query}' ---") | |
main_entities, keywords, processed_query = self._extract_entities_and_keywords(query) | |
print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'") | |
fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query) | |
if not fetched_pages_data: | |
return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. " | |
f"Main entities sought: {main_entities}") | |
all_page_titles = [title for _, title in fetched_pages_data] | |
print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}") | |
all_documents: List[Document] = [] | |
for page_content, page_title in fetched_pages_data: | |
chunks = self._text_splitter.split_text(page_content) | |
if not chunks: | |
print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.") | |
continue | |
for i, chunk_text in enumerate(chunks): | |
all_documents.append(Document(page_content=chunk_text, metadata={ | |
"source_page_title": page_title, | |
"original_query": query, | |
"chunk_index": i # Add chunk index for potential debugging or ordering | |
})) | |
print(f"Split content from '{page_title}' into {len(chunks)} chunks.") | |
if not all_documents: | |
return (f"Could not process content into searchable chunks from the fetched Wikipedia pages " | |
f"({', '.join(all_page_titles)}) for query '{query}'.") | |
print(f"\nTotal document chunks from all pages: {len(all_documents)}") | |
print("Creating FAISS index from content of all fetched pages...") | |
try: | |
vector_store = FAISS.from_documents(all_documents, self._embedding_model) | |
print("FAISS index created successfully.") | |
except Exception as e: | |
return f"Error creating FAISS vector store: {e}" | |
print(f"\nPerforming enhanced semantic search across all collected content...") | |
try: | |
relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query) | |
except Exception as e: | |
return f"Error during semantic search: {e}" | |
if not relevant_docs: | |
return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' " | |
f"for your query '{query}' using entity-focused semantic search with list retrieval.") | |
unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs])) | |
result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) " | |
f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n") | |
nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, " | |
f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n") | |
result_details = [] | |
for i, doc in enumerate(relevant_docs): | |
source_info = doc.metadata.get('source_page_title', 'Unknown Source') | |
variant_info = doc.metadata.get('retrieved_by_variant', 'N/A') | |
score_info = doc.metadata.get('retrieval_score', 'N/A') | |
detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n" | |
f"(Retrieved by: '{variant_info}')\n{doc.page_content}") | |
result_details.append(detail) | |
final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details) | |
print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).") | |
return final_result.strip() | |
except Exception as e: | |
import traceback | |
print(f"Unexpected error in {self.name}: {traceback.format_exc()}") | |
return f"An unexpected error occurred: {str(e)}" | |
# Example of creating the tool instance: | |
# wikipedia_tool_faiss = WikipediaSearchToolWithFAISS() | |
# To use this new tool in your agent, you would replace the old | |
# `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list. | |
# For example: | |
# tools = [wikipedia_tool_faiss, search_tool] | |
# Create tool instances | |
#wikipedia_tool = WikipediaSearchTool() | |
# --- Define Call LLM function --- | |
# 3. Improved LLM call with memory management | |
def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: # Added llm_model parameter | |
"""Call LLM with memory management, context truncation, and process response.""" | |
print("Running call_llm with memory management...") | |
# It's crucial to work with a copy of messages for modification within this step | |
# The final state["messages"] should reflect the full history + new response. | |
original_messages = list(state["messages"]) | |
messages_for_llm_processing = list(state["messages"]) # Use this for truncation logic | |
#ipdb.set_trace() | |
# --- Context Truncation Logic --- | |
system_message_content = None | |
# Check if the first message is a system message and preserve it | |
if messages_for_llm_processing and isinstance(messages_for_llm_processing[0], SystemMessage): | |
system_message_content = messages_for_llm_processing[0] | |
# Process only non-system messages for truncation count | |
regular_messages = messages_for_llm_processing[1:] | |
else: | |
regular_messages = messages_for_llm_processing | |
# Truncate context if too many messages (e.g., keep system + X most recent) | |
# Max 10 messages total (e.g. 1 system + 9 others) | |
max_regular_messages = 9 | |
if len(regular_messages) > max_regular_messages: | |
print(f"🔄 Truncating message count: {len(messages_for_llm_processing)} -> ~{max_regular_messages + (1 if system_message_content else 0)} messages") | |
regular_messages = regular_messages[- (max_regular_messages -1):] # Keep X-1 most recent, to add user input later | |
# Reconstruct messages for LLM call | |
messages_for_llm = [] | |
if system_message_content: | |
messages_for_llm.append(system_message_content) | |
messages_for_llm.extend(regular_messages) | |
# Further truncate based on character count (rough proxy for tokens) | |
total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) | |
# Example character limit, adjust based on your model (e.g. 8k chars for ~4k tokens) | |
char_limit = 8000 | |
if total_chars > char_limit: | |
print(f"📏 Context too long ({total_chars} chars > {char_limit}), further truncation needed") | |
# More aggressive truncation of regular messages | |
chars_to_remove = total_chars - char_limit | |
temp_regular_messages = list(regular_messages) # copy | |
while sum(len(str(m.content)) for m in temp_regular_messages) > char_limit and temp_regular_messages: | |
if system_message_content and sum(len(str(m.content)) for m in temp_regular_messages) + len(str(system_message_content.content)) <= char_limit : | |
break # if removing one more makes it too small with system message | |
print(f"Removing message: {temp_regular_messages[0].type} - {temp_regular_messages[0].content[:50]}...") | |
temp_regular_messages.pop(0) | |
regular_messages = temp_regular_messages | |
messages_for_llm = [] # Rebuild | |
if system_message_content: | |
messages_for_llm.append(system_message_content) | |
messages_for_llm.extend(regular_messages) | |
print(f"Context truncated to {sum(len(str(m.content)) for m in messages_for_llm)} chars.") | |
new_state = state.copy() # Start with a copy of the input state | |
try: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"🧹 Pre-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
print(f"Invoking LLM with {len(messages_for_llm)} messages.") | |
# This is where you call your actual LLM | |
formatted_input = "\n".join([f"[{msg.type.upper()}] {msg.content}" for msg in messages_for_llm]) | |
print(f"\n\nFormatted input for LLM:\n\n{formatted_input}") | |
llm_response_object = llm_model.invoke(formatted_input) | |
#ipdb.set_trace() | |
# The response_object is typically a BaseMessage subclass (e.g., AIMessage) | |
# or a string for simpler LLMs. Adapt as needed. | |
if isinstance(llm_response_object, BaseMessage): | |
ai_message_response = llm_response_object # It's already a message object | |
if not ai_message_response.content: # Ensure content is not empty | |
ai_message_response.content = "" | |
elif hasattr(llm_response_object, 'content'): # Some models might return a custom object with a content attribute | |
ai_message_response = AIMessage(content=str(llm_response_object.content) if llm_response_object.content is not None else "") | |
else: # Assuming it's a string for basic LLMs | |
ai_message_response = AIMessage(content=str(llm_response_object) if llm_response_object is not None else "") | |
print(f"LLM Response: {ai_message_response.content[:300]}...") # Print a snippet | |
# Append the LLM's response to the original full list of messages | |
final_messages = original_messages + [ai_message_response] | |
new_state["messages"] = final_messages | |
new_state.pop("done", None) # LLM responded, so not 'done' by default | |
except Exception as e: | |
print(f"LLM call failed: {e}") | |
error_message_content = f"LLM call failed with error: {str(e)}. Input consisted of {len(messages_for_llm)} messages." | |
if "out of memory" in str(e).lower(): | |
print("🚨 CUDA OOM detected during LLM call! Implementing emergency cleanup...") | |
error_message_content = f"LLM failed due to Out of Memory: {str(e)}." | |
try: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception as cleanup_e: | |
print(f"Emergency OOM cleanup failed: {cleanup_e}") | |
# Append an error message to the original message history | |
error_ai_message = AIMessage(content=error_message_content) | |
final_messages_on_error = original_messages + [error_ai_message] | |
new_state["messages"] = final_messages_on_error | |
new_state["done"] = True # Mark as done to prevent loops on LLM failure | |
finally: | |
try: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"🧹 Post-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
except Exception: | |
pass # Avoid error in cleanup hiding the main error | |
return new_state | |
import re | |
import uuid | |
def parse_react_output(state: AgentState) -> AgentState: | |
print("Running parse_react_output (Action prioritized)...") | |
messages = state["messages"] | |
last_message = messages[-1] | |
new_state = state.copy() | |
# Only process AI messages (not system/user) | |
if not isinstance(last_message, AIMessage): | |
return new_state | |
content = last_message.content | |
# Remove any system prompt/instructions (if present in content) | |
# Assume that the actual AI output is after the last occurrence of "You are a general AI assistant" or similar system prompt marker | |
sys_prompt_pattern = r"(You are a general AI assistant.*?)(?=\n\n|$)" | |
content_wo_sys_prompt = re.sub(sys_prompt_pattern, '', content, flags=re.DOTALL | re.IGNORECASE).strip() | |
# Find the last occurrence of FINAL ANSWER or Action Input | |
final_answer_match = list(re.finditer(r"FINAL ANSWER:", content_wo_sys_prompt, re.IGNORECASE)) | |
action_input_match = list(re.finditer(r"Action Input:", content_wo_sys_prompt, re.IGNORECASE)) | |
# Helper: get the last match position and which it was | |
last_marker = None | |
last_pos = -1 | |
if final_answer_match: | |
last_fa = final_answer_match[-1] | |
last_marker = 'FINAL ANSWER' | |
last_pos = last_fa.start() | |
if action_input_match: | |
last_ai = action_input_match[-1] | |
if last_ai.start() > last_pos: | |
last_marker = 'Action Input' | |
last_pos = last_ai.start() | |
# If neither marker found, mark as done | |
if not last_marker: | |
print("No FINAL ANSWER or Action Input found in last AI output.") | |
new_state["done"] = True | |
return new_state | |
# Get the substring from the last marker to the end | |
last_section = content_wo_sys_prompt[last_pos:].strip() | |
# 2. If FINAL ANSWER is in the last part, end the process | |
if last_marker == 'FINAL ANSWER': | |
# Extract the answer after FINAL ANSWER: | |
answer = re.search(r"FINAL ANSWER:\s*(.+)", last_section, re.IGNORECASE) | |
final_answer_text = answer.group(1).strip() if answer else "" | |
updated_ai_message = AIMessage(content=f"FINAL ANSWER: {final_answer_text}", tool_calls=[]) | |
new_state["messages"] = messages[:-1] + [updated_ai_message] | |
new_state["done"] = True | |
print(f"FINAL ANSWER found at end: '{final_answer_text}'") | |
return new_state | |
# 3. If Action Input is in the last part, launch tool | |
if last_marker == 'Action Input': | |
# Try to extract the Action and Action Input for the last occurrence | |
action_match = list(re.finditer(r"Action:\s*([^\n]+)", last_section)) | |
action_input_match = list(re.finditer(r"Action Input:\s*([^\n]+)", last_section)) | |
if action_match and action_input_match: | |
tool_name = action_match[-1].group(1).strip() | |
tool_input_raw = action_input_match[-1].group(1).strip() | |
print(f"ReAct: Found Action: {tool_name}, Input: '{tool_input_raw}'") | |
# Format tool_args as in your original code (simplified here) | |
tool_args = {"query": tool_input_raw} | |
tool_call_id = str(uuid.uuid4()) | |
parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] | |
updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) | |
new_state["messages"] = messages[:-1] + [updated_ai_message] | |
new_state.pop("done", None) | |
print(f"AIMessage updated with tool_calls: {parsed_tool_calls}") | |
return new_state | |
else: | |
print("Action Input found at end, but could not parse Action or Action Input.") | |
new_state["done"] = True | |
return new_state | |
# Fallback: mark as done | |
print("No actionable marker found at end of last AI output. Marking as done.") | |
new_state["done"] = True | |
return new_state | |
def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'): | |
"""Download a YouTube video using yt-dlp""" | |
# Ensure the output directory exists | |
os.makedirs(output_dir, exist_ok=True) | |
# Delete all files in the output directory | |
files = glob.glob(os.path.join(output_dir, '*')) | |
for f in files: | |
try: | |
os.remove(f) | |
except Exception as e: | |
print(f"Error deleting {f}: {str(e)}") | |
# Set output path for yt-dlp | |
output_path = os.path.join(output_dir, output_filename) | |
try: | |
ydl_opts = { | |
'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', | |
'outtmpl': output_path, | |
'quiet': True, | |
'merge_output_format': 'mp4', # Ensures merged output is mp4 | |
'postprocessors': [{ | |
'key': 'FFmpegVideoConvertor', | |
'preferedformat': 'mp4', # Recode if needed | |
}] | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
return output_path | |
except Exception as e: | |
print(f"Error downloading YouTube video: {str(e)}") | |
return None | |
def extract_frames(video_path, output_dir, frame_interval_seconds=10): | |
"""Extract frames from a video file at specified intervals""" | |
# Clean output directory before extracting new frames | |
if os.path.exists(output_dir): | |
for filename in os.listdir(output_dir): | |
file_path = os.path.join(output_dir, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
else: | |
os.makedirs(output_dir, exist_ok=True) | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return False | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_interval = int(fps * frame_interval_seconds) | |
count = 0 | |
saved = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if count % frame_interval == 0: | |
frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") | |
cv2.imwrite(frame_filename, frame) | |
saved += 1 | |
count += 1 | |
cap.release() | |
print(f"Extracted {saved} frames.") | |
return saved > 0 | |
except Exception as e: | |
print(f"Exception during frame extraction: {e}") | |
return False | |
def answer_question_on_frame(image_path, question): | |
"""Answer a question about a single video frame using BLIP""" | |
try: | |
vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly | |
processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used | |
model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used | |
device = "cpu" | |
image = Image.open(image_path).convert('RGB') | |
inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
out = model_vqa.generate(**inputs) | |
answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
return answer | |
except Exception as e: | |
print(f"Error processing frame {image_path}: {str(e)}") | |
return "Error processing this frame" | |
def answer_video_question(frames_dir, question): | |
"""Answer a question about a video by analyzing extracted frames""" | |
valid_exts = ('.jpg', '.jpeg', '.png') | |
# Check if directory exists | |
if not os.path.exists(frames_dir): | |
return { | |
"most_common_answer": "No frames found to analyze.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) | |
if f.lower().endswith(valid_exts)] | |
# Sort frames properly by number | |
def get_frame_number(filename): | |
match = re.search(r'(\d+)', os.path.basename(filename)) | |
return int(match.group(1)) if match else 0 | |
frame_files = sorted(frame_files, key=get_frame_number) | |
if not frame_files: | |
return { | |
"most_common_answer": "No valid image frames found.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
answers = [] | |
for frame_path in frame_files: | |
try: | |
ans = answer_question_on_frame(frame_path, question) | |
answers.append(ans) | |
print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") | |
except Exception as e: | |
print(f"Error processing frame {frame_path}: {str(e)}") | |
if not answers: | |
return { | |
"most_common_answer": "Could not analyze any frames successfully.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
counted = Counter(answers) | |
most_common_answer, freq = counted.most_common(1)[0] | |
return { | |
"most_common_answer": most_common_answer, | |
"all_answers": answers, | |
"answer_counts": counted | |
} | |
class YoutubeScreenshotQA(BaseTool): | |
name: str = "youtube_screenshot_qa" | |
description: str = ( | |
"Downloads a YouTube video, extracts screenshots at intervals, " | |
"and answers a question about the video based on the screenshots. " | |
"Input should be a dict with keys: 'youtube_url' and 'question'." | |
"Example input: {'youtube_url': 'https://www.youtube.com/watch?v=L1vXCYZAYYM', 'question': 'What is the highest number of bird species on camera simultaneously?'}" | |
) | |
frame_interval_seconds: int = 10 # Can be parameterized if needed | |
def _run(self, input_data: Dict[str, Any]) -> str: | |
youtube_url = input_data.get("youtube_url") | |
question = input_data.get("question") | |
if not youtube_url or not question: | |
return "Error: Input must include 'youtube_url' and 'question'." | |
# Step 1: Download the video | |
video_dir = '/tmp/video/' | |
video_filename = 'downloaded_video.mp4' | |
print(f"Downloading YouTube video from {youtube_url}...") | |
video_path = download_youtube_video(youtube_url, output_dir=video_dir, output_filename=video_filename) | |
if not video_path or not os.path.exists(video_path): | |
return "Error: Failed to download the YouTube video." | |
# Step 2: Extract frames | |
frames_dir = '/tmp/video_frames/' | |
print(f"Extracting frames from {video_path} every {self.frame_interval_seconds} seconds...") | |
success = extract_frames(video_path, frames_dir, frame_interval_seconds=self.frame_interval_seconds) | |
if not success: | |
return "Error: Failed to extract frames from the video." | |
# Step 3: Analyze frames and answer question | |
print(f"Answering question about the video frames...") | |
answer_result = answer_video_question(frames_dir, question) | |
if not answer_result or not answer_result.get("most_common_answer"): | |
return "Error: Could not analyze video frames to answer the question." | |
# Format the result | |
most_common = answer_result["most_common_answer"] | |
all_answers = answer_result["all_answers"] | |
counts = answer_result["answer_counts"] | |
result = ( | |
f"Most common answer: {most_common}\n" | |
f"All answers: {all_answers}\n" | |
f"Answer counts: {dict(counts)}" | |
) | |
return result | |
def tools_condition_with_logging(state: AgentState): | |
""" | |
Custom tools condition function that checks if the last message contains tool calls | |
in the Thought/Action/Action Input format and logs the transition decision. | |
Args: | |
state (AgentState): The current state containing messages | |
Returns: | |
str: "tools" if tool calls are present, "__end__" otherwise | |
""" | |
import re | |
# Ensure we have messages in the state | |
if not state.get("messages") or len(state["messages"]) == 0: | |
print("❌ No messages found in state, ending conversation") | |
return "__end__" | |
# Get the last message | |
last_message = state["messages"][-1] | |
# Get message content | |
content = "" | |
if hasattr(last_message, 'content'): | |
content = str(last_message.content) | |
elif isinstance(last_message, dict) and 'content' in last_message: | |
content = str(last_message['content']) | |
else: | |
print("❌ No content found in last message, ending conversation") | |
return "__end__" | |
print(f"🔍 Analyzing message content: {content[:200]}...") | |
# Check for Thought/Action/Action Input format | |
has_tool_calls = False | |
# Pattern to match the format: | |
# Thought: <thought> | |
# Action: <tool_name> | |
# Action Input: <input> | |
thought_action_pattern = re.compile( | |
r'Thought:\s*(.*?)\n\s*Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', | |
re.DOTALL | re.IGNORECASE | |
) | |
# Also check for just Action/Action Input without Thought | |
action_only_pattern = re.compile( | |
r'Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', | |
re.DOTALL | re.IGNORECASE | |
) | |
# Look for the complete format first | |
match = thought_action_pattern.search(content) | |
if not match: | |
# Try the action-only format | |
match = action_only_pattern.search(content) | |
if match: | |
thought = "No thought provided" | |
action = match.group(1).strip() | |
action_input = match.group(2).strip() | |
else: | |
action = None | |
action_input = None | |
thought = None | |
else: | |
thought = match.group(1).strip() | |
action = match.group(2).strip() | |
action_input = match.group(3).strip() | |
if match and action: | |
has_tool_calls = True | |
print(f"🔧 Found tool call format:") | |
print(f" Thought: {thought}") | |
print(f" Action: {action}") | |
print(f" Action Input: {action_input}") | |
# Map common tool names to your actual tools | |
tool_mappings = { | |
'wikipedia_semantic_search': 'wikipedia_tool', | |
'wikipedia': 'wikipedia_tool', | |
'search': 'search_tool', | |
'duckduckgo_search': 'search_tool', | |
'web_search': 'search_tool', | |
'youtube_screenshot_qa_tool': 'youtube_tool', | |
'youtube': 'youtube_tool', | |
} | |
# Normalize the action name | |
normalized_action = action.lower().strip() | |
# Store the parsed tool call information in the state for the tools node to use | |
if 'parsed_tool_calls' not in state: | |
state['parsed_tool_calls'] = [] | |
tool_call_info = { | |
'thought': thought, | |
'action': action, | |
'action_input': action_input, | |
'normalized_action': normalized_action, | |
'tool_mapping': tool_mappings.get(normalized_action, normalized_action) | |
} | |
state['parsed_tool_calls'].append(tool_call_info) | |
print(f"🚀 Added tool call to state: {tool_call_info}") | |
# Don't execute tools here - let call_tool handle execution | |
# Just store the parsed information for call_tool to use | |
# Also check for standalone tool mentions (fallback) | |
if not has_tool_calls: | |
# Check for tool names mentioned in content | |
tool_keywords = [ | |
'wikipedia_semantic_search', 'wikipedia', 'search', 'duckduckgo', | |
'youtube_screenshot_qa_tool', 'youtube', 'web search' | |
] | |
content_lower = content.lower() | |
for keyword in tool_keywords: | |
if keyword in content_lower: | |
print(f"🔧 Found tool keyword '{keyword}' in content (fallback detection)") | |
has_tool_calls = True | |
break | |
if has_tool_calls: | |
print("🔧 Tool calls detected, transitioning to tools...") | |
return "tools" | |
else: | |
print("✅ No tool calls found, ending conversation") | |
return "__end__" | |
# 2. Improved call_tool with memory management | |
def call_tool_with_memory_management(state: AgentState) -> AgentState: | |
"""Process tool calls with memory management.""" | |
print("Running call_tool with memory management...") | |
# Clear CUDA cache before processing | |
try: | |
import torch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"🧹 Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
except: | |
pass | |
# Check if we have parsed tool calls from the condition function | |
if 'parsed_tool_calls' in state and state['parsed_tool_calls']: | |
return execute_parsed_tool_calls(state) | |
# Fallback to original OpenAI-style tool calls handling | |
messages = state["messages"] | |
last_message = messages[-1] | |
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: | |
print("No tool calls found in last message") | |
return state | |
# Copy the messages to avoid mutating the original list | |
new_messages = list(messages) | |
print(f"Processing {len(last_message.tool_calls)} tool calls") | |
for i, tool_call in enumerate(last_message.tool_calls): | |
print(f"Processing tool call {i+1}: {tool_call['name'] if isinstance(tool_call, dict) else tool_call.name}") | |
# Handle both dict and object-style tool calls | |
if isinstance(tool_call, dict): | |
tool_name = tool_call.get("name", "") | |
args = tool_call.get("args", {}) | |
tool_call_id = tool_call.get("id", str(uuid.uuid4())) | |
else: | |
tool_name = getattr(tool_call, "name", "") | |
args = getattr(tool_call, "args", {}) | |
tool_call_id = getattr(tool_call, "id", str(uuid.uuid4())) | |
# Find the matching tool | |
selected_tool = None | |
for tool in tools: | |
if tool.name.lower() == tool_name.lower(): | |
selected_tool = tool | |
break | |
if not selected_tool: | |
tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}" | |
print(f"Tool not found: {tool_name}") | |
else: | |
try: | |
# Extract query | |
if isinstance(args, dict) and "query" in args: | |
query = args["query"] | |
else: | |
query = str(args) if args else "" | |
print(f"Executing {tool_name} with query: {query[:100]}...") | |
tool_result = selected_tool.run(query) | |
# Aggressive truncation to prevent memory issues | |
max_length = 3000 if "wikipedia" in tool_name.lower() else 2000 | |
if len(tool_result) > max_length: | |
tool_result = tool_result[:max_length] + f"... [Result truncated from {len(tool_result)} to {max_length} chars to prevent memory issues]" | |
print(f"📄 Truncated result to {max_length} characters") | |
print(f"Tool result length: {len(tool_result)} characters") | |
except Exception as e: | |
tool_result = f"Error executing tool '{tool_name}': {str(e)}" | |
print(f"Tool execution error: {e}") | |
# Create tool message | |
tool_message = ToolMessage( | |
content=tool_result, | |
name=tool_name, | |
tool_call_id=tool_call_id | |
) | |
new_messages.append(tool_message) | |
print(f"Added tool message for {tool_name}") | |
# Update the state | |
new_state = state.copy() | |
new_state["messages"] = new_messages | |
# Clear CUDA cache after processing | |
try: | |
import torch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except: | |
pass | |
return new_state | |
def execute_parsed_tool_calls(state: AgentState): | |
""" | |
Execute tool calls that were parsed from the Thought/Action/Action Input format. | |
This is called by call_tool when parsed_tool_calls are present in state. | |
Args: | |
state (AgentState): The current state containing parsed tool calls | |
Returns: | |
AgentState: Updated state with tool results | |
""" | |
# Use the same tools list that's available globally | |
# Map tool names to the actual tool instances | |
tool_name_mappings = { | |
'wikipedia_semantic_search': 'wikipedia_tool', | |
'wikipedia': 'wikipedia_tool', | |
'search': 'enhanced_search', # Updated mapping | |
'duckduckgo_search': 'enhanced_search', # Updated mapping | |
'web_search': 'enhanced_search', # Updated mapping | |
'enhanced_search': 'enhanced_search', # Direct mapping | |
'youtube_screenshot_qa_tool': 'youtube_tool', | |
'youtube': 'youtube_tool', | |
} | |
# Create a lookup by tool names for your existing tools list | |
tools_by_name = {} | |
for tool in tools: | |
tools_by_name[tool.name.lower()] = tool | |
# Also map by class name for flexibility | |
class_name = tool.__class__.__name__.lower() | |
if 'wikipedia' in class_name: | |
tools_by_name['wikipedia_tool'] = tool | |
elif 'search' in class_name or 'duck' in class_name: | |
tools_by_name['search_tool'] = tool | |
elif 'youtube' in class_name: | |
tools_by_name['youtube_tool'] = tool | |
# Copy messages to avoid mutation during iteration | |
new_messages = list(state["messages"]) | |
for tool_call in state['parsed_tool_calls']: | |
action = tool_call['action'] | |
action_input = tool_call['action_input'] | |
thought = tool_call['thought'] | |
normalized_action = tool_call['normalized_action'] | |
print(f"🚀 Executing tool: {action} with input: {action_input}") | |
# Find the tool instance | |
tool_instance = None | |
# Try direct name match first | |
if normalized_action in tools_by_name: | |
tool_instance = tools_by_name[normalized_action] | |
# Try mapped name | |
elif normalized_action in tool_name_mappings: | |
mapped_name = tool_name_mappings[normalized_action] | |
if mapped_name in tools_by_name: | |
tool_instance = tools_by_name[mapped_name] | |
if tool_instance: | |
try: | |
result = tool_instance.run(action_input) | |
if len(result) > 6000: | |
result = result[:6000] + "... [Result truncated due to length]" | |
# Create observation message in the format your agent expects | |
from langchain_core.messages import AIMessage | |
observation = f"Observation: {result}" | |
observation_message = AIMessage(content=observation) | |
new_messages.append(observation_message) | |
print(f"✅ Tool '{action}' executed successfully") | |
except Exception as e: | |
print(f"❌ Error executing tool '{action}': {e}") | |
from langchain_core.messages import AIMessage | |
error_msg = f"Observation: Error executing '{action}': {str(e)}" | |
error_message = AIMessage(content=error_msg) | |
new_messages.append(error_message) | |
else: | |
print(f"❌ Tool '{action}' not found in available tools") | |
available_tool_names = list(tools_by_name.keys()) | |
from langchain_core.messages import AIMessage | |
error_msg = f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}" | |
error_message = AIMessage(content=error_msg) | |
new_messages.append(error_message) | |
# Update state with new messages and clear parsed tool calls | |
state["messages"] = new_messages | |
state['parsed_tool_calls'] = [] | |
return state | |
# 1. Add loop detection to your AgentState | |
def should_continue(state: AgentState) -> str: | |
"""Determine if the agent should continue or end.""" | |
print("Running should_continue....") | |
messages = state["messages"] | |
#ipdb.set_trace() | |
# Check if we're done | |
if state.get("done", False): | |
return "end" | |
# Prevent infinite loops - limit tool calls | |
tool_call_count = sum(1 for msg in messages if hasattr(msg, 'tool_calls') and msg.tool_calls) | |
if tool_call_count >= 3: # Max 3 tool calls per conversation | |
print(f"⚠️ Stopping: Too many tool calls ({tool_call_count})") | |
return "end" | |
# Check for repeated tool calls with same query | |
recent_tool_calls = [] | |
for msg in messages[-6:]: # Check last 6 messages | |
if hasattr(msg, 'tool_calls') and msg.tool_calls: | |
for tool_call in msg.tool_calls: | |
if isinstance(tool_call, dict): | |
recent_tool_calls.append((tool_call.get('name'), str(tool_call.get('args', {})))) | |
if len(recent_tool_calls) >= 2 and recent_tool_calls[-1] == recent_tool_calls[-2]: | |
print("⚠️ Stopping: Repeated tool call detected") | |
return "end" | |
# Check message count to prevent runaway conversations | |
if len(messages) > 15: | |
print(f"⚠️ Stopping: Too many messages ({len(messages)})") | |
return "end" | |
return "continue" | |
def route_after_parse_react(state: AgentState) -> str: | |
"""Determines the next step after parsing LLM output, prioritizing end state.""" | |
if state.get("done", False): # Check if parse_react_output decided we are done | |
return "end_processing" | |
# Original logic: check for tool calls in the last message | |
# Ensure messages list and last message exist before checking tool_calls | |
messages = state.get("messages", []) | |
if messages: | |
last_message = messages[-1] | |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
return "call_tool" | |
return "call_llm" | |
#wikipedia_tool = WikipediaSearchToolWithFAISS() | |
#search_tool = DuckDuckGoSearchRun() | |
#youtube_screenshot_qa_tool = YoutubeScreenshotQA() | |
# Combine all tools | |
#tools = [wikipedia_tool, search_tool, youtube_screenshot_qa_tool] | |
# Update your tools list to use the global instances | |
# | |
# --- Graph Construction --- | |
# --- Graph Construction --- | |
def create_memory_safe_workflow(): | |
"""Create a workflow with memory management and loop prevention.""" | |
# These models are initialized here but might be better managed if they need to be released/reinitialized | |
# like you attempt in run_agent. Consider passing them or managing their lifecycle carefully. | |
hf_pipe = create_llm_pipeline() | |
llm = HuggingFacePipeline(pipeline=hf_pipe) | |
# vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly | |
# processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used | |
# model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used | |
workflow = StateGraph(AgentState) | |
# Bind the llm_model to the call_llm_with_memory_management function | |
bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm) | |
# Add nodes with memory-safe versions | |
workflow.add_node("call_llm", bound_call_llm) # Use the bound version here | |
workflow.add_node("parse_react_output", parse_react_output) | |
workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly | |
# Set entry point | |
workflow.set_entry_point("call_llm") | |
# Add conditional edges | |
workflow.add_conditional_edges( | |
"call_llm", | |
should_continue, | |
{ | |
"continue": "parse_react_output", | |
"end": END | |
} | |
) | |
workflow.add_conditional_edges( | |
"parse_react_output", | |
route_after_parse_react, | |
{ | |
"call_tool": "call_tool", | |
"call_llm": "call_llm", | |
"end_processing": END | |
} | |
) | |
workflow.add_edge("call_tool", "call_llm") | |
return workflow.compile() | |
# --- Run the Agent --- | |
def run_agent(myagent, state: AgentState): | |
""" | |
Initialize agent with proper system message and formatted query. | |
""" | |
#global llm, hf_pipe, model_vqa, processor_vqa | |
global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, tools | |
#ipdb.set_trace() | |
# At the module level, create instances once | |
WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS() | |
SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=3000) | |
YOUTUBE_TOOL = YoutubeScreenshotQA() | |
tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL] | |
# Create a fresh system message each time | |
formatted_tools_description = render_text_description(tools) | |
current_date_str = datetime.now().strftime("%Y-%m-%d") | |
system_content = f"""You are a general AI assistant. with access to these tools: | |
{formatted_tools_description} | |
If you need the most current information as of 2025, use enhanced_search | |
If you need to do in-depth research, use wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval | |
If you can answer the question confidently, do so directly. | |
If the question seems like gibberish (not English), try flipping the entire question and re-read the question. | |
If you need more information, use a tool. | |
(Think through the problem step by step) | |
When using a tool, follow this format: | |
Thought: <your thought> | |
Action: <tool_name> | |
Action Input: <tool_input> | |
Only use the tools listed above for the Action: step. Do not invent new tool names or actions. If you need to reason, do so in the Thought: step. After using a tool, process its output in your Thought: step, not as an Action. | |
Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. | |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. | |
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string | |
Do not provide disclaimers. | |
Do not provide supporting details. | |
""" | |
# Get user question from AgentState | |
query = state['question'] | |
# Pattern for YouTube | |
yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
has_youtube = re.search(yt_pattern, query) is not None | |
if has_youtube: | |
# Store the extracted YouTube URL in the state | |
url_match = re.search(r"(https?://[^\s]+)", query) | |
if url_match: | |
state['youtube_url'] = url_match.group(0) | |
# Format the user query to guide the model better | |
formatted_query = f"""{query}""" | |
# Initialize agent state with proper message types | |
system_message = SystemMessage(content=system_content) | |
human_message = HumanMessage(content=formatted_query) | |
# Initialize state with properly typed messages and done=False | |
# state = {"messages": [system_message, human_message], "done": False} | |
state['messages'] = [system_message, human_message] | |
state["done"] = False | |
# Use the new method to run the graph | |
result = myagent.invoke(state) | |
# Check if FINAL ANSWER was given (i.e., workflow ended) | |
if result.get("done"): | |
#del llm | |
#del hf_pipe | |
#del model_vqa | |
#del processor_vqa | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
gc.collect() | |
print("Released GPU memory after FINAL ANSWER.") | |
# Re-initialize for the next run | |
#hf_pipe = create_llm_pipeline() | |
#llm = HuggingFacePipeline(pipeline=hf_pipe) | |
#print("Re-initilized llm...") | |
# Extract and return just the messages for cleaner output | |
return result["messages"] | |