| import tensorflow as tf |
| from tensorflow.keras.layers import Dense,Conv2d,BatchNormalization,LayerNormalization,MultiHeadAttention |
| from tensorflow.keras.layers import ZeroPadding2D,AveragePooling2D,Identity |
| from tensorflow.keras import Model |
| import numpy as np |
| from typing import Tuple, Union |
|
|
|
|
| class Bottleneck(tf.keras.layers.Layer): |
| expansion = 4 |
|
|
| def __init__(self, inplanes, planes, stride=1): |
| |
| super(Bottleneck, self).__init__() |
| self.conv1 = Conv2d(planes, 1, use_bias=False) |
| self.bn1 = BatchNormalization() |
| self.relu1 = tf.nn.relu |
|
|
| self.zeropadding2d = ZeroPadding2D(padding=1) |
| self.conv2 = Conv2d(planes, 3, use_bias=False) |
| self.bn2 = BatchNormalization() |
| self.relu2 = tf.nn.relu |
|
|
| self.avgpool = AveragePooling2D(stride, stride, 'VALID') if stride > 1 else Identity() |
|
|
| self.conv3 = Conv2d(planes * self.expansion, 1, use_bias=False) |
| self.bn3 = BatchNormalization() |
| self.relu3 = tf.nn.relu |
|
|
| self.downsample = None |
| self.stride = stride |
|
|
| if stride > 1 or inplanes != planes * Bottleneck.expansion: |
| |
| self.downsample = tf.keras.Sequential() |
| self.downsample.add(AveragePooling2D(stride, stride, 'VALID')) |
| self.downsample.add(Conv2d(planes * self.expansion, 1, strides=1, use_bias=False)) |
| self.downsample.add(BatchNormalization()) |
|
|
| def __call__(self, x): |
| identity = x |
|
|
| out = self.relu1(self.bn1(self.conv1(x))) |
| out = self.zeropadding2d(out) |
| out = self.relu2(self.bn2(self.conv2(out))) |
| out = self.avgpool(out) |
| out = self.bn3(self.conv3(out)) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.relu3(out) |
| return out |
|
|
|
|
| class AttentionPool2d(tf.keras.layers.Layer): |
| def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): |
| self.positional_embedding = self.add_weight( |
| name='positional_embedding', |
| shape=[self.spacial_dim ** 2 + 1, self.embed_dim], |
| initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5), |
| trainable=True |
| ) |
| self.k_proj = Dense(embed_dim) |
| self.q_proj = Dense(embed_dim) |
| self.v_proj = Dense(embed_dim) |
| self.c_proj = Dense(output_dim or embed_dim) |
| self.num_heads = num_heads |
|
|
| def __call__(self, x): |
| shape = x.shape |
| batch_size = shape[0] |
| height = shape[1] |
| width = shape[2] |
| channels = shape[3] |
| new_shape = (batch_size, height * width, channels) |
| x = tf.transpose(tf.reshape(x, new_shape), (1, 0, 2)) |
| x = tf.concat([tf.reduce_mean(x, axis=0, keepdims=True), x], axis=0) |
| x = x + tf.cast(self.positional_embedding[:, None, :], x.dtype) |
| tgt_len, bsz, embed_dim = x.shape |
| query=self.q_proj(x[:1]) |
| key=self.k_proj(x) |
| value=self.v_proj(x) |
| query = tf.reshape(query, [bsz, 1, self.num_heads, -1]) |
| query = tf.transpose(query, [0, 2, 1, 3]) |
| query = tf.multiply(query, 1.0 / tf.math.sqrt(float(embed_dim))) |
| key = tf.reshape(key, [bsz, tgt_len, self.num_heads, -1]) |
| key = tf.transpose(key, [0, 2, 3, 1]) |
| value = tf.reshape(value, [bsz, tgt_len, self.num_heads, -1]) |
| value = tf.transpose(value, [0, 2, 1, 3]) |
| qk = tf.matmul(query, key) |
| w = tf.nn.softmax(qk) |
| wv = tf.reshape(tf.transpose(tf.matmul(w, value), [0, 2, 1, 3]), [1, bsz, -1]) |
| x = self.c_proj(wv) |
| return tf.squeeze(x, 0) |
|
|
|
|
| class ModifiedResNet: |
| """ |
| A ResNet class that is similar to torchvision's but contains the following changes: |
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. |
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 |
| - The final pooling layer is a QKV attention instead of an average pool |
| """ |
|
|
| def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): |
| self.output_dim = output_dim |
| self.input_resolution = input_resolution |
|
|
| |
| self.zeropadding2d = ZeroPadding2D(padding=1) |
| self.conv1 = Conv2d(width // 2, kernel_size=3, strides=2, use_bias=False) |
| self.bn1 = BatchNormalization() |
| self.relu1 = tf.nn.relu |
| self.conv2 = Conv2d(width // 2, kernel_size=3, use_bias=False) |
| self.bn2 = BatchNormalization() |
| self.relu2 = tf.nn.relu |
| self.conv3 = Conv2d(width, kernel_size=3, use_bias=False) |
| self.bn3 = BatchNormalization() |
| self.relu3 = tf.nn.relu |
| self.avgpool = AveragePooling2D(2, 2, 'VALID') |
|
|
| |
| self._inplanes = width |
| self.layer1 = self._make_layer(width, layers[0]) |
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) |
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) |
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) |
|
|
| embed_dim = width * 32 |
| self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) |
|
|
| def _make_layer(self, planes, blocks, stride=1): |
| layers = tf.keras.Sequential() |
| layers.add(Bottleneck(self._inplanes, planes, stride)) |
|
|
| self._inplanes = planes * Bottleneck.expansion |
| for _ in range(1, blocks): |
| layers.add(Bottleneck(self._inplanes, planes)) |
|
|
| return layers |
|
|
| def __call__(self, x): |
| def stem(x): |
| x = self.zeropadding2d(x) |
| x = self.conv1(x) |
| x = self.relu1(self.bn1(x)) |
| x = self.zeropadding2d(x) |
| x = self.conv2(x) |
| x = self.relu2(self.bn2(x)) |
| x = self.zeropadding2d(x) |
| x = self.conv3(x) |
| x = self.relu3(self.bn3(x)) |
| x = self.avgpool(x) |
| return x |
|
|
| x = stem(x) |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
| x = self.attnpool(x) |
|
|
| return x |
|
|
|
|
| class LayerNorm: |
| """Subclass torch's LayerNorm to handle fp16.""" |
| def __init__(self, input_size): |
| self.layer_norm = LayerNormalization() |
|
|
| def __call__(self, x): |
| orig_type = x.dtype |
| ret = self.layer_norm(tf.cast(x, tf.float32)) |
| return tf.cast(ret, orig_type) |
|
|
|
|
| class QuickGELU(tf.keras.layers.Layer): |
| def __init__(self): |
| super(QuickGELU, self).__init__() |
| |
| def __call__(self, x): |
| return x * tf.nn.sigmoid(1.702 * x) |
|
|
|
|
| class ResidualAttentionBlock(tf.keras.layers.Layer): |
| def __init__(self, d_model: int, n_head: int, attn_mask = None): |
| super(ResidualAttentionBlock, self).__init__() |
| self.attn = MultiHeadAttention(n_head, d_model) |
| self.ln_1 = LayerNorm(d_model) |
| self.mlp = tf.keras.Sequential() |
| self.mlp.add(Dense(d_model * 4)) |
| self.mlp.add(QuickGELU()) |
| self.mlp.add(Dense(d_model)) |
| self.ln_2 = LayerNorm(d_model) |
| self.attn_mask = attn_mask |
|
|
| def attention(self, x): |
| self.attn_mask = tf.cast(self.attn_mask, x.dtype) if self.attn_mask is not None else None |
| return self.attn(x, x, attention_mask=self.attn_mask)[0] |
|
|
| def __call__(self, x): |
| x = x + self.attention(self.ln_1(x)) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
|
|
| class Transformer: |
| def __init__(self, width: int, layers: int, heads: int, attn_mask = None): |
| self.width = width |
| self.layers = layers |
| self.resblocks = tf.keras.Sequential() |
| for _ in range(layers): |
| self.resblocks.add(ResidualAttentionBlock(width, heads, attn_mask)) |
|
|
| def __call__(self, x): |
| return self.resblocks(x) |
|
|
|
|
| class VisionTransformer(tf.keras.layers.Layer): |
| def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
| self.input_resolution = input_resolution |
| self.output_dim = output_dim |
| self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False) |
|
|
| scale = width ** -0.5 |
| self.class_embedding = self.add_weight( |
| name='class_embedding', |
| shape=[self.width], |
| initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale, |
| trainable=True |
| ) |
| self.positional_embedding = self.add_weight( |
| name='positional_embedding', |
| shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width], |
| initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale, |
| trainable=True |
| ) |
| self.ln_pre = LayerNorm(width) |
|
|
| self.transformer = Transformer(width, layers, heads) |
|
|
| self.ln_post = LayerNorm(width) |
| self.proj = tf.Variable(scale * tf.random.normal(width, output_dim)) |
|
|
| def __call__(self, x, train_flag=True): |
| x = self.conv1(x) |
| x = tf.reshape(x, [x.shape[0], x.shape[1], -1]) |
| x = tf.transpose(x, (0, 2, 1)) |
| x = tf.concat([tf.cast(self.class_embedding, x.dtype) + tf.zeros([x.shape[0], 1, x.shape[-1]], dtype=x.dtype), x], axis=1) |
| x = x + tf.cast(self.positional_embedding, x.dtype) |
| x = self.ln_pre(x) |
|
|
| x = tf.transpose(x, (1, 0, 2)) |
| x = self.transformer(x) |
| x = tf.transpose(x, (1, 0, 2)) |
|
|
| x = self.ln_post(x[:, 0, :]) |
|
|
| if self.proj is not None: |
| x = tf.matmul(x, self.proj) |
|
|
| return x |
|
|
|
|
| class CLIP(Model): |
| def __init__(self, |
| embed_dim: int, |
| |
| image_resolution: int, |
| vision_layers: Union[Tuple[int, int, int, int], int], |
| vision_width: int, |
| vision_patch_size: int, |
| |
| context_length: int, |
| vocab_size: int, |
| transformer_width: int, |
| transformer_heads: int, |
| transformer_layers: int |
| ): |
| super(CLIP, self).__init__() |
| |
| self.context_length = context_length |
|
|
| if isinstance(vision_layers, (tuple, list)): |
| vision_heads = vision_width * 32 // 64 |
| self.visual = ModifiedResNet( |
| layers=vision_layers, |
| output_dim=embed_dim, |
| heads=vision_heads, |
| input_resolution=image_resolution, |
| width=vision_width |
| ) |
| else: |
| vision_heads = vision_width // 64 |
| self.visual = VisionTransformer( |
| input_resolution=image_resolution, |
| patch_size=vision_patch_size, |
| width=vision_width, |
| layers=vision_layers, |
| heads=vision_heads, |
| output_dim=embed_dim |
| ) |
|
|
| self.transformer = Transformer( |
| width=transformer_width, |
| layers=transformer_layers, |
| heads=transformer_heads, |
| attn_mask=self.build_attention_mask() |
| ) |
|
|
| self.vocab_size = vocab_size |
| self.token_embedding = self.add_weight( |
| name='token_embedding', |
| shape=(vocab_size, transformer_width), |
| initializer=tf.keras.initializers.RandomNormal(stddev=0.02), |
| trainable=True |
| ) |
| self.positional_embedding = self.add_weight( |
| name='positional_embedding', |
| shape=(self.context_length, transformer_width), |
| initializer=tf.keras.initializers.RandomNormal(stddev=0.01), |
| trainable=True |
| ) |
| self.ln_final = LayerNorm(transformer_width) |
|
|
| self.text_projection = self.add_weight( |
| name='text_projection', |
| shape=(transformer_width, embed_dim), |
| initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5), |
| trainable=True |
| ) |
| self.logit_scale = self.add_weight( |
| name='logit_scale', |
| shape=[], |
| initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)), |
| trainable=True |
| ) |
|
|
| def build_attention_mask(self): |
| mask = tf.ones((self.context_length, self.context_length)) |
| mask = tf.linalg.band_part(mask, 0, -1) |
| mask = mask * -1e9 |
| return mask |
|
|
| def encode_image(self, image): |
| return self.visual(image) |
|
|
| def encode_text(self, text): |
| x = tf.gather(self.token_embedding, text) |
|
|
| x = x + self.positional_embedding |
| x = tf.transpose(x, (1, 0, 2)) |
| x = self.transformer(x) |
| x = tf.transpose(x, (1, 0, 2)) |
| x = self.ln_final(x) |
|
|
| |
| |
| x = tf.matmul(tf.gather_nd(x, tf.stack([tf.range(x.shape[0], dtype='int32'), |
| tf.argmax(text, axis=-1, output_type='int32')], axis=1)), self.text_projection) |
|
|
| return x |
|
|
| def __call__(self, image, text): |
| image_features = self.encode_image(image) |
| text_features = self.encode_text(text) |
|
|
| |
| image_features = image_features / tf.norm(image_features, axis=1, keepdims=True) |
| text_features = text_features / tf.norm(text_features, axis=1, keepdims=True) |
|
|
| |
| logit_scale = tf.math.exp(self.logit_scale) |
| logits_per_image = tf.matmul(logit_scale * image_features, tf.transpose(text_features)) |
| logits_per_text = tf.transpose(logits_per_image) |
|
|
| |
| return logits_per_image, logits_per_text |