| | import math |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutput |
| |
|
| | from .configuration_binaryllm import BinaryLLMConfig |
| |
|
| | try: |
| | import flash_attn_v100_cuda |
| | _FLASH_V100_AVAILABLE = True |
| | except Exception: |
| | flash_attn_v100_cuda = None |
| | _FLASH_V100_AVAILABLE = False |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | """ |
| | Sinusoidal positional encoding, stocké en fp32, |
| | puis casté au dtype de x à chaque forward. |
| | """ |
| |
|
| | def __init__(self, d_model: int, max_len: int) -> None: |
| | super().__init__() |
| | pe = torch.zeros(max_len, d_model, dtype=torch.float32) |
| | position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, d_model, 2, dtype=torch.float32) |
| | * (-torch.log(torch.tensor(10000.0)) / d_model) |
| | ) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | pe = pe.unsqueeze(0) |
| | self.register_buffer("pe", pe, persistent=False) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | t = x.size(1) |
| | pe = self.pe[:, :t, :].to(device=x.device, dtype=x.dtype) |
| | return x + pe |
| |
|
| |
|
| | @dataclass |
| | class _InnerCfg: |
| | block_size: int |
| | embed_dim: int |
| | vocab_size: int |
| | num_heads: int |
| | num_layers: int |
| | ff_hidden_dim: int |
| | dropout: float |
| | layernorm_dim: Optional[int] = None |
| | head_dim: Optional[int] = None |
| | attn_backend: str = "auto" |
| |
|
| |
|
| | class FlashSelfAttentionPortable(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int, |
| | dropout: float = 0.0, |
| | causal: bool = True, |
| | backend: str = "auto", |
| | ) -> None: |
| | super().__init__() |
| |
|
| | if embed_dim % num_heads != 0: |
| | raise ValueError( |
| | f"embed_dim ({embed_dim}) doit être divisible par num_heads ({num_heads})" |
| | ) |
| |
|
| | self.embed_dim = embed_dim |
| | self.num_heads = num_heads |
| | self.head_dim = embed_dim // num_heads |
| | self.dropout = float(dropout) |
| | self.causal = bool(causal) |
| | self.backend = str(backend) |
| | self.softmax_scale = self.head_dim ** -0.5 |
| |
|
| | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| |
|
| | def _shape_qkv( |
| | self, |
| | x: torch.Tensor, |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.dtype]: |
| | bsz, seqlen, _ = x.shape |
| | residual_dtype = x.dtype |
| |
|
| | proj_dtype = self.q_proj.weight.dtype |
| | if x.dtype != proj_dtype: |
| | x = x.to(proj_dtype) |
| |
|
| | q = self.q_proj(x) |
| | k = self.k_proj(x) |
| | v = self.v_proj(x) |
| |
|
| | q = q.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| | k = k.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| | v = v.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| |
|
| | return q, k, v, residual_dtype |
| |
|
| | def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: |
| | bsz, nheads, seqlen, head_dim = x.shape |
| | return x.transpose(1, 2).contiguous().view(bsz, seqlen, nheads * head_dim) |
| |
|
| | def _can_use_v100_kernel(self, q: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> bool: |
| | if not _FLASH_V100_AVAILABLE: |
| | return False |
| |
|
| | if not q.is_cuda: |
| | return False |
| |
|
| | if padding_mask is not None and bool(padding_mask.any().item()): |
| | return False |
| |
|
| | cc = torch.cuda.get_device_capability(q.device) |
| | if cc != (7, 0): |
| | return False |
| |
|
| | hd = q.size(-1) |
| | if hd % 2 != 0: |
| | return False |
| | if hd % 8 != 0: |
| | return False |
| | if hd > 256: |
| | return False |
| |
|
| | return True |
| |
|
| | def _flash_attn_v100( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | ) -> torch.Tensor: |
| | if q.dtype != torch.float16: |
| | q = q.to(torch.float16) |
| | if k.dtype != torch.float16: |
| | k = k.to(torch.float16) |
| | if v.dtype != torch.float16: |
| | v = v.to(torch.float16) |
| |
|
| | result = flash_attn_v100_cuda.fwd( |
| | q, |
| | k, |
| | v, |
| | None, |
| | None, |
| | 0.0, |
| | self.softmax_scale, |
| | self.causal, |
| | -1, |
| | -1, |
| | 0.0, |
| | False, |
| | None, |
| | ) |
| |
|
| | out = result[0] |
| | return out |
| |
|
| | def _sdpa_attn( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | padding_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | bsz, nheads, tq, _ = q.shape |
| | tk = k.size(-2) |
| |
|
| | attn_mask = None |
| | if padding_mask is not None: |
| | key_mask = padding_mask[:, None, None, :].to(device=q.device, dtype=torch.bool) |
| | key_mask = key_mask.expand(bsz, nheads, tq, tk) |
| | attn_mask = ~key_mask |
| |
|
| | dropout_p = self.dropout if self.training else 0.0 |
| |
|
| | with torch.backends.cuda.sdp_kernel( |
| | enable_flash=True, |
| | enable_mem_efficient=True, |
| | enable_math=True, |
| | ): |
| | out = F.scaled_dot_product_attention( |
| | q, |
| | k, |
| | v, |
| | attn_mask=attn_mask, |
| | dropout_p=dropout_p, |
| | is_causal=self.causal if attn_mask is None else False, |
| | scale=self.softmax_scale, |
| | ) |
| |
|
| | return out |
| |
|
| | def _eager_attn( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | padding_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * self.softmax_scale |
| |
|
| | if self.causal: |
| | tq = q.size(-2) |
| | tk = k.size(-2) |
| | causal_mask = torch.triu( |
| | torch.ones(tq, tk, device=scores.device, dtype=torch.bool), |
| | diagonal=1, |
| | ) |
| | scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) |
| |
|
| | if padding_mask is not None: |
| | key_mask = padding_mask[:, None, None, :].to(device=scores.device, dtype=torch.bool) |
| | scores = scores.masked_fill(key_mask, float("-inf")) |
| |
|
| | probs = torch.softmax(scores, dim=-1) |
| |
|
| | if self.training and self.dropout > 0.0: |
| | probs = F.dropout(probs, p=self.dropout) |
| |
|
| | out = torch.matmul(probs, v.float()) |
| | return out.to(q.dtype) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | padding_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | q, k, v, residual_dtype = self._shape_qkv(x) |
| |
|
| | if padding_mask is not None: |
| | padding_mask = padding_mask.to(device=x.device, dtype=torch.bool) |
| |
|
| | backend = self.backend |
| |
|
| | if backend == "v100": |
| | if not self._can_use_v100_kernel(q, padding_mask): |
| | raise RuntimeError( |
| | "backend='v100' demandé mais indisponible " |
| | "(flash_attn_v100_cuda absent, GPU non sm70/V100, padding présent, " |
| | "ou head_dim incompatible)." |
| | ) |
| | out = self._flash_attn_v100(q, k, v) |
| |
|
| | elif backend == "sdpa": |
| | out = self._sdpa_attn(q, k, v, padding_mask=padding_mask) |
| |
|
| | elif backend == "eager": |
| | out = self._eager_attn(q, k, v, padding_mask=padding_mask) |
| |
|
| | elif backend == "auto": |
| | if self._can_use_v100_kernel(q, padding_mask): |
| | out = self._flash_attn_v100(q, k, v) |
| | else: |
| | out = self._sdpa_attn(q, k, v, padding_mask=padding_mask) |
| |
|
| | else: |
| | raise ValueError(f"backend d'attention non supporté: {backend}") |
| |
|
| | out = self._merge_heads(out) |
| |
|
| | out_proj_dtype = self.out_proj.weight.dtype |
| | if out.dtype != out_proj_dtype: |
| | out = out.to(out_proj_dtype) |
| |
|
| | out = self.out_proj(out) |
| |
|
| | if out.dtype != residual_dtype: |
| | out = out.to(residual_dtype) |
| |
|
| | return out |
| |
|
| |
|
| | class FlashTransformerEncoderLayerPortable(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | nhead: int, |
| | dim_feedforward: int, |
| | dropout: float = 0.1, |
| | activation: str = "gelu", |
| | batch_first: bool = True, |
| | attn_backend: str = "auto", |
| | ) -> None: |
| | super().__init__() |
| |
|
| | if not batch_first: |
| | raise ValueError("Cette implémentation supporte batch_first=True uniquement.") |
| |
|
| | self.self_attn = FlashSelfAttentionPortable( |
| | embed_dim=d_model, |
| | num_heads=nhead, |
| | dropout=dropout, |
| | causal=True, |
| | backend=attn_backend, |
| | ) |
| |
|
| | self.linear1 = nn.Linear(d_model, dim_feedforward) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model) |
| |
|
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.norm2 = nn.LayerNorm(d_model) |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| |
|
| | if activation == "gelu": |
| | self.activation = F.gelu |
| | elif activation == "relu": |
| | self.activation = F.relu |
| | else: |
| | raise ValueError(f"activation non supportée: {activation}") |
| |
|
| | def _sa_block( |
| | self, |
| | x: torch.Tensor, |
| | src_key_padding_mask: Optional[torch.Tensor], |
| | ) -> torch.Tensor: |
| | x = self.self_attn(x, padding_mask=src_key_padding_mask) |
| | x = self.dropout1(x) |
| | return x |
| |
|
| | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: |
| | ff_dtype = self.linear1.weight.dtype |
| | x_ff = x if x.dtype == ff_dtype else x.to(ff_dtype) |
| |
|
| | x_ff = self.linear1(x_ff) |
| | x_ff = self.activation(x_ff) |
| | x_ff = self.dropout(x_ff) |
| | x_ff = self.linear2(x_ff) |
| | x_ff = self.dropout2(x_ff) |
| |
|
| | if x_ff.dtype != x.dtype: |
| | x_ff = x_ff.to(x.dtype) |
| |
|
| | return x_ff |
| |
|
| | def forward( |
| | self, |
| | src: torch.Tensor, |
| | src_mask: Optional[torch.Tensor] = None, |
| | src_key_padding_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | x = src |
| | x = self.norm1(x + self._sa_block(x, src_key_padding_mask)) |
| | x = self.norm2(x + self._ff_block(x)) |
| | return x |
| |
|
| |
|
| | class FlashTransformerEncoderPortable(nn.Module): |
| | def __init__( |
| | self, |
| | encoder_layer: FlashTransformerEncoderLayerPortable, |
| | num_layers: int, |
| | attn_backend: str = "auto", |
| | ) -> None: |
| | super().__init__() |
| |
|
| | d_model = encoder_layer.norm1.normalized_shape[0] |
| | nhead = encoder_layer.self_attn.num_heads |
| | dim_feedforward = encoder_layer.linear1.out_features |
| | dropout = encoder_layer.dropout.p |
| |
|
| | self.layers = nn.ModuleList( |
| | [ |
| | FlashTransformerEncoderLayerPortable( |
| | d_model=d_model, |
| | nhead=nhead, |
| | dim_feedforward=dim_feedforward, |
| | dropout=dropout, |
| | activation="gelu", |
| | batch_first=True, |
| | attn_backend=attn_backend, |
| | ) |
| | for _ in range(num_layers) |
| | ] |
| | ) |
| |
|
| | def forward( |
| | self, |
| | src: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | src_key_padding_mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | x = src |
| | for layer in self.layers: |
| | x = layer(x, src_mask=mask, src_key_padding_mask=src_key_padding_mask) |
| | return x |
| |
|
| |
|
| | class TinyTransformerLM(nn.Module): |
| | def __init__(self, cfg: _InnerCfg) -> None: |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | vocab_size = cfg.vocab_size |
| | self.tok_embed = nn.Embedding(vocab_size, cfg.embed_dim) |
| | self.pos_encoding = PositionalEncoding(cfg.embed_dim, cfg.block_size) |
| |
|
| | encoder_layer = FlashTransformerEncoderLayerPortable( |
| | d_model=cfg.embed_dim, |
| | nhead=cfg.num_heads, |
| | dim_feedforward=cfg.ff_hidden_dim, |
| | dropout=cfg.dropout, |
| | activation="gelu", |
| | batch_first=True, |
| | attn_backend=cfg.attn_backend, |
| | ) |
| | self.encoder = FlashTransformerEncoderPortable( |
| | encoder_layer, |
| | num_layers=cfg.num_layers, |
| | attn_backend=cfg.attn_backend, |
| | ) |
| |
|
| | ln_dim = cfg.layernorm_dim or cfg.embed_dim |
| | head_dim = cfg.head_dim or ln_dim |
| |
|
| | self.pre_ln_proj: Optional[nn.Linear] = None |
| | if ln_dim != cfg.embed_dim: |
| | self.pre_ln_proj = nn.Linear(cfg.embed_dim, ln_dim) |
| |
|
| | self.ln = nn.LayerNorm(ln_dim) |
| |
|
| | self.head_pre: Optional[nn.Linear] = None |
| | if head_dim != ln_dim: |
| | self.head_pre = nn.Linear(ln_dim, head_dim) |
| |
|
| | self.head = nn.Linear(head_dim, vocab_size, bias=False) |
| |
|
| | if self.pre_ln_proj is None and self.head_pre is None and head_dim == cfg.embed_dim: |
| | self.head.weight = self.tok_embed.weight |
| |
|
| | causal = torch.triu(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool), diagonal=1) |
| | self.register_buffer("causal_mask", causal, persistent=False) |
| |
|
| | def forward(self, tokens: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| | x = self.tok_embed(tokens) |
| | x = self.pos_encoding(x) |
| |
|
| | seq_len = tokens.size(1) |
| | attn_mask = self.causal_mask[:seq_len, :seq_len].to(device=tokens.device) |
| |
|
| | if padding_mask is not None: |
| | padding_mask = padding_mask[:, :seq_len].to(device=tokens.device, dtype=torch.bool) |
| |
|
| | x = self.encoder(x, mask=attn_mask, src_key_padding_mask=padding_mask) |
| |
|
| | if self.pre_ln_proj is not None: |
| | proj_dtype = self.pre_ln_proj.weight.dtype |
| | if x.dtype != proj_dtype: |
| | x = x.to(proj_dtype) |
| | x = self.pre_ln_proj(x) |
| |
|
| | ln_dtype = self.ln.weight.dtype |
| | if x.dtype != ln_dtype: |
| | x = x.to(ln_dtype) |
| | x = self.ln(x) |
| |
|
| | if self.head_pre is not None: |
| | head_pre_dtype = self.head_pre.weight.dtype |
| | if x.dtype != head_pre_dtype: |
| | x = x.to(head_pre_dtype) |
| | x = self.head_pre(x) |
| |
|
| | head_dtype = self.head.weight.dtype |
| | if x.dtype != head_dtype: |
| | x = x.to(head_dtype) |
| |
|
| | return self.head(x) |
| |
|
| |
|
| | class BinaryLLMForCausalLM(PreTrainedModel): |
| | config_class = BinaryLLMConfig |
| | main_input_name = "input_ids" |
| |
|
| | def __init__(self, config: BinaryLLMConfig): |
| | super().__init__(config) |
| |
|
| | attn_backend = getattr(config, "attn_backend", "auto") |
| |
|
| | inner = _InnerCfg( |
| | block_size=int(config.max_position_embeddings), |
| | embed_dim=int(config.hidden_size), |
| | vocab_size=int(config.vocab_size), |
| | num_heads=int(config.num_attention_heads), |
| | num_layers=int(config.num_hidden_layers), |
| | ff_hidden_dim=int(config.intermediate_size), |
| | dropout=float(getattr(config, "dropout", 0.0)), |
| | layernorm_dim=None, |
| | head_dim=None, |
| | attn_backend=str(attn_backend), |
| | ) |
| | self.model = TinyTransformerLM(inner) |
| |
|
| | self.post_init() |
| |
|
| | def get_input_embeddings(self) -> nn.Module: |
| | return self.model.tok_embed |
| |
|
| | def set_input_embeddings(self, value: nn.Module) -> None: |
| | self.model.tok_embed = value |
| |
|
| | def get_output_embeddings(self) -> nn.Module: |
| | return self.model.head |
| |
|
| | def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| | self.model.head = new_embeddings |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ): |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | } |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> CausalLMOutput: |
| | padding_mask = None |
| | if attention_mask is not None: |
| | padding_mask = ~attention_mask.to(torch.bool) |
| |
|
| | logits = self.model(input_ids, padding_mask=padding_mask) |
| |
|
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[:, :-1, :].contiguous() |
| | shift_labels = labels[:, 1:].contiguous() |
| | loss = F.cross_entropy( |
| | shift_logits.view(-1, self.config.vocab_size), |
| | shift_labels.view(-1), |
| | ignore_index=-100, |
| | ) |
| |
|
| | return CausalLMOutput(loss=loss, logits=logits) |
| |
|