demo / app.py
Kazel's picture
Update app.py
37c70c9 verified
import gradio as gr
import tempfile
import os
import fitz # PyMuPDF
import uuid
import shutil
from pymilvus import MilvusClient
from middleware import Middleware
from rag import Rag
from pathlib import Path
import subprocess
import getpass
# importing necessary functions from dotenv library
from dotenv import load_dotenv, dotenv_values
import dotenv
import platform
import time
# loading variables from .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
#kickstart docker and ollama servers
rag = Rag()
def generate_uuid(state):
# Check if UUID already exists in session state
if state["user_uuid"] is None:
# Generate a new UUID if not already set
state["user_uuid"] = str(uuid.uuid4())
return state["user_uuid"]
class PDFSearchApp:
def __init__(self):
self.indexed_docs = {}
self.current_pdf = None
def upload_and_convert(self, state, files, max_pages):
#change id
#id = generate_uuid(state)
pages = 0
if files is None:
return "No file uploaded"
try: #if onlyy one file
for file in files[:]: # Iterate over a shallow copy of the list, TEST THIS
# Extract the last part of the path (file name)
filename = os.path.basename(file.name)
# Split the base name into name and extension
name, ext = os.path.splitext(filename)
self.current_pdf = file.name
pdf_path=file.name
#if ppt will get replaced with path of ppt!
# Replace spaces and hyphens with underscores in the name
modified_filename = name.replace(" ", "_").replace("-", "_")
id = modified_filename #if string cmi then serialize the name, test for later
print(f"Uploading file: {id}, id: abc")
middleware = Middleware(modified_filename, create_collection=True)
pages = middleware.index(pdf_path, id=id, max_pages=max_pages)
self.indexed_docs[id] = True
#clear files for next consec upload after loop is complete
files = []
return f"Uploaded and extracted all pages"
except Exception as e:
return f"Error processing PDF: {str(e)}"
def display_file_list(text):
try:
# Retrieve all entries in the specified directory
directory_path = "pages"
current_working_directory = os.getcwd()
directory_path = os.path.join(current_working_directory, directory_path)
entries = os.listdir(directory_path)
# Filter out entries that are directories
directories = [entry for entry in entries if os.path.isdir(os.path.join(directory_path, entry))]
return directories
except FileNotFoundError:
return f"The directory {directory_path} does not exist."
except PermissionError:
return f"Permission denied to access {directory_path}."
except Exception as e:
return str(e)
def search_documents(self, state, query, num_results=1):
print(f"Searching for query: {query}")
#id = generate_uuid(state)
id = "test" # not used anyway
"""
if not self.indexed_docs[id]:
print("Please index documents first")
return "Please index documents first", "--"
""" #edited out to allow direct query on db to test persistency
if not query:
print("Please enter a search query")
return "Please enter a search query", "--"
try:
middleware = Middleware(id, create_collection=False)
search_results = middleware.search([query])[0]
#direct retrieve file path rather than rely on page nums!
#try to retrieve multiple files rather than a single page (TBD)
page_num = search_results[0][1] +1 # final return value is a list of tuples, each tuple being: (score, doc_id, collection_name), so use [0][2] to get collection name of first ranked item, need +1!
coll_num = search_results[0][2]
print(f"Retrieved page number: {page_num}")
img_path = f"pages/{coll_num}/page_{page_num}.png"
path = f"pages/{coll_num}/page_{page_num}"
print(f"Retrieved image path: {img_path}")
rag_response = rag.get_answer_from_gemini(query, [img_path])
return path,img_path, rag_response
except Exception as e:
return f"Error during search: {str(e)}", "--"
def delete(state,choice):
#delete file in pages, then use middleware to delete collection
# 1. Create a milvus client
client = MilvusClient("./milvus_demo.db")
path = f"pages/{choice}"
if os.path.exists(path):
shutil.rmtree(path)
#call milvus manager to delete collection
client.drop_collection(collection_name=choice)
return f"Deleted {choice}"
else:
return "Directory not found"
def dbupdate(state,metric_type,m_num,ef_num,topk):
os.environ['metrictype'] = metric_type
# Update the .env file with the new value
dotenv.set_key(dotenv_file, 'metrictype', metric_type)
os.environ['mnum'] = str(m_num)
dotenv.set_key(dotenv_file, 'mnum', str(m_num))
os.environ['efnum'] = str(ef_num)
dotenv.set_key(dotenv_file, 'efnum', str(ef_num))
os.environ['topk'] = str(topk)
dotenv.set_key(dotenv_file, 'topk', str(topk))
return "DB Settings Updated, Restart App To Load"
def list_downloaded_hf_models(state):
# Determine the cache directory
hf_cache_dir = Path(os.getenv('HF_HOME', Path.home() / '.cache/huggingface/hub'))
# Initialize a list to store model names
model_names = []
# Traverse the cache directory
for repo_dir in hf_cache_dir.glob('models--*'):
# Extract the model name from the directory structure
model_name = repo_dir.name.split('--', 1)[-1].replace('--', '/')
model_names.append(model_name)
return model_names
def list_downloaded_ollama_models(state):
# Retrieve the current user's name
username = getpass.getuser()
# Construct the target directory path
#base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library" #this is for if ollama pull is called from C://, if ollama pulls are called from the proj dir, use the NEW_PATH in the proj dir!
base_path = f"NEW_PATH\\manifests\\registry.ollama.ai\\library" #relative to proj dir! (IMPT: OLLAMA PULL COMMAND IN PROJ DIR!!!)
try:
# List all entries in the directory
with os.scandir(base_path) as entries:
# Filter and print only directories
directories = [entry.name for entry in entries if entry.is_dir()]
return directories
except FileNotFoundError:
print(f"The directory {base_path} does not exist.")
except PermissionError:
print(f"Permission denied to access {base_path}.")
except Exception as e:
print(f"An error occurred: {e}")
def model_settings(state,hfchoice, ollamachoice,flash, temp):
os.environ['colpali'] = hfchoice
# Update the .env file with the new value
dotenv.set_key(dotenv_file, 'colpali', hfchoice)
os.environ['ollama'] = ollamachoice
dotenv.set_key(dotenv_file, 'ollama', ollamachoice)
if flash == "Enabled":
os.environ['flashattn'] = "1"
dotenv.set_key(dotenv_file, 'flashattn', "1")
else:
os.environ['flashattn'] = "0"
dotenv.set_key(dotenv_file, 'flashattn', "0")
os.environ['temperature'] = str(temp)
dotenv.set_key(dotenv_file, 'temperature', str(temp))
return "Models Updated, Restart App To Use New Settings"
def create_ui():
app = PDFSearchApp()
with gr.Blocks(theme=gr.themes.Ocean(), css ="""
footer a[href*="gradio.app"] {
display: none !important;
}
""") as demo:
# Overlay HTML and CSS
gr.HTML("""
<style>
#hello-overlay {
position: fixed;
z-index: 9999;
top: 0; left: 0; right: 0; bottom: 0;
width: 100vw; height: 100vh;
background: #000;
display: flex; flex-direction: column; align-items: center; justify-content: center;
transition: opacity 0.5s;
}
/* Hide the checkbox */
#hello-overlay input[type=checkbox] {
display: none;
}
/* Delayed overlay fadeout to match loading bar duration */
#hello-overlay:has(input[type=checkbox]:checked) {
animation: overlay-fadeout 0.3s forwards;
pointer-events: none;
}
@keyframes overlay-fadeout {
0% { opacity: 1; }
99% { opacity: 1; }
100% { opacity: 0; }
}
.grid-particles {
display: grid;
grid-template-columns: repeat(38, 1vw);
grid-template-rows: repeat(8, 1vw);
gap: 0.25vw;
justify-content: center;
align-items: center;
}
.particle {
width: 1.3vw;
height: 1.3vw;
opacity: 0;
border-radius: 50%;
box-shadow: none;
}
/* Letter color order: #164194, #00976f, #ec6608 */
.particle.letter-h,
.particle.letter-d {
background: linear-gradient(135deg, #164194, #164194);
}
.particle.letter-e,
.particle.letter-l1,
.particle.letter-l2,
.particle.letter-s,
.particle.letter-t {
background: linear-gradient(135deg, #00976f, #00976f);
}
.particle.letter-o,
.particle.letter-a {
background: linear-gradient(135deg, #ec6608, #ec6608);
}
/* Animation timing for particles */
.letter-h.particle { animation: fadein 6.8s linear infinite; animation-delay: 0.0s !important; }
.letter-e.particle { animation: fadein 6.8s linear infinite; animation-delay: 0.5s !important; }
.letter-l1.particle { animation: fadein 6.8s linear infinite; animation-delay: 1.0s !important; }
.letter-l2.particle { animation: fadein 6.8s linear infinite; animation-delay: 1.5s !important; }
.letter-o.particle { animation: fadein 6.8s linear infinite; animation-delay: 2.0s !important; }
.letter-d.particle { animation: fadein 6.8s linear infinite; animation-delay: 2.5s !important; }
.letter-s.particle { animation: fadein 6.8s linear infinite; animation-delay: 3.0s !important; }
.letter-t.particle { animation: fadein 6.8s linear infinite; animation-delay: 3.5s !important; }
.letter-a.particle { animation: fadein 6.8s linear infinite; animation-delay: 4.0s !important; }
@keyframes fadein {
0% { opacity: 0; transform: scale(0.2);}
11.76% { opacity: 1; transform: scale(1);}
73.53% { opacity: 1; transform: scale(1);}
100% { opacity: 0; transform: scale(0.2);}
}
label[for="enter-app-toggle"] {
margin-top:4vw;
padding:1.2rem 3rem;
font-size:2rem;
border:none;
border-radius:2.5rem;
background: linear-gradient(90deg, #164194, #00976f, #ec6608, #164194);
background-size: 300% 300%;
background-position: 0% 50%;
color: #fff;
font-weight: 700;
box-shadow: 0 8px 40px 0 #00976fcc, 0 2px 8px 0 #16419488;
cursor: pointer;
transition: background 0.3s, transform 0.2s, width 0.2s, padding 0.2s;
display: inline-block;
animation: gradient-move 3s linear infinite;
z-index: 2;
position: relative;
transform: none;
}
@keyframes gradient-move {
0% { background-position: 0% 50%; }
100% { background-position: 100% 50%; }
}
label[for="enter-app-toggle"]:hover {
transform: scale(1.08);
padding-left: 3.5rem;
padding-right: 3.5rem;
}
/* Loading bar styles */
.loading-bar-container {
width: 320px;
margin: 2vw auto 0 auto;
height: 18px;
background: #151515;
border-radius: 10px;
box-shadow: 0 2px 8px #000a inset;
overflow: hidden;
position: relative;
z-index: 3;
opacity: 1;
transition: opacity 0.4s;
display: none;
}
.loading-bar {
height: 100%;
width: 0%;
background: linear-gradient(90deg, #22ffde, #2196f3 80%);
border-radius: 10px;
transition: none;
}
/* Show loading bar and animate when button is pressed (focus/active) */
label[for="enter-app-toggle"]:focus ~ .loading-bar-container,
label[for="enter-app-toggle"]:active ~ .loading-bar-container {
display: block;
}
label[for="enter-app-toggle"]:focus ~ .loading-bar-container .loading-bar,
label[for="enter-app-toggle"]:active ~ .loading-bar-container .loading-bar {
animation: loading-bar-fill 0.3s linear forwards;
}
@keyframes loading-bar-fill {
0% { width: 0%; }
100% { width: 100%; }
}
</style>
<div id="hello-overlay">
<input type="checkbox" id="enter-app-toggle"/>
<div style="display: flex; flex-direction: column; align-items: center;">
<div class="grid-particles">
<!-- H -->
<div class="particle letter-h" style="grid-column:2;grid-row:2"></div>
<div class="particle letter-h" style="grid-column:2;grid-row:3"></div>
<div class="particle letter-h" style="grid-column:2;grid-row:4"></div>
<div class="particle letter-h" style="grid-column:2;grid-row:5"></div>
<div class="particle letter-h" style="grid-column:2;grid-row:6"></div>
<div class="particle letter-h" style="grid-column:3;grid-row:4"></div>
<div class="particle letter-h" style="grid-column:4;grid-row:2"></div>
<div class="particle letter-h" style="grid-column:4;grid-row:3"></div>
<div class="particle letter-h" style="grid-column:4;grid-row:4"></div>
<div class="particle letter-h" style="grid-column:4;grid-row:5"></div>
<div class="particle letter-h" style="grid-column:4;grid-row:6"></div>
<!-- E -->
<div class="particle letter-e" style="grid-column:6;grid-row:2"></div>
<div class="particle letter-e" style="grid-column:6;grid-row:3"></div>
<div class="particle letter-e" style="grid-column:6;grid-row:4"></div>
<div class="particle letter-e" style="grid-column:6;grid-row:5"></div>
<div class="particle letter-e" style="grid-column:6;grid-row:6"></div>
<div class="particle letter-e" style="grid-column:7;grid-row:2"></div>
<div class="particle letter-e" style="grid-column:7;grid-row:4"></div>
<div class="particle letter-e" style="grid-column:7;grid-row:6"></div>
<div class="particle letter-e" style="grid-column:8;grid-row:2"></div>
<div class="particle letter-e" style="grid-column:8;grid-row:6"></div>
<!-- L -->
<div class="particle letter-l1" style="grid-column:10;grid-row:2"></div>
<div class="particle letter-l1" style="grid-column:10;grid-row:3"></div>
<div class="particle letter-l1" style="grid-column:10;grid-row:4"></div>
<div class="particle letter-l1" style="grid-column:10;grid-row:5"></div>
<div class="particle letter-l1" style="grid-column:10;grid-row:6"></div>
<div class="particle letter-l1" style="grid-column:11;grid-row:6"></div>
<!-- L -->
<div class="particle letter-l2" style="grid-column:13;grid-row:2"></div>
<div class="particle letter-l2" style="grid-column:13;grid-row:3"></div>
<div class="particle letter-l2" style="grid-column:13;grid-row:4"></div>
<div class="particle letter-l2" style="grid-column:13;grid-row:5"></div>
<div class="particle letter-l2" style="grid-column:13;grid-row:6"></div>
<div class="particle letter-l2" style="grid-column:14;grid-row:6"></div>
<!-- O -->
<div class="particle letter-o" style="grid-column:16;grid-row:3"></div>
<div class="particle letter-o" style="grid-column:16;grid-row:4"></div>
<div class="particle letter-o" style="grid-column:16;grid-row:5"></div>
<div class="particle letter-o" style="grid-column:17;grid-row:2"></div>
<div class="particle letter-o" style="grid-column:17;grid-row:6"></div>
<div class="particle letter-o" style="grid-column:18;grid-row:3"></div>
<div class="particle letter-o" style="grid-column:18;grid-row:4"></div>
<div class="particle letter-o" style="grid-column:18;grid-row:5"></div>
<!-- D -->
<div class="particle letter-d" style="grid-column:23;grid-row:2"></div>
<div class="particle letter-d" style="grid-column:23;grid-row:3"></div>
<div class="particle letter-d" style="grid-column:23;grid-row:4"></div>
<div class="particle letter-d" style="grid-column:23;grid-row:5"></div>
<div class="particle letter-d" style="grid-column:23;grid-row:6"></div>
<div class="particle letter-d" style="grid-column:24;grid-row:2"></div>
<div class="particle letter-d" style="grid-column:24;grid-row:6"></div>
<div class="particle letter-d" style="grid-column:25;grid-row:3"></div>
<div class="particle letter-d" style="grid-column:25;grid-row:4"></div>
<div class="particle letter-d" style="grid-column:25;grid-row:5"></div>
<!-- S -->
<div class="particle letter-s" style="grid-column:27;grid-row:2"></div>
<div class="particle letter-s" style="grid-column:28;grid-row:2"></div>
<div class="particle letter-s" style="grid-column:29;grid-row:2"></div>
<div class="particle letter-s" style="grid-column:27;grid-row:3"></div>
<div class="particle letter-s" style="grid-column:27;grid-row:4"></div>
<div class="particle letter-s" style="grid-column:28;grid-row:4"></div>
<div class="particle letter-s" style="grid-column:29;grid-row:4"></div>
<div class="particle letter-s" style="grid-column:29;grid-row:5"></div>
<div class="particle letter-s" style="grid-column:27;grid-row:6"></div>
<div class="particle letter-s" style="grid-column:28;grid-row:6"></div>
<div class="particle letter-s" style="grid-column:29;grid-row:6"></div>
<!-- T -->
<div class="particle letter-t" style="grid-column:31;grid-row:2"></div>
<div class="particle letter-t" style="grid-column:32;grid-row:2"></div>
<div class="particle letter-t" style="grid-column:33;grid-row:2"></div>
<div class="particle letter-t" style="grid-column:32;grid-row:3"></div>
<div class="particle letter-t" style="grid-column:32;grid-row:4"></div>
<div class="particle letter-t" style="grid-column:32;grid-row:5"></div>
<div class="particle letter-t" style="grid-column:32;grid-row:6"></div>
<!-- A -->
<div class="particle letter-a" style="grid-column:36;grid-row:2"></div>
<div class="particle letter-a" style="grid-column:35;grid-row:3"></div>
<div class="particle letter-a" style="grid-column:37;grid-row:3"></div>
<div class="particle letter-a" style="grid-column:35;grid-row:4"></div>
<div class="particle letter-a" style="grid-column:36;grid-row:4"></div>
<div class="particle letter-a" style="grid-column:37;grid-row:4"></div>
<div class="particle letter-a" style="grid-column:35;grid-row:5"></div>
<div class="particle letter-a" style="grid-column:37;grid-row:5"></div>
<div class="particle letter-a" style="grid-column:35;grid-row:6"></div>
<div class="particle letter-a" style="grid-column:37;grid-row:6"></div>
</div>
<label for="enter-app-toggle" tabindex="0">Begin Demo</label>
<div class="loading-bar-container">
<div class="loading-bar"></div>
</div>
</div>
</div>
""")
state = gr.State(value={"user_uuid": None})
gr.Markdown("# Collar Multimodal RAG Demo")
gr.Markdown("Settings Available On Local Offline Setup")
with gr.Tab("Upload Documents"):
with gr.Column():
max_pages_input = gr.Slider(
minimum=1,
maximum=10000,
value=20,
step=10,
label="Max pages to extract and index per document"
)
file_input = gr.Files(label="Upload PPTs/PDFs")
file_list = gr.Textbox(label="Uploaded Files", interactive=False, value="Available on Local Setup")
status = gr.Textbox(label="Indexing Status", interactive=False)
with gr.Tab("Query"):
with gr.Column():
query_input = gr.Textbox(label="Enter query")
#num_results = gr.Slider(
# minimum=1,
# maximum=10,
# value=5,
# step=1,
# label="Number of results"
#)
search_btn = gr.Button("Query")
llm_answer = gr.Textbox(label="RAG Response", interactive=False)
path = gr.Textbox(label="Link To Document Page", interactive=False)
images = gr.Image(label="Top page matching query")
with gr.Tab("Data Settings"): #deletion of collections, changing of model parameters etc
with gr.Column():
# Button to delete (TBD)
choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
status1 = gr.Textbox(label="Deletion Status", interactive=False)
delete_button = gr.Button("Delete Document From DB")
# Create the dropdown component with default value as the first option
#Milvusindex = gr.Dropdown(["HNSW","FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "RHNSW_FLAT"], value="HNSW", label="Select Vector DB Index Parameter")
metric_type = gr.Dropdown(choices=["IP", "L2", "COSINE"],value="IP",label="Metric Type (Mathematical function to measure similarity)")
m_num = gr.Dropdown(
choices=["8", "16", "32", "64"], value="16",label="M Vectors (Maximum number of neighbors each node can connect to in the graph)")
ef_num = gr.Slider(
minimum=50,
maximum=1000,
value=500,
step=10,
label="EF Construction (Number of candidate neighbors considered for connection during index construction)"
)
topk = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-K (Maximum number of entities to return in a single search of a document)"
)
db_button = gr.Button("Update DB Settings")
status3 = gr.Textbox(label="DB Update Status", interactive=False)
with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
with gr.Column():
# Button to delete (TBD)
hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),value=os.environ['colpali'], label="Primary Visual Model")
ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),value=os.environ['ollama'],label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
flash = gr.Dropdown(["Enabled","Disabled"], value = "Enabled",label ="Flash Attention 2.0 Acceleration")
temp = gr.Slider(
minimum=0.1,
maximum=1,
value=0.8,
step=0.1,
label="RAG Temperature"
)
model_button = gr.Button("Update Settings")
status2 = gr.Textbox(label="Update Status", interactive=False)
# Event handlers
file_input.change(
fn=app.upload_and_convert,
inputs=[state, file_input, max_pages_input],
outputs=[status]
)
search_btn.click(
#try to query without uploading first
fn= app.search_documents,
inputs=[state, query_input],
outputs=[path,images, llm_answer]
)
"""
delete_button.click(
fn=app.delete,
inputs=[choice],
outputs=[status1]
)
db_button.click(
fn=app.dbupdate,
inputs=[metric_type,m_num,ef_num,topk],
outputs=[status3]
)
model_button.click(
fn=app.model_settings,
inputs=[hfchoice, ollamachoice,flash,temp],
outputs=[status2]
)
"""
return demo
if __name__ == "__main__":
demo = create_ui()
#demo.launch(auth=("admin", "pass1234")) # for with login page config
demo.launch()