Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import logging | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch.nn.attention.flex_attention import BlockMask | |
| from xformers.ops import AttentionBias | |
| from bytelatent.base_transformer import ( | |
| BaseTransformer, | |
| BaseTransformerArgs, | |
| flex_attention_comp, | |
| repeat_kv, | |
| ) | |
| from bytelatent.model.utils import create_causal_mask | |
| logger = logging.getLogger() | |
| try: | |
| from apex.normalization.fused_layer_norm import FusedRMSNorm | |
| RMSNorm = FusedRMSNorm | |
| except (ImportError, ModuleNotFoundError): | |
| logging.debug("Apex not found. Using nn.RMSNorm") | |
| RMSNorm = nn.RMSNorm | |
| class CrossAttention(nn.Module): | |
| """ | |
| CrossAttention block to attend to the encoder states from the decoder. | |
| Rope is not supported. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| head_dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| norm_eps: float, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.head_dim = head_dim | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.heads_per_group = self.n_heads // self.n_kv_heads | |
| self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) | |
| self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) | |
| self.wq = nn.Linear( | |
| dim, | |
| n_heads * head_dim, | |
| bias=False, | |
| ) | |
| self.wk = nn.Linear( | |
| dim, | |
| n_kv_heads * head_dim, | |
| bias=False, | |
| ) | |
| self.wv = nn.Linear( | |
| dim, | |
| n_kv_heads * head_dim, | |
| bias=False, | |
| ) | |
| self.wo = nn.Linear( | |
| n_heads * head_dim, | |
| dim, | |
| bias=False, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| kv: torch.Tensor, | |
| mask: Optional[Union[BlockMask, AttentionBias, str]] = None, | |
| ) -> torch.Tensor: | |
| # B S D | |
| bsz, seq_len, _ = x.shape | |
| _, slen_kv, _ = kv.shape | |
| x_norm = self.cross_attn_norm_q(x) | |
| kv = self.cross_attn_norm_kv(kv) | |
| xq = self.wq(x_norm) | |
| xk = self.wk(kv) | |
| xv = self.wv(kv) | |
| output_shape = xq.shape | |
| # B S D -> B S H D | |
| xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) | |
| xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) | |
| xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) | |
| xk = repeat_kv(xk, self.heads_per_group, dim=2) | |
| xv = repeat_kv(xv, self.heads_per_group, dim=2) | |
| assert mask is None or isinstance(mask, BlockMask) | |
| xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) | |
| output = flex_attention_comp(xq, xk, xv, block_mask=mask) | |
| output = output.transpose(1, 2).contiguous() # B H S D -> B S H D | |
| output = self.wo(output.reshape(output_shape)) | |
| return x + output | |
| def init_weights(self, base_std: float, factor: float = 1.0): | |
| std = base_std or (self.dim ** (-0.5)) / factor | |
| nn.init.trunc_normal_( | |
| self.wq.weight, | |
| mean=0.0, | |
| std=std, | |
| a=-3 * std, | |
| b=3 * std, | |
| ) | |
| nn.init.trunc_normal_( | |
| self.wk.weight, | |
| mean=0.0, | |
| std=std, | |
| a=-3 * std, | |
| b=3 * std, | |
| ) | |
| nn.init.trunc_normal_( | |
| self.wv.weight, | |
| mean=0.0, | |
| std=std, | |
| a=-3 * std, | |
| b=3 * std, | |
| ) | |
| nn.init.trunc_normal_( | |
| self.wo.weight, | |
| mean=0.0, | |
| std=std, | |
| a=-3 * std, | |
| b=3 * std, | |
| ) | |
| self.cross_attn_norm_q.reset_parameters() | |
| self.cross_attn_norm_kv.reset_parameters() | |
| class GlobalTransformer(BaseTransformer): | |
| def __init__(self, args: BaseTransformerArgs): | |
| super().__init__(args) | |
| self.dropout = args.dropout | |
| self.eos_id = args.eos_id | |
| self.dim_token_emb = args.dim_token_emb | |
| self.token_embedding_projection = None | |
| if args.dim_token_emb is not None and args.dim_token_emb != self.dim: | |
| self.token_embedding_projection = nn.Linear( | |
| args.dim_token_emb, | |
| args.dim, | |
| bias=False, | |
| ) | |
| def forward( | |
| self, | |
| tokens: torch.Tensor, | |
| tok_idx: Optional[torch.Tensor] = None, | |
| embeds: Optional[torch.Tensor] = None, | |
| mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, | |
| cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, | |
| ): | |
| """ | |
| Similar to BaseTransformer.forward, but with an additional embeds argument | |
| and projection to the token space. | |
| """ | |
| bs, seqlen = tokens.shape | |
| h = embeds | |
| mask = ( | |
| mask | |
| if mask is not None | |
| else create_causal_mask( | |
| seqlen, | |
| self.attn_impl, | |
| self.attn_bias_type, | |
| tokens=tokens, | |
| eos_id=self.eos_id, | |
| ) | |
| ) | |
| if self.token_embedding_projection is not None and h.shape[-1] != self.dim: | |
| h = self.token_embedding_projection(h) | |
| h = F.dropout(h, p=self.dropout, training=self.training) | |
| h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) | |
| return h, cache | |
| def init_weights(self): | |
| super().init_weights() | |
| std = self.dim_token_emb ** (-0.5) | |
| if self.token_embedding_projection is not None: | |
| nn.init.trunc_normal_( | |
| self.token_embedding_projection.weight, | |
| mean=0.0, | |
| std=std, | |
| a=-3 * std, | |
| b=3 * std, | |
| ) | |