| | """ |
| | Tiny Transformer with modern components: |
| | - RoPE (Rotary Position Embeddings) |
| | - RMSNorm |
| | - SwiGLU activation |
| | - Weight tying |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dim: int, eps: float = 1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x): |
| | norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| | return x * norm * self.weight |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__(self, dim: int, max_seq_len: int = 512, base: int = 10000): |
| | super().__init__() |
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| | self.max_seq_len = max_seq_len |
| | |
| | def forward(self, x, seq_len: int): |
| | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
| | freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | return emb.cos(), emb.sin() |
| |
|
| |
|
| | def rotate_half(x): |
| | x1, x2 = x.chunk(2, dim=-1) |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb(q, k, cos, sin): |
| | cos = cos.unsqueeze(0).unsqueeze(0) |
| | sin = sin.unsqueeze(0).unsqueeze(0) |
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|
| |
|
| | class SwiGLU(nn.Module): |
| | def __init__(self, hidden_size: int, intermediate_size: int): |
| | super().__init__() |
| | self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) |
| | self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) |
| | self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) |
| |
|
| | def forward(self, x): |
| | return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.head_dim = hidden_size // num_heads |
| | |
| | self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | |
| | self.rotary = RotaryEmbedding(self.head_dim) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x, mask=None): |
| | B, T, C = x.shape |
| | |
| | q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| | k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| | v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| | |
| | cos, sin = self.rotary(x, T) |
| | q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| | |
| | |
| | scale = 1.0 / math.sqrt(self.head_dim) |
| | attn = torch.matmul(q, k.transpose(-2, -1)) * scale |
| | |
| | if mask is not None: |
| | attn = attn.masked_fill(mask == 0, float('-inf')) |
| | |
| | attn = F.softmax(attn, dim=-1) |
| | attn = self.dropout(attn) |
| | |
| | out = torch.matmul(attn, v) |
| | out = out.transpose(1, 2).contiguous().view(B, T, C) |
| | return self.o_proj(out) |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int, dropout: float = 0.0): |
| | super().__init__() |
| | self.norm1 = RMSNorm(hidden_size) |
| | self.attn = Attention(hidden_size, num_heads, dropout) |
| | self.norm2 = RMSNorm(hidden_size) |
| | self.ffn = SwiGLU(hidden_size, intermediate_size) |
| |
|
| | def forward(self, x, mask=None): |
| | x = x + self.attn(self.norm1(x), mask) |
| | x = x + self.ffn(self.norm2(x)) |
| | return x |
| |
|
| |
|
| | class TinyLLM(nn.Module): |
| | def __init__( |
| | self, |
| | vocab_size: int = 32000, |
| | hidden_size: int = 512, |
| | num_layers: int = 12, |
| | num_heads: int = 8, |
| | intermediate_size: int = 1408, |
| | max_position_embeddings: int = 512, |
| | dropout: float = 0.0, |
| | tie_weights: bool = True, |
| | ): |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | |
| | self.embed_tokens = nn.Embedding(vocab_size, hidden_size) |
| | self.layers = nn.ModuleList([ |
| | TransformerBlock(hidden_size, num_heads, intermediate_size, dropout) |
| | for _ in range(num_layers) |
| | ]) |
| | self.norm = RMSNorm(hidden_size) |
| | self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
| | |
| | if tie_weights: |
| | self.lm_head.weight = self.embed_tokens.weight |
| | |
| | |
| | self.register_buffer( |
| | "causal_mask", |
| | torch.tril(torch.ones(max_position_embeddings, max_position_embeddings)) |
| | ) |
| | |
| | self._init_weights() |
| | |
| | def _init_weights(self): |
| | for module in self.modules(): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | |
| | def forward(self, input_ids, labels=None): |
| | B, T = input_ids.shape |
| | |
| | x = self.embed_tokens(input_ids) |
| | mask = self.causal_mask[:T, :T] |
| | |
| | for layer in self.layers: |
| | x = layer(x, mask) |
| | |
| | x = self.norm(x) |
| | logits = self.lm_head(x) |
| | |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss = F.cross_entropy( |
| | shift_logits.view(-1, self.vocab_size), |
| | shift_labels.view(-1), |
| | ignore_index=-100 |
| | ) |
| | |
| | return {"loss": loss, "logits": logits} |
| | |
| | def count_parameters(self): |
| | return sum(p.numel() for p in self.parameters()) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | model = TinyLLM() |
| | print(f"Parameters: {model.count_parameters() / 1e6:.2f}M") |
| | |
| | x = torch.randint(0, 32000, (2, 128)) |
| | out = model(x, labels=x) |
| | print(f"Loss: {out['loss'].item():.4f}") |
| | print(f"Logits shape: {out['logits'].shape}") |
| |
|