Samuel Thomas commited on
Commit
11398e5
·
1 Parent(s): ebdd994

multiple revisions

Browse files
Files changed (3) hide show
  1. app.py +18 -15
  2. requirements.txt +1 -0
  3. tools.py +1472 -519
app.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  import inspect
5
  import pandas as pd
6
  import traceback
7
- from tools import intelligent_agent, get_file_type, write_bytes_to_temp_dir, State
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
@@ -87,24 +87,27 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
87
  tb_str = traceback.format_exc()
88
  print(f"Error creating new states: {tb_str}")
89
  return f"Error creating new states: {tb_str}", None
 
 
90
 
91
- # 3. Setup states for questions and run agent
92
  answers_payload = []
93
  results_log = []
94
  for r in range(len(hf_questions)):
95
- s = State(question = hf_questions[r]['question'],
96
- input_file = hf_questions[r]['input_file'],
97
- file_type = hf_questions[r]['file_type'],
98
- file_path = hf_questions[r]['file_path'])
99
- try:
100
- task_id = hf_questions[r]['task_id']
101
- question_text = hf_questions[r]['question']
102
- submitted_answer = intelligent_agent(s)
103
- answers_payload.append({"task_id": task_id, "model_answer": submitted_answer})
104
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
105
- except:
106
- print(f"Error running agent on task {task_id}: {e}")
107
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
108
 
109
  if not answers_payload:
110
  print("Agent did not produce any answers to submit.")
 
4
  import inspect
5
  import pandas as pd
6
  import traceback
7
+ from tools import create_memory_safe_workflow, get_file_type, write_bytes_to_temp_dir, AgentState, extract_final_answer, run_agent
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
 
87
  tb_str = traceback.format_exc()
88
  print(f"Error creating new states: {tb_str}")
89
  return f"Error creating new states: {tb_str}", None
90
+
91
+ agent = create_memory_safe_workflow()
92
 
93
+ # Setup states for questions and run agent
94
  answers_payload = []
95
  results_log = []
96
  for r in range(len(hf_questions)):
97
+ s = AgentState(question = hf_questions[r]['question'],
98
+ input_file = hf_questions[r]['input_file'],
99
+ file_type = hf_questions[r]['file_type'],
100
+ file_path = hf_questions[r]['file_path'])
101
+ try:
102
+ task_id = hf_questions[r]['task_id']
103
+ question_text = hf_questions[r]['question']
104
+ full_answer = run_agent(agent, s)
105
+ submitted_answer = extract_final_answer(full_answer[-1].content)
106
+ answers_payload.append({"task_id": task_id, "model_answer": submitted_answer})
107
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
108
+ except:
109
+ print(f"Error running agent on task {task_id}: {e}")
110
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
111
 
112
  if not answers_payload:
113
  print("Agent did not produce any answers to submit.")
requirements.txt CHANGED
@@ -19,3 +19,4 @@ accelerate
19
  en_core_web_sm @ https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
20
  transformers==4.40.0
21
  datasets==2.19.0
 
 
19
  en_core_web_sm @ https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
20
  transformers==4.40.0
21
  datasets==2.19.0
22
+ beautifulsoup4
tools.py CHANGED
@@ -1,45 +1,85 @@
1
- import numpy as np
2
- import spacy
 
3
  import tempfile
 
4
  import glob
5
- import yt_dlp
6
  import shutil
 
 
 
 
 
 
 
 
 
 
 
 
7
  import cv2
8
- import librosa
9
  import wikipedia
 
 
 
 
 
 
 
 
10
 
11
- from typing import TypedDict, List, Optional, Dict, Any
12
  from langchain.docstore.document import Document
13
  from langchain.prompts import PromptTemplate
14
  from langchain_community.document_loaders import WikipediaLoader
15
- from langgraph.graph import START, END, StateGraph
16
- from langchain_core.messages import AnyMessage, HumanMessage, AIMessage # If you are using it
17
- from langchain_community.retrievers import BM25Retriever # If you are using it
18
- from langgraph.prebuilt import ToolNode, tools_condition # If you are using it
19
  from langchain.vectorstores import FAISS
20
  from langchain.embeddings import HuggingFaceEmbeddings
 
21
  from langchain.schema import Document
22
- from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline
23
- from io import BytesIO
24
- from sentence_transformers import SentenceTransformer
25
- from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
26
- from transformers import AutoTokenizer, AutoModelWithLMHead
27
 
 
 
 
28
 
29
- import os
30
- import re
31
- from PIL import Image # This is correctly imported, but was being used incorrectly
32
- import numpy as np
33
- from collections import Counter
34
  import torch
35
- from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline
36
- from typing import TypedDict, List, Optional, Dict, Any, Literal, Tuple
37
- from langgraph.graph import StateGraph, START, END
38
- from langchain.docstore.document import Document
 
 
 
 
39
 
40
 
41
  nlp = spacy.load("en_core_web_sm")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Define file extension sets for each category
44
  PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
45
  AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'}
@@ -82,8 +122,196 @@ def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str:
82
  print(f"File written to: {file_path}")
83
  return file_path
84
 
85
- # 1. Define the State type
86
- class State(TypedDict, total=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  question: str
88
  task_id: str
89
  input_file: Optional[bytes]
@@ -93,144 +321,628 @@ class State(TypedDict, total=False):
93
  youtube_url: Optional[str]
94
  answer: Optional[str]
95
  frame_answers: Optional[list]
96
- next: Optional[str] # Added to track the next node
97
-
98
- # --- LLM pipeline for general questions ---
99
- llm_pipe = pipeline(
100
- "text-generation",
101
- model="microsoft/Phi-3-mini-4k-instruct",
102
- device_map="auto",
103
- torch_dtype="auto",
104
- max_new_tokens=256,
105
- trust_remote_code=True
106
- )
107
-
108
- # Initialize RAG components
109
- tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base", trust_remote_code=True)
110
- retriever = RagRetriever.from_pretrained(
111
- "facebook/rag-token-base",
112
- index_name="exact", # or "legacy" for legacy FAISS index
113
- use_dummy_dataset=False, # set to False and download the full index for real Wikipedia retrieval
114
- trust_remote_code=True, # Trust remote code for dataset loading
115
- dataset_revision="main", # Specify a fixed revision
116
- dataset="wiki_dpr", # Explicitly specify dataset name
117
- )
118
- rag_model = RagSequenceForGeneration.from_pretrained(
119
- "facebook/rag-token-base",
120
- retriever=retriever,
121
- trust_remote_code=True
122
- )
123
- # Speech-to-text pipeline
124
- asr_pipe = pipeline(
125
- "automatic-speech-recognition",
126
- model="openai/whisper-small",
127
- device="auto"
128
- )
129
-
130
- # --- BLIP VQA setup ---
131
- device = "cuda" if torch.cuda.is_available() else "cpu"
132
- vqa_model_name = "Salesforce/blip-vqa-base"
133
- processor_vqa = BlipProcessor.from_pretrained(vqa_model_name)
134
-
135
- # Attempt to load model to GPU; fall back to CPU if OOM
136
- try:
137
- model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device)
138
- except torch.cuda.OutOfMemoryError:
139
- print("WARNING: Loading model to CPU due to insufficient GPU memory.")
140
- device = "cpu" # Switch device to CPU
141
- model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device)
142
-
143
- # --- Helper functions ---
144
- def ensure_final_answer_format(answer_text: str) -> str:
145
- """Ensure the answer ends with FINAL ANSWER: format"""
146
- # Check if the answer already contains a FINAL ANSWER section
147
- if "FINAL ANSWER:" in answer_text:
148
- # Extract everything after FINAL ANSWER:
149
- final_answer_part = answer_text.split("FINAL ANSWER:", 1)[1].strip()
150
- return f"FINAL ANSWER: {final_answer_part}"
151
- else:
152
- # If no FINAL ANSWER section exists, wrap the entire answer
153
- return f"FINAL ANSWER: {answer_text.strip()}"
154
-
155
- def extract_entities(text: str) -> List[str]:
156
- """Extract key entities from text using spaCy if available, or regex fallback"""
157
- if nlp:
158
- # Using spaCy for better entity extraction
159
- doc = nlp(text)
160
- entities = [ent.text for ent in doc.ents]
161
- keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")]
162
- return entities if entities else keywords
163
- else:
164
- # Simple fallback using regex to extract potential keywords
165
- words = text.lower().split()
166
- stopwords = ["what", "who", "when", "where", "why", "how", "is", "are", "the", "a", "an", "of", "in", "on", "at"]
167
- keywords = [word for word in words if word not in stopwords and len(word) > 2]
168
- return keywords
169
 
170
- def answer_question_on_frame(image_path, question):
171
- """Answer a question about a single video frame using BLIP"""
172
- try:
173
- image = Image.open(image_path).convert('RGB')
174
- inputs = processor_vqa(image, question, return_tensors="pt").to(device)
175
- out = model_vqa.generate(**inputs)
176
- answer = processor_vqa.decode(out[0], skip_special_tokens=True)
177
- return answer
178
- except Exception as e:
179
- print(f"Error processing frame {image_path}: {str(e)}")
180
- return "Error processing this frame"
181
 
182
- def answer_video_question(frames_dir, question):
183
- """Answer a question about a video by analyzing extracted frames"""
184
- valid_exts = ('.jpg', '.jpeg', '.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- # Check if directory exists
187
- if not os.path.exists(frames_dir):
188
- return {
189
- "most_common_answer": "No frames found to analyze.",
190
- "all_answers": [],
191
- "answer_counts": Counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir)
195
- if f.lower().endswith(valid_exts)]
 
 
196
 
197
- # Sort frames properly by number
198
- def get_frame_number(filename):
199
- match = re.search(r'(\d+)', os.path.basename(filename))
200
- return int(match.group(1)) if match else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- frame_files = sorted(frame_files, key=get_frame_number)
 
 
 
 
 
203
 
204
- if not frame_files:
205
- return {
206
- "most_common_answer": "No valid image frames found.",
207
- "all_answers": [],
208
- "answer_counts": Counter()
209
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- answers = []
212
- for frame_path in frame_files:
213
- try:
214
- ans = answer_question_on_frame(frame_path, question)
215
- answers.append(ans)
216
- print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}")
217
  except Exception as e:
218
- print(f"Error processing frame {frame_path}: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- if not answers:
221
- return {
222
- "most_common_answer": "Could not analyze any frames successfully.",
223
- "all_answers": [],
224
- "answer_counts": Counter()
225
- }
226
 
227
- counted = Counter(answers)
228
- most_common_answer, freq = counted.most_common(1)[0]
229
- return {
230
- "most_common_answer": most_common_answer,
231
- "all_answers": answers,
232
- "answer_counts": counted
233
- }
234
 
235
  def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'):
236
  """Download a YouTube video using yt-dlp"""
@@ -307,419 +1019,660 @@ def extract_frames(video_path, output_dir, frame_interval_seconds=10):
307
  print(f"Exception during frame extraction: {e}")
308
  return False
309
 
310
- def image_qa(image_path: str, question: str) -> str:
311
- """Answer questions about an image using the BLIP model"""
312
  try:
 
 
 
 
 
313
  image = Image.open(image_path).convert('RGB')
314
  inputs = processor_vqa(image, question, return_tensors="pt").to(device)
315
  out = model_vqa.generate(**inputs)
316
  answer = processor_vqa.decode(out[0], skip_special_tokens=True)
317
  return answer
318
  except Exception as e:
319
- print(f"Error in image_qa: {str(e)}")
320
- return f"Error processing image: {str(e)}"
321
-
322
- # --- Node functions ---
323
- def router(state: Dict[str, Any]) -> str:
324
- """Determine the next node based on question content and file type"""
325
- question = state.get('question', '')
326
-
327
- # Pattern for Wikipedia and similar sources
328
- wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)"
329
- has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None
330
-
331
- # Pattern for YouTube
332
- yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+"
333
- has_youtube = re.search(yt_pattern, question) is not None
334
-
335
- # Check for image
336
- has_image = state.get('file_type') == 'picture'
337
 
338
- # Check for audio
339
- has_audio = state.get('file_type') == 'audio'
 
340
 
341
- print(f"Has Wikipedia reference: {has_wiki}")
342
- print(f"Has YouTube link: {has_youtube}")
343
- print(f"Has picture file: {has_image}")
344
- print(f"Has audio file: {has_audio}")
 
 
 
345
 
346
- if has_wiki:
347
- return "retrieve"
348
- elif has_youtube:
349
- # Store the extracted YouTube URL in the state
350
- url_match = re.search(r"(https?://[^\s]+)", question)
351
- if url_match:
352
- state['youtube_url'] = url_match.group(0)
353
- return "video"
354
- elif has_image:
355
- return "image"
356
- elif has_audio:
357
- return "audio"
358
- else:
359
- return "llm"
360
-
361
- def node_decide(state: Dict[str, Any]) -> Dict[str, Any]:
362
- """Router node that decides which node to go to next"""
363
- print("Running node_decide")
364
- # Initialize context list if not present
365
- if 'context' not in state:
366
- state['context'] = []
367
- # Add the next state to the state dict
368
- state["next"] = router(state)
369
- print(f"Routing to: {state['next']}")
370
- return state
371
 
372
- def node_image(state: Dict[str, Any]) -> Dict[str, Any]:
373
- """Process image-based questions"""
374
- print("Running node_image")
375
- try:
376
- # Make sure the image file exists
377
- if not os.path.exists(state['file_path']):
378
- state['answer'] = ensure_final_answer_format("Image file not found.")
379
- return state
380
-
381
- # Get answer from image QA model
382
- answer = image_qa(state['file_path'], state['question'])
383
-
384
- # Format the final answer
385
- state['answer'] = ensure_final_answer_format(answer)
386
-
387
- # Add document to state for traceability
388
- image_doc = Document(
389
- page_content=f"Image analysis result: {answer}",
390
- metadata={"source": "image_analysis", "file_path": state['file_path']}
391
- )
392
- state['context'].append(image_doc)
393
-
394
- except Exception as e:
395
- error_msg = f"Error processing image: {str(e)}"
396
- print(error_msg)
397
- state['answer'] = ensure_final_answer_format(error_msg)
398
-
399
- return state
400
 
401
- def node_video(state: Dict[str, Any]) -> Dict[str, Any]:
402
- """Process video-based questions"""
403
- print("Running node_video")
404
- youtube_url = state.get('youtube_url')
405
- if not youtube_url:
406
- state['answer'] = ensure_final_answer_format("No YouTube URL found in the question.")
407
- return state
408
 
409
- question = state['question']
410
- # Extract the actual question part (remove the URL)
411
- question_text = re.sub(r'https?://[^\s]+', '', question).strip()
412
- if not question_text.endswith('?'):
413
- question_text += '?'
 
414
 
415
- video_file = download_youtube_video(youtube_url)
416
- if not video_file or not os.path.exists(video_file):
417
- state['answer'] = ensure_final_answer_format("Failed to download the video.")
418
- return state
 
 
 
 
419
 
420
- frames_dir = "/tmp/frames"
421
- os.makedirs(frames_dir, exist_ok=True)
 
 
 
 
422
 
423
- success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10)
424
- if not success:
425
- state['answer'] = ensure_final_answer_format("Failed to extract frames from the video.")
426
- return state
 
 
 
427
 
428
- result = answer_video_question(frames_dir, question_text)
429
- final_answer = result['most_common_answer']
430
- state['frame_answers'] = result['all_answers']
431
 
432
- # Create Document objects for each frame analysis
433
- frame_documents = []
434
- for i, ans in enumerate(result['all_answers']):
435
- doc = Document(
436
- page_content=f"Frame {i}: {ans}",
437
- metadata={"frame_number": i, "source": "video_analysis"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  )
439
- frame_documents.append(doc)
440
-
441
- # Add documents to state
442
- state['context'].extend(frame_documents)
443
- state['answer'] = ensure_final_answer_format(final_answer)
444
-
445
- print(f"Video answer: {state['answer']}")
446
- return state
447
-
448
- def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]:
449
- """Process audio-based questions"""
450
- print(f"Processing audio file: {state['file_path']}")
451
 
452
- try:
453
- # Step 1: Transcribe audio
454
- audio, sr = librosa.load(state['file_path'], sr=16000)
455
- asr_result = asr_pipe({"raw": audio, "sampling_rate": sr})
456
- audio_transcript = asr_result['text']
457
- print(f"Audio transcript: {audio_transcript}")
458
-
459
- # Step 2: Store transcript in vector store
460
- transcript_doc = [Document(page_content=audio_transcript)]
461
- embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5')
462
- vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings)
463
-
464
- # Step 3: Retrieve relevant docs for the user's question
465
- question = state['question']
466
- similar_docs = vector_db.similarity_search(question, k=1)
467
- retrieved_context = "\n".join([doc.page_content for doc in similar_docs])
468
-
469
- # Step 4: Generate answer
470
- prompt = (
471
- f"You are an AI assistant that answers questions about audio content.\n\n"
472
- f"Audio transcript: {retrieved_context}\n\n"
473
- f"Question: {question}\n\n"
474
- f"Based only on the provided audio transcript, answer the question. "
475
- f"If the transcript does not contain relevant information, state that clearly.\n\n"
476
- f"End your response with 'FINAL ANSWER: ' followed by a concise answer."
477
- )
478
-
479
- llm_response = llm_pipe(prompt)
480
- answer_text = llm_response[0]['generated_text']
481
-
482
- # Add documents to state
483
- state['context'].extend(transcript_doc)
484
- state['context'].append(Document(
485
- page_content=prompt,
486
- metadata={"source": "audio_analysis_prompt"}
487
- ))
488
 
489
- # Ensure final answer format
490
- state['answer'] = ensure_final_answer_format(answer_text)
491
-
492
- except Exception as e:
493
- error_msg = f"Audio processing error: {str(e)}"
494
- print(error_msg)
495
- state['answer'] = ensure_final_answer_format(error_msg)
496
-
497
- return state
498
 
499
- def node_llm(state: Dict[str, Any]) -> Dict[str, Any]:
500
- """Process general knowledge questions with LLM"""
501
- print("Running node_llm")
502
- question = state['question']
503
-
504
- # Compose a detailed prompt
505
- prompt = (
506
- "You are an AI assistant that answers questions using your general knowledge. "
507
- "Follow these steps:\n\n"
508
- "1. If the question appears to be scrambled or jumbled, first try to unscramble or reconstruct the intended meaning.\n"
509
- "2. Analyze the question (unscrambled if needed) and use your own knowledge to answer it.\n"
510
- "3. If the question can't be answered with certainty, provide your best estimate and clearly explain any assumptions.\n"
511
- "4. Format your answer using these rules:\n"
512
- " - Numbers: Plain digits without commas/units (e.g. 1234567)\n"
513
- " - Strings: Minimal words, no articles/abbreviations\n"
514
- " - Lists: comma-separated values without extra formatting\n\n"
515
- "5. Always conclude with:\n"
516
- "FINAL ANSWER: [your answer] (replace bracketed text)\n\n"
517
- f"Current question: {question}"
 
 
 
 
 
 
 
 
 
 
 
518
  )
519
-
520
- # Add document to state for traceability
521
- query_doc = Document(
522
- page_content=prompt,
523
- metadata={"source": "llm_prompt"}
524
  )
525
- state['context'].append(query_doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
- try:
528
- result = llm_pipe(prompt)
529
- answer_text = result[0]['generated_text']
530
- state['answer'] = ensure_final_answer_format(answer_text)
531
- except Exception as e:
532
- print(f"Error in LLM processing: {str(e)}")
533
- error_msg = f"An error occurred while processing your question: {str(e)}"
534
- state['answer'] = ensure_final_answer_format(error_msg)
535
 
536
- print(f"LLM answer: {state['answer']}")
537
- return state
538
- def retrieve(state: State) -> State:
539
- """Retrieve relevant documents using RAG"""
540
- print("Running retrieve")
541
- question = state["question"]
542
 
 
543
  try:
544
- # Tokenize the question
545
- inputs = tokenizer(question, return_tensors="pt")
546
-
547
- # Get doc_ids by using the retriever directly
548
- question_hidden_states = rag_model.question_encoder(inputs["input_ids"])[0]
549
- docs_dict = retriever(
550
- inputs["input_ids"].numpy(),
551
- question_hidden_states.detach().numpy(),
552
- return_tensors="pt"
553
- )
554
-
555
- # Extract the retrieved passages
556
- all_chunks = []
 
 
 
 
 
 
 
 
 
 
 
 
 
557
 
558
- # Debug print to see what's in docs_dict
559
- print(f"docs_dict keys: {docs_dict.keys()}")
 
 
 
 
 
 
 
560
 
561
- # Check for different possible keys that might contain the documents
562
- doc_text_key = None
563
- for possible_key in ['retrieved_doc_text', 'doc_text', 'texts', 'documents']:
564
- if possible_key in docs_dict:
565
- doc_text_key = possible_key
566
  break
567
 
568
- if doc_text_key:
569
- # Access the retrieved document texts from the docs_dict
570
- for i in range(len(docs_dict["doc_ids"][0])):
571
- doc_text = docs_dict[doc_text_key][0][i]
572
- all_chunks.append(Document(page_content=doc_text))
573
-
574
- print(f"Retrieved {len(all_chunks)} documents")
575
  else:
576
- # Fallback: Try to extract document text from doc_ids
577
- doc_ids = docs_dict.get("doc_ids", [[]])[0]
578
- print(f"Retrieved doc_ids: {doc_ids}")
579
-
580
- # Create minimal document stubs from IDs
581
- for doc_id in doc_ids:
582
- stub_text = f"Information related to document ID: {doc_id}"
583
- all_chunks.append(Document(page_content=stub_text))
584
-
585
- print(f"Created {len(all_chunks)} document stubs from IDs")
586
-
587
- # Add documents to state context
588
- if not state.get('context'):
589
- state['context'] = []
590
- state['context'].extend(all_chunks)
 
 
 
 
 
 
591
 
592
- except Exception as e:
593
- print(f"Error in retrieval: {str(e)}")
594
- # Create an error document
595
- error_doc = Document(
596
- page_content=f"Error during retrieval: {str(e)}",
597
- metadata={"source": "retrieval_error"}
598
  )
599
- if not state.get('context'):
600
- state['context'] = []
601
- state['context'].append(error_doc)
602
 
603
- return state
604
-
605
- def generate(state: State) -> State:
606
- """Generate an answer based on retrieved documents"""
607
- print("Running generate")
608
 
 
609
  try:
610
- # Check if context exists
611
- if not state.get('context') or len(state['context']) == 0:
612
- state['answer'] = ensure_final_answer_format("No relevant information found to answer your question.")
613
- return state
614
-
615
- # Concatenate all context documents into a single string
616
- docs_content = "\n\n".join(doc.page_content for doc in state["context"])
617
-
618
- # Format the prompt for the LLM
619
- prompt_str = PromptTemplate(
620
- input_variables=["question", "context"],
621
- template=(
622
- "You are an AI assistant that answers questions using retrieved context. "
623
- "Follow these steps:\n\n"
624
- "1. Analyze the provided context:\n{context}\n\n"
625
- "2. If the context contains scrambled text, first attempt to reconstruct meaningful information\n"
626
- "3. If the question can't be answered from context alone, combine context with general knowledge "
627
- "but clearly state this limitation\n"
628
- "4. Format your answer using these rules:\n"
629
- " - Numbers: Plain digits without commas/units (e.g. 1234567)\n"
630
- " - Strings: Minimal words, no articles/abbreviations\n"
631
- " - Lists: comma-separated values without extra formatting\n\n"
632
- "5. Always conclude with:\n"
633
- "FINAL ANSWER: [your answer] (replace bracketed text)\n\n"
634
- "Current question: {question}"
635
- )
636
- ).format(question=state["question"], context=docs_content)
637
 
638
- # Generate answer using the LLM pipeline
639
- response = llm_pipe(prompt_str)
640
- answer_text = response[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
- # Ensure answer has the FINAL ANSWER format
643
- state['answer'] = ensure_final_answer_format(answer_text)
644
 
645
- except Exception as e:
646
- print(f"Error in generate node: {str(e)}")
647
- error_msg = f"Error generating answer: {str(e)}"
648
- state['answer'] = ensure_final_answer_format(error_msg)
649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  return state
651
 
652
- # --- Define the edge condition function ---
653
- def get_next_node(state: Dict[str, Any]) -> str:
654
- """Get the next node from the state"""
655
- return state["next"]
656
-
657
- # Create the StateGraph
658
- graph = StateGraph(State)
659
-
660
- # Add nodes
661
- graph.add_node("decide", node_decide)
662
- graph.add_node("video", node_video)
663
- graph.add_node("llm", node_llm)
664
- graph.add_node("retrieve", retrieve)
665
- graph.add_node("generate", generate)
666
- graph.add_node("image", node_image)
667
- graph.add_node("audio", node_audio_rag)
668
-
669
- # Add edge from START to decide
670
- graph.add_edge(START, "decide")
671
- graph.add_edge("retrieve", "generate")
672
-
673
- # Add conditional edges from decide to other nodes based on question
674
- graph.add_conditional_edges(
675
- "decide",
676
- get_next_node,
677
- {
678
- "video": "video",
679
- "llm": "llm",
680
- "retrieve": "retrieve",
681
- "image": "image",
682
- "audio": "audio"
683
- }
684
- )
685
 
686
- # Add edges from all terminal nodes to END
687
- graph.add_edge("video", END)
688
- graph.add_edge("llm", END)
689
- graph.add_edge("generate", END)
690
- graph.add_edge("image", END)
691
- graph.add_edge("audio", END)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
- # Compile the graph
694
- agent = graph.compile()
 
 
 
 
 
 
 
695
 
696
- # --- Intelligent Agent Function ---
697
- def intelligent_agent(state: State) -> str:
698
- """Process a question using the appropriate pipeline based on content."""
699
- try:
700
- # Ensure state has proper structure
701
- if not isinstance(state, dict):
702
- return "FINAL ANSWER: Error - input must be a valid State dictionary"
703
-
704
- # Make sure question exists
705
- if 'question' not in state:
706
- return "FINAL ANSWER: Error - question is required"
707
-
708
- # Initialize context if not present
709
- if 'context' not in state:
710
- state['context'] = []
711
-
712
- print(f"Processing question: {state['question']}")
713
-
714
- # Invoke the agent with the state
715
- final_state = agent.invoke(state)
716
-
717
- # Ensure answer has FINAL ANSWER format
718
- answer = final_state.get('answer', "No answer found.")
719
- formatted_answer = ensure_final_answer_format(answer)
720
-
721
- return formatted_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
- except Exception as e:
724
- print(f"Error in agent execution: {str(e)}")
725
- return f"FINAL ANSWER: An error occurred - {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Library
2
+ import os
3
+ import re
4
  import tempfile
5
+ import string
6
  import glob
 
7
  import shutil
8
+ import gc
9
+ import uuid
10
+ import signal
11
+ from datetime import datetime
12
+ from io import BytesIO
13
+ from contextlib import contextmanager
14
+ from langchain_huggingface import HuggingFacePipeline
15
+ from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set
16
+ import time
17
+ from collections import Counter
18
+
19
+ # Third-Party Packages
20
  import cv2
21
+ import requests
22
  import wikipedia
23
+ import spacy
24
+ import yt_dlp
25
+ import librosa
26
+ from PIL import Image
27
+ from bs4 import BeautifulSoup
28
+ from duckduckgo_search import DDGS
29
+ from sentence_transformers import SentenceTransformer
30
+ from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline
31
 
32
+ # LangChain Ecosystem
33
  from langchain.docstore.document import Document
34
  from langchain.prompts import PromptTemplate
35
  from langchain_community.document_loaders import WikipediaLoader
36
+ from langchain_huggingface import HuggingFaceEndpoint
37
+ from langchain_community.retrievers import BM25Retriever
 
 
38
  from langchain.vectorstores import FAISS
39
  from langchain.embeddings import HuggingFaceEmbeddings
40
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
41
  from langchain.schema import Document
42
+ from langchain_community.tools import DuckDuckGoSearchRun
43
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage
44
+ from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description
45
+ from langchain_core.documents import Document
 
46
 
47
+ # LangGraph
48
+ from langgraph.graph import START, END, StateGraph
49
+ from langgraph.prebuilt import ToolNode, tools_condition
50
 
51
+ # PyTorch
 
 
 
 
52
  import torch
53
+ from functools import partial
54
+ from transformers import pipeline
55
+
56
+ # Additional Utilities
57
+ from datetime import datetime
58
+
59
+ from urllib.parse import urljoin, urlparse
60
+ import logging
61
 
62
 
63
  nlp = spacy.load("en_core_web_sm")
64
 
65
+ logger = logging.getLogger(__name__)
66
+
67
+ # --- Model Configuration ---
68
+ def create_llm_pipeline():
69
+ #model_id = "meta-llama/Llama-2-13b-chat-hf"
70
+ #model_id = "meta-llama/Llama-3.3-70B-Instruct"
71
+ #model_id = "mistralai/Mistral-Small-24B-Base-2501"
72
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
73
+ #model_id = "Qwen/Qwen2-7B-Instruct"
74
+ return pipeline(
75
+ "text-generation",
76
+ model=model_id,
77
+ device_map="auto",
78
+ torch_dtype=torch.float16,
79
+ max_new_tokens=1024,
80
+ temperature=0.1
81
+ )
82
+
83
  # Define file extension sets for each category
84
  PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
85
  AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'}
 
122
  print(f"File written to: {file_path}")
123
  return file_path
124
 
125
+
126
+ def extract_final_answer(text: str) -> str:
127
+ """
128
+ Returns the substring starting from the last occurrence of 'FINAL ANSWER:' (case-insensitive)
129
+ to the end of the string, with any trailing punctuation removed.
130
+ If not found, returns an empty string.
131
+ """
132
+ marker = "FINAL ANSWER:"
133
+ idx = text.lower().rfind(marker.lower())
134
+ if idx == -1:
135
+ return ""
136
+ result = text[idx:].strip()
137
+ # Remove trailing punctuation
138
+ return result.rstrip(string.punctuation + " ")
139
+
140
+
141
+ class EnhancedDuckDuckGoSearchTool(BaseTool):
142
+ name: str = "enhanced_search"
143
+ description: str = (
144
+ "Performs a DuckDuckGo web search and retrieves actual content from the top web results. "
145
+ "Input should be a search query string. "
146
+ "Returns search results with extracted content from web pages, making it much more useful for answering questions. "
147
+ "Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. "
148
+ "Ideal for topics that require the latest news, recent developments, or information not covered in static sources."
149
+ )
150
+ max_results: int = 3
151
+ max_chars_per_page: int = 3000
152
+ session: Any = None # Now it's optional and defaults to None
153
+
154
+
155
+ # Use model_post_init for initialization logic in Pydantic v2+
156
+ def model_post_init(self, __context: Any) -> None:
157
+ super().model_post_init(__context)
158
+ # Initialize HTTP session here
159
+ self.session = requests.Session()
160
+ self.session.headers.update({
161
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
162
+ 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
163
+ 'Accept-Language': 'en-US,en;q=0.5',
164
+ 'Accept-Encoding': 'gzip, deflate',
165
+ 'Connection': 'keep-alive',
166
+ 'Upgrade-Insecure-Requests': '1',
167
+ })
168
+
169
+ def _search_duckduckgo(self, query: str) -> List[Dict]:
170
+ """Perform DuckDuckGo search and return results."""
171
+ try:
172
+ with DDGS() as ddgs:
173
+ results = list(ddgs.text(query, max_results=self.max_results))
174
+ return results
175
+ except Exception as e:
176
+ logger.error(f"DuckDuckGo search failed: {e}")
177
+ return []
178
+
179
+ def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]:
180
+ """Extract clean text content from a web page."""
181
+ try:
182
+ # Skip certain file types
183
+ if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']):
184
+ return "Content type not supported for extraction"
185
+
186
+ response = self.session.get(url, timeout=timeout, allow_redirects=True)
187
+ response.raise_for_status()
188
+
189
+ # Check content type
190
+ content_type = response.headers.get('content-type', '').lower()
191
+ if 'text/html' not in content_type:
192
+ return "Non-HTML content detected"
193
+
194
+ soup = BeautifulSoup(response.content, 'html.parser')
195
+
196
+ # Remove script and style elements
197
+ for script in soup(["script", "style", "nav", "header", "footer", "aside", "form"]):
198
+ script.decompose()
199
+
200
+ # Try to find main content areas
201
+ main_content = None
202
+ for selector in ['main', 'article', '.content', '#content', '.post', '.entry']:
203
+ main_content = soup.select_one(selector)
204
+ if main_content:
205
+ break
206
+
207
+ if not main_content:
208
+ main_content = soup.find('body') or soup
209
+
210
+ # Extract text
211
+ text = main_content.get_text(separator='\n', strip=True)
212
+
213
+ # Clean up the text
214
+ lines = [line.strip() for line in text.split('\n') if line.strip()]
215
+ text = '\n'.join(lines)
216
+
217
+ # Remove excessive whitespace
218
+ text = re.sub(r'\n{3,}', '\n\n', text)
219
+ text = re.sub(r' {2,}', ' ', text)
220
+
221
+ # Truncate if too long
222
+ if len(text) > self.max_chars_per_page:
223
+ text = text[:self.max_chars_per_page] + "\n[Content truncated...]"
224
+
225
+ return text
226
+
227
+ except requests.exceptions.Timeout:
228
+ return "Page loading timed out"
229
+ except requests.exceptions.RequestException as e:
230
+ return f"Failed to retrieve page: {str(e)}"
231
+ except Exception as e:
232
+ logger.error(f"Content extraction failed for {url}: {e}")
233
+ return "Failed to extract content from page"
234
+
235
+ def _format_search_result(self, result: Dict, content: str) -> str:
236
+ """Format a single search result with its content."""
237
+ title = result.get('title', 'No title')
238
+ url = result.get('href', 'No URL')
239
+ snippet = result.get('body', 'No snippet')
240
+
241
+ formatted = f"""
242
+ 🔍 **{title}**
243
+ URL: {url}
244
+ Snippet: {snippet}
245
+
246
+ 📄 **Page Content:**
247
+ {content}
248
+ ---
249
+ """
250
+ return formatted
251
+
252
+ def run(self, query: str) -> str:
253
+
254
+ """Execute the enhanced search."""
255
+ if not query or not query.strip():
256
+ return "Please provide a search query."
257
+
258
+ query = query.strip()
259
+ logger.info(f"Searching for: {query}")
260
+
261
+ # Perform DuckDuckGo search
262
+ search_results = self._search_duckduckgo(query)
263
+
264
+ if not search_results:
265
+ return f"No search results found for query: {query}"
266
+
267
+ # Process each result and extract content
268
+ enhanced_results = []
269
+ processed_count = 0
270
+
271
+ for i, result in enumerate(search_results[:self.max_results]):
272
+ url = result.get('href', '')
273
+ if not url:
274
+ continue
275
+
276
+ logger.info(f"Processing result {i+1}: {url}")
277
+
278
+ # Extract content from the page
279
+ content = self._extract_content_from_url(url)
280
+
281
+ if content and len(content.strip()) > 50: # Only include results with substantial content
282
+ formatted_result = self._format_search_result(result, content)
283
+ enhanced_results.append(formatted_result)
284
+ processed_count += 1
285
+
286
+ # Small delay to be respectful to servers
287
+ time.sleep(0.5)
288
+
289
+ if not enhanced_results:
290
+ return f"Search completed but no content could be extracted from the pages for query: {query}"
291
+
292
+ # Compile final response
293
+ response = f"""🔍 **Enhanced Search Results for: "{query}"**
294
+ Found {len(search_results)} results, successfully processed {processed_count} pages with content.
295
+
296
+ {''.join(enhanced_results)}
297
+
298
+ 💡 **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query.
299
+ """
300
+
301
+ # Ensure the response isn't too long
302
+ if len(response) > 8000:
303
+ response = response[:8000] + "\n[Response truncated to prevent memory issues]"
304
+
305
+ return response
306
+
307
+ def _run(self, query: str) -> str:
308
+ """Required by BaseTool interface."""
309
+ return self.run(query)
310
+
311
+ # --- Agent State Definition ---
312
+ class AgentState(TypedDict):
313
+ messages: Annotated[List[AnyMessage], lambda x, y: x + y]
314
+ done: bool = False # Default value of False
315
  question: str
316
  task_id: str
317
  input_file: Optional[bytes]
 
321
  youtube_url: Optional[str]
322
  answer: Optional[str]
323
  frame_answers: Optional[list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ def fetch_page_with_tables(page_title):
327
+ """
328
+ Fetches Wikipedia page content and extracts all tables as readable text.
329
+ Returns a tuple: (main_text, [table_texts])
330
+ """
331
+ # Fetch the page object
332
+ page = wikipedia.page(page_title)
333
+ main_text = page.content
334
+
335
+ # Get the HTML for table extraction
336
+ html = page.html()
337
+ soup = BeautifulSoup(html, 'html.parser')
338
+ tables = soup.find_all('table')
339
+
340
+ table_texts = []
341
+ for table in tables:
342
+ rows = table.find_all('tr')
343
+ table_lines = []
344
+ for row in rows:
345
+ cells = row.find_all(['th', 'td'])
346
+ cell_texts = [cell.get_text(strip=True) for cell in cells]
347
+ if cell_texts:
348
+ # Format as Markdown table row
349
+ table_lines.append(" | ".join(cell_texts))
350
+ if table_lines:
351
+ table_text = "\n".join(table_lines)
352
+ table_texts.append(table_text)
353
+
354
+ return main_text, table_texts
355
+
356
+ class WikipediaSearchToolWithFAISS(BaseTool):
357
+ name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval"
358
+ description: str = (
359
+ "Fetches content from multiple Wikipedia pages based on intelligent NLP query processing "
360
+ "of various search candidates, with strong prioritization of query entities. It then performs "
361
+ "entity-focused semantic search across all fetched content to find the most relevant information, "
362
+ "with improved retrieval for lists like discographies. Uses spaCy for named entity "
363
+ "recognition and query enhancement. Input should be a search query or topic. "
364
+ "Note: Uses the current live version of Wikipedia."
365
+ )
366
+ embedding_model_name: str = "all-MiniLM-L6-v2"
367
+ chunk_size: int = 4000
368
+ chunk_overlap: int = 250 # Maintained moderate overlap
369
+ top_k_results: int = 3
370
+ spacy_model: str = "en_core_web_sm"
371
+ # Increased multiplier to fetch more candidates per semantic query variant
372
+ semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic
373
+
374
+ def __init__(self, **kwargs):
375
+ super().__init__(**kwargs)
376
+ try:
377
+ self._nlp = spacy.load(self.spacy_model)
378
+ print(f"Loaded spaCy model: {self.spacy_model}")
379
+ self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name)
380
+ # Refined separators for better handling of Wikipedia lists and sections
381
+ self._text_splitter = RecursiveCharacterTextSplitter(
382
+ chunk_size=self.chunk_size,
383
+ chunk_overlap=self.chunk_overlap,
384
+ separators=[
385
+ "\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content)
386
+ "\n\n\n", "\n\n", # Multiple newlines (paragraph breaks)
387
+ "\n* ", "\n- ", "\n# ", # List items
388
+ "\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation
389
+ " ", "" # Word and character level
390
+ ]
391
+ )
392
+ except OSError as e:
393
+ print(f"Error loading spaCy model '{self.spacy_model}': {e}")
394
+ print("Try running: python -m spacy download en_core_web_sm")
395
+ self._nlp = None
396
+ self._embedding_model = None
397
+ self._text_splitter = None
398
+ except Exception as e:
399
+ print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}")
400
+ self._nlp = None
401
+ self._embedding_model = None
402
+ self._text_splitter = None
403
+
404
+ def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]:
405
+ if not self._nlp:
406
+ return [], [], query
407
+ doc = self._nlp(query)
408
+ main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]]
409
+ keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2]
410
+ main_entities = list(dict.fromkeys(main_entities))
411
+ keywords = list(dict.fromkeys(keywords))
412
+ processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()]
413
+ processed_query = " ".join(processed_tokens)
414
+ return main_entities, keywords, processed_query
415
+
416
+ def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]:
417
+ candidates_set = set()
418
+ entity_prefix = main_entities[0] if main_entities else None
419
+
420
+ for me in main_entities:
421
+ candidates_set.add(me)
422
+ candidates_set.add(query)
423
+ if processed_query and processed_query != query:
424
+ candidates_set.add(processed_query)
425
+
426
+ if entity_prefix and keywords:
427
+ first_entity_lower = entity_prefix.lower()
428
+ for kw in keywords[:3]:
429
+ if kw not in first_entity_lower and len(kw) > 2:
430
+ candidates_set.add(f"{entity_prefix} {kw}")
431
+ keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2)
432
+ if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}")
433
+
434
+ if len(main_entities) > 1:
435
+ candidates_set.add(" ".join(main_entities[:2]))
436
+
437
+ if keywords:
438
+ keyword_combo = " ".join(keywords[:2])
439
+ if entity_prefix:
440
+ candidate_to_add = f"{entity_prefix} {keyword_combo}"
441
+ if not any(c.lower() == candidate_to_add.lower() for c in candidates_set):
442
+ candidates_set.add(candidate_to_add)
443
+ elif not main_entities:
444
+ candidates_set.add(keyword_combo)
445
+
446
+ ordered_candidates = []
447
+ for me in main_entities:
448
+ if me not in ordered_candidates: ordered_candidates.append(me)
449
+ for c in list(candidates_set):
450
+ if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c)
451
+
452
+ print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}")
453
+ return ordered_candidates
454
 
455
+ def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]:
456
+ candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text)
457
+ found_pages_data: List[Tuple[str, str]] = []
458
+ processed_page_titles: Set[str] = set()
459
+
460
+ for i, candidate_query in enumerate(candidates):
461
+ print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'")
462
+ page_object = None
463
+ final_page_title = None
464
+ is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False
465
+
466
+ try:
467
+ try:
468
+ page_to_load = candidate_query
469
+ suggest_mode = True # Default to auto_suggest=True
470
+ if is_candidate_entity_focused and main_entities_from_query:
471
+ try: # Attempt precise match first for entity-focused candidates
472
+ temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True)
473
+ suggest_mode = False # Flag that precise match worked
474
+ except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError):
475
+ print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.")
476
+ # Fallthrough to auto_suggest=True below if this fails
477
+
478
+ if suggest_mode: # If not attempted or failed with auto_suggest=False
479
+ temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True)
480
+
481
+ final_page_title = temp_page.title
482
+
483
+ if is_candidate_entity_focused and main_entities_from_query:
484
+ title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query)
485
+ if not title_matches_main_entity:
486
+ print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') "
487
+ f"does not strongly match main query entities: {main_entities_from_query}. Skipping.")
488
+ continue
489
+ if final_page_title in processed_page_titles:
490
+ print(f" ~ Already processed '{final_page_title}'")
491
+ continue
492
+ page_object = temp_page
493
+ print(f" ✓ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'")
494
+
495
+ except wikipedia.exceptions.PageError:
496
+ if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates
497
+ print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...")
498
+ search_results = wikipedia.search(candidate_query, results=1)
499
+ if not search_results:
500
+ print(f" - No Wikipedia search results for '{candidate_query}'.")
501
+ continue
502
+ search_result_title = search_results[0]
503
+ try:
504
+ temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical
505
+ final_page_title = temp_page.title
506
+ if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent
507
+ title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query)
508
+ if not title_matches_main_entity:
509
+ print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') "
510
+ f"does not strongly match main query entities: {main_entities_from_query}. Skipping.")
511
+ continue
512
+ if final_page_title in processed_page_titles:
513
+ print(f" ~ Already processed '{final_page_title}'")
514
+ continue
515
+ page_object = temp_page
516
+ print(f" ✓ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'")
517
+ except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr:
518
+ print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}")
519
+ else:
520
+ print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.")
521
+ except wikipedia.exceptions.DisambiguationError as de:
522
+ print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}")
523
+ if de.options:
524
+ option_title = de.options[0]
525
+ try:
526
+ temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True)
527
+ final_page_title = temp_page.title
528
+ if is_candidate_entity_focused and main_entities_from_query: # Check against original intent
529
+ title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query)
530
+ if not title_matches_main_entity:
531
+ print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') "
532
+ f"does not strongly match main query entities: {main_entities_from_query}. Skipping.")
533
+ continue
534
+ if final_page_title in processed_page_titles:
535
+ print(f" ~ Already processed '{final_page_title}'")
536
+ continue
537
+ page_object = temp_page
538
+ print(f" ✓ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'")
539
+ except Exception as e_dis_opt:
540
+ print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}")
541
+
542
+ if page_object and final_page_title and (final_page_title not in processed_page_titles):
543
+ # Extract main text
544
+ main_text = page_object.content
545
+
546
+ # Extract tables using BeautifulSoup
547
+ try:
548
+ html = page_object.html()
549
+ soup = BeautifulSoup(html, 'html.parser')
550
+ tables = soup.find_all('table')
551
+ table_texts = []
552
+ for table in tables:
553
+ rows = table.find_all('tr')
554
+ table_lines = []
555
+ for row in rows:
556
+ cells = row.find_all(['th', 'td'])
557
+ cell_texts = [cell.get_text(strip=True) for cell in cells]
558
+ if cell_texts:
559
+ table_lines.append(" | ".join(cell_texts))
560
+ if table_lines:
561
+ table_text = "\n".join(table_lines)
562
+ table_texts.append(table_text)
563
+ except Exception as e:
564
+ print(f" !! Error extracting tables for '{final_page_title}': {e}")
565
+ table_texts = []
566
+
567
+ # Combine main text and all table texts as separate chunks
568
+ all_text_chunks = [main_text] + table_texts
569
+
570
+ for chunk in all_text_chunks:
571
+ found_pages_data.append((chunk, final_page_title))
572
+ processed_page_titles.add(final_page_title)
573
+ print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}")
574
+ except Exception as e:
575
+ print(f" !! Unexpected error processing candidate '{candidate_query}': {e}")
576
+
577
+ if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.")
578
+ else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.")
579
+ return found_pages_data
580
+
581
+ def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]:
582
+ core_query_parts = set()
583
+ core_query_parts.add(query)
584
+ if processed_query != query: core_query_parts.add(processed_query)
585
+ if keywords: core_query_parts.add(" ".join(keywords[:2]))
586
+
587
+ section_phrases_templates = []
588
+ lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords)
589
+
590
+ section_keywords_map = {
591
+ "discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"],
592
+ "biography": ["biography", "life story", "career details", "background history"],
593
+ "filmography": ["filmography", "list of films", "movie appearances", "acting roles"],
594
  }
595
+ for section_term_key, specific_phrases_list in section_keywords_map.items():
596
+ # Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums")
597
+ # are mentioned or implied by the query terms.
598
+ if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()):
599
+ section_phrases_templates.extend(specific_phrases_list)
600
+ # Also check if phrases themselves are in query terms (e.g. query "list of albums by X")
601
+ for phrase in specific_phrases_list:
602
+ if phrase in query.lower(): # Check against original query for direct phrase matches
603
+ section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit
604
+ break
605
+ section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate
606
+
607
+ final_search_queries = set()
608
+ if main_entities:
609
+ entity_prefix = main_entities[0]
610
+ final_search_queries.add(entity_prefix)
611
+ for part in core_query_parts:
612
+ final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part)
613
+ for phrase_template in section_phrases_templates:
614
+ final_search_queries.add(f"{entity_prefix} {phrase_template}")
615
+ if "list of" in phrase_template or "history of" in phrase_template :
616
+ final_search_queries.add(f"{phrase_template} of {entity_prefix}")
617
+ else:
618
+ final_search_queries.update(core_query_parts)
619
+ final_search_queries.update(section_phrases_templates)
620
+
621
+ deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip()))
622
+ print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}")
623
+
624
+ all_results_docs: List[Document] = []
625
+ seen_content_hashes: Set[int] = set()
626
+ k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier
627
+
628
+ for search_query_variant in deduplicated_queries:
629
+ try:
630
+ results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch)
631
+ print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.")
632
+ for doc, score in results: # Assuming similarity_search_with_score returns (doc, score)
633
+ content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness
634
+ if content_hash not in seen_content_hashes:
635
+ seen_content_hashes.add(content_hash)
636
+ doc.metadata['retrieved_by_variant'] = search_query_variant
637
+ doc.metadata['retrieval_score'] = float(score) # Store score
638
+ all_results_docs.append(doc)
639
+ except Exception as e:
640
+ print(f" Error in semantic search for variant '{search_query_variant}': {e}")
641
+
642
+ # Sort all collected unique results by score (FAISS L2 distance is lower is better)
643
+ all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf')))
644
+ print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.")
645
+
646
+ return all_results_docs[:self.top_k_results]
647
 
648
+ def _run(self, query: str) -> str:
649
+ if not self._nlp or not self._embedding_model or not self._text_splitter:
650
+ print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.")
651
+ return "Error: Wikipedia tool components not initialized properly. Please check server logs."
652
 
653
+ try:
654
+ print(f"\n--- Running {self.name} for query: '{query}' ---")
655
+ main_entities, keywords, processed_query = self._extract_entities_and_keywords(query)
656
+ print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'")
657
+
658
+ fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query)
659
+
660
+ if not fetched_pages_data:
661
+ return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. "
662
+ f"Main entities sought: {main_entities}")
663
+
664
+ all_page_titles = [title for _, title in fetched_pages_data]
665
+ print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}")
666
+
667
+ all_documents: List[Document] = []
668
+ for page_content, page_title in fetched_pages_data:
669
+ chunks = self._text_splitter.split_text(page_content)
670
+ if not chunks:
671
+ print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.")
672
+ continue
673
+ for i, chunk_text in enumerate(chunks):
674
+ all_documents.append(Document(page_content=chunk_text, metadata={
675
+ "source_page_title": page_title,
676
+ "original_query": query,
677
+ "chunk_index": i # Add chunk index for potential debugging or ordering
678
+ }))
679
+ print(f"Split content from '{page_title}' into {len(chunks)} chunks.")
680
+
681
+ if not all_documents:
682
+ return (f"Could not process content into searchable chunks from the fetched Wikipedia pages "
683
+ f"({', '.join(all_page_titles)}) for query '{query}'.")
684
+
685
+ print(f"\nTotal document chunks from all pages: {len(all_documents)}")
686
 
687
+ print("Creating FAISS index from content of all fetched pages...")
688
+ try:
689
+ vector_store = FAISS.from_documents(all_documents, self._embedding_model)
690
+ print("FAISS index created successfully.")
691
+ except Exception as e:
692
+ return f"Error creating FAISS vector store: {e}"
693
 
694
+ print(f"\nPerforming enhanced semantic search across all collected content...")
695
+ try:
696
+ relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query)
697
+ except Exception as e:
698
+ return f"Error during semantic search: {e}"
699
+
700
+ if not relevant_docs:
701
+ return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' "
702
+ f"for your query '{query}' using entity-focused semantic search with list retrieval.")
703
+
704
+ unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs]))
705
+ result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) "
706
+ f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n")
707
+ nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, "
708
+ f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n")
709
+ result_details = []
710
+ for i, doc in enumerate(relevant_docs):
711
+ source_info = doc.metadata.get('source_page_title', 'Unknown Source')
712
+ variant_info = doc.metadata.get('retrieved_by_variant', 'N/A')
713
+ score_info = doc.metadata.get('retrieval_score', 'N/A')
714
+ detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n"
715
+ f"(Retrieved by: '{variant_info}')\n{doc.page_content}")
716
+ result_details.append(detail)
717
+
718
+ final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details)
719
+ print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).")
720
+ return final_result.strip()
721
 
 
 
 
 
 
 
722
  except Exception as e:
723
+ import traceback
724
+ print(f"Unexpected error in {self.name}: {traceback.format_exc()}")
725
+ return f"An unexpected error occurred: {str(e)}"
726
+
727
+
728
+ # Example of creating the tool instance:
729
+ # wikipedia_tool_faiss = WikipediaSearchToolWithFAISS()
730
+
731
+ # To use this new tool in your agent, you would replace the old
732
+ # `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list.
733
+ # For example:
734
+ # tools = [wikipedia_tool_faiss, search_tool]
735
+ # Create tool instances
736
+ #wikipedia_tool = WikipediaSearchTool()
737
+
738
+ # --- Define Call LLM function ---
739
+
740
+ # 3. Improved LLM call with memory management
741
+
742
+ def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: # Added llm_model parameter
743
+ """Call LLM with memory management, context truncation, and process response."""
744
+ print("Running call_llm with memory management...")
745
+
746
+ # It's crucial to work with a copy of messages for modification within this step
747
+ # The final state["messages"] should reflect the full history + new response.
748
+ original_messages = list(state["messages"])
749
+ messages_for_llm_processing = list(state["messages"]) # Use this for truncation logic
750
+
751
+ #ipdb.set_trace()
752
+
753
+ # --- Context Truncation Logic ---
754
+ system_message_content = None
755
+ # Check if the first message is a system message and preserve it
756
+ if messages_for_llm_processing and isinstance(messages_for_llm_processing[0], SystemMessage):
757
+ system_message_content = messages_for_llm_processing[0]
758
+ # Process only non-system messages for truncation count
759
+ regular_messages = messages_for_llm_processing[1:]
760
+ else:
761
+ regular_messages = messages_for_llm_processing
762
+
763
+ # Truncate context if too many messages (e.g., keep system + X most recent)
764
+ # Max 10 messages total (e.g. 1 system + 9 others)
765
+ max_regular_messages = 9
766
+ if len(regular_messages) > max_regular_messages:
767
+ print(f"🔄 Truncating message count: {len(messages_for_llm_processing)} -> ~{max_regular_messages + (1 if system_message_content else 0)} messages")
768
+ regular_messages = regular_messages[- (max_regular_messages -1):] # Keep X-1 most recent, to add user input later
769
+
770
+ # Reconstruct messages for LLM call
771
+ messages_for_llm = []
772
+ if system_message_content:
773
+ messages_for_llm.append(system_message_content)
774
+ messages_for_llm.extend(regular_messages)
775
+
776
+ # Further truncate based on character count (rough proxy for tokens)
777
+ total_chars = sum(len(str(msg.content)) for msg in messages_for_llm)
778
+ # Example character limit, adjust based on your model (e.g. 8k chars for ~4k tokens)
779
+ char_limit = 8000
780
+ if total_chars > char_limit:
781
+ print(f"📏 Context too long ({total_chars} chars > {char_limit}), further truncation needed")
782
+ # More aggressive truncation of regular messages
783
+ chars_to_remove = total_chars - char_limit
784
+ temp_regular_messages = list(regular_messages) # copy
785
+ while sum(len(str(m.content)) for m in temp_regular_messages) > char_limit and temp_regular_messages:
786
+ if system_message_content and sum(len(str(m.content)) for m in temp_regular_messages) + len(str(system_message_content.content)) <= char_limit :
787
+ break # if removing one more makes it too small with system message
788
+ print(f"Removing message: {temp_regular_messages[0].type} - {temp_regular_messages[0].content[:50]}...")
789
+ temp_regular_messages.pop(0)
790
+
791
+ regular_messages = temp_regular_messages
792
+ messages_for_llm = [] # Rebuild
793
+ if system_message_content:
794
+ messages_for_llm.append(system_message_content)
795
+ messages_for_llm.extend(regular_messages)
796
+ print(f"Context truncated to {sum(len(str(m.content)) for m in messages_for_llm)} chars.")
797
+
798
+ new_state = state.copy() # Start with a copy of the input state
799
+
800
+ try:
801
+ if torch.cuda.is_available():
802
+ torch.cuda.empty_cache()
803
+ print(f"🧹 Pre-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
804
+
805
+ print(f"Invoking LLM with {len(messages_for_llm)} messages.")
806
+ # This is where you call your actual LLM
807
+ formatted_input = "\n".join([f"[{msg.type.upper()}] {msg.content}" for msg in messages_for_llm])
808
+ print(f"\n\nFormatted input for LLM:\n\n{formatted_input}")
809
+
810
+ llm_response_object = llm_model.invoke(formatted_input)
811
+
812
+ #ipdb.set_trace()
813
+
814
+ # The response_object is typically a BaseMessage subclass (e.g., AIMessage)
815
+ # or a string for simpler LLMs. Adapt as needed.
816
+ if isinstance(llm_response_object, BaseMessage):
817
+ ai_message_response = llm_response_object # It's already a message object
818
+ if not ai_message_response.content: # Ensure content is not empty
819
+ ai_message_response.content = ""
820
+ elif hasattr(llm_response_object, 'content'): # Some models might return a custom object with a content attribute
821
+ ai_message_response = AIMessage(content=str(llm_response_object.content) if llm_response_object.content is not None else "")
822
+ else: # Assuming it's a string for basic LLMs
823
+ ai_message_response = AIMessage(content=str(llm_response_object) if llm_response_object is not None else "")
824
+
825
+ print(f"LLM Response: {ai_message_response.content[:300]}...") # Print a snippet
826
+
827
+ # Append the LLM's response to the original full list of messages
828
+ final_messages = original_messages + [ai_message_response]
829
+ new_state["messages"] = final_messages
830
+ new_state.pop("done", None) # LLM responded, so not 'done' by default
831
+
832
+ except Exception as e:
833
+ print(f"LLM call failed: {e}")
834
+ error_message_content = f"LLM call failed with error: {str(e)}. Input consisted of {len(messages_for_llm)} messages."
835
+
836
+ if "out of memory" in str(e).lower():
837
+ print("🚨 CUDA OOM detected during LLM call! Implementing emergency cleanup...")
838
+ error_message_content = f"LLM failed due to Out of Memory: {str(e)}."
839
+ try:
840
+ if torch.cuda.is_available():
841
+ torch.cuda.empty_cache()
842
+ gc.collect()
843
+ except Exception as cleanup_e:
844
+ print(f"Emergency OOM cleanup failed: {cleanup_e}")
845
+
846
+ # Append an error message to the original message history
847
+ error_ai_message = AIMessage(content=error_message_content)
848
+ final_messages_on_error = original_messages + [error_ai_message]
849
+ new_state["messages"] = final_messages_on_error
850
+ new_state["done"] = True # Mark as done to prevent loops on LLM failure
851
+ finally:
852
+ try:
853
+ if torch.cuda.is_available():
854
+ torch.cuda.empty_cache()
855
+ print(f"🧹 Post-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
856
+ except Exception:
857
+ pass # Avoid error in cleanup hiding the main error
858
+
859
+ return new_state
860
+ import re
861
+ import uuid
862
+
863
+ def parse_react_output(state: AgentState) -> AgentState:
864
+ print("Running parse_react_output (Action prioritized)...")
865
+ messages = state["messages"]
866
+ last_message = messages[-1]
867
+ new_state = state.copy()
868
+
869
+ # Only process AI messages (not system/user)
870
+ if not isinstance(last_message, AIMessage):
871
+ return new_state
872
+
873
+ content = last_message.content
874
+
875
+ # Remove any system prompt/instructions (if present in content)
876
+ # Assume that the actual AI output is after the last occurrence of "You are a general AI assistant" or similar system prompt marker
877
+ sys_prompt_pattern = r"(You are a general AI assistant.*?)(?=\n\n|$)"
878
+ content_wo_sys_prompt = re.sub(sys_prompt_pattern, '', content, flags=re.DOTALL | re.IGNORECASE).strip()
879
+
880
+ # Find the last occurrence of FINAL ANSWER or Action Input
881
+ final_answer_match = list(re.finditer(r"FINAL ANSWER:", content_wo_sys_prompt, re.IGNORECASE))
882
+ action_input_match = list(re.finditer(r"Action Input:", content_wo_sys_prompt, re.IGNORECASE))
883
+
884
+ # Helper: get the last match position and which it was
885
+ last_marker = None
886
+ last_pos = -1
887
+ if final_answer_match:
888
+ last_fa = final_answer_match[-1]
889
+ last_marker = 'FINAL ANSWER'
890
+ last_pos = last_fa.start()
891
+ if action_input_match:
892
+ last_ai = action_input_match[-1]
893
+ if last_ai.start() > last_pos:
894
+ last_marker = 'Action Input'
895
+ last_pos = last_ai.start()
896
+
897
+ # If neither marker found, mark as done
898
+ if not last_marker:
899
+ print("No FINAL ANSWER or Action Input found in last AI output.")
900
+ new_state["done"] = True
901
+ return new_state
902
+
903
+ # Get the substring from the last marker to the end
904
+ last_section = content_wo_sys_prompt[last_pos:].strip()
905
+
906
+ # 2. If FINAL ANSWER is in the last part, end the process
907
+ if last_marker == 'FINAL ANSWER':
908
+ # Extract the answer after FINAL ANSWER:
909
+ answer = re.search(r"FINAL ANSWER:\s*(.+)", last_section, re.IGNORECASE)
910
+ final_answer_text = answer.group(1).strip() if answer else ""
911
+ updated_ai_message = AIMessage(content=f"FINAL ANSWER: {final_answer_text}", tool_calls=[])
912
+ new_state["messages"] = messages[:-1] + [updated_ai_message]
913
+ new_state["done"] = True
914
+ print(f"FINAL ANSWER found at end: '{final_answer_text}'")
915
+ return new_state
916
+
917
+ # 3. If Action Input is in the last part, launch tool
918
+ if last_marker == 'Action Input':
919
+ # Try to extract the Action and Action Input for the last occurrence
920
+ action_match = list(re.finditer(r"Action:\s*([^\n]+)", last_section))
921
+ action_input_match = list(re.finditer(r"Action Input:\s*([^\n]+)", last_section))
922
+ if action_match and action_input_match:
923
+ tool_name = action_match[-1].group(1).strip()
924
+ tool_input_raw = action_input_match[-1].group(1).strip()
925
+ print(f"ReAct: Found Action: {tool_name}, Input: '{tool_input_raw}'")
926
+ # Format tool_args as in your original code (simplified here)
927
+ tool_args = {"query": tool_input_raw}
928
+ tool_call_id = str(uuid.uuid4())
929
+ parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}]
930
+ updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls)
931
+ new_state["messages"] = messages[:-1] + [updated_ai_message]
932
+ new_state.pop("done", None)
933
+ print(f"AIMessage updated with tool_calls: {parsed_tool_calls}")
934
+ return new_state
935
+ else:
936
+ print("Action Input found at end, but could not parse Action or Action Input.")
937
+ new_state["done"] = True
938
+ return new_state
939
+
940
+ # Fallback: mark as done
941
+ print("No actionable marker found at end of last AI output. Marking as done.")
942
+ new_state["done"] = True
943
+ return new_state
944
 
 
 
 
 
 
 
945
 
 
 
 
 
 
 
 
946
 
947
  def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'):
948
  """Download a YouTube video using yt-dlp"""
 
1019
  print(f"Exception during frame extraction: {e}")
1020
  return False
1021
 
1022
+ def answer_question_on_frame(image_path, question):
1023
+ """Answer a question about a single video frame using BLIP"""
1024
  try:
1025
+ vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly
1026
+ processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used
1027
+ model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used
1028
+ device = "cpu"
1029
+
1030
  image = Image.open(image_path).convert('RGB')
1031
  inputs = processor_vqa(image, question, return_tensors="pt").to(device)
1032
  out = model_vqa.generate(**inputs)
1033
  answer = processor_vqa.decode(out[0], skip_special_tokens=True)
1034
  return answer
1035
  except Exception as e:
1036
+ print(f"Error processing frame {image_path}: {str(e)}")
1037
+ return "Error processing this frame"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1038
 
1039
+ def answer_video_question(frames_dir, question):
1040
+ """Answer a question about a video by analyzing extracted frames"""
1041
+ valid_exts = ('.jpg', '.jpeg', '.png')
1042
 
1043
+ # Check if directory exists
1044
+ if not os.path.exists(frames_dir):
1045
+ return {
1046
+ "most_common_answer": "No frames found to analyze.",
1047
+ "all_answers": [],
1048
+ "answer_counts": Counter()
1049
+ }
1050
 
1051
+ frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir)
1052
+ if f.lower().endswith(valid_exts)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1053
 
1054
+ # Sort frames properly by number
1055
+ def get_frame_number(filename):
1056
+ match = re.search(r'(\d+)', os.path.basename(filename))
1057
+ return int(match.group(1)) if match else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1058
 
1059
+ frame_files = sorted(frame_files, key=get_frame_number)
 
 
 
 
 
 
1060
 
1061
+ if not frame_files:
1062
+ return {
1063
+ "most_common_answer": "No valid image frames found.",
1064
+ "all_answers": [],
1065
+ "answer_counts": Counter()
1066
+ }
1067
 
1068
+ answers = []
1069
+ for frame_path in frame_files:
1070
+ try:
1071
+ ans = answer_question_on_frame(frame_path, question)
1072
+ answers.append(ans)
1073
+ print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}")
1074
+ except Exception as e:
1075
+ print(f"Error processing frame {frame_path}: {str(e)}")
1076
 
1077
+ if not answers:
1078
+ return {
1079
+ "most_common_answer": "Could not analyze any frames successfully.",
1080
+ "all_answers": [],
1081
+ "answer_counts": Counter()
1082
+ }
1083
 
1084
+ counted = Counter(answers)
1085
+ most_common_answer, freq = counted.most_common(1)[0]
1086
+ return {
1087
+ "most_common_answer": most_common_answer,
1088
+ "all_answers": answers,
1089
+ "answer_counts": counted
1090
+ }
1091
 
 
 
 
1092
 
1093
+ class YoutubeScreenshotQA(BaseTool):
1094
+ name: str = "youtube_screenshot_qa"
1095
+ description: str = (
1096
+ "Downloads a YouTube video, extracts screenshots at intervals, "
1097
+ "and answers a question about the video based on the screenshots. "
1098
+ "Input should be a dict with keys: 'youtube_url' and 'question'."
1099
+ "Example input: {'youtube_url': 'https://www.youtube.com/watch?v=L1vXCYZAYYM', 'question': 'What is the highest number of bird species on camera simultaneously?'}"
1100
+ )
1101
+ frame_interval_seconds: int = 10 # Can be parameterized if needed
1102
+
1103
+ def _run(self, input_data: Dict[str, Any]) -> str:
1104
+ youtube_url = input_data.get("youtube_url")
1105
+ question = input_data.get("question")
1106
+
1107
+ if not youtube_url or not question:
1108
+ return "Error: Input must include 'youtube_url' and 'question'."
1109
+
1110
+ # Step 1: Download the video
1111
+ video_dir = '/tmp/video/'
1112
+ video_filename = 'downloaded_video.mp4'
1113
+ print(f"Downloading YouTube video from {youtube_url}...")
1114
+ video_path = download_youtube_video(youtube_url, output_dir=video_dir, output_filename=video_filename)
1115
+ if not video_path or not os.path.exists(video_path):
1116
+ return "Error: Failed to download the YouTube video."
1117
+
1118
+ # Step 2: Extract frames
1119
+ frames_dir = '/tmp/video_frames/'
1120
+ print(f"Extracting frames from {video_path} every {self.frame_interval_seconds} seconds...")
1121
+ success = extract_frames(video_path, frames_dir, frame_interval_seconds=self.frame_interval_seconds)
1122
+ if not success:
1123
+ return "Error: Failed to extract frames from the video."
1124
+
1125
+ # Step 3: Analyze frames and answer question
1126
+ print(f"Answering question about the video frames...")
1127
+ answer_result = answer_video_question(frames_dir, question)
1128
+ if not answer_result or not answer_result.get("most_common_answer"):
1129
+ return "Error: Could not analyze video frames to answer the question."
1130
+
1131
+ # Format the result
1132
+ most_common = answer_result["most_common_answer"]
1133
+ all_answers = answer_result["all_answers"]
1134
+ counts = answer_result["answer_counts"]
1135
+
1136
+ result = (
1137
+ f"Most common answer: {most_common}\n"
1138
+ f"All answers: {all_answers}\n"
1139
+ f"Answer counts: {dict(counts)}"
1140
  )
1141
+ return result
 
 
 
 
 
 
 
 
 
 
 
1142
 
1143
+ def tools_condition_with_logging(state: AgentState):
1144
+ """
1145
+ Custom tools condition function that checks if the last message contains tool calls
1146
+ in the Thought/Action/Action Input format and logs the transition decision.
1147
+
1148
+ Args:
1149
+ state (AgentState): The current state containing messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1150
 
1151
+ Returns:
1152
+ str: "tools" if tool calls are present, "__end__" otherwise
1153
+ """
1154
+
1155
+ import re
 
 
 
 
1156
 
1157
+ # Ensure we have messages in the state
1158
+ if not state.get("messages") or len(state["messages"]) == 0:
1159
+ print(" No messages found in state, ending conversation")
1160
+ return "__end__"
1161
+
1162
+ # Get the last message
1163
+ last_message = state["messages"][-1]
1164
+
1165
+ # Get message content
1166
+ content = ""
1167
+ if hasattr(last_message, 'content'):
1168
+ content = str(last_message.content)
1169
+ elif isinstance(last_message, dict) and 'content' in last_message:
1170
+ content = str(last_message['content'])
1171
+ else:
1172
+ print(" No content found in last message, ending conversation")
1173
+ return "__end__"
1174
+
1175
+ print(f"🔍 Analyzing message content: {content[:200]}...")
1176
+
1177
+ # Check for Thought/Action/Action Input format
1178
+ has_tool_calls = False
1179
+
1180
+ # Pattern to match the format:
1181
+ # Thought: <thought>
1182
+ # Action: <tool_name>
1183
+ # Action Input: <input>
1184
+ thought_action_pattern = re.compile(
1185
+ r'Thought:\s*(.*?)\n\s*Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)',
1186
+ re.DOTALL | re.IGNORECASE
1187
  )
1188
+
1189
+ # Also check for just Action/Action Input without Thought
1190
+ action_only_pattern = re.compile(
1191
+ r'Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)',
1192
+ re.DOTALL | re.IGNORECASE
1193
  )
1194
+
1195
+ # Look for the complete format first
1196
+ match = thought_action_pattern.search(content)
1197
+ if not match:
1198
+ # Try the action-only format
1199
+ match = action_only_pattern.search(content)
1200
+ if match:
1201
+ thought = "No thought provided"
1202
+ action = match.group(1).strip()
1203
+ action_input = match.group(2).strip()
1204
+ else:
1205
+ action = None
1206
+ action_input = None
1207
+ thought = None
1208
+ else:
1209
+ thought = match.group(1).strip()
1210
+ action = match.group(2).strip()
1211
+ action_input = match.group(3).strip()
1212
+
1213
+ if match and action:
1214
+ has_tool_calls = True
1215
+ print(f"🔧 Found tool call format:")
1216
+ print(f" Thought: {thought}")
1217
+ print(f" Action: {action}")
1218
+ print(f" Action Input: {action_input}")
1219
+
1220
+ # Map common tool names to your actual tools
1221
+ tool_mappings = {
1222
+ 'wikipedia_semantic_search': 'wikipedia_tool',
1223
+ 'wikipedia': 'wikipedia_tool',
1224
+ 'search': 'search_tool',
1225
+ 'duckduckgo_search': 'search_tool',
1226
+ 'web_search': 'search_tool',
1227
+ 'youtube_screenshot_qa_tool': 'youtube_tool',
1228
+ 'youtube': 'youtube_tool',
1229
+ }
1230
+
1231
+ # Normalize the action name
1232
+ normalized_action = action.lower().strip()
1233
+
1234
+ # Store the parsed tool call information in the state for the tools node to use
1235
+ if 'parsed_tool_calls' not in state:
1236
+ state['parsed_tool_calls'] = []
1237
+
1238
+ tool_call_info = {
1239
+ 'thought': thought,
1240
+ 'action': action,
1241
+ 'action_input': action_input,
1242
+ 'normalized_action': normalized_action,
1243
+ 'tool_mapping': tool_mappings.get(normalized_action, normalized_action)
1244
+ }
1245
+
1246
+ state['parsed_tool_calls'].append(tool_call_info)
1247
+ print(f"🚀 Added tool call to state: {tool_call_info}")
1248
+
1249
+ # Don't execute tools here - let call_tool handle execution
1250
+ # Just store the parsed information for call_tool to use
1251
+
1252
+ # Also check for standalone tool mentions (fallback)
1253
+ if not has_tool_calls:
1254
+ # Check for tool names mentioned in content
1255
+ tool_keywords = [
1256
+ 'wikipedia_semantic_search', 'wikipedia', 'search', 'duckduckgo',
1257
+ 'youtube_screenshot_qa_tool', 'youtube', 'web search'
1258
+ ]
1259
+
1260
+ content_lower = content.lower()
1261
+ for keyword in tool_keywords:
1262
+ if keyword in content_lower:
1263
+ print(f"🔧 Found tool keyword '{keyword}' in content (fallback detection)")
1264
+ has_tool_calls = True
1265
+ break
1266
+
1267
+ if has_tool_calls:
1268
+ print("🔧 Tool calls detected, transitioning to tools...")
1269
+ return "tools"
1270
+ else:
1271
+ print("✅ No tool calls found, ending conversation")
1272
+ return "__end__"
1273
 
 
 
 
 
 
 
 
 
1274
 
1275
+ # 2. Improved call_tool with memory management
1276
+ def call_tool_with_memory_management(state: AgentState) -> AgentState:
1277
+ """Process tool calls with memory management."""
1278
+ print("Running call_tool with memory management...")
 
 
1279
 
1280
+ # Clear CUDA cache before processing
1281
  try:
1282
+ import torch
1283
+ if torch.cuda.is_available():
1284
+ torch.cuda.empty_cache()
1285
+ print(f"🧹 Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB")
1286
+ except:
1287
+ pass
1288
+
1289
+ # Check if we have parsed tool calls from the condition function
1290
+ if 'parsed_tool_calls' in state and state['parsed_tool_calls']:
1291
+ return execute_parsed_tool_calls(state)
1292
+
1293
+ # Fallback to original OpenAI-style tool calls handling
1294
+ messages = state["messages"]
1295
+ last_message = messages[-1]
1296
+
1297
+ if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
1298
+ print("No tool calls found in last message")
1299
+ return state
1300
+
1301
+ # Copy the messages to avoid mutating the original list
1302
+ new_messages = list(messages)
1303
+
1304
+ print(f"Processing {len(last_message.tool_calls)} tool calls")
1305
+
1306
+ for i, tool_call in enumerate(last_message.tool_calls):
1307
+ print(f"Processing tool call {i+1}: {tool_call['name'] if isinstance(tool_call, dict) else tool_call.name}")
1308
 
1309
+ # Handle both dict and object-style tool calls
1310
+ if isinstance(tool_call, dict):
1311
+ tool_name = tool_call.get("name", "")
1312
+ args = tool_call.get("args", {})
1313
+ tool_call_id = tool_call.get("id", str(uuid.uuid4()))
1314
+ else:
1315
+ tool_name = getattr(tool_call, "name", "")
1316
+ args = getattr(tool_call, "args", {})
1317
+ tool_call_id = getattr(tool_call, "id", str(uuid.uuid4()))
1318
 
1319
+ # Find the matching tool
1320
+ selected_tool = None
1321
+ for tool in tools:
1322
+ if tool.name.lower() == tool_name.lower():
1323
+ selected_tool = tool
1324
  break
1325
 
1326
+ if not selected_tool:
1327
+ tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}"
1328
+ print(f"Tool not found: {tool_name}")
 
 
 
 
1329
  else:
1330
+ try:
1331
+ # Extract query
1332
+ if isinstance(args, dict) and "query" in args:
1333
+ query = args["query"]
1334
+ else:
1335
+ query = str(args) if args else ""
1336
+
1337
+ print(f"Executing {tool_name} with query: {query[:100]}...")
1338
+ tool_result = selected_tool.run(query)
1339
+
1340
+ # Aggressive truncation to prevent memory issues
1341
+ max_length = 3000 if "wikipedia" in tool_name.lower() else 2000
1342
+ if len(tool_result) > max_length:
1343
+ tool_result = tool_result[:max_length] + f"... [Result truncated from {len(tool_result)} to {max_length} chars to prevent memory issues]"
1344
+ print(f"📄 Truncated result to {max_length} characters")
1345
+
1346
+ print(f"Tool result length: {len(tool_result)} characters")
1347
+
1348
+ except Exception as e:
1349
+ tool_result = f"Error executing tool '{tool_name}': {str(e)}"
1350
+ print(f"Tool execution error: {e}")
1351
 
1352
+ # Create tool message
1353
+ tool_message = ToolMessage(
1354
+ content=tool_result,
1355
+ name=tool_name,
1356
+ tool_call_id=tool_call_id
 
1357
  )
1358
+ new_messages.append(tool_message)
1359
+ print(f"Added tool message for {tool_name}")
 
1360
 
1361
+ # Update the state
1362
+ new_state = state.copy()
1363
+ new_state["messages"] = new_messages
 
 
1364
 
1365
+ # Clear CUDA cache after processing
1366
  try:
1367
+ import torch
1368
+ if torch.cuda.is_available():
1369
+ torch.cuda.empty_cache()
1370
+ except:
1371
+ pass
1372
+
1373
+ return new_state
1374
+
1375
+
1376
+ def execute_parsed_tool_calls(state: AgentState):
1377
+ """
1378
+ Execute tool calls that were parsed from the Thought/Action/Action Input format.
1379
+ This is called by call_tool when parsed_tool_calls are present in state.
1380
+
1381
+ Args:
1382
+ state (AgentState): The current state containing parsed tool calls
 
 
 
 
 
 
 
 
 
 
 
1383
 
1384
+ Returns:
1385
+ AgentState: Updated state with tool results
1386
+ """
1387
+
1388
+ # Use the same tools list that's available globally
1389
+ # Map tool names to the actual tool instances
1390
+ tool_name_mappings = {
1391
+ 'wikipedia_semantic_search': 'wikipedia_tool',
1392
+ 'wikipedia': 'wikipedia_tool',
1393
+ 'search': 'enhanced_search', # Updated mapping
1394
+ 'duckduckgo_search': 'enhanced_search', # Updated mapping
1395
+ 'web_search': 'enhanced_search', # Updated mapping
1396
+ 'enhanced_search': 'enhanced_search', # Direct mapping
1397
+ 'youtube_screenshot_qa_tool': 'youtube_tool',
1398
+ 'youtube': 'youtube_tool',
1399
+ }
1400
+
1401
+
1402
+ # Create a lookup by tool names for your existing tools list
1403
+ tools_by_name = {}
1404
+ for tool in tools:
1405
+ tools_by_name[tool.name.lower()] = tool
1406
+ # Also map by class name for flexibility
1407
+ class_name = tool.__class__.__name__.lower()
1408
+ if 'wikipedia' in class_name:
1409
+ tools_by_name['wikipedia_tool'] = tool
1410
+ elif 'search' in class_name or 'duck' in class_name:
1411
+ tools_by_name['search_tool'] = tool
1412
+ elif 'youtube' in class_name:
1413
+ tools_by_name['youtube_tool'] = tool
1414
+
1415
+ # Copy messages to avoid mutation during iteration
1416
+ new_messages = list(state["messages"])
1417
+
1418
+ for tool_call in state['parsed_tool_calls']:
1419
+ action = tool_call['action']
1420
+ action_input = tool_call['action_input']
1421
+ thought = tool_call['thought']
1422
+ normalized_action = tool_call['normalized_action']
1423
 
1424
+ print(f"🚀 Executing tool: {action} with input: {action_input}")
 
1425
 
1426
+ # Find the tool instance
1427
+ tool_instance = None
 
 
1428
 
1429
+ # Try direct name match first
1430
+ if normalized_action in tools_by_name:
1431
+ tool_instance = tools_by_name[normalized_action]
1432
+ # Try mapped name
1433
+ elif normalized_action in tool_name_mappings:
1434
+ mapped_name = tool_name_mappings[normalized_action]
1435
+ if mapped_name in tools_by_name:
1436
+ tool_instance = tools_by_name[mapped_name]
1437
+
1438
+ if tool_instance:
1439
+ try:
1440
+ result = tool_instance.run(action_input)
1441
+ if len(result) > 6000:
1442
+ result = result[:6000] + "... [Result truncated due to length]"
1443
+
1444
+ # Create observation message in the format your agent expects
1445
+ from langchain_core.messages import AIMessage
1446
+ observation = f"Observation: {result}"
1447
+ observation_message = AIMessage(content=observation)
1448
+ new_messages.append(observation_message)
1449
+
1450
+ print(f"✅ Tool '{action}' executed successfully")
1451
+
1452
+ except Exception as e:
1453
+ print(f"❌ Error executing tool '{action}': {e}")
1454
+ from langchain_core.messages import AIMessage
1455
+ error_msg = f"Observation: Error executing '{action}': {str(e)}"
1456
+ error_message = AIMessage(content=error_msg)
1457
+ new_messages.append(error_message)
1458
+ else:
1459
+ print(f"❌ Tool '{action}' not found in available tools")
1460
+ available_tool_names = list(tools_by_name.keys())
1461
+ from langchain_core.messages import AIMessage
1462
+ error_msg = f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}"
1463
+ error_message = AIMessage(content=error_msg)
1464
+ new_messages.append(error_message)
1465
+
1466
+ # Update state with new messages and clear parsed tool calls
1467
+ state["messages"] = new_messages
1468
+ state['parsed_tool_calls'] = []
1469
+
1470
  return state
1471
 
1472
+ # 1. Add loop detection to your AgentState
1473
+ def should_continue(state: AgentState) -> str:
1474
+ """Determine if the agent should continue or end."""
1475
+ print("Running should_continue....")
1476
+ messages = state["messages"]
1477
+
1478
+ #ipdb.set_trace()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1479
 
1480
+ # Check if we're done
1481
+ if state.get("done", False):
1482
+ return "end"
1483
+
1484
+ # Prevent infinite loops - limit tool calls
1485
+ tool_call_count = sum(1 for msg in messages if hasattr(msg, 'tool_calls') and msg.tool_calls)
1486
+ if tool_call_count >= 3: # Max 3 tool calls per conversation
1487
+ print(f"⚠️ Stopping: Too many tool calls ({tool_call_count})")
1488
+ return "end"
1489
+
1490
+ # Check for repeated tool calls with same query
1491
+ recent_tool_calls = []
1492
+ for msg in messages[-6:]: # Check last 6 messages
1493
+ if hasattr(msg, 'tool_calls') and msg.tool_calls:
1494
+ for tool_call in msg.tool_calls:
1495
+ if isinstance(tool_call, dict):
1496
+ recent_tool_calls.append((tool_call.get('name'), str(tool_call.get('args', {}))))
1497
+
1498
+ if len(recent_tool_calls) >= 2 and recent_tool_calls[-1] == recent_tool_calls[-2]:
1499
+ print("⚠️ Stopping: Repeated tool call detected")
1500
+ return "end"
1501
+
1502
+ # Check message count to prevent runaway conversations
1503
+ if len(messages) > 15:
1504
+ print(f"⚠️ Stopping: Too many messages ({len(messages)})")
1505
+ return "end"
1506
+
1507
+ return "continue"
1508
+
1509
+ def route_after_parse_react(state: AgentState) -> str:
1510
+ """Determines the next step after parsing LLM output, prioritizing end state."""
1511
+ if state.get("done", False): # Check if parse_react_output decided we are done
1512
+ return "end_processing"
1513
+
1514
+ # Original logic: check for tool calls in the last message
1515
+ # Ensure messages list and last message exist before checking tool_calls
1516
+ messages = state.get("messages", [])
1517
+ if messages:
1518
+ last_message = messages[-1]
1519
+ if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
1520
+ return "call_tool"
1521
+ return "call_llm"
1522
+
1523
+ #wikipedia_tool = WikipediaSearchToolWithFAISS()
1524
+ #search_tool = DuckDuckGoSearchRun()
1525
+ #youtube_screenshot_qa_tool = YoutubeScreenshotQA()
1526
+
1527
+ # Combine all tools
1528
+ #tools = [wikipedia_tool, search_tool, youtube_screenshot_qa_tool]
1529
+
1530
+ # Update your tools list to use the global instances
1531
+ #
1532
+
1533
+ # --- Graph Construction ---
1534
+ # --- Graph Construction ---
1535
+ def create_memory_safe_workflow():
1536
+ """Create a workflow with memory management and loop prevention."""
1537
+ # These models are initialized here but might be better managed if they need to be released/reinitialized
1538
+ # like you attempt in run_agent. Consider passing them or managing their lifecycle carefully.
1539
+ hf_pipe = create_llm_pipeline()
1540
+ llm = HuggingFacePipeline(pipeline=hf_pipe)
1541
+ # vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly
1542
+ # processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used
1543
+ # model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used
1544
+
1545
+ workflow = StateGraph(AgentState)
1546
+
1547
+ # Bind the llm_model to the call_llm_with_memory_management function
1548
+ bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm)
1549
+
1550
+ # Add nodes with memory-safe versions
1551
+ workflow.add_node("call_llm", bound_call_llm) # Use the bound version here
1552
+ workflow.add_node("parse_react_output", parse_react_output)
1553
+ workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly
1554
+
1555
+ # Set entry point
1556
+ workflow.set_entry_point("call_llm")
1557
+
1558
+ # Add conditional edges
1559
+ workflow.add_conditional_edges(
1560
+ "call_llm",
1561
+ should_continue,
1562
+ {
1563
+ "continue": "parse_react_output",
1564
+ "end": END
1565
+ }
1566
+ )
1567
 
1568
+ workflow.add_conditional_edges(
1569
+ "parse_react_output",
1570
+ route_after_parse_react,
1571
+ {
1572
+ "call_tool": "call_tool",
1573
+ "call_llm": "call_llm",
1574
+ "end_processing": END
1575
+ }
1576
+ )
1577
 
1578
+ workflow.add_edge("call_tool", "call_llm")
1579
+
1580
+ return workflow.compile()
1581
+
1582
+ # --- Run the Agent ---
1583
+ def run_agent(myagent, state: AgentState):
1584
+ """
1585
+ Initialize agent with proper system message and formatted query.
1586
+ """
1587
+ #global llm, hf_pipe, model_vqa, processor_vqa
1588
+ global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, tools
1589
+
1590
+ #ipdb.set_trace()
1591
+
1592
+ # At the module level, create instances once
1593
+ WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS()
1594
+ SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=3000)
1595
+ YOUTUBE_TOOL = YoutubeScreenshotQA()
1596
+
1597
+ tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL]
1598
+
1599
+ # Create a fresh system message each time
1600
+ formatted_tools_description = render_text_description(tools)
1601
+ current_date_str = datetime.now().strftime("%Y-%m-%d")
1602
+
1603
+ system_content = f"""You are a general AI assistant. with access to these tools:
1604
+
1605
+ {formatted_tools_description}
1606
+
1607
+ If you need the most current information as of 2025, use enhanced_search
1608
+ If you need to do in-depth research, use wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval
1609
+ If you can answer the question confidently, do so directly.
1610
+ If the question seems like gibberish (not English), try flipping the entire question and re-read the question.
1611
+ If you need more information, use a tool.
1612
+ (Think through the problem step by step)
1613
+
1614
+ When using a tool, follow this format:
1615
+ Thought: <your thought>
1616
+ Action: <tool_name>
1617
+ Action Input: <tool_input>
1618
+
1619
+ Only use the tools listed above for the Action: step. Do not invent new tool names or actions. If you need to reason, do so in the Thought: step. After using a tool, process its output in your Thought: step, not as an Action.
1620
+
1621
+ Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
1622
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
1623
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
1624
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
1625
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string
1626
+ Do not provide disclaimers.
1627
+ Do not provide supporting details.
1628
+
1629
+ """
1630
+
1631
+ # Get user question from AgentState
1632
+ query = state['question']
1633
+
1634
+ # Pattern for YouTube
1635
+ yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+"
1636
+ has_youtube = re.search(yt_pattern, query) is not None
1637
+
1638
+ if has_youtube:
1639
+ # Store the extracted YouTube URL in the state
1640
+ url_match = re.search(r"(https?://[^\s]+)", query)
1641
+ if url_match:
1642
+ state['youtube_url'] = url_match.group(0)
1643
 
1644
+ # Format the user query to guide the model better
1645
+ formatted_query = f"""{query}"""
1646
+
1647
+ # Initialize agent state with proper message types
1648
+ system_message = SystemMessage(content=system_content)
1649
+ human_message = HumanMessage(content=formatted_query)
1650
+
1651
+ # Initialize state with properly typed messages and done=False
1652
+ # state = {"messages": [system_message, human_message], "done": False}
1653
+ state['messages'] = [system_message, human_message]
1654
+ state["done"] = False
1655
+
1656
+ # Use the new method to run the graph
1657
+ result = myagent.invoke(state)
1658
+
1659
+ # Check if FINAL ANSWER was given (i.e., workflow ended)
1660
+ if result.get("done"):
1661
+ #del llm
1662
+ #del hf_pipe
1663
+ #del model_vqa
1664
+ #del processor_vqa
1665
+ torch.cuda.empty_cache()
1666
+ torch.cuda.ipc_collect()
1667
+ gc.collect()
1668
+ print("Released GPU memory after FINAL ANSWER.")
1669
+ # Re-initialize for the next run
1670
+ #hf_pipe = create_llm_pipeline()
1671
+ #llm = HuggingFacePipeline(pipeline=hf_pipe)
1672
+ #print("Re-initilized llm...")
1673
+
1674
+ # Extract and return just the messages for cleaner output
1675
+ return result["messages"]
1676
+
1677
+
1678
+