import base64 import io import os from typing import Any, Dict, List from urllib.parse import urlparse from urllib.request import urlopen import torch from PIL import Image from transformers import AutoModel, AutoProcessor class EndpointHandler: def __init__(self, path: str = ""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = AutoModel.from_pretrained(path, trust_remote_code=True) self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) self.model.to(self.device) self.model.eval() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: payload = data.pop("inputs", data) parameters = data.pop("parameters", {}) or {} texts = self._coerce_texts(payload) images = self._coerce_images(payload) if not texts and not images: raise ValueError( "Expected `inputs` to include `text`/`texts` and/or `image`/`images`." ) result: Dict[str, Any] = {} with torch.no_grad(): text_embeds = None image_embeds = None if texts: text_inputs = self.processor(text=texts, return_tensors="pt") text_inputs = self._move_to_device(text_inputs) text_embeds = self.model(**text_inputs).text_embeds result["text_embedding"] = text_embeds.cpu().tolist() if images: image_inputs = self.processor(images=images, return_tensors="pt") image_inputs = self._move_to_device(image_inputs) image_embeds = self.model(**image_inputs).image_embeds result["image_embedding"] = image_embeds.cpu().tolist() if text_embeds is not None and image_embeds is not None: scores = image_embeds @ text_embeds.t() result["scores"] = scores.cpu().tolist() if parameters.get("return_probs", True): result["probs"] = scores.softmax(dim=-1).cpu().tolist() if parameters.get("return_logits", False): logit_scale = self.model.model.logit_scale.exp() result["logits_per_image"] = ( (logit_scale * image_embeds @ text_embeds.t()).cpu().tolist() ) return result def _move_to_device(self, batch: Dict[str, Any]) -> Dict[str, Any]: moved = {} for key, value in batch.items(): moved[key] = value.to(self.device) if hasattr(value, "to") else value return moved def _coerce_texts(self, payload: Any) -> List[str]: if isinstance(payload, str): return [payload] if not isinstance(payload, dict): return [] texts = payload.get("text", payload.get("texts")) if texts is None: return [] if isinstance(texts, str): return [texts] return [str(item) for item in texts] def _coerce_images(self, payload: Any) -> List[Image.Image]: if not isinstance(payload, dict): return [] images = payload.get("image", payload.get("images")) if images is None: return [] if not isinstance(images, (list, tuple)): images = [images] return [self._load_image(item) for item in images] def _load_image(self, value: Any) -> Image.Image: if isinstance(value, Image.Image): return value.convert("RGB") if isinstance(value, dict): for key in ("data", "image", "url", "path"): if key in value: value = value[key] break if not isinstance(value, str): raise TypeError(f"Unsupported image input type: {type(value)!r}") if os.path.exists(value): return Image.open(value).convert("RGB") parsed = urlparse(value) if parsed.scheme in ("http", "https"): with urlopen(value) as response: return Image.open(io.BytesIO(response.read())).convert("RGB") if value.startswith("data:image/"): _, encoded = value.split(",", 1) return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB") try: return Image.open(io.BytesIO(base64.b64decode(value))).convert("RGB") except Exception as exc: raise ValueError("Unsupported image string. Use URL, local path, or base64.") from exc