fix: scaling based on preprocessor instead of grid estimation and padding

#3
Files changed (1) hide show
  1. similarity.py +34 -50
similarity.py CHANGED
@@ -189,7 +189,7 @@ class JinaV4SimilarityMapper:
189
  print(f"Token map: {token_map}")
190
  return tokens, query_embeddings, token_map
191
 
192
- def process_image(self, image: Union[str, bytes, Image.Image]) -> Tuple[Image.Image, torch.Tensor, Tuple[int, int]]:
193
  """
194
  Process image to get patch embeddings in multivector format.
195
 
@@ -200,34 +200,34 @@ class JinaV4SimilarityMapper:
200
  pil_image: Original PIL image.
201
  patch_embeddings: Image patch embeddings [num_patches/num_vectors, embed_dim].
202
  size: Original image size (width, height).
 
203
  """
204
  pil_image = self._load_image(image)
205
-
206
  proc_out = self.preprocessor.process_images(images=[pil_image])
207
- for key, value in proc_out.items():
208
- if isinstance(value, torch.Tensor):
209
- print(f"proc out {key} shape: {value.shape}")
210
-
 
 
 
 
211
  size = pil_image.size
212
  image_embeddings = self.model.encode_image(
213
  images=[pil_image],
214
  task="retrieval",
215
  return_multivector=True,
216
- max_pixels=1024*1024, # Max image resolution
217
  truncate_dim=self.num_vectors
218
  )
219
- image_embeddings = image_embeddings[0] # [num_patches/num_vectors, embed_dim]
220
- non_zero_mask = (image_embeddings.abs().sum(dim=1) > 0)
221
- image_embeddings = image_embeddings[non_zero_mask]
222
-
223
- # <|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n
224
- vision_start_position_from_start = 3 + 1
225
- vision_end_position_from_end = 6 + 1
226
  # Remove special tokens
 
 
227
  image_embeddings = image_embeddings[vision_start_position_from_start:-vision_end_position_from_end]
228
- print(f"Image embeddings shape: {image_embeddings.shape}")
229
- print(f"Image size: {size}")
230
- return pil_image, image_embeddings, size
231
 
232
  def _load_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image:
233
  """Load image from various formats (URL, path, bytes, PIL Image)."""
@@ -273,45 +273,37 @@ class JinaV4SimilarityMapper:
273
  similarity_scores = torch.cosine_similarity(token_expanded, patch_embeddings, dim=1)
274
  return similarity_scores
275
 
276
- def generate_heatmap(self, image: Image.Image, similarity_map: torch.Tensor, size: Tuple[int, int]) -> str:
277
  """
278
  Generate a heatmap overlay on the image and return as base64.
279
 
280
  Args:
281
  image: Original PIL image.
282
- similarity_map: Similarity scores [num_vectors/num_patches].
283
  size: Original image size (width, height).
284
-
285
- Returns:
286
- Base64-encoded PNG image with heatmap.
287
  """
288
- num_patches = similarity_map.shape[0]
 
 
289
  # Normalize to [0, 1]
290
  similarity_map = (similarity_map - similarity_map.min()) / (
291
  similarity_map.max() - similarity_map.min() + 1e-8
292
  )
293
- # Calculate grid dimensions from image aspect ratio
294
- width, height = size
295
- aspect_ratio = width / height
296
- grid_width = int(np.ceil(np.sqrt(num_patches * aspect_ratio)))
297
- grid_height = int(np.ceil(num_patches / grid_width))
298
- total_patches = grid_width * grid_height
299
- # Ensure similarity map fits grid (padding/truncation)
300
- if num_patches < total_patches:
301
- padding = torch.zeros(total_patches - num_patches, device=similarity_map.device)
302
- similarity_map = torch.cat([similarity_map, padding])
303
- else:
304
- similarity_map = similarity_map[:total_patches]
305
- # Reshape to 2D grid [grid_height, grid_width]
306
  similarity_2d = similarity_map.reshape(grid_height, grid_width).cpu().numpy()
 
307
  # Create & resize heatmap
308
  heatmap = (self.colormap(similarity_2d) * 255).astype(np.uint8)
309
  heatmap = Image.fromarray(heatmap[..., :3], mode="RGB")
310
  heatmap = heatmap.resize(size, resample=Image.BICUBIC)
 
311
  # Blend with original image
312
  original_rgba = image.convert("RGBA")
313
  heatmap_rgba = heatmap.convert("RGBA")
314
  blended = Image.blend(original_rgba, heatmap_rgba, alpha=self.heatmap_alpha)
 
315
  # Encode to base64
316
  buffer = BytesIO()
317
  blended.save(buffer, format="PNG")
@@ -325,30 +317,22 @@ class JinaV4SimilarityMapper:
325
  ) -> Tuple[List[str], Dict[str, str]]:
326
  """
327
  Main method to generate similarity maps for all query tokens.
328
-
329
- Args:
330
- query: Input query text.
331
- image: Image to analyze.
332
- aggregation: How to aggregate multivector similarities.
333
-
334
- Returns:
335
- tokens: List of query tokens.
336
- heatmaps: Dictionary of {token: base64_heatmap}.
337
  """
338
- tokens, query_embeddings, token_map = self.process_query(query)
339
- pil_image, patch_embeddings, size = self.process_image(image)
 
340
  heatmaps = {}
341
  tokens_for_ui = []
 
342
  for idx, token in token_map.items():
343
- print(f"Processing token: {token} (index {idx})")
344
  if self._should_filter_token(token):
345
  continue
346
  tokens_for_ui.append(token)
347
- token_embedding = query_embeddings[idx] # [embed_dim]
348
  sim_map = self.compute_similarity_map(
349
  token_embedding, patch_embeddings, aggregation
350
  )
351
- heatmap_b64 = self.generate_heatmap(pil_image, sim_map, size)
352
  heatmaps[token] = heatmap_b64
353
 
354
  return tokens_for_ui, heatmaps
 
189
  print(f"Token map: {token_map}")
190
  return tokens, query_embeddings, token_map
191
 
192
+ def process_image(self, image: Union[str, bytes, Image.Image]) -> Tuple[Image.Image, torch.Tensor, Tuple[int, int], Tuple[int, int]]:
193
  """
194
  Process image to get patch embeddings in multivector format.
195
 
 
200
  pil_image: Original PIL image.
201
  patch_embeddings: Image patch embeddings [num_patches/num_vectors, embed_dim].
202
  size: Original image size (width, height).
203
+ grid_size: Patch grid dimensions (height, width) after merge.
204
  """
205
  pil_image = self._load_image(image)
 
206
  proc_out = self.preprocessor.process_images(images=[pil_image])
207
+
208
+ # Get the grid dimensions from preprocessor
209
+ image_grid_thw = proc_out["image_grid_thw"]
210
+ _, height, width = image_grid_thw[0].tolist()
211
+ # Account for 2x2 merge
212
+ grid_height = height // 2
213
+ grid_width = width // 2
214
+
215
  size = pil_image.size
216
  image_embeddings = self.model.encode_image(
217
  images=[pil_image],
218
  task="retrieval",
219
  return_multivector=True,
220
+ max_pixels=1024*1024,
221
  truncate_dim=self.num_vectors
222
  )
223
+ image_embeddings = image_embeddings[0]
224
+
 
 
 
 
 
225
  # Remove special tokens
226
+ vision_start_position_from_start = 5
227
+ vision_end_position_from_end = 6
228
  image_embeddings = image_embeddings[vision_start_position_from_start:-vision_end_position_from_end]
229
+
230
+ return pil_image, image_embeddings, size, (grid_height, grid_width)
 
231
 
232
  def _load_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image:
233
  """Load image from various formats (URL, path, bytes, PIL Image)."""
 
273
  similarity_scores = torch.cosine_similarity(token_expanded, patch_embeddings, dim=1)
274
  return similarity_scores
275
 
276
+ def generate_heatmap(self, image: Image.Image, similarity_map: torch.Tensor, size: Tuple[int, int], grid_size: Tuple[int, int]) -> str:
277
  """
278
  Generate a heatmap overlay on the image and return as base64.
279
 
280
  Args:
281
  image: Original PIL image.
282
+ similarity_map: Similarity scores [num_patches].
283
  size: Original image size (width, height).
284
+ grid_size: Patch grid dimensions (height, width).
 
 
285
  """
286
+ # num_patches = similarity_map.shape[0]
287
+ grid_height, grid_width = grid_size
288
+
289
  # Normalize to [0, 1]
290
  similarity_map = (similarity_map - similarity_map.min()) / (
291
  similarity_map.max() - similarity_map.min() + 1e-8
292
  )
293
+
294
+ # Reshape to 2D grid
 
 
 
 
 
 
 
 
 
 
 
295
  similarity_2d = similarity_map.reshape(grid_height, grid_width).cpu().numpy()
296
+
297
  # Create & resize heatmap
298
  heatmap = (self.colormap(similarity_2d) * 255).astype(np.uint8)
299
  heatmap = Image.fromarray(heatmap[..., :3], mode="RGB")
300
  heatmap = heatmap.resize(size, resample=Image.BICUBIC)
301
+
302
  # Blend with original image
303
  original_rgba = image.convert("RGBA")
304
  heatmap_rgba = heatmap.convert("RGBA")
305
  blended = Image.blend(original_rgba, heatmap_rgba, alpha=self.heatmap_alpha)
306
+
307
  # Encode to base64
308
  buffer = BytesIO()
309
  blended.save(buffer, format="PNG")
 
317
  ) -> Tuple[List[str], Dict[str, str]]:
318
  """
319
  Main method to generate similarity maps for all query tokens.
 
 
 
 
 
 
 
 
 
320
  """
321
+ _, query_embeddings, token_map = self.process_query(query)
322
+ pil_image, patch_embeddings, size, grid_size = self.process_image(image)
323
+
324
  heatmaps = {}
325
  tokens_for_ui = []
326
+
327
  for idx, token in token_map.items():
 
328
  if self._should_filter_token(token):
329
  continue
330
  tokens_for_ui.append(token)
331
+ token_embedding = query_embeddings[idx]
332
  sim_map = self.compute_similarity_map(
333
  token_embedding, patch_embeddings, aggregation
334
  )
335
+ heatmap_b64 = self.generate_heatmap(pil_image, sim_map, size, grid_size)
336
  heatmaps[token] = heatmap_b64
337
 
338
  return tokens_for_ui, heatmaps