Samuel Thomas commited on
Commit
38df4e4
·
1 Parent(s): 6f21ce8

reading from api

Browse files
Files changed (2) hide show
  1. app.py +25 -5
  2. tools.py +601 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -53,11 +54,11 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
53
  try:
54
  response = requests.get(questions_url, timeout=15)
55
  response.raise_for_status()
56
- questions_data = response.json()
57
- if not questions_data:
58
  print("Fetched questions list is empty.")
59
  return "Fetched questions list is empty or invalid format.", None
60
- print(f"Fetched {len(questions_data)} questions.")
61
  except requests.exceptions.RequestException as e:
62
  print(f"Error fetching questions: {e}")
63
  return f"Error fetching questions: {e}", None
@@ -68,7 +69,26 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
68
  except Exception as e:
69
  print(f"An unexpected error occurred fetching questions: {e}")
70
  return f"An unexpected error occurred fetching questions: {e}", None
71
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # 3. Run your Agent
73
  results_log = []
74
  answers_payload = []
@@ -138,7 +158,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
138
  print(status_message)
139
  results_df = pd.DataFrame(results_log)
140
  return status_message, results_df
141
-
142
 
143
  # --- Build Gradio Interface using Blocks ---
144
  with gr.Blocks() as demo:
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from tools import intelligent_agent, get_file_type, write_bytes_to_temp_dir
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
 
54
  try:
55
  response = requests.get(questions_url, timeout=15)
56
  response.raise_for_status()
57
+ hf_questions = response.json()
58
+ if not hf_questions:
59
  print("Fetched questions list is empty.")
60
  return "Fetched questions list is empty or invalid format.", None
61
+ print(f"Fetched {len(hf_questions)} questions.")
62
  except requests.exceptions.RequestException as e:
63
  print(f"Error fetching questions: {e}")
64
  return f"Error fetching questions: {e}", None
 
69
  except Exception as e:
70
  print(f"An unexpected error occurred fetching questions: {e}")
71
  return f"An unexpected error occurred fetching questions: {e}", None
72
+
73
+ # 3. Create states
74
+ for item in hf_questions:
75
+ file_name = item.get('file_name', '')
76
+ if file_name == '':
77
+ item['input_file'] = None
78
+ item['file_type'] = None
79
+ item['file_path'] = None
80
+ else:
81
+ # Call the API to retrieve the file; adjust params as needed
82
+ task_id = item['task_id']
83
+ api_response = requests.get(f"{api_url}/{task_id}")
84
+ if api_response.status_code == 200:
85
+ item['input_file'] = api_response.content # Store file as bytes
86
+ item['file_type'] = get_file_type(file_name)
87
+ item['file_path'] = write_bytes_to_temp_dir(item['input_file'], file_name)
88
+ else:
89
+ item['input_file'] = None # Or handle error as needed
90
+
91
+ """
92
  # 3. Run your Agent
93
  results_log = []
94
  answers_payload = []
 
158
  print(status_message)
159
  results_df = pd.DataFrame(results_log)
160
  return status_message, results_df
161
+ """
162
 
163
  # --- Build Gradio Interface using Blocks ---
164
  with gr.Blocks() as demo:
tools.py CHANGED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import spacy
3
+ import tempfile
4
+ import glob
5
+ import yt_dlp
6
+ import shutil
7
+ import cv2
8
+ import librosa
9
+ import wikipedia
10
+
11
+ from typing import TypedDict, List, Optional, Dict, Any
12
+ from langchain.docstore.document import Document
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain_community.document_loaders import WikipediaLoader
15
+ from langgraph.graph import START, END, StateGraph
16
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage # If you are using it
17
+ from langchain_community.retrievers import BM25Retriever # If you are using it
18
+ from langgraph.prebuilt import ToolNode, tools_condition # If you are using it
19
+ from langchain.vectorstores import FAISS
20
+ from langchain.embeddings import HuggingFaceEmbeddings
21
+ from langchain.schema import Document
22
+ from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline
23
+ from io import BytesIO
24
+ from sentence_transformers import SentenceTransformer
25
+
26
+ nlp = spacy.load("en_core_web_sm")
27
+
28
+ # Define file extension sets for each category
29
+ PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
30
+ AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'}
31
+ CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'}
32
+ SPREADSHEET_EXTENSIONS = {
33
+ '.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm',
34
+ '.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet',
35
+ '.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2',
36
+ '.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx',
37
+ '.pmv', '.uos', '.txt'
38
+ }
39
+
40
+ def get_file_type(filename: str) -> str:
41
+ if not filename or '.' not in filename or filename == '':
42
+ return ''
43
+ ext = filename.lower().rsplit('.', 1)[-1]
44
+ dot_ext = f'.{ext}'
45
+ if dot_ext in PICTURE_EXTENSIONS:
46
+ return 'picture'
47
+ elif dot_ext in AUDIO_EXTENSIONS:
48
+ return 'audio'
49
+ elif dot_ext in CODE_EXTENSIONS:
50
+ return 'code'
51
+ elif dot_ext in SPREADSHEET_EXTENSIONS:
52
+ return 'spreadsheet'
53
+ else:
54
+ return 'unknown'
55
+
56
+ def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str:
57
+ """
58
+ Writes bytes to a file in the system temporary directory using the provided file_name.
59
+ Returns the full path to the saved file.
60
+ The file will persist until manually deleted or the OS cleans the temp directory.
61
+ """
62
+ temp_dir = tempfile.gettempdir()
63
+ file_path = os.path.join(temp_dir, file_name)
64
+ with open(file_path, 'wb') as f:
65
+ f.write(file_bytes)
66
+ print(f"File written to: {file_path}")
67
+ return file_path
68
+
69
+ import os
70
+ import re
71
+ from PIL import Image # This is correctly imported, but was being used incorrectly
72
+ import numpy as np
73
+ from collections import Counter
74
+ import torch
75
+ from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline
76
+ from typing import TypedDict, List, Optional, Dict, Any, Literal, Tuple
77
+ from langgraph.graph import StateGraph, START, END
78
+ from langchain.docstore.document import Document
79
+
80
+ # 1. Define the State type
81
+ class State(TypedDict, total=False):
82
+ question: str
83
+ task_id: str
84
+ input_file: bytes
85
+ file_type: str
86
+ context: List[Document] # Using LangChain's Document class
87
+ file_path: Optional[str]
88
+ youtube_url: Optional[str]
89
+ answer: Optional[str]
90
+ frame_answers: Optional[list]
91
+ next: Optional[str] # Added to track the next node
92
+
93
+ # --- LLM pipeline for general questions ---
94
+ llm_pipe = pipeline("text-generation",
95
+ #model="meta-llama/Llama-3.3-70B-Instruct",
96
+ #model="meta-llama/Meta-Llama-3-8B-Instruct",
97
+ #model="Qwen/Qwen2-7B-Instruct",
98
+ #model="microsoft/Phi-4-reasoning",
99
+ model="microsoft/Phi-3-mini-4k-instruct",
100
+ device_map="auto",
101
+ #device_map={ "": 0 }, # "" means the whole model
102
+ #max_memory={0: "10GiB"},
103
+ torch_dtype="auto",
104
+ max_new_tokens=256)
105
+
106
+ # Speech-to-text pipeline
107
+ asr_pipe = pipeline(
108
+ "automatic-speech-recognition",
109
+ model="openai/whisper-small",
110
+ device=-1
111
+ #device_map={"", 0},
112
+ #max_memory = {0: "4.5GiB"},
113
+ #device_map="auto"
114
+ )
115
+
116
+ # --- Your BLIP VQA setup ---
117
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ device = "cpu"
119
+ vqa_model_name = "Salesforce/blip-vqa-base"
120
+ processor_vqa = BlipProcessor.from_pretrained(vqa_model_name)
121
+
122
+ # Attempt to load model to GPU; fall back to CPU if OOM
123
+ try:
124
+ model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device)
125
+ except torch.cuda.OutOfMemoryError:
126
+ print("WARNING: Loading model to CPU due to insufficient GPU memory.")
127
+ device = "cpu" # Switch device to CPU
128
+ model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device)
129
+
130
+
131
+ # --- Helper: Answer question on a single frame ---
132
+ def answer_question_on_frame(image_path, question):
133
+ # Fixed: Properly use the PIL Image module
134
+ image = Image.open(image_path).convert('RGB')
135
+ inputs = processor_vqa(image, question, return_tensors="pt").to(device)
136
+ out = model_vqa.generate(**inputs)
137
+ answer = processor_vqa.decode(out[0], skip_special_tokens=True)
138
+ return answer
139
+
140
+ # --- Helper: Answer question about the whole video ---
141
+ def answer_video_question(frames_dir, question):
142
+ valid_exts = ('.jpg', '.jpeg', '.png')
143
+
144
+ # Check if directory exists
145
+ if not os.path.exists(frames_dir):
146
+ return {
147
+ "most_common_answer": "No frames found to analyze.",
148
+ "all_answers": [],
149
+ "answer_counts": Counter()
150
+ }
151
+
152
+ frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir)
153
+ if f.lower().endswith(valid_exts)]
154
+
155
+ # Sort frames properly by number
156
+ def get_frame_number(filename):
157
+ match = re.search(r'(\d+)', os.path.basename(filename))
158
+ return int(match.group(1)) if match else 0
159
+
160
+ frame_files = sorted(frame_files, key=get_frame_number)
161
+
162
+ if not frame_files:
163
+ return {
164
+ "most_common_answer": "No valid image frames found.",
165
+ "all_answers": [],
166
+ "answer_counts": Counter()
167
+ }
168
+
169
+ answers = []
170
+ for frame_path in frame_files:
171
+ try:
172
+ ans = answer_question_on_frame(frame_path, question)
173
+ answers.append(ans)
174
+ print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}")
175
+ except Exception as e:
176
+ print(f"Error processing frame {frame_path}: {str(e)}")
177
+
178
+ if not answers:
179
+ return {
180
+ "most_common_answer": "Could not analyze any frames successfully.",
181
+ "all_answers": [],
182
+ "answer_counts": Counter()
183
+ }
184
+
185
+ counted = Counter(answers)
186
+ most_common_answer, freq = counted.most_common(1)[0]
187
+ return {
188
+ "most_common_answer": most_common_answer,
189
+ "all_answers": answers,
190
+ "answer_counts": counted
191
+ }
192
+
193
+
194
+ def download_youtube_video(url, output_dir='/content/video/', output_filename='downloaded_video.mp4'):
195
+ # Ensure the output directory exists
196
+ os.makedirs(output_dir, exist_ok=True)
197
+
198
+ # Delete all files in the output directory
199
+ files = glob.glob(os.path.join(output_dir, '*'))
200
+ for f in files:
201
+ try:
202
+ os.remove(f)
203
+ except Exception as e:
204
+ print(f"Error deleting {f}: {str(e)}")
205
+
206
+ # Set output path for yt-dlp
207
+ output_path = os.path.join(output_dir, output_filename)
208
+
209
+ ydl_opts = {
210
+ 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
211
+ 'outtmpl': output_path,
212
+ 'quiet': True,
213
+ 'merge_output_format': 'mp4', # Ensures merged output is mp4
214
+ 'postprocessors': [{
215
+ 'key': 'FFmpegVideoConvertor',
216
+ 'preferedformat': 'mp4', # Recode if needed
217
+ }]
218
+ }
219
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
220
+ ydl.download([url])
221
+ return output_path
222
+
223
+
224
+
225
+ # --- Helper: Extract frames from video ---
226
+ def extract_frames(video_path, output_dir, frame_interval_seconds=10):
227
+ # --- Clean output directory before extracting new frames ---
228
+ if os.path.exists(output_dir):
229
+ for filename in os.listdir(output_dir):
230
+ file_path = os.path.join(output_dir, filename)
231
+ try:
232
+ if os.path.isfile(file_path) or os.path.islink(file_path):
233
+ os.unlink(file_path)
234
+ elif os.path.isdir(file_path):
235
+ shutil.rmtree(file_path)
236
+ except Exception as e:
237
+ print(f'Failed to delete {file_path}. Reason: {e}')
238
+ else:
239
+ os.makedirs(output_dir, exist_ok=True)
240
+
241
+ try:
242
+ cap = cv2.VideoCapture(video_path)
243
+ if not cap.isOpened():
244
+ print("Error: Could not open video.")
245
+ return False
246
+ fps = cap.get(cv2.CAP_PROP_FPS)
247
+ frame_interval = int(fps * frame_interval_seconds)
248
+ count = 0
249
+ saved = 0
250
+ while True:
251
+ ret, frame = cap.read()
252
+ if not ret:
253
+ break
254
+ if count % frame_interval == 0:
255
+ frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg")
256
+ cv2.imwrite(frame_filename, frame)
257
+ saved += 1
258
+ count += 1
259
+ cap.release()
260
+ print(f"Extracted {saved} frames.")
261
+ return saved > 0
262
+ except Exception as e:
263
+ print(f"Exception during frame extraction: {e}")
264
+ return False
265
+
266
+ def image_qa(image_path: str, question: str, model_name: str = vqa_model_name) -> str:
267
+ """
268
+ Answers questions about images using Hugging Face's VQA pipeline.
269
+
270
+ Args:
271
+ image_path: Path to local image file or URL
272
+ question: Natural language question about the image
273
+ model_name: Pretrained VQA model (default: good general-purpose model)
274
+
275
+ Returns:
276
+ str: The model's best answer
277
+ """
278
+ # Create VQA pipeline with specified model
279
+ vqa_pipeline = pipeline("visual-question-answering", model=model_name)
280
+
281
+ # Get predictions (automatically handles local files/URLs)
282
+ results = vqa_pipeline(image=image_path, question=question, top_k=1)
283
+
284
+ # Return top answer
285
+ return results[0]['answer']
286
+
287
+
288
+ def router(state: Dict[str, Any]) -> str:
289
+ """Determine the next node based on whether the question contains a YouTube URL or references Wikipedia."""
290
+ question = state.get('question', '')
291
+
292
+
293
+ # Pattern for Wikipedia and similar sources
294
+ wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)"
295
+ has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None
296
+
297
+ # Pattern for YouTube
298
+ yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+"
299
+ has_youtube = re.search(yt_pattern, question) is not None
300
+
301
+ # Check for image
302
+ has_image = state.get('file_type') == 'picture'
303
+
304
+ # Check for audio
305
+ has_audio = state.get('file_type') == 'audio'
306
+
307
+ print(f"Has Wikipedia reference: {has_wiki}")
308
+ print(f"Has YouTube link: {has_youtube}")
309
+ print(f"Has picture file: {has_image}")
310
+ print(f"Has audio file: {has_audio}")
311
+
312
+ if has_wiki:
313
+ return "retrieve"
314
+ elif has_youtube:
315
+ # Store the extracted YouTube URL in the state
316
+ url_match = re.search(r"(https?://[^\s]+)", question)
317
+ if url_match:
318
+ state['youtube_url'] = url_match.group(0)
319
+ return "video"
320
+ elif has_image:
321
+ return "image"
322
+ elif has_audio:
323
+ return "audio"
324
+ else:
325
+ return "llm"
326
+
327
+
328
+ # --- Node Implementation ---
329
+ def node_image(state: Dict[str, Any]) -> Dict[str, Any]:
330
+ """Router node that decides which node to go to next."""
331
+ print("Running node_image")
332
+ # Add the next state to the state dict
333
+ img = Image.open(state['file_path'])
334
+ state['answer'] = image_qa(state['file_path'], state['question'])
335
+ return state
336
+
337
+
338
+ def node_decide(state: Dict[str, Any]) -> Dict[str, Any]:
339
+ """Router node that decides which node to go to next."""
340
+ print("Running node_decide")
341
+ # Add the next state to the state dict
342
+ state["next"] = router(state)
343
+ print(f"Routing to: {state['next']}")
344
+ return state
345
+
346
+ def node_video(state: Dict[str, Any]) -> Dict[str, Any]:
347
+ print("Running node_video")
348
+ youtube_url = state.get('youtube_url')
349
+ if not youtube_url:
350
+ state['answer'] = "No YouTube URL found in the question."
351
+ return state
352
+
353
+ question = state['question']
354
+ # Extract the actual question part (remove the URL)
355
+ question_text = re.sub(r'https?://[^\s]+', '', question).strip()
356
+ if not question_text.endswith('?'):
357
+ question_text += '?'
358
+
359
+ video_file = download_youtube_video(youtube_url)
360
+ if not video_file or not os.path.exists(video_file):
361
+ state['answer'] = "Failed to download the video."
362
+ return state
363
+
364
+ frames_dir = "/tmp/frames"
365
+ os.makedirs(frames_dir, exist_ok=True)
366
+
367
+ success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10)
368
+ if not success:
369
+ state['answer'] = "Failed to extract frames from the video."
370
+ return state
371
+
372
+ result = answer_video_question(frames_dir, question_text)
373
+ state['answer'] = result['most_common_answer']
374
+ state['frame_answers'] = result['all_answers']
375
+
376
+ # Create Document objects for each frame analysis
377
+ frame_documents = []
378
+ for i, ans in enumerate(result['all_answers']):
379
+ doc = Document(
380
+ page_content=f"Frame {i}: {ans}",
381
+ metadata={"frame_number": i, "source": "video_analysis"}
382
+ )
383
+ frame_documents.append(doc)
384
+
385
+ # Add documents to state if not already present
386
+ if 'context' not in state:
387
+ state['context'] = []
388
+ state['context'].extend(frame_documents)
389
+
390
+ print(f"Video answer: {state['answer']}")
391
+ return state
392
+
393
+ def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]:
394
+ print(f"Processing audio file: {state['file_path']}")
395
+
396
+ try:
397
+ # Step 1: Transcribe audio
398
+ audio, sr = librosa.load(state['file_path'], sr=16000)
399
+ asr_result = asr_pipe({"raw": audio, "sampling_rate": sr})
400
+ audio_transcript = asr_result['text']
401
+ print(f"Audio transcript: {audio_transcript}")
402
+
403
+ # Step 2: Store ONLY the transcript in the vector store
404
+ transcript_doc = [Document(page_content=audio_transcript)]
405
+ embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5')
406
+ vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings)
407
+
408
+ # Step 3: Retrieve relevant docs for the user's question
409
+ question = state['question']
410
+ similar_docs = vector_db.similarity_search(question, k=1) # Only one doc in store
411
+ retrieved_context = "\n".join([doc.page_content for doc in similar_docs])
412
+
413
+ # Step 4: Augment prompt and generate answer
414
+ prompt = (
415
+ f"Use the following context to answer the question.\n"
416
+ f"Context:\n{retrieved_context}\n\n"
417
+ f"Question: {question}\nAnswer:"
418
+ )
419
+ llm_response = llm_pipe(prompt)
420
+ state['answer'] = llm_response[0]['generated_text']
421
+
422
+ except Exception as e:
423
+ error_msg = f"Audio processing error: {str(e)}"
424
+ print(error_msg)
425
+ state['answer'] = error_msg
426
+
427
+ return state
428
+
429
+ def node_llm(state: Dict[str, Any]) -> Dict[str, Any]:
430
+ print("Running node_llm")
431
+ question = state['question']
432
+
433
+ # Optionally add context from state (e.g., Wikipedia/Wikidata content)
434
+ context_text = ""
435
+ if 'article_content' in state and state['article_content']:
436
+ context_text = f"\n\nBackground Information:\n{state['article_content']}\n"
437
+ elif 'context' in state and state['context']:
438
+ context_text = "\n\n".join([doc.page_content for doc in state['context']])
439
+
440
+ # Compose a detailed prompt
441
+ prompt = (
442
+ "You are an expert researcher. Answer the user's question as accurately as possible. "
443
+ "If the text appears to be scrambled, try to unscramble the text for the user"
444
+ "If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. "
445
+ "If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step.\n\n"
446
+ f"Question: {question}"
447
+ f"{context_text}\n"
448
+ "Answer:"
449
+ )
450
+
451
+ # Add document to state for traceability
452
+ query_doc = Document(
453
+ page_content=prompt,
454
+ metadata={"source": "llm_prompt"}
455
+ )
456
+ if 'context' not in state:
457
+ state['context'] = []
458
+ state['context'].append(query_doc)
459
+
460
+ try:
461
+ result = llm_pipe(prompt)
462
+ state['answer'] = result[0]['generated_text']
463
+ except Exception as e:
464
+ print(f"Error in LLM processing: {str(e)}")
465
+ state['answer'] = f"An error occurred while processing your question: {str(e)}"
466
+
467
+ print(f"LLM answer: {state['answer']}")
468
+ return state
469
+
470
+
471
+ # --- Define the edge condition function ---
472
+ def get_next_node(state: Dict[str, Any]) -> str:
473
+ """Get the next node from the state."""
474
+ return state["next"]
475
+
476
+
477
+ # 2. Improved Wikipedia Retrieval Node
478
+ def extract_keywords(question: str) -> List[str]:
479
+ doc = nlp(question)
480
+ keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] # Extract proper nouns and nouns
481
+ return keywords
482
+
483
+ def extract_entities(question: str) -> List[str]:
484
+ doc = nlp(question)
485
+ entities = [ent.text for ent in doc.ents]
486
+ return entities if entities else [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")]
487
+
488
+
489
+ def retrieve(state: State) -> dict:
490
+ keywords = extract_entities(state["question"])
491
+ query = " ".join(keywords)
492
+ search_results = wikipedia.search(query)
493
+ selected_page = search_results[0] if search_results else None
494
+
495
+ if selected_page:
496
+ loader = WikipediaLoader(
497
+ query=selected_page,
498
+ lang="en",
499
+ load_max_docs=1,
500
+ doc_content_chars_max=100000,
501
+ load_all_available_meta=True
502
+ )
503
+ docs = loader.load()
504
+ # Chunk the article for finer retrieval
505
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
506
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
507
+ all_chunks = []
508
+ for doc in docs:
509
+ chunks = splitter.split_text(doc.page_content)
510
+ all_chunks.extend([Document(page_content=chunk) for chunk in chunks])
511
+ # Optionally: re-rank or filter chunks here
512
+ return {"context": all_chunks}
513
+ else:
514
+ return {"context": []}
515
+
516
+ # 3. Prompt Template for General QA
517
+ prompt = PromptTemplate(
518
+ input_variables=["question", "context"],
519
+ template=(
520
+ "You are an expert researcher. Given the following context from Wikipedia, answer the user's question as accurately as possible. "
521
+ "If the text appears to be scrambled, try to unscramble the text for the user"
522
+ "If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. "
523
+ "If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step."
524
+ "Context:\n{context}\n\n"
525
+ "Question: {question}\n\n"
526
+ "Best Estimate Answer:"
527
+ )
528
+ )
529
+
530
+ """
531
+ def generate(state: State) -> dict:
532
+ # Concatenate all context documents into a single string
533
+ docs_content = "\n\n".join(doc.page_content for doc in state["context"])
534
+ # Format the prompt for the LLM
535
+ prompt_str = prompt.format(question=state["question"], context=docs_content)
536
+ # Generate answer
537
+ response = llm.invoke(prompt_str)
538
+ return {"answer": response}
539
+ """
540
+
541
+ def generate(state: dict) -> dict:
542
+ # Concatenate all context documents into a single string
543
+ docs_content = "\n\n".join(doc.page_content for doc in state["context"])
544
+ # Format the prompt for the LLM
545
+ prompt_str = prompt.format(question=state["question"], context=docs_content)
546
+ # Generate answer using Hugging Face pipeline
547
+ response = llm_pipe(prompt_str)
548
+ # Extract generated text
549
+ answer = response[0]["generated_text"]
550
+ return {"answer": answer}
551
+
552
+ # Create the StateGraph
553
+ graph = StateGraph(State)
554
+
555
+ # Add nodes
556
+ graph.add_node("decide", node_decide)
557
+ graph.add_node("video", node_video)
558
+ graph.add_node("llm", node_llm)
559
+ graph.add_node("retrieve", retrieve)
560
+ graph.add_node("generate", generate)
561
+ graph.add_node("image", node_image)
562
+ graph.add_node("audio", node_audio_rag)
563
+
564
+ # Add edge from START to decide
565
+ graph.add_edge(START, "decide")
566
+ graph.add_edge("retrieve", "generate")
567
+
568
+ # Add conditional edges from decide to video or llm based on question
569
+ graph.add_conditional_edges(
570
+ "decide",
571
+ get_next_node,
572
+ {
573
+ "video": "video",
574
+ "llm": "llm",
575
+ "retrieve": "retrieve",
576
+ "image": "image",
577
+ "audio": "audio"
578
+ }
579
+ )
580
+
581
+ # Add edges from video and llm to END to terminate the graph
582
+ graph.add_edge("video", END)
583
+ graph.add_edge("llm", END)
584
+ graph.add_edge("generate", END)
585
+ graph.add_edge("image", END)
586
+ graph.add_edge("audio", END)
587
+
588
+ # Compile the graph
589
+ agent = graph.compile()
590
+
591
+ # --- Usage Example ---
592
+ def intelligent_agent(state: State) -> str:
593
+ """Process a question using the appropriate pipeline based on content."""
594
+ #state = State(question= question)
595
+ try:
596
+ final_state = agent.invoke(state)
597
+ return final_state.get('answer', "No answer found.")
598
+ except Exception as e:
599
+ print(f"Error in agent execution: {str(e)}")
600
+ return f"An error occurred: {str(e)}"
601
+