Spaces:
Sleeping
Sleeping
import numpy as np | |
import spacy | |
import tempfile | |
import glob | |
import yt_dlp | |
import shutil | |
import cv2 | |
import librosa | |
import wikipedia | |
from typing import TypedDict, List, Optional, Dict, Any | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import WikipediaLoader | |
from langgraph.graph import START, END, StateGraph | |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage # If you are using it | |
from langchain_community.retrievers import BM25Retriever # If you are using it | |
from langgraph.prebuilt import ToolNode, tools_condition # If you are using it | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.schema import Document | |
from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
from io import BytesIO | |
from sentence_transformers import SentenceTransformer | |
from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration | |
from transformers import AutoTokenizer, AutoModelWithLMHead | |
import os | |
import re | |
from PIL import Image # This is correctly imported, but was being used incorrectly | |
import numpy as np | |
from collections import Counter | |
import torch | |
from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
from typing import TypedDict, List, Optional, Dict, Any, Literal, Tuple | |
from langgraph.graph import StateGraph, START, END | |
from langchain.docstore.document import Document | |
nlp = spacy.load("en_core_web_sm") | |
# Define file extension sets for each category | |
PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
SPREADSHEET_EXTENSIONS = { | |
'.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
'.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
'.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
'.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
'.pmv', '.uos', '.txt' | |
} | |
def get_file_type(filename: str) -> str: | |
if not filename or '.' not in filename or filename == '': | |
return '' | |
ext = filename.lower().rsplit('.', 1)[-1] | |
dot_ext = f'.{ext}' | |
if dot_ext in PICTURE_EXTENSIONS: | |
return 'picture' | |
elif dot_ext in AUDIO_EXTENSIONS: | |
return 'audio' | |
elif dot_ext in CODE_EXTENSIONS: | |
return 'code' | |
elif dot_ext in SPREADSHEET_EXTENSIONS: | |
return 'spreadsheet' | |
else: | |
return 'unknown' | |
def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
""" | |
Writes bytes to a file in the system temporary directory using the provided file_name. | |
Returns the full path to the saved file. | |
The file will persist until manually deleted or the OS cleans the temp directory. | |
""" | |
temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces | |
os.makedirs(temp_dir, exist_ok=True) | |
file_path = os.path.join(temp_dir, file_name) | |
with open(file_path, 'wb') as f: | |
f.write(file_bytes) | |
print(f"File written to: {file_path}") | |
return file_path | |
# 1. Define the State type | |
class State(TypedDict, total=False): | |
question: str | |
task_id: str | |
input_file: Optional[bytes] | |
file_type: Optional[str] | |
context: List[Document] # Using LangChain's Document class | |
file_path: Optional[str] | |
youtube_url: Optional[str] | |
answer: Optional[str] | |
frame_answers: Optional[list] | |
next: Optional[str] # Added to track the next node | |
# --- LLM pipeline for general questions --- | |
llm_pipe = pipeline( | |
"text-generation", | |
model="microsoft/Phi-3-mini-4k-instruct", | |
device_map="auto", | |
torch_dtype="auto", | |
max_new_tokens=256, | |
trust_remote_code=True | |
) | |
# Initialize RAG components | |
tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base", trust_remote_code=True) | |
retriever = RagRetriever.from_pretrained( | |
"facebook/rag-token-base", | |
index_name="exact", # or "legacy" for legacy FAISS index | |
use_dummy_dataset=False, # set to False and download the full index for real Wikipedia retrieval | |
trust_remote_code=True, # Trust remote code for dataset loading | |
dataset_revision="main", # Specify a fixed revision | |
dataset="wiki_dpr", # Explicitly specify dataset name | |
) | |
rag_model = RagSequenceForGeneration.from_pretrained( | |
"facebook/rag-token-base", | |
retriever=retriever, | |
trust_remote_code=True | |
) | |
# Speech-to-text pipeline | |
asr_pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-small", | |
device="auto" | |
) | |
# --- BLIP VQA setup --- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
vqa_model_name = "Salesforce/blip-vqa-base" | |
processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) | |
# Attempt to load model to GPU; fall back to CPU if OOM | |
try: | |
model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
except torch.cuda.OutOfMemoryError: | |
print("WARNING: Loading model to CPU due to insufficient GPU memory.") | |
device = "cpu" # Switch device to CPU | |
model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
# --- Helper functions --- | |
def ensure_final_answer_format(answer_text: str) -> str: | |
"""Ensure the answer ends with FINAL ANSWER: format""" | |
# Check if the answer already contains a FINAL ANSWER section | |
if "FINAL ANSWER:" in answer_text: | |
# Extract everything after FINAL ANSWER: | |
final_answer_part = answer_text.split("FINAL ANSWER:", 1)[1].strip() | |
return f"FINAL ANSWER: {final_answer_part}" | |
else: | |
# If no FINAL ANSWER section exists, wrap the entire answer | |
return f"FINAL ANSWER: {answer_text.strip()}" | |
def extract_entities(text: str) -> List[str]: | |
"""Extract key entities from text using spaCy if available, or regex fallback""" | |
if nlp: | |
# Using spaCy for better entity extraction | |
doc = nlp(text) | |
entities = [ent.text for ent in doc.ents] | |
keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] | |
return entities if entities else keywords | |
else: | |
# Simple fallback using regex to extract potential keywords | |
words = text.lower().split() | |
stopwords = ["what", "who", "when", "where", "why", "how", "is", "are", "the", "a", "an", "of", "in", "on", "at"] | |
keywords = [word for word in words if word not in stopwords and len(word) > 2] | |
return keywords | |
def answer_question_on_frame(image_path, question): | |
"""Answer a question about a single video frame using BLIP""" | |
try: | |
image = Image.open(image_path).convert('RGB') | |
inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
out = model_vqa.generate(**inputs) | |
answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
return answer | |
except Exception as e: | |
print(f"Error processing frame {image_path}: {str(e)}") | |
return "Error processing this frame" | |
def answer_video_question(frames_dir, question): | |
"""Answer a question about a video by analyzing extracted frames""" | |
valid_exts = ('.jpg', '.jpeg', '.png') | |
# Check if directory exists | |
if not os.path.exists(frames_dir): | |
return { | |
"most_common_answer": "No frames found to analyze.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) | |
if f.lower().endswith(valid_exts)] | |
# Sort frames properly by number | |
def get_frame_number(filename): | |
match = re.search(r'(\d+)', os.path.basename(filename)) | |
return int(match.group(1)) if match else 0 | |
frame_files = sorted(frame_files, key=get_frame_number) | |
if not frame_files: | |
return { | |
"most_common_answer": "No valid image frames found.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
answers = [] | |
for frame_path in frame_files: | |
try: | |
ans = answer_question_on_frame(frame_path, question) | |
answers.append(ans) | |
print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") | |
except Exception as e: | |
print(f"Error processing frame {frame_path}: {str(e)}") | |
if not answers: | |
return { | |
"most_common_answer": "Could not analyze any frames successfully.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
counted = Counter(answers) | |
most_common_answer, freq = counted.most_common(1)[0] | |
return { | |
"most_common_answer": most_common_answer, | |
"all_answers": answers, | |
"answer_counts": counted | |
} | |
def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'): | |
"""Download a YouTube video using yt-dlp""" | |
# Ensure the output directory exists | |
os.makedirs(output_dir, exist_ok=True) | |
# Delete all files in the output directory | |
files = glob.glob(os.path.join(output_dir, '*')) | |
for f in files: | |
try: | |
os.remove(f) | |
except Exception as e: | |
print(f"Error deleting {f}: {str(e)}") | |
# Set output path for yt-dlp | |
output_path = os.path.join(output_dir, output_filename) | |
try: | |
ydl_opts = { | |
'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', | |
'outtmpl': output_path, | |
'quiet': True, | |
'merge_output_format': 'mp4', # Ensures merged output is mp4 | |
'postprocessors': [{ | |
'key': 'FFmpegVideoConvertor', | |
'preferedformat': 'mp4', # Recode if needed | |
}] | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
return output_path | |
except Exception as e: | |
print(f"Error downloading YouTube video: {str(e)}") | |
return None | |
def extract_frames(video_path, output_dir, frame_interval_seconds=10): | |
"""Extract frames from a video file at specified intervals""" | |
# Clean output directory before extracting new frames | |
if os.path.exists(output_dir): | |
for filename in os.listdir(output_dir): | |
file_path = os.path.join(output_dir, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
else: | |
os.makedirs(output_dir, exist_ok=True) | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return False | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_interval = int(fps * frame_interval_seconds) | |
count = 0 | |
saved = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if count % frame_interval == 0: | |
frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") | |
cv2.imwrite(frame_filename, frame) | |
saved += 1 | |
count += 1 | |
cap.release() | |
print(f"Extracted {saved} frames.") | |
return saved > 0 | |
except Exception as e: | |
print(f"Exception during frame extraction: {e}") | |
return False | |
def image_qa(image_path: str, question: str) -> str: | |
"""Answer questions about an image using the BLIP model""" | |
try: | |
image = Image.open(image_path).convert('RGB') | |
inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
out = model_vqa.generate(**inputs) | |
answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
return answer | |
except Exception as e: | |
print(f"Error in image_qa: {str(e)}") | |
return f"Error processing image: {str(e)}" | |
# --- Node functions --- | |
def router(state: Dict[str, Any]) -> str: | |
"""Determine the next node based on question content and file type""" | |
question = state.get('question', '') | |
# Pattern for Wikipedia and similar sources | |
wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)" | |
has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None | |
# Pattern for YouTube | |
yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
has_youtube = re.search(yt_pattern, question) is not None | |
# Check for image | |
has_image = state.get('file_type') == 'picture' | |
# Check for audio | |
has_audio = state.get('file_type') == 'audio' | |
print(f"Has Wikipedia reference: {has_wiki}") | |
print(f"Has YouTube link: {has_youtube}") | |
print(f"Has picture file: {has_image}") | |
print(f"Has audio file: {has_audio}") | |
if has_wiki: | |
return "retrieve" | |
elif has_youtube: | |
# Store the extracted YouTube URL in the state | |
url_match = re.search(r"(https?://[^\s]+)", question) | |
if url_match: | |
state['youtube_url'] = url_match.group(0) | |
return "video" | |
elif has_image: | |
return "image" | |
elif has_audio: | |
return "audio" | |
else: | |
return "llm" | |
def node_decide(state: Dict[str, Any]) -> Dict[str, Any]: | |
"""Router node that decides which node to go to next""" | |
print("Running node_decide") | |
# Initialize context list if not present | |
if 'context' not in state: | |
state['context'] = [] | |
# Add the next state to the state dict | |
state["next"] = router(state) | |
print(f"Routing to: {state['next']}") | |
return state | |
def node_image(state: Dict[str, Any]) -> Dict[str, Any]: | |
"""Process image-based questions""" | |
print("Running node_image") | |
try: | |
# Make sure the image file exists | |
if not os.path.exists(state['file_path']): | |
state['answer'] = ensure_final_answer_format("Image file not found.") | |
return state | |
# Get answer from image QA model | |
answer = image_qa(state['file_path'], state['question']) | |
# Format the final answer | |
state['answer'] = ensure_final_answer_format(answer) | |
# Add document to state for traceability | |
image_doc = Document( | |
page_content=f"Image analysis result: {answer}", | |
metadata={"source": "image_analysis", "file_path": state['file_path']} | |
) | |
state['context'].append(image_doc) | |
except Exception as e: | |
error_msg = f"Error processing image: {str(e)}" | |
print(error_msg) | |
state['answer'] = ensure_final_answer_format(error_msg) | |
return state | |
def node_video(state: Dict[str, Any]) -> Dict[str, Any]: | |
"""Process video-based questions""" | |
print("Running node_video") | |
youtube_url = state.get('youtube_url') | |
if not youtube_url: | |
state['answer'] = ensure_final_answer_format("No YouTube URL found in the question.") | |
return state | |
question = state['question'] | |
# Extract the actual question part (remove the URL) | |
question_text = re.sub(r'https?://[^\s]+', '', question).strip() | |
if not question_text.endswith('?'): | |
question_text += '?' | |
video_file = download_youtube_video(youtube_url) | |
if not video_file or not os.path.exists(video_file): | |
state['answer'] = ensure_final_answer_format("Failed to download the video.") | |
return state | |
frames_dir = "/tmp/frames" | |
os.makedirs(frames_dir, exist_ok=True) | |
success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10) | |
if not success: | |
state['answer'] = ensure_final_answer_format("Failed to extract frames from the video.") | |
return state | |
result = answer_video_question(frames_dir, question_text) | |
final_answer = result['most_common_answer'] | |
state['frame_answers'] = result['all_answers'] | |
# Create Document objects for each frame analysis | |
frame_documents = [] | |
for i, ans in enumerate(result['all_answers']): | |
doc = Document( | |
page_content=f"Frame {i}: {ans}", | |
metadata={"frame_number": i, "source": "video_analysis"} | |
) | |
frame_documents.append(doc) | |
# Add documents to state | |
state['context'].extend(frame_documents) | |
state['answer'] = ensure_final_answer_format(final_answer) | |
print(f"Video answer: {state['answer']}") | |
return state | |
def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]: | |
"""Process audio-based questions""" | |
print(f"Processing audio file: {state['file_path']}") | |
try: | |
# Step 1: Transcribe audio | |
audio, sr = librosa.load(state['file_path'], sr=16000) | |
asr_result = asr_pipe({"raw": audio, "sampling_rate": sr}) | |
audio_transcript = asr_result['text'] | |
print(f"Audio transcript: {audio_transcript}") | |
# Step 2: Store transcript in vector store | |
transcript_doc = [Document(page_content=audio_transcript)] | |
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5') | |
vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings) | |
# Step 3: Retrieve relevant docs for the user's question | |
question = state['question'] | |
similar_docs = vector_db.similarity_search(question, k=1) | |
retrieved_context = "\n".join([doc.page_content for doc in similar_docs]) | |
# Step 4: Generate answer | |
prompt = ( | |
f"You are an AI assistant that answers questions about audio content.\n\n" | |
f"Audio transcript: {retrieved_context}\n\n" | |
f"Question: {question}\n\n" | |
f"Based only on the provided audio transcript, answer the question. " | |
f"If the transcript does not contain relevant information, state that clearly.\n\n" | |
f"End your response with 'FINAL ANSWER: ' followed by a concise answer." | |
) | |
llm_response = llm_pipe(prompt) | |
answer_text = llm_response[0]['generated_text'] | |
# Add documents to state | |
state['context'].extend(transcript_doc) | |
state['context'].append(Document( | |
page_content=prompt, | |
metadata={"source": "audio_analysis_prompt"} | |
)) | |
# Ensure final answer format | |
state['answer'] = ensure_final_answer_format(answer_text) | |
except Exception as e: | |
error_msg = f"Audio processing error: {str(e)}" | |
print(error_msg) | |
state['answer'] = ensure_final_answer_format(error_msg) | |
return state | |
def node_llm(state: Dict[str, Any]) -> Dict[str, Any]: | |
"""Process general knowledge questions with LLM""" | |
print("Running node_llm") | |
question = state['question'] | |
# Compose a detailed prompt | |
prompt = ( | |
"You are an AI assistant that answers questions using your general knowledge. " | |
"Follow these steps:\n\n" | |
"1. If the question appears to be scrambled or jumbled, first try to unscramble or reconstruct the intended meaning.\n" | |
"2. Analyze the question (unscrambled if needed) and use your own knowledge to answer it.\n" | |
"3. If the question can't be answered with certainty, provide your best estimate and clearly explain any assumptions.\n" | |
"4. Format your answer using these rules:\n" | |
" - Numbers: Plain digits without commas/units (e.g. 1234567)\n" | |
" - Strings: Minimal words, no articles/abbreviations\n" | |
" - Lists: comma-separated values without extra formatting\n\n" | |
"5. Always conclude with:\n" | |
"FINAL ANSWER: [your answer] (replace bracketed text)\n\n" | |
f"Current question: {question}" | |
) | |
# Add document to state for traceability | |
query_doc = Document( | |
page_content=prompt, | |
metadata={"source": "llm_prompt"} | |
) | |
state['context'].append(query_doc) | |
try: | |
result = llm_pipe(prompt) | |
answer_text = result[0]['generated_text'] | |
state['answer'] = ensure_final_answer_format(answer_text) | |
except Exception as e: | |
print(f"Error in LLM processing: {str(e)}") | |
error_msg = f"An error occurred while processing your question: {str(e)}" | |
state['answer'] = ensure_final_answer_format(error_msg) | |
print(f"LLM answer: {state['answer']}") | |
return state | |
def retrieve(state: State) -> State: | |
"""Retrieve relevant documents using RAG""" | |
print("Running retrieve") | |
question = state["question"] | |
try: | |
# Tokenize the question | |
inputs = tokenizer(question, return_tensors="pt") | |
# Get doc_ids by using the retriever directly | |
question_hidden_states = rag_model.question_encoder(inputs["input_ids"])[0] | |
docs_dict = retriever( | |
inputs["input_ids"].numpy(), | |
question_hidden_states.detach().numpy(), | |
return_tensors="pt" | |
) | |
# Extract the retrieved passages | |
all_chunks = [] | |
# Debug print to see what's in docs_dict | |
print(f"docs_dict keys: {docs_dict.keys()}") | |
# Check for different possible keys that might contain the documents | |
doc_text_key = None | |
for possible_key in ['retrieved_doc_text', 'doc_text', 'texts', 'documents']: | |
if possible_key in docs_dict: | |
doc_text_key = possible_key | |
break | |
if doc_text_key: | |
# Access the retrieved document texts from the docs_dict | |
for i in range(len(docs_dict["doc_ids"][0])): | |
doc_text = docs_dict[doc_text_key][0][i] | |
all_chunks.append(Document(page_content=doc_text)) | |
print(f"Retrieved {len(all_chunks)} documents") | |
else: | |
# Fallback: Try to extract document text from doc_ids | |
doc_ids = docs_dict.get("doc_ids", [[]])[0] | |
print(f"Retrieved doc_ids: {doc_ids}") | |
# Create minimal document stubs from IDs | |
for doc_id in doc_ids: | |
stub_text = f"Information related to document ID: {doc_id}" | |
all_chunks.append(Document(page_content=stub_text)) | |
print(f"Created {len(all_chunks)} document stubs from IDs") | |
# Add documents to state context | |
if not state.get('context'): | |
state['context'] = [] | |
state['context'].extend(all_chunks) | |
except Exception as e: | |
print(f"Error in retrieval: {str(e)}") | |
# Create an error document | |
error_doc = Document( | |
page_content=f"Error during retrieval: {str(e)}", | |
metadata={"source": "retrieval_error"} | |
) | |
if not state.get('context'): | |
state['context'] = [] | |
state['context'].append(error_doc) | |
return state | |
def generate(state: State) -> State: | |
"""Generate an answer based on retrieved documents""" | |
print("Running generate") | |
try: | |
# Check if context exists | |
if not state.get('context') or len(state['context']) == 0: | |
state['answer'] = ensure_final_answer_format("No relevant information found to answer your question.") | |
return state | |
# Concatenate all context documents into a single string | |
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
# Format the prompt for the LLM | |
prompt_str = PromptTemplate( | |
input_variables=["question", "context"], | |
template=( | |
"You are an AI assistant that answers questions using retrieved context. " | |
"Follow these steps:\n\n" | |
"1. Analyze the provided context:\n{context}\n\n" | |
"2. If the context contains scrambled text, first attempt to reconstruct meaningful information\n" | |
"3. If the question can't be answered from context alone, combine context with general knowledge " | |
"but clearly state this limitation\n" | |
"4. Format your answer using these rules:\n" | |
" - Numbers: Plain digits without commas/units (e.g. 1234567)\n" | |
" - Strings: Minimal words, no articles/abbreviations\n" | |
" - Lists: comma-separated values without extra formatting\n\n" | |
"5. Always conclude with:\n" | |
"FINAL ANSWER: [your answer] (replace bracketed text)\n\n" | |
"Current question: {question}" | |
) | |
).format(question=state["question"], context=docs_content) | |
# Generate answer using the LLM pipeline | |
response = llm_pipe(prompt_str) | |
answer_text = response[0]["generated_text"] | |
# Ensure answer has the FINAL ANSWER format | |
state['answer'] = ensure_final_answer_format(answer_text) | |
except Exception as e: | |
print(f"Error in generate node: {str(e)}") | |
error_msg = f"Error generating answer: {str(e)}" | |
state['answer'] = ensure_final_answer_format(error_msg) | |
return state | |
# --- Define the edge condition function --- | |
def get_next_node(state: Dict[str, Any]) -> str: | |
"""Get the next node from the state""" | |
return state["next"] | |
# Create the StateGraph | |
graph = StateGraph(State) | |
# Add nodes | |
graph.add_node("decide", node_decide) | |
graph.add_node("video", node_video) | |
graph.add_node("llm", node_llm) | |
graph.add_node("retrieve", retrieve) | |
graph.add_node("generate", generate) | |
graph.add_node("image", node_image) | |
graph.add_node("audio", node_audio_rag) | |
# Add edge from START to decide | |
graph.add_edge(START, "decide") | |
graph.add_edge("retrieve", "generate") | |
# Add conditional edges from decide to other nodes based on question | |
graph.add_conditional_edges( | |
"decide", | |
get_next_node, | |
{ | |
"video": "video", | |
"llm": "llm", | |
"retrieve": "retrieve", | |
"image": "image", | |
"audio": "audio" | |
} | |
) | |
# Add edges from all terminal nodes to END | |
graph.add_edge("video", END) | |
graph.add_edge("llm", END) | |
graph.add_edge("generate", END) | |
graph.add_edge("image", END) | |
graph.add_edge("audio", END) | |
# Compile the graph | |
agent = graph.compile() | |
# --- Intelligent Agent Function --- | |
def intelligent_agent(state: State) -> str: | |
"""Process a question using the appropriate pipeline based on content.""" | |
try: | |
# Ensure state has proper structure | |
if not isinstance(state, dict): | |
return "FINAL ANSWER: Error - input must be a valid State dictionary" | |
# Make sure question exists | |
if 'question' not in state: | |
return "FINAL ANSWER: Error - question is required" | |
# Initialize context if not present | |
if 'context' not in state: | |
state['context'] = [] | |
print(f"Processing question: {state['question']}") | |
# Invoke the agent with the state | |
final_state = agent.invoke(state) | |
# Ensure answer has FINAL ANSWER format | |
answer = final_state.get('answer', "No answer found.") | |
formatted_answer = ensure_final_answer_format(answer) | |
return formatted_answer | |
except Exception as e: | |
print(f"Error in agent execution: {str(e)}") | |
return f"FINAL ANSWER: An error occurred - {str(e)}" |