fix: scaling based on preprocessor instead of grid estimation and padding
#3
by
ajinauser
- opened
- similarity.py +34 -50
similarity.py
CHANGED
@@ -189,7 +189,7 @@ class JinaV4SimilarityMapper:
|
|
189 |
print(f"Token map: {token_map}")
|
190 |
return tokens, query_embeddings, token_map
|
191 |
|
192 |
-
def process_image(self, image: Union[str, bytes, Image.Image]) -> Tuple[Image.Image, torch.Tensor, Tuple[int, int]]:
|
193 |
"""
|
194 |
Process image to get patch embeddings in multivector format.
|
195 |
|
@@ -200,34 +200,34 @@ class JinaV4SimilarityMapper:
|
|
200 |
pil_image: Original PIL image.
|
201 |
patch_embeddings: Image patch embeddings [num_patches/num_vectors, embed_dim].
|
202 |
size: Original image size (width, height).
|
|
|
203 |
"""
|
204 |
pil_image = self._load_image(image)
|
205 |
-
|
206 |
proc_out = self.preprocessor.process_images(images=[pil_image])
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
211 |
size = pil_image.size
|
212 |
image_embeddings = self.model.encode_image(
|
213 |
images=[pil_image],
|
214 |
task="retrieval",
|
215 |
return_multivector=True,
|
216 |
-
max_pixels=1024*1024,
|
217 |
truncate_dim=self.num_vectors
|
218 |
)
|
219 |
-
image_embeddings = image_embeddings[0]
|
220 |
-
|
221 |
-
image_embeddings = image_embeddings[non_zero_mask]
|
222 |
-
|
223 |
-
# <|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n
|
224 |
-
vision_start_position_from_start = 3 + 1
|
225 |
-
vision_end_position_from_end = 6 + 1
|
226 |
# Remove special tokens
|
|
|
|
|
227 |
image_embeddings = image_embeddings[vision_start_position_from_start:-vision_end_position_from_end]
|
228 |
-
|
229 |
-
|
230 |
-
return pil_image, image_embeddings, size
|
231 |
|
232 |
def _load_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image:
|
233 |
"""Load image from various formats (URL, path, bytes, PIL Image)."""
|
@@ -273,45 +273,37 @@ class JinaV4SimilarityMapper:
|
|
273 |
similarity_scores = torch.cosine_similarity(token_expanded, patch_embeddings, dim=1)
|
274 |
return similarity_scores
|
275 |
|
276 |
-
def generate_heatmap(self, image: Image.Image, similarity_map: torch.Tensor, size: Tuple[int, int]) -> str:
|
277 |
"""
|
278 |
Generate a heatmap overlay on the image and return as base64.
|
279 |
|
280 |
Args:
|
281 |
image: Original PIL image.
|
282 |
-
similarity_map: Similarity scores [
|
283 |
size: Original image size (width, height).
|
284 |
-
|
285 |
-
Returns:
|
286 |
-
Base64-encoded PNG image with heatmap.
|
287 |
"""
|
288 |
-
num_patches = similarity_map.shape[0]
|
|
|
|
|
289 |
# Normalize to [0, 1]
|
290 |
similarity_map = (similarity_map - similarity_map.min()) / (
|
291 |
similarity_map.max() - similarity_map.min() + 1e-8
|
292 |
)
|
293 |
-
|
294 |
-
|
295 |
-
aspect_ratio = width / height
|
296 |
-
grid_width = int(np.ceil(np.sqrt(num_patches * aspect_ratio)))
|
297 |
-
grid_height = int(np.ceil(num_patches / grid_width))
|
298 |
-
total_patches = grid_width * grid_height
|
299 |
-
# Ensure similarity map fits grid (padding/truncation)
|
300 |
-
if num_patches < total_patches:
|
301 |
-
padding = torch.zeros(total_patches - num_patches, device=similarity_map.device)
|
302 |
-
similarity_map = torch.cat([similarity_map, padding])
|
303 |
-
else:
|
304 |
-
similarity_map = similarity_map[:total_patches]
|
305 |
-
# Reshape to 2D grid [grid_height, grid_width]
|
306 |
similarity_2d = similarity_map.reshape(grid_height, grid_width).cpu().numpy()
|
|
|
307 |
# Create & resize heatmap
|
308 |
heatmap = (self.colormap(similarity_2d) * 255).astype(np.uint8)
|
309 |
heatmap = Image.fromarray(heatmap[..., :3], mode="RGB")
|
310 |
heatmap = heatmap.resize(size, resample=Image.BICUBIC)
|
|
|
311 |
# Blend with original image
|
312 |
original_rgba = image.convert("RGBA")
|
313 |
heatmap_rgba = heatmap.convert("RGBA")
|
314 |
blended = Image.blend(original_rgba, heatmap_rgba, alpha=self.heatmap_alpha)
|
|
|
315 |
# Encode to base64
|
316 |
buffer = BytesIO()
|
317 |
blended.save(buffer, format="PNG")
|
@@ -325,30 +317,22 @@ class JinaV4SimilarityMapper:
|
|
325 |
) -> Tuple[List[str], Dict[str, str]]:
|
326 |
"""
|
327 |
Main method to generate similarity maps for all query tokens.
|
328 |
-
|
329 |
-
Args:
|
330 |
-
query: Input query text.
|
331 |
-
image: Image to analyze.
|
332 |
-
aggregation: How to aggregate multivector similarities.
|
333 |
-
|
334 |
-
Returns:
|
335 |
-
tokens: List of query tokens.
|
336 |
-
heatmaps: Dictionary of {token: base64_heatmap}.
|
337 |
"""
|
338 |
-
|
339 |
-
pil_image, patch_embeddings, size = self.process_image(image)
|
|
|
340 |
heatmaps = {}
|
341 |
tokens_for_ui = []
|
|
|
342 |
for idx, token in token_map.items():
|
343 |
-
print(f"Processing token: {token} (index {idx})")
|
344 |
if self._should_filter_token(token):
|
345 |
continue
|
346 |
tokens_for_ui.append(token)
|
347 |
-
token_embedding = query_embeddings[idx]
|
348 |
sim_map = self.compute_similarity_map(
|
349 |
token_embedding, patch_embeddings, aggregation
|
350 |
)
|
351 |
-
heatmap_b64 = self.generate_heatmap(pil_image, sim_map, size)
|
352 |
heatmaps[token] = heatmap_b64
|
353 |
|
354 |
return tokens_for_ui, heatmaps
|
|
|
189 |
print(f"Token map: {token_map}")
|
190 |
return tokens, query_embeddings, token_map
|
191 |
|
192 |
+
def process_image(self, image: Union[str, bytes, Image.Image]) -> Tuple[Image.Image, torch.Tensor, Tuple[int, int], Tuple[int, int]]:
|
193 |
"""
|
194 |
Process image to get patch embeddings in multivector format.
|
195 |
|
|
|
200 |
pil_image: Original PIL image.
|
201 |
patch_embeddings: Image patch embeddings [num_patches/num_vectors, embed_dim].
|
202 |
size: Original image size (width, height).
|
203 |
+
grid_size: Patch grid dimensions (height, width) after merge.
|
204 |
"""
|
205 |
pil_image = self._load_image(image)
|
|
|
206 |
proc_out = self.preprocessor.process_images(images=[pil_image])
|
207 |
+
|
208 |
+
# Get the grid dimensions from preprocessor
|
209 |
+
image_grid_thw = proc_out["image_grid_thw"]
|
210 |
+
_, height, width = image_grid_thw[0].tolist()
|
211 |
+
# Account for 2x2 merge
|
212 |
+
grid_height = height // 2
|
213 |
+
grid_width = width // 2
|
214 |
+
|
215 |
size = pil_image.size
|
216 |
image_embeddings = self.model.encode_image(
|
217 |
images=[pil_image],
|
218 |
task="retrieval",
|
219 |
return_multivector=True,
|
220 |
+
max_pixels=1024*1024,
|
221 |
truncate_dim=self.num_vectors
|
222 |
)
|
223 |
+
image_embeddings = image_embeddings[0]
|
224 |
+
|
|
|
|
|
|
|
|
|
|
|
225 |
# Remove special tokens
|
226 |
+
vision_start_position_from_start = 5
|
227 |
+
vision_end_position_from_end = 6
|
228 |
image_embeddings = image_embeddings[vision_start_position_from_start:-vision_end_position_from_end]
|
229 |
+
|
230 |
+
return pil_image, image_embeddings, size, (grid_height, grid_width)
|
|
|
231 |
|
232 |
def _load_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image:
|
233 |
"""Load image from various formats (URL, path, bytes, PIL Image)."""
|
|
|
273 |
similarity_scores = torch.cosine_similarity(token_expanded, patch_embeddings, dim=1)
|
274 |
return similarity_scores
|
275 |
|
276 |
+
def generate_heatmap(self, image: Image.Image, similarity_map: torch.Tensor, size: Tuple[int, int], grid_size: Tuple[int, int]) -> str:
|
277 |
"""
|
278 |
Generate a heatmap overlay on the image and return as base64.
|
279 |
|
280 |
Args:
|
281 |
image: Original PIL image.
|
282 |
+
similarity_map: Similarity scores [num_patches].
|
283 |
size: Original image size (width, height).
|
284 |
+
grid_size: Patch grid dimensions (height, width).
|
|
|
|
|
285 |
"""
|
286 |
+
# num_patches = similarity_map.shape[0]
|
287 |
+
grid_height, grid_width = grid_size
|
288 |
+
|
289 |
# Normalize to [0, 1]
|
290 |
similarity_map = (similarity_map - similarity_map.min()) / (
|
291 |
similarity_map.max() - similarity_map.min() + 1e-8
|
292 |
)
|
293 |
+
|
294 |
+
# Reshape to 2D grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
similarity_2d = similarity_map.reshape(grid_height, grid_width).cpu().numpy()
|
296 |
+
|
297 |
# Create & resize heatmap
|
298 |
heatmap = (self.colormap(similarity_2d) * 255).astype(np.uint8)
|
299 |
heatmap = Image.fromarray(heatmap[..., :3], mode="RGB")
|
300 |
heatmap = heatmap.resize(size, resample=Image.BICUBIC)
|
301 |
+
|
302 |
# Blend with original image
|
303 |
original_rgba = image.convert("RGBA")
|
304 |
heatmap_rgba = heatmap.convert("RGBA")
|
305 |
blended = Image.blend(original_rgba, heatmap_rgba, alpha=self.heatmap_alpha)
|
306 |
+
|
307 |
# Encode to base64
|
308 |
buffer = BytesIO()
|
309 |
blended.save(buffer, format="PNG")
|
|
|
317 |
) -> Tuple[List[str], Dict[str, str]]:
|
318 |
"""
|
319 |
Main method to generate similarity maps for all query tokens.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
"""
|
321 |
+
_, query_embeddings, token_map = self.process_query(query)
|
322 |
+
pil_image, patch_embeddings, size, grid_size = self.process_image(image)
|
323 |
+
|
324 |
heatmaps = {}
|
325 |
tokens_for_ui = []
|
326 |
+
|
327 |
for idx, token in token_map.items():
|
|
|
328 |
if self._should_filter_token(token):
|
329 |
continue
|
330 |
tokens_for_ui.append(token)
|
331 |
+
token_embedding = query_embeddings[idx]
|
332 |
sim_map = self.compute_similarity_map(
|
333 |
token_embedding, patch_embeddings, aggregation
|
334 |
)
|
335 |
+
heatmap_b64 = self.generate_heatmap(pil_image, sim_map, size, grid_size)
|
336 |
heatmaps[token] = heatmap_b64
|
337 |
|
338 |
return tokens_for_ui, heatmaps
|