Custom_RAG / rag_engine.py
hoshoo21
deployment
7a837d4
raw
history blame
5.59 kB
import os
import shutil
import tempfile
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain_community.llms import Ollama
from book_title_extractor import BookTitleExtractor
from duplicate_detector import DuplicateDetector
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_community.chat_models import ChatOllama
class StreamingHanlder(BaseCallbackHandler):
def __init__(self):
self.buffer =[]
self.token_callback = None
def on_llm_new_token(self, token:str, **kwargs):
self.buffer.append(token)
if self.token_callback:
self.token_callback(token)
class RagEngine:
def __init__(self, embed_model= "nomic-embed-text",llm_model="qwen:1.8b", temp_dir ="chroma_temp"):
self.embed_model = embed_model
self.llm_model = llm_model
self.embedding = OllamaEmbeddings(model=self.embed_model)
self.vectorstore = None
self.qa_chain = None
self.handler = StreamingHanlder()
self.llm = ChatOllama (model=self.llm_model, streaming= True, callbacks=[self.handler] )
self.temp_dir = temp_dir
os.makedirs(self.temp_dir, exist_ok=True)
self.title_extractor = BookTitleExtractor(llm=self.llm)
self.duplicate_detector = DuplicateDetector()
if os.path.exists(os.path.join(self.temp_dir, "chroma.sqlite3")):
print("πŸ” Loading existing Chroma vectorstore...")
self.vectorstore = Chroma(
persist_directory=self.temp_dir,
embedding_function=self.embedding
)
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
retriever=self.vectorstore.as_retriever(),
return_source_documents=True
)
print("Vectorstore and QA chain restored.")
def clear_temp(self):
shutil.rmtree(self.temp_dir,ignore_errors=True)
os.makedirs(self.temp_dir, exist_ok=True)
def index_pdf(self, pdf_path):
if self.duplicate_detector.is_duplicate(pdf_path):
raise ValueError(f"duplicate book detected, skipping index of: {pdf_path}")
return
self.duplicate_detector.store_fingerprints(pdf_path)
self.clear_temp()
filename = os.path.basename(pdf_path)
loader = PyPDFLoader(pdf_path)
documents = loader.load()
title = self.title_extractor.extract_book_title_from_documents(documents,max_docs=10)
for doc in documents:
doc.metadata["source"] = title
documents = [doc for doc in documents if doc.page_content.strip()]
if not documents:
raise ValueError("No Reasonable text in uploaded pdf")
splitter = RecursiveCharacterTextSplitter(chunk_size = 1000,chunk_overlap = 500 )
chunks = splitter.split_documents(documents)
if self.vectorstore is None:
self.vectorstore = Chroma.from_documents(
documents=chunks,
embedding=self.embedding,
persist_directory=self.temp_dir
)
self.vectorstore.persist()
else:
self.vectorstore.add_documents(chunks)
self.qa_chain = RetrievalQA.from_chain_type(
llm = self.llm,
retriever = self.vectorstore.as_retriever(),
return_source_documents = True
)
def ask_question(self, question):
print (question)
if not self.qa_chain :
return "please upload and index pdf document first"
result = self.qa_chain({"query":question})
answer = result["result"]
sources =[]
for doc in result["source_documents"]:
source = doc.metadata.get("source", "Unknown")
sources.append(source)
print (answer)
return {
"answer": answer,
"sources": list(set(sources)) # Remove duplicates
}
def ask_question_stream(self, question: str):
if not self.qa_chain:
yield "❗ Please upload and index a PDF document first."
return
from queue import Queue, Empty
import threading
q = Queue()
def token_callback(token):
q.put(token)
self.handler.buffer = []
self.handler.token_callback = token_callback
def run():
result = self.qa_chain.invoke({"query": question})
print (result)
self._latest_result = result
q.put(None)
threading.Thread(target=run).start()
print("Threading started", flush=True)
while True:
try:
token = q.get(timeout=30)
if token is None:
print("Stream finished", flush=True)
break
yield token
except Empty:
print("Timed out waiting for token", flush=True)
break
sources = []
for doc in self._latest_result.get("source_documents",[] ):
source = doc.metadata.get("source", "Unknown")
sources.append(source)
if sources:
yield "\n\nπŸ“š **Sources:**\n"
for i, src in enumerate(set(sources)):
yield f"[{i+1}] {src}\n"