YatNMN-Softplus + scalar_bias d=12 Chinchilla (261M) β PyTorch / HuggingFace Transformers
A 261M-parameter nanochat-architecture GPT with the YatNMN-Softplus MLP using scalar_bias=True β a single shared (1,) bias across all neurons (matches the clean theoretical formulation of YatNMN). Trained in JAX/Flax on TPU v6e-8 to Chinchilla-optimal token budget on C4, then ported to PyTorch for easy inference via the HuggingFace transformers API.
This is the scalar_bias ablation in the d=12 / 261M / Chinchilla series:
| MLP variant | Final smooth loss | vs GELU |
|---|---|---|
| YatNMN-Softplus (per-neuron bias) | 2.98 | β0.13 |
| YatNMN-Softplus + scalar_bias (this model) | 3.06 | β0.05 |
| GELU | 3.11 | baseline |
Weights are bit-exact with the Flax checkpoint (mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M) β parity validated at max |Ξ logits| = 1.6e-5 on CPU/fp32.
Quick start
pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M-pytorch",
trust_remote_code=True,
dtype=torch.float32,
).eval()
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
out = model.generate(
ids, max_new_tokens=50,
do_sample=True, temperature=0.8, top_p=0.9,
use_cache=True, pad_token_id=tokenizer.eos_token_id or 0,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
YatNMN-Softplus + scalar_bias
Each MLP block uses the YatNMN nonlinearity from nmn>=0.2.29:
y = Ξ± Β· (x Β· W + softplus(b))Β² / (||x β W||Β² + softplus(Ξ΅))
bshape(1,)β single shared bias across all4Β·n_embd = 3072neurons (scalar_bias=True)Ξ΅shape(1,)β single learnable epsilon, kept positive via softplusΞ±shape(1,)β single learnable scalar, applied as a final gain- bias and epsilon are passed through softplus to remain strictly positive
- followed by
c_proj(Linear β 768) on top of YatNMN's output
The shared scalar bias is theoretically motivated β YatNMN's score (xΒ·W + b)Β² in the numerator and ||x β W||Β² in the denominator are both per-neuron quantities. A scalar bias preserves the per-neuron geometric interpretation cleanly, whereas a per-neuron bias (ff,) adds extra parameters that break the symmetry. Empirically the per-neuron variant beats this scalar one by 0.08 nats at d=12 β those extra parameters do help in practice β but scalar_bias=True is closer to the YatNMN definition.
Model details
| Parameters | 261,096,374 |
| Architecture | Nanochat-style GPT with YatNMN-Softplus + scalar_bias MLP (ported from JAX/Flax NNX) |
| Config | d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window |
| Training data | allenai/c4 (English split), 5.22 B tokens (Chinchilla 20Γ) |
| Tokenizer | mistralai/Mistral-7B-v0.1 (vocab 32,768) |
| Optimizer | plain AdamW, peak LR 0.03, warmup-cosine |
| Hardware | TPU v6e-8 (TRC), europe-west4-a |
| Final loss (smooth) | 3.06 |
Architecture features
Full nanochat stack, faithfully ported to PyTorch:
- YatNMN-Softplus + scalar_bias MLP (shared
(1,)bias, softplus-positive, learnable Ξ± and Ξ΅) - RoPE (base 100,000), split-half layout
- MHA (n_head = n_kv_head = 12; the code supports GQA via n_kv_head < n_head, but all d=12 models use full MHA)
- QK-norm with 1.2Γ scaling (after RoPE)
- Parameterless RMSNorm (no learnable gain) post-embedding and per block
- Sliding-window attention with
"SSSL"pattern - Tied embeddings (lm_head = wte.T)
- Value embeddings on alternating layers (ResFormer-style)
- Per-layer learnable residual scalars (
resid_lambdas,x0_lambdas) - Smear β learnable gate on first 24 dims of token embedding mixes in prev token
- Backout β subtract mid-layer residual from late layers
- Logit soft-cap:
15 Β· tanh(logits / 15) - No biases in any Linear
KV cache
The YatGPTForCausalLM class implements a smear-aware KV cache for fast autoregressive generation. Pass use_cache=True (the default for .generate()).
Files in this repo
.
βββ config.json # HF config with auto_map β the classes below
βββ generation_config.json
βββ model.safetensors # ~1.04 GB, fp32 weights + persistent RoPE buffers
βββ yatnmn_gpt.py # pure PyTorch Yat_GPT module + YatNMN layer
βββ torch_gpt.py # shared building blocks (RMSNorm, RoPE, attention)
βββ configuration_yatnmn_gpt.py # PretrainedConfig subclass
βββ modeling_yatnmn_gpt.py # PreTrainedModel + GenerationMixin wrapper with KV cache
βββ README.md
Related
mlnomad/yatnmn-softplus-sb-d12-chinchilla-261Mβ original JAX/Flax Orbax checkpoint (model + AdamW optimizer state, resumable)mlnomad/yatnmn-softplus-d12-chinchilla-261M-pytorchβ per-neuron bias variant (better loss: 2.98 vs 3.06)mlnomad/gelu-d12-chinchilla-261M-pytorchβ GELU baseline at identical compute, smooth loss 3.11- flaxchat β JAX/Flax training harness
nmnβ the YatNMN layer (used at training time; not required for inference here, the nonlinearity is reimplemented in pure PyTorch)
Wikitext-103 evaluation
| Metric | Value |
|---|---|
| Wikitext-103 test loss | 3.677 |
| Wikitext-103 test PPL | 39.53 |
Evaluated on ~330K tokens from wikitext-103 test set (model trained on C4 only β this is a zero-shot transfer metric).
License
Apache 2.0.
- Downloads last month
- 2,001