File size: 2,967 Bytes
b5e0972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chromadb.config import Settings

# Initialize Chroma settings once
CHROMA_SETTINGS = Settings(
    chroma_db_impl='duckdb+parquet',
    persist_directory="db",
    anonymized_telemetry=False
)

# Initialize the Chroma database on app start (assuming the database will be initialized only once)
def init_db_if_not_exists(pdf_path):
    try:
        # Check if the database exists and load it
        db = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, client_settings=CHROMA_SETTINGS)
        db.get_collection()  # This line will raise an error if the collection doesn't exist
    except Exception:
        # If not, initialize the database
        loader = PDFMinerLoader(pdf_path)
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
        texts = text_splitter.split_documents(documents)
        embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
        db = Chroma.from_documents(texts, embeddings, persist_directory=CHROMA_SETTINGS.persist_directory)
        db.persist()

# Load model and create pipeline once
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float32)
llm_pipeline = HuggingFacePipeline(pipeline=pipeline("text2text-generation", model=base_model, tokenizer=tokenizer))

def process_answer(instruction):
    embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    vectordb = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, embedding_function=embeddings)
    retriever = vectordb.as_retriever()
    qa = RetrievalQA.from_chain_type(llm=llm_pipeline, chain_type="stuff", retriever=retriever)
    generated_text = qa(instruction)
    return generated_text["result"]

def chatbot(pdf_file, user_question):
    if pdf_file:  # Only initialize if a new PDF is uploaded
        init_db_if_not_exists(pdf_file.name)
    try:
        answer = process_answer(user_question)
        return answer
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Create Gradio Interface
iface = gr.Interface(
    fn=chatbot,
    inputs=[gr.inputs.File(type="file", label="Upload your PDF"), gr.inputs.Textbox(lines=1, label="Ask a Question")],
    outputs="text",
    title="PDF Chatbot",
    description="Upload a PDF and ask questions about its content.",
)

# Run the Gradio interface
iface.launch()