Modular Multiplication Transformer
A 1-layer, 4-head transformer trained on (a x b) mod 113 that exhibits grokking (delayed generalization after memorization). This checkpoint includes full training history (400 checkpoints across 40,000 epochs).
Model Architecture
| Parameter | Value |
|---|---|
| Layers | 1 |
| Attention Heads | 4 |
| d_model | 128 |
| d_head | 32 |
| d_mlp | 512 |
| Activation | ReLU |
| Layer Norm | None |
| Vocabulary | 114 (0-112 + "=" separator) |
| Output Classes | 113 |
| Context Length | 3 tokens [a, b, =] |
| Trainable Parameters | ~230,000 |
Built with TransformerLens (HookedTransformer). No layer normalization and frozen biases.
Training Details
| Parameter | Value |
|---|---|
| Optimizer | AdamW |
| Learning Rate | 1e-3 |
| Weight Decay | 1.0 |
| Betas | (0.9, 0.98) |
| Epochs | 40,000 |
| Training Fraction | 30% (3,830 / 12,769 samples) |
| Batch Size | Full-batch |
| Data Seed | 598 |
| Model Seed | 999 |
Checkpoint Contents
checkpoint = torch.load("mod_mult_grokking.pth")
checkpoint["model"] # Final model state_dict
checkpoint["config"] # HookedTransformerConfig
checkpoint["checkpoints"] # List of 400 state_dicts (every 100 epochs)
checkpoint["checkpoint_epochs"] # [0, 100, 200, ..., 39900]
checkpoint["train_losses"] # 40,000 training loss values
checkpoint["test_losses"] # 40,000 test loss values
checkpoint["train_accs"] # 40,000 training accuracy values
checkpoint["test_accs"] # 40,000 test accuracy values
checkpoint["train_indices"] # Indices of 3,830 training samples
checkpoint["test_indices"] # Indices of 8,939 test samples
Usage
import torch
from transformer_lens import HookedTransformer
checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu")
model = HookedTransformer(checkpoint["config"])
model.load_state_dict(checkpoint["model"])
model.eval()
# Compute 7 * 16 mod 113 = 112
a, b = 7, 16
separator = 113 # "=" token
input_tokens = torch.tensor([[a, b, separator]])
logits = model(input_tokens)
prediction = logits[0, -1].argmax().item()
print(f"{a} * {b} mod 113 = {prediction}") # 112
References
- Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023.
- Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets."
- TransformerLens