Custom_RAG / rag_engine.py
hoshoo21
removing gguf file
e0313cc
import os
import shutil
import tempfile
from threading import Thread
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM, pipeline
from book_title_extractor import BookTitleExtractor
from duplicate_detector import DuplicateDetector
class StreamingHanlder():
def __init__(self):
self.buffer =[]
self.token_callback = None
def on_llm_new_token(self, token:str, **kwargs):
self.buffer.append(token)
if self.token_callback:
self.token_callback(token)
class RagEngine:
def _load_vectorstore(self):
if os.path.exists(os.path.join(self.persist_dir, "chroma.sqlite3")):
self.vectorstore = Chroma(
persist_directory=self.persist_dir,
embedding_function=self.embedding
)
self.retriever = self.vectorstore.as_retriever()
def __init__(self, persist_dir="chroma_store",embed_model= "nomic-embed-text",llm_model="qwen:1.8b", temp_dir ="chroma_temp"):
self.temp_dir = temp_dir
os.makedirs(self.temp_dir, exist_ok=True)
self.duplicate_detector = DuplicateDetector()
self.title_extractor = BookTitleExtractor()
self.embedding = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
self.vectorstore =None
self.retriever = None
self.persist_dir = "chroma_temp"
self._load_vectorstore()
self.model_id = "Qwen/Qwen-1_8B-Chat"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code = True)
self.model = AutoModelForCausalLM.from_pretrained(self.model_id,
trust_remote_code = True,
device_map ="auto",
torch_dtype = "auto")
self.model.eval()
def clear_temp(self):
shutil.rmtree(self.temp_dir,ignore_errors=True)
os.makedirs(self.temp_dir, exist_ok=True)
def index_pdf(self, pdf_path):
if self.duplicate_detector.is_duplicate(pdf_path):
raise ValueError(f"duplicate book detected, skipping index of: {pdf_path}")
return
self.duplicate_detector.store_fingerprints(pdf_path)
self.clear_temp()
filename = os.path.basename(pdf_path)
loader = PyPDFLoader(pdf_path)
documents = loader.load()
title = self.title_extractor.extract_book_title_from_documents(documents,max_docs=10)
for doc in documents:
doc.metadata["source"] = title
documents = [doc for doc in documents if doc.page_content.strip()]
if not documents:
raise ValueError("No Reasonable text in uploaded pdf")
splitter = RecursiveCharacterTextSplitter(chunk_size = 1000,chunk_overlap = 500 )
chunks = splitter.split_documents(documents)
if self.vectorstore is None:
self.vectorstore = Chroma.from_documents(
documents=chunks,
embedding=self.embedding,
persist_directory=self.temp_dir
)
self.vectorstore.persist()
else:
self.vectorstore.add_documents(chunks)
self.vectorstore.persist()
self.retriever = self.vectorstore.as_retriever()
def stream_answer(self, question):
if not self.retriever:
yield "data: ❗ Please upload and index a PDF first.\n\n"
return
docs = self.retriever.get_relevant_documents(question)
if not docs:
yield "data: ❗ No relevant documents found.\n\n"
return
sources = []
for doc in docs:
title = doc.metadata.get("source", "Unknown Title")
page = doc.metadata.get("page", "Unknown Page")
sources.append(f"{title} - Page {page}")
context = "\n\n".join([doc.page_content for doc in docs[:3]])
system_prompt = "You are a helpful assistant that only replies in English."
user_prompt = f"Context:\n{context}\n\nQuestion: {question}"
prompt = (
"<|im_start|>system\nYou are a helpful assistant that only replies in English.<|im_end|>\n"
f"<|im_start|>user\nContext:\n{context}\n\nQuestion: {question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
print (prompt)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
print("πŸ”’ Prompt token length:", inputs['input_ids'].shape[-1])
streamer = TextIteratorStreamer(
tokenizer=self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_args = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": 512,
"streamer": streamer,
"do_sample": False,
"temperature": 0.0,
"top_p": 0.95,
}
thread = Thread(target=self.model.generate, kwargs=generation_args)
thread.start()
collected_tokens = []
for token in streamer:
if token.strip(): # Filter out whitespace
collected_tokens.append(token)
yield f"{token} "
if sources:
sources_text = "\n\nπŸ“š **Sources:**\n" + "\n".join(set(sources))
for line in sources_text.splitlines():
if line.strip():
yield f"{line} \n"
yield "\n\n"
def ask_question(self, question):
print (question)
if not self.qa_chain :
return "please upload and index pdf document first"
result = self.qa_chain({"query":question})
answer = result["result"]
sources =[]
for doc in result["source_documents"]:
source = doc.metadata.get("source", "Unknown")
sources.append(source)
print (answer)
return {
"answer": answer,
"sources": list(set(sources)) # Remove duplicates
}
def ask_question_stream(self, question: str):
if not self.qa_chain:
yield "❗ Please upload and index a PDF document first."
return
from queue import Queue, Empty
import threading
q = Queue()
def token_callback(token):
q.put(token)
self.handler.buffer = []
self.handler.token_callback = token_callback
def run():
result = self.qa_chain.invoke({"query": question})
print (result)
self._latest_result = result
q.put(None)
threading.Thread(target=run).start()
print("Threading started", flush=True)
while True:
try:
token = q.get(timeout=30)
if token is None:
print("Stream finished", flush=True)
break
yield token
except Empty:
print("Timed out waiting for token", flush=True)
break
sources = []
for doc in self._latest_result.get("source_documents",[] ):
source = doc.metadata.get("source", "Unknown")
sources.append(source)
if sources:
yield "\n\nπŸ“š **Sources:**\n"
for i, src in enumerate(set(sources)):
yield f"[{i+1}] {src}\n"