| from typing import Dict, List, Any | |
| from transformers import pipeline | |
| from PIL import Image | |
| import requests | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| self.pipe = pipeline("image-to-text", model=path) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """ | |
| data args: | |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) | |
| kwargs | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| inputs = data.pop('inputs', data) | |
| url = inputs.get('url') | |
| prompt = inputs.get('prompt') | |
| max_new_tokens = inputs.get('max_new_tokens', 1000) | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| prompt = f'user<image>\n{prompt}\nassistant:' | |
| results = self.pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": max_new_tokens}) | |
| result = results[0] | |
| result['generated_text'] = result['generated_text'].replace(prompt.replace('<image>', '') + ' ', '') | |
| return result |