PDFgpt / app.py
swamisharan's picture
Create app.py
b5e0972 verified
raw
history blame
2.97 kB
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chromadb.config import Settings
# Initialize Chroma settings once
CHROMA_SETTINGS = Settings(
chroma_db_impl='duckdb+parquet',
persist_directory="db",
anonymized_telemetry=False
)
# Initialize the Chroma database on app start (assuming the database will be initialized only once)
def init_db_if_not_exists(pdf_path):
try:
# Check if the database exists and load it
db = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, client_settings=CHROMA_SETTINGS)
db.get_collection() # This line will raise an error if the collection doesn't exist
except Exception:
# If not, initialize the database
loader = PDFMinerLoader(pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
texts = text_splitter.split_documents(documents)
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = Chroma.from_documents(texts, embeddings, persist_directory=CHROMA_SETTINGS.persist_directory)
db.persist()
# Load model and create pipeline once
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float32)
llm_pipeline = HuggingFacePipeline(pipeline=pipeline("text2text-generation", model=base_model, tokenizer=tokenizer))
def process_answer(instruction):
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, embedding_function=embeddings)
retriever = vectordb.as_retriever()
qa = RetrievalQA.from_chain_type(llm=llm_pipeline, chain_type="stuff", retriever=retriever)
generated_text = qa(instruction)
return generated_text["result"]
def chatbot(pdf_file, user_question):
if pdf_file: # Only initialize if a new PDF is uploaded
init_db_if_not_exists(pdf_file.name)
try:
answer = process_answer(user_question)
return answer
except Exception as e:
return f"An error occurred: {str(e)}"
# Create Gradio Interface
iface = gr.Interface(
fn=chatbot,
inputs=[gr.inputs.File(type="file", label="Upload your PDF"), gr.inputs.Textbox(lines=1, label="Ask a Question")],
outputs="text",
title="PDF Chatbot",
description="Upload a PDF and ask questions about its content.",
)
# Run the Gradio interface
iface.launch()