Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import tempfile | |
from threading import Thread | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM, pipeline | |
from book_title_extractor import BookTitleExtractor | |
from duplicate_detector import DuplicateDetector | |
class StreamingHanlder(): | |
def __init__(self): | |
self.buffer =[] | |
self.token_callback = None | |
def on_llm_new_token(self, token:str, **kwargs): | |
self.buffer.append(token) | |
if self.token_callback: | |
self.token_callback(token) | |
class RagEngine: | |
def _load_vectorstore(self): | |
if os.path.exists(os.path.join(self.persist_dir, "chroma.sqlite3")): | |
self.vectorstore = Chroma( | |
persist_directory=self.persist_dir, | |
embedding_function=self.embedding | |
) | |
self.retriever = self.vectorstore.as_retriever() | |
def __init__(self, persist_dir="chroma_store",embed_model= "nomic-embed-text",llm_model="qwen:1.8b", temp_dir ="chroma_temp"): | |
self.temp_dir = temp_dir | |
os.makedirs(self.temp_dir, exist_ok=True) | |
self.duplicate_detector = DuplicateDetector() | |
self.title_extractor = BookTitleExtractor() | |
self.embedding = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
) | |
self.vectorstore =None | |
self.retriever = None | |
self.persist_dir = "chroma_temp" | |
self._load_vectorstore() | |
self.model_id = "Qwen/Qwen-1_8B-Chat" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code = True) | |
self.model = AutoModelForCausalLM.from_pretrained(self.model_id, | |
trust_remote_code = True, | |
device_map ="auto", | |
torch_dtype = "auto") | |
self.model.eval() | |
def clear_temp(self): | |
shutil.rmtree(self.temp_dir,ignore_errors=True) | |
os.makedirs(self.temp_dir, exist_ok=True) | |
def index_pdf(self, pdf_path): | |
if self.duplicate_detector.is_duplicate(pdf_path): | |
raise ValueError(f"duplicate book detected, skipping index of: {pdf_path}") | |
return | |
self.duplicate_detector.store_fingerprints(pdf_path) | |
self.clear_temp() | |
filename = os.path.basename(pdf_path) | |
loader = PyPDFLoader(pdf_path) | |
documents = loader.load() | |
title = self.title_extractor.extract_book_title_from_documents(documents,max_docs=10) | |
for doc in documents: | |
doc.metadata["source"] = title | |
documents = [doc for doc in documents if doc.page_content.strip()] | |
if not documents: | |
raise ValueError("No Reasonable text in uploaded pdf") | |
splitter = RecursiveCharacterTextSplitter(chunk_size = 1000,chunk_overlap = 500 ) | |
chunks = splitter.split_documents(documents) | |
if self.vectorstore is None: | |
self.vectorstore = Chroma.from_documents( | |
documents=chunks, | |
embedding=self.embedding, | |
persist_directory=self.temp_dir | |
) | |
self.vectorstore.persist() | |
else: | |
self.vectorstore.add_documents(chunks) | |
self.vectorstore.persist() | |
self.retriever = self.vectorstore.as_retriever() | |
def stream_answer(self, question): | |
if not self.retriever: | |
yield "data: β Please upload and index a PDF first.\n\n" | |
return | |
docs = self.retriever.get_relevant_documents(question) | |
if not docs: | |
yield "data: β No relevant documents found.\n\n" | |
return | |
sources = [] | |
for doc in docs: | |
title = doc.metadata.get("source", "Unknown Title") | |
page = doc.metadata.get("page", "Unknown Page") | |
sources.append(f"{title} - Page {page}") | |
context = "\n\n".join([doc.page_content for doc in docs[:3]]) | |
system_prompt = "You are a helpful assistant that only replies in English." | |
user_prompt = f"Context:\n{context}\n\nQuestion: {question}" | |
prompt = ( | |
"<|im_start|>system\nYou are a helpful assistant that only replies in English.<|im_end|>\n" | |
f"<|im_start|>user\nContext:\n{context}\n\nQuestion: {question}<|im_end|>\n" | |
"<|im_start|>assistant\n" | |
) | |
print (prompt) | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
print("π’ Prompt token length:", inputs['input_ids'].shape[-1]) | |
streamer = TextIteratorStreamer( | |
tokenizer=self.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generation_args = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens": 512, | |
"streamer": streamer, | |
"do_sample": False, | |
"temperature": 0.0, | |
"top_p": 0.95, | |
} | |
thread = Thread(target=self.model.generate, kwargs=generation_args) | |
thread.start() | |
collected_tokens = [] | |
for token in streamer: | |
if token.strip(): # Filter out whitespace | |
collected_tokens.append(token) | |
yield f"{token} " | |
if sources: | |
sources_text = "\n\nπ **Sources:**\n" + "\n".join(set(sources)) | |
for line in sources_text.splitlines(): | |
if line.strip(): | |
yield f"{line} \n" | |
yield "\n\n" | |
def ask_question(self, question): | |
print (question) | |
if not self.qa_chain : | |
return "please upload and index pdf document first" | |
result = self.qa_chain({"query":question}) | |
answer = result["result"] | |
sources =[] | |
for doc in result["source_documents"]: | |
source = doc.metadata.get("source", "Unknown") | |
sources.append(source) | |
print (answer) | |
return { | |
"answer": answer, | |
"sources": list(set(sources)) # Remove duplicates | |
} | |
def ask_question_stream(self, question: str): | |
if not self.qa_chain: | |
yield "β Please upload and index a PDF document first." | |
return | |
from queue import Queue, Empty | |
import threading | |
q = Queue() | |
def token_callback(token): | |
q.put(token) | |
self.handler.buffer = [] | |
self.handler.token_callback = token_callback | |
def run(): | |
result = self.qa_chain.invoke({"query": question}) | |
print (result) | |
self._latest_result = result | |
q.put(None) | |
threading.Thread(target=run).start() | |
print("Threading started", flush=True) | |
while True: | |
try: | |
token = q.get(timeout=30) | |
if token is None: | |
print("Stream finished", flush=True) | |
break | |
yield token | |
except Empty: | |
print("Timed out waiting for token", flush=True) | |
break | |
sources = [] | |
for doc in self._latest_result.get("source_documents",[] ): | |
source = doc.metadata.get("source", "Unknown") | |
sources.append(source) | |
if sources: | |
yield "\n\nπ **Sources:**\n" | |
for i, src in enumerate(set(sources)): | |
yield f"[{i+1}] {src}\n" |