Spaces:
Build error
Build error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
| t = torch.arange(end, device=freqs.device) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| return freqs_cis | |
| def reshape_for_broadcast(freqs_cis, x): | |
| batch_size, num_heads, seq_len, head_size = x.shape | |
| freqs_cis = freqs_cis[:seq_len] | |
| shape = [1, 1, seq_len, head_size // 2] | |
| return freqs_cis.view(*shape) | |
| def apply_rope(x, position, freqs_cis): | |
| x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
| freqs_cis = reshape_for_broadcast(freqs_cis, x) | |
| x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) | |
| return x_out.type_as(x) | |
| class RMSNorm(torch.nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class Attention(nn.Module): | |
| """ | |
| Multi-head Self-Attention with RoPE | |
| """ | |
| def __init__(self, num_heads, head_size, num_embed): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_size = head_size | |
| self.wq = nn.Linear(num_embed, num_heads * head_size, bias = False) | |
| self.wk = nn.Linear(num_embed, num_heads * head_size, bias = False) | |
| self.wv = nn.Linear(num_embed, num_heads * head_size, bias = False) | |
| self.wo = nn.Linear(num_heads * head_size, num_embed, bias = False) | |
| def forward(self, x, freqs_cis): | |
| B, T, C = x.shape | |
| mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1) | |
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | |
| xq = xq.view(B, T, self.num_heads, self.head_size) | |
| xk = xk.view(B, T, self.num_heads, self.head_size) | |
| xv = xv.view(B, T, self.num_heads, self.head_size) | |
| xq = xq.transpose(1, 2) | |
| xk = xk.transpose(1, 2) | |
| xv = xv.transpose(1, 2) | |
| xq = apply_rope(xq, T, freqs_cis) | |
| xk = apply_rope(xk, T, freqs_cis) | |
| attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_size) | |
| attn_weights += mask | |
| attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq) | |
| output = torch.matmul(attn_weights, xv) | |
| output = output.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.wo(output) | |
| class MLP(nn.Module): | |
| def __init__(self, num_embed, dropout): | |
| super().__init__() | |
| self.num_embed = num_embed | |
| hidden_dim = 3 * int(num_embed * 2 / 3) | |
| self.linear1 = nn.Linear(num_embed, hidden_dim) | |
| self.linear2 = nn.Linear(hidden_dim, num_embed) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| x = self.linear1(x) | |
| x = F.silu(x) | |
| x = self.linear2(x) | |
| x = self.dropout(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| """ | |
| This calss will group together MultiHead Attention and | |
| FeedForward NN, so that we can copy it in Transformer | |
| """ | |
| def __init__(self, num_heads, num_embed, dropout): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.num_embed = num_embed | |
| head_size = num_embed // num_heads | |
| self.sa = Attention( | |
| num_heads=num_heads, | |
| head_size=head_size, | |
| num_embed=num_embed | |
| ) | |
| self.ffwd = MLP(num_embed=num_embed, dropout=dropout) | |
| # add the layer normalization | |
| self.ln1 = RMSNorm(num_embed) | |
| self.ln2 = RMSNorm(num_embed) | |
| def forward(self, x, freqs_cis): | |
| # "x +" is the skip (or residual) connection | |
| # it helps with optimization | |
| # also we apply layer normalization before self-attention | |
| # and feed-forward (a reshufle from original paper) | |
| x = x + self.sa(self.ln1(x), freqs_cis) | |
| x = x + self.ffwd(self.ln2(x)) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| # a simple lookup table that stores embeddings of a fixed dictionary and size | |
| # each token directly reads off the logits for the next token from a lookup table | |
| # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html | |
| self.vocab_size = kwargs.get("vocab_size", 100) | |
| self.num_embed = kwargs.get("num_embed", 32) | |
| self.num_heads = kwargs.get("num_heads", 4) | |
| self.num_layers = kwargs.get("num_layers", 4) | |
| self.max_seq_len = kwargs.get("max_seq_len", 1024) | |
| self.dropout = kwargs.get("dropout", 0.2) | |
| # each token reads the logits for the next token from a lookup table | |
| self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed) | |
| # each position from 0 to block_size-1 will get its embedding | |
| #self.position_embedding_table = nn.Embedding(self.block_size, self.num_embed) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock( | |
| num_heads=self.num_heads, | |
| num_embed=self.num_embed, | |
| dropout=self.dropout | |
| ) | |
| for _ in range(self.num_layers) | |
| ]) | |
| # we add the layer norm before the Linear layer | |
| self.lm_head = nn.Linear(self.num_embed, self.vocab_size) | |
| self.norm = RMSNorm(self.num_embed) | |
| self.freqs_cis = precompute_freqs_cis( | |
| self.num_embed//self.num_heads, | |
| self.max_seq_len * 2, | |
| 500000, | |
| ) | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| # idx and targets are (B,T) tensor of integers | |
| # the token_emb is (B, T, C), C = NUM_EMBED | |
| x = self.token_embedding_table(idx) | |
| freq = self.freqs_cis[:self.max_seq_len] | |
| # apply one head of self-attention | |
| for block in self.blocks: | |
| x = block(x, freq) | |
| x = self.norm(x) | |
| # (B, T, vocab_size) | |
| logits = self.lm_head(x) | |
| # compute the loss | |
| if targets != None: | |
| # cross_entropy accepts inputs in a (batch_size, num_classes) | |
| # so we need to reformat our logits dimensions to | |
| # (batch_size * time, dim_vocabulary), time = block_size | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) | |
| else: | |
| loss = None | |
| return logits, loss | |
| def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.7, top_p: float = 0.9): | |
| for _ in range(max_new_tokens): | |
| idx_crop = idx[:, -self.max_seq_len:] | |
| freq = self.freqs_cis[:self.max_seq_len] | |
| logits, loss = self.forward(idx_crop) | |
| logits = logits[:, -1, :] | |
| if temperature > 0: | |
| probs = F.softmax(logits / temperature, dim=-1) | |
| idx_next = self.sample_top_p(probs, top_p) | |
| else: | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) | |
| return idx[0] | |
| def sample_top_p(self, probs: torch.Tensor, top_p: float) -> torch.Tensor: | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| # Create a mask for top-p filtering | |
| top_p_mask = cumulative_probs <= top_p | |
| top_p_mask[..., 1:] = top_p_mask[..., :-1].clone() | |
| top_p_mask[..., 0] = 1 | |
| filtered_probs = sorted_probs * top_p_mask | |
| filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) # Normalize filtered probabilities | |
| next_token = torch.multinomial(filtered_probs, num_samples=1) | |
| return torch.gather(sorted_indices, -1, next_token) |