Mobile MoE Architecture Search
Collection
32 MoE models from 41 experiments exploring expert count, routing, and learning rates for mobile deployment. • 33 items • Updated
Mobile-optimized MoE model configuration from architecture search.
# Load the model
from safetensors.torch import load_file
from architecture.model import Qwen3Model # Your custom model class
# Load config
import json
with open("config.json") as f:
config = json.load(f)
# Initialize model
model = Qwen3Model(config)
# Load weights
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("kshitijthakkar/moe-198m-114m-8x2-12L-8exp-balanced")
{
"vocab_size": 151936,
"emb_dim": 512,
"n_heads": 8,
"n_layers": 12,
"n_kv_groups": 2,
"num_experts": 8,
"num_experts_per_tok": 2,
"moe_hidden_dim": 768,
"head_dim": 64,
"max_position_embeddings": 4096,
"rope_base": 1000000.0,
"qk_norm": true
}
{
"model_config": {
"vocab_size": 151936,
"emb_dim": 512,
"n_heads": 8,
"n_layers": 12,
"n_kv_groups": 2,
"num_experts": 8,
"num_experts_per_tok": 2,
"moe_hidden_dim": 768,
"head_dim": 64,
"max_position_embeddings": 4096,
"rope_base": 1000000.0,
"qk_norm": true
},
"learning_rate": 0.0001,
"batch_size": 4,
"context_length": 1024,
"warmup_ratio": 0.1,
"warmup_steps": null,
"weight_decay": 0.1,
"gradient_clip": 1.0,
"gradient_accumulation_steps": 1,
"scheduler_type": "cosine",
"wsd_decay_ratio": 0.1,
"max_steps": 2000,
"eval_steps": 500,
"eval_batches": 20,
"log_steps": 100,
"early_stopping": true,
"early_stopping_patience": 500,
"early_stopping_min_delta": 0.01,
"early_stopping_min_steps": 200,
"track_expert_balance": true,
"expert_balance_log_steps": 100,
"use_wandb": true,
"wandb_project": "moe-architecture-search",
"wandb_entity": null,
"wandb_tags": [
"8exp_balanced",
"architecture-search"
],
"train_data_path": null,
"val_data_path": null,
"output_dir": null,
"experiment_name": "8exp_balanced",
"device": "cuda",
"dtype": "bfloat16",
"gradient_checkpointing": true,
"architecture_name": "8exp_balanced",
"mobile_estimate": {
"tok_per_sec_fp16": 41.06949925149829,
"tok_per_sec_q8": 68.44916541916382,
"tok_per_sec_q4": 95.82883158682934,
"ttft_ms_fp16": 72.35044072727273,
"ttft_ms_q8": 48.23362715151515,
"ttft_ms_q4": 38.58690172121212,
"memory_mb_fp16": 437.8291015625,
"memory_mb_q8": 262.88920898437505,
"memory_mb_q4": 165.93193359375,
"total_params": 198963712,
"active_params": 114029056,
"meets_ttft_target": true,
"meets_throughput_target": true,
"meets_memory_target": true
}
}