# Standard Library import os import re import tempfile import string import glob import shutil import gc import uuid import signal from datetime import datetime from io import BytesIO from contextlib import contextmanager from langchain_huggingface import HuggingFacePipeline from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set import time from collections import Counter # Third-Party Packages import cv2 import requests import wikipedia import spacy import yt_dlp import librosa from PIL import Image from bs4 import BeautifulSoup from duckduckgo_search import DDGS from sentence_transformers import SentenceTransformer from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline # LangChain Ecosystem from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain_community.document_loaders import WikipediaLoader from langchain_huggingface import HuggingFaceEndpoint from langchain_community.retrievers import BM25Retriever from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description from langchain_core.documents import Document # LangGraph from langgraph.graph import START, END, StateGraph from langgraph.prebuilt import ToolNode, tools_condition # PyTorch import torch from functools import partial from transformers import pipeline # Additional Utilities from datetime import datetime from urllib.parse import urljoin, urlparse import logging nlp = spacy.load("en_core_web_sm") logger = logging.getLogger(__name__) # --- Model Configuration --- def create_llm_pipeline(): #model_id = "meta-llama/Llama-2-13b-chat-hf" #model_id = "meta-llama/Llama-3.3-70B-Instruct" #model_id = "mistralai/Mistral-Small-24B-Base-2501" model_id = "mistralai/Mistral-7B-Instruct-v0.3" #model_id = "Qwen/Qwen2-7B-Instruct" return pipeline( "text-generation", model=model_id, device_map="auto", torch_dtype=torch.float16, max_new_tokens=1024, temperature=0.1 ) # Define file extension sets for each category PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} SPREADSHEET_EXTENSIONS = { '.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', '.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', '.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', '.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', '.pmv', '.uos', '.txt' } def get_file_type(filename: str) -> str: if not filename or '.' not in filename or filename == '': return '' ext = filename.lower().rsplit('.', 1)[-1] dot_ext = f'.{ext}' if dot_ext in PICTURE_EXTENSIONS: return 'picture' elif dot_ext in AUDIO_EXTENSIONS: return 'audio' elif dot_ext in CODE_EXTENSIONS: return 'code' elif dot_ext in SPREADSHEET_EXTENSIONS: return 'spreadsheet' else: return 'unknown' def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: """ Writes bytes to a file in the system temporary directory using the provided file_name. Returns the full path to the saved file. The file will persist until manually deleted or the OS cleans the temp directory. """ temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces os.makedirs(temp_dir, exist_ok=True) file_path = os.path.join(temp_dir, file_name) with open(file_path, 'wb') as f: f.write(file_bytes) print(f"File written to: {file_path}") return file_path def extract_final_answer(text: str) -> str: """ Returns the substring starting from the last occurrence of 'FINAL ANSWER:' (case-insensitive) to the end of the string, with any trailing punctuation removed. If not found, returns an empty string. """ marker = "FINAL ANSWER:" idx = text.lower().rfind(marker.lower()) if idx == -1: return "" result = text[idx:].strip() # Remove trailing punctuation return result.rstrip(string.punctuation + " ") class EnhancedDuckDuckGoSearchTool(BaseTool): name: str = "enhanced_search" description: str = ( "Performs a DuckDuckGo web search and retrieves actual content from the top web results. " "Input should be a search query string. " "Returns search results with extracted content from web pages, making it much more useful for answering questions. " "Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. " "Ideal for topics that require the latest news, recent developments, or information not covered in static sources." ) max_results: int = 3 max_chars_per_page: int = 3000 session: Any = None # Now it's optional and defaults to None # Use model_post_init for initialization logic in Pydantic v2+ def model_post_init(self, __context: Any) -> None: super().model_post_init(__context) # Initialize HTTP session here self.session = requests.Session() self.session.headers.update({ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', 'Accept-Language': 'en-US,en;q=0.5', 'Accept-Encoding': 'gzip, deflate', 'Connection': 'keep-alive', 'Upgrade-Insecure-Requests': '1', }) def _search_duckduckgo(self, query: str) -> List[Dict]: """Perform DuckDuckGo search and return results.""" try: with DDGS() as ddgs: results = list(ddgs.text(query, max_results=self.max_results)) return results except Exception as e: logger.error(f"DuckDuckGo search failed: {e}") return [] def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]: """Extract clean text content from a web page.""" try: # Skip certain file types if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']): return "Content type not supported for extraction" response = self.session.get(url, timeout=timeout, allow_redirects=True) response.raise_for_status() # Check content type content_type = response.headers.get('content-type', '').lower() if 'text/html' not in content_type: return "Non-HTML content detected" soup = BeautifulSoup(response.content, 'html.parser') # Remove script and style elements for script in soup(["script", "style", "nav", "header", "footer", "aside", "form"]): script.decompose() # Try to find main content areas main_content = None for selector in ['main', 'article', '.content', '#content', '.post', '.entry']: main_content = soup.select_one(selector) if main_content: break if not main_content: main_content = soup.find('body') or soup # Extract text text = main_content.get_text(separator='\n', strip=True) # Clean up the text lines = [line.strip() for line in text.split('\n') if line.strip()] text = '\n'.join(lines) # Remove excessive whitespace text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r' {2,}', ' ', text) # Truncate if too long if len(text) > self.max_chars_per_page: text = text[:self.max_chars_per_page] + "\n[Content truncated...]" return text except requests.exceptions.Timeout: return "Page loading timed out" except requests.exceptions.RequestException as e: return f"Failed to retrieve page: {str(e)}" except Exception as e: logger.error(f"Content extraction failed for {url}: {e}") return "Failed to extract content from page" def _format_search_result(self, result: Dict, content: str) -> str: """Format a single search result with its content.""" title = result.get('title', 'No title') url = result.get('href', 'No URL') snippet = result.get('body', 'No snippet') formatted = f""" ๐Ÿ” **{title}** URL: {url} Snippet: {snippet} ๐Ÿ“„ **Page Content:** {content} --- """ return formatted def run(self, query: str) -> str: """Execute the enhanced search.""" if not query or not query.strip(): return "Please provide a search query." query = query.strip() logger.info(f"Searching for: {query}") # Perform DuckDuckGo search search_results = self._search_duckduckgo(query) if not search_results: return f"No search results found for query: {query}" # Process each result and extract content enhanced_results = [] processed_count = 0 for i, result in enumerate(search_results[:self.max_results]): url = result.get('href', '') if not url: continue logger.info(f"Processing result {i+1}: {url}") # Extract content from the page content = self._extract_content_from_url(url) if content and len(content.strip()) > 50: # Only include results with substantial content formatted_result = self._format_search_result(result, content) enhanced_results.append(formatted_result) processed_count += 1 # Small delay to be respectful to servers time.sleep(0.5) if not enhanced_results: return f"Search completed but no content could be extracted from the pages for query: {query}" # Compile final response response = f"""๐Ÿ” **Enhanced Search Results for: "{query}"** Found {len(search_results)} results, successfully processed {processed_count} pages with content. {''.join(enhanced_results)} ๐Ÿ’ก **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query. """ # Ensure the response isn't too long if len(response) > 8000: response = response[:8000] + "\n[Response truncated to prevent memory issues]" return response def _run(self, query: str) -> str: """Required by BaseTool interface.""" return self.run(query) # --- Agent State Definition --- class AgentState(TypedDict): messages: Annotated[List[AnyMessage], lambda x, y: x + y] done: bool = False # Default value of False question: str task_id: str input_file: Optional[bytes] file_type: Optional[str] context: List[Document] # Using LangChain's Document class file_path: Optional[str] youtube_url: Optional[str] answer: Optional[str] frame_answers: Optional[list] def fetch_page_with_tables(page_title): """ Fetches Wikipedia page content and extracts all tables as readable text. Returns a tuple: (main_text, [table_texts]) """ # Fetch the page object page = wikipedia.page(page_title) main_text = page.content # Get the HTML for table extraction html = page.html() soup = BeautifulSoup(html, 'html.parser') tables = soup.find_all('table') table_texts = [] for table in tables: rows = table.find_all('tr') table_lines = [] for row in rows: cells = row.find_all(['th', 'td']) cell_texts = [cell.get_text(strip=True) for cell in cells] if cell_texts: # Format as Markdown table row table_lines.append(" | ".join(cell_texts)) if table_lines: table_text = "\n".join(table_lines) table_texts.append(table_text) return main_text, table_texts class WikipediaSearchToolWithFAISS(BaseTool): name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval" description: str = ( "Fetches content from multiple Wikipedia pages based on intelligent NLP query processing " "of various search candidates, with strong prioritization of query entities. It then performs " "entity-focused semantic search across all fetched content to find the most relevant information, " "with improved retrieval for lists like discographies. Uses spaCy for named entity " "recognition and query enhancement. Input should be a search query or topic. " "Note: Uses the current live version of Wikipedia." ) embedding_model_name: str = "all-MiniLM-L6-v2" chunk_size: int = 4000 chunk_overlap: int = 250 # Maintained moderate overlap top_k_results: int = 3 spacy_model: str = "en_core_web_sm" # Increased multiplier to fetch more candidates per semantic query variant semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic def __init__(self, **kwargs): super().__init__(**kwargs) try: self._nlp = spacy.load(self.spacy_model) print(f"Loaded spaCy model: {self.spacy_model}") self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name) # Refined separators for better handling of Wikipedia lists and sections self._text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, separators=[ "\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content) "\n\n\n", "\n\n", # Multiple newlines (paragraph breaks) "\n* ", "\n- ", "\n# ", # List items "\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation " ", "" # Word and character level ] ) except OSError as e: print(f"Error loading spaCy model '{self.spacy_model}': {e}") print("Try running: python -m spacy download en_core_web_sm") self._nlp = None self._embedding_model = None self._text_splitter = None except Exception as e: print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}") self._nlp = None self._embedding_model = None self._text_splitter = None def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]: if not self._nlp: return [], [], query doc = self._nlp(query) main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]] keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2] main_entities = list(dict.fromkeys(main_entities)) keywords = list(dict.fromkeys(keywords)) processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()] processed_query = " ".join(processed_tokens) return main_entities, keywords, processed_query def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]: candidates_set = set() entity_prefix = main_entities[0] if main_entities else None for me in main_entities: candidates_set.add(me) candidates_set.add(query) if processed_query and processed_query != query: candidates_set.add(processed_query) if entity_prefix and keywords: first_entity_lower = entity_prefix.lower() for kw in keywords[:3]: if kw not in first_entity_lower and len(kw) > 2: candidates_set.add(f"{entity_prefix} {kw}") keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2) if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}") if len(main_entities) > 1: candidates_set.add(" ".join(main_entities[:2])) if keywords: keyword_combo = " ".join(keywords[:2]) if entity_prefix: candidate_to_add = f"{entity_prefix} {keyword_combo}" if not any(c.lower() == candidate_to_add.lower() for c in candidates_set): candidates_set.add(candidate_to_add) elif not main_entities: candidates_set.add(keyword_combo) ordered_candidates = [] for me in main_entities: if me not in ordered_candidates: ordered_candidates.append(me) for c in list(candidates_set): if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c) print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}") return ordered_candidates def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]: candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text) found_pages_data: List[Tuple[str, str]] = [] processed_page_titles: Set[str] = set() for i, candidate_query in enumerate(candidates): print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'") page_object = None final_page_title = None is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False try: try: page_to_load = candidate_query suggest_mode = True # Default to auto_suggest=True if is_candidate_entity_focused and main_entities_from_query: try: # Attempt precise match first for entity-focused candidates temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True) suggest_mode = False # Flag that precise match worked except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError): print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.") # Fallthrough to auto_suggest=True below if this fails if suggest_mode: # If not attempted or failed with auto_suggest=False temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True) final_page_title = temp_page.title if is_candidate_entity_focused and main_entities_from_query: title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) if not title_matches_main_entity: print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') " f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") continue page_object = temp_page print(f" โœ“ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'") except wikipedia.exceptions.PageError: if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...") search_results = wikipedia.search(candidate_query, results=1) if not search_results: print(f" - No Wikipedia search results for '{candidate_query}'.") continue search_result_title = search_results[0] try: temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical final_page_title = temp_page.title if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) if not title_matches_main_entity: print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') " f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") continue page_object = temp_page print(f" โœ“ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'") except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr: print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}") else: print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.") except wikipedia.exceptions.DisambiguationError as de: print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}") if de.options: option_title = de.options[0] try: temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True) final_page_title = temp_page.title if is_candidate_entity_focused and main_entities_from_query: # Check against original intent title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) if not title_matches_main_entity: print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') " f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") continue page_object = temp_page print(f" โœ“ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'") except Exception as e_dis_opt: print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}") if page_object and final_page_title and (final_page_title not in processed_page_titles): # Extract main text main_text = page_object.content # Extract tables using BeautifulSoup try: html = page_object.html() soup = BeautifulSoup(html, 'html.parser') tables = soup.find_all('table') table_texts = [] for table in tables: rows = table.find_all('tr') table_lines = [] for row in rows: cells = row.find_all(['th', 'td']) cell_texts = [cell.get_text(strip=True) for cell in cells] if cell_texts: table_lines.append(" | ".join(cell_texts)) if table_lines: table_text = "\n".join(table_lines) table_texts.append(table_text) except Exception as e: print(f" !! Error extracting tables for '{final_page_title}': {e}") table_texts = [] # Combine main text and all table texts as separate chunks all_text_chunks = [main_text] + table_texts for chunk in all_text_chunks: found_pages_data.append((chunk, final_page_title)) processed_page_titles.add(final_page_title) print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}") except Exception as e: print(f" !! Unexpected error processing candidate '{candidate_query}': {e}") if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.") else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.") return found_pages_data def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]: core_query_parts = set() core_query_parts.add(query) if processed_query != query: core_query_parts.add(processed_query) if keywords: core_query_parts.add(" ".join(keywords[:2])) section_phrases_templates = [] lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords) section_keywords_map = { "discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"], "biography": ["biography", "life story", "career details", "background history"], "filmography": ["filmography", "list of films", "movie appearances", "acting roles"], } for section_term_key, specific_phrases_list in section_keywords_map.items(): # Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums") # are mentioned or implied by the query terms. if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()): section_phrases_templates.extend(specific_phrases_list) # Also check if phrases themselves are in query terms (e.g. query "list of albums by X") for phrase in specific_phrases_list: if phrase in query.lower(): # Check against original query for direct phrase matches section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit break section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate final_search_queries = set() if main_entities: entity_prefix = main_entities[0] final_search_queries.add(entity_prefix) for part in core_query_parts: final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part) for phrase_template in section_phrases_templates: final_search_queries.add(f"{entity_prefix} {phrase_template}") if "list of" in phrase_template or "history of" in phrase_template : final_search_queries.add(f"{phrase_template} of {entity_prefix}") else: final_search_queries.update(core_query_parts) final_search_queries.update(section_phrases_templates) deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip())) print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}") all_results_docs: List[Document] = [] seen_content_hashes: Set[int] = set() k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier for search_query_variant in deduplicated_queries: try: results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.") for doc, score in results: # Assuming similarity_search_with_score returns (doc, score) content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness if content_hash not in seen_content_hashes: seen_content_hashes.add(content_hash) doc.metadata['retrieved_by_variant'] = search_query_variant doc.metadata['retrieval_score'] = float(score) # Store score all_results_docs.append(doc) except Exception as e: print(f" Error in semantic search for variant '{search_query_variant}': {e}") # Sort all collected unique results by score (FAISS L2 distance is lower is better) all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf'))) print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.") return all_results_docs[:self.top_k_results] def _run(self, query: str) -> str: if not self._nlp or not self._embedding_model or not self._text_splitter: print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.") return "Error: Wikipedia tool components not initialized properly. Please check server logs." try: print(f"\n--- Running {self.name} for query: '{query}' ---") main_entities, keywords, processed_query = self._extract_entities_and_keywords(query) print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'") fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query) if not fetched_pages_data: return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. " f"Main entities sought: {main_entities}") all_page_titles = [title for _, title in fetched_pages_data] print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}") all_documents: List[Document] = [] for page_content, page_title in fetched_pages_data: chunks = self._text_splitter.split_text(page_content) if not chunks: print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.") continue for i, chunk_text in enumerate(chunks): all_documents.append(Document(page_content=chunk_text, metadata={ "source_page_title": page_title, "original_query": query, "chunk_index": i # Add chunk index for potential debugging or ordering })) print(f"Split content from '{page_title}' into {len(chunks)} chunks.") if not all_documents: return (f"Could not process content into searchable chunks from the fetched Wikipedia pages " f"({', '.join(all_page_titles)}) for query '{query}'.") print(f"\nTotal document chunks from all pages: {len(all_documents)}") print("Creating FAISS index from content of all fetched pages...") try: vector_store = FAISS.from_documents(all_documents, self._embedding_model) print("FAISS index created successfully.") except Exception as e: return f"Error creating FAISS vector store: {e}" print(f"\nPerforming enhanced semantic search across all collected content...") try: relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query) except Exception as e: return f"Error during semantic search: {e}" if not relevant_docs: return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' " f"for your query '{query}' using entity-focused semantic search with list retrieval.") unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs])) result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) " f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n") nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, " f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n") result_details = [] for i, doc in enumerate(relevant_docs): source_info = doc.metadata.get('source_page_title', 'Unknown Source') variant_info = doc.metadata.get('retrieved_by_variant', 'N/A') score_info = doc.metadata.get('retrieval_score', 'N/A') detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n" f"(Retrieved by: '{variant_info}')\n{doc.page_content}") result_details.append(detail) final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details) print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).") return final_result.strip() except Exception as e: import traceback print(f"Unexpected error in {self.name}: {traceback.format_exc()}") return f"An unexpected error occurred: {str(e)}" # Example of creating the tool instance: # wikipedia_tool_faiss = WikipediaSearchToolWithFAISS() # To use this new tool in your agent, you would replace the old # `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list. # For example: # tools = [wikipedia_tool_faiss, search_tool] # Create tool instances #wikipedia_tool = WikipediaSearchTool() # --- Define Call LLM function --- # 3. Improved LLM call with memory management def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: # Added llm_model parameter """Call LLM with memory management, context truncation, and process response.""" print("Running call_llm with memory management...") # It's crucial to work with a copy of messages for modification within this step # The final state["messages"] should reflect the full history + new response. original_messages = list(state["messages"]) messages_for_llm_processing = list(state["messages"]) # Use this for truncation logic #ipdb.set_trace() # --- Context Truncation Logic --- system_message_content = None # Check if the first message is a system message and preserve it if messages_for_llm_processing and isinstance(messages_for_llm_processing[0], SystemMessage): system_message_content = messages_for_llm_processing[0] # Process only non-system messages for truncation count regular_messages = messages_for_llm_processing[1:] else: regular_messages = messages_for_llm_processing # Truncate context if too many messages (e.g., keep system + X most recent) # Max 10 messages total (e.g. 1 system + 9 others) max_regular_messages = 9 if len(regular_messages) > max_regular_messages: print(f"๐Ÿ”„ Truncating message count: {len(messages_for_llm_processing)} -> ~{max_regular_messages + (1 if system_message_content else 0)} messages") regular_messages = regular_messages[- (max_regular_messages -1):] # Keep X-1 most recent, to add user input later # Reconstruct messages for LLM call messages_for_llm = [] if system_message_content: messages_for_llm.append(system_message_content) messages_for_llm.extend(regular_messages) # Further truncate based on character count (rough proxy for tokens) total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) # Example character limit, adjust based on your model (e.g. 8k chars for ~4k tokens) char_limit = 8000 if total_chars > char_limit: print(f"๐Ÿ“ Context too long ({total_chars} chars > {char_limit}), further truncation needed") # More aggressive truncation of regular messages chars_to_remove = total_chars - char_limit temp_regular_messages = list(regular_messages) # copy while sum(len(str(m.content)) for m in temp_regular_messages) > char_limit and temp_regular_messages: if system_message_content and sum(len(str(m.content)) for m in temp_regular_messages) + len(str(system_message_content.content)) <= char_limit : break # if removing one more makes it too small with system message print(f"Removing message: {temp_regular_messages[0].type} - {temp_regular_messages[0].content[:50]}...") temp_regular_messages.pop(0) regular_messages = temp_regular_messages messages_for_llm = [] # Rebuild if system_message_content: messages_for_llm.append(system_message_content) messages_for_llm.extend(regular_messages) print(f"Context truncated to {sum(len(str(m.content)) for m in messages_for_llm)} chars.") new_state = state.copy() # Start with a copy of the input state try: if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"๐Ÿงน Pre-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") print(f"Invoking LLM with {len(messages_for_llm)} messages.") # This is where you call your actual LLM formatted_input = "\n".join([f"[{msg.type.upper()}] {msg.content}" for msg in messages_for_llm]) print(f"\n\nFormatted input for LLM:\n\n{formatted_input}") llm_response_object = llm_model.invoke(formatted_input) #ipdb.set_trace() # The response_object is typically a BaseMessage subclass (e.g., AIMessage) # or a string for simpler LLMs. Adapt as needed. if isinstance(llm_response_object, BaseMessage): ai_message_response = llm_response_object # It's already a message object if not ai_message_response.content: # Ensure content is not empty ai_message_response.content = "" elif hasattr(llm_response_object, 'content'): # Some models might return a custom object with a content attribute ai_message_response = AIMessage(content=str(llm_response_object.content) if llm_response_object.content is not None else "") else: # Assuming it's a string for basic LLMs ai_message_response = AIMessage(content=str(llm_response_object) if llm_response_object is not None else "") print(f"LLM Response: {ai_message_response.content[:300]}...") # Print a snippet # Append the LLM's response to the original full list of messages final_messages = original_messages + [ai_message_response] new_state["messages"] = final_messages new_state.pop("done", None) # LLM responded, so not 'done' by default except Exception as e: print(f"LLM call failed: {e}") error_message_content = f"LLM call failed with error: {str(e)}. Input consisted of {len(messages_for_llm)} messages." if "out of memory" in str(e).lower(): print("๐Ÿšจ CUDA OOM detected during LLM call! Implementing emergency cleanup...") error_message_content = f"LLM failed due to Out of Memory: {str(e)}." try: if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() except Exception as cleanup_e: print(f"Emergency OOM cleanup failed: {cleanup_e}") # Append an error message to the original message history error_ai_message = AIMessage(content=error_message_content) final_messages_on_error = original_messages + [error_ai_message] new_state["messages"] = final_messages_on_error new_state["done"] = True # Mark as done to prevent loops on LLM failure finally: try: if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"๐Ÿงน Post-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") except Exception: pass # Avoid error in cleanup hiding the main error return new_state import re import uuid def parse_react_output(state: AgentState) -> AgentState: print("Running parse_react_output (Action prioritized)...") messages = state["messages"] last_message = messages[-1] new_state = state.copy() # Only process AI messages (not system/user) if not isinstance(last_message, AIMessage): return new_state content = last_message.content # Remove any system prompt/instructions (if present in content) # Assume that the actual AI output is after the last occurrence of "You are a general AI assistant" or similar system prompt marker sys_prompt_pattern = r"(You are a general AI assistant.*?)(?=\n\n|$)" content_wo_sys_prompt = re.sub(sys_prompt_pattern, '', content, flags=re.DOTALL | re.IGNORECASE).strip() # Find the last occurrence of FINAL ANSWER or Action Input final_answer_match = list(re.finditer(r"FINAL ANSWER:", content_wo_sys_prompt, re.IGNORECASE)) action_input_match = list(re.finditer(r"Action Input:", content_wo_sys_prompt, re.IGNORECASE)) # Helper: get the last match position and which it was last_marker = None last_pos = -1 if final_answer_match: last_fa = final_answer_match[-1] last_marker = 'FINAL ANSWER' last_pos = last_fa.start() if action_input_match: last_ai = action_input_match[-1] if last_ai.start() > last_pos: last_marker = 'Action Input' last_pos = last_ai.start() # If neither marker found, mark as done if not last_marker: print("No FINAL ANSWER or Action Input found in last AI output.") new_state["done"] = True return new_state # Get the substring from the last marker to the end last_section = content_wo_sys_prompt[last_pos:].strip() # 2. If FINAL ANSWER is in the last part, end the process if last_marker == 'FINAL ANSWER': # Extract the answer after FINAL ANSWER: answer = re.search(r"FINAL ANSWER:\s*(.+)", last_section, re.IGNORECASE) final_answer_text = answer.group(1).strip() if answer else "" updated_ai_message = AIMessage(content=f"FINAL ANSWER: {final_answer_text}", tool_calls=[]) new_state["messages"] = messages[:-1] + [updated_ai_message] new_state["done"] = True print(f"FINAL ANSWER found at end: '{final_answer_text}'") return new_state # 3. If Action Input is in the last part, launch tool if last_marker == 'Action Input': # Try to extract the Action and Action Input for the last occurrence action_match = list(re.finditer(r"Action:\s*([^\n]+)", last_section)) action_input_match = list(re.finditer(r"Action Input:\s*([^\n]+)", last_section)) if action_match and action_input_match: tool_name = action_match[-1].group(1).strip() tool_input_raw = action_input_match[-1].group(1).strip() print(f"ReAct: Found Action: {tool_name}, Input: '{tool_input_raw}'") # Format tool_args as in your original code (simplified here) tool_args = {"query": tool_input_raw} tool_call_id = str(uuid.uuid4()) parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) new_state["messages"] = messages[:-1] + [updated_ai_message] new_state.pop("done", None) print(f"AIMessage updated with tool_calls: {parsed_tool_calls}") return new_state else: print("Action Input found at end, but could not parse Action or Action Input.") new_state["done"] = True return new_state # Fallback: mark as done print("No actionable marker found at end of last AI output. Marking as done.") new_state["done"] = True return new_state def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'): """Download a YouTube video using yt-dlp""" # Ensure the output directory exists os.makedirs(output_dir, exist_ok=True) # Delete all files in the output directory files = glob.glob(os.path.join(output_dir, '*')) for f in files: try: os.remove(f) except Exception as e: print(f"Error deleting {f}: {str(e)}") # Set output path for yt-dlp output_path = os.path.join(output_dir, output_filename) try: ydl_opts = { 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', 'outtmpl': output_path, 'quiet': True, 'merge_output_format': 'mp4', # Ensures merged output is mp4 'postprocessors': [{ 'key': 'FFmpegVideoConvertor', 'preferedformat': 'mp4', # Recode if needed }] } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([url]) return output_path except Exception as e: print(f"Error downloading YouTube video: {str(e)}") return None def extract_frames(video_path, output_dir, frame_interval_seconds=10): """Extract frames from a video file at specified intervals""" # Clean output directory before extracting new frames if os.path.exists(output_dir): for filename in os.listdir(output_dir): file_path = os.path.join(output_dir, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f'Failed to delete {file_path}. Reason: {e}') else: os.makedirs(output_dir, exist_ok=True) try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("Error: Could not open video.") return False fps = cap.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * frame_interval_seconds) count = 0 saved = 0 while True: ret, frame = cap.read() if not ret: break if count % frame_interval == 0: frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") cv2.imwrite(frame_filename, frame) saved += 1 count += 1 cap.release() print(f"Extracted {saved} frames.") return saved > 0 except Exception as e: print(f"Exception during frame extraction: {e}") return False def answer_question_on_frame(image_path, question): """Answer a question about a single video frame using BLIP""" try: vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used device = "cpu" image = Image.open(image_path).convert('RGB') inputs = processor_vqa(image, question, return_tensors="pt").to(device) out = model_vqa.generate(**inputs) answer = processor_vqa.decode(out[0], skip_special_tokens=True) return answer except Exception as e: print(f"Error processing frame {image_path}: {str(e)}") return "Error processing this frame" def answer_video_question(frames_dir, question): """Answer a question about a video by analyzing extracted frames""" valid_exts = ('.jpg', '.jpeg', '.png') # Check if directory exists if not os.path.exists(frames_dir): return { "most_common_answer": "No frames found to analyze.", "all_answers": [], "answer_counts": Counter() } frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.lower().endswith(valid_exts)] # Sort frames properly by number def get_frame_number(filename): match = re.search(r'(\d+)', os.path.basename(filename)) return int(match.group(1)) if match else 0 frame_files = sorted(frame_files, key=get_frame_number) if not frame_files: return { "most_common_answer": "No valid image frames found.", "all_answers": [], "answer_counts": Counter() } answers = [] for frame_path in frame_files: try: ans = answer_question_on_frame(frame_path, question) answers.append(ans) print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") except Exception as e: print(f"Error processing frame {frame_path}: {str(e)}") if not answers: return { "most_common_answer": "Could not analyze any frames successfully.", "all_answers": [], "answer_counts": Counter() } counted = Counter(answers) most_common_answer, freq = counted.most_common(1)[0] return { "most_common_answer": most_common_answer, "all_answers": answers, "answer_counts": counted } class YoutubeScreenshotQA(BaseTool): name: str = "youtube_screenshot_qa" description: str = ( "Downloads a YouTube video, extracts screenshots at intervals, " "and answers a question about the video based on the screenshots. " "Input should be a dict with keys: 'youtube_url' and 'question'." "Example input: {'youtube_url': 'https://www.youtube.com/watch?v=L1vXCYZAYYM', 'question': 'What is the highest number of bird species on camera simultaneously?'}" ) frame_interval_seconds: int = 10 # Can be parameterized if needed def _run(self, input_data: Dict[str, Any]) -> str: youtube_url = input_data.get("youtube_url") question = input_data.get("question") if not youtube_url or not question: return "Error: Input must include 'youtube_url' and 'question'." # Step 1: Download the video video_dir = '/tmp/video/' video_filename = 'downloaded_video.mp4' print(f"Downloading YouTube video from {youtube_url}...") video_path = download_youtube_video(youtube_url, output_dir=video_dir, output_filename=video_filename) if not video_path or not os.path.exists(video_path): return "Error: Failed to download the YouTube video." # Step 2: Extract frames frames_dir = '/tmp/video_frames/' print(f"Extracting frames from {video_path} every {self.frame_interval_seconds} seconds...") success = extract_frames(video_path, frames_dir, frame_interval_seconds=self.frame_interval_seconds) if not success: return "Error: Failed to extract frames from the video." # Step 3: Analyze frames and answer question print(f"Answering question about the video frames...") answer_result = answer_video_question(frames_dir, question) if not answer_result or not answer_result.get("most_common_answer"): return "Error: Could not analyze video frames to answer the question." # Format the result most_common = answer_result["most_common_answer"] all_answers = answer_result["all_answers"] counts = answer_result["answer_counts"] result = ( f"Most common answer: {most_common}\n" f"All answers: {all_answers}\n" f"Answer counts: {dict(counts)}" ) return result def tools_condition_with_logging(state: AgentState): """ Custom tools condition function that checks if the last message contains tool calls in the Thought/Action/Action Input format and logs the transition decision. Args: state (AgentState): The current state containing messages Returns: str: "tools" if tool calls are present, "__end__" otherwise """ import re # Ensure we have messages in the state if not state.get("messages") or len(state["messages"]) == 0: print("โŒ No messages found in state, ending conversation") return "__end__" # Get the last message last_message = state["messages"][-1] # Get message content content = "" if hasattr(last_message, 'content'): content = str(last_message.content) elif isinstance(last_message, dict) and 'content' in last_message: content = str(last_message['content']) else: print("โŒ No content found in last message, ending conversation") return "__end__" print(f"๐Ÿ” Analyzing message content: {content[:200]}...") # Check for Thought/Action/Action Input format has_tool_calls = False # Pattern to match the format: # Thought: # Action: # Action Input: thought_action_pattern = re.compile( r'Thought:\s*(.*?)\n\s*Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', re.DOTALL | re.IGNORECASE ) # Also check for just Action/Action Input without Thought action_only_pattern = re.compile( r'Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', re.DOTALL | re.IGNORECASE ) # Look for the complete format first match = thought_action_pattern.search(content) if not match: # Try the action-only format match = action_only_pattern.search(content) if match: thought = "No thought provided" action = match.group(1).strip() action_input = match.group(2).strip() else: action = None action_input = None thought = None else: thought = match.group(1).strip() action = match.group(2).strip() action_input = match.group(3).strip() if match and action: has_tool_calls = True print(f"๐Ÿ”ง Found tool call format:") print(f" Thought: {thought}") print(f" Action: {action}") print(f" Action Input: {action_input}") # Map common tool names to your actual tools tool_mappings = { 'wikipedia_semantic_search': 'wikipedia_tool', 'wikipedia': 'wikipedia_tool', 'search': 'search_tool', 'duckduckgo_search': 'search_tool', 'web_search': 'search_tool', 'youtube_screenshot_qa_tool': 'youtube_tool', 'youtube': 'youtube_tool', } # Normalize the action name normalized_action = action.lower().strip() # Store the parsed tool call information in the state for the tools node to use if 'parsed_tool_calls' not in state: state['parsed_tool_calls'] = [] tool_call_info = { 'thought': thought, 'action': action, 'action_input': action_input, 'normalized_action': normalized_action, 'tool_mapping': tool_mappings.get(normalized_action, normalized_action) } state['parsed_tool_calls'].append(tool_call_info) print(f"๐Ÿš€ Added tool call to state: {tool_call_info}") # Don't execute tools here - let call_tool handle execution # Just store the parsed information for call_tool to use # Also check for standalone tool mentions (fallback) if not has_tool_calls: # Check for tool names mentioned in content tool_keywords = [ 'wikipedia_semantic_search', 'wikipedia', 'search', 'duckduckgo', 'youtube_screenshot_qa_tool', 'youtube', 'web search' ] content_lower = content.lower() for keyword in tool_keywords: if keyword in content_lower: print(f"๐Ÿ”ง Found tool keyword '{keyword}' in content (fallback detection)") has_tool_calls = True break if has_tool_calls: print("๐Ÿ”ง Tool calls detected, transitioning to tools...") return "tools" else: print("โœ… No tool calls found, ending conversation") return "__end__" # 2. Improved call_tool with memory management def call_tool_with_memory_management(state: AgentState) -> AgentState: """Process tool calls with memory management.""" print("Running call_tool with memory management...") # Clear CUDA cache before processing try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"๐Ÿงน Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") except: pass # Check if we have parsed tool calls from the condition function if 'parsed_tool_calls' in state and state['parsed_tool_calls']: return execute_parsed_tool_calls(state) # Fallback to original OpenAI-style tool calls handling messages = state["messages"] last_message = messages[-1] if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: print("No tool calls found in last message") return state # Copy the messages to avoid mutating the original list new_messages = list(messages) print(f"Processing {len(last_message.tool_calls)} tool calls") for i, tool_call in enumerate(last_message.tool_calls): print(f"Processing tool call {i+1}: {tool_call['name'] if isinstance(tool_call, dict) else tool_call.name}") # Handle both dict and object-style tool calls if isinstance(tool_call, dict): tool_name = tool_call.get("name", "") args = tool_call.get("args", {}) tool_call_id = tool_call.get("id", str(uuid.uuid4())) else: tool_name = getattr(tool_call, "name", "") args = getattr(tool_call, "args", {}) tool_call_id = getattr(tool_call, "id", str(uuid.uuid4())) # Find the matching tool selected_tool = None for tool in tools: if tool.name.lower() == tool_name.lower(): selected_tool = tool break if not selected_tool: tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}" print(f"Tool not found: {tool_name}") else: try: # Extract query if isinstance(args, dict) and "query" in args: query = args["query"] else: query = str(args) if args else "" print(f"Executing {tool_name} with query: {query[:100]}...") tool_result = selected_tool.run(query) # Aggressive truncation to prevent memory issues max_length = 3000 if "wikipedia" in tool_name.lower() else 2000 if len(tool_result) > max_length: tool_result = tool_result[:max_length] + f"... [Result truncated from {len(tool_result)} to {max_length} chars to prevent memory issues]" print(f"๐Ÿ“„ Truncated result to {max_length} characters") print(f"Tool result length: {len(tool_result)} characters") except Exception as e: tool_result = f"Error executing tool '{tool_name}': {str(e)}" print(f"Tool execution error: {e}") # Create tool message tool_message = ToolMessage( content=tool_result, name=tool_name, tool_call_id=tool_call_id ) new_messages.append(tool_message) print(f"Added tool message for {tool_name}") # Update the state new_state = state.copy() new_state["messages"] = new_messages # Clear CUDA cache after processing try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except: pass return new_state def execute_parsed_tool_calls(state: AgentState): """ Execute tool calls that were parsed from the Thought/Action/Action Input format. This is called by call_tool when parsed_tool_calls are present in state. Args: state (AgentState): The current state containing parsed tool calls Returns: AgentState: Updated state with tool results """ # Use the same tools list that's available globally # Map tool names to the actual tool instances tool_name_mappings = { 'wikipedia_semantic_search': 'wikipedia_tool', 'wikipedia': 'wikipedia_tool', 'search': 'enhanced_search', # Updated mapping 'duckduckgo_search': 'enhanced_search', # Updated mapping 'web_search': 'enhanced_search', # Updated mapping 'enhanced_search': 'enhanced_search', # Direct mapping 'youtube_screenshot_qa_tool': 'youtube_tool', 'youtube': 'youtube_tool', } # Create a lookup by tool names for your existing tools list tools_by_name = {} for tool in tools: tools_by_name[tool.name.lower()] = tool # Also map by class name for flexibility class_name = tool.__class__.__name__.lower() if 'wikipedia' in class_name: tools_by_name['wikipedia_tool'] = tool elif 'search' in class_name or 'duck' in class_name: tools_by_name['search_tool'] = tool elif 'youtube' in class_name: tools_by_name['youtube_tool'] = tool # Copy messages to avoid mutation during iteration new_messages = list(state["messages"]) for tool_call in state['parsed_tool_calls']: action = tool_call['action'] action_input = tool_call['action_input'] thought = tool_call['thought'] normalized_action = tool_call['normalized_action'] print(f"๐Ÿš€ Executing tool: {action} with input: {action_input}") # Find the tool instance tool_instance = None # Try direct name match first if normalized_action in tools_by_name: tool_instance = tools_by_name[normalized_action] # Try mapped name elif normalized_action in tool_name_mappings: mapped_name = tool_name_mappings[normalized_action] if mapped_name in tools_by_name: tool_instance = tools_by_name[mapped_name] if tool_instance: try: result = tool_instance.run(action_input) if len(result) > 6000: result = result[:6000] + "... [Result truncated due to length]" # Create observation message in the format your agent expects from langchain_core.messages import AIMessage observation = f"Observation: {result}" observation_message = AIMessage(content=observation) new_messages.append(observation_message) print(f"โœ… Tool '{action}' executed successfully") except Exception as e: print(f"โŒ Error executing tool '{action}': {e}") from langchain_core.messages import AIMessage error_msg = f"Observation: Error executing '{action}': {str(e)}" error_message = AIMessage(content=error_msg) new_messages.append(error_message) else: print(f"โŒ Tool '{action}' not found in available tools") available_tool_names = list(tools_by_name.keys()) from langchain_core.messages import AIMessage error_msg = f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}" error_message = AIMessage(content=error_msg) new_messages.append(error_message) # Update state with new messages and clear parsed tool calls state["messages"] = new_messages state['parsed_tool_calls'] = [] return state # 1. Add loop detection to your AgentState def should_continue(state: AgentState) -> str: """Determine if the agent should continue or end.""" print("Running should_continue....") messages = state["messages"] #ipdb.set_trace() # Check if we're done if state.get("done", False): return "end" # Prevent infinite loops - limit tool calls tool_call_count = sum(1 for msg in messages if hasattr(msg, 'tool_calls') and msg.tool_calls) if tool_call_count >= 3: # Max 3 tool calls per conversation print(f"โš ๏ธ Stopping: Too many tool calls ({tool_call_count})") return "end" # Check for repeated tool calls with same query recent_tool_calls = [] for msg in messages[-6:]: # Check last 6 messages if hasattr(msg, 'tool_calls') and msg.tool_calls: for tool_call in msg.tool_calls: if isinstance(tool_call, dict): recent_tool_calls.append((tool_call.get('name'), str(tool_call.get('args', {})))) if len(recent_tool_calls) >= 2 and recent_tool_calls[-1] == recent_tool_calls[-2]: print("โš ๏ธ Stopping: Repeated tool call detected") return "end" # Check message count to prevent runaway conversations if len(messages) > 15: print(f"โš ๏ธ Stopping: Too many messages ({len(messages)})") return "end" return "continue" def route_after_parse_react(state: AgentState) -> str: """Determines the next step after parsing LLM output, prioritizing end state.""" if state.get("done", False): # Check if parse_react_output decided we are done return "end_processing" # Original logic: check for tool calls in the last message # Ensure messages list and last message exist before checking tool_calls messages = state.get("messages", []) if messages: last_message = messages[-1] if hasattr(last_message, 'tool_calls') and last_message.tool_calls: return "call_tool" return "call_llm" #wikipedia_tool = WikipediaSearchToolWithFAISS() #search_tool = DuckDuckGoSearchRun() #youtube_screenshot_qa_tool = YoutubeScreenshotQA() # Combine all tools #tools = [wikipedia_tool, search_tool, youtube_screenshot_qa_tool] # Update your tools list to use the global instances # # --- Graph Construction --- # --- Graph Construction --- def create_memory_safe_workflow(): """Create a workflow with memory management and loop prevention.""" # These models are initialized here but might be better managed if they need to be released/reinitialized # like you attempt in run_agent. Consider passing them or managing their lifecycle carefully. hf_pipe = create_llm_pipeline() llm = HuggingFacePipeline(pipeline=hf_pipe) # vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly # processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used # model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used workflow = StateGraph(AgentState) # Bind the llm_model to the call_llm_with_memory_management function bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm) # Add nodes with memory-safe versions workflow.add_node("call_llm", bound_call_llm) # Use the bound version here workflow.add_node("parse_react_output", parse_react_output) workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly # Set entry point workflow.set_entry_point("call_llm") # Add conditional edges workflow.add_conditional_edges( "call_llm", should_continue, { "continue": "parse_react_output", "end": END } ) workflow.add_conditional_edges( "parse_react_output", route_after_parse_react, { "call_tool": "call_tool", "call_llm": "call_llm", "end_processing": END } ) workflow.add_edge("call_tool", "call_llm") return workflow.compile() # --- Run the Agent --- def run_agent(myagent, state: AgentState): """ Initialize agent with proper system message and formatted query. """ #global llm, hf_pipe, model_vqa, processor_vqa global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, tools #ipdb.set_trace() # At the module level, create instances once WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS() SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=3000) YOUTUBE_TOOL = YoutubeScreenshotQA() tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL] # Create a fresh system message each time formatted_tools_description = render_text_description(tools) current_date_str = datetime.now().strftime("%Y-%m-%d") system_content = f"""You are a general AI assistant. with access to these tools: {formatted_tools_description} If you need the most current information as of 2025, use enhanced_search If you need to do in-depth research, use wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval If you can answer the question confidently, do so directly. If the question seems like gibberish (not English), try flipping the entire question and re-read the question. If you need more information, use a tool. (Think through the problem step by step) When using a tool, follow this format: Thought: Action: Action Input: Only use the tools listed above for the Action: step. Do not invent new tool names or actions. If you need to reason, do so in the Thought: step. After using a tool, process its output in your Thought: step, not as an Action. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string Do not provide disclaimers. Do not provide supporting details. """ # Get user question from AgentState query = state['question'] # Pattern for YouTube yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" has_youtube = re.search(yt_pattern, query) is not None if has_youtube: # Store the extracted YouTube URL in the state url_match = re.search(r"(https?://[^\s]+)", query) if url_match: state['youtube_url'] = url_match.group(0) # Format the user query to guide the model better formatted_query = f"""{query}""" # Initialize agent state with proper message types system_message = SystemMessage(content=system_content) human_message = HumanMessage(content=formatted_query) # Initialize state with properly typed messages and done=False # state = {"messages": [system_message, human_message], "done": False} state['messages'] = [system_message, human_message] state["done"] = False # Use the new method to run the graph result = myagent.invoke(state) # Check if FINAL ANSWER was given (i.e., workflow ended) if result.get("done"): #del llm #del hf_pipe #del model_vqa #del processor_vqa torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() print("Released GPU memory after FINAL ANSWER.") # Re-initialize for the next run #hf_pipe = create_llm_pipeline() #llm = HuggingFacePipeline(pipeline=hf_pipe) #print("Re-initilized llm...") # Extract and return just the messages for cleaner output return result["messages"]