| |
| from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig |
| from .config import InternVideo2Config as config |
| import warnings |
| import torch |
| from torch import nn |
| import torchvision.transforms as transforms |
| from torchvision.transforms import InterpolationMode |
| from transformers.utils import logging |
| warnings.filterwarnings("ignore") |
| from .internvideo2_clip_vision import InternVideo2 |
| from .mobile_clip import TextTransformer, ClipTokenizer |
| logger = logging.get_logger(__name__) |
|
|
| class InternVideo2_CLIP_small(PreTrainedModel): |
| config_class = config |
|
|
| def __init__(self, config, tokenizer=None, is_pretrain=True): |
| super().__init__(config) |
| self.config = config |
| self.tokenizer = tokenizer |
| self.is_pretrain = is_pretrain |
| print(config) |
| if tokenizer is None: |
| self.tokenizer = ClipTokenizer(self.config.model.text_encoder) |
| |
| self.vision_encoder = self.build_vision_encoder() |
|
|
| self.vision_align = nn.Sequential( |
| nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim), |
| nn.Linear( |
| self.config.model.vision_encoder.clip_embed_dim, |
| self.config.model.vision_encoder.align_dim |
| ), |
| ) |
| self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"]) |
| |
| self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) |
| self.temp_min = config.model.temp_min |
|
|
| if self.config.model.freeze_vision: |
| for name, p in self.vision_encoder.named_parameters(): |
| if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'): |
| logger.info(f"Unfreeze {name}") |
| else: |
| logger.info(f"Freeze {name}") |
| p.requires_grad = False |
| if self.config.model.freeze_text: |
| for name, p in self.text_encoder.named_parameters(): |
| if self.config.model.open_text_projection and name.startswith('projection_layer'): |
| logger.info(f"Unfreeze {name}") |
| else: |
| logger.info(f"Freeze {name}") |
| p.requires_grad = False |
| img_size = self.config.model.vision_encoder.img_size |
| self.transform = transforms.Compose( |
| [ |
| transforms.Resize( |
| (img_size, img_size), |
| interpolation=InterpolationMode.BICUBIC, |
| ), |
| transforms.Lambda(lambda x: x.float().div(255.0)), |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ] |
| ) |
|
|
|
|
| @torch.no_grad() |
| def clip_contrastive_temperature(self): |
| """Seems only used during pre-training""" |
| self.temp.clamp_(min=self.temp_min) |
|
|
| def encode_vision(self, image, test=False): |
| """encode image / videos as features. |
| |
| Args: |
| image (torch.Tensor): The input images. |
| test (bool): Whether testing. |
| |
| Returns: tuple. |
| - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C]. |
| |
| """ |
| T = image.shape[1] |
| use_image = True if T == 1 else False |
| image = image.permute(0, 2, 1, 3, 4) |
|
|
| vision_embeds = self.vision_encoder(image, use_image=use_image) |
| vision_embeds = self.vision_align(vision_embeds) |
| return vision_embeds |
|
|
| def encode_text(self, text): |
| """encode text. |
| Args: |
| text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: |
| - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. |
| - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. |
| - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". |
| Returns: tuple. |
| - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C]. |
| |
| """ |
| text_embeds = self.text_encoder(text) |
| return text_embeds |
|
|
| def build_vision_encoder(self): |
| """build vision encoder |
| Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. |
| |
| """ |
| vision_encoder = InternVideo2( |
| in_chans=self.config.model.vision_encoder.in_chans, |
| patch_size=self.config.model.vision_encoder.patch_size, |
| img_size=self.config.model.vision_encoder.img_size, |
| qkv_bias=self.config.model.vision_encoder.qkv_bias, |
| drop_path_rate=self.config.model.vision_encoder.drop_path_rate, |
| head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate, |
| embed_dim=self.config.model.vision_encoder.embed_dim, |
| num_heads=self.config.model.vision_encoder.num_heads, |
| mlp_ratio=self.config.model.vision_encoder.mlp_ratio, |
| init_values=self.config.model.vision_encoder.init_values, |
| qk_normalization=self.config.model.vision_encoder.qk_normalization, |
| depth=self.config.model.vision_encoder.depth, |
| use_flash_attn=self.config.model.vision_encoder.use_flash_attn, |
| use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm, |
| use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp, |
| fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic, |
| attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads, |
| clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim, |
| layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32, |
| num_frames=self.config.model.vision_encoder.num_frames, |
| tubelet_size=self.config.model.vision_encoder.tubelet_size, |
| sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed, |
| use_checkpoint=self.config.model.vision_encoder.use_checkpoint, |
| checkpoint_num=self.config.model.vision_encoder.checkpoint_num, |
| ) |
| return vision_encoder |
|
|
| def build_text_encoder(self, cfg, projection_dim): |
| """build text_encoder and possiblly video-to-text multimodal fusion encoder. |
| Returns: nn.Module. The text encoder |
| |
| """ |
| text_encoder = TextTransformer(cfg, projection_dim) |
|
|
| return text_encoder |
| |
| if __name__ == "__main__": |
| model_config = config() |
| model = InternVideo2Stage2VideoEncoder(model_config) |
| x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) |
| output = model(x) |