File size: 4,548 Bytes
e53cbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import base64
import io
import os
from typing import Any, Dict, List
from urllib.parse import urlparse
from urllib.request import urlopen

import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor


class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
        self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        payload = data.pop("inputs", data)
        parameters = data.pop("parameters", {}) or {}

        texts = self._coerce_texts(payload)
        images = self._coerce_images(payload)
        if not texts and not images:
            raise ValueError(
                "Expected `inputs` to include `text`/`texts` and/or `image`/`images`."
            )

        result: Dict[str, Any] = {}
        with torch.no_grad():
            text_embeds = None
            image_embeds = None

            if texts:
                text_inputs = self.processor(text=texts, return_tensors="pt")
                text_inputs = self._move_to_device(text_inputs)
                text_embeds = self.model(**text_inputs).text_embeds
                result["text_embedding"] = text_embeds.cpu().tolist()

            if images:
                image_inputs = self.processor(images=images, return_tensors="pt")
                image_inputs = self._move_to_device(image_inputs)
                image_embeds = self.model(**image_inputs).image_embeds
                result["image_embedding"] = image_embeds.cpu().tolist()

            if text_embeds is not None and image_embeds is not None:
                scores = image_embeds @ text_embeds.t()
                result["scores"] = scores.cpu().tolist()
                if parameters.get("return_probs", True):
                    result["probs"] = scores.softmax(dim=-1).cpu().tolist()
                if parameters.get("return_logits", False):
                    logit_scale = self.model.model.logit_scale.exp()
                    result["logits_per_image"] = (
                        (logit_scale * image_embeds @ text_embeds.t()).cpu().tolist()
                    )

        return result

    def _move_to_device(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        moved = {}
        for key, value in batch.items():
            moved[key] = value.to(self.device) if hasattr(value, "to") else value
        return moved

    def _coerce_texts(self, payload: Any) -> List[str]:
        if isinstance(payload, str):
            return [payload]
        if not isinstance(payload, dict):
            return []

        texts = payload.get("text", payload.get("texts"))
        if texts is None:
            return []
        if isinstance(texts, str):
            return [texts]
        return [str(item) for item in texts]

    def _coerce_images(self, payload: Any) -> List[Image.Image]:
        if not isinstance(payload, dict):
            return []

        images = payload.get("image", payload.get("images"))
        if images is None:
            return []
        if not isinstance(images, (list, tuple)):
            images = [images]
        return [self._load_image(item) for item in images]

    def _load_image(self, value: Any) -> Image.Image:
        if isinstance(value, Image.Image):
            return value.convert("RGB")

        if isinstance(value, dict):
            for key in ("data", "image", "url", "path"):
                if key in value:
                    value = value[key]
                    break

        if not isinstance(value, str):
            raise TypeError(f"Unsupported image input type: {type(value)!r}")

        if os.path.exists(value):
            return Image.open(value).convert("RGB")

        parsed = urlparse(value)
        if parsed.scheme in ("http", "https"):
            with urlopen(value) as response:
                return Image.open(io.BytesIO(response.read())).convert("RGB")

        if value.startswith("data:image/"):
            _, encoded = value.split(",", 1)
            return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB")

        try:
            return Image.open(io.BytesIO(base64.b64decode(value))).convert("RGB")
        except Exception as exc:
            raise ValueError("Unsupported image string. Use URL, local path, or base64.") from exc