| import base64 |
| import io |
| import os |
| from typing import Any, Dict, List |
| from urllib.parse import urlparse |
| from urllib.request import urlopen |
|
|
| import torch |
| from PIL import Image |
| from transformers import AutoModel, AutoProcessor |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = AutoModel.from_pretrained(path, trust_remote_code=True) |
| self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| payload = data.pop("inputs", data) |
| parameters = data.pop("parameters", {}) or {} |
|
|
| texts = self._coerce_texts(payload) |
| images = self._coerce_images(payload) |
| if not texts and not images: |
| raise ValueError( |
| "Expected `inputs` to include `text`/`texts` and/or `image`/`images`." |
| ) |
|
|
| result: Dict[str, Any] = {} |
| with torch.no_grad(): |
| text_embeds = None |
| image_embeds = None |
|
|
| if texts: |
| text_inputs = self.processor(text=texts, return_tensors="pt") |
| text_inputs = self._move_to_device(text_inputs) |
| text_embeds = self.model(**text_inputs).text_embeds |
| result["text_embedding"] = text_embeds.cpu().tolist() |
|
|
| if images: |
| image_inputs = self.processor(images=images, return_tensors="pt") |
| image_inputs = self._move_to_device(image_inputs) |
| image_embeds = self.model(**image_inputs).image_embeds |
| result["image_embedding"] = image_embeds.cpu().tolist() |
|
|
| if text_embeds is not None and image_embeds is not None: |
| scores = image_embeds @ text_embeds.t() |
| result["scores"] = scores.cpu().tolist() |
| if parameters.get("return_probs", True): |
| result["probs"] = scores.softmax(dim=-1).cpu().tolist() |
| if parameters.get("return_logits", False): |
| logit_scale = self.model.model.logit_scale.exp() |
| result["logits_per_image"] = ( |
| (logit_scale * image_embeds @ text_embeds.t()).cpu().tolist() |
| ) |
|
|
| return result |
|
|
| def _move_to_device(self, batch: Dict[str, Any]) -> Dict[str, Any]: |
| moved = {} |
| for key, value in batch.items(): |
| moved[key] = value.to(self.device) if hasattr(value, "to") else value |
| return moved |
|
|
| def _coerce_texts(self, payload: Any) -> List[str]: |
| if isinstance(payload, str): |
| return [payload] |
| if not isinstance(payload, dict): |
| return [] |
|
|
| texts = payload.get("text", payload.get("texts")) |
| if texts is None: |
| return [] |
| if isinstance(texts, str): |
| return [texts] |
| return [str(item) for item in texts] |
|
|
| def _coerce_images(self, payload: Any) -> List[Image.Image]: |
| if not isinstance(payload, dict): |
| return [] |
|
|
| images = payload.get("image", payload.get("images")) |
| if images is None: |
| return [] |
| if not isinstance(images, (list, tuple)): |
| images = [images] |
| return [self._load_image(item) for item in images] |
|
|
| def _load_image(self, value: Any) -> Image.Image: |
| if isinstance(value, Image.Image): |
| return value.convert("RGB") |
|
|
| if isinstance(value, dict): |
| for key in ("data", "image", "url", "path"): |
| if key in value: |
| value = value[key] |
| break |
|
|
| if not isinstance(value, str): |
| raise TypeError(f"Unsupported image input type: {type(value)!r}") |
|
|
| if os.path.exists(value): |
| return Image.open(value).convert("RGB") |
|
|
| parsed = urlparse(value) |
| if parsed.scheme in ("http", "https"): |
| with urlopen(value) as response: |
| return Image.open(io.BytesIO(response.read())).convert("RGB") |
|
|
| if value.startswith("data:image/"): |
| _, encoded = value.split(",", 1) |
| return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB") |
|
|
| try: |
| return Image.open(io.BytesIO(base64.b64decode(value))).convert("RGB") |
| except Exception as exc: |
| raise ValueError("Unsupported image string. Use URL, local path, or base64.") from exc |
|
|