Spaces:
Runtime error
Runtime error
| %%writefile app.py | |
| import gradio as gr | |
| from unstructured.partition.pdf import partition_pdf | |
| import pymupdf | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| import pandas as pd | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| import gc | |
| import torch | |
| import chromadb | |
| from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
| from chromadb.utils.data_loaders import ImageLoader | |
| from sentence_transformers import SentenceTransformer | |
| from chromadb.utils import embedding_functions | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| import base64 | |
| from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain import PromptTemplate | |
| import spaces | |
| if torch.cuda.is_available(): | |
| processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
| vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| "llava-hf/llava-v1.6-mistral-7b-hf", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| load_in_4bit=True, | |
| ) | |
| def image_to_bytes(image): | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format="PNG") | |
| return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") | |
| def get_image_descriptions(images): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| descriptions = [] | |
| prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
| for img in images: | |
| inputs = processor(prompt, img, return_tensors="pt").to("cuda:0") | |
| output = vision_model.generate(**inputs, max_new_tokens=100) | |
| descriptions.append(processor.decode(output[0], skip_special_tokens=True)) | |
| return descriptions | |
| CSS = """ | |
| #table_col {background-color: rgb(33, 41, 54);} | |
| """ | |
| def extract_pdfs(docs, doc_collection): | |
| if docs: | |
| doc_collection = [] | |
| doc_collection.extend(docs) | |
| return ( | |
| doc_collection, | |
| gr.Tabs(selected=1), | |
| pd.DataFrame([i.split("/")[-1] for i in list(docs)], columns=["Filename"]), | |
| ) | |
| def extract_images(docs): | |
| images = [] | |
| for doc_path in docs: | |
| doc = pymupdf.open(doc_path) # open a document | |
| for page_index in range(len(doc)): # iterate over pdf pages | |
| page = doc[page_index] # get the page | |
| image_list = page.get_images() | |
| for image_index, img in enumerate( | |
| image_list, start=1 | |
| ): # enumerate the image list | |
| xref = img[0] # get the XREF of the image | |
| pix = pymupdf.Pixmap(doc, xref) # create a Pixmap | |
| if pix.n - pix.alpha > 3: # CMYK: convert to RGB first | |
| pix = pymupdf.Pixmap(pymupdf.csRGB, pix) | |
| images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG")))) | |
| return images | |
| # def get_vectordb(text, images, tables): | |
| def get_vectordb(text, images): | |
| client = chromadb.EphemeralClient() | |
| loader = ImageLoader() | |
| sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="multi-qa-mpnet-base-dot-v1" | |
| ) | |
| if "text_db" in [i.name for i in client.list_collections()]: | |
| client.delete_collection("text_db") | |
| if "image_db" in [i.name for i in client.list_collections()]: | |
| client.delete_collection("image_db") | |
| text_collection = client.get_or_create_collection( | |
| name="text_db", | |
| embedding_function=sentence_transformer_ef, | |
| data_loader=loader, | |
| ) | |
| image_collection = client.get_or_create_collection( | |
| name="image_db", | |
| embedding_function=sentence_transformer_ef, | |
| data_loader=loader, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| image_descriptions = get_image_descriptions(images) | |
| image_dict = [{"image": image_to_bytes(img) for img in images}] | |
| if len(images)>0: | |
| image_collection.add( | |
| ids=[str(i) for i in range(len(images))], | |
| documents=image_descriptions, | |
| metadatas=image_dict, | |
| ) | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=10, | |
| ) | |
| if len(text)>0: | |
| docs = splitter.create_documents([text]) | |
| doc_texts = [i.page_content for i in docs] | |
| text_collection.add( | |
| ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts | |
| ) | |
| return client | |
| def extract_data_from_pdfs(docs, session, progress=gr.Progress()): | |
| if len(docs) == 0: | |
| raise gr.Error("No documents to process") | |
| progress(0, "Extracting Images") | |
| images = extract_images(docs) | |
| progress(0.25, "Extracting Text") | |
| strategy = "hi_res" | |
| model_name = "yolox" | |
| all_elements = [] | |
| for doc in docs: | |
| elements = partition_pdf( | |
| filename=doc, | |
| strategy=strategy, | |
| infer_table_structure=True, | |
| model_name=model_name, | |
| ) | |
| all_elements.extend(elements) | |
| all_text = "" | |
| # tables = [] | |
| prev = None | |
| for i in all_elements: | |
| meta = i.to_dict() | |
| if meta["type"].lower() not in ["table", "figurecaption"]: | |
| if meta["type"].lower() in ["listitem", "title"]: | |
| all_text += "\n\n" + meta["text"] + "\n" | |
| else: | |
| all_text += meta["text"] | |
| elif meta["type"] == "Table": | |
| continue | |
| # tables.append(meta["metadata"]["text_as_html"]) | |
| # html = "<br>".join(tables) | |
| # display = "<h3>Sample Tables</h3>" + "<br>".join(tables[:2]) | |
| # html = gr.HTML(html) | |
| # vectordb = get_vectordb(all_text, images, tables) | |
| progress(0.5, "Generating image descriptions") | |
| image_descriptions = "\n".join(get_image_descriptions(images)) | |
| progress(0.75, "Inserting data into vector database") | |
| vectordb = get_vectordb(all_text, images) | |
| progress(1, "Completed") | |
| session["processed"] = True | |
| return ( | |
| vectordb, | |
| session, | |
| gr.Row(visible=True), | |
| all_text[:2000] + "...", | |
| # display, | |
| images[:2], | |
| "<h1 style='text-align: center'>Completed<h1>", | |
| # image_descriptions | |
| ) | |
| sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="multi-qa-mpnet-base-dot-v1" | |
| ) | |
| def conversation(vectordb_client, msg, num_context, img_context, history): | |
| text_collection = vectordb_client.get_collection( | |
| "text_db", embedding_function=sentence_transformer_ef | |
| ) | |
| image_collection = vectordb_client.get_collection( | |
| "image_db", embedding_function=sentence_transformer_ef | |
| ) | |
| results = text_collection.query( | |
| query_texts=[msg], include=["documents"], n_results=num_context | |
| )["documents"][0] | |
| # print(results) | |
| # print("R"*100) | |
| similar_images = image_collection.query( | |
| query_texts=[msg], | |
| include=["metadatas", "distances", "documents"], | |
| n_results=img_context, | |
| ) | |
| img_links = [i["image"] for i in similar_images["metadatas"][0]] | |
| images_and_locs = [ | |
| Image.open(io.BytesIO(base64.b64decode(i[1]))) | |
| for i in zip(similar_images["distances"][0], img_links) | |
| ] | |
| img_desc = "\n".join(similar_images["documents"][0]) | |
| if len(img_links) == 0: | |
| img_desc = "No Images Are Provided" | |
| template = """ | |
| Context: | |
| {context} | |
| Included Images: | |
| {images} | |
| Question: | |
| {question} | |
| Answer: | |
| """ | |
| prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
| context = "\n\n".join(results) | |
| # references = [gr.Textbox(i, visible=True, interactive=False) for i in results] | |
| response = llm(prompt.format(context=context, question=msg, images=img_desc)) | |
| return history + [(msg, response)], results, images_and_locs | |
| def check_validity_and_llm(session_states): | |
| if session_states.get("processed", False) == True: | |
| return gr.Tabs(selected=2) | |
| raise gr.Error("Please extract data first") | |
| def get_stats(vectordb): | |
| eles = vectordb.get() | |
| # words = | |
| text_data = [f"Chunks: {len(eles)}", "HIII"] | |
| return "\n".join(text_data), "", "" | |
| llm = HuggingFaceEndpoint( | |
| repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| temperature=0.4, | |
| max_new_tokens=800, | |
| ) | |
| with gr.Blocks(css=CSS) as demo: | |
| vectordb = gr.State() | |
| doc_collection = gr.State(value=[]) | |
| session_states = gr.State(value={}) | |
| references = gr.State(value=[]) | |
| gr.Markdown( | |
| """<h2><center>Multimodal PDF Chatbot</center></h2> | |
| <h3><center><b>Interact With Your PDF Documents</b></center></h3>""" | |
| ) | |
| gr.Markdown( | |
| """<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br> | |
| <center>Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.</center>""" | |
| ) | |
| gr.Markdown( | |
| """ | |
| <center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center> | |
| """ | |
| ) | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Upload PDFs", id=0) as pdf_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| documents = gr.File( | |
| file_count="multiple", | |
| file_types=["pdf"], | |
| interactive=True, | |
| label="Upload your PDF file/s", | |
| ) | |
| pdf_btn = gr.Button(value="Next", elem_id="button1") | |
| with gr.TabItem("Extract Data", id=1) as preprocess: | |
| with gr.Row(): | |
| with gr.Column(): | |
| back_p1 = gr.Button(value="Back") | |
| with gr.Column(): | |
| embed = gr.Button(value="Extract Data") | |
| with gr.Column(): | |
| next_p1 = gr.Button(value="Next") | |
| with gr.Row() as row: | |
| with gr.Column(): | |
| selected = gr.Dataframe( | |
| interactive=False, | |
| col_count=(1, "fixed"), | |
| headers=["Selected Files"], | |
| ) | |
| with gr.Column(variant="panel"): | |
| prog = gr.HTML( | |
| value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>" | |
| ) | |
| with gr.Accordion("See Parts of Extracted Data", open=False): | |
| with gr.Column(visible=True) as sample_data: | |
| with gr.Row(): | |
| with gr.Column(): | |
| ext_text = gr.Textbox( | |
| label="Sample Extracted Text", lines=15 | |
| ) | |
| with gr.Column(): | |
| images = gr.Gallery( | |
| label="Sample Extracted Images", columns=1, rows=2 | |
| ) | |
| with gr.TabItem("Chat", id=2) as chat_tab: | |
| with gr.Column(): | |
| choice = gr.Radio( | |
| ["chromaDB"], | |
| value="chromaDB", | |
| label="Vector Database", | |
| interactive=True, | |
| ) | |
| num_context = gr.Slider( | |
| label="Number of text context elements", | |
| minimum=1, | |
| maximum=20, | |
| step=1, | |
| interactive=True, | |
| value=3, | |
| ) | |
| img_context = gr.Slider( | |
| label="Number of image context elements", | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| interactive=True, | |
| value=2, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| ret_images = gr.Gallery("Similar Images", columns=1, rows=2) | |
| with gr.Column(): | |
| chatbot = gr.Chatbot(height=400) | |
| with gr.Accordion("Text References", open=False): | |
| # text_context = gr.Row() | |
| def gen_refs(references): | |
| # print(references) | |
| n = len(references) | |
| for i in range(n): | |
| gr.Textbox(label=f"Reference-{i+1}", value=references[i], lines=3) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type your question here (e.g. 'What is this document about?')", | |
| interactive=True, | |
| container=True, | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit message") | |
| clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation") | |
| pdf_btn.click( | |
| fn=extract_pdfs, | |
| inputs=[documents, doc_collection], | |
| outputs=[doc_collection, tabs, selected], | |
| ) | |
| embed.click( | |
| extract_data_from_pdfs, | |
| inputs=[doc_collection, session_states], | |
| outputs=[ | |
| vectordb, | |
| session_states, | |
| sample_data, | |
| ext_text, | |
| images, | |
| prog, | |
| ], | |
| ) | |
| submit_btn.click( | |
| conversation, | |
| [vectordb, msg, num_context, img_context, chatbot], | |
| [chatbot,references ,ret_images], | |
| ) | |
| back_p1.click(lambda: gr.Tabs(selected=0), None, tabs) | |
| next_p1.click(check_validity_and_llm, session_states, tabs) | |
| if __name__ == "__main__": | |
| demo.launch() |