Pdf_mugger / document_chat.py
Anirudh1993's picture
Update document_chat.py
c5844af verified
raw
history blame
2.16 kB
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import HuggingFaceHub
from langchain.memory import ConversationBufferMemory
# Constants
CHROMA_DB_PATH = "chroma_db"
SENTENCE_TRANSFORMER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "HuggingFaceH4/zephyr-7b-beta"
# Initialize vector store
def initialize_vector_store():
embeddings = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL)
return Chroma(persist_directory=CHROMA_DB_PATH, embedding_function=embeddings)
vector_store = initialize_vector_store()
def ingest_pdf(pdf_path):
"""Loads, splits, and stores PDF content in a vector database."""
loader = PyMuPDFLoader(pdf_path)
documents = loader.load()
# Split text into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(documents)
# Re-initialize vector store to ensure persistence
vector_store.add_documents(split_docs)
vector_store.persist()
def process_query_with_memory(query, chat_memory):
"""Processes user queries while maintaining conversational memory."""
retriever = vector_store.as_retriever()
# Initialize LLM
llm = HuggingFaceHub(repo_id=LLM_MODEL, model_kwargs={"max_new_tokens": 500})
# Create Conversational Retrieval Chain correctly
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=chat_memory
)
# Fix: Properly load chat history
chat_history = chat_memory.load_memory_variables({}).get("chat_history", [])
return conversation_chain.run({"question": query, "chat_history": chat_history})