Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import tempfile | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_ollama import OllamaEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import Ollama | |
from book_title_extractor import BookTitleExtractor | |
from duplicate_detector import DuplicateDetector | |
from langchain_core.callbacks.base import BaseCallbackHandler | |
from langchain_community.chat_models import ChatOllama | |
class StreamingHanlder(BaseCallbackHandler): | |
def __init__(self): | |
self.buffer =[] | |
self.token_callback = None | |
def on_llm_new_token(self, token:str, **kwargs): | |
self.buffer.append(token) | |
if self.token_callback: | |
self.token_callback(token) | |
class RagEngine: | |
def __init__(self, embed_model= "nomic-embed-text",llm_model="qwen:1.8b", temp_dir ="chroma_temp"): | |
self.embed_model = embed_model | |
self.llm_model = llm_model | |
self.embedding = OllamaEmbeddings(model=self.embed_model) | |
self.vectorstore = None | |
self.qa_chain = None | |
self.handler = StreamingHanlder() | |
self.llm = ChatOllama (model=self.llm_model, streaming= True, callbacks=[self.handler] ) | |
self.temp_dir = temp_dir | |
os.makedirs(self.temp_dir, exist_ok=True) | |
self.title_extractor = BookTitleExtractor(llm=self.llm) | |
self.duplicate_detector = DuplicateDetector() | |
if os.path.exists(os.path.join(self.temp_dir, "chroma.sqlite3")): | |
print("π Loading existing Chroma vectorstore...") | |
self.vectorstore = Chroma( | |
persist_directory=self.temp_dir, | |
embedding_function=self.embedding | |
) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
retriever=self.vectorstore.as_retriever(), | |
return_source_documents=True | |
) | |
print("Vectorstore and QA chain restored.") | |
def clear_temp(self): | |
shutil.rmtree(self.temp_dir,ignore_errors=True) | |
os.makedirs(self.temp_dir, exist_ok=True) | |
def index_pdf(self, pdf_path): | |
if self.duplicate_detector.is_duplicate(pdf_path): | |
raise ValueError(f"duplicate book detected, skipping index of: {pdf_path}") | |
return | |
self.duplicate_detector.store_fingerprints(pdf_path) | |
self.clear_temp() | |
filename = os.path.basename(pdf_path) | |
loader = PyPDFLoader(pdf_path) | |
documents = loader.load() | |
title = self.title_extractor.extract_book_title_from_documents(documents,max_docs=10) | |
for doc in documents: | |
doc.metadata["source"] = title | |
documents = [doc for doc in documents if doc.page_content.strip()] | |
if not documents: | |
raise ValueError("No Reasonable text in uploaded pdf") | |
splitter = RecursiveCharacterTextSplitter(chunk_size = 1000,chunk_overlap = 500 ) | |
chunks = splitter.split_documents(documents) | |
if self.vectorstore is None: | |
self.vectorstore = Chroma.from_documents( | |
documents=chunks, | |
embedding=self.embedding, | |
persist_directory=self.temp_dir | |
) | |
self.vectorstore.persist() | |
else: | |
self.vectorstore.add_documents(chunks) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm = self.llm, | |
retriever = self.vectorstore.as_retriever(), | |
return_source_documents = True | |
) | |
def ask_question(self, question): | |
print (question) | |
if not self.qa_chain : | |
return "please upload and index pdf document first" | |
result = self.qa_chain({"query":question}) | |
answer = result["result"] | |
sources =[] | |
for doc in result["source_documents"]: | |
source = doc.metadata.get("source", "Unknown") | |
sources.append(source) | |
print (answer) | |
return { | |
"answer": answer, | |
"sources": list(set(sources)) # Remove duplicates | |
} | |
def ask_question_stream(self, question: str): | |
if not self.qa_chain: | |
yield "β Please upload and index a PDF document first." | |
return | |
from queue import Queue, Empty | |
import threading | |
q = Queue() | |
def token_callback(token): | |
q.put(token) | |
self.handler.buffer = [] | |
self.handler.token_callback = token_callback | |
def run(): | |
result = self.qa_chain.invoke({"query": question}) | |
print (result) | |
self._latest_result = result | |
q.put(None) | |
threading.Thread(target=run).start() | |
print("Threading started", flush=True) | |
while True: | |
try: | |
token = q.get(timeout=30) | |
if token is None: | |
print("Stream finished", flush=True) | |
break | |
yield token | |
except Empty: | |
print("Timed out waiting for token", flush=True) | |
break | |
sources = [] | |
for doc in self._latest_result.get("source_documents",[] ): | |
source = doc.metadata.get("source", "Unknown") | |
sources.append(source) | |
if sources: | |
yield "\n\nπ **Sources:**\n" | |
for i, src in enumerate(set(sources)): | |
yield f"[{i+1}] {src}\n" |