|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from typing import Dict, List, Tuple, Union, Optional, Any |
|
import base64 |
|
from io import BytesIO |
|
import re |
|
import logging |
|
from transformers import AutoModel, AutoProcessor |
|
import requests |
|
import matplotlib.pyplot as plt |
|
import os |
|
import json |
|
|
|
IMG_SIZE = 1024 |
|
|
|
|
|
class JinaEmbeddingsClient: |
|
""" |
|
Minimal wrapper for https://api.jina.ai/v1/embeddings |
|
""" |
|
|
|
API_URL = "https://api.jina.ai/v1/embeddings" |
|
|
|
def __init__( |
|
self, |
|
model: str = "jina-embeddings-v4", |
|
return_multivector: bool = True, |
|
task: str = "retrieval.query", |
|
timeout: int = 30, |
|
) -> None: |
|
self.headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer Not Set", |
|
} |
|
self.base_payload = { |
|
"model": model, |
|
"return_multivector": return_multivector, |
|
"task": task, |
|
} |
|
self.timeout = timeout |
|
|
|
def encode_text(self, texts: List[str], **kwargs) -> Dict[str, Any]: |
|
""" |
|
Encode a batch of texts. |
|
""" |
|
payload = [{"text": t} for t in texts] |
|
res = self._post(payload) |
|
return self._as_tensors(res["data"]) |
|
|
|
def encode_image(self, images: List[Union[str, bytes, 'Image.Image']], **kwargs) -> List: |
|
""" |
|
Encode a batch of images given as |
|
• URLs (str) – https://…/image.png |
|
• base64 strings (str) – iVBORw0… |
|
• raw bytes – b'\xff\xd8…' (base64‑encoded automatically) |
|
• PIL Image.Image instances (converted to base64 PNG) |
|
""" |
|
def pil_image_to_base64_str(img): |
|
buffered = BytesIO() |
|
img.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode() |
|
|
|
processed = [] |
|
for img in images: |
|
if isinstance(img, bytes): |
|
img = base64.b64encode(img).decode() |
|
elif hasattr(img, "save"): |
|
img = pil_image_to_base64_str(img) |
|
|
|
processed.append({"image": img}) |
|
|
|
res = self._post(processed) |
|
|
|
return [torch.tensor(item['embeddings']) for item in res['data']] |
|
|
|
def _post(self, input_batch: List[Dict[str, str]]) -> Dict[str, Any]: |
|
payload = {**self.base_payload, "input": input_batch} |
|
resp = requests.post( |
|
self.API_URL, headers=self.headers, json=payload, timeout=self.timeout |
|
) |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def set_api_key(self, api_key: str) -> None: |
|
""" |
|
Set the API key for authentication. |
|
""" |
|
if not api_key: |
|
raise ValueError("API key must not be empty.") |
|
self.headers["Authorization"] = f"Bearer {api_key}" |
|
|
|
@staticmethod |
|
def _as_tensors(data: List[Dict[str, Any]]) -> List[torch.Tensor]: |
|
""" |
|
Convert the `"data"` array of the API response into a list |
|
of `torch.Tensor`s (one tensor per text / image you sent). |
|
|
|
Each tensor’s shape is (n_vectors, dim). When you set |
|
`return_multivector=False` you’ll just get shape (1, dim). |
|
""" |
|
tensors: List[torch.Tensor] = [] |
|
for item in data: |
|
emb_lists = item["embeddings"] |
|
tensors.append(torch.tensor(emb_lists, dtype=torch.float32)) |
|
return tensors |
|
|
|
|
|
|
|
class JinaV4SimilarityMapper: |
|
""" |
|
Generates interactive similarity maps between query tokens and images using Jina Embedding v4. |
|
Enables visualizing which parts of an image correspond to specific words in the query. |
|
""" |
|
def __init__( |
|
self, |
|
model_name: str = "jinaai/jina-embeddings-v4", |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
heatmap_alpha: float = 0.6, |
|
num_vectors: int = 128, |
|
client_type: str = "local", |
|
): |
|
""" |
|
Initialize the mapper with Jina Embedding v4. |
|
|
|
Args: |
|
model_name: Model name from Hugging Face hub. |
|
device: Compute device (GPU recommended for performance). |
|
patch_size: Size of image patches for embedding. |
|
heatmap_alpha: Transparency for the similarity heatmap. |
|
""" |
|
self.model_name = model_name |
|
self.device = device |
|
self.logger = logging.getLogger("JinaV4SimMapper") |
|
self.logger.info(f"Initializing model on {device}") |
|
assert client_type in ["local", "web"], "client_type must be 'local' or 'web'" |
|
if client_type == "local": |
|
self.model = AutoModel.from_pretrained( |
|
self.model_name, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
self.model.eval() |
|
else: |
|
self.model = JinaEmbeddingsClient() |
|
self.preprocessor = AutoProcessor.from_pretrained( |
|
self.model_name, |
|
trust_remote_code=True |
|
) |
|
self.heatmap_alpha = heatmap_alpha |
|
self.num_vectors = num_vectors |
|
self.colormap = plt.cm.get_cmap("jet") |
|
|
|
def process_query(self, query: str) -> Tuple[List[str], torch.Tensor, Dict[int, str]]: |
|
""" |
|
Process query to get tokens, multivector embeddings, and token-index map. |
|
|
|
Args: |
|
query: Input query text. |
|
|
|
Returns: |
|
tokens: List of query tokens. |
|
embeddings: Multivector embeddings [num_tokens/num_vectors, embed_dim]. |
|
token_map: Mapping from index to token. |
|
""" |
|
query_embeddings = self.model.encode_text( |
|
texts=[query], |
|
task="retrieval", |
|
prompt_name="query", |
|
return_multivector=True, |
|
truncate_dim=self.num_vectors |
|
) |
|
query_embeddings = query_embeddings[0] |
|
print(f"Query embeddings shape: {query_embeddings.shape}") |
|
preprocessor_results = self.preprocessor.process_texts( |
|
texts=[query], |
|
prefix="Query" |
|
) |
|
input_ids = preprocessor_results["input_ids"] |
|
tokens = input_ids[0].tolist() |
|
tokens = self.preprocessor.tokenizer.convert_ids_to_tokens(tokens) |
|
print(f"Tokens: {tokens}") |
|
tokens = tokens[2:] |
|
query_embeddings = query_embeddings[2:] |
|
num_tokens = query_embeddings.shape[0] |
|
assert len(tokens) == num_tokens |
|
tokens = [tok.replace("Ġ", "") for tok in tokens] |
|
token_map = {i: tok for i, tok in enumerate(tokens)} |
|
print(f"Token map: {token_map}") |
|
return tokens, query_embeddings, token_map |
|
|
|
def process_image(self, image: Union[str, bytes, Image.Image]) -> Tuple[Image.Image, torch.Tensor, Tuple[int, int], Tuple[int, int]]: |
|
""" |
|
Process image to get patch embeddings in multivector format. |
|
|
|
Args: |
|
image: Image path, URL, bytes, or PIL Image. |
|
|
|
Returns: |
|
pil_image: Original PIL image. |
|
patch_embeddings: Image patch embeddings [num_patches/num_vectors, embed_dim]. |
|
size: Original image size (width, height). |
|
grid_size: Patch grid dimensions (height, width) after merge. |
|
""" |
|
pil_image = self._load_image(image) |
|
proc_out = self.preprocessor.process_images(images=[pil_image]) |
|
|
|
|
|
image_grid_thw = proc_out["image_grid_thw"] |
|
_, height, width = image_grid_thw[0].tolist() |
|
|
|
grid_height = height // 2 |
|
grid_width = width // 2 |
|
|
|
size = pil_image.size |
|
image_embeddings = self.model.encode_image( |
|
images=[pil_image], |
|
task="retrieval", |
|
return_multivector=True, |
|
max_pixels=1024*1024, |
|
truncate_dim=self.num_vectors |
|
) |
|
image_embeddings = image_embeddings[0] |
|
|
|
|
|
vision_start_position_from_start = 4 |
|
vision_end_position_from_end = 7 |
|
image_embeddings = image_embeddings[vision_start_position_from_start:-vision_end_position_from_end] |
|
|
|
return pil_image, image_embeddings, size, (grid_height, grid_width) |
|
|
|
def _load_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image: |
|
"""Load image from various formats (URL, path, bytes, PIL Image).""" |
|
if isinstance(image, Image.Image): |
|
pil_image = image.convert("RGB") |
|
elif isinstance(image, str): |
|
if image.startswith(("http://", "https://")): |
|
response = requests.get(image) |
|
response.raise_for_status() |
|
pil_image = Image.open(BytesIO(response.content)).convert("RGB") |
|
else: |
|
pil_image = Image.open(image).convert("RGB") |
|
elif isinstance(image, bytes): |
|
pil_image = Image.open(BytesIO(image)).convert("RGB") |
|
else: |
|
raise ValueError(f"Unsupported image format: {type(image)}") |
|
|
|
|
|
original_width, original_height = pil_image.size |
|
aspect_ratio = original_height / original_width |
|
new_height = int(IMG_SIZE * aspect_ratio) |
|
pil_image = pil_image.resize((IMG_SIZE, new_height), Image.Resampling.LANCZOS) |
|
return pil_image |
|
|
|
def compute_similarity_map( |
|
self, |
|
token_embedding: torch.Tensor, |
|
patch_embeddings: torch.Tensor, |
|
aggregation: str = "mean" |
|
) -> torch.Tensor: |
|
""" |
|
Compute similarity between a query token and image patches. |
|
|
|
Args: |
|
token_embedding: Token multivector [embed_dim]. |
|
patch_embeddings: Image patch multivectors [num_vectors/num_patches, embed_dim]. |
|
|
|
Returns: |
|
similarity scores [num_vectors/num_patches]. |
|
""" |
|
num_patches = patch_embeddings.shape[0] |
|
token_expanded = token_embedding.expand(num_patches, -1) |
|
similarity_scores = torch.cosine_similarity(token_expanded, patch_embeddings, dim=1) |
|
return similarity_scores |
|
|
|
def generate_heatmap(self, image: Image.Image, similarity_map: torch.Tensor, size: Tuple[int, int], grid_size: Tuple[int, int]) -> str: |
|
""" |
|
Generate a heatmap overlay on the image and return as base64. |
|
|
|
Args: |
|
image: Original PIL image. |
|
similarity_map: Similarity scores [num_patches]. |
|
size: Original image size (width, height). |
|
grid_size: Patch grid dimensions (height, width). |
|
""" |
|
|
|
grid_height, grid_width = grid_size |
|
|
|
|
|
similarity_map = (similarity_map - similarity_map.min()) / ( |
|
similarity_map.max() - similarity_map.min() + 1e-8 |
|
) |
|
|
|
|
|
similarity_2d = similarity_map.reshape(grid_height, grid_width).cpu().numpy() |
|
|
|
|
|
heatmap = (self.colormap(similarity_2d) * 255).astype(np.uint8) |
|
heatmap = Image.fromarray(heatmap[..., :3], mode="RGB") |
|
heatmap = heatmap.resize(size, resample=Image.BICUBIC) |
|
|
|
|
|
original_rgba = image.convert("RGBA") |
|
heatmap_rgba = heatmap.convert("RGBA") |
|
blended = Image.blend(original_rgba, heatmap_rgba, alpha=self.heatmap_alpha) |
|
|
|
|
|
buffer = BytesIO() |
|
blended.save(buffer, format="PNG") |
|
return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
def get_token_similarity_maps( |
|
self, |
|
query: str, |
|
image: Union[str, bytes, Image.Image], |
|
aggregation: str = "mean" |
|
) -> Tuple[List[str], Dict[str, str]]: |
|
""" |
|
Main method to generate similarity maps for all query tokens. |
|
""" |
|
_, query_embeddings, token_map = self.process_query(query) |
|
pil_image, patch_embeddings, size, grid_size = self.process_image(image) |
|
|
|
heatmaps = {} |
|
tokens_for_ui = [] |
|
|
|
for idx, token in token_map.items(): |
|
if self._should_filter_token(token): |
|
continue |
|
tokens_for_ui.append(token) |
|
token_embedding = query_embeddings[idx] |
|
sim_map = self.compute_similarity_map( |
|
token_embedding, patch_embeddings, aggregation |
|
) |
|
heatmap_b64 = self.generate_heatmap(pil_image, sim_map, size, grid_size) |
|
heatmaps[token] = heatmap_b64 |
|
|
|
return tokens_for_ui, heatmaps |
|
|
|
def _should_filter_token(self, token: str) -> bool: |
|
"""Filter out irrelevant tokens (punctuation, special symbols).""" |
|
if token.strip() == "" or re.match(r'^\s*$|^[^\w\s]+$|^<.*>$', token): |
|
return True |
|
return False |
|
|