Wiki-Test / modeling_binaryllm.py
PhysiQuanty's picture
Duplicate from PhysiQuanty/Patenty-Test2-Radix-65536
4c9ccfe
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_binaryllm import BinaryLLMConfig
try:
import flash_attn_v100_cuda
_FLASH_V100_AVAILABLE = True
except Exception:
flash_attn_v100_cuda = None
_FLASH_V100_AVAILABLE = False
class PositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding, stocké en fp32,
puis casté au dtype de x à chaque forward.
"""
def __init__(self, d_model: int, max_len: int) -> None:
super().__init__()
pe = torch.zeros(max_len, d_model, dtype=torch.float32)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32)
* (-torch.log(torch.tensor(10000.0)) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = x.size(1)
pe = self.pe[:, :t, :].to(device=x.device, dtype=x.dtype)
return x + pe
@dataclass
class _InnerCfg:
block_size: int
embed_dim: int
vocab_size: int
num_heads: int
num_layers: int
ff_hidden_dim: int
dropout: float
layernorm_dim: Optional[int] = None
head_dim: Optional[int] = None
attn_backend: str = "auto"
class FlashSelfAttentionPortable(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
causal: bool = True,
backend: str = "auto",
) -> None:
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim ({embed_dim}) doit être divisible par num_heads ({num_heads})"
)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout = float(dropout)
self.causal = bool(causal)
self.backend = str(backend)
self.softmax_scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def _shape_qkv(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.dtype]:
bsz, seqlen, _ = x.shape
residual_dtype = x.dtype
proj_dtype = self.q_proj.weight.dtype
if x.dtype != proj_dtype:
x = x.to(proj_dtype)
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
k = k.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
v = v.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
return q, k, v, residual_dtype
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
bsz, nheads, seqlen, head_dim = x.shape
return x.transpose(1, 2).contiguous().view(bsz, seqlen, nheads * head_dim)
def _can_use_v100_kernel(self, q: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> bool:
if not _FLASH_V100_AVAILABLE:
return False
if not q.is_cuda:
return False
if padding_mask is not None and bool(padding_mask.any().item()):
return False
cc = torch.cuda.get_device_capability(q.device)
if cc != (7, 0):
return False
hd = q.size(-1)
if hd % 2 != 0:
return False
if hd % 8 != 0:
return False
if hd > 256:
return False
return True
def _flash_attn_v100(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
if q.dtype != torch.float16:
q = q.to(torch.float16)
if k.dtype != torch.float16:
k = k.to(torch.float16)
if v.dtype != torch.float16:
v = v.to(torch.float16)
result = flash_attn_v100_cuda.fwd(
q,
k,
v,
None,
None,
0.0,
self.softmax_scale,
self.causal,
-1,
-1,
0.0,
False,
None,
)
out = result[0]
return out
def _sdpa_attn(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bsz, nheads, tq, _ = q.shape
tk = k.size(-2)
attn_mask = None
if padding_mask is not None:
key_mask = padding_mask[:, None, None, :].to(device=q.device, dtype=torch.bool)
key_mask = key_mask.expand(bsz, nheads, tq, tk)
attn_mask = ~key_mask
dropout_p = self.dropout if self.training else 0.0
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_mem_efficient=True,
enable_math=True,
):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=self.causal if attn_mask is None else False,
scale=self.softmax_scale,
)
return out
def _eager_attn(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * self.softmax_scale
if self.causal:
tq = q.size(-2)
tk = k.size(-2)
causal_mask = torch.triu(
torch.ones(tq, tk, device=scores.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
if padding_mask is not None:
key_mask = padding_mask[:, None, None, :].to(device=scores.device, dtype=torch.bool)
scores = scores.masked_fill(key_mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
if self.training and self.dropout > 0.0:
probs = F.dropout(probs, p=self.dropout)
out = torch.matmul(probs, v.float())
return out.to(q.dtype)
def forward(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v, residual_dtype = self._shape_qkv(x)
if padding_mask is not None:
padding_mask = padding_mask.to(device=x.device, dtype=torch.bool)
backend = self.backend
if backend == "v100":
if not self._can_use_v100_kernel(q, padding_mask):
raise RuntimeError(
"backend='v100' demandé mais indisponible "
"(flash_attn_v100_cuda absent, GPU non sm70/V100, padding présent, "
"ou head_dim incompatible)."
)
out = self._flash_attn_v100(q, k, v)
elif backend == "sdpa":
out = self._sdpa_attn(q, k, v, padding_mask=padding_mask)
elif backend == "eager":
out = self._eager_attn(q, k, v, padding_mask=padding_mask)
elif backend == "auto":
if self._can_use_v100_kernel(q, padding_mask):
out = self._flash_attn_v100(q, k, v)
else:
out = self._sdpa_attn(q, k, v, padding_mask=padding_mask)
else:
raise ValueError(f"backend d'attention non supporté: {backend}")
out = self._merge_heads(out)
out_proj_dtype = self.out_proj.weight.dtype
if out.dtype != out_proj_dtype:
out = out.to(out_proj_dtype)
out = self.out_proj(out)
if out.dtype != residual_dtype:
out = out.to(residual_dtype)
return out
class FlashTransformerEncoderLayerPortable(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int,
dropout: float = 0.1,
activation: str = "gelu",
batch_first: bool = True,
attn_backend: str = "auto",
) -> None:
super().__init__()
if not batch_first:
raise ValueError("Cette implémentation supporte batch_first=True uniquement.")
self.self_attn = FlashSelfAttentionPortable(
embed_dim=d_model,
num_heads=nhead,
dropout=dropout,
causal=True,
backend=attn_backend,
)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
if activation == "gelu":
self.activation = F.gelu
elif activation == "relu":
self.activation = F.relu
else:
raise ValueError(f"activation non supportée: {activation}")
def _sa_block(
self,
x: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor],
) -> torch.Tensor:
x = self.self_attn(x, padding_mask=src_key_padding_mask)
x = self.dropout1(x)
return x
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
ff_dtype = self.linear1.weight.dtype
x_ff = x if x.dtype == ff_dtype else x.to(ff_dtype)
x_ff = self.linear1(x_ff)
x_ff = self.activation(x_ff)
x_ff = self.dropout(x_ff)
x_ff = self.linear2(x_ff)
x_ff = self.dropout2(x_ff)
if x_ff.dtype != x.dtype:
x_ff = x_ff.to(x.dtype)
return x_ff
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = src
x = self.norm1(x + self._sa_block(x, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
class FlashTransformerEncoderPortable(nn.Module):
def __init__(
self,
encoder_layer: FlashTransformerEncoderLayerPortable,
num_layers: int,
attn_backend: str = "auto",
) -> None:
super().__init__()
d_model = encoder_layer.norm1.normalized_shape[0]
nhead = encoder_layer.self_attn.num_heads
dim_feedforward = encoder_layer.linear1.out_features
dropout = encoder_layer.dropout.p
self.layers = nn.ModuleList(
[
FlashTransformerEncoderLayerPortable(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="gelu",
batch_first=True,
attn_backend=attn_backend,
)
for _ in range(num_layers)
]
)
def forward(
self,
src: torch.Tensor,
mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = src
for layer in self.layers:
x = layer(x, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
return x
class TinyTransformerLM(nn.Module):
def __init__(self, cfg: _InnerCfg) -> None:
super().__init__()
self.cfg = cfg
vocab_size = cfg.vocab_size
self.tok_embed = nn.Embedding(vocab_size, cfg.embed_dim)
self.pos_encoding = PositionalEncoding(cfg.embed_dim, cfg.block_size)
encoder_layer = FlashTransformerEncoderLayerPortable(
d_model=cfg.embed_dim,
nhead=cfg.num_heads,
dim_feedforward=cfg.ff_hidden_dim,
dropout=cfg.dropout,
activation="gelu",
batch_first=True,
attn_backend=cfg.attn_backend,
)
self.encoder = FlashTransformerEncoderPortable(
encoder_layer,
num_layers=cfg.num_layers,
attn_backend=cfg.attn_backend,
)
ln_dim = cfg.layernorm_dim or cfg.embed_dim
head_dim = cfg.head_dim or ln_dim
self.pre_ln_proj: Optional[nn.Linear] = None
if ln_dim != cfg.embed_dim:
self.pre_ln_proj = nn.Linear(cfg.embed_dim, ln_dim)
self.ln = nn.LayerNorm(ln_dim)
self.head_pre: Optional[nn.Linear] = None
if head_dim != ln_dim:
self.head_pre = nn.Linear(ln_dim, head_dim)
self.head = nn.Linear(head_dim, vocab_size, bias=False)
if self.pre_ln_proj is None and self.head_pre is None and head_dim == cfg.embed_dim:
self.head.weight = self.tok_embed.weight
causal = torch.triu(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", causal, persistent=False)
def forward(self, tokens: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.tok_embed(tokens)
x = self.pos_encoding(x)
seq_len = tokens.size(1)
attn_mask = self.causal_mask[:seq_len, :seq_len].to(device=tokens.device)
if padding_mask is not None:
padding_mask = padding_mask[:, :seq_len].to(device=tokens.device, dtype=torch.bool)
x = self.encoder(x, mask=attn_mask, src_key_padding_mask=padding_mask)
if self.pre_ln_proj is not None:
proj_dtype = self.pre_ln_proj.weight.dtype
if x.dtype != proj_dtype:
x = x.to(proj_dtype)
x = self.pre_ln_proj(x)
ln_dtype = self.ln.weight.dtype
if x.dtype != ln_dtype:
x = x.to(ln_dtype)
x = self.ln(x)
if self.head_pre is not None:
head_pre_dtype = self.head_pre.weight.dtype
if x.dtype != head_pre_dtype:
x = x.to(head_pre_dtype)
x = self.head_pre(x)
head_dtype = self.head.weight.dtype
if x.dtype != head_dtype:
x = x.to(head_dtype)
return self.head(x)
class BinaryLLMForCausalLM(PreTrainedModel):
config_class = BinaryLLMConfig
main_input_name = "input_ids"
def __init__(self, config: BinaryLLMConfig):
super().__init__(config)
attn_backend = getattr(config, "attn_backend", "auto")
inner = _InnerCfg(
block_size=int(config.max_position_embeddings),
embed_dim=int(config.hidden_size),
vocab_size=int(config.vocab_size),
num_heads=int(config.num_attention_heads),
num_layers=int(config.num_hidden_layers),
ff_hidden_dim=int(config.intermediate_size),
dropout=float(getattr(config, "dropout", 0.0)),
layernorm_dim=None,
head_dim=None,
attn_backend=str(attn_backend),
)
self.model = TinyTransformerLM(inner)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.model.tok_embed
def set_input_embeddings(self, value: nn.Module) -> None:
self.model.tok_embed = value
def get_output_embeddings(self) -> nn.Module:
return self.model.head
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
self.model.head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutput:
padding_mask = None
if attention_mask is not None:
padding_mask = ~attention_mask.to(torch.bool)
logits = self.model(input_ids, padding_mask=padding_mask)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutput(loss=loss, logits=logits)