| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin |
| from transformers.image_utils import ImageFeatureExtractionMixin |
|
|
|
|
| class M2EncoderImageProcessor(FeatureExtractionMixin, ImageFeatureExtractionMixin): |
| model_input_names = ["pixel_values"] |
|
|
| def __init__(self, size: int = 224, resample: int = Image.BICUBIC, **kwargs): |
| super().__init__(**kwargs) |
| if isinstance(size, dict): |
| size = int(size.get("height") or size.get("width")) |
| self.size = size |
| self.resample = resample |
|
|
| def __call__( |
| self, |
| images, |
| return_tensors: Optional[Union[str, torch.Tensor]] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| if not isinstance(images, (list, tuple)): |
| images = [images] |
|
|
| pixel_values: List[np.ndarray] = [] |
| for image in images: |
| if not isinstance(image, Image.Image): |
| image = Image.fromarray(np.asarray(image)) |
| image = image.convert("RGB") |
| image = image.resize((self.size, self.size), resample=self.resample) |
| array = np.asarray(image, dtype=np.float32) / 255.0 |
| array = np.transpose(array, (2, 0, 1)) |
| pixel_values.append(array) |
|
|
| return BatchFeature( |
| data={"pixel_values": pixel_values}, |
| tensor_type=return_tensors, |
| ) |
|
|