| | from dataclasses import dataclass, field |
| | from einops import rearrange, repeat |
| | import math |
| | import torch |
| | from torch.amp.autocast_mode import autocast |
| | import torch.nn as nn |
| | from transformers.activations import ACT2FN |
| | from typing import cast |
| |
|
| | |
| | try: |
| | from flash_attn.bert_padding import pad_input, unpad_input |
| | from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding |
| | from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention |
| | from flash_attn.ops.fused_dense import FusedDense |
| | except ImportError: |
| | print("flash_attn not found, using default implementations") |
| | pad_input = unpad_input = FlashRotaryEmbedding = FlashCrossAttentio = FlashSelfAttention = FusedDense = None |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc""" |
| |
|
| | def __init__( |
| | self, |
| | d_rotary: int, |
| | rotary_base: float = 10000.0, |
| | initial_cos_sin_cache_len: int = 2048, |
| | device: torch.device | None = None, |
| | ) -> None: |
| | super().__init__() |
| | self.d_rotary = d_rotary |
| | self.rotary_base = rotary_base |
| | self.device = device |
| | self.dtype = torch.float32 |
| | self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len) |
| |
|
| | def _update_cos_sin_cache( |
| | self, |
| | seqlen: int, |
| | device: str | None = None, |
| | dtype: torch.dtype | None = None, |
| | ) -> None: |
| | |
| | self._max_seqlen = seqlen |
| |
|
| | |
| | m = torch.arange( |
| | seqlen, |
| | device=device, |
| | dtype=torch.float32, |
| | ) |
| | theta_i = 1.0 / ( |
| | self.rotary_base ** ( |
| | torch.arange( |
| | start=0, |
| | end=self.d_rotary, |
| | step=2, |
| | device=device, |
| | dtype=torch.float32, |
| | ) / self.d_rotary |
| | ) |
| | ) |
| | |
| | |
| | m_theta_i = torch.outer(m, theta_i) |
| | self._cos_cached = torch.cos(m_theta_i).to(dtype) |
| | self._sin_cached = torch.sin(m_theta_i).to(dtype) |
| |
|
| | |
| | """ |
| | if scale_base is not None: |
| | scale = ( |
| | torch.arange( |
| | start=0, |
| | end=self.d_rotary, |
| | step=2, |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) + 0.4 * self.d_rotary |
| | ) / (1.4 * self.d_rotary) |
| | power = ( |
| | torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2 |
| | ) / scale_base |
| | scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
| | self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype) |
| | self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype) |
| | """ |
| |
|
| | def _apply_rotary_emb_qkv( |
| | self, |
| | x: torch.FloatTensor, |
| | cos: torch.FloatTensor, |
| | sin: torch.FloatTensor, |
| | ) -> torch.FloatTensor: |
| | seqlen = x.shape[1] |
| | x_to_rotate = x[..., :self.d_rotary] |
| | x_to_keep_unrotated = x[..., self.d_rotary:] |
| | x1, x2 = x_to_rotate.chunk(2, dim=-1) |
| | broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d" |
| | c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange) |
| | x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] |
| | x_rotated = cast( |
| | torch.FloatTensor, |
| | torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype) |
| | ) |
| | return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1) |
| |
|
| | def forward( |
| | self, |
| | x: torch.FloatTensor, |
| | seqlen_offset: int = 0, |
| | ) -> torch.FloatTensor: |
| | if ( |
| | not self._max_seqlen |
| | or self._max_seqlen < x.shape[1] + seqlen_offset |
| | or self._cos_cached.device != x.device |
| | or self._cos_cached.dtype != x.dtype |
| | or (self.training and self._cos_cached.is_inference()) |
| | ): |
| | self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype) |
| | return self._apply_rotary_emb_qkv( |
| | x, |
| | cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]), |
| | cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]), |
| | ) |
| |
|
| |
|
| | class SelfAttention(nn.Module): |
| | def __init__( |
| | self, |
| | qk_scale: float | None = None, |
| | attention_dropout: float = 0.0, |
| | ) -> None: |
| | super().__init__() |
| | self.qk_scale = qk_scale |
| | self.dropout = nn.Dropout(attention_dropout) |
| |
|
| | |
| | @autocast("cpu", enabled=False) |
| | @autocast("cuda", enabled=False) |
| | def forward( |
| | self, |
| | qkv: torch.FloatTensor, |
| | causal: bool = True, |
| | key_padding_mask: torch.BoolTensor | None = None, |
| | ) -> torch.FloatTensor: |
| | batch_size, seqlen = qkv.shape[0], qkv.shape[1] |
| | q, k, v = qkv.unbind(dim=2) |
| | q = q.to(torch.float32) |
| | k = k.to(torch.float32) |
| | qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) |
| |
|
| | scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) |
| |
|
| | if key_padding_mask: |
| | padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) |
| | padding_mask.masked_fill_(key_padding_mask, 0.0) |
| | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
| |
|
| | if causal: |
| | causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) |
| | scores = scores + causal_mask.to(dtype=scores.dtype) |
| |
|
| | attention = torch.softmax(scores, dim=-1).to(v.dtype) |
| | attention = self.dropout(attention) |
| |
|
| | output = torch.einsum("bhts,bshd->bthd", attention, v) |
| | return cast(torch.FloatTensor, output) |
| |
|
| |
|
| | class CrossAttention(nn.Module): |
| | def __init__( |
| | self, |
| | qk_scale: float | None = None, |
| | attention_dropout: float = 0.0, |
| | ) -> None: |
| | super().__init__() |
| | self.qk_scale = qk_scale |
| | self.dropout = nn.Dropout(attention_dropout) |
| |
|
| | |
| | @autocast("cpu", enabled=False) |
| | @autocast("cuda", enabled=False) |
| | def forward( |
| | self, |
| | q: torch.FloatTensor, |
| | kv: torch.FloatTensor, |
| | causal: bool = True, |
| | key_padding_mask: torch.BoolTensor | None = None, |
| | ) -> torch.FloatTensor: |
| | batch_size, seqlen_q = q.shape[0], q.shape[1] |
| | seqlen_k = kv.shape[1] |
| | if kv.shape[3] != q.shape[2]: |
| | kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) |
| | k, v = kv.unbind(dim=2) |
| | q = cast(torch.FloatTensor, q.to(torch.float32)) |
| | k = k.to(torch.float32) |
| | qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) |
| |
|
| | scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) |
| |
|
| | if key_padding_mask: |
| | padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) |
| | padding_mask.masked_fill_(key_padding_mask, 0.0) |
| | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
| |
|
| | if causal: |
| | rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") |
| | cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) |
| | causal_mask = cols > rows + seqlen_k - seqlen_q |
| | scores = scores.masked_fill(causal_mask, -10000.0) |
| |
|
| | attention = torch.softmax(scores, dim=-1).to(v.dtype) |
| | attention = self.dropout(attention) |
| |
|
| | output = torch.einsum("bhts,bshd->bthd", attention, v) |
| | return cast(torch.FloatTensor, output) |
| |
|
| |
|
| | class MLP(nn.Module): |
| | def __init__( |
| | self, |
| | d_embedding: int, |
| | act_fn: str = "gelu_new", |
| | ) -> None: |
| | super().__init__() |
| | n_inner = 4 * d_embedding |
| | self.fc1 = nn.Linear(d_embedding, n_inner) |
| | self.act = ACT2FN[act_fn] |
| | self.fc2 = nn.Linear(n_inner, d_embedding) |
| |
|
| | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.fc2(x) |
| | return x |
| |
|
| |
|
| | @dataclass |
| | class KVCache: |
| | """Options for model to calculate and store context during inference.""" |
| | max_seqlen: int |
| | max_batch_size: int |
| | seqlen_offset: int |
| | batch_size_offset: int |
| | kv_block_map: dict[int, torch.Tensor] = field(default_factory=dict) |
| | lengths_per_sample: torch.Tensor | None = None |
| |
|
| |
|
| | class MHA(nn.Module): |
| | """Multi-head attention block.""" |
| |
|
| | def __init__( |
| | self, |
| | d_embedding: int, |
| | n_attn_heads: int, |
| | block_n: int, |
| | initial_cos_sin_cache_len: int, |
| | attn_pdrop: float, |
| | use_flash_rotary: bool, |
| | use_flash_attn: bool, |
| | use_fused_dense: bool, |
| | checkpointing: bool, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | |
| | rotary_cls = ( |
| | FlashRotaryEmbedding |
| | if use_flash_rotary and FlashRotaryEmbedding is not None |
| | else RotaryEmbedding |
| | ) |
| | self.rotary_emb = rotary_cls( |
| | |
| | d_rotary=32, |
| | initial_cos_sin_cache_len=initial_cos_sin_cache_len, |
| | ) |
| |
|
| | |
| | self_attn_cls = ( |
| | FlashSelfAttention |
| | if use_flash_attn and FlashSelfAttention is not None |
| | else SelfAttention |
| | ) |
| | self.inner_self_attn = self_attn_cls(attention_dropout=attn_pdrop) |
| |
|
| | |
| | cross_attn_cls = ( |
| | FlashCrossAttention |
| | if use_flash_attn and FlashCrossAttention is not None |
| | else CrossAttention |
| | ) |
| | self.inner_cross_attn = cross_attn_cls(attention_dropout=attn_pdrop) |
| |
|
| | |
| | self.n_attn_heads = n_attn_heads |
| | self.d_head = d_embedding // n_attn_heads |
| | linear_cls = ( |
| | FusedDense |
| | if use_fused_dense and FusedDense is not None |
| | else nn.Linear |
| | ) |
| | self.Wqkv = linear_cls( |
| | d_embedding, |
| | self.d_head * (3 * self.n_attn_heads), |
| | ) |
| | self.fc_out = linear_cls(d_embedding, d_embedding) |
| |
|
| | |
| | self.using_flash_attn = self_attn_cls is FlashSelfAttention |
| | self.block_n = block_n |
| | self.checkpointing = checkpointing |
| |
|
| | def _forward_self_attn( |
| | self, |
| | qkv: torch.FloatTensor, |
| | key_padding_mask: torch.BoolTensor | None, |
| | ) -> torch.FloatTensor: |
| | qkv = cast( |
| | torch.FloatTensor, |
| | torch.cat( |
| | [ |
| | self.rotary_emb(qkv[:, :, :2, :, :]), |
| | qkv[:, :, 2, :, :], |
| | ], |
| | dim=2, |
| | ) |
| | ) |
| |
|
| | if self.using_flash_attn and unpad_input and pad_input: |
| | batch_size, seqlen = qkv.shape[0], qkv.shape[1] |
| | cu_seqlens, max_seqlen, indices = None, None, None |
| |
|
| | |
| | if key_padding_mask: |
| | qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask) |
| |
|
| | if self.checkpointing: |
| | attn_output = torch.utils.checkpoint.checkpoint( |
| | self.inner_self_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen |
| | ) |
| | else: |
| | attn_output = self.inner_self_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device) |
| |
|
| | |
| | if key_padding_mask: |
| | return pad_input(attn_output, indices, batch_size, seqlen) |
| | else: |
| | return attn_output |
| |
|
| | if self.checkpointing: |
| | return torch.utils.checkpoint.checkpoint(self.inner_self_attn, qkv, key_padding_mask=key_padding_mask) |
| | else: |
| | return self.inner_self_attn(qkv, key_padding_mask=key_padding_mask) |
| |
|
| | def _update_kv_cache( |
| | self, |
| | kv: torch.FloatTensor, |
| | kv_cache: KVCache, |
| | block_n: int, |
| | ) -> None: |
| | if block_n not in kv_cache.kv_block_map: |
| | kv_cache.kv_block_map[block_n] = torch.empty( |
| | kv_cache.max_batch_size, |
| | kv_cache.max_seqlen, |
| | 2, |
| | kv.shape[-2], |
| | kv.shape[-1], |
| | dtype=kv.dtype, |
| | device=kv.device, |
| | ) |
| |
|
| | batch_start = kv_cache.batch_size_offset |
| | batch_end = batch_start + kv.shape[0] |
| | sequence_start = kv_cache.seqlen_offset |
| | sequence_end = sequence_start + kv.shape[1] |
| |
|
| | |
| | if sequence_end >= kv_cache.max_seqlen: |
| | kv_cache.kv_block_map[block_n] = torch.concatenate( |
| | (kv_cache.kv_block_map[block_n], kv), |
| | dim=1, |
| | ) |
| | kv_cache.kv_block_map[block_n][ |
| | batch_start:batch_end, |
| | sequence_start:sequence_end, |
| | ... |
| | ] = kv |
| | kv = kv_cache.kv_block_map[block_n][ |
| | batch_start:batch_end, |
| | :sequence_end, |
| | ... |
| | ] |
| | return kv |
| |
|
| | def _forward_cross_attn( |
| | self, |
| | qkv: torch.FloatTensor, |
| | kv_cache: KVCache, |
| | key_padding_mask: torch.BoolTensor | None, |
| | ) -> torch.FloatTensor: |
| | qk = qkv[:, :, :2, :, :] |
| | qk = self.rotary_emb( |
| | qk, |
| | seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset, |
| | ) |
| | v = cast(torch.FloatTensor, qkv[:, :, 2, :, :]) |
| | q = qk[:, :, 0, :, :] |
| | kv = torch.cat( |
| | [ |
| | qk[:, :, 1, :, :].unsqueeze(2), |
| | v.unsqueeze(2), |
| | ], |
| | dim=2, |
| | ) |
| | kv = self._update_kv_cache(kv, kv_cache, self.block_n) |
| |
|
| | causal = (kv_cache.seqlen_offset == 0) |
| |
|
| | if self.using_flash_attn and unpad_input and pad_input: |
| | batch_size, seqlen_q = q.shape[0], q.shape[1] |
| | seqlen_k = kv.shape[1] |
| | cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, indices_q = ( |
| | None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |
| |
|
| | |
| | if key_padding_mask: |
| | kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) |
| |
|
| | if seqlen_q == 1: |
| | key_padding_mask = cast(torch.BoolTensor, torch.ones(batch_size, 1, device=q.device)) |
| | elif seqlen_q != seqlen_k: |
| | key_padding_mask = cast(torch.BoolTensor, key_padding_mask[:, -seqlen_q:]) |
| |
|
| | q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask) |
| |
|
| | if self.checkpointing: |
| | attn_output = torch.utils.checkpoint.checkpoint( |
| | self.inner_cross_attn, |
| | q, |
| | kv, |
| | causal=causal, |
| | cu_seqlens=cu_seqlens_q, |
| | max_seqlen=max_seqlen_q, |
| | cu_seqlens_k=cu_seqlens_k, |
| | max_seqlen_k=max_seqlen_k, |
| | ) |
| | else: |
| | attn_output = self.inner_cross_attn( |
| | q, |
| | kv, |
| | causal=causal, |
| | cu_seqlens=cu_seqlens_q, |
| | max_seqlen=max_seqlen_q, |
| | cu_seqlens_k=cu_seqlens_k, |
| | max_seqlen_k=max_seqlen_k, |
| | ) |
| |
|
| | if key_padding_mask: |
| | return pad_input(attn_output, indices_q, batch_size, max_seqlen_q) |
| | else: |
| | return attn_output |
| |
|
| | if self.checkpointing: |
| | return torch.utils.checkpoint.checkpoint( |
| | self.inner_cross_attn, |
| | q, |
| | kv, |
| | key_padding_mask=key_padding_mask, |
| | causal=causal, |
| | ) |
| | else: |
| | return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal) |
| |
|
| | def forward( |
| | self, |
| | x: torch.FloatTensor, |
| | kv_cache: KVCache | None = None, |
| | key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, |
| | ) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
| | if key_padding_mask is not None: |
| | key_padding_mask = cast(torch.BoolTensor, key_padding_mask.bool()) |
| |
|
| | qkv = self.Wqkv(x) |
| | qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.d_head) |
| | if kv_cache is None: |
| | attn_output = self._forward_self_attn(qkv, key_padding_mask) |
| | else: |
| | attn_output = self._forward_cross_attn(qkv, kv_cache, key_padding_mask) |
| |
|
| | output = rearrange(attn_output, "... h d -> ... (h d)") |
| | output = self.fc_out(output) |
| | return output |
| |
|
| |
|
| | class ParallelAttentionBlock(nn.Module): |
| | """Calculates attention and MLP in parallel.""" |
| |
|
| | def __init__( |
| | self, |
| | resid_pdrop: float, |
| | layer_norm_epsilon: float, |
| | d_embedding: int, |
| | n_attn_heads: int, |
| | block_n: int, |
| | initial_cos_sin_cache_len: int, |
| | attn_pdrop: float, |
| | use_flash_rotary: bool = True, |
| | use_flash_attn: bool = True, |
| | use_fused_dense: bool = True, |
| | checkpointing: bool = False, |
| | ) -> None: |
| | super().__init__() |
| | self.layer_norm = nn.LayerNorm(d_embedding, eps=layer_norm_epsilon) |
| | self.block_n = block_n |
| | self.multi_head_attention = MHA( |
| | d_embedding=d_embedding, |
| | n_attn_heads=n_attn_heads, |
| | block_n=block_n, |
| | initial_cos_sin_cache_len=initial_cos_sin_cache_len, |
| | attn_pdrop=attn_pdrop, |
| | use_flash_rotary=use_flash_rotary, |
| | use_flash_attn=use_flash_attn, |
| | use_fused_dense=use_fused_dense, |
| | checkpointing=checkpointing, |
| | ) |
| | self.mlp = MLP(d_embedding) |
| | self.dropout = nn.Dropout(resid_pdrop) |
| |
|
| | def forward( |
| | self, |
| | x: torch.FloatTensor, |
| | kv_cache: KVCache | None = None, |
| | key_padding_mask: torch.BoolTensor | None = None, |
| | ) -> torch.FloatTensor: |
| | residual = x |
| | x = self.layer_norm(x) |
| | attn_outputs = self.multi_head_attention( |
| | x, |
| | kv_cache=kv_cache, |
| | key_padding_mask=key_padding_mask, |
| | ) |
| | mlp_outputs = self.mlp(x) |
| | x = self.dropout(attn_outputs + mlp_outputs) + residual |
| | return x |
| |
|