Spaces:
Sleeping
Sleeping
import numpy as np | |
import spacy | |
import tempfile | |
import glob | |
import yt_dlp | |
import shutil | |
import cv2 | |
import librosa | |
import wikipedia | |
from typing import TypedDict, List, Optional, Dict, Any | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import WikipediaLoader | |
from langgraph.graph import START, END, StateGraph | |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage # If you are using it | |
from langchain_community.retrievers import BM25Retriever # If you are using it | |
from langgraph.prebuilt import ToolNode, tools_condition # If you are using it | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.schema import Document | |
from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
from io import BytesIO | |
from sentence_transformers import SentenceTransformer | |
import os | |
import re | |
from PIL import Image # This is correctly imported, but was being used incorrectly | |
import numpy as np | |
from collections import Counter | |
import torch | |
from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
from typing import TypedDict, List, Optional, Dict, Any, Literal, Tuple | |
from langgraph.graph import StateGraph, START, END | |
from langchain.docstore.document import Document | |
nlp = spacy.load("en_core_web_sm") | |
# Define file extension sets for each category | |
PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
SPREADSHEET_EXTENSIONS = { | |
'.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
'.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
'.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
'.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
'.pmv', '.uos', '.txt' | |
} | |
def get_file_type(filename: str) -> str: | |
if not filename or '.' not in filename or filename == '': | |
return '' | |
ext = filename.lower().rsplit('.', 1)[-1] | |
dot_ext = f'.{ext}' | |
if dot_ext in PICTURE_EXTENSIONS: | |
return 'picture' | |
elif dot_ext in AUDIO_EXTENSIONS: | |
return 'audio' | |
elif dot_ext in CODE_EXTENSIONS: | |
return 'code' | |
elif dot_ext in SPREADSHEET_EXTENSIONS: | |
return 'spreadsheet' | |
else: | |
return 'unknown' | |
def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
""" | |
Writes bytes to a file in the system temporary directory using the provided file_name. | |
Returns the full path to the saved file. | |
The file will persist until manually deleted or the OS cleans the temp directory. | |
""" | |
temp_dir = tempfile.gettempdir() | |
file_path = os.path.join(temp_dir, file_name) | |
with open(file_path, 'wb') as f: | |
f.write(file_bytes) | |
print(f"File written to: {file_path}") | |
return file_path | |
# 1. Define the State type | |
class State(TypedDict, total=False): | |
question: str | |
task_id: str | |
input_file: bytes | |
file_type: str | |
context: List[Document] # Using LangChain's Document class | |
file_path: Optional[str] | |
youtube_url: Optional[str] | |
answer: Optional[str] | |
frame_answers: Optional[list] | |
next: Optional[str] # Added to track the next node | |
# --- LLM pipeline for general questions --- | |
llm_pipe = pipeline("text-generation", | |
#model="meta-llama/Llama-3.3-70B-Instruct", | |
#model="meta-llama/Meta-Llama-3-8B-Instruct", | |
#model="Qwen/Qwen2-7B-Instruct", | |
#model="microsoft/Phi-4-reasoning", | |
model="microsoft/Phi-3-mini-4k-instruct", | |
device_map="auto", | |
#device_map={ "": 0 }, # "" means the whole model | |
#max_memory={0: "10GiB"}, | |
torch_dtype="auto", | |
max_new_tokens=256) | |
# Speech-to-text pipeline | |
asr_pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-small", | |
device=-1 | |
#device_map={"", 0}, | |
#max_memory = {0: "4.5GiB"}, | |
#device_map="auto" | |
) | |
# --- Your BLIP VQA setup --- | |
#device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = "cpu" | |
vqa_model_name = "Salesforce/blip-vqa-base" | |
processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) | |
# Attempt to load model to GPU; fall back to CPU if OOM | |
try: | |
model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
except torch.cuda.OutOfMemoryError: | |
print("WARNING: Loading model to CPU due to insufficient GPU memory.") | |
device = "cpu" # Switch device to CPU | |
model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
# --- Helper: Answer question on a single frame --- | |
def answer_question_on_frame(image_path, question): | |
# Fixed: Properly use the PIL Image module | |
image = Image.open(image_path).convert('RGB') | |
inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
out = model_vqa.generate(**inputs) | |
answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
return answer | |
# --- Helper: Answer question about the whole video --- | |
def answer_video_question(frames_dir, question): | |
valid_exts = ('.jpg', '.jpeg', '.png') | |
# Check if directory exists | |
if not os.path.exists(frames_dir): | |
return { | |
"most_common_answer": "No frames found to analyze.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) | |
if f.lower().endswith(valid_exts)] | |
# Sort frames properly by number | |
def get_frame_number(filename): | |
match = re.search(r'(\d+)', os.path.basename(filename)) | |
return int(match.group(1)) if match else 0 | |
frame_files = sorted(frame_files, key=get_frame_number) | |
if not frame_files: | |
return { | |
"most_common_answer": "No valid image frames found.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
answers = [] | |
for frame_path in frame_files: | |
try: | |
ans = answer_question_on_frame(frame_path, question) | |
answers.append(ans) | |
print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") | |
except Exception as e: | |
print(f"Error processing frame {frame_path}: {str(e)}") | |
if not answers: | |
return { | |
"most_common_answer": "Could not analyze any frames successfully.", | |
"all_answers": [], | |
"answer_counts": Counter() | |
} | |
counted = Counter(answers) | |
most_common_answer, freq = counted.most_common(1)[0] | |
return { | |
"most_common_answer": most_common_answer, | |
"all_answers": answers, | |
"answer_counts": counted | |
} | |
def download_youtube_video(url, output_dir='/content/video/', output_filename='downloaded_video.mp4'): | |
# Ensure the output directory exists | |
os.makedirs(output_dir, exist_ok=True) | |
# Delete all files in the output directory | |
files = glob.glob(os.path.join(output_dir, '*')) | |
for f in files: | |
try: | |
os.remove(f) | |
except Exception as e: | |
print(f"Error deleting {f}: {str(e)}") | |
# Set output path for yt-dlp | |
output_path = os.path.join(output_dir, output_filename) | |
ydl_opts = { | |
'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', | |
'outtmpl': output_path, | |
'quiet': True, | |
'merge_output_format': 'mp4', # Ensures merged output is mp4 | |
'postprocessors': [{ | |
'key': 'FFmpegVideoConvertor', | |
'preferedformat': 'mp4', # Recode if needed | |
}] | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
return output_path | |
# --- Helper: Extract frames from video --- | |
def extract_frames(video_path, output_dir, frame_interval_seconds=10): | |
# --- Clean output directory before extracting new frames --- | |
if os.path.exists(output_dir): | |
for filename in os.listdir(output_dir): | |
file_path = os.path.join(output_dir, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
else: | |
os.makedirs(output_dir, exist_ok=True) | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return False | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_interval = int(fps * frame_interval_seconds) | |
count = 0 | |
saved = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if count % frame_interval == 0: | |
frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") | |
cv2.imwrite(frame_filename, frame) | |
saved += 1 | |
count += 1 | |
cap.release() | |
print(f"Extracted {saved} frames.") | |
return saved > 0 | |
except Exception as e: | |
print(f"Exception during frame extraction: {e}") | |
return False | |
def image_qa(image_path: str, question: str, model_name: str = vqa_model_name) -> str: | |
""" | |
Answers questions about images using Hugging Face's VQA pipeline. | |
Args: | |
image_path: Path to local image file or URL | |
question: Natural language question about the image | |
model_name: Pretrained VQA model (default: good general-purpose model) | |
Returns: | |
str: The model's best answer | |
""" | |
# Create VQA pipeline with specified model | |
vqa_pipeline = pipeline("visual-question-answering", model=model_name) | |
# Get predictions (automatically handles local files/URLs) | |
results = vqa_pipeline(image=image_path, question=question, top_k=1) | |
# Return top answer | |
return results[0]['answer'] | |
def router(state: Dict[str, Any]) -> str: | |
"""Determine the next node based on whether the question contains a YouTube URL or references Wikipedia.""" | |
question = state.get('question', '') | |
# Pattern for Wikipedia and similar sources | |
wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)" | |
has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None | |
# Pattern for YouTube | |
yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
has_youtube = re.search(yt_pattern, question) is not None | |
# Check for image | |
has_image = state.get('file_type') == 'picture' | |
# Check for audio | |
has_audio = state.get('file_type') == 'audio' | |
print(f"Has Wikipedia reference: {has_wiki}") | |
print(f"Has YouTube link: {has_youtube}") | |
print(f"Has picture file: {has_image}") | |
print(f"Has audio file: {has_audio}") | |
if has_wiki: | |
return "retrieve" | |
elif has_youtube: | |
# Store the extracted YouTube URL in the state | |
url_match = re.search(r"(https?://[^\s]+)", question) | |
if url_match: | |
state['youtube_url'] = url_match.group(0) | |
return "video" | |
elif has_image: | |
return "image" | |
elif has_audio: | |
return "audio" | |
else: | |
return "llm" | |
# --- Node Implementation --- | |
def node_image(state: Dict[str, Any]) -> Dict[str, Any]: | |
"""Router node that decides which node to go to next.""" | |
print("Running node_image") | |
# Add the next state to the state dict | |
img = Image.open(state['file_path']) | |
state['answer'] = image_qa(state['file_path'], state['question']) | |
return state | |
def node_decide(state: Dict[str, Any]) -> Dict[str, Any]: | |
"""Router node that decides which node to go to next.""" | |
print("Running node_decide") | |
# Add the next state to the state dict | |
state["next"] = router(state) | |
print(f"Routing to: {state['next']}") | |
return state | |
def node_video(state: Dict[str, Any]) -> Dict[str, Any]: | |
print("Running node_video") | |
youtube_url = state.get('youtube_url') | |
if not youtube_url: | |
state['answer'] = "No YouTube URL found in the question." | |
return state | |
question = state['question'] | |
# Extract the actual question part (remove the URL) | |
question_text = re.sub(r'https?://[^\s]+', '', question).strip() | |
if not question_text.endswith('?'): | |
question_text += '?' | |
video_file = download_youtube_video(youtube_url) | |
if not video_file or not os.path.exists(video_file): | |
state['answer'] = "Failed to download the video." | |
return state | |
frames_dir = "/tmp/frames" | |
os.makedirs(frames_dir, exist_ok=True) | |
success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10) | |
if not success: | |
state['answer'] = "Failed to extract frames from the video." | |
return state | |
result = answer_video_question(frames_dir, question_text) | |
state['answer'] = result['most_common_answer'] | |
state['frame_answers'] = result['all_answers'] | |
# Create Document objects for each frame analysis | |
frame_documents = [] | |
for i, ans in enumerate(result['all_answers']): | |
doc = Document( | |
page_content=f"Frame {i}: {ans}", | |
metadata={"frame_number": i, "source": "video_analysis"} | |
) | |
frame_documents.append(doc) | |
# Add documents to state if not already present | |
if 'context' not in state: | |
state['context'] = [] | |
state['context'].extend(frame_documents) | |
print(f"Video answer: {state['answer']}") | |
return state | |
def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]: | |
print(f"Processing audio file: {state['file_path']}") | |
try: | |
# Step 1: Transcribe audio | |
audio, sr = librosa.load(state['file_path'], sr=16000) | |
asr_result = asr_pipe({"raw": audio, "sampling_rate": sr}) | |
audio_transcript = asr_result['text'] | |
print(f"Audio transcript: {audio_transcript}") | |
# Step 2: Store ONLY the transcript in the vector store | |
transcript_doc = [Document(page_content=audio_transcript)] | |
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5') | |
vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings) | |
# Step 3: Retrieve relevant docs for the user's question | |
question = state['question'] | |
similar_docs = vector_db.similarity_search(question, k=1) # Only one doc in store | |
retrieved_context = "\n".join([doc.page_content for doc in similar_docs]) | |
# Step 4: Augment prompt and generate answer | |
prompt = ( | |
f"Use the following context to answer the question.\n" | |
f"Context:\n{retrieved_context}\n\n" | |
f"Question: {question}\nAnswer:" | |
) | |
llm_response = llm_pipe(prompt) | |
state['answer'] = llm_response[0]['generated_text'] | |
except Exception as e: | |
error_msg = f"Audio processing error: {str(e)}" | |
print(error_msg) | |
state['answer'] = error_msg | |
return state | |
def node_llm(state: Dict[str, Any]) -> Dict[str, Any]: | |
print("Running node_llm") | |
question = state['question'] | |
# Optionally add context from state (e.g., Wikipedia/Wikidata content) | |
context_text = "" | |
if 'article_content' in state and state['article_content']: | |
context_text = f"\n\nBackground Information:\n{state['article_content']}\n" | |
elif 'context' in state and state['context']: | |
context_text = "\n\n".join([doc.page_content for doc in state['context']]) | |
# Compose a detailed prompt | |
prompt = ( | |
"You are an expert researcher. Answer the user's question as accurately as possible. " | |
"If the text appears to be scrambled, try to unscramble the text for the user" | |
"If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. " | |
"If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step.\n\n" | |
f"Question: {question}" | |
f"{context_text}\n" | |
"Answer:" | |
) | |
# Add document to state for traceability | |
query_doc = Document( | |
page_content=prompt, | |
metadata={"source": "llm_prompt"} | |
) | |
if 'context' not in state: | |
state['context'] = [] | |
state['context'].append(query_doc) | |
try: | |
result = llm_pipe(prompt) | |
state['answer'] = result[0]['generated_text'] | |
except Exception as e: | |
print(f"Error in LLM processing: {str(e)}") | |
state['answer'] = f"An error occurred while processing your question: {str(e)}" | |
print(f"LLM answer: {state['answer']}") | |
return state | |
# --- Define the edge condition function --- | |
def get_next_node(state: Dict[str, Any]) -> str: | |
"""Get the next node from the state.""" | |
return state["next"] | |
# 2. Improved Wikipedia Retrieval Node | |
def extract_keywords(question: str) -> List[str]: | |
doc = nlp(question) | |
keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] # Extract proper nouns and nouns | |
return keywords | |
def extract_entities(question: str) -> List[str]: | |
doc = nlp(question) | |
entities = [ent.text for ent in doc.ents] | |
return entities if entities else [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] | |
def retrieve(state: State) -> dict: | |
keywords = extract_entities(state["question"]) | |
query = " ".join(keywords) | |
search_results = wikipedia.search(query) | |
selected_page = search_results[0] if search_results else None | |
if selected_page: | |
loader = WikipediaLoader( | |
query=selected_page, | |
lang="en", | |
load_max_docs=1, | |
doc_content_chars_max=100000, | |
load_all_available_meta=True | |
) | |
docs = loader.load() | |
# Chunk the article for finer retrieval | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200) | |
all_chunks = [] | |
for doc in docs: | |
chunks = splitter.split_text(doc.page_content) | |
all_chunks.extend([Document(page_content=chunk) for chunk in chunks]) | |
# Optionally: re-rank or filter chunks here | |
return {"context": all_chunks} | |
else: | |
return {"context": []} | |
# 3. Prompt Template for General QA | |
prompt = PromptTemplate( | |
input_variables=["question", "context"], | |
template=( | |
"You are an expert researcher. Given the following context from Wikipedia, answer the user's question as accurately as possible. " | |
"If the text appears to be scrambled, try to unscramble the text for the user" | |
"If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. " | |
"If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step." | |
"Context:\n{context}\n\n" | |
"Question: {question}\n\n" | |
"Best Estimate Answer:" | |
) | |
) | |
""" | |
def generate(state: State) -> dict: | |
# Concatenate all context documents into a single string | |
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
# Format the prompt for the LLM | |
prompt_str = prompt.format(question=state["question"], context=docs_content) | |
# Generate answer | |
response = llm.invoke(prompt_str) | |
return {"answer": response} | |
""" | |
def generate(state: dict) -> dict: | |
# Concatenate all context documents into a single string | |
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
# Format the prompt for the LLM | |
prompt_str = prompt.format(question=state["question"], context=docs_content) | |
# Generate answer using Hugging Face pipeline | |
response = llm_pipe(prompt_str) | |
# Extract generated text | |
answer = response[0]["generated_text"] | |
return {"answer": answer} | |
# Create the StateGraph | |
graph = StateGraph(State) | |
# Add nodes | |
graph.add_node("decide", node_decide) | |
graph.add_node("video", node_video) | |
graph.add_node("llm", node_llm) | |
graph.add_node("retrieve", retrieve) | |
graph.add_node("generate", generate) | |
graph.add_node("image", node_image) | |
graph.add_node("audio", node_audio_rag) | |
# Add edge from START to decide | |
graph.add_edge(START, "decide") | |
graph.add_edge("retrieve", "generate") | |
# Add conditional edges from decide to video or llm based on question | |
graph.add_conditional_edges( | |
"decide", | |
get_next_node, | |
{ | |
"video": "video", | |
"llm": "llm", | |
"retrieve": "retrieve", | |
"image": "image", | |
"audio": "audio" | |
} | |
) | |
# Add edges from video and llm to END to terminate the graph | |
graph.add_edge("video", END) | |
graph.add_edge("llm", END) | |
graph.add_edge("generate", END) | |
graph.add_edge("image", END) | |
graph.add_edge("audio", END) | |
# Compile the graph | |
agent = graph.compile() | |
# --- Usage Example --- | |
def intelligent_agent(state: State) -> str: | |
"""Process a question using the appropriate pipeline based on content.""" | |
#state = State(question= question) | |
try: | |
final_state = agent.invoke(state) | |
return final_state.get('answer', "No answer found.") | |
except Exception as e: | |
print(f"Error in agent execution: {str(e)}") | |
return f"An error occurred: {str(e)}" | |