YatNMN-Softplus + scalar_bias d=12 Chinchilla (261M) β€” PyTorch / HuggingFace Transformers

A 261M-parameter nanochat-architecture GPT with the YatNMN-Softplus MLP using scalar_bias=True β€” a single shared (1,) bias across all neurons (matches the clean theoretical formulation of YatNMN). Trained in JAX/Flax on TPU v6e-8 to Chinchilla-optimal token budget on C4, then ported to PyTorch for easy inference via the HuggingFace transformers API.

This is the scalar_bias ablation in the d=12 / 261M / Chinchilla series:

MLP variant Final smooth loss vs GELU
YatNMN-Softplus (per-neuron bias) 2.98 βˆ’0.13
YatNMN-Softplus + scalar_bias (this model) 3.06 βˆ’0.05
GELU 3.11 baseline

Weights are bit-exact with the Flax checkpoint (mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M) β€” parity validated at max |Ξ” logits| = 1.6e-5 on CPU/fp32.

Quick start

pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M-pytorch",
    trust_remote_code=True,
    dtype=torch.float32,
).eval()

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
    out = model.generate(
        ids, max_new_tokens=50,
        do_sample=True, temperature=0.8, top_p=0.9,
        use_cache=True, pad_token_id=tokenizer.eos_token_id or 0,
    )
print(tokenizer.decode(out[0], skip_special_tokens=True))

YatNMN-Softplus + scalar_bias

Each MLP block uses the YatNMN nonlinearity from nmn>=0.2.29:

y = Ξ± Β· (x Β· W + softplus(b))Β² / (||x βˆ’ W||Β² + softplus(Ξ΅))
  • b shape (1,) β€” single shared bias across all 4Β·n_embd = 3072 neurons (scalar_bias=True)
  • Ξ΅ shape (1,) β€” single learnable epsilon, kept positive via softplus
  • Ξ± shape (1,) β€” single learnable scalar, applied as a final gain
  • bias and epsilon are passed through softplus to remain strictly positive
  • followed by c_proj (Linear β†’ 768) on top of YatNMN's output

The shared scalar bias is theoretically motivated β€” YatNMN's score (xΒ·W + b)Β² in the numerator and ||x βˆ’ W||Β² in the denominator are both per-neuron quantities. A scalar bias preserves the per-neuron geometric interpretation cleanly, whereas a per-neuron bias (ff,) adds extra parameters that break the symmetry. Empirically the per-neuron variant beats this scalar one by 0.08 nats at d=12 β€” those extra parameters do help in practice β€” but scalar_bias=True is closer to the YatNMN definition.

Model details

Parameters 261,096,374
Architecture Nanochat-style GPT with YatNMN-Softplus + scalar_bias MLP (ported from JAX/Flax NNX)
Config d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window
Training data allenai/c4 (English split), 5.22 B tokens (Chinchilla 20Γ—)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, peak LR 0.03, warmup-cosine
Hardware TPU v6e-8 (TRC), europe-west4-a
Final loss (smooth) 3.06

Architecture features

Full nanochat stack, faithfully ported to PyTorch:

  • YatNMN-Softplus + scalar_bias MLP (shared (1,) bias, softplus-positive, learnable Ξ± and Ξ΅)
  • RoPE (base 100,000), split-half layout
  • MHA (n_head = n_kv_head = 12; the code supports GQA via n_kv_head < n_head, but all d=12 models use full MHA)
  • QK-norm with 1.2Γ— scaling (after RoPE)
  • Parameterless RMSNorm (no learnable gain) post-embedding and per block
  • Sliding-window attention with "SSSL" pattern
  • Tied embeddings (lm_head = wte.T)
  • Value embeddings on alternating layers (ResFormer-style)
  • Per-layer learnable residual scalars (resid_lambdas, x0_lambdas)
  • Smear β€” learnable gate on first 24 dims of token embedding mixes in prev token
  • Backout β€” subtract mid-layer residual from late layers
  • Logit soft-cap: 15 Β· tanh(logits / 15)
  • No biases in any Linear

KV cache

The YatGPTForCausalLM class implements a smear-aware KV cache for fast autoregressive generation. Pass use_cache=True (the default for .generate()).

Files in this repo

.
β”œβ”€β”€ config.json                       # HF config with auto_map β†’ the classes below
β”œβ”€β”€ generation_config.json
β”œβ”€β”€ model.safetensors                 # ~1.04 GB, fp32 weights + persistent RoPE buffers
β”œβ”€β”€ yatnmn_gpt.py                     # pure PyTorch Yat_GPT module + YatNMN layer
β”œβ”€β”€ torch_gpt.py                      # shared building blocks (RMSNorm, RoPE, attention)
β”œβ”€β”€ configuration_yatnmn_gpt.py       # PretrainedConfig subclass
β”œβ”€β”€ modeling_yatnmn_gpt.py            # PreTrainedModel + GenerationMixin wrapper with KV cache
└── README.md

Related

Wikitext-103 evaluation

Metric Value
Wikitext-103 test loss 3.677
Wikitext-103 test PPL 39.53

Evaluated on ~330K tokens from wikitext-103 test set (model trained on C4 only β€” this is a zero-shot transfer metric).

License

Apache 2.0.

Downloads last month
2,001
Safetensors
Model size
0.3B params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M-pytorch

Finetuned
(1)
this model

Dataset used to train mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M-pytorch

Space using mlnomad/yatnmn-softplus-sb-d12-chinchilla-261M-pytorch 1