Samuel Thomas commited on
Commit
82de5c7
·
1 Parent(s): da40168

changes to model

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. tools.py +314 -200
app.py CHANGED
@@ -100,7 +100,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
100
  task_id = hf_questions[r]['task_id']
101
  question_text = hf_questions[r]['question']
102
  submitted_answer = intelligent_agent(s)
103
- answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
104
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
105
  except:
106
  print(f"Error running agent on task {task_id}: {e}")
 
100
  task_id = hf_questions[r]['task_id']
101
  question_text = hf_questions[r]['question']
102
  submitted_answer = intelligent_agent(s)
103
+ answers_payload.append({"task_id": task_id, "model_answer": submitted_answer})
104
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
105
  except:
106
  print(f"Error running agent on task {task_id}: {e}")
tools.py CHANGED
@@ -22,6 +22,7 @@ from langchain.schema import Document
22
  from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline
23
  from io import BytesIO
24
  from sentence_transformers import SentenceTransformer
 
25
 
26
 
27
  import os
@@ -84,8 +85,8 @@ def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str:
84
  class State(TypedDict, total=False):
85
  question: str
86
  task_id: str
87
- input_file: bytes
88
- file_type: str
89
  context: List[Document] # Using LangChain's Document class
90
  file_path: Optional[str]
91
  youtube_url: Optional[str]
@@ -94,31 +95,33 @@ class State(TypedDict, total=False):
94
  next: Optional[str] # Added to track the next node
95
 
96
  # --- LLM pipeline for general questions ---
97
- llm_pipe = pipeline("text-generation",
98
- #model="meta-llama/Llama-3.3-70B-Instruct",
99
- #model="meta-llama/Meta-Llama-3-8B-Instruct",
100
- #model="Qwen/Qwen2-7B-Instruct",
101
- #model="microsoft/Phi-4-reasoning",
102
- model="microsoft/Phi-3-mini-4k-instruct",
103
- device_map="auto",
104
- #device_map={ "": 0 }, # "" means the whole model
105
- #max_memory={0: "10GiB"},
106
- torch_dtype="auto",
107
- max_new_tokens=256)
 
 
 
 
 
 
108
 
109
  # Speech-to-text pipeline
110
  asr_pipe = pipeline(
111
  "automatic-speech-recognition",
112
  model="openai/whisper-small",
113
- device=-1
114
- #device_map={"", 0},
115
- #max_memory = {0: "4.5GiB"},
116
- #device_map="auto"
117
  )
118
 
119
- # --- Your BLIP VQA setup ---
120
- #device = "cuda" if torch.cuda.is_available() else "cpu"
121
- device = "cpu"
122
  vqa_model_name = "Salesforce/blip-vqa-base"
123
  processor_vqa = BlipProcessor.from_pretrained(vqa_model_name)
124
 
@@ -130,18 +133,47 @@ except torch.cuda.OutOfMemoryError:
130
  device = "cpu" # Switch device to CPU
131
  model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # --- Helper: Answer question on a single frame ---
135
  def answer_question_on_frame(image_path, question):
136
- # Fixed: Properly use the PIL Image module
137
- image = Image.open(image_path).convert('RGB')
138
- inputs = processor_vqa(image, question, return_tensors="pt").to(device)
139
- out = model_vqa.generate(**inputs)
140
- answer = processor_vqa.decode(out[0], skip_special_tokens=True)
141
- return answer
142
-
143
- # --- Helper: Answer question about the whole video ---
 
 
 
144
  def answer_video_question(frames_dir, question):
 
145
  valid_exts = ('.jpg', '.jpeg', '.png')
146
 
147
  # Check if directory exists
@@ -193,8 +225,8 @@ def answer_video_question(frames_dir, question):
193
  "answer_counts": counted
194
  }
195
 
196
-
197
- def download_youtube_video(url, output_dir='tmp/content/video/', output_filename='downloaded_video.mp4'):
198
  # Ensure the output directory exists
199
  os.makedirs(output_dir, exist_ok=True)
200
 
@@ -209,25 +241,27 @@ def download_youtube_video(url, output_dir='tmp/content/video/', output_filename
209
  # Set output path for yt-dlp
210
  output_path = os.path.join(output_dir, output_filename)
211
 
212
- ydl_opts = {
213
- 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
214
- 'outtmpl': output_path,
215
- 'quiet': True,
216
- 'merge_output_format': 'mp4', # Ensures merged output is mp4
217
- 'postprocessors': [{
218
- 'key': 'FFmpegVideoConvertor',
219
- 'preferedformat': 'mp4', # Recode if needed
220
- }]
221
- }
222
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
223
- ydl.download([url])
224
- return output_path
225
-
226
-
 
 
227
 
228
- # --- Helper: Extract frames from video ---
229
  def extract_frames(video_path, output_dir, frame_interval_seconds=10):
230
- # --- Clean output directory before extracting new frames ---
 
231
  if os.path.exists(output_dir):
232
  for filename in os.listdir(output_dir):
233
  file_path = os.path.join(output_dir, filename)
@@ -266,33 +300,23 @@ def extract_frames(video_path, output_dir, frame_interval_seconds=10):
266
  print(f"Exception during frame extraction: {e}")
267
  return False
268
 
269
- def image_qa(image_path: str, question: str, model_name: str = vqa_model_name) -> str:
270
- """
271
- Answers questions about images using Hugging Face's VQA pipeline.
272
-
273
- Args:
274
- image_path: Path to local image file or URL
275
- question: Natural language question about the image
276
- model_name: Pretrained VQA model (default: good general-purpose model)
277
-
278
- Returns:
279
- str: The model's best answer
280
- """
281
- # Create VQA pipeline with specified model
282
- vqa_pipeline = pipeline("visual-question-answering", model=model_name)
283
-
284
- # Get predictions (automatically handles local files/URLs)
285
- results = vqa_pipeline(image=image_path, question=question, top_k=1)
286
-
287
- # Return top answer
288
- return results[0]['answer']
289
-
290
 
 
291
  def router(state: Dict[str, Any]) -> str:
292
- """Determine the next node based on whether the question contains a YouTube URL or references Wikipedia."""
293
  question = state.get('question', '')
294
 
295
-
296
  # Pattern for Wikipedia and similar sources
297
  wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)"
298
  has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None
@@ -327,30 +351,52 @@ def router(state: Dict[str, Any]) -> str:
327
  else:
328
  return "llm"
329
 
330
-
331
- # --- Node Implementation ---
332
- def node_image(state: Dict[str, Any]) -> Dict[str, Any]:
333
- """Router node that decides which node to go to next."""
334
- print("Running node_image")
335
- # Add the next state to the state dict
336
- img = Image.open(state['file_path'])
337
- state['answer'] = image_qa(state['file_path'], state['question'])
338
- return state
339
-
340
-
341
  def node_decide(state: Dict[str, Any]) -> Dict[str, Any]:
342
- """Router node that decides which node to go to next."""
343
  print("Running node_decide")
 
 
 
344
  # Add the next state to the state dict
345
  state["next"] = router(state)
346
  print(f"Routing to: {state['next']}")
347
  return state
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def node_video(state: Dict[str, Any]) -> Dict[str, Any]:
 
350
  print("Running node_video")
351
  youtube_url = state.get('youtube_url')
352
  if not youtube_url:
353
- state['answer'] = "No YouTube URL found in the question."
354
  return state
355
 
356
  question = state['question']
@@ -361,7 +407,7 @@ def node_video(state: Dict[str, Any]) -> Dict[str, Any]:
361
 
362
  video_file = download_youtube_video(youtube_url)
363
  if not video_file or not os.path.exists(video_file):
364
- state['answer'] = "Failed to download the video."
365
  return state
366
 
367
  frames_dir = "/tmp/frames"
@@ -369,11 +415,11 @@ def node_video(state: Dict[str, Any]) -> Dict[str, Any]:
369
 
370
  success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10)
371
  if not success:
372
- state['answer'] = "Failed to extract frames from the video."
373
  return state
374
 
375
  result = answer_video_question(frames_dir, question_text)
376
- state['answer'] = result['most_common_answer']
377
  state['frame_answers'] = result['all_answers']
378
 
379
  # Create Document objects for each frame analysis
@@ -385,15 +431,15 @@ def node_video(state: Dict[str, Any]) -> Dict[str, Any]:
385
  )
386
  frame_documents.append(doc)
387
 
388
- # Add documents to state if not already present
389
- if 'context' not in state:
390
- state['context'] = []
391
  state['context'].extend(frame_documents)
 
392
 
393
  print(f"Video answer: {state['answer']}")
394
  return state
395
 
396
  def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]:
 
397
  print(f"Processing audio file: {state['file_path']}")
398
 
399
  try:
@@ -403,52 +449,65 @@ def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]:
403
  audio_transcript = asr_result['text']
404
  print(f"Audio transcript: {audio_transcript}")
405
 
406
- # Step 2: Store ONLY the transcript in the vector store
407
  transcript_doc = [Document(page_content=audio_transcript)]
408
  embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5')
409
  vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings)
410
 
411
  # Step 3: Retrieve relevant docs for the user's question
412
  question = state['question']
413
- similar_docs = vector_db.similarity_search(question, k=1) # Only one doc in store
414
  retrieved_context = "\n".join([doc.page_content for doc in similar_docs])
415
 
416
- # Step 4: Augment prompt and generate answer
417
  prompt = (
418
- f"Use the following context to answer the question.\n"
419
- f"Context:\n{retrieved_context}\n\n"
420
- f"Question: {question}\nAnswer:"
 
 
 
421
  )
 
422
  llm_response = llm_pipe(prompt)
423
- state['answer'] = llm_response[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
424
 
425
  except Exception as e:
426
  error_msg = f"Audio processing error: {str(e)}"
427
  print(error_msg)
428
- state['answer'] = error_msg
429
 
430
  return state
431
 
432
  def node_llm(state: Dict[str, Any]) -> Dict[str, Any]:
 
433
  print("Running node_llm")
434
  question = state['question']
435
 
436
- # Optionally add context from state (e.g., Wikipedia/Wikidata content)
437
- context_text = ""
438
- if 'article_content' in state and state['article_content']:
439
- context_text = f"\n\nBackground Information:\n{state['article_content']}\n"
440
- elif 'context' in state and state['context']:
441
- context_text = "\n\n".join([doc.page_content for doc in state['context']])
442
-
443
  # Compose a detailed prompt
444
  prompt = (
445
- "You are an expert researcher. Answer the user's question as accurately as possible. "
446
- "If the text appears to be scrambled, try to unscramble the text for the user"
447
- "If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. "
448
- "If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step.\n\n"
449
- f"Question: {question}"
450
- f"{context_text}\n"
451
- "Answer:"
 
 
 
 
 
452
  )
453
 
454
  # Add document to state for traceability
@@ -456,102 +515,138 @@ def node_llm(state: Dict[str, Any]) -> Dict[str, Any]:
456
  page_content=prompt,
457
  metadata={"source": "llm_prompt"}
458
  )
459
- if 'context' not in state:
460
- state['context'] = []
461
  state['context'].append(query_doc)
462
 
463
  try:
464
  result = llm_pipe(prompt)
465
- state['answer'] = result[0]['generated_text']
 
466
  except Exception as e:
467
  print(f"Error in LLM processing: {str(e)}")
468
- state['answer'] = f"An error occurred while processing your question: {str(e)}"
 
469
 
470
  print(f"LLM answer: {state['answer']}")
471
  return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  # --- Define the edge condition function ---
475
  def get_next_node(state: Dict[str, Any]) -> str:
476
- """Get the next node from the state."""
477
  return state["next"]
478
 
479
-
480
- # 2. Improved Wikipedia Retrieval Node
481
- def extract_keywords(question: str) -> List[str]:
482
- doc = nlp(question)
483
- keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] # Extract proper nouns and nouns
484
- return keywords
485
-
486
- def extract_entities(question: str) -> List[str]:
487
- doc = nlp(question)
488
- entities = [ent.text for ent in doc.ents]
489
- return entities if entities else [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")]
490
-
491
-
492
- def retrieve(state: State) -> dict:
493
- keywords = extract_entities(state["question"])
494
- query = " ".join(keywords)
495
- search_results = wikipedia.search(query)
496
- selected_page = search_results[0] if search_results else None
497
-
498
- if selected_page:
499
- loader = WikipediaLoader(
500
- query=selected_page,
501
- lang="en",
502
- load_max_docs=1,
503
- doc_content_chars_max=100000,
504
- load_all_available_meta=True
505
- )
506
- docs = loader.load()
507
- # Chunk the article for finer retrieval
508
- from langchain.text_splitter import RecursiveCharacterTextSplitter
509
- splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
510
- all_chunks = []
511
- for doc in docs:
512
- chunks = splitter.split_text(doc.page_content)
513
- all_chunks.extend([Document(page_content=chunk) for chunk in chunks])
514
- # Optionally: re-rank or filter chunks here
515
- return {"context": all_chunks}
516
- else:
517
- return {"context": []}
518
-
519
- # 3. Prompt Template for General QA
520
- prompt = PromptTemplate(
521
- input_variables=["question", "context"],
522
- template=(
523
- "You are an expert researcher. Given the following context from Wikipedia, answer the user's question as accurately as possible. "
524
- "If the text appears to be scrambled, try to unscramble the text for the user"
525
- "If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. "
526
- "If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step."
527
- "Context:\n{context}\n\n"
528
- "Question: {question}\n\n"
529
- "Best Estimate Answer:"
530
- )
531
- )
532
-
533
- """
534
- def generate(state: State) -> dict:
535
- # Concatenate all context documents into a single string
536
- docs_content = "\n\n".join(doc.page_content for doc in state["context"])
537
- # Format the prompt for the LLM
538
- prompt_str = prompt.format(question=state["question"], context=docs_content)
539
- # Generate answer
540
- response = llm.invoke(prompt_str)
541
- return {"answer": response}
542
- """
543
-
544
- def generate(state: dict) -> dict:
545
- # Concatenate all context documents into a single string
546
- docs_content = "\n\n".join(doc.page_content for doc in state["context"])
547
- # Format the prompt for the LLM
548
- prompt_str = prompt.format(question=state["question"], context=docs_content)
549
- # Generate answer using Hugging Face pipeline
550
- response = llm_pipe(prompt_str)
551
- # Extract generated text
552
- answer = response[0]["generated_text"]
553
- return {"answer": answer}
554
-
555
  # Create the StateGraph
556
  graph = StateGraph(State)
557
 
@@ -568,7 +663,7 @@ graph.add_node("audio", node_audio_rag)
568
  graph.add_edge(START, "decide")
569
  graph.add_edge("retrieve", "generate")
570
 
571
- # Add conditional edges from decide to video or llm based on question
572
  graph.add_conditional_edges(
573
  "decide",
574
  get_next_node,
@@ -581,7 +676,7 @@ graph.add_conditional_edges(
581
  }
582
  )
583
 
584
- # Add edges from video and llm to END to terminate the graph
585
  graph.add_edge("video", END)
586
  graph.add_edge("llm", END)
587
  graph.add_edge("generate", END)
@@ -591,14 +686,33 @@ graph.add_edge("audio", END)
591
  # Compile the graph
592
  agent = graph.compile()
593
 
594
- # --- Usage Example ---
595
  def intelligent_agent(state: State) -> str:
596
  """Process a question using the appropriate pipeline based on content."""
597
- #state = State(question= question)
598
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  final_state = agent.invoke(state)
600
- return final_state.get('answer', "No answer found.")
 
 
 
 
 
 
601
  except Exception as e:
602
  print(f"Error in agent execution: {str(e)}")
603
- return f"An error occurred: {str(e)}"
604
-
 
22
  from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline
23
  from io import BytesIO
24
  from sentence_transformers import SentenceTransformer
25
+ from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
26
 
27
 
28
  import os
 
85
  class State(TypedDict, total=False):
86
  question: str
87
  task_id: str
88
+ input_file: Optional[bytes]
89
+ file_type: Optional[str]
90
  context: List[Document] # Using LangChain's Document class
91
  file_path: Optional[str]
92
  youtube_url: Optional[str]
 
95
  next: Optional[str] # Added to track the next node
96
 
97
  # --- LLM pipeline for general questions ---
98
+ llm_pipe = pipeline(
99
+ "text-generation",
100
+ model="microsoft/Phi-3-mini-4k-instruct",
101
+ device_map=0,
102
+ torch_dtype="auto",
103
+ max_new_tokens=256
104
+ )
105
+
106
+ # Initialize RAG components
107
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
108
+ retriever = RagRetriever.from_pretrained(
109
+ "facebook/rag-token-base",
110
+ index_name="exact", # or "legacy" for legacy FAISS index
111
+ use_dummy_dataset=False, # set to False and download the full index for real Wikipedia retrieval
112
+ trust_remote_code=True
113
+ )
114
+ rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
115
 
116
  # Speech-to-text pipeline
117
  asr_pipe = pipeline(
118
  "automatic-speech-recognition",
119
  model="openai/whisper-small",
120
+ device=0
 
 
 
121
  )
122
 
123
+ # --- BLIP VQA setup ---
124
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
125
  vqa_model_name = "Salesforce/blip-vqa-base"
126
  processor_vqa = BlipProcessor.from_pretrained(vqa_model_name)
127
 
 
133
  device = "cpu" # Switch device to CPU
134
  model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device)
135
 
136
+ # --- Helper functions ---
137
+ def ensure_final_answer_format(answer_text: str) -> str:
138
+ """Ensure the answer ends with FINAL ANSWER: format"""
139
+ # Check if the answer already contains a FINAL ANSWER section
140
+ if "FINAL ANSWER:" in answer_text:
141
+ # Extract everything after FINAL ANSWER:
142
+ final_answer_part = answer_text.split("FINAL ANSWER:", 1)[1].strip()
143
+ return f"FINAL ANSWER: {final_answer_part}"
144
+ else:
145
+ # If no FINAL ANSWER section exists, wrap the entire answer
146
+ return f"FINAL ANSWER: {answer_text.strip()}"
147
+
148
+ def extract_entities(text: str) -> List[str]:
149
+ """Extract key entities from text using spaCy if available, or regex fallback"""
150
+ if nlp:
151
+ # Using spaCy for better entity extraction
152
+ doc = nlp(text)
153
+ entities = [ent.text for ent in doc.ents]
154
+ keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")]
155
+ return entities if entities else keywords
156
+ else:
157
+ # Simple fallback using regex to extract potential keywords
158
+ words = text.lower().split()
159
+ stopwords = ["what", "who", "when", "where", "why", "how", "is", "are", "the", "a", "an", "of", "in", "on", "at"]
160
+ keywords = [word for word in words if word not in stopwords and len(word) > 2]
161
+ return keywords
162
 
 
163
  def answer_question_on_frame(image_path, question):
164
+ """Answer a question about a single video frame using BLIP"""
165
+ try:
166
+ image = Image.open(image_path).convert('RGB')
167
+ inputs = processor_vqa(image, question, return_tensors="pt").to(device)
168
+ out = model_vqa.generate(**inputs)
169
+ answer = processor_vqa.decode(out[0], skip_special_tokens=True)
170
+ return answer
171
+ except Exception as e:
172
+ print(f"Error processing frame {image_path}: {str(e)}")
173
+ return "Error processing this frame"
174
+
175
  def answer_video_question(frames_dir, question):
176
+ """Answer a question about a video by analyzing extracted frames"""
177
  valid_exts = ('.jpg', '.jpeg', '.png')
178
 
179
  # Check if directory exists
 
225
  "answer_counts": counted
226
  }
227
 
228
+ def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'):
229
+ """Download a YouTube video using yt-dlp"""
230
  # Ensure the output directory exists
231
  os.makedirs(output_dir, exist_ok=True)
232
 
 
241
  # Set output path for yt-dlp
242
  output_path = os.path.join(output_dir, output_filename)
243
 
244
+ try:
245
+ ydl_opts = {
246
+ 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
247
+ 'outtmpl': output_path,
248
+ 'quiet': True,
249
+ 'merge_output_format': 'mp4', # Ensures merged output is mp4
250
+ 'postprocessors': [{
251
+ 'key': 'FFmpegVideoConvertor',
252
+ 'preferedformat': 'mp4', # Recode if needed
253
+ }]
254
+ }
255
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
256
+ ydl.download([url])
257
+ return output_path
258
+ except Exception as e:
259
+ print(f"Error downloading YouTube video: {str(e)}")
260
+ return None
261
 
 
262
  def extract_frames(video_path, output_dir, frame_interval_seconds=10):
263
+ """Extract frames from a video file at specified intervals"""
264
+ # Clean output directory before extracting new frames
265
  if os.path.exists(output_dir):
266
  for filename in os.listdir(output_dir):
267
  file_path = os.path.join(output_dir, filename)
 
300
  print(f"Exception during frame extraction: {e}")
301
  return False
302
 
303
+ def image_qa(image_path: str, question: str) -> str:
304
+ """Answer questions about an image using the BLIP model"""
305
+ try:
306
+ image = Image.open(image_path).convert('RGB')
307
+ inputs = processor_vqa(image, question, return_tensors="pt").to(device)
308
+ out = model_vqa.generate(**inputs)
309
+ answer = processor_vqa.decode(out[0], skip_special_tokens=True)
310
+ return answer
311
+ except Exception as e:
312
+ print(f"Error in image_qa: {str(e)}")
313
+ return f"Error processing image: {str(e)}"
 
 
 
 
 
 
 
 
 
 
314
 
315
+ # --- Node functions ---
316
  def router(state: Dict[str, Any]) -> str:
317
+ """Determine the next node based on question content and file type"""
318
  question = state.get('question', '')
319
 
 
320
  # Pattern for Wikipedia and similar sources
321
  wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)"
322
  has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None
 
351
  else:
352
  return "llm"
353
 
 
 
 
 
 
 
 
 
 
 
 
354
  def node_decide(state: Dict[str, Any]) -> Dict[str, Any]:
355
+ """Router node that decides which node to go to next"""
356
  print("Running node_decide")
357
+ # Initialize context list if not present
358
+ if 'context' not in state:
359
+ state['context'] = []
360
  # Add the next state to the state dict
361
  state["next"] = router(state)
362
  print(f"Routing to: {state['next']}")
363
  return state
364
 
365
+ def node_image(state: Dict[str, Any]) -> Dict[str, Any]:
366
+ """Process image-based questions"""
367
+ print("Running node_image")
368
+ try:
369
+ # Make sure the image file exists
370
+ if not os.path.exists(state['file_path']):
371
+ state['answer'] = ensure_final_answer_format("Image file not found.")
372
+ return state
373
+
374
+ # Get answer from image QA model
375
+ answer = image_qa(state['file_path'], state['question'])
376
+
377
+ # Format the final answer
378
+ state['answer'] = ensure_final_answer_format(answer)
379
+
380
+ # Add document to state for traceability
381
+ image_doc = Document(
382
+ page_content=f"Image analysis result: {answer}",
383
+ metadata={"source": "image_analysis", "file_path": state['file_path']}
384
+ )
385
+ state['context'].append(image_doc)
386
+
387
+ except Exception as e:
388
+ error_msg = f"Error processing image: {str(e)}"
389
+ print(error_msg)
390
+ state['answer'] = ensure_final_answer_format(error_msg)
391
+
392
+ return state
393
+
394
  def node_video(state: Dict[str, Any]) -> Dict[str, Any]:
395
+ """Process video-based questions"""
396
  print("Running node_video")
397
  youtube_url = state.get('youtube_url')
398
  if not youtube_url:
399
+ state['answer'] = ensure_final_answer_format("No YouTube URL found in the question.")
400
  return state
401
 
402
  question = state['question']
 
407
 
408
  video_file = download_youtube_video(youtube_url)
409
  if not video_file or not os.path.exists(video_file):
410
+ state['answer'] = ensure_final_answer_format("Failed to download the video.")
411
  return state
412
 
413
  frames_dir = "/tmp/frames"
 
415
 
416
  success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10)
417
  if not success:
418
+ state['answer'] = ensure_final_answer_format("Failed to extract frames from the video.")
419
  return state
420
 
421
  result = answer_video_question(frames_dir, question_text)
422
+ final_answer = result['most_common_answer']
423
  state['frame_answers'] = result['all_answers']
424
 
425
  # Create Document objects for each frame analysis
 
431
  )
432
  frame_documents.append(doc)
433
 
434
+ # Add documents to state
 
 
435
  state['context'].extend(frame_documents)
436
+ state['answer'] = ensure_final_answer_format(final_answer)
437
 
438
  print(f"Video answer: {state['answer']}")
439
  return state
440
 
441
  def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]:
442
+ """Process audio-based questions"""
443
  print(f"Processing audio file: {state['file_path']}")
444
 
445
  try:
 
449
  audio_transcript = asr_result['text']
450
  print(f"Audio transcript: {audio_transcript}")
451
 
452
+ # Step 2: Store transcript in vector store
453
  transcript_doc = [Document(page_content=audio_transcript)]
454
  embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5')
455
  vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings)
456
 
457
  # Step 3: Retrieve relevant docs for the user's question
458
  question = state['question']
459
+ similar_docs = vector_db.similarity_search(question, k=1)
460
  retrieved_context = "\n".join([doc.page_content for doc in similar_docs])
461
 
462
+ # Step 4: Generate answer
463
  prompt = (
464
+ f"You are an AI assistant that answers questions about audio content.\n\n"
465
+ f"Audio transcript: {retrieved_context}\n\n"
466
+ f"Question: {question}\n\n"
467
+ f"Based only on the provided audio transcript, answer the question. "
468
+ f"If the transcript does not contain relevant information, state that clearly.\n\n"
469
+ f"End your response with 'FINAL ANSWER: ' followed by a concise answer."
470
  )
471
+
472
  llm_response = llm_pipe(prompt)
473
+ answer_text = llm_response[0]['generated_text']
474
+
475
+ # Add documents to state
476
+ state['context'].extend(transcript_doc)
477
+ state['context'].append(Document(
478
+ page_content=prompt,
479
+ metadata={"source": "audio_analysis_prompt"}
480
+ ))
481
+
482
+ # Ensure final answer format
483
+ state['answer'] = ensure_final_answer_format(answer_text)
484
 
485
  except Exception as e:
486
  error_msg = f"Audio processing error: {str(e)}"
487
  print(error_msg)
488
+ state['answer'] = ensure_final_answer_format(error_msg)
489
 
490
  return state
491
 
492
  def node_llm(state: Dict[str, Any]) -> Dict[str, Any]:
493
+ """Process general knowledge questions with LLM"""
494
  print("Running node_llm")
495
  question = state['question']
496
 
 
 
 
 
 
 
 
497
  # Compose a detailed prompt
498
  prompt = (
499
+ "You are an AI assistant that answers questions using your general knowledge. "
500
+ "Follow these steps:\n\n"
501
+ "1. If the question appears to be scrambled or jumbled, first try to unscramble or reconstruct the intended meaning.\n"
502
+ "2. Analyze the question (unscrambled if needed) and use your own knowledge to answer it.\n"
503
+ "3. If the question can't be answered with certainty, provide your best estimate and clearly explain any assumptions.\n"
504
+ "4. Format your answer using these rules:\n"
505
+ " - Numbers: Plain digits without commas/units (e.g. 1234567)\n"
506
+ " - Strings: Minimal words, no articles/abbreviations\n"
507
+ " - Lists: comma-separated values without extra formatting\n\n"
508
+ "5. Always conclude with:\n"
509
+ "FINAL ANSWER: [your answer] (replace bracketed text)\n\n"
510
+ f"Current question: {question}"
511
  )
512
 
513
  # Add document to state for traceability
 
515
  page_content=prompt,
516
  metadata={"source": "llm_prompt"}
517
  )
 
 
518
  state['context'].append(query_doc)
519
 
520
  try:
521
  result = llm_pipe(prompt)
522
+ answer_text = result[0]['generated_text']
523
+ state['answer'] = ensure_final_answer_format(answer_text)
524
  except Exception as e:
525
  print(f"Error in LLM processing: {str(e)}")
526
+ error_msg = f"An error occurred while processing your question: {str(e)}"
527
+ state['answer'] = ensure_final_answer_format(error_msg)
528
 
529
  print(f"LLM answer: {state['answer']}")
530
  return state
531
+ def retrieve(state: State) -> State:
532
+ """Retrieve relevant documents using RAG"""
533
+ print("Running retrieve")
534
+ question = state["question"]
535
+
536
+ try:
537
+ # Tokenize the question
538
+ inputs = tokenizer(question, return_tensors="pt")
539
+
540
+ # Get doc_ids by using the retriever directly
541
+ question_hidden_states = rag_model.question_encoder(inputs["input_ids"])[0]
542
+ docs_dict = retriever(
543
+ inputs["input_ids"].numpy(),
544
+ question_hidden_states.detach().numpy(),
545
+ return_tensors="pt"
546
+ )
547
+
548
+ # Extract the retrieved passages
549
+ all_chunks = []
550
+
551
+ # Debug print to see what's in docs_dict
552
+ print(f"docs_dict keys: {docs_dict.keys()}")
553
+
554
+ # Check for different possible keys that might contain the documents
555
+ doc_text_key = None
556
+ for possible_key in ['retrieved_doc_text', 'doc_text', 'texts', 'documents']:
557
+ if possible_key in docs_dict:
558
+ doc_text_key = possible_key
559
+ break
560
+
561
+ if doc_text_key:
562
+ # Access the retrieved document texts from the docs_dict
563
+ for i in range(len(docs_dict["doc_ids"][0])):
564
+ doc_text = docs_dict[doc_text_key][0][i]
565
+ all_chunks.append(Document(page_content=doc_text))
566
+
567
+ print(f"Retrieved {len(all_chunks)} documents")
568
+ else:
569
+ # Fallback: Try to extract document text from doc_ids
570
+ doc_ids = docs_dict.get("doc_ids", [[]])[0]
571
+ print(f"Retrieved doc_ids: {doc_ids}")
572
+
573
+ # Create minimal document stubs from IDs
574
+ for doc_id in doc_ids:
575
+ stub_text = f"Information related to document ID: {doc_id}"
576
+ all_chunks.append(Document(page_content=stub_text))
577
+
578
+ print(f"Created {len(all_chunks)} document stubs from IDs")
579
+
580
+ # Add documents to state context
581
+ if not state.get('context'):
582
+ state['context'] = []
583
+ state['context'].extend(all_chunks)
584
+
585
+ except Exception as e:
586
+ print(f"Error in retrieval: {str(e)}")
587
+ # Create an error document
588
+ error_doc = Document(
589
+ page_content=f"Error during retrieval: {str(e)}",
590
+ metadata={"source": "retrieval_error"}
591
+ )
592
+ if not state.get('context'):
593
+ state['context'] = []
594
+ state['context'].append(error_doc)
595
+
596
+ return state
597
 
598
+ def generate(state: State) -> State:
599
+ """Generate an answer based on retrieved documents"""
600
+ print("Running generate")
601
+
602
+ try:
603
+ # Check if context exists
604
+ if not state.get('context') or len(state['context']) == 0:
605
+ state['answer'] = ensure_final_answer_format("No relevant information found to answer your question.")
606
+ return state
607
+
608
+ # Concatenate all context documents into a single string
609
+ docs_content = "\n\n".join(doc.page_content for doc in state["context"])
610
+
611
+ # Format the prompt for the LLM
612
+ prompt_str = PromptTemplate(
613
+ input_variables=["question", "context"],
614
+ template=(
615
+ "You are an AI assistant that answers questions using retrieved context. "
616
+ "Follow these steps:\n\n"
617
+ "1. Analyze the provided context:\n{context}\n\n"
618
+ "2. If the context contains scrambled text, first attempt to reconstruct meaningful information\n"
619
+ "3. If the question can't be answered from context alone, combine context with general knowledge "
620
+ "but clearly state this limitation\n"
621
+ "4. Format your answer using these rules:\n"
622
+ " - Numbers: Plain digits without commas/units (e.g. 1234567)\n"
623
+ " - Strings: Minimal words, no articles/abbreviations\n"
624
+ " - Lists: comma-separated values without extra formatting\n\n"
625
+ "5. Always conclude with:\n"
626
+ "FINAL ANSWER: [your answer] (replace bracketed text)\n\n"
627
+ "Current question: {question}"
628
+ )
629
+ ).format(question=state["question"], context=docs_content)
630
+
631
+ # Generate answer using the LLM pipeline
632
+ response = llm_pipe(prompt_str)
633
+ answer_text = response[0]["generated_text"]
634
+
635
+ # Ensure answer has the FINAL ANSWER format
636
+ state['answer'] = ensure_final_answer_format(answer_text)
637
+
638
+ except Exception as e:
639
+ print(f"Error in generate node: {str(e)}")
640
+ error_msg = f"Error generating answer: {str(e)}"
641
+ state['answer'] = ensure_final_answer_format(error_msg)
642
+
643
+ return state
644
 
645
  # --- Define the edge condition function ---
646
  def get_next_node(state: Dict[str, Any]) -> str:
647
+ """Get the next node from the state"""
648
  return state["next"]
649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  # Create the StateGraph
651
  graph = StateGraph(State)
652
 
 
663
  graph.add_edge(START, "decide")
664
  graph.add_edge("retrieve", "generate")
665
 
666
+ # Add conditional edges from decide to other nodes based on question
667
  graph.add_conditional_edges(
668
  "decide",
669
  get_next_node,
 
676
  }
677
  )
678
 
679
+ # Add edges from all terminal nodes to END
680
  graph.add_edge("video", END)
681
  graph.add_edge("llm", END)
682
  graph.add_edge("generate", END)
 
686
  # Compile the graph
687
  agent = graph.compile()
688
 
689
+ # --- Intelligent Agent Function ---
690
  def intelligent_agent(state: State) -> str:
691
  """Process a question using the appropriate pipeline based on content."""
 
692
  try:
693
+ # Ensure state has proper structure
694
+ if not isinstance(state, dict):
695
+ return "FINAL ANSWER: Error - input must be a valid State dictionary"
696
+
697
+ # Make sure question exists
698
+ if 'question' not in state:
699
+ return "FINAL ANSWER: Error - question is required"
700
+
701
+ # Initialize context if not present
702
+ if 'context' not in state:
703
+ state['context'] = []
704
+
705
+ print(f"Processing question: {state['question']}")
706
+
707
+ # Invoke the agent with the state
708
  final_state = agent.invoke(state)
709
+
710
+ # Ensure answer has FINAL ANSWER format
711
+ answer = final_state.get('answer', "No answer found.")
712
+ formatted_answer = ensure_final_answer_format(answer)
713
+
714
+ return formatted_answer
715
+
716
  except Exception as e:
717
  print(f"Error in agent execution: {str(e)}")
718
+ return f"FINAL ANSWER: An error occurred - {str(e)}"