demo-updated / middleware.py
Kazel's picture
logging
d901124
raw
history blame
2.94 kB
from colpali_manager import ColpaliManager
from milvus_manager import MilvusManager
from pdf_manager import PdfManager
import hashlib
pdf_manager = PdfManager()
colpali_manager = ColpaliManager()
class Middleware:
def __init__(self, id:str, create_collection=True):
#hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
hashed_id = 0 #switched to persistent db, shld use diff id for diff accs
milvus_db_name = f"milvus_{hashed_id}.db"
self.milvus_manager = MilvusManager(milvus_db_name, id, create_collection) #create collections based on id rather than colpali
def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
if type(pdf_path) == None: #for direct query without any upload to db
print("no docs")
return
print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
print(f"Saved {len(image_paths)} images")
colbert_vecs = colpali_manager.process_images(image_paths)
images_data = [{
"colbert_vecs": colbert_vecs[i],
"filepath": image_paths[i]
} for i in range(len(image_paths))]
print(f"Inserting {len(images_data)} images data to Milvus")
self.milvus_manager.insert_images_data(images_data)
print("Indexing completed")
return image_paths
def drop_collection(self):
"""Drop the current collection from Milvus"""
return self.milvus_manager.drop_collection()
def search(self, search_queries: list[str], topk: int = 10):
print(f"\nπŸ” MIDDLEWARE SEARCH INITIATED")
print(f"πŸ“ Queries to process: {len(search_queries)}")
print(f"🎯 Top-k requested: {topk}")
print("-" * 60)
final_res = []
for i, query in enumerate(search_queries, 1):
print(f"\nπŸ” Processing Query {i}/{len(search_queries)}: '{query}'")
print(f"πŸ“Š Converting query to vector representation...")
query_vec = colpali_manager.process_text([query])[0]
print(f"βœ… Query vector generated (dimension: {len(query_vec)})")
print(f"πŸ” Executing vector search in Milvus...")
search_res = self.milvus_manager.search(query_vec, topk=topk)
print(f"βœ… Search completed: {len(search_res)} results retrieved")
if search_res:
print(f"πŸ“Š Score range: {search_res[0][0]:.4f} (highest) to {search_res[-1][0]:.4f} (lowest)")
final_res.append(search_res)
print(f"\nπŸŽ‰ MIDDLEWARE SEARCH COMPLETED")
print(f"πŸ“Š Total queries processed: {len(search_queries)}")
print(f"πŸ“„ Total results across all queries: {sum(len(res) for res in final_res)}")
print("=" * 60)
return final_res