Spaces:
Running
on
T4
Running
on
T4
| #!/usr/bin/env python3 | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from PIL import Image | |
| import numpy as np | |
| from typing import cast | |
| import asyncio | |
| from colpali_engine.models import ColPali, ColPaliProcessor | |
| from colpali_engine.utils.torch_utils import get_torch_device | |
| from vespa.application import Vespa | |
| from vespa.io import VespaQueryResponse | |
| from dotenv import load_dotenv | |
| from pathlib import Path | |
| MAX_QUERY_TERMS = 64 | |
| SAVEDIR = Path(__file__) / "output" / "images" | |
| load_dotenv() | |
| def process_queries(processor, queries, image): | |
| inputs = processor( | |
| images=[image] * len(queries), text=queries, return_tensors="pt", padding=True | |
| ) | |
| return inputs | |
| def display_query_results(query, response, hits=5): | |
| query_time = response.json.get("timing", {}).get("searchtime", -1) | |
| query_time = round(query_time, 2) | |
| count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0) | |
| result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n" | |
| for i, hit in enumerate(response.hits[:hits]): | |
| title = hit["fields"]["title"] | |
| url = hit["fields"]["url"] | |
| page = hit["fields"]["page_number"] | |
| image = hit["fields"]["image"] | |
| _id = hit["id"] | |
| score = hit["relevance"] | |
| result_text += f"\nPDF Result {i + 1}\n" | |
| result_text += f"Title: {title}, page {page+1} with score {score:.2f}\n" | |
| result_text += f"URL: {url}\n" | |
| result_text += f"ID: {_id}\n" | |
| # Optionally, save or display the image | |
| # img_data = base64.b64decode(image) | |
| # img_path = SAVEDIR / f"{title}.png" | |
| # with open(f"{img_path}", "wb") as f: | |
| # f.write(img_data) | |
| print(result_text) | |
| async def query_vespa_default(app, queries, qs): | |
| async with app.asyncio(connections=1, total_timeout=120) as session: | |
| for idx, query in enumerate(queries): | |
| query_embedding = {k: v.tolist() for k, v in enumerate(qs[idx])} | |
| response: VespaQueryResponse = await session.query( | |
| yql="select documentid,title,url,image,page_number from pdf_page where userInput(@userQuery)", | |
| ranking="default", | |
| userQuery=query, | |
| timeout=120, | |
| hits=3, | |
| body={"input.query(qt)": query_embedding, "presentation.timing": True}, | |
| ) | |
| assert response.is_successful() | |
| display_query_results(query, response) | |
| async def query_vespa_nearest_neighbor(app, queries, qs): | |
| # Using nearestNeighbor for retrieval | |
| target_hits_per_query_tensor = ( | |
| 20 # this is a hyper parameter that can be tuned for speed versus accuracy | |
| ) | |
| async with app.asyncio(connections=1, total_timeout=180) as session: | |
| for idx, query in enumerate(queries): | |
| float_query_embedding = {k: v.tolist() for k, v in enumerate(qs[idx])} | |
| binary_query_embeddings = dict() | |
| for k, v in float_query_embedding.items(): | |
| binary_vector = ( | |
| np.packbits(np.where(np.array(v) > 0, 1, 0)) | |
| .astype(np.int8) | |
| .tolist() | |
| ) | |
| binary_query_embeddings[k] = binary_vector | |
| if len(binary_query_embeddings) >= MAX_QUERY_TERMS: | |
| print( | |
| f"Warning: Query has more than {MAX_QUERY_TERMS} terms. Truncating." | |
| ) | |
| break | |
| # The mixed tensors used in MaxSim calculations | |
| # We use both binary and float representations | |
| query_tensors = { | |
| "input.query(qtb)": binary_query_embeddings, | |
| "input.query(qt)": float_query_embedding, | |
| } | |
| # The query tensors used in the nearest neighbor calculations | |
| for i in range(0, len(binary_query_embeddings)): | |
| query_tensors[f"input.query(rq{i})"] = binary_query_embeddings[i] | |
| nn = [] | |
| for i in range(0, len(binary_query_embeddings)): | |
| nn.append( | |
| f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))" | |
| ) | |
| # We use an OR operator to combine the nearest neighbor operator | |
| nn = " OR ".join(nn) | |
| response: VespaQueryResponse = await session.query( | |
| body={ | |
| **query_tensors, | |
| "presentation.timing": True, | |
| "yql": f"select documentid, title, url, image, page_number from pdf_page where {nn}", | |
| "ranking.profile": "retrieval-and-rerank", | |
| "timeout": 120, | |
| "hits": 3, | |
| }, | |
| ) | |
| assert response.is_successful(), response.json | |
| display_query_results(query, response) | |
| def main(): | |
| vespa_app_url = os.environ.get( | |
| "VESPA_APP_URL" | |
| ) # Ensure this is set to your Vespa app URL | |
| vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN") | |
| if not vespa_app_url or not vespa_cloud_secret_token: | |
| raise ValueError( | |
| "Please set the VESPA_APP_URL and VESPA_CLOUD_SECRET_TOKEN environment variables" | |
| ) | |
| # Instantiate Vespa connection | |
| app = Vespa(url=vespa_app_url, vespa_cloud_secret_token=vespa_cloud_secret_token) | |
| status_resp = app.get_application_status() | |
| if status_resp.status_code != 200: | |
| print(f"Failed to connect to Vespa at {vespa_app_url}") | |
| return | |
| else: | |
| print(f"Connected to Vespa at {vespa_app_url}") | |
| # Load the model | |
| device = get_torch_device("auto") | |
| print(f"Using device: {device}") | |
| model_name = "vidore/colpali-v1.2" | |
| processor_name = "google/paligemma-3b-mix-448" | |
| model = cast( | |
| ColPali, | |
| ColPali.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map=device, | |
| ), | |
| ).eval() | |
| processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(processor_name)) | |
| # Create dummy image | |
| dummy_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
| # Define queries | |
| queries = [ | |
| "Percentage of non-fresh water as source?", | |
| "Policies related to nature risk?", | |
| "How much of produced water is recycled?", | |
| ] | |
| # Obtain query embeddings | |
| dataloader = DataLoader( | |
| queries, | |
| batch_size=1, | |
| shuffle=False, | |
| collate_fn=lambda x: process_queries(processor, x, dummy_image), | |
| ) | |
| qs = [] | |
| for batch_query in dataloader: | |
| with torch.no_grad(): | |
| batch_query = {k: v.to(model.device) for k, v in batch_query.items()} | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| # Perform queries using default rank profile | |
| print("Performing queries using default rank profile:") | |
| asyncio.run(query_vespa_default(app, queries, qs)) | |
| # Perform queries using nearestNeighbor | |
| print("Performing queries using nearestNeighbor:") | |
| asyncio.run(query_vespa_nearest_neighbor(app, queries, qs)) | |
| if __name__ == "__main__": | |
| main() | |