from typing import List, Optional, Union import numpy as np import torch from PIL import Image from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin from transformers.image_utils import ImageFeatureExtractionMixin class M2EncoderImageProcessor(FeatureExtractionMixin, ImageFeatureExtractionMixin): model_input_names = ["pixel_values"] def __init__(self, size: int = 224, resample: int = Image.BICUBIC, **kwargs): super().__init__(**kwargs) if isinstance(size, dict): size = int(size.get("height") or size.get("width")) self.size = size self.resample = resample def __call__( self, images, return_tensors: Optional[Union[str, torch.Tensor]] = None, **kwargs, ) -> BatchFeature: if not isinstance(images, (list, tuple)): images = [images] pixel_values: List[np.ndarray] = [] for image in images: if not isinstance(image, Image.Image): image = Image.fromarray(np.asarray(image)) image = image.convert("RGB") image = image.resize((self.size, self.size), resample=self.resample) array = np.asarray(image, dtype=np.float32) / 255.0 array = np.transpose(array, (2, 0, 1)) pixel_values.append(array) return BatchFeature( data={"pixel_values": pixel_values}, tensor_type=return_tensors, )