Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from langchain.embeddings import SentenceTransformerEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import PDFMinerLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from chromadb.config import Settings | |
# Initialize Chroma settings once | |
CHROMA_SETTINGS = Settings( | |
chroma_db_impl='duckdb+parquet', | |
persist_directory="db", | |
anonymized_telemetry=False | |
) | |
# Initialize the Chroma database on app start (assuming the database will be initialized only once) | |
def init_db_if_not_exists(pdf_path): | |
try: | |
# Check if the database exists and load it | |
db = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, client_settings=CHROMA_SETTINGS) | |
db.get_collection() # This line will raise an error if the collection doesn't exist | |
except Exception: | |
# If not, initialize the database | |
loader = PDFMinerLoader(pdf_path) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
texts = text_splitter.split_documents(documents) | |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
db = Chroma.from_documents(texts, embeddings, persist_directory=CHROMA_SETTINGS.persist_directory) | |
db.persist() | |
# Load model and create pipeline once | |
checkpoint = "MBZUAI/LaMini-Flan-T5-783M" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float32) | |
llm_pipeline = HuggingFacePipeline(pipeline=pipeline("text2text-generation", model=base_model, tokenizer=tokenizer)) | |
def process_answer(instruction): | |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
vectordb = Chroma(persist_directory=CHROMA_SETTINGS.persist_directory, embedding_function=embeddings) | |
retriever = vectordb.as_retriever() | |
qa = RetrievalQA.from_chain_type(llm=llm_pipeline, chain_type="stuff", retriever=retriever) | |
generated_text = qa(instruction) | |
return generated_text["result"] | |
def chatbot(pdf_file, user_question): | |
if pdf_file: # Only initialize if a new PDF is uploaded | |
init_db_if_not_exists(pdf_file.name) | |
try: | |
answer = process_answer(user_question) | |
return answer | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=chatbot, | |
inputs=[gr.inputs.File(type="file", label="Upload your PDF"), gr.inputs.Textbox(lines=1, label="Ask a Question")], | |
outputs="text", | |
title="PDF Chatbot", | |
description="Upload a PDF and ask questions about its content.", | |
) | |
# Run the Gradio interface | |
iface.launch() |