fix: fix the bug when return_numpy is false
Browse files
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -334,6 +334,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 334 |
shuffle=False,
|
| 335 |
collate_fn=processor_fn,
|
| 336 |
)
|
|
|
|
|
|
|
| 337 |
results = []
|
| 338 |
self.eval()
|
| 339 |
for batch in tqdm(dataloader, desc=desc):
|
|
@@ -344,23 +346,18 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 344 |
):
|
| 345 |
embeddings = self(**batch, task_label=task_label)
|
| 346 |
attention_mask = embeddings.attention_mask
|
| 347 |
-
if not
|
| 348 |
embeddings = embeddings.single_vec_emb
|
| 349 |
if truncate_dim is not None:
|
| 350 |
embeddings = embeddings[:, :truncate_dim]
|
| 351 |
else:
|
| 352 |
embeddings = embeddings.multi_vec_emb
|
| 353 |
-
if return_multivector
|
| 354 |
-
# Get valid token mask from attention_mask
|
| 355 |
valid_tokens = attention_mask.bool()
|
| 356 |
-
# Remove padding by selecting only valid tokens for each sequence
|
| 357 |
embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
|
| 358 |
-
# Stack back into tensor with variable sequence lengths
|
| 359 |
results.append(embeddings)
|
| 360 |
else:
|
| 361 |
results.append(
|
| 362 |
-
# If return_numpy is True, move embeddings to CPU for numpy conversion
|
| 363 |
-
# Otherwise, unbind the tensor into a list of individual tensors along dim=0
|
| 364 |
embeddings.cpu()
|
| 365 |
if return_numpy
|
| 366 |
else list(torch.unbind(embeddings))
|
|
@@ -450,6 +447,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 450 |
)
|
| 451 |
|
| 452 |
return_list = isinstance(texts, list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
if isinstance(texts, str):
|
| 455 |
texts = [texts]
|
|
@@ -498,7 +501,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 498 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
| 499 |
batch_size: Number of images to process at once
|
| 500 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
| 501 |
-
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 502 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 503 |
max_pixels: Maximum number of pixels to process per image
|
| 504 |
|
|
@@ -515,6 +518,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 515 |
|
| 516 |
return_list = isinstance(images, list)
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
# Convert single image to list
|
| 519 |
if isinstance(images, (str, Image.Image)):
|
| 520 |
images = [images]
|
|
|
|
| 334 |
shuffle=False,
|
| 335 |
collate_fn=processor_fn,
|
| 336 |
)
|
| 337 |
+
if return_multivector and len(data) > 1:
|
| 338 |
+
assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
| 339 |
results = []
|
| 340 |
self.eval()
|
| 341 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
|
| 346 |
):
|
| 347 |
embeddings = self(**batch, task_label=task_label)
|
| 348 |
attention_mask = embeddings.attention_mask
|
| 349 |
+
if return_multivector and not return_numpy:
|
| 350 |
embeddings = embeddings.single_vec_emb
|
| 351 |
if truncate_dim is not None:
|
| 352 |
embeddings = embeddings[:, :truncate_dim]
|
| 353 |
else:
|
| 354 |
embeddings = embeddings.multi_vec_emb
|
| 355 |
+
if return_multivector:
|
|
|
|
| 356 |
valid_tokens = attention_mask.bool()
|
|
|
|
| 357 |
embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
|
|
|
|
| 358 |
results.append(embeddings)
|
| 359 |
else:
|
| 360 |
results.append(
|
|
|
|
|
|
|
| 361 |
embeddings.cpu()
|
| 362 |
if return_numpy
|
| 363 |
else list(torch.unbind(embeddings))
|
|
|
|
| 447 |
)
|
| 448 |
|
| 449 |
return_list = isinstance(texts, list)
|
| 450 |
+
|
| 451 |
+
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
| 452 |
+
if return_multivector and return_list and len(texts) > 1:
|
| 453 |
+
if return_numpy:
|
| 454 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
|
| 455 |
+
return_numpy = False
|
| 456 |
|
| 457 |
if isinstance(texts, str):
|
| 458 |
texts = [texts]
|
|
|
|
| 501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
| 502 |
batch_size: Number of images to process at once
|
| 503 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
| 504 |
+
return_numpy: Whether to return numpy arrays instead of torch tensors. If `return_multivector` is `True` and more than one image is encoded, this parameter is ignored.
|
| 505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 506 |
max_pixels: Maximum number of pixels to process per image
|
| 507 |
|
|
|
|
| 518 |
|
| 519 |
return_list = isinstance(images, list)
|
| 520 |
|
| 521 |
+
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
| 522 |
+
if return_multivector and return_list and len(images) > 1:
|
| 523 |
+
if return_numpy:
|
| 524 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`")
|
| 525 |
+
return_numpy = False
|
| 526 |
+
|
| 527 |
# Convert single image to list
|
| 528 |
if isinstance(images, (str, Image.Image)):
|
| 529 |
images = [images]
|