feat-remove-paddings-0623 (#22)
Browse files- fix: remove the padding tokens when a list of multivectors are returned (ef1876f5b9dbe290d7a58ff16fc37367217d32c5)
- fix: fix the bug when return_numpy is false (77d5a29ef8ef2396f56106d8ed882e322b7dc9be)
- fix: fix the bug when return_numpy is false (205b18f42bb9bd3ee57bab31cdd1b3116b6d762b)
- fix: fix the bug when return_numpy is false (6bb8cf2b7575b11b19c8f9780efda6f0b1a61708)
- fix: fix the bug (3ad717f7eaca1d26701063a16cfbe5f40ebaf551)
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -127,11 +127,13 @@ class JinaEmbeddingsV4ModelOutput:
|
|
| 127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
| 128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
| 129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
|
|
|
| 130 |
"""
|
| 131 |
|
| 132 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
| 133 |
single_vec_emb: Optional[torch.Tensor] = None
|
| 134 |
multi_vec_emb: Optional[torch.Tensor] = None
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
@@ -312,6 +314,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 312 |
),
|
| 313 |
single_vec_emb=single_vec_emb,
|
| 314 |
multi_vec_emb=multi_vec_emb,
|
|
|
|
| 315 |
)
|
| 316 |
|
| 317 |
def _process_batches(
|
|
@@ -331,6 +334,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 331 |
shuffle=False,
|
| 332 |
collate_fn=processor_fn,
|
| 333 |
)
|
|
|
|
|
|
|
| 334 |
results = []
|
| 335 |
self.eval()
|
| 336 |
for batch in tqdm(dataloader, desc=desc):
|
|
@@ -340,17 +345,23 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 341 |
):
|
| 342 |
embeddings = self(**batch, task_label=task_label)
|
|
|
|
| 343 |
if not return_multivector:
|
| 344 |
embeddings = embeddings.single_vec_emb
|
| 345 |
if truncate_dim is not None:
|
| 346 |
embeddings = embeddings[:, :truncate_dim]
|
| 347 |
else:
|
| 348 |
embeddings = embeddings.multi_vec_emb
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
if return_numpy:
|
| 355 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
| 356 |
return [item for sublist in results for item in sublist]
|
|
@@ -436,6 +447,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 436 |
)
|
| 437 |
|
| 438 |
return_list = isinstance(texts, list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
if isinstance(texts, str):
|
| 441 |
texts = [texts]
|
|
@@ -484,7 +501,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 484 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
| 485 |
batch_size: Number of images to process at once
|
| 486 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
| 487 |
-
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 488 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 489 |
max_pixels: Maximum number of pixels to process per image
|
| 490 |
|
|
@@ -501,6 +518,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 501 |
|
| 502 |
return_list = isinstance(images, list)
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
# Convert single image to list
|
| 505 |
if isinstance(images, (str, Image.Image)):
|
| 506 |
images = [images]
|
|
|
|
| 127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
| 128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
| 129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
| 130 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 131 |
"""
|
| 132 |
|
| 133 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
| 134 |
single_vec_emb: Optional[torch.Tensor] = None
|
| 135 |
multi_vec_emb: Optional[torch.Tensor] = None
|
| 136 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 137 |
|
| 138 |
|
| 139 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
|
|
| 314 |
),
|
| 315 |
single_vec_emb=single_vec_emb,
|
| 316 |
multi_vec_emb=multi_vec_emb,
|
| 317 |
+
attention_mask=attention_mask,
|
| 318 |
)
|
| 319 |
|
| 320 |
def _process_batches(
|
|
|
|
| 334 |
shuffle=False,
|
| 335 |
collate_fn=processor_fn,
|
| 336 |
)
|
| 337 |
+
if return_multivector and len(data) > 1:
|
| 338 |
+
assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
| 339 |
results = []
|
| 340 |
self.eval()
|
| 341 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
|
| 345 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 346 |
):
|
| 347 |
embeddings = self(**batch, task_label=task_label)
|
| 348 |
+
attention_mask = embeddings.attention_mask
|
| 349 |
if not return_multivector:
|
| 350 |
embeddings = embeddings.single_vec_emb
|
| 351 |
if truncate_dim is not None:
|
| 352 |
embeddings = embeddings[:, :truncate_dim]
|
| 353 |
else:
|
| 354 |
embeddings = embeddings.multi_vec_emb
|
| 355 |
+
if return_multivector and not return_numpy:
|
| 356 |
+
valid_tokens = attention_mask.bool()
|
| 357 |
+
embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
|
| 358 |
+
results.append(embeddings)
|
| 359 |
+
else:
|
| 360 |
+
results.append(
|
| 361 |
+
embeddings.cpu()
|
| 362 |
+
if return_numpy
|
| 363 |
+
else list(torch.unbind(embeddings))
|
| 364 |
+
)
|
| 365 |
if return_numpy:
|
| 366 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
| 367 |
return [item for sublist in results for item in sublist]
|
|
|
|
| 447 |
)
|
| 448 |
|
| 449 |
return_list = isinstance(texts, list)
|
| 450 |
+
|
| 451 |
+
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
| 452 |
+
if return_multivector and return_list and len(texts) > 1:
|
| 453 |
+
if return_numpy:
|
| 454 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
|
| 455 |
+
return_numpy = False
|
| 456 |
|
| 457 |
if isinstance(texts, str):
|
| 458 |
texts = [texts]
|
|
|
|
| 501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
| 502 |
batch_size: Number of images to process at once
|
| 503 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
| 504 |
+
return_numpy: Whether to return numpy arrays instead of torch tensors. If `return_multivector` is `True` and more than one image is encoded, this parameter is ignored.
|
| 505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 506 |
max_pixels: Maximum number of pixels to process per image
|
| 507 |
|
|
|
|
| 518 |
|
| 519 |
return_list = isinstance(images, list)
|
| 520 |
|
| 521 |
+
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
| 522 |
+
if return_multivector and return_list and len(images) > 1:
|
| 523 |
+
if return_numpy:
|
| 524 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`")
|
| 525 |
+
return_numpy = False
|
| 526 |
+
|
| 527 |
# Convert single image to list
|
| 528 |
if isinstance(images, (str, Image.Image)):
|
| 529 |
images = [images]
|