fix: scaling based on preprocessor instead of grid estimation and padding (#3)
Browse files- fix: scaling based on preprocessor instead of grid estimation and padding (2eed7787ccd86acc072c3093e511f42715701436)
- similarity.py +34 -50
similarity.py
CHANGED
|
@@ -189,7 +189,7 @@ class JinaV4SimilarityMapper:
|
|
| 189 |
print(f"Token map: {token_map}")
|
| 190 |
return tokens, query_embeddings, token_map
|
| 191 |
|
| 192 |
-
def process_image(self, image: Union[str, bytes, Image.Image]) -> Tuple[Image.Image, torch.Tensor, Tuple[int, int]]:
|
| 193 |
"""
|
| 194 |
Process image to get patch embeddings in multivector format.
|
| 195 |
|
|
@@ -200,34 +200,34 @@ class JinaV4SimilarityMapper:
|
|
| 200 |
pil_image: Original PIL image.
|
| 201 |
patch_embeddings: Image patch embeddings [num_patches/num_vectors, embed_dim].
|
| 202 |
size: Original image size (width, height).
|
|
|
|
| 203 |
"""
|
| 204 |
pil_image = self._load_image(image)
|
| 205 |
-
|
| 206 |
proc_out = self.preprocessor.process_images(images=[pil_image])
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
size = pil_image.size
|
| 212 |
image_embeddings = self.model.encode_image(
|
| 213 |
images=[pil_image],
|
| 214 |
task="retrieval",
|
| 215 |
return_multivector=True,
|
| 216 |
-
max_pixels=1024*1024,
|
| 217 |
truncate_dim=self.num_vectors
|
| 218 |
)
|
| 219 |
-
image_embeddings = image_embeddings[0]
|
| 220 |
-
|
| 221 |
-
image_embeddings = image_embeddings[non_zero_mask]
|
| 222 |
-
|
| 223 |
-
# <|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n
|
| 224 |
-
vision_start_position_from_start = 3 + 1
|
| 225 |
-
vision_end_position_from_end = 6 + 1
|
| 226 |
# Remove special tokens
|
|
|
|
|
|
|
| 227 |
image_embeddings = image_embeddings[vision_start_position_from_start:-vision_end_position_from_end]
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
return pil_image, image_embeddings, size
|
| 231 |
|
| 232 |
def _load_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image:
|
| 233 |
"""Load image from various formats (URL, path, bytes, PIL Image)."""
|
|
@@ -273,45 +273,37 @@ class JinaV4SimilarityMapper:
|
|
| 273 |
similarity_scores = torch.cosine_similarity(token_expanded, patch_embeddings, dim=1)
|
| 274 |
return similarity_scores
|
| 275 |
|
| 276 |
-
def generate_heatmap(self, image: Image.Image, similarity_map: torch.Tensor, size: Tuple[int, int]) -> str:
|
| 277 |
"""
|
| 278 |
Generate a heatmap overlay on the image and return as base64.
|
| 279 |
|
| 280 |
Args:
|
| 281 |
image: Original PIL image.
|
| 282 |
-
similarity_map: Similarity scores [
|
| 283 |
size: Original image size (width, height).
|
| 284 |
-
|
| 285 |
-
Returns:
|
| 286 |
-
Base64-encoded PNG image with heatmap.
|
| 287 |
"""
|
| 288 |
-
num_patches = similarity_map.shape[0]
|
|
|
|
|
|
|
| 289 |
# Normalize to [0, 1]
|
| 290 |
similarity_map = (similarity_map - similarity_map.min()) / (
|
| 291 |
similarity_map.max() - similarity_map.min() + 1e-8
|
| 292 |
)
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
aspect_ratio = width / height
|
| 296 |
-
grid_width = int(np.ceil(np.sqrt(num_patches * aspect_ratio)))
|
| 297 |
-
grid_height = int(np.ceil(num_patches / grid_width))
|
| 298 |
-
total_patches = grid_width * grid_height
|
| 299 |
-
# Ensure similarity map fits grid (padding/truncation)
|
| 300 |
-
if num_patches < total_patches:
|
| 301 |
-
padding = torch.zeros(total_patches - num_patches, device=similarity_map.device)
|
| 302 |
-
similarity_map = torch.cat([similarity_map, padding])
|
| 303 |
-
else:
|
| 304 |
-
similarity_map = similarity_map[:total_patches]
|
| 305 |
-
# Reshape to 2D grid [grid_height, grid_width]
|
| 306 |
similarity_2d = similarity_map.reshape(grid_height, grid_width).cpu().numpy()
|
|
|
|
| 307 |
# Create & resize heatmap
|
| 308 |
heatmap = (self.colormap(similarity_2d) * 255).astype(np.uint8)
|
| 309 |
heatmap = Image.fromarray(heatmap[..., :3], mode="RGB")
|
| 310 |
heatmap = heatmap.resize(size, resample=Image.BICUBIC)
|
|
|
|
| 311 |
# Blend with original image
|
| 312 |
original_rgba = image.convert("RGBA")
|
| 313 |
heatmap_rgba = heatmap.convert("RGBA")
|
| 314 |
blended = Image.blend(original_rgba, heatmap_rgba, alpha=self.heatmap_alpha)
|
|
|
|
| 315 |
# Encode to base64
|
| 316 |
buffer = BytesIO()
|
| 317 |
blended.save(buffer, format="PNG")
|
|
@@ -325,30 +317,22 @@ class JinaV4SimilarityMapper:
|
|
| 325 |
) -> Tuple[List[str], Dict[str, str]]:
|
| 326 |
"""
|
| 327 |
Main method to generate similarity maps for all query tokens.
|
| 328 |
-
|
| 329 |
-
Args:
|
| 330 |
-
query: Input query text.
|
| 331 |
-
image: Image to analyze.
|
| 332 |
-
aggregation: How to aggregate multivector similarities.
|
| 333 |
-
|
| 334 |
-
Returns:
|
| 335 |
-
tokens: List of query tokens.
|
| 336 |
-
heatmaps: Dictionary of {token: base64_heatmap}.
|
| 337 |
"""
|
| 338 |
-
|
| 339 |
-
pil_image, patch_embeddings, size = self.process_image(image)
|
|
|
|
| 340 |
heatmaps = {}
|
| 341 |
tokens_for_ui = []
|
|
|
|
| 342 |
for idx, token in token_map.items():
|
| 343 |
-
print(f"Processing token: {token} (index {idx})")
|
| 344 |
if self._should_filter_token(token):
|
| 345 |
continue
|
| 346 |
tokens_for_ui.append(token)
|
| 347 |
-
token_embedding = query_embeddings[idx]
|
| 348 |
sim_map = self.compute_similarity_map(
|
| 349 |
token_embedding, patch_embeddings, aggregation
|
| 350 |
)
|
| 351 |
-
heatmap_b64 = self.generate_heatmap(pil_image, sim_map, size)
|
| 352 |
heatmaps[token] = heatmap_b64
|
| 353 |
|
| 354 |
return tokens_for_ui, heatmaps
|
|
|
|
| 189 |
print(f"Token map: {token_map}")
|
| 190 |
return tokens, query_embeddings, token_map
|
| 191 |
|
| 192 |
+
def process_image(self, image: Union[str, bytes, Image.Image]) -> Tuple[Image.Image, torch.Tensor, Tuple[int, int], Tuple[int, int]]:
|
| 193 |
"""
|
| 194 |
Process image to get patch embeddings in multivector format.
|
| 195 |
|
|
|
|
| 200 |
pil_image: Original PIL image.
|
| 201 |
patch_embeddings: Image patch embeddings [num_patches/num_vectors, embed_dim].
|
| 202 |
size: Original image size (width, height).
|
| 203 |
+
grid_size: Patch grid dimensions (height, width) after merge.
|
| 204 |
"""
|
| 205 |
pil_image = self._load_image(image)
|
|
|
|
| 206 |
proc_out = self.preprocessor.process_images(images=[pil_image])
|
| 207 |
+
|
| 208 |
+
# Get the grid dimensions from preprocessor
|
| 209 |
+
image_grid_thw = proc_out["image_grid_thw"]
|
| 210 |
+
_, height, width = image_grid_thw[0].tolist()
|
| 211 |
+
# Account for 2x2 merge
|
| 212 |
+
grid_height = height // 2
|
| 213 |
+
grid_width = width // 2
|
| 214 |
+
|
| 215 |
size = pil_image.size
|
| 216 |
image_embeddings = self.model.encode_image(
|
| 217 |
images=[pil_image],
|
| 218 |
task="retrieval",
|
| 219 |
return_multivector=True,
|
| 220 |
+
max_pixels=1024*1024,
|
| 221 |
truncate_dim=self.num_vectors
|
| 222 |
)
|
| 223 |
+
image_embeddings = image_embeddings[0]
|
| 224 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
# Remove special tokens
|
| 226 |
+
vision_start_position_from_start = 5
|
| 227 |
+
vision_end_position_from_end = 6
|
| 228 |
image_embeddings = image_embeddings[vision_start_position_from_start:-vision_end_position_from_end]
|
| 229 |
+
|
| 230 |
+
return pil_image, image_embeddings, size, (grid_height, grid_width)
|
|
|
|
| 231 |
|
| 232 |
def _load_image(self, image: Union[str, bytes, Image.Image]) -> Image.Image:
|
| 233 |
"""Load image from various formats (URL, path, bytes, PIL Image)."""
|
|
|
|
| 273 |
similarity_scores = torch.cosine_similarity(token_expanded, patch_embeddings, dim=1)
|
| 274 |
return similarity_scores
|
| 275 |
|
| 276 |
+
def generate_heatmap(self, image: Image.Image, similarity_map: torch.Tensor, size: Tuple[int, int], grid_size: Tuple[int, int]) -> str:
|
| 277 |
"""
|
| 278 |
Generate a heatmap overlay on the image and return as base64.
|
| 279 |
|
| 280 |
Args:
|
| 281 |
image: Original PIL image.
|
| 282 |
+
similarity_map: Similarity scores [num_patches].
|
| 283 |
size: Original image size (width, height).
|
| 284 |
+
grid_size: Patch grid dimensions (height, width).
|
|
|
|
|
|
|
| 285 |
"""
|
| 286 |
+
# num_patches = similarity_map.shape[0]
|
| 287 |
+
grid_height, grid_width = grid_size
|
| 288 |
+
|
| 289 |
# Normalize to [0, 1]
|
| 290 |
similarity_map = (similarity_map - similarity_map.min()) / (
|
| 291 |
similarity_map.max() - similarity_map.min() + 1e-8
|
| 292 |
)
|
| 293 |
+
|
| 294 |
+
# Reshape to 2D grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
similarity_2d = similarity_map.reshape(grid_height, grid_width).cpu().numpy()
|
| 296 |
+
|
| 297 |
# Create & resize heatmap
|
| 298 |
heatmap = (self.colormap(similarity_2d) * 255).astype(np.uint8)
|
| 299 |
heatmap = Image.fromarray(heatmap[..., :3], mode="RGB")
|
| 300 |
heatmap = heatmap.resize(size, resample=Image.BICUBIC)
|
| 301 |
+
|
| 302 |
# Blend with original image
|
| 303 |
original_rgba = image.convert("RGBA")
|
| 304 |
heatmap_rgba = heatmap.convert("RGBA")
|
| 305 |
blended = Image.blend(original_rgba, heatmap_rgba, alpha=self.heatmap_alpha)
|
| 306 |
+
|
| 307 |
# Encode to base64
|
| 308 |
buffer = BytesIO()
|
| 309 |
blended.save(buffer, format="PNG")
|
|
|
|
| 317 |
) -> Tuple[List[str], Dict[str, str]]:
|
| 318 |
"""
|
| 319 |
Main method to generate similarity maps for all query tokens.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
"""
|
| 321 |
+
_, query_embeddings, token_map = self.process_query(query)
|
| 322 |
+
pil_image, patch_embeddings, size, grid_size = self.process_image(image)
|
| 323 |
+
|
| 324 |
heatmaps = {}
|
| 325 |
tokens_for_ui = []
|
| 326 |
+
|
| 327 |
for idx, token in token_map.items():
|
|
|
|
| 328 |
if self._should_filter_token(token):
|
| 329 |
continue
|
| 330 |
tokens_for_ui.append(token)
|
| 331 |
+
token_embedding = query_embeddings[idx]
|
| 332 |
sim_map = self.compute_similarity_map(
|
| 333 |
token_embedding, patch_embeddings, aggregation
|
| 334 |
)
|
| 335 |
+
heatmap_b64 = self.generate_heatmap(pil_image, sim_map, size, grid_size)
|
| 336 |
heatmaps[token] = heatmap_b64
|
| 337 |
|
| 338 |
return tokens_for_ui, heatmaps
|