File size: 5,548 Bytes
493df70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from typing import List, Optional, Union, Any, Dict
from PIL import Image
import torch
from transformers.image_processing_base import BatchFeature
from transformers.image_processing_utils_fast import BaseImageProcessorFast, divide_to_patches
from transformers.image_utils import (make_list_of_images, get_image_size,
get_image_type, ImageInput, ImageType, ChannelDimension)
from transformers.utils import TensorType
import torchvision.transforms as T
class NemotronNanoVLV2ImageProcessor(BaseImageProcessorFast):
model_input_names = ["pixel_values"]
def __init__(self, image_size=512, max_num_tiles=12, use_thumbnail=True, norm_mean=None, norm_std=None, do_rescale=True, patch_size=16, downsample_ratio=0.5, **kwargs):
super().__init__(**kwargs)
self.image_size = image_size
self.max_num_tiles = max_num_tiles
self.use_thumbnail = use_thumbnail
self.norm_mean = norm_mean
self.norm_std = norm_std
self.do_rescale = do_rescale
self.num_image_token = int((image_size // patch_size) ** 2 * (downsample_ratio ** 2))
def _process_image(
self,
image: ImageInput,
**kwargs,
) -> torch.Tensor:
image_type = get_image_type(image)
if image_type == ImageType.PIL:
if image.mode != 'RGB':
image = image.convert('RGB')
image = T.ToTensor()(image)
return image
def _preprocess(
self,
images: List[torch.Tensor],
image_size: int = None,
max_num_tiles: int = None,
use_thumbnail: bool = None,
do_rescale: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> List[torch.Tensor]:
image_size = image_size if image_size is not None else self.image_size
max_num_tiles = max_num_tiles if max_num_tiles is not None else self.max_num_tiles
use_thumbnail = use_thumbnail if use_thumbnail is not None else self.use_thumbnail
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
images = make_list_of_images(images)
all_patches = []
num_patches = []
for image in images:
patches = dynamic_preprocess(image, image_size, max_num_tiles, use_thumbnail)
all_patches.extend(patches)
num_patches.append(len(patches))
pixel_values = torch.stack(all_patches, dim=0)
norm_mean = torch.Tensor(self.norm_mean).view(1, 3, 1, 1)
norm_std = torch.Tensor(self.norm_std).view(1, 3, 1, 1)
pixel_values = (pixel_values - norm_mean) / norm_std
return BatchFeature(data={"pixel_values": pixel_values, "num_patches": num_patches}, tensor_type=return_tensors)
def get_internvl_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
# Copyright (c) 2023 OpenGVLab.
def find_closest_aspect_ratio(
aspect_ratio: float,
target_ratios: list[tuple[int, int]],
width: int,
height: int,
image_size: int,
) -> tuple[int, int]:
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def calculate_targets(
orig_width: int,
orig_height: int,
target_ratios: list[tuple[int, int]],
image_size: int,
) -> tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
width=orig_width,
height=orig_height,
image_size=image_size,
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
return blocks, target_width, target_height
def dynamic_preprocess(image, image_size=512, max_num_tiles=12, use_thumbnail=True):
orig_height, orig_width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
target_ratios = get_internvl_target_ratios(1, max_num_tiles)
blocks, target_width, target_height = calculate_targets(
orig_width,
orig_height,
target_ratios,
image_size
)
# resize the image
resized_img = T.Resize((target_height, target_width), interpolation=T.InterpolationMode.BICUBIC)(image)
patches = divide_to_patches(resized_img, image_size)
assert len(patches) == blocks
if use_thumbnail and len(patches) != 1:
thumbnail_img = T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(image)
patches.append(thumbnail_img)
return patches
|