Kernels
TaehyunKim github-actions[bot] wyldecat Claude Opus 4.6 (1M context) commited on
Commit
10848ab
·
unverified ·
1 Parent(s): 67f7e11

draft commit for cpu_offload (#23)

Browse files

* draft commit for cpu_offload

* draft commit for cpu_offload [skip-build]

* claude is british

* Add built binary [skip-build]

* fix load, save and add test

* Add built binary [skip-build]

* Fix yapf formatting for pre-commit CI [skip-build]

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  2. build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} +1 -1
  3. build/torch210-cxx11-cu126-x86_64-linux/adamw.py +154 -37
  4. build/torch210-cxx11-cu126-x86_64-linux/core.py +134 -31
  5. build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py +188 -0
  6. build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +0 -6
  7. build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +11 -10
  8. build/torch210-cxx11-cu126-x86_64-linux/muon.py +573 -99
  9. build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py +206 -20
  10. build/torch210-cxx11-cu126-x86_64-linux/pipeline.py +158 -80
  11. build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py +15 -9
  12. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  13. build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} +1 -1
  14. build/torch210-cxx11-cu128-x86_64-linux/adamw.py +154 -37
  15. build/torch210-cxx11-cu128-x86_64-linux/core.py +134 -31
  16. build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py +188 -0
  17. build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +0 -6
  18. build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +11 -10
  19. build/torch210-cxx11-cu128-x86_64-linux/muon.py +573 -99
  20. build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py +206 -20
  21. build/torch210-cxx11-cu128-x86_64-linux/pipeline.py +158 -80
  22. build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py +15 -9
  23. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  24. build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} +1 -1
  25. build/torch210-cxx11-cu130-x86_64-linux/adamw.py +154 -37
  26. build/torch210-cxx11-cu130-x86_64-linux/core.py +134 -31
  27. build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py +188 -0
  28. build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +0 -6
  29. build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +11 -10
  30. build/torch210-cxx11-cu130-x86_64-linux/muon.py +573 -99
  31. build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py +206 -20
  32. build/torch210-cxx11-cu130-x86_64-linux/pipeline.py +158 -80
  33. build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py +15 -9
  34. build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
  35. build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} +1 -1
  36. build/torch210-cxx11-rocm70-x86_64-linux/adamw.py +154 -37
  37. build/torch210-cxx11-rocm70-x86_64-linux/core.py +134 -31
  38. build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py +188 -0
  39. build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +0 -6
  40. build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py +11 -10
  41. build/torch210-cxx11-rocm70-x86_64-linux/muon.py +573 -99
  42. build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py +206 -20
  43. build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py +158 -80
  44. build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py +15 -9
  45. build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +3 -3
  46. build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} +1 -1
  47. build/torch210-cxx11-rocm71-x86_64-linux/adamw.py +154 -37
  48. build/torch210-cxx11-rocm71-x86_64-linux/core.py +134 -31
  49. build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py +188 -0
  50. build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +0 -6
build/torch210-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_7aef62f_dirty
3
- ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_7aef62f_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_5b58933_dirty
3
+ ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_5b58933_dirty::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f095be87ff6185010a3cff4175abbde0b2e50fe1e435dc1db4eaf5bf1f6199ca
3
  size 1940944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90ace47a61519aefe759810c803789e7f91e6949ca0b04fc177e311709976334
3
  size 1940944
build/torch210-cxx11-cu126-x86_64-linux/adamw.py CHANGED
@@ -1,8 +1,12 @@
 
1
  from collections import defaultdict
2
  from typing import cast
3
 
4
  import torch
5
  from torch.distributed.tensor import DTensor
 
 
 
6
 
7
 
8
  def fused_adamw(
@@ -72,54 +76,72 @@ def fused_adamw(
72
  )
73
 
74
 
75
- def step_adamw_params(optimizer_state, params, group):
76
- """Run fused AdamW on a list of parameters sharing the same placement.
 
77
 
78
- Args:
79
- optimizer_state: The optimizer's state dict (self.state in Muon).
80
- params: List of parameters to update.
81
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  params_with_grads = []
84
  grads = []
85
  moment1 = []
86
  moment2 = []
87
- max_exp_avg_sqs = []
88
  state_steps = []
89
- lr = group["lr"]
90
- beta1, beta2 = group["adamw_betas"]
91
- eps = group["adamw_eps"]
92
- weight_decay = group["weight_decay"]
93
 
94
  for p in params:
95
  g = p.grad
96
  if g is None:
97
  continue
98
  state = optimizer_state[p]
99
- params_with_grads.append(p)
100
- grads.append(g)
101
  if "step" not in state:
102
- state["step"] = (torch.zeros((),
103
- dtype=torch.float32,
104
- device=p.device))
105
  state["moment1"] = torch.zeros_like(g)
106
  state["moment2"] = torch.zeros_like(g)
107
- moment1.append(state["moment1"])
108
- moment2.append(state["moment2"])
109
  if not isinstance(state["step"], torch.Tensor):
110
- step_tensor = torch.tensor(state["step"],
111
- dtype=torch.float32,
112
- device=p.device)
113
- else:
114
- step_tensor = state["step"]
115
- state_steps.append(step_tensor)
 
 
 
 
 
 
116
 
117
  fused_adamw(
118
  params_with_grads,
119
  grads,
120
  moment1,
121
  moment2,
122
- max_exp_avg_sqs,
123
  state_steps,
124
  amsgrad=False,
125
  beta1=beta1,
@@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group):
131
  )
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def step_adamw(optimizer_state, group):
135
  """Dispatch AdamW step, grouping parameters by type and placement.
136
 
 
 
 
137
  Args:
138
  optimizer_state: The optimizer's state dict (self.state in Muon).
139
  group: Parameter group dict.
140
  """
141
  params = group["params"]
 
142
 
143
- # group params with its type and placement
144
- placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
- for p in params:
146
- match p:
147
- case DTensor():
148
- placement_to_params[tuple([p.placements,
149
- p.device_mesh])].append(p)
150
- case torch.Tensor():
151
- placement_to_params[tuple([torch.Tensor, None])].append(p)
152
-
153
- for group_params in placement_to_params.values():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  step_adamw_params(optimizer_state, group_params, group)
 
1
+ import logging
2
  from collections import defaultdict
3
  from typing import cast
4
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ logger = logging.getLogger(__name__)
10
 
11
 
12
  def fused_adamw(
 
76
  )
77
 
78
 
79
+ def _to_local(t):
80
+ """Unwrap DTensor to local tensor for fused ops."""
81
+ return t._local_tensor if isinstance(t, DTensor) else t
82
 
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Caches for eliminating per-step Python overhead.
86
+ #
87
+ # Placement grouping and tensor list assembly are identical every step
88
+ # (params don't change placement, moment/step tensors are the same objects
89
+ # after initialisation). We cache them keyed by id() of the param list
90
+ # stored in param_groups (stable across steps).
91
+ #
92
+ # Only gradients change each step and must be collected fresh.
93
+ # ---------------------------------------------------------------------------
94
+
95
+ # id(group["params"]) → dict[placement_key, list[param]]
96
+ _placement_cache: dict[int, dict[tuple, list]] = {}
97
+
98
+ # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
+ _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
+
101
+
102
+ def _step_adamw_params_slow(optimizer_state, params, group):
103
+ """Uncached fallback for the rare case where some params lack grads."""
104
  params_with_grads = []
105
  grads = []
106
  moment1 = []
107
  moment2 = []
 
108
  state_steps = []
 
 
 
 
109
 
110
  for p in params:
111
  g = p.grad
112
  if g is None:
113
  continue
114
  state = optimizer_state[p]
115
+ params_with_grads.append(_to_local(p))
116
+ grads.append(_to_local(g))
117
  if "step" not in state:
118
+ state["step"] = torch.zeros((),
119
+ dtype=torch.float32,
120
+ device=p.device)
121
  state["moment1"] = torch.zeros_like(g)
122
  state["moment2"] = torch.zeros_like(g)
123
+ moment1.append(_to_local(state["moment1"]))
124
+ moment2.append(_to_local(state["moment2"]))
125
  if not isinstance(state["step"], torch.Tensor):
126
+ state["step"] = torch.tensor(state["step"],
127
+ dtype=torch.float32,
128
+ device=p.device)
129
+ state_steps.append(state["step"])
130
+
131
+ if not params_with_grads:
132
+ return
133
+
134
+ lr = group["lr"]
135
+ beta1, beta2 = group["adamw_betas"]
136
+ eps = group["adamw_eps"]
137
+ weight_decay = group["weight_decay"]
138
 
139
  fused_adamw(
140
  params_with_grads,
141
  grads,
142
  moment1,
143
  moment2,
144
+ [],
145
  state_steps,
146
  amsgrad=False,
147
  beta1=beta1,
 
153
  )
154
 
155
 
156
+ def step_adamw_params(optimizer_state, params, group):
157
+ """Run fused AdamW on a list of parameters sharing the same placement.
158
+
159
+ After the first call, cached tensor lists (params_local, moment1,
160
+ moment2, state_steps) are reused — only gradients are collected fresh.
161
+
162
+ Args:
163
+ optimizer_state: The optimizer's state dict (self.state in Muon).
164
+ params: List of parameters to update.
165
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
+ """
167
+ # Collect grads — the only thing that changes each step.
168
+ with record_function("adamw::collect_grads"):
169
+ grads = []
170
+ for p in params:
171
+ g = p.grad
172
+ if g is None:
173
+ # Rare: fall back to slow path that filters per-param.
174
+ _step_adamw_params_slow(optimizer_state, params, group)
175
+ return
176
+ grads.append(_to_local(g))
177
+
178
+ tensor_key = id(params)
179
+ if tensor_key not in _tensor_cache:
180
+ with record_function("adamw::init_tensor_cache"):
181
+ params_local = []
182
+ moment1 = []
183
+ moment2 = []
184
+ state_steps = []
185
+
186
+ for p in params:
187
+ state = optimizer_state[p]
188
+ params_local.append(_to_local(p))
189
+ if "step" not in state:
190
+ state["step"] = torch.zeros((),
191
+ dtype=torch.float32,
192
+ device=p.device)
193
+ state["moment1"] = torch.zeros_like(p.grad)
194
+ state["moment2"] = torch.zeros_like(p.grad)
195
+ moment1.append(_to_local(state["moment1"]))
196
+ moment2.append(_to_local(state["moment2"]))
197
+ if not isinstance(state["step"], torch.Tensor):
198
+ state["step"] = torch.tensor(state["step"],
199
+ dtype=torch.float32,
200
+ device=p.device)
201
+ state_steps.append(state["step"])
202
+
203
+ _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
+ state_steps)
205
+
206
+ params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
+
208
+ lr = group["lr"]
209
+ beta1, beta2 = group["adamw_betas"]
210
+ eps = group["adamw_eps"]
211
+ weight_decay = group["weight_decay"]
212
+
213
+ with record_function("adamw::fused_adamw"):
214
+ fused_adamw(
215
+ params_local,
216
+ grads,
217
+ moment1,
218
+ moment2,
219
+ [],
220
+ state_steps,
221
+ amsgrad=False,
222
+ beta1=beta1,
223
+ beta2=beta2,
224
+ lr=lr,
225
+ weight_decay=weight_decay,
226
+ eps=eps,
227
+ maximize=False,
228
+ )
229
+
230
+
231
  def step_adamw(optimizer_state, group):
232
  """Dispatch AdamW step, grouping parameters by type and placement.
233
 
234
+ Placement grouping is cached after the first call since params never
235
+ change their placement between steps.
236
+
237
  Args:
238
  optimizer_state: The optimizer's state dict (self.state in Muon).
239
  group: Parameter group dict.
240
  """
241
  params = group["params"]
242
+ placement_key = id(params)
243
 
244
+ if placement_key not in _placement_cache:
245
+ with record_function("adamw::group_by_placement"):
246
+ placement_to_params: dict[tuple,
247
+ list[torch.Tensor]] = defaultdict(list)
248
+ for p in params:
249
+ match p:
250
+ case DTensor():
251
+ logger.debug(
252
+ "[AdamW] DTensor param: shape=%s, placements=%s, "
253
+ "mesh=%s, grad=%s", p.shape, p.placements,
254
+ p.device_mesh.mesh_dim_names,
255
+ p.grad.shape if p.grad is not None else None)
256
+ placement_to_params[tuple(
257
+ [p.placements, p.device_mesh])].append(p)
258
+ case torch.Tensor():
259
+ logger.debug(
260
+ "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
+ p.grad.shape if p.grad is not None else None)
262
+ placement_to_params[tuple([torch.Tensor,
263
+ None])].append(p)
264
+
265
+ logger.debug("[AdamW] %d placement groups, %d total params",
266
+ len(placement_to_params), len(params))
267
+
268
+ _placement_cache[placement_key] = dict(placement_to_params)
269
+
270
+ for group_params in _placement_cache[placement_key].values():
271
  step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-cu126-x86_64-linux/core.py CHANGED
@@ -1,11 +1,25 @@
 
1
  import math
2
  from dataclasses import dataclass
 
3
 
4
  import torch
5
- import torch.distributed as dist
6
  from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @dataclass
11
  class _muon_state:
@@ -17,26 +31,71 @@ class _muon_state:
17
  qk_clip_state: torch.Tensor | None = None
18
 
19
 
20
- def update_g(optimizer_state, p, g, group, momentum):
21
- """Apply momentum update to gradient.
 
 
 
 
 
 
22
 
23
- Args:
24
- optimizer_state: The optimizer's state dict (self.state in Muon).
25
- p: Parameter tensor.
26
- g: Gradient tensor.
27
- group: Parameter group dict.
28
- momentum: Momentum coefficient.
29
 
30
- Returns:
31
- Momentum-updated gradient tensor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
- state = optimizer_state[p]
34
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
- torch.add(g, buf, alpha=momentum, out=buf)
36
- if group["nesterov"]:
37
- g.add_(buf, alpha=momentum)
38
- return g
39
- return buf
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay):
49
  adjusted_lr: Size-adjusted learning rate.
50
  weight_decay: Weight decay coefficient.
51
  """
52
- if isinstance(p, torch.nn.Parameter):
53
- # apply weight decay
54
- p.data.mul_(1 - lr * weight_decay)
55
- # apply update
56
- p.data.add_(u, alpha=-adjusted_lr)
57
- else:
58
- p.mul_(1 - lr * weight_decay)
59
- p.add_(u, alpha=-adjusted_lr)
60
 
61
 
62
  def adjust_lr_for_muon(lr, param_shape):
@@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape):
77
  return adjusted_lr
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def default_is_muon(name, x, expert_keys=None):
81
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
- if any(key in name for key in skip_keys):
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return False
84
  effective_ndim = x.ndim
85
- if expert_keys and any(key in name for key in expert_keys):
 
86
  effective_ndim -= 1
87
- return effective_ndim >= 2
 
 
 
 
 
88
 
89
 
90
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
@@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
92
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
 
94
  muon_params, muon_names = [], []
95
- non_muon_params = []
96
 
97
  for n, p in model.named_parameters():
98
  if not p.requires_grad:
@@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
102
  muon_names.append(n)
103
  else:
104
  non_muon_params.append(p)
 
 
 
 
105
 
106
  return [
107
  {
 
1
+ import logging
2
  import math
3
  from dataclasses import dataclass
4
+ from typing import List
5
 
6
  import torch
 
7
  from torch.distributed import ProcessGroup
8
  from torch.distributed.tensor import DTensor
9
 
10
+ # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
+ # parameter FQNs. Activation checkpointing similarly inserts
12
+ # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
+ # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
+ _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def normalize_fqn(name: str) -> str:
20
+ """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
+ return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
+
23
 
24
  @dataclass
25
  class _muon_state:
 
31
  qk_clip_state: torch.Tensor | None = None
32
 
33
 
34
+ def _batch_momentum(
35
+ grads: List[torch.Tensor],
36
+ momentum_bufs: List[torch.Tensor],
37
+ momentum: torch.Tensor,
38
+ ) -> None:
39
+ """Batched momentum update (no nesterov)."""
40
+ torch._foreach_mul_(momentum_bufs, momentum)
41
+ torch._foreach_add_(momentum_bufs, grads)
42
 
 
 
 
 
 
 
43
 
44
+ def _batch_momentum_nesterov(
45
+ grads: List[torch.Tensor],
46
+ momentum_bufs: List[torch.Tensor],
47
+ momentum: torch.Tensor,
48
+ ) -> None:
49
+ """Batched momentum update with nesterov correction."""
50
+ torch._foreach_mul_(momentum_bufs, momentum)
51
+ torch._foreach_add_(momentum_bufs, grads)
52
+ nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
+ torch._foreach_add_(grads, nesterov_terms)
54
+
55
+
56
+ _compiled_momentum: dict[bool, callable] = {}
57
+ _use_momentum_compile = True
58
+
59
+
60
+ def set_momentum_compile(enabled: bool):
61
+ """Toggle torch.compile for batched momentum."""
62
+ global _use_momentum_compile
63
+ _use_momentum_compile = enabled
64
+
65
+
66
+ def batch_pre_ortho(
67
+ grads: List[torch.Tensor],
68
+ momentum_bufs: List[torch.Tensor],
69
+ momentum: torch.Tensor,
70
+ nesterov: bool,
71
+ ) -> None:
72
+ """Batched momentum update on lists of plain tensors.
73
+
74
+ Mirrors dion's ``muon_update_pre_orthogonalize``.
75
+ Inputs must be plain CUDA tensors (not DTensor).
76
+ Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
+
78
+ When compile is enabled, uses separately compiled functions for
79
+ nesterov=True/False to avoid graph breaks from the branch.
80
  """
81
+ fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
+ if _use_momentum_compile:
83
+ if nesterov not in _compiled_momentum:
84
+ _compiled_momentum[nesterov] = torch.compile(fn)
85
+ fn = _compiled_momentum[nesterov]
86
+ fn(grads, momentum_bufs, momentum)
87
+
88
+
89
+ def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
+ """Weight-decay + update on plain tensors.
91
+
92
+ Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
+ lookup per call × 256+ params = massive overhead. The pipeline path uses
94
+ batched _foreach_* ops instead; this function remains for base() and
95
+ distributed_muon().
96
+ """
97
+ p_data.mul_(1 - lr * weight_decay)
98
+ p_data.add_(u_data, alpha=-adjusted_lr)
99
 
100
 
101
  def update_p(p, u, lr, adjusted_lr, weight_decay):
 
108
  adjusted_lr: Size-adjusted learning rate.
109
  weight_decay: Weight decay coefficient.
110
  """
111
+ # Unwrap Parameter -> underlying data tensor.
112
+ p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
+ # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
+ if isinstance(p_data, DTensor):
115
+ p_data = p_data._local_tensor
116
+ u_data = u._local_tensor if isinstance(u, DTensor) else u
117
+ _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
 
118
 
119
 
120
  def adjust_lr_for_muon(lr, param_shape):
 
135
  return adjusted_lr
136
 
137
 
138
+ def _match_key(parts, key):
139
+ """Check if key matches as contiguous components in parts.
140
+
141
+ Single-component keys (e.g. "experts") match any single component.
142
+ Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
+ """
144
+ key_parts = key.split(".")
145
+ key_len = len(key_parts)
146
+ if key_len == 1:
147
+ return key in parts
148
+ return any(parts[i:i + key_len] == key_parts
149
+ for i in range(len(parts) - key_len + 1))
150
+
151
+
152
+ def is_expert_param(name, expert_keys):
153
+ """Check if a parameter name matches any expert key (component-level)."""
154
+ if not expert_keys:
155
+ return False
156
+ parts = normalize_fqn(name).split(".")
157
+ return any(_match_key(parts, key) for key in expert_keys)
158
+
159
+
160
  def default_is_muon(name, x, expert_keys=None):
161
+ normalized = normalize_fqn(name)
162
+ parts = normalized.split(".")
163
+ skip_keys = [
164
+ "embed_tokens",
165
+ "lm_head",
166
+ "tok_embeddings",
167
+ "output",
168
+ "mhc_attn",
169
+ "mhc_ffn",
170
+ "lambda_proj",
171
+ ]
172
+ if any(key in parts for key in skip_keys):
173
+ logger.info(
174
+ "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
+ normalized, name, x.ndim)
176
  return False
177
  effective_ndim = x.ndim
178
+ is_expert = is_expert_param(name, expert_keys)
179
+ if is_expert:
180
  effective_ndim -= 1
181
+ result = effective_ndim >= 2
182
+ logger.info(
183
+ "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
+ normalized, name, x.ndim, is_expert, effective_ndim,
185
+ "Muon" if result else "AdamW")
186
+ return result
187
 
188
 
189
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
 
191
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
 
193
  muon_params, muon_names = [], []
194
+ non_muon_params, non_muon_names = [], []
195
 
196
  for n, p in model.named_parameters():
197
  if not p.requires_grad:
 
201
  muon_names.append(n)
202
  else:
203
  non_muon_params.append(p)
204
+ non_muon_names.append(n)
205
+
206
+ logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
+ expert_keys, len(muon_names), len(non_muon_names))
208
 
209
  return [
210
  {
build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CPU offloading for optimizer states.
2
+
3
+ Manages a pinned CPU memory pool and async CUDA streams to offload
4
+ optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
+ optimizer steps, freeing GPU memory.
6
+
7
+ All tracked tensors are packed into a single flat pinned CPU buffer
8
+ (per dtype). D2H and H2D copies are performed per-tensor directly
9
+ between individual GPU tensors and their slice of the CPU flat buffer
10
+ — no GPU staging buffer is allocated, so there is **no temporary GPU
11
+ memory spike** during offload or reload.
12
+
13
+ Individual tensor storages are freed after offload via
14
+ ``untyped_storage().resize_(0)``, preserving tensor identity so
15
+ downstream caches remain valid.
16
+ """
17
+
18
+ import logging
19
+ from collections import defaultdict
20
+
21
+ import torch
22
+ from torch.distributed.tensor import DTensor
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class CPUOffloadPool:
28
+ """Pinned CPU memory pool for async optimizer state offloading.
29
+
30
+ Tracked tensors are grouped by dtype. Each group gets a single flat
31
+ pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
+ the flat buffer) to avoid allocating a GPU staging buffer.
33
+ """
34
+
35
+ def __init__(self):
36
+ self._managed: list[torch.Tensor] = []
37
+ self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
+
39
+ # Per-dtype group: populated on first offload.
40
+ # dtype → dict with keys:
41
+ # "indices" : list[int] managed-list indices
42
+ # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
+ # "total" : int total numel
44
+ # "cpu_flat" : Tensor pinned CPU buffer
45
+ self._groups: dict[torch.dtype, dict] = {}
46
+
47
+ self._offload_stream: torch.cuda.Stream | None = None
48
+ self._device: torch.device | None = None
49
+ self._initialized: bool = False
50
+ self._logged: bool = False
51
+
52
+ # ------------------------------------------------------------------
53
+ @staticmethod
54
+ def _local(t: torch.Tensor) -> torch.Tensor:
55
+ """Unwrap DTensor to its local CUDA tensor."""
56
+ return t._local_tensor if isinstance(t, DTensor) else t
57
+
58
+ def _ensure_stream(self):
59
+ if self._offload_stream is None:
60
+ self._offload_stream = torch.cuda.Stream(device=self._device)
61
+
62
+ # ------------------------------------------------------------------
63
+ def track(self, tensor: torch.Tensor):
64
+ """Register a GPU tensor for CPU offloading. Idempotent."""
65
+ tid = id(tensor)
66
+ if tid in self._storage_nbytes:
67
+ return
68
+ local = self._local(tensor)
69
+ if self._device is None:
70
+ self._device = local.device
71
+ self._storage_nbytes[tid] = local.untyped_storage().size()
72
+ self._managed.append(tensor)
73
+
74
+ # ------------------------------------------------------------------
75
+ def _init_buffers(self):
76
+ """Build per-dtype flat buffers on first offload."""
77
+ # Group managed tensors by dtype.
78
+ dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
79
+ for idx, t in enumerate(self._managed):
80
+ local = self._local(t)
81
+ dtype_map[local.dtype].append((idx, local.numel()))
82
+
83
+ total_cpu_bytes = 0
84
+ for dtype, entries in dtype_map.items():
85
+ offsets: list[tuple[int, int]] = []
86
+ indices: list[int] = []
87
+ off = 0
88
+ for idx, n in entries:
89
+ indices.append(idx)
90
+ offsets.append((off, n))
91
+ off += n
92
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
93
+ self._groups[dtype] = {
94
+ "indices": indices,
95
+ "offsets": offsets,
96
+ "total": off,
97
+ "cpu_flat": cpu_flat,
98
+ }
99
+ total_cpu_bytes += off * cpu_flat.element_size()
100
+
101
+ self._initialized = True
102
+ logger.info(
103
+ "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
104
+ "%.2f MB pinned CPU memory",
105
+ len(self._managed),
106
+ len(self._groups),
107
+ total_cpu_bytes / (1024**2),
108
+ )
109
+
110
+ # ------------------------------------------------------------------
111
+ def offload(self):
112
+ """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
113
+ if not self._managed:
114
+ return
115
+ if not self._initialized:
116
+ self._init_buffers()
117
+ self._ensure_stream()
118
+
119
+ # Offload stream waits for compute to finish.
120
+ compute_event = torch.cuda.current_stream(
121
+ self._device).record_event()
122
+ self._offload_stream.wait_event(compute_event)
123
+
124
+ offloaded_bytes = 0
125
+
126
+ # Per-tensor D2H copies directly into CPU flat buffer slices.
127
+ # No GPU staging buffer → no temporary GPU memory spike.
128
+ with torch.cuda.stream(self._offload_stream):
129
+ for dtype, grp in self._groups.items():
130
+ indices = grp["indices"]
131
+ offsets = grp["offsets"]
132
+ cpu_flat = grp["cpu_flat"]
133
+
134
+ for i, mgd_idx in enumerate(indices):
135
+ local = self._local(self._managed[mgd_idx])
136
+ off, n = offsets[i]
137
+ cpu_flat[off:off + n].copy_(
138
+ local.reshape(-1), non_blocking=True)
139
+
140
+ offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
+
142
+ # Wait for all D2H copies to land, then free GPU storage.
143
+ self._offload_stream.synchronize()
144
+ for t in self._managed:
145
+ self._local(t).untyped_storage().resize_(0)
146
+
147
+ if not self._logged:
148
+ logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
149
+ offloaded_bytes / (1024**2))
150
+
151
+ # ------------------------------------------------------------------
152
+ def reload(self):
153
+ """Per-tensor H2D from CPU flat buffer on the default stream.
154
+
155
+ Runs on the current (default) CUDA stream to avoid stream
156
+ interaction issues with the parallel Muon pipeline. Since
157
+ pinned CPU memory is the source, the copies overlap with
158
+ GPU idle time between steps.
159
+ """
160
+ if not self._managed or not self._initialized:
161
+ return
162
+
163
+ reloaded_bytes = 0
164
+
165
+ # Re-allocate all GPU storages first.
166
+ for t in self._managed:
167
+ local = self._local(t)
168
+ local.untyped_storage().resize_(self._storage_nbytes[id(t)])
169
+
170
+ # Per-tensor H2D copies from CPU flat buffer slices.
171
+ # non_blocking=True with pinned source allows DMA overlap.
172
+ for dtype, grp in self._groups.items():
173
+ indices = grp["indices"]
174
+ offsets = grp["offsets"]
175
+ cpu_flat = grp["cpu_flat"]
176
+
177
+ for i, mgd_idx in enumerate(indices):
178
+ local = self._local(self._managed[mgd_idx])
179
+ off, n = offsets[i]
180
+ local.reshape(-1).copy_(
181
+ cpu_flat[off:off + n], non_blocking=True)
182
+
183
+ reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
+
185
+ if not self._logged:
186
+ logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)",
187
+ reloaded_bytes / (1024**2))
188
+ self._logged = True
build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py CHANGED
@@ -72,12 +72,6 @@ def get_slices_of_dtensor(
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
75
- if curr_size % num_chunks != 0:
76
- raise NotImplementedError(
77
- f"Dimension size {curr_size} is not divisible "
78
- f"by number of ranks {num_chunks} for shard "
79
- f"placement on dim {shard_dim}. (shape: {target.shape})")
80
-
81
  # Compute indices for this level of sharding
82
  if isinstance(placement, _StridedShard):
83
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
 
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
 
 
 
 
 
 
75
  # Compute indices for this level of sharding
76
  if isinstance(placement, _StridedShard):
77
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -43,6 +43,7 @@ def get_autotune_config():
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
 
46
  )
47
  @triton.jit
48
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
@@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
102
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
 
104
 
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
  d_in = d_in.contiguous()
116
  M, K = d_in.shape
117
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
@@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
46
+ restore_value=['y'],
47
  )
48
  @triton.jit
49
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
 
103
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
 
105
 
106
+ @torch.library.custom_op("muon::matmul_transpose_assign",
107
+ mutates_args=("d_out", ))
108
+ def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
+ """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
 
 
 
 
 
 
110
  d_in = d_in.contiguous()
111
  M, K = d_in.shape
112
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
 
114
  with torch.cuda.device(d_in.device.index):
115
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
  d_out.stride(0), d_out.stride(1))
117
+
118
+
119
+ @matmul_transpose_assign.register_fake
120
+ def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
+ """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
+ pass
build/torch210-cxx11-cu126-x86_64-linux/muon.py CHANGED
@@ -10,13 +10,16 @@ from torch.profiler import record_function
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon,
14
- get_default_muon_param_groups, update_g, update_p)
 
15
  from .distributed.utils import (_is_shard, construct_shard_mesh,
16
  get_slices_of_dtensor)
17
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
- _zeropower_via_newtonschulz5)
19
- from .pipeline import muon_chunk_pipeline
 
 
20
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
@@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys):
45
  expanded_params = []
46
 
47
  for n, p in zip(names, params):
48
- is_expert = expert_keys and any(key in n for key in expert_keys)
49
  is_dtensor = isinstance(p.data, DTensor)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if not is_expert:
52
  assert p.data.ndim <= 2, (
53
  f"Param {n} has ndim={p.data.ndim} but does not match "
@@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer):
168
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
  expert_keys: List of strings to identify expert-parallel parameters.
173
  If any key appears in a parameter's name, its outermost
174
  dimension is treated as the expert dimension and expanded
@@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer):
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
- small_param_numel_threshold=65536,
197
- expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
@@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer):
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
- self.small_param_numel_threshold = small_param_numel_threshold
232
  self.expert_keys = expert_keys
 
 
 
 
 
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
@@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer):
333
  if g is None:
334
  continue
335
 
336
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
- steps=group["ns_steps"])
338
 
339
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
  update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer):
355
  weight_decay: float,
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
- """ Implementation of Distributed Muon by Liu et al. """
359
 
360
- # Momentum is already applied by _step_muon before this method.
361
- for n, p in zip(names, params):
362
- g = p.grad
363
- if g is None:
364
- continue
365
-
366
- # Gather G
367
- if isinstance(p.data, DTensor):
368
- g_full = g.full_tensor()
369
- p_full = p.data.full_tensor()
370
- else:
371
- g_full = g
372
- p_full = p
373
-
374
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
- steps=group["ns_steps"])
376
-
377
- adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
- update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- scales_full = compute_scales(
383
- p_full, qk_clip_state) if qk_clip_state is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
- if scales_full is not None:
386
- qk_clip(p_full, scales_full, qk_clip_state.head_dim)
 
 
387
 
388
- if isinstance(p.data, DTensor):
389
- ndims = len(p.device_mesh.mesh.shape)
390
- p_replicate = DTensor.from_local(
391
- p_full,
392
- device_mesh=p.device_mesh,
393
- placements=[Replicate() for _ in range(ndims)],
394
- )
395
 
396
- p_sharded = p_replicate.redistribute(
397
- device_mesh=p.device_mesh,
398
- placements=p.placements,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  )
400
 
401
- p.copy_(p_sharded)
402
 
403
- def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
406
 
@@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer):
409
  interleaves multiple chunks so that communication and computation
410
  overlap across chunks (the same overlap previously achieved by the
411
  warmup + main-loop index scheduling).
 
 
 
 
412
  """
413
 
414
  # Momentum is already applied by _step_muon before this method.
415
 
416
- param_to_state, ordered_params = self.init_state_and_assign_params(
417
- names, params, group, qk_logits)
418
-
419
- # Compute local rank for this group's shard process group.
420
- shard_pg = param_to_state[id(ordered_params[0])].process_group
421
- rank = dist.get_rank(group=shard_pg)
422
-
423
- if self.chunk_size == -1:
424
- shard_ranks = dist.get_world_size(param_to_state[id(
425
- ordered_params[0])].process_group)
426
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
- elif self.chunk_size > 0:
428
- chunk_size = self.chunk_size
429
- else:
430
- raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
  def pipelines():
 
433
  for start in range(0, len(ordered_params), chunk_size):
434
  chunk = ordered_params[start:start + chunk_size]
435
  if chunk:
436
- yield muon_chunk_pipeline(
437
  params=chunk,
438
  param_to_state=param_to_state,
439
  rank=rank,
@@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer):
442
  weight_decay=weight_decay,
443
  none_grad=group["none_grad"],
444
  )
 
 
 
 
445
 
446
- with record_function("muon::barrier"):
447
- dist.barrier()
448
  with record_function("muon::pipeline"):
449
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
450
 
@@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer):
456
  names = group["names"]
457
 
458
  # Apply momentum to all params before routing/expansion.
 
459
  with record_function("muon::momentum"):
460
- for n, p in zip(names, params):
461
- g = p.grad
462
- if g is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  continue
464
- g = update_g(self.state, p, g, group, momentum)
465
- p.grad = g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Expand expert params by splitting on dim 0.
468
- names, params = _expand_expert_params(names, params, self.expert_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  param_dtensors = []
471
  name_dtensors = []
@@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer):
473
  param_tensors = []
474
  name_tensors = []
475
 
476
- param_dtensors_small = []
477
- name_dtensors_small = []
478
-
479
  if self.use_distributed_muon:
 
480
  self.distributed_muon(names=names,
481
  params=params,
482
  group=group,
@@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer):
485
  qk_logits=qk_logits)
486
  return
487
 
488
- # For simplicity, we use distributed Muon for small parameters
489
- # whose number of elements is below a threshold.
490
  for n, p in zip(names, params):
491
  if p is None or p.grad is None:
492
  continue
@@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer):
494
  if all(
495
  isinstance(placement, Replicate)
496
  for placement in p.placements):
 
 
 
497
  param_tensors.append(p)
498
  name_tensors.append(n)
499
- elif p.data.numel() <= self.small_param_numel_threshold:
500
- param_dtensors_small.append(p)
501
- name_dtensors_small.append(n)
502
  else:
 
 
 
 
503
  param_dtensors.append(p)
504
  name_dtensors.append(n)
505
  elif isinstance(p.data, torch.Tensor):
 
 
506
  param_tensors.append(p)
507
  name_tensors.append(n)
508
  else:
509
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
510
 
511
- logger.debug(
512
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
513
- f"{len(param_dtensors_small)} Small DTensors")
514
 
515
  def group_dtensors(dtensors, names):
516
  # To support different placements, we group parameters by placements
@@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer):
526
  p.device_mesh])][1].append(p)
527
  return placement_to_params
528
 
529
- if len(param_dtensors_small) > 0:
530
- if not dist.is_initialized():
531
- raise RuntimeError(
532
- "Parallel Muon requires torch.distributed to be initialized."
533
- )
534
-
535
- self.distributed_muon(
536
- params=param_dtensors_small,
537
- names=name_dtensors_small,
538
- group=group,
539
- lr=lr,
540
- weight_decay=weight_decay,
541
- qk_logits=qk_logits,
542
- )
543
-
544
  if len(param_dtensors) > 0:
545
  if not dist.is_initialized():
546
  raise RuntimeError(
@@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer):
548
  )
549
 
550
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  for _, (names, params) in dtensor_group.items():
 
 
552
  self.parallel(
553
  names,
554
  params,
@@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer):
556
  lr=lr,
557
  weight_decay=weight_decay,
558
  qk_logits=qk_logits,
 
559
  )
 
 
560
 
561
  if len(param_tensors) > 0:
562
  self.base(
@@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer):
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
@@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer):
585
  with torch.enable_grad():
586
  loss = closure()
587
 
588
- for group in self.param_groups:
 
 
 
 
 
 
 
589
  if group["use_muon"]:
 
 
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
 
 
 
592
  step_adamw(self.state, group)
593
 
 
 
 
 
 
 
 
594
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
+ get_default_muon_param_groups, is_expert_param, update_p)
15
+ from .cpu_offload import CPUOffloadPool
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
18
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
+ _zeropower_via_newtonschulz5,
20
+ zeropower_via_newtonschulz5,
21
+ zeropower_via_newtonschulz5_batched)
22
+ from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
23
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
24
 
25
  logger = logging.getLogger(__name__)
 
48
  expanded_params = []
49
 
50
  for n, p in zip(names, params):
51
+ is_expert = is_expert_param(n, expert_keys)
52
  is_dtensor = isinstance(p.data, DTensor)
53
 
54
+ if is_expert:
55
+ if is_dtensor:
56
+ logger.debug(
57
+ "[expand_expert] %s: expert DTensor, shape=%s, "
58
+ "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
59
+ p.placements, p.device_mesh.mesh_dim_names,
60
+ p.to_local().shape)
61
+ else:
62
+ logger.debug(
63
+ "[expand_expert] %s: expert plain tensor, shape=%s", n,
64
+ p.data.shape)
65
+
66
  if not is_expert:
67
  assert p.data.ndim <= 2, (
68
  f"Param {n} has ndim={p.data.ndim} but does not match "
 
183
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
184
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
185
  For testing purpose only.
 
186
  expert_keys: List of strings to identify expert-parallel parameters.
187
  If any key appears in a parameter's name, its outermost
188
  dimension is treated as the expert dimension and expanded
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None,
211
+ cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
 
242
  self.warmup_step = warmup_step
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
 
245
  self.expert_keys = expert_keys
246
+ self.cpu_offload = cpu_offload
247
+ self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
+ self._offload_initialized = False
249
+ self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
+ self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
251
 
252
  def _calc_flops(self, G, steps):
253
  assert len(G.shape) == 2
 
351
  if g is None:
352
  continue
353
 
354
+ u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
355
+ steps=group["ns_steps"])
356
 
357
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
358
  update_p(p, u, lr, adjusted_lr, weight_decay)
 
373
  weight_decay: float,
374
  qk_logits: list[torch.Tensor | DTensor] | None,
375
  ):
376
+ """Batched Distributed Muon for testing/correctness verification only.
377
 
378
+ Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
379
+ the full grad, then slices back to local shards. This is simpler but
380
+ slower than the parallel pipeline (all2all) path, so it serves as a
381
+ reference implementation for verifying correctness.
382
+ """
383
+ with record_function("distributed_muon"):
384
+ # Momentum is already applied by _step_muon before this method.
385
+ ns_steps = group["ns_steps"]
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ # Separate plain tensors (no communication) from DTensors.
388
+ plain_names, plain_params = [], []
389
+ dtensor_names, dtensor_params = [], []
390
+ for n, p in zip(names, params):
391
+ if p.grad is None:
392
+ continue
393
+ if isinstance(p.data, DTensor):
394
+ dtensor_names.append(n)
395
+ dtensor_params.append(p)
396
+ else:
397
+ plain_names.append(n)
398
+ plain_params.append(p)
399
+
400
+ # Process plain tensors per-param (no communication).
401
+ for n, p in zip(plain_names, plain_params):
402
+ u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
403
+ steps=ns_steps)
404
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
405
+ update_p(p, u, lr, adjusted_lr, weight_decay)
406
+
407
+ qk_clip_state = get_qk_clip_info(self.clip_config, n,
408
+ qk_logits)
409
+ scales_full = compute_scales(
410
+ p, qk_clip_state) if qk_clip_state is not None else None
411
+ if scales_full is not None:
412
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
413
+
414
+ if not dtensor_params:
415
+ return
416
+
417
+ # Group DTensors by (placements, mesh) for batched all-gather.
418
+ placement_groups: dict[tuple,
419
+ tuple[list,
420
+ list]] = defaultdict(lambda: ([], []))
421
+ for n, p in zip(dtensor_names, dtensor_params):
422
+ key = (p.placements, p.device_mesh)
423
+ placement_groups[key][0].append(n)
424
+ placement_groups[key][1].append(p)
425
+
426
+ logger.info(
427
+ "distributed_muon: %d placement groups, %d total dtensors",
428
+ len(placement_groups), len(dtensor_params))
429
+
430
+ for (placements, mesh), (grp_names,
431
+ grp_params) in placement_groups.items():
432
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
433
+ placements, mesh)
434
+ rank = dist.get_rank(shard_pg)
435
+ world_size = dist.get_world_size(shard_pg)
436
+
437
+ logger.info(" group: %d params, placements=%s, world_size=%d",
438
+ len(grp_params), placements, world_size)
439
+
440
+ # Separate params that can be batched (all shard dims evenly
441
+ # divisible) from those needing per-param full_tensor
442
+ # (e.g. MoE gate weights with fewer rows than shard ranks).
443
+ # all_gather_into_tensor requires equal buffer sizes across
444
+ # ranks, so uneven splits must use DTensor full_tensor().
445
+ batch_names, batch_params = [], []
446
+ single_names, single_params = [], []
447
+ for n, p in zip(grp_names, grp_params):
448
+ even = all(p.shape[pl.dim] %
449
+ shard_mesh.mesh.shape[dim_idx] == 0
450
+ for dim_idx, pl in enumerate(shard_placements))
451
+ if even:
452
+ batch_names.append(n)
453
+ batch_params.append(p)
454
+ else:
455
+ single_names.append(n)
456
+ single_params.append(p)
457
+
458
+ # Process uneven-split params per-param via full_tensor().
459
+ for n, p in zip(single_names, single_params):
460
+ with record_function("distributed_muon::newton_schulz"):
461
+ g_full = p.grad.full_tensor().to(COMM_DTYPE)
462
+ u_full = _zeropower_via_newtonschulz5(g_full,
463
+ steps=ns_steps)
464
+ del g_full
465
+ with record_function("distributed_muon::update"):
466
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
467
+ p._local_tensor.mul_(1 - lr * weight_decay)
468
+ local_indices = get_slices_of_dtensor(
469
+ p, rank, shard_mesh, shard_placements)
470
+ u_local = u_full[local_indices]
471
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
472
+ del u_full
473
+
474
+ qk_clip_state = get_qk_clip_info(
475
+ self.clip_config, n, qk_logits)
476
+ scales_full = compute_scales(
477
+ p, qk_clip_state
478
+ ) if qk_clip_state is not None else None
479
+ if scales_full is not None:
480
+ ratio = p.shape[0] // scales_full.shape[0]
481
+ idx0 = local_indices[0]
482
+ if isinstance(idx0, slice):
483
+ start = idx0.start or 0
484
+ idx0 = torch.arange(start,
485
+ idx0.stop,
486
+ device=scales_full.device)
487
+ row_scales = scales_full[idx0 // ratio]
488
+ p._local_tensor.mul_(row_scales.view(-1, 1))
489
+
490
+ if not batch_params:
491
+ continue
492
 
493
+ logger.info(" batched=%d, single=%d", len(batch_params),
494
+ len(single_params))
495
+
496
+ # Concat all local grad shards into a single flat buffer.
497
+ with record_function("distributed_muon::gather"):
498
+ grad_locals = [
499
+ p.grad.to_local().to(COMM_DTYPE).flatten()
500
+ for p in batch_params
501
+ ]
502
+ numels = [g.numel() for g in grad_locals]
503
+ grad_concat = torch.cat(grad_locals)
504
+ del grad_locals
505
+
506
+ # Single all-gather (replaces N separate full_tensor).
507
+ grad_gathered = torch.empty(
508
+ grad_concat.numel() * world_size,
509
+ dtype=COMM_DTYPE,
510
+ device="cuda",
511
+ )
512
+ dist.all_gather_into_tensor(grad_gathered,
513
+ grad_concat,
514
+ group=shard_pg)
515
+
516
+ total_numel = grad_concat.numel()
517
+ del grad_concat
518
+
519
+ # Precompute per-param offsets within the concat buffer.
520
+ offsets = []
521
+ off = 0
522
+ for ne in numels:
523
+ offsets.append(off)
524
+ off += ne
525
+
526
+ # Per-param: reconstruct full grad → NS → local update.
527
+ for i, (n, p) in enumerate(zip(batch_names, batch_params)):
528
+ with record_function("distributed_muon::newton_schulz"):
529
+ g_full = torch.empty(p.shape,
530
+ dtype=COMM_DTYPE,
531
+ device="cuda")
532
+ for r in range(world_size):
533
+ r_start = r * total_numel + offsets[i]
534
+ shard = grad_gathered[r_start:r_start + numels[i]]
535
+ indices = get_slices_of_dtensor(
536
+ p, r, shard_mesh, shard_placements)
537
+ g_full[indices] = shard.reshape(
538
+ g_full[indices].shape)
539
+
540
+ u_full = _zeropower_via_newtonschulz5(g_full,
541
+ steps=ns_steps)
542
+ del g_full
543
+
544
+ with record_function("distributed_muon::update"):
545
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
546
+ p._local_tensor.mul_(1 - lr * weight_decay)
547
+ local_indices = get_slices_of_dtensor(
548
+ p, rank, shard_mesh, shard_placements)
549
+ u_local = u_full[local_indices]
550
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
551
+ del u_full
552
+
553
+ qk_clip_state = get_qk_clip_info(
554
+ self.clip_config, n, qk_logits)
555
+ scales_full = compute_scales(
556
+ p, qk_clip_state
557
+ ) if qk_clip_state is not None else None
558
+ if scales_full is not None:
559
+ ratio = p.shape[0] // scales_full.shape[0]
560
+ idx0 = local_indices[0]
561
+ if isinstance(idx0, slice):
562
+ start = idx0.start or 0
563
+ idx0 = torch.arange(start,
564
+ idx0.stop,
565
+ device=scales_full.device)
566
+ row_scales = scales_full[idx0 // ratio]
567
+ p._local_tensor.mul_(row_scales.view(-1, 1))
568
+
569
+ def _setup_parallel(self, names, params, group, qk_logits):
570
+ """Compute (or retrieve cached) parallel pipeline metadata.
571
+
572
+ Returns:
573
+ (ordered_params, param_to_state, rank, chunk_size)
574
+ """
575
+ cache_key = tuple(names)
576
 
577
+ if cache_key not in self._parallel_cache:
578
+ # First call: compute metadata and populate cache.
579
+ param_to_state, ordered_params = self.init_state_and_assign_params(
580
+ names, params, group, qk_logits)
581
 
582
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
583
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
584
 
585
+ if self.chunk_size == -1:
586
+ shard_ranks = dist.get_world_size(shard_pg)
587
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
588
+ elif self.chunk_size > 0:
589
+ chunk_size = self.chunk_size
590
+ else:
591
+ raise ValueError(
592
+ "chunk_size must be -1 or a positive integer.")
593
+
594
+ ordered_names = [
595
+ param_to_state[id(p)].name for p in ordered_params
596
+ ]
597
+ name_to_state = {
598
+ param_to_state[id(p)].name: param_to_state[id(p)]
599
+ for p in ordered_params
600
+ }
601
+ self._parallel_cache[cache_key] = {
602
+ 'ordered_names': ordered_names,
603
+ 'name_to_state': name_to_state,
604
+ 'rank': rank,
605
+ 'chunk_size': chunk_size,
606
+ }
607
+ else:
608
+ # Cached path: rebuild param_to_state with current id(p) keys.
609
+ cache = self._parallel_cache[cache_key]
610
+ rank = cache['rank']
611
+ chunk_size = cache['chunk_size']
612
+
613
+ name_to_param = dict(zip(names, params))
614
+ ordered_params = [name_to_param[n] for n in cache['ordered_names']]
615
+
616
+ param_to_state = {}
617
+ for p, n in zip(ordered_params, cache['ordered_names']):
618
+ cached_state = cache['name_to_state'][n]
619
+ param_to_state[id(p)] = _muon_state(
620
+ worker_rank=cached_state.worker_rank,
621
+ process_group=cached_state.process_group,
622
+ rank_indices=cached_state.rank_indices,
623
+ rank_numels=cached_state.rank_numels,
624
+ name=n,
625
+ qk_clip_state=get_qk_clip_info(self.clip_config, n,
626
+ qk_logits),
627
  )
628
 
629
+ return ordered_params, param_to_state, rank, chunk_size
630
 
631
+ def parallel(self,
632
+ names,
633
+ params,
634
+ group,
635
+ lr,
636
+ weight_decay,
637
+ qk_logits,
638
+ prelaunch_gather=None):
639
  """
640
  Perform a parallel optimization step using Muon.
641
 
 
644
  interleaves multiple chunks so that communication and computation
645
  overlap across chunks (the same overlap previously achieved by the
646
  warmup + main-loop index scheduling).
647
+
648
+ If ``prelaunch_gather`` is provided, it is passed to the first
649
+ chunk's generator to skip re-launching the already in-flight
650
+ A2A gather.
651
  """
652
 
653
  # Momentum is already applied by _step_muon before this method.
654
 
655
+ ordered_params, param_to_state, rank, chunk_size = (
656
+ self._setup_parallel(names, params, group, qk_logits))
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
  def pipelines():
659
+ first = True
660
  for start in range(0, len(ordered_params), chunk_size):
661
  chunk = ordered_params[start:start + chunk_size]
662
  if chunk:
663
+ kwargs = dict(
664
  params=chunk,
665
  param_to_state=param_to_state,
666
  rank=rank,
 
669
  weight_decay=weight_decay,
670
  none_grad=group["none_grad"],
671
  )
672
+ if first and prelaunch_gather is not None:
673
+ kwargs['prelaunch_gather'] = prelaunch_gather
674
+ first = False
675
+ yield muon_chunk_pipeline(**kwargs)
676
 
 
 
677
  with record_function("muon::pipeline"):
678
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
679
 
 
685
  names = group["names"]
686
 
687
  # Apply momentum to all params before routing/expansion.
688
+ # Batched using _foreach_* ops (compiled, fullgraph=True).
689
  with record_function("muon::momentum"):
690
+ active_params = [p for p in params if p.grad is not None]
691
+ if active_params:
692
+ # Ensure momentum buffers exist (avoid zeros_like when already present).
693
+ for p in active_params:
694
+ if "momentum_buffer" not in self.state[p]:
695
+ self.state[p]["momentum_buffer"] = torch.zeros_like(
696
+ p.grad)
697
+
698
+ # Extract local tensors for compiled batch function.
699
+ local_grads = [
700
+ p.grad._local_tensor
701
+ if isinstance(p.grad, DTensor) else p.grad
702
+ for p in active_params
703
+ ]
704
+ local_bufs = [
705
+ self.state[p]["momentum_buffer"]._local_tensor
706
+ if isinstance(self.state[p]["momentum_buffer"], DTensor)
707
+ else self.state[p]["momentum_buffer"]
708
+ for p in active_params
709
+ ]
710
+
711
+ # Wrap momentum as tensor for torch.compile.
712
+ batch_pre_ortho(local_grads, local_bufs,
713
+ torch.tensor(momentum), group["nesterov"])
714
+
715
+ # For non-nesterov, the result is the momentum buffer.
716
+ if not group["nesterov"]:
717
+ for p in active_params:
718
+ p.grad = self.state[p]["momentum_buffer"]
719
+
720
+ # Identify batched experts for deferred NS.
721
+ # Detection is cheap (condition checks only); actual NS compute is
722
+ # deferred so it can overlap with the first chunk's A2A gather.
723
+ deferred_expert_work = []
724
+ if self.expert_keys:
725
+ batched_expert_indices = []
726
+ for i, (n, p) in enumerate(zip(names, params)):
727
+ if not (is_expert_param(n, self.expert_keys)
728
+ and p.grad is not None):
729
  continue
730
+ # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
731
+ if isinstance(p.data, DTensor):
732
+ has_tp = any(
733
+ _is_shard(pl) and pl.dim != 0 for pl in p.placements)
734
+ if has_tp:
735
+ continue
736
+ batched_expert_indices.append(i)
737
+
738
+ if batched_expert_indices:
739
+ # Save refs for deferred NS; free grads from param list.
740
+ for i in batched_expert_indices:
741
+ p = params[i]
742
+ g = p.grad
743
+ local_g = (g._local_tensor
744
+ if isinstance(g, DTensor) else g)
745
+ local_data = (p.data._local_tensor if isinstance(
746
+ p.data, DTensor) else p.data)
747
+ deferred_expert_work.append((local_data, local_g))
748
+ p.grad = None
749
+
750
+ # Remove batched experts from lists before expansion.
751
+ keep = sorted(
752
+ set(range(len(params))) - set(batched_expert_indices))
753
+ names = [names[i] for i in keep]
754
+ params = [params[i] for i in keep]
755
+
756
+ def _run_deferred_expert_ns():
757
+ """Execute deferred batched expert NS."""
758
+ if not deferred_expert_work:
759
+ return
760
+ with record_function("muon::batched_expert_ns"):
761
+ ns_steps = group["ns_steps"]
762
+ for local_data, local_g in deferred_expert_work:
763
+ u = zeropower_via_newtonschulz5_batched(
764
+ local_g.to(COMM_DTYPE), steps=ns_steps)
765
+ adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
766
+ local_data.mul_(1 - lr * weight_decay)
767
+ local_data.add_(u, alpha=-adjusted_lr)
768
 
769
  # Expand expert params by splitting on dim 0.
770
+ logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
771
+ len(params), self.expert_keys)
772
+ if self.expert_keys:
773
+ cache_key = tuple(id(p) for p in params)
774
+ cache = self._expert_expand_cache.get(cache_key)
775
+
776
+ if cache is None:
777
+ # Cold path: full expansion + build cache metadata.
778
+ exp_names, exp_params = _expand_expert_params(
779
+ names, params, self.expert_keys)
780
+
781
+ # Build per-expert-group info for hot-path grad updates.
782
+ grad_info = []
783
+ exp_idx = 0
784
+ for orig_idx, (n, p) in enumerate(zip(names, params)):
785
+ if not is_expert_param(n, self.expert_keys):
786
+ exp_idx += 1
787
+ continue
788
+
789
+ is_dt = isinstance(p.data, DTensor)
790
+ num_experts = (p.to_local() if is_dt else p.data).shape[0]
791
+
792
+ # Detect TP mesh from the first expanded expert param.
793
+ tp_mesh = None
794
+ tp_pls = None
795
+ sample = exp_params[exp_idx]
796
+ if isinstance(sample.data, DTensor):
797
+ tp_mesh = sample.data.device_mesh
798
+ tp_pls = list(sample.data.placements)
799
+
800
+ grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
801
+ tp_mesh, tp_pls))
802
+ exp_idx += num_experts
803
+
804
+ self._expert_expand_cache[cache_key] = {
805
+ 'names': exp_names,
806
+ 'params': exp_params,
807
+ 'grad_info': grad_info,
808
+ }
809
+ names, params = exp_names, exp_params
810
+ else:
811
+ # Hot path: reuse cached params, only update expert grads.
812
+ for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
813
+ tp_pls) in cache['grad_info']:
814
+ p = params[orig_idx]
815
+ g = p.grad
816
+ local_grad = (g.to_local()
817
+ if is_dt and isinstance(g, DTensor) else g)
818
+ for i in range(num_experts):
819
+ expert_p = cache['params'][exp_start + i]
820
+ sg = local_grad[i]
821
+ if tp_mesh is not None:
822
+ expert_p.grad = DTensor.from_local(
823
+ sg, device_mesh=tp_mesh, placements=tp_pls)
824
+ else:
825
+ expert_p.grad = sg
826
+ p.grad = None
827
+
828
+ names = cache['names']
829
+ params = cache['params']
830
+ else:
831
+ names, params = _expand_expert_params(names, params,
832
+ self.expert_keys)
833
+ logger.debug("[_step_muon] after expand: %d params", len(params))
834
 
835
  param_dtensors = []
836
  name_dtensors = []
 
838
  param_tensors = []
839
  name_tensors = []
840
 
841
+ # distributed_muon is a reference implementation for testing only.
842
+ # The parallel pipeline (all2all) path below is the production path.
 
843
  if self.use_distributed_muon:
844
+ _run_deferred_expert_ns()
845
  self.distributed_muon(names=names,
846
  params=params,
847
  group=group,
 
850
  qk_logits=qk_logits)
851
  return
852
 
 
 
853
  for n, p in zip(names, params):
854
  if p is None or p.grad is None:
855
  continue
 
857
  if all(
858
  isinstance(placement, Replicate)
859
  for placement in p.placements):
860
+ logger.debug(
861
+ "[route] %s → base (DTensor all-Replicate), "
862
+ "shape=%s, placements=%s", n, p.shape, p.placements)
863
  param_tensors.append(p)
864
  name_tensors.append(n)
 
 
 
865
  else:
866
+ logger.debug(
867
+ "[route] %s → parallel (DTensor), shape=%s, "
868
+ "placements=%s, mesh=%s", n, p.shape, p.placements,
869
+ p.device_mesh.mesh_dim_names)
870
  param_dtensors.append(p)
871
  name_dtensors.append(n)
872
  elif isinstance(p.data, torch.Tensor):
873
+ logger.debug("[route] %s → base (plain tensor), shape=%s", n,
874
+ p.data.shape)
875
  param_tensors.append(p)
876
  name_tensors.append(n)
877
  else:
878
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
879
 
880
+ logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
881
+ f"{len(param_tensors)} Tensors → base")
 
882
 
883
  def group_dtensors(dtensors, names):
884
  # To support different placements, we group parameters by placements
 
894
  p.device_mesh])][1].append(p)
895
  return placement_to_params
896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  if len(param_dtensors) > 0:
898
  if not dist.is_initialized():
899
  raise RuntimeError(
 
901
  )
902
 
903
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
904
+
905
+ # Pre-launch the first chunk's A2A gather so that the NCCL
906
+ # communication overlaps with the (deferred) batched expert NS
907
+ # compute on the default CUDA stream.
908
+ prelaunch = None
909
+ if deferred_expert_work:
910
+ first_names, first_params = next(iter(dtensor_group.values()))
911
+ ordered, pts, rnk, csz = self._setup_parallel(
912
+ first_names, first_params, group, qk_logits)
913
+ first_chunk = ordered[:csz]
914
+ if first_chunk:
915
+ prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
916
+ group["none_grad"])
917
+
918
+ _run_deferred_expert_ns()
919
+
920
+ first_group = True
921
  for _, (names, params) in dtensor_group.items():
922
+ pg = prelaunch if first_group else None
923
+ first_group = False
924
  self.parallel(
925
  names,
926
  params,
 
928
  lr=lr,
929
  weight_decay=weight_decay,
930
  qk_logits=qk_logits,
931
+ prelaunch_gather=pg,
932
  )
933
+ else:
934
+ _run_deferred_expert_ns()
935
 
936
  if len(param_tensors) > 0:
937
  self.base(
 
943
  qk_logits=qk_logits,
944
  )
945
 
946
+ def _register_states_for_offload(self):
947
+ """Register all optimizer state tensors with the CPU offload pool.
948
+
949
+ Called once after the first step when states have been lazily created.
950
+ Offloads all param states (momentum buffers for Muon, moment1/moment2
951
+ for AdamW) to free GPU memory between steps.
952
+ """
953
+ pool = self._cpu_offload_pool
954
+ tracked = 0
955
+ for group in self.param_groups:
956
+ for p in group["params"]:
957
+ if p not in self.state:
958
+ continue
959
+ state = self.state[p]
960
+ if group.get("use_muon", False):
961
+ if "momentum_buffer" in state:
962
+ pool.track(state["momentum_buffer"])
963
+ tracked += 1
964
+ else:
965
+ if "moment1" in state:
966
+ pool.track(state["moment1"])
967
+ if "moment2" in state:
968
+ pool.track(state["moment2"])
969
+ tracked += 1
970
+ logger.info("[CPUOffload] Registered %d param states for offload",
971
+ tracked)
972
+
973
  @torch.no_grad
974
  def step(self, closure=None, qk_logits=None):
975
  """Perform a single optimization step.
 
987
  with torch.enable_grad():
988
  loss = closure()
989
 
990
+ # H2D: reload optimizer states from CPU before computation.
991
+ if self.cpu_offload and self._offload_initialized:
992
+ self._cpu_offload_pool.reload()
993
+
994
+ logger.debug("[Muon.step] expert_keys=%s, %d param groups",
995
+ self.expert_keys, len(self.param_groups))
996
+
997
+ for i, group in enumerate(self.param_groups):
998
  if group["use_muon"]:
999
+ logger.debug("[Muon.step] group %d: use_muon=True, %d params",
1000
+ i, len(group["params"]))
1001
  self._step_muon(group, qk_logits=qk_logits)
1002
  else:
1003
+ logger.debug(
1004
+ "[Muon.step] group %d: use_muon=False (AdamW), %d params",
1005
+ i, len(group["params"]))
1006
  step_adamw(self.state, group)
1007
 
1008
+ # D2H: offload optimizer states to CPU after computation.
1009
+ if self.cpu_offload:
1010
+ if not self._offload_initialized:
1011
+ self._register_states_for_offload()
1012
+ self._offload_initialized = True
1013
+ self._cpu_offload_pool.offload()
1014
+
1015
  return loss
1016
+
1017
+ # ------------------------------------------------------------------
1018
+ # Checkpoint support for cpu_offload
1019
+ # ------------------------------------------------------------------
1020
+
1021
+ def state_dict(self) -> dict:
1022
+ """Return optimizer state dict, reloading offloaded states first.
1023
+
1024
+ When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
+ storage freed (``resize_(0)``) between steps. We reload them,
1026
+ snapshot the state dict, then re-offload so the optimizer stays
1027
+ in the expected post-step state. The returned dict holds cloned
1028
+ tensors so they remain valid after the re-offload frees the
1029
+ originals' GPU storage.
1030
+ """
1031
+ if self.cpu_offload and self._offload_initialized:
1032
+ self._cpu_offload_pool.reload()
1033
+ torch.cuda.current_stream().synchronize()
1034
+ sd = super().state_dict()
1035
+ if self.cpu_offload and self._offload_initialized:
1036
+ # Clone state tensors so the returned dict survives re-offload
1037
+ # (which frees GPU storage on the originals via resize_(0)).
1038
+ for k in sd["state"]:
1039
+ sd["state"][k] = {
1040
+ sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
+ for sk, sv in sd["state"][k].items()
1042
+ }
1043
+ self._cpu_offload_pool.offload()
1044
+ return sd
1045
+
1046
+ def load_state_dict(self, state_dict: dict) -> None:
1047
+ """Load optimizer state dict, then offload states if needed.
1048
+
1049
+ After ``super().load_state_dict()`` populates GPU tensors, we
1050
+ re-register them with the offload pool and offload to CPU so the
1051
+ optimizer is in the same post-step state (GPU storage freed).
1052
+ """
1053
+ # If states were offloaded, reload first so storage sizes are
1054
+ # correct for super().load_state_dict() to overwrite.
1055
+ if self.cpu_offload and self._offload_initialized:
1056
+ self._cpu_offload_pool.reload()
1057
+ torch.cuda.current_stream().synchronize()
1058
+
1059
+ super().load_state_dict(state_dict)
1060
+
1061
+ if self.cpu_offload:
1062
+ # Re-create the offload pool since state tensors may be new
1063
+ # objects after load_state_dict.
1064
+ self._cpu_offload_pool = CPUOffloadPool()
1065
+ self._offload_initialized = False
1066
+ self._register_states_for_offload()
1067
+ self._offload_initialized = True
1068
+ self._cpu_offload_pool.offload()
build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import torch
2
 
3
  from .matmul_transpose_triton import matmul_transpose_assign
@@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16
6
  DEFAULT_CHUNK_SIZE_RATIO = 4
7
 
8
 
9
- # This code snippet is a modified version adapted from the following GitHub repositories:
10
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
- # Muon's Newton–Schulz iteration causes high variance in singular values
12
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @torch.no_grad()
14
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
  def _zeropower_via_newtonschulz5(G, steps):
16
  """
17
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
- performance at all relative to UV^T, where USV^T = G is the SVD.
 
 
 
 
 
 
 
24
  """
25
  assert len(G.shape) == 2
26
  assert G.dtype == COMM_DTYPE
@@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps):
28
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
- # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
 
 
33
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
  # Perform the NS iterations
36
- for a, b, c in [
37
- (4.0848, -6.8946, 2.9270),
38
- (3.9505, -6.3029, 2.6377),
39
- (3.7418, -5.5913, 2.3037),
40
- (2.8769, -3.1427, 1.2046),
41
- (2.8366, -3.0525, 1.2012),
42
- ]:
43
  matmul_transpose_assign(X, buf1)
44
  matmul_transpose_assign(buf1, buf2)
45
  buf1.mul_(b).add_(buf2, alpha=c)
@@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps):
47
 
48
  if G.size(0) > G.size(1):
49
  X = X.T
 
50
  return X
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ from math import inf, sqrt
3
+
4
+ import numpy as np
5
  import torch
6
 
7
  from .matmul_transpose_triton import matmul_transpose_assign
 
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
 
13
+ def _optimal_quintic(l, u, max_iter=1000):
14
+ """
15
+ Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
+ to the constant function x -> 1 over the interval [l, u].
17
+
18
+ Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
+ approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
+ two interior equioscillation nodes q, r until convergence. Returns the
21
+ closed-form equioscillating solution when l ≈ u.
22
+
23
+ Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
+ (NaN or inf). Raises RuntimeError if convergence is not reached within
25
+ max_iter iterations.
26
+ """
27
+ assert 0 <= l <= u
28
+ if 1 - 5e-6 <= l / u:
29
+ return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
+ q = (3 * l + u) / 4
31
+ r = (l + 3 * u) / 4
32
+ E = inf
33
+ for _ in range(max_iter):
34
+ old_E = E
35
+ LHS = np.array([
36
+ [l, l**3, l**5, 1],
37
+ [q, q**3, q**5, -1],
38
+ [r, r**3, r**5, 1],
39
+ [u, u**3, u**5, -1],
40
+ ])
41
+ a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
+ if not np.all(np.isfinite([a, b, c, E])):
43
+ raise ValueError(f"_optimal_quintic: non-finite solve result "
44
+ f"a={a}, b={b}, c={c}, E={E}")
45
+ q, r = np.sqrt(
46
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
+ (10 * c))
48
+ if not np.all(np.isfinite([q, r])):
49
+ raise ValueError(
50
+ f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
+ if abs(old_E - E) <= 1e-15:
52
+ break
53
+ else:
54
+ raise RuntimeError(
55
+ f"_optimal_quintic: did not converge after {max_iter} iterations")
56
+ return float(a), float(b), float(c)
57
+
58
+
59
+ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
60
+ """
61
+ Compute the Polar Express coefficient series for `num_iters` quintic iterations.
62
+
63
+ Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
64
+ compose to map singular values from [l, 1] toward 1. At each step:
65
+ 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
66
+ prevents near-zero singular values from stalling by raising the effective
67
+ lower bound; if it is active (cushion*u > l), the coefficients are
68
+ rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
69
+ 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
70
+ last iteration, providing numerical headroom at the cost of a slightly slower
71
+ final convergence step.
72
+ 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
73
+
74
+ Returns a list of (a, b, c) tuples, one per iteration.
75
+
76
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
77
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
78
+ """
79
+ u = 1
80
+ assert 0 <= l <= u
81
+ safety_factor = 1 + safety_factor_eps
82
+ coefficients = []
83
+ for iter in range(num_iters):
84
+ a, b, c = _optimal_quintic(max(l, cushion * u), u)
85
+ if cushion * u > l:
86
+ pl = a * l + b * l**3 + c * l**5
87
+ pu = a * u + b * u**3 + c * u**5
88
+ rescaler = 2 / (pl + pu)
89
+ a *= rescaler
90
+ b *= rescaler
91
+ c *= rescaler
92
+ if iter < num_iters - 1:
93
+ a /= safety_factor
94
+ b /= safety_factor**3
95
+ c /= safety_factor**5
96
+ coefficients.append((a, b, c))
97
+ l = a * l + b * l**3 + c * l**5
98
+ u = 2 - l
99
+ return coefficients
100
+
101
+
102
+ # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
103
+ # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
104
+ # approximant to x->1 over the current singular-value interval, computed once at
105
+ # import time and reused across all optimizer steps.
106
+ #
107
+ # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
108
+ # - Former: empirically tuned to maximize slope at zero; did not converge
109
+ # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
110
+ # of the true polar factor UV^T.
111
+ # - Polar Express: analytically optimal per step, adapting to the shrinking
112
+ # singular-value interval [l, u] as iterations progress; converges all
113
+ # singular values to 1, producing the exact polar factor UV^T.
114
+ _coeffs_list = _optimal_composition(l=1e-3,
115
+ num_iters=10,
116
+ safety_factor_eps=1e-2,
117
+ cushion=0.02)
118
+
119
+
120
+ # This code is adapted from:
121
+ # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
122
+ # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
123
+ # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
124
  @torch.no_grad()
 
125
  def _zeropower_via_newtonschulz5(G, steps):
126
  """
127
+ Compute the polar factor of G via the Polar Express method.
128
+
129
+ Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
130
+ are the Polar Express coefficients from `_coeffs_list`. Each step is the
131
+ optimal odd quintic approximant to x -> 1 over the current singular-value
132
+ interval, minimizing the maximum approximation error (Remez / minimax criterion).
133
+ The composition maps singular values from [l, 1] to near 1, producing the
134
+ polar factor (orthogonal factor in the polar decomposition G = UP).
135
+
136
+ `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
137
+ cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
138
+
139
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
140
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
141
  """
142
  assert len(G.shape) == 2
143
  assert G.dtype == COMM_DTYPE
 
145
 
146
  if G.size(0) > G.size(1):
147
  X = X.T
148
+
149
  X = X / (X.norm() + 1e-7)
150
+ hs = _coeffs_list[:steps] + list(
151
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
155
+ for a, b, c in hs:
 
 
 
 
 
 
156
  matmul_transpose_assign(X, buf1)
157
  matmul_transpose_assign(buf1, buf2)
158
  buf1.mul_(b).add_(buf2, alpha=c)
 
160
 
161
  if G.size(0) > G.size(1):
162
  X = X.T
163
+
164
  return X
165
+
166
+
167
+ @torch.no_grad()
168
+ def _zeropower_via_newtonschulz5_batched(G, steps):
169
+ """Batched polar factor computation for 3D (E, out, in) tensors.
170
+
171
+ Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
172
+ ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
173
+ processing all E expert matrices in a single batched call.
174
+ """
175
+ assert len(G.shape) == 3
176
+ assert G.dtype == COMM_DTYPE
177
+ X = G
178
+
179
+ if G.size(1) > G.size(2):
180
+ X = X.transpose(-2, -1)
181
+
182
+ # Per-expert Frobenius norm.
183
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
+
185
+ hs = _coeffs_list[:steps] + list(
186
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
187
+ for a, b, c in hs:
188
+ buf1 = torch.bmm(X, X.transpose(-2, -1))
189
+ buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
190
+ buf1.mul_(b).add_(buf2, alpha=c)
191
+ X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
192
+
193
+ if G.size(1) > G.size(2):
194
+ X = X.transpose(-2, -1)
195
+
196
+ return X
197
+
198
+
199
+ _ns_per_shape: dict[tuple[int, ...], callable] = {}
200
+ _use_compile = True
201
+
202
+
203
+ def set_ns_compile(enabled: bool):
204
+ """Toggle torch.compile for Newton-Schulz iteration."""
205
+ global _use_compile
206
+ _use_compile = enabled
207
+
208
+
209
+ def zeropower_via_newtonschulz5(G, steps=5):
210
+ if not _use_compile:
211
+ return _zeropower_via_newtonschulz5(G, steps)
212
+ key = G.shape
213
+ if key not in _ns_per_shape:
214
+ _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
215
+ options={
216
+ "triton.cudagraphs": True,
217
+ "shape_padding": False
218
+ })
219
+ torch.compiler.cudagraph_mark_step_begin()
220
+ return _ns_per_shape[key](G, steps).clone()
221
+
222
+
223
+ def zeropower_via_newtonschulz5_batched(G, steps=5):
224
+ """Compile-cached batched Newton-Schulz for 3D expert tensors."""
225
+ if not _use_compile:
226
+ return _zeropower_via_newtonschulz5_batched(G, steps)
227
+ key = G.shape
228
+ if key not in _ns_per_shape:
229
+ _ns_per_shape[key] = torch.compile(
230
+ _zeropower_via_newtonschulz5_batched,
231
+ options={
232
+ "triton.cudagraphs": True,
233
+ "shape_padding": False
234
+ })
235
+ torch.compiler.cudagraph_mark_step_begin()
236
+ return _ns_per_shape[key](G, steps).clone()
build/torch210-cxx11-cu126-x86_64-linux/pipeline.py CHANGED
@@ -6,8 +6,8 @@ import torch.distributed as dist
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
- from .core import _muon_state, adjust_lr_for_muon, update_p
10
- from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
@@ -45,26 +45,33 @@ def _launch_gather(
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
- # Build send buffer
49
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
  send_counts = [0] * num_ranks
51
-
52
  for p in params:
53
  state = param_to_state[id(p)]
54
- dst = state.worker_rank
55
- assert dst < num_ranks
56
- shard_elems = state.rank_numels[rank]
57
- g = p.grad
58
- g = g.to_local().to(COMM_DTYPE).contiguous()
59
- assert g.numel() == shard_elems
60
- per_dst[dst].append(g.view(-1))
61
- send_counts[dst] += shard_elems
62
-
63
- assert any(
64
- len(v) > 0 for v in
65
- per_dst), "At least one destination rank must receive a sharded tensor"
66
- per_dst_flat = [t for dst in per_dst for t in dst]
67
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
 
 
 
68
 
69
  # Build recv buffer
70
  recv_counts = [0] * num_ranks
@@ -120,7 +127,8 @@ def _complete_gather(
120
 
121
  shard_view = gathered_grads[id(p)][indices]
122
  n = shard_view.numel()
123
- assert n > 0
 
124
 
125
  sg = recv_buf.narrow(0, off + inner_off, n)
126
  sg = sg.reshape(shard_view.shape)
@@ -143,7 +151,7 @@ def _compute_ns(
143
  """
144
  computed_us: dict[int, torch.Tensor | None] = {}
145
  for p in owned_params:
146
- u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
  gathered_grads[id(p)] = None # free gathered grad
148
  computed_us[id(p)] = u
149
  return computed_us
@@ -163,46 +171,47 @@ def _launch_scatter(
163
  Returns:
164
  work: Async operation handle.
165
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
- scattered_us: ``{id(p): empty_local_tensor}`` for all params.
 
167
  recv_counts: Per-source-rank element counts.
168
  """
169
- # Allocate scattered-u buffers
 
 
 
170
  scattered_us: dict[int, torch.Tensor] = {}
171
  for p in params:
172
- scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
 
 
173
 
174
- # Build send buffer (from computed_us on owner ranks)
175
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
  send_counts = [0] * num_ranks
177
-
178
  if owned_params:
179
  for p in owned_params:
180
  state = param_to_state[id(p)]
181
-
182
- assert computed_us[id(p)] is not None
183
- u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
-
185
- total_sent = 0
186
  for dst_rank in range(num_ranks):
187
- indices = state.rank_indices[dst_rank]
188
- su = u_full[indices].flatten()
189
-
190
- n = su.numel()
191
- assert n > 0
192
 
193
- per_dst[dst_rank].append(su)
194
- send_counts[dst_rank] += n
195
- total_sent += n
196
-
197
- assert total_sent == u_full.numel()
198
-
199
- lengths = [len(v) for v in per_dst]
200
- if all(l > 0 for l in lengths):
201
- assert all(
202
- l == lengths[0] for l in lengths
203
- ), "All destination ranks must have the same number of sharded tensor"
204
- per_dst_flat = [t for dst in per_dst for t in dst]
205
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
206
  else:
207
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
 
@@ -218,7 +227,6 @@ def _launch_scatter(
218
  recv_counts[src] = total
219
 
220
  recv_total = sum(recv_counts)
221
- assert recv_total > 0
222
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
 
224
  # Launch async all-to-all
@@ -242,7 +250,13 @@ def _complete_scatter(
242
  rank: int,
243
  scattered_us: dict[int, torch.Tensor],
244
  ) -> None:
245
- """Copy recv buffer into scattered_us (in-place)."""
 
 
 
 
 
 
246
  off = 0
247
  for src in range(len(recv_counts)):
248
  block = recv_counts[src]
@@ -255,11 +269,11 @@ def _complete_scatter(
255
  if state.worker_rank != src:
256
  continue
257
  n = state.rank_numels[rank]
258
- assert n > 0
 
259
 
260
- flat_local = recv_buf.narrow(0, off + inner_off,
261
- n).view_as(p.to_local())
262
- scattered_us[id(p)].copy_(flat_local)
263
 
264
  inner_off += n
265
 
@@ -275,23 +289,40 @@ def _update_params(
275
  lr: float,
276
  weight_decay: float,
277
  ) -> None:
278
- """Apply weight decay, Muon update, and optional QK clipping."""
279
- for p in params:
280
- state = param_to_state[id(p)]
281
- u_dtensor = DTensor.from_local(
282
- scattered_us[id(p)],
283
- placements=p.placements,
284
- device_mesh=p.device_mesh,
285
- )
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
- update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
 
 
 
289
 
290
- # QK clipping applied directly on the local tensor to
291
- # avoid DTensor sharding-propagation issues with _StridedShard.
292
- scales_full = compute_scales(
293
- p,
294
- state.qk_clip_state) if state.qk_clip_state is not None else None
 
 
 
 
 
295
  if scales_full is not None:
296
  ratio = p.shape[0] // scales_full.shape[0]
297
  idx0 = state.rank_indices[rank][0]
@@ -304,6 +335,45 @@ def _update_params(
304
  p._local_tensor.mul_(row_scales.view(-1, 1))
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # ======================================================================
308
  # Main generator – thin orchestrator that wires stages together.
309
  # ======================================================================
@@ -318,6 +388,7 @@ def muon_chunk_pipeline(
318
  lr: float,
319
  weight_decay: float,
320
  none_grad: bool,
 
321
  ) -> Generator[None, None, None]:
322
  """Process one chunk of parameters through the full Muon pipeline.
323
 
@@ -334,9 +405,12 @@ def muon_chunk_pipeline(
334
  runs concurrently on the NCCL stream — no separate ``comm_stream``
335
  is required.
336
 
 
 
 
337
  Yields exactly **2** times:
338
 
339
- 1. After launching async all-to-all gather.
340
  2. After launching async all-to-all scatter.
341
  """
342
  process_group = param_to_state[id(params[0])].process_group
@@ -345,15 +419,19 @@ def muon_chunk_pipeline(
345
  p for p in params if param_to_state[id(p)].worker_rank == rank
346
  ]
347
 
348
- # Stages 1-2: launch async gather.
349
- with record_function("muon::launch_gather"):
350
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
- params, owned_params, param_to_state, rank, num_ranks,
352
- process_group)
353
-
354
- if none_grad:
355
- for p in params:
356
- p.grad = None
 
 
 
 
357
 
358
  yield # --- YIELD 1: other chunks can launch their gather ---
359
 
 
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
+ from .core import _muon_state, adjust_lr_for_muon
10
+ from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
 
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
+ # Build send buffer – batch grad copies via torch.cat
49
+ # (1-2 fused kernels vs N individual narrow().copy_() calls).
50
  send_counts = [0] * num_ranks
 
51
  for p in params:
52
  state = param_to_state[id(p)]
53
+ send_counts[state.worker_rank] += state.rank_numels[rank]
54
+
55
+ total_send = sum(send_counts)
56
+ if total_send > 0:
57
+ # Group grad slices by destination rank in a single pass.
58
+ dst_to_grads = [[] for _ in range(num_ranks)]
59
+ for p in params:
60
+ state = param_to_state[id(p)]
61
+ n = state.rank_numels[rank]
62
+ if n > 0:
63
+ g = p.grad.to_local()
64
+ dst_to_grads[state.worker_rank].append(g.reshape(-1))
65
+
66
+ # Flatten in dst order and cat once.
67
+ all_slices = []
68
+ for dst in range(num_ranks):
69
+ all_slices.extend(dst_to_grads[dst])
70
+ send_buf = torch.cat(all_slices)
71
+ if send_buf.dtype != COMM_DTYPE:
72
+ send_buf = send_buf.to(COMM_DTYPE)
73
+ else:
74
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
75
 
76
  # Build recv buffer
77
  recv_counts = [0] * num_ranks
 
127
 
128
  shard_view = gathered_grads[id(p)][indices]
129
  n = shard_view.numel()
130
+ if n == 0:
131
+ continue
132
 
133
  sg = recv_buf.narrow(0, off + inner_off, n)
134
  sg = sg.reshape(shard_view.shape)
 
151
  """
152
  computed_us: dict[int, torch.Tensor | None] = {}
153
  for p in owned_params:
154
+ u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
155
  gathered_grads[id(p)] = None # free gathered grad
156
  computed_us[id(p)] = u
157
  return computed_us
 
171
  Returns:
172
  work: Async operation handle.
173
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
174
+ scattered_us: Empty dict, populated by ``_complete_scatter`` with
175
+ zero-copy views into ``recv_buf``.
176
  recv_counts: Per-source-rank element counts.
177
  """
178
+ # scattered_us is populated by _complete_scatter with zero-copy views
179
+ # into recv_buf, avoiding N empty_like allocations + N copy_ calls.
180
+ # Pre-seed entries for params whose local shard is empty (rank_numels == 0)
181
+ # so _update_params can iterate all params without KeyError.
182
  scattered_us: dict[int, torch.Tensor] = {}
183
  for p in params:
184
+ if param_to_state[id(p)].rank_numels[rank] == 0:
185
+ scattered_us[id(p)] = torch.empty_like(p.to_local(),
186
+ dtype=COMM_DTYPE)
187
 
188
+ # Build send buffer batch via torch.cat
189
+ # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
190
  send_counts = [0] * num_ranks
 
191
  if owned_params:
192
  for p in owned_params:
193
  state = param_to_state[id(p)]
 
 
 
 
 
194
  for dst_rank in range(num_ranks):
195
+ send_counts[dst_rank] += state.rank_numels[dst_rank]
 
 
 
 
196
 
197
+ total_send = sum(send_counts)
198
+ if total_send > 0:
199
+ # Cache u_full conversions to avoid redundant .to() per dst_rank.
200
+ u_fulls = {}
201
+ for p in owned_params:
202
+ u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
203
+
204
+ # Collect slices in dst order (matches all-to-all send layout).
205
+ all_slices = []
206
+ for dst_rank in range(num_ranks):
207
+ for p in owned_params:
208
+ state = param_to_state[id(p)]
209
+ su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
210
+ if su.numel() > 0:
211
+ all_slices.append(su)
212
+
213
+ send_buf = torch.cat(all_slices) if all_slices else torch.empty(
214
+ 0, dtype=COMM_DTYPE, device="cuda")
215
  else:
216
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
217
 
 
227
  recv_counts[src] = total
228
 
229
  recv_total = sum(recv_counts)
 
230
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
231
 
232
  # Launch async all-to-all
 
250
  rank: int,
251
  scattered_us: dict[int, torch.Tensor],
252
  ) -> None:
253
+ """Populate scattered_us with zero-copy views into recv_buf.
254
+
255
+ Instead of pre-allocating tensors and copying, we assign views directly
256
+ from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
257
+ The underlying storage of ``recv_buf`` is kept alive through the views
258
+ until ``scattered_us`` is cleared after ``_update_params``.
259
+ """
260
  off = 0
261
  for src in range(len(recv_counts)):
262
  block = recv_counts[src]
 
269
  if state.worker_rank != src:
270
  continue
271
  n = state.rank_numels[rank]
272
+ if n == 0:
273
+ continue
274
 
275
+ scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
276
+ n).view_as(p.to_local())
 
277
 
278
  inner_off += n
279
 
 
289
  lr: float,
290
  weight_decay: float,
291
  ) -> None:
292
+ """Apply weight decay, Muon update, and optional QK clipping.
 
 
 
 
 
 
 
293
 
294
+ Uses batched ``_foreach_mul_`` for weight decay and batched
295
+ ``_foreach_add_`` for the Muon update, grouping parameters by
296
+ adjusted_lr to minimize kernel launches while preserving float32
297
+ precision for the alpha scaling.
298
+ """
299
+ if not params:
300
+ return
301
+
302
+ # Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
303
+ p_locals = [p._local_tensor for p in params]
304
+ torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
305
+
306
+ # Group params by adjusted_lr so _foreach_add_ can use a single
307
+ # alpha per group (preserves float32 precision for alpha scaling).
308
+ lr_groups: dict[float, tuple[list, list]] = {}
309
+ for p in params:
310
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
311
+ if adjusted_lr not in lr_groups:
312
+ lr_groups[adjusted_lr] = ([], [])
313
+ lr_groups[adjusted_lr][0].append(p._local_tensor)
314
+ lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
315
 
316
+ for adjusted_lr, (p_group, u_group) in lr_groups.items():
317
+ torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
318
+
319
+ # QK clipping – applied directly on the local tensor to
320
+ # avoid DTensor sharding-propagation issues with _StridedShard.
321
+ for p in params:
322
+ state = param_to_state[id(p)]
323
+ if state.qk_clip_state is None:
324
+ continue
325
+ scales_full = compute_scales(p, state.qk_clip_state)
326
  if scales_full is not None:
327
  ratio = p.shape[0] // scales_full.shape[0]
328
  idx0 = state.rank_indices[rank][0]
 
335
  p._local_tensor.mul_(row_scales.view(-1, 1))
336
 
337
 
338
+ # ======================================================================
339
+ # Pre-launch helper for overlapping first chunk's gather with other work.
340
+ # ======================================================================
341
+
342
+
343
+ @torch.no_grad()
344
+ def prelaunch_first_gather(
345
+ params: list[DTensor],
346
+ param_to_state: dict[int, _muon_state],
347
+ rank: int,
348
+ none_grad: bool,
349
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
350
+ """Launch the first chunk's A2A gather early for overlap with other compute.
351
+
352
+ Call this *before* expensive GPU work (e.g. batched expert NS) so that
353
+ the NCCL all-to-all runs concurrently on the NCCL stream while the
354
+ default stream executes compute.
355
+
356
+ Returns the same 4-tuple that ``_launch_gather`` produces, which should
357
+ be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
358
+ """
359
+ process_group = param_to_state[id(params[0])].process_group
360
+ num_ranks = dist.get_world_size(group=process_group)
361
+ owned_params = [
362
+ p for p in params if param_to_state[id(p)].worker_rank == rank
363
+ ]
364
+
365
+ with record_function("muon::prelaunch_gather"):
366
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
367
+ params, owned_params, param_to_state, rank, num_ranks,
368
+ process_group)
369
+
370
+ if none_grad:
371
+ for p in params:
372
+ p.grad = None
373
+
374
+ return work, recv_buf, gathered_grads, recv_counts
375
+
376
+
377
  # ======================================================================
378
  # Main generator – thin orchestrator that wires stages together.
379
  # ======================================================================
 
388
  lr: float,
389
  weight_decay: float,
390
  none_grad: bool,
391
+ prelaunch_gather: tuple | None = None,
392
  ) -> Generator[None, None, None]:
393
  """Process one chunk of parameters through the full Muon pipeline.
394
 
 
405
  runs concurrently on the NCCL stream — no separate ``comm_stream``
406
  is required.
407
 
408
+ If ``prelaunch_gather`` is provided, the gather was already launched
409
+ by :func:`prelaunch_first_gather` and we skip launching it again.
410
+
411
  Yields exactly **2** times:
412
 
413
+ 1. After launching async all-to-all gather (or immediately if pre-launched).
414
  2. After launching async all-to-all scatter.
415
  """
416
  process_group = param_to_state[id(params[0])].process_group
 
419
  p for p in params if param_to_state[id(p)].worker_rank == rank
420
  ]
421
 
422
+ if prelaunch_gather is not None:
423
+ # Gather was pre-launched; none_grad already handled by caller.
424
+ work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
425
+ else:
426
+ # Normal path: launch async gather.
427
+ with record_function("muon::launch_gather"):
428
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
429
+ params, owned_params, param_to_state, rank, num_ranks,
430
+ process_group)
431
+
432
+ if none_grad:
433
+ for p in params:
434
+ p.grad = None
435
 
436
  yield # --- YIELD 1: other chunks can launch their gather ---
437
 
build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import dataclass
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
@@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
  'model.4.attn.v_proj.weight' -> (None, -1)
25
  """
26
- parts = name.split('.')
27
  if len(parts) < 3:
28
  return None, -1
29
 
@@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state):
100
  threshold = qk_clip_state.threshold
101
  logit = qk_clip_state.logit
102
 
103
- H_global = p.shape[0] // head_dim
104
- scales_full = torch.ones(H_global, device=p.data.device)
105
- scaling = 0
106
-
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
- if new_scale < scales_full[head_idx]:
112
- scales_full[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
117
- scaling += 1
118
 
119
- return scales_full if scaling > 0 else None
 
 
 
 
 
 
 
120
 
121
 
122
  def qk_clip(p, scales, head_dim):
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
8
+ from .core import normalize_fqn
9
+
10
  logger = logging.getLogger(__name__)
11
 
12
 
 
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
+ parts = normalize_fqn(name).split('.')
29
  if len(parts) < 3:
30
  return None, -1
31
 
 
102
  threshold = qk_clip_state.threshold
103
  logit = qk_clip_state.logit
104
 
105
+ # Check if any head exceeds threshold before allocating.
106
+ head_scales = {}
 
 
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
+ if head_idx not in head_scales or new_scale < head_scales[head_idx]:
112
+ head_scales[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
 
117
 
118
+ if not head_scales:
119
+ return None
120
+
121
+ H_global = p.shape[0] // head_dim
122
+ scales_full = torch.ones(H_global, device=p.data.device)
123
+ for head_idx, scale in head_scales.items():
124
+ scales_full[head_idx] = scale
125
+ return scales_full
126
 
127
 
128
  def qk_clip(p, scales, head_dim):
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_7aef62f_dirty
3
- ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_7aef62f_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_5b58933_dirty
3
+ ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_5b58933_dirty::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4919c48c77c6223dbf668f1461bcec175ef1bd6ea4cec8c2509de12ca7200a62
3
  size 2004144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1abfa69cd254e0000246a074c0bfa53c2e72bb53cc5fa8216275295cd021c57a
3
  size 2004144
build/torch210-cxx11-cu128-x86_64-linux/adamw.py CHANGED
@@ -1,8 +1,12 @@
 
1
  from collections import defaultdict
2
  from typing import cast
3
 
4
  import torch
5
  from torch.distributed.tensor import DTensor
 
 
 
6
 
7
 
8
  def fused_adamw(
@@ -72,54 +76,72 @@ def fused_adamw(
72
  )
73
 
74
 
75
- def step_adamw_params(optimizer_state, params, group):
76
- """Run fused AdamW on a list of parameters sharing the same placement.
 
77
 
78
- Args:
79
- optimizer_state: The optimizer's state dict (self.state in Muon).
80
- params: List of parameters to update.
81
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  params_with_grads = []
84
  grads = []
85
  moment1 = []
86
  moment2 = []
87
- max_exp_avg_sqs = []
88
  state_steps = []
89
- lr = group["lr"]
90
- beta1, beta2 = group["adamw_betas"]
91
- eps = group["adamw_eps"]
92
- weight_decay = group["weight_decay"]
93
 
94
  for p in params:
95
  g = p.grad
96
  if g is None:
97
  continue
98
  state = optimizer_state[p]
99
- params_with_grads.append(p)
100
- grads.append(g)
101
  if "step" not in state:
102
- state["step"] = (torch.zeros((),
103
- dtype=torch.float32,
104
- device=p.device))
105
  state["moment1"] = torch.zeros_like(g)
106
  state["moment2"] = torch.zeros_like(g)
107
- moment1.append(state["moment1"])
108
- moment2.append(state["moment2"])
109
  if not isinstance(state["step"], torch.Tensor):
110
- step_tensor = torch.tensor(state["step"],
111
- dtype=torch.float32,
112
- device=p.device)
113
- else:
114
- step_tensor = state["step"]
115
- state_steps.append(step_tensor)
 
 
 
 
 
 
116
 
117
  fused_adamw(
118
  params_with_grads,
119
  grads,
120
  moment1,
121
  moment2,
122
- max_exp_avg_sqs,
123
  state_steps,
124
  amsgrad=False,
125
  beta1=beta1,
@@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group):
131
  )
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def step_adamw(optimizer_state, group):
135
  """Dispatch AdamW step, grouping parameters by type and placement.
136
 
 
 
 
137
  Args:
138
  optimizer_state: The optimizer's state dict (self.state in Muon).
139
  group: Parameter group dict.
140
  """
141
  params = group["params"]
 
142
 
143
- # group params with its type and placement
144
- placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
- for p in params:
146
- match p:
147
- case DTensor():
148
- placement_to_params[tuple([p.placements,
149
- p.device_mesh])].append(p)
150
- case torch.Tensor():
151
- placement_to_params[tuple([torch.Tensor, None])].append(p)
152
-
153
- for group_params in placement_to_params.values():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  step_adamw_params(optimizer_state, group_params, group)
 
1
+ import logging
2
  from collections import defaultdict
3
  from typing import cast
4
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ logger = logging.getLogger(__name__)
10
 
11
 
12
  def fused_adamw(
 
76
  )
77
 
78
 
79
+ def _to_local(t):
80
+ """Unwrap DTensor to local tensor for fused ops."""
81
+ return t._local_tensor if isinstance(t, DTensor) else t
82
 
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Caches for eliminating per-step Python overhead.
86
+ #
87
+ # Placement grouping and tensor list assembly are identical every step
88
+ # (params don't change placement, moment/step tensors are the same objects
89
+ # after initialisation). We cache them keyed by id() of the param list
90
+ # stored in param_groups (stable across steps).
91
+ #
92
+ # Only gradients change each step and must be collected fresh.
93
+ # ---------------------------------------------------------------------------
94
+
95
+ # id(group["params"]) → dict[placement_key, list[param]]
96
+ _placement_cache: dict[int, dict[tuple, list]] = {}
97
+
98
+ # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
+ _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
+
101
+
102
+ def _step_adamw_params_slow(optimizer_state, params, group):
103
+ """Uncached fallback for the rare case where some params lack grads."""
104
  params_with_grads = []
105
  grads = []
106
  moment1 = []
107
  moment2 = []
 
108
  state_steps = []
 
 
 
 
109
 
110
  for p in params:
111
  g = p.grad
112
  if g is None:
113
  continue
114
  state = optimizer_state[p]
115
+ params_with_grads.append(_to_local(p))
116
+ grads.append(_to_local(g))
117
  if "step" not in state:
118
+ state["step"] = torch.zeros((),
119
+ dtype=torch.float32,
120
+ device=p.device)
121
  state["moment1"] = torch.zeros_like(g)
122
  state["moment2"] = torch.zeros_like(g)
123
+ moment1.append(_to_local(state["moment1"]))
124
+ moment2.append(_to_local(state["moment2"]))
125
  if not isinstance(state["step"], torch.Tensor):
126
+ state["step"] = torch.tensor(state["step"],
127
+ dtype=torch.float32,
128
+ device=p.device)
129
+ state_steps.append(state["step"])
130
+
131
+ if not params_with_grads:
132
+ return
133
+
134
+ lr = group["lr"]
135
+ beta1, beta2 = group["adamw_betas"]
136
+ eps = group["adamw_eps"]
137
+ weight_decay = group["weight_decay"]
138
 
139
  fused_adamw(
140
  params_with_grads,
141
  grads,
142
  moment1,
143
  moment2,
144
+ [],
145
  state_steps,
146
  amsgrad=False,
147
  beta1=beta1,
 
153
  )
154
 
155
 
156
+ def step_adamw_params(optimizer_state, params, group):
157
+ """Run fused AdamW on a list of parameters sharing the same placement.
158
+
159
+ After the first call, cached tensor lists (params_local, moment1,
160
+ moment2, state_steps) are reused — only gradients are collected fresh.
161
+
162
+ Args:
163
+ optimizer_state: The optimizer's state dict (self.state in Muon).
164
+ params: List of parameters to update.
165
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
+ """
167
+ # Collect grads — the only thing that changes each step.
168
+ with record_function("adamw::collect_grads"):
169
+ grads = []
170
+ for p in params:
171
+ g = p.grad
172
+ if g is None:
173
+ # Rare: fall back to slow path that filters per-param.
174
+ _step_adamw_params_slow(optimizer_state, params, group)
175
+ return
176
+ grads.append(_to_local(g))
177
+
178
+ tensor_key = id(params)
179
+ if tensor_key not in _tensor_cache:
180
+ with record_function("adamw::init_tensor_cache"):
181
+ params_local = []
182
+ moment1 = []
183
+ moment2 = []
184
+ state_steps = []
185
+
186
+ for p in params:
187
+ state = optimizer_state[p]
188
+ params_local.append(_to_local(p))
189
+ if "step" not in state:
190
+ state["step"] = torch.zeros((),
191
+ dtype=torch.float32,
192
+ device=p.device)
193
+ state["moment1"] = torch.zeros_like(p.grad)
194
+ state["moment2"] = torch.zeros_like(p.grad)
195
+ moment1.append(_to_local(state["moment1"]))
196
+ moment2.append(_to_local(state["moment2"]))
197
+ if not isinstance(state["step"], torch.Tensor):
198
+ state["step"] = torch.tensor(state["step"],
199
+ dtype=torch.float32,
200
+ device=p.device)
201
+ state_steps.append(state["step"])
202
+
203
+ _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
+ state_steps)
205
+
206
+ params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
+
208
+ lr = group["lr"]
209
+ beta1, beta2 = group["adamw_betas"]
210
+ eps = group["adamw_eps"]
211
+ weight_decay = group["weight_decay"]
212
+
213
+ with record_function("adamw::fused_adamw"):
214
+ fused_adamw(
215
+ params_local,
216
+ grads,
217
+ moment1,
218
+ moment2,
219
+ [],
220
+ state_steps,
221
+ amsgrad=False,
222
+ beta1=beta1,
223
+ beta2=beta2,
224
+ lr=lr,
225
+ weight_decay=weight_decay,
226
+ eps=eps,
227
+ maximize=False,
228
+ )
229
+
230
+
231
  def step_adamw(optimizer_state, group):
232
  """Dispatch AdamW step, grouping parameters by type and placement.
233
 
234
+ Placement grouping is cached after the first call since params never
235
+ change their placement between steps.
236
+
237
  Args:
238
  optimizer_state: The optimizer's state dict (self.state in Muon).
239
  group: Parameter group dict.
240
  """
241
  params = group["params"]
242
+ placement_key = id(params)
243
 
244
+ if placement_key not in _placement_cache:
245
+ with record_function("adamw::group_by_placement"):
246
+ placement_to_params: dict[tuple,
247
+ list[torch.Tensor]] = defaultdict(list)
248
+ for p in params:
249
+ match p:
250
+ case DTensor():
251
+ logger.debug(
252
+ "[AdamW] DTensor param: shape=%s, placements=%s, "
253
+ "mesh=%s, grad=%s", p.shape, p.placements,
254
+ p.device_mesh.mesh_dim_names,
255
+ p.grad.shape if p.grad is not None else None)
256
+ placement_to_params[tuple(
257
+ [p.placements, p.device_mesh])].append(p)
258
+ case torch.Tensor():
259
+ logger.debug(
260
+ "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
+ p.grad.shape if p.grad is not None else None)
262
+ placement_to_params[tuple([torch.Tensor,
263
+ None])].append(p)
264
+
265
+ logger.debug("[AdamW] %d placement groups, %d total params",
266
+ len(placement_to_params), len(params))
267
+
268
+ _placement_cache[placement_key] = dict(placement_to_params)
269
+
270
+ for group_params in _placement_cache[placement_key].values():
271
  step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-cu128-x86_64-linux/core.py CHANGED
@@ -1,11 +1,25 @@
 
1
  import math
2
  from dataclasses import dataclass
 
3
 
4
  import torch
5
- import torch.distributed as dist
6
  from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @dataclass
11
  class _muon_state:
@@ -17,26 +31,71 @@ class _muon_state:
17
  qk_clip_state: torch.Tensor | None = None
18
 
19
 
20
- def update_g(optimizer_state, p, g, group, momentum):
21
- """Apply momentum update to gradient.
 
 
 
 
 
 
22
 
23
- Args:
24
- optimizer_state: The optimizer's state dict (self.state in Muon).
25
- p: Parameter tensor.
26
- g: Gradient tensor.
27
- group: Parameter group dict.
28
- momentum: Momentum coefficient.
29
 
30
- Returns:
31
- Momentum-updated gradient tensor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
- state = optimizer_state[p]
34
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
- torch.add(g, buf, alpha=momentum, out=buf)
36
- if group["nesterov"]:
37
- g.add_(buf, alpha=momentum)
38
- return g
39
- return buf
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay):
49
  adjusted_lr: Size-adjusted learning rate.
50
  weight_decay: Weight decay coefficient.
51
  """
52
- if isinstance(p, torch.nn.Parameter):
53
- # apply weight decay
54
- p.data.mul_(1 - lr * weight_decay)
55
- # apply update
56
- p.data.add_(u, alpha=-adjusted_lr)
57
- else:
58
- p.mul_(1 - lr * weight_decay)
59
- p.add_(u, alpha=-adjusted_lr)
60
 
61
 
62
  def adjust_lr_for_muon(lr, param_shape):
@@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape):
77
  return adjusted_lr
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def default_is_muon(name, x, expert_keys=None):
81
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
- if any(key in name for key in skip_keys):
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return False
84
  effective_ndim = x.ndim
85
- if expert_keys and any(key in name for key in expert_keys):
 
86
  effective_ndim -= 1
87
- return effective_ndim >= 2
 
 
 
 
 
88
 
89
 
90
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
@@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
92
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
 
94
  muon_params, muon_names = [], []
95
- non_muon_params = []
96
 
97
  for n, p in model.named_parameters():
98
  if not p.requires_grad:
@@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
102
  muon_names.append(n)
103
  else:
104
  non_muon_params.append(p)
 
 
 
 
105
 
106
  return [
107
  {
 
1
+ import logging
2
  import math
3
  from dataclasses import dataclass
4
+ from typing import List
5
 
6
  import torch
 
7
  from torch.distributed import ProcessGroup
8
  from torch.distributed.tensor import DTensor
9
 
10
+ # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
+ # parameter FQNs. Activation checkpointing similarly inserts
12
+ # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
+ # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
+ _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def normalize_fqn(name: str) -> str:
20
+ """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
+ return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
+
23
 
24
  @dataclass
25
  class _muon_state:
 
31
  qk_clip_state: torch.Tensor | None = None
32
 
33
 
34
+ def _batch_momentum(
35
+ grads: List[torch.Tensor],
36
+ momentum_bufs: List[torch.Tensor],
37
+ momentum: torch.Tensor,
38
+ ) -> None:
39
+ """Batched momentum update (no nesterov)."""
40
+ torch._foreach_mul_(momentum_bufs, momentum)
41
+ torch._foreach_add_(momentum_bufs, grads)
42
 
 
 
 
 
 
 
43
 
44
+ def _batch_momentum_nesterov(
45
+ grads: List[torch.Tensor],
46
+ momentum_bufs: List[torch.Tensor],
47
+ momentum: torch.Tensor,
48
+ ) -> None:
49
+ """Batched momentum update with nesterov correction."""
50
+ torch._foreach_mul_(momentum_bufs, momentum)
51
+ torch._foreach_add_(momentum_bufs, grads)
52
+ nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
+ torch._foreach_add_(grads, nesterov_terms)
54
+
55
+
56
+ _compiled_momentum: dict[bool, callable] = {}
57
+ _use_momentum_compile = True
58
+
59
+
60
+ def set_momentum_compile(enabled: bool):
61
+ """Toggle torch.compile for batched momentum."""
62
+ global _use_momentum_compile
63
+ _use_momentum_compile = enabled
64
+
65
+
66
+ def batch_pre_ortho(
67
+ grads: List[torch.Tensor],
68
+ momentum_bufs: List[torch.Tensor],
69
+ momentum: torch.Tensor,
70
+ nesterov: bool,
71
+ ) -> None:
72
+ """Batched momentum update on lists of plain tensors.
73
+
74
+ Mirrors dion's ``muon_update_pre_orthogonalize``.
75
+ Inputs must be plain CUDA tensors (not DTensor).
76
+ Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
+
78
+ When compile is enabled, uses separately compiled functions for
79
+ nesterov=True/False to avoid graph breaks from the branch.
80
  """
81
+ fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
+ if _use_momentum_compile:
83
+ if nesterov not in _compiled_momentum:
84
+ _compiled_momentum[nesterov] = torch.compile(fn)
85
+ fn = _compiled_momentum[nesterov]
86
+ fn(grads, momentum_bufs, momentum)
87
+
88
+
89
+ def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
+ """Weight-decay + update on plain tensors.
91
+
92
+ Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
+ lookup per call × 256+ params = massive overhead. The pipeline path uses
94
+ batched _foreach_* ops instead; this function remains for base() and
95
+ distributed_muon().
96
+ """
97
+ p_data.mul_(1 - lr * weight_decay)
98
+ p_data.add_(u_data, alpha=-adjusted_lr)
99
 
100
 
101
  def update_p(p, u, lr, adjusted_lr, weight_decay):
 
108
  adjusted_lr: Size-adjusted learning rate.
109
  weight_decay: Weight decay coefficient.
110
  """
111
+ # Unwrap Parameter -> underlying data tensor.
112
+ p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
+ # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
+ if isinstance(p_data, DTensor):
115
+ p_data = p_data._local_tensor
116
+ u_data = u._local_tensor if isinstance(u, DTensor) else u
117
+ _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
 
118
 
119
 
120
  def adjust_lr_for_muon(lr, param_shape):
 
135
  return adjusted_lr
136
 
137
 
138
+ def _match_key(parts, key):
139
+ """Check if key matches as contiguous components in parts.
140
+
141
+ Single-component keys (e.g. "experts") match any single component.
142
+ Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
+ """
144
+ key_parts = key.split(".")
145
+ key_len = len(key_parts)
146
+ if key_len == 1:
147
+ return key in parts
148
+ return any(parts[i:i + key_len] == key_parts
149
+ for i in range(len(parts) - key_len + 1))
150
+
151
+
152
+ def is_expert_param(name, expert_keys):
153
+ """Check if a parameter name matches any expert key (component-level)."""
154
+ if not expert_keys:
155
+ return False
156
+ parts = normalize_fqn(name).split(".")
157
+ return any(_match_key(parts, key) for key in expert_keys)
158
+
159
+
160
  def default_is_muon(name, x, expert_keys=None):
161
+ normalized = normalize_fqn(name)
162
+ parts = normalized.split(".")
163
+ skip_keys = [
164
+ "embed_tokens",
165
+ "lm_head",
166
+ "tok_embeddings",
167
+ "output",
168
+ "mhc_attn",
169
+ "mhc_ffn",
170
+ "lambda_proj",
171
+ ]
172
+ if any(key in parts for key in skip_keys):
173
+ logger.info(
174
+ "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
+ normalized, name, x.ndim)
176
  return False
177
  effective_ndim = x.ndim
178
+ is_expert = is_expert_param(name, expert_keys)
179
+ if is_expert:
180
  effective_ndim -= 1
181
+ result = effective_ndim >= 2
182
+ logger.info(
183
+ "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
+ normalized, name, x.ndim, is_expert, effective_ndim,
185
+ "Muon" if result else "AdamW")
186
+ return result
187
 
188
 
189
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
 
191
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
 
193
  muon_params, muon_names = [], []
194
+ non_muon_params, non_muon_names = [], []
195
 
196
  for n, p in model.named_parameters():
197
  if not p.requires_grad:
 
201
  muon_names.append(n)
202
  else:
203
  non_muon_params.append(p)
204
+ non_muon_names.append(n)
205
+
206
+ logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
+ expert_keys, len(muon_names), len(non_muon_names))
208
 
209
  return [
210
  {
build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CPU offloading for optimizer states.
2
+
3
+ Manages a pinned CPU memory pool and async CUDA streams to offload
4
+ optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
+ optimizer steps, freeing GPU memory.
6
+
7
+ All tracked tensors are packed into a single flat pinned CPU buffer
8
+ (per dtype). D2H and H2D copies are performed per-tensor directly
9
+ between individual GPU tensors and their slice of the CPU flat buffer
10
+ — no GPU staging buffer is allocated, so there is **no temporary GPU
11
+ memory spike** during offload or reload.
12
+
13
+ Individual tensor storages are freed after offload via
14
+ ``untyped_storage().resize_(0)``, preserving tensor identity so
15
+ downstream caches remain valid.
16
+ """
17
+
18
+ import logging
19
+ from collections import defaultdict
20
+
21
+ import torch
22
+ from torch.distributed.tensor import DTensor
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class CPUOffloadPool:
28
+ """Pinned CPU memory pool for async optimizer state offloading.
29
+
30
+ Tracked tensors are grouped by dtype. Each group gets a single flat
31
+ pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
+ the flat buffer) to avoid allocating a GPU staging buffer.
33
+ """
34
+
35
+ def __init__(self):
36
+ self._managed: list[torch.Tensor] = []
37
+ self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
+
39
+ # Per-dtype group: populated on first offload.
40
+ # dtype → dict with keys:
41
+ # "indices" : list[int] managed-list indices
42
+ # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
+ # "total" : int total numel
44
+ # "cpu_flat" : Tensor pinned CPU buffer
45
+ self._groups: dict[torch.dtype, dict] = {}
46
+
47
+ self._offload_stream: torch.cuda.Stream | None = None
48
+ self._device: torch.device | None = None
49
+ self._initialized: bool = False
50
+ self._logged: bool = False
51
+
52
+ # ------------------------------------------------------------------
53
+ @staticmethod
54
+ def _local(t: torch.Tensor) -> torch.Tensor:
55
+ """Unwrap DTensor to its local CUDA tensor."""
56
+ return t._local_tensor if isinstance(t, DTensor) else t
57
+
58
+ def _ensure_stream(self):
59
+ if self._offload_stream is None:
60
+ self._offload_stream = torch.cuda.Stream(device=self._device)
61
+
62
+ # ------------------------------------------------------------------
63
+ def track(self, tensor: torch.Tensor):
64
+ """Register a GPU tensor for CPU offloading. Idempotent."""
65
+ tid = id(tensor)
66
+ if tid in self._storage_nbytes:
67
+ return
68
+ local = self._local(tensor)
69
+ if self._device is None:
70
+ self._device = local.device
71
+ self._storage_nbytes[tid] = local.untyped_storage().size()
72
+ self._managed.append(tensor)
73
+
74
+ # ------------------------------------------------------------------
75
+ def _init_buffers(self):
76
+ """Build per-dtype flat buffers on first offload."""
77
+ # Group managed tensors by dtype.
78
+ dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
79
+ for idx, t in enumerate(self._managed):
80
+ local = self._local(t)
81
+ dtype_map[local.dtype].append((idx, local.numel()))
82
+
83
+ total_cpu_bytes = 0
84
+ for dtype, entries in dtype_map.items():
85
+ offsets: list[tuple[int, int]] = []
86
+ indices: list[int] = []
87
+ off = 0
88
+ for idx, n in entries:
89
+ indices.append(idx)
90
+ offsets.append((off, n))
91
+ off += n
92
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
93
+ self._groups[dtype] = {
94
+ "indices": indices,
95
+ "offsets": offsets,
96
+ "total": off,
97
+ "cpu_flat": cpu_flat,
98
+ }
99
+ total_cpu_bytes += off * cpu_flat.element_size()
100
+
101
+ self._initialized = True
102
+ logger.info(
103
+ "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
104
+ "%.2f MB pinned CPU memory",
105
+ len(self._managed),
106
+ len(self._groups),
107
+ total_cpu_bytes / (1024**2),
108
+ )
109
+
110
+ # ------------------------------------------------------------------
111
+ def offload(self):
112
+ """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
113
+ if not self._managed:
114
+ return
115
+ if not self._initialized:
116
+ self._init_buffers()
117
+ self._ensure_stream()
118
+
119
+ # Offload stream waits for compute to finish.
120
+ compute_event = torch.cuda.current_stream(
121
+ self._device).record_event()
122
+ self._offload_stream.wait_event(compute_event)
123
+
124
+ offloaded_bytes = 0
125
+
126
+ # Per-tensor D2H copies directly into CPU flat buffer slices.
127
+ # No GPU staging buffer → no temporary GPU memory spike.
128
+ with torch.cuda.stream(self._offload_stream):
129
+ for dtype, grp in self._groups.items():
130
+ indices = grp["indices"]
131
+ offsets = grp["offsets"]
132
+ cpu_flat = grp["cpu_flat"]
133
+
134
+ for i, mgd_idx in enumerate(indices):
135
+ local = self._local(self._managed[mgd_idx])
136
+ off, n = offsets[i]
137
+ cpu_flat[off:off + n].copy_(
138
+ local.reshape(-1), non_blocking=True)
139
+
140
+ offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
+
142
+ # Wait for all D2H copies to land, then free GPU storage.
143
+ self._offload_stream.synchronize()
144
+ for t in self._managed:
145
+ self._local(t).untyped_storage().resize_(0)
146
+
147
+ if not self._logged:
148
+ logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
149
+ offloaded_bytes / (1024**2))
150
+
151
+ # ------------------------------------------------------------------
152
+ def reload(self):
153
+ """Per-tensor H2D from CPU flat buffer on the default stream.
154
+
155
+ Runs on the current (default) CUDA stream to avoid stream
156
+ interaction issues with the parallel Muon pipeline. Since
157
+ pinned CPU memory is the source, the copies overlap with
158
+ GPU idle time between steps.
159
+ """
160
+ if not self._managed or not self._initialized:
161
+ return
162
+
163
+ reloaded_bytes = 0
164
+
165
+ # Re-allocate all GPU storages first.
166
+ for t in self._managed:
167
+ local = self._local(t)
168
+ local.untyped_storage().resize_(self._storage_nbytes[id(t)])
169
+
170
+ # Per-tensor H2D copies from CPU flat buffer slices.
171
+ # non_blocking=True with pinned source allows DMA overlap.
172
+ for dtype, grp in self._groups.items():
173
+ indices = grp["indices"]
174
+ offsets = grp["offsets"]
175
+ cpu_flat = grp["cpu_flat"]
176
+
177
+ for i, mgd_idx in enumerate(indices):
178
+ local = self._local(self._managed[mgd_idx])
179
+ off, n = offsets[i]
180
+ local.reshape(-1).copy_(
181
+ cpu_flat[off:off + n], non_blocking=True)
182
+
183
+ reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
+
185
+ if not self._logged:
186
+ logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)",
187
+ reloaded_bytes / (1024**2))
188
+ self._logged = True
build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py CHANGED
@@ -72,12 +72,6 @@ def get_slices_of_dtensor(
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
75
- if curr_size % num_chunks != 0:
76
- raise NotImplementedError(
77
- f"Dimension size {curr_size} is not divisible "
78
- f"by number of ranks {num_chunks} for shard "
79
- f"placement on dim {shard_dim}. (shape: {target.shape})")
80
-
81
  # Compute indices for this level of sharding
82
  if isinstance(placement, _StridedShard):
83
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
 
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
 
 
 
 
 
 
75
  # Compute indices for this level of sharding
76
  if isinstance(placement, _StridedShard):
77
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -43,6 +43,7 @@ def get_autotune_config():
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
 
46
  )
47
  @triton.jit
48
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
@@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
102
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
 
104
 
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
  d_in = d_in.contiguous()
116
  M, K = d_in.shape
117
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
@@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
46
+ restore_value=['y'],
47
  )
48
  @triton.jit
49
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
 
103
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
 
105
 
106
+ @torch.library.custom_op("muon::matmul_transpose_assign",
107
+ mutates_args=("d_out", ))
108
+ def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
+ """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
 
 
 
 
 
 
110
  d_in = d_in.contiguous()
111
  M, K = d_in.shape
112
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
 
114
  with torch.cuda.device(d_in.device.index):
115
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
  d_out.stride(0), d_out.stride(1))
117
+
118
+
119
+ @matmul_transpose_assign.register_fake
120
+ def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
+ """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
+ pass
build/torch210-cxx11-cu128-x86_64-linux/muon.py CHANGED
@@ -10,13 +10,16 @@ from torch.profiler import record_function
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon,
14
- get_default_muon_param_groups, update_g, update_p)
 
15
  from .distributed.utils import (_is_shard, construct_shard_mesh,
16
  get_slices_of_dtensor)
17
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
- _zeropower_via_newtonschulz5)
19
- from .pipeline import muon_chunk_pipeline
 
 
20
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
@@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys):
45
  expanded_params = []
46
 
47
  for n, p in zip(names, params):
48
- is_expert = expert_keys and any(key in n for key in expert_keys)
49
  is_dtensor = isinstance(p.data, DTensor)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if not is_expert:
52
  assert p.data.ndim <= 2, (
53
  f"Param {n} has ndim={p.data.ndim} but does not match "
@@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer):
168
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
  expert_keys: List of strings to identify expert-parallel parameters.
173
  If any key appears in a parameter's name, its outermost
174
  dimension is treated as the expert dimension and expanded
@@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer):
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
- small_param_numel_threshold=65536,
197
- expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
@@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer):
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
- self.small_param_numel_threshold = small_param_numel_threshold
232
  self.expert_keys = expert_keys
 
 
 
 
 
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
@@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer):
333
  if g is None:
334
  continue
335
 
336
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
- steps=group["ns_steps"])
338
 
339
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
  update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer):
355
  weight_decay: float,
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
- """ Implementation of Distributed Muon by Liu et al. """
359
 
360
- # Momentum is already applied by _step_muon before this method.
361
- for n, p in zip(names, params):
362
- g = p.grad
363
- if g is None:
364
- continue
365
-
366
- # Gather G
367
- if isinstance(p.data, DTensor):
368
- g_full = g.full_tensor()
369
- p_full = p.data.full_tensor()
370
- else:
371
- g_full = g
372
- p_full = p
373
-
374
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
- steps=group["ns_steps"])
376
-
377
- adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
- update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- scales_full = compute_scales(
383
- p_full, qk_clip_state) if qk_clip_state is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
- if scales_full is not None:
386
- qk_clip(p_full, scales_full, qk_clip_state.head_dim)
 
 
387
 
388
- if isinstance(p.data, DTensor):
389
- ndims = len(p.device_mesh.mesh.shape)
390
- p_replicate = DTensor.from_local(
391
- p_full,
392
- device_mesh=p.device_mesh,
393
- placements=[Replicate() for _ in range(ndims)],
394
- )
395
 
396
- p_sharded = p_replicate.redistribute(
397
- device_mesh=p.device_mesh,
398
- placements=p.placements,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  )
400
 
401
- p.copy_(p_sharded)
402
 
403
- def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
406
 
@@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer):
409
  interleaves multiple chunks so that communication and computation
410
  overlap across chunks (the same overlap previously achieved by the
411
  warmup + main-loop index scheduling).
 
 
 
 
412
  """
413
 
414
  # Momentum is already applied by _step_muon before this method.
415
 
416
- param_to_state, ordered_params = self.init_state_and_assign_params(
417
- names, params, group, qk_logits)
418
-
419
- # Compute local rank for this group's shard process group.
420
- shard_pg = param_to_state[id(ordered_params[0])].process_group
421
- rank = dist.get_rank(group=shard_pg)
422
-
423
- if self.chunk_size == -1:
424
- shard_ranks = dist.get_world_size(param_to_state[id(
425
- ordered_params[0])].process_group)
426
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
- elif self.chunk_size > 0:
428
- chunk_size = self.chunk_size
429
- else:
430
- raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
  def pipelines():
 
433
  for start in range(0, len(ordered_params), chunk_size):
434
  chunk = ordered_params[start:start + chunk_size]
435
  if chunk:
436
- yield muon_chunk_pipeline(
437
  params=chunk,
438
  param_to_state=param_to_state,
439
  rank=rank,
@@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer):
442
  weight_decay=weight_decay,
443
  none_grad=group["none_grad"],
444
  )
 
 
 
 
445
 
446
- with record_function("muon::barrier"):
447
- dist.barrier()
448
  with record_function("muon::pipeline"):
449
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
450
 
@@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer):
456
  names = group["names"]
457
 
458
  # Apply momentum to all params before routing/expansion.
 
459
  with record_function("muon::momentum"):
460
- for n, p in zip(names, params):
461
- g = p.grad
462
- if g is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  continue
464
- g = update_g(self.state, p, g, group, momentum)
465
- p.grad = g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Expand expert params by splitting on dim 0.
468
- names, params = _expand_expert_params(names, params, self.expert_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  param_dtensors = []
471
  name_dtensors = []
@@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer):
473
  param_tensors = []
474
  name_tensors = []
475
 
476
- param_dtensors_small = []
477
- name_dtensors_small = []
478
-
479
  if self.use_distributed_muon:
 
480
  self.distributed_muon(names=names,
481
  params=params,
482
  group=group,
@@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer):
485
  qk_logits=qk_logits)
486
  return
487
 
488
- # For simplicity, we use distributed Muon for small parameters
489
- # whose number of elements is below a threshold.
490
  for n, p in zip(names, params):
491
  if p is None or p.grad is None:
492
  continue
@@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer):
494
  if all(
495
  isinstance(placement, Replicate)
496
  for placement in p.placements):
 
 
 
497
  param_tensors.append(p)
498
  name_tensors.append(n)
499
- elif p.data.numel() <= self.small_param_numel_threshold:
500
- param_dtensors_small.append(p)
501
- name_dtensors_small.append(n)
502
  else:
 
 
 
 
503
  param_dtensors.append(p)
504
  name_dtensors.append(n)
505
  elif isinstance(p.data, torch.Tensor):
 
 
506
  param_tensors.append(p)
507
  name_tensors.append(n)
508
  else:
509
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
510
 
511
- logger.debug(
512
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
513
- f"{len(param_dtensors_small)} Small DTensors")
514
 
515
  def group_dtensors(dtensors, names):
516
  # To support different placements, we group parameters by placements
@@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer):
526
  p.device_mesh])][1].append(p)
527
  return placement_to_params
528
 
529
- if len(param_dtensors_small) > 0:
530
- if not dist.is_initialized():
531
- raise RuntimeError(
532
- "Parallel Muon requires torch.distributed to be initialized."
533
- )
534
-
535
- self.distributed_muon(
536
- params=param_dtensors_small,
537
- names=name_dtensors_small,
538
- group=group,
539
- lr=lr,
540
- weight_decay=weight_decay,
541
- qk_logits=qk_logits,
542
- )
543
-
544
  if len(param_dtensors) > 0:
545
  if not dist.is_initialized():
546
  raise RuntimeError(
@@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer):
548
  )
549
 
550
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  for _, (names, params) in dtensor_group.items():
 
 
552
  self.parallel(
553
  names,
554
  params,
@@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer):
556
  lr=lr,
557
  weight_decay=weight_decay,
558
  qk_logits=qk_logits,
 
559
  )
 
 
560
 
561
  if len(param_tensors) > 0:
562
  self.base(
@@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer):
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
@@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer):
585
  with torch.enable_grad():
586
  loss = closure()
587
 
588
- for group in self.param_groups:
 
 
 
 
 
 
 
589
  if group["use_muon"]:
 
 
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
 
 
 
592
  step_adamw(self.state, group)
593
 
 
 
 
 
 
 
 
594
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
+ get_default_muon_param_groups, is_expert_param, update_p)
15
+ from .cpu_offload import CPUOffloadPool
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
18
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
+ _zeropower_via_newtonschulz5,
20
+ zeropower_via_newtonschulz5,
21
+ zeropower_via_newtonschulz5_batched)
22
+ from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
23
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
24
 
25
  logger = logging.getLogger(__name__)
 
48
  expanded_params = []
49
 
50
  for n, p in zip(names, params):
51
+ is_expert = is_expert_param(n, expert_keys)
52
  is_dtensor = isinstance(p.data, DTensor)
53
 
54
+ if is_expert:
55
+ if is_dtensor:
56
+ logger.debug(
57
+ "[expand_expert] %s: expert DTensor, shape=%s, "
58
+ "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
59
+ p.placements, p.device_mesh.mesh_dim_names,
60
+ p.to_local().shape)
61
+ else:
62
+ logger.debug(
63
+ "[expand_expert] %s: expert plain tensor, shape=%s", n,
64
+ p.data.shape)
65
+
66
  if not is_expert:
67
  assert p.data.ndim <= 2, (
68
  f"Param {n} has ndim={p.data.ndim} but does not match "
 
183
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
184
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
185
  For testing purpose only.
 
186
  expert_keys: List of strings to identify expert-parallel parameters.
187
  If any key appears in a parameter's name, its outermost
188
  dimension is treated as the expert dimension and expanded
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None,
211
+ cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
 
242
  self.warmup_step = warmup_step
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
 
245
  self.expert_keys = expert_keys
246
+ self.cpu_offload = cpu_offload
247
+ self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
+ self._offload_initialized = False
249
+ self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
+ self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
251
 
252
  def _calc_flops(self, G, steps):
253
  assert len(G.shape) == 2
 
351
  if g is None:
352
  continue
353
 
354
+ u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
355
+ steps=group["ns_steps"])
356
 
357
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
358
  update_p(p, u, lr, adjusted_lr, weight_decay)
 
373
  weight_decay: float,
374
  qk_logits: list[torch.Tensor | DTensor] | None,
375
  ):
376
+ """Batched Distributed Muon for testing/correctness verification only.
377
 
378
+ Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
379
+ the full grad, then slices back to local shards. This is simpler but
380
+ slower than the parallel pipeline (all2all) path, so it serves as a
381
+ reference implementation for verifying correctness.
382
+ """
383
+ with record_function("distributed_muon"):
384
+ # Momentum is already applied by _step_muon before this method.
385
+ ns_steps = group["ns_steps"]
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ # Separate plain tensors (no communication) from DTensors.
388
+ plain_names, plain_params = [], []
389
+ dtensor_names, dtensor_params = [], []
390
+ for n, p in zip(names, params):
391
+ if p.grad is None:
392
+ continue
393
+ if isinstance(p.data, DTensor):
394
+ dtensor_names.append(n)
395
+ dtensor_params.append(p)
396
+ else:
397
+ plain_names.append(n)
398
+ plain_params.append(p)
399
+
400
+ # Process plain tensors per-param (no communication).
401
+ for n, p in zip(plain_names, plain_params):
402
+ u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
403
+ steps=ns_steps)
404
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
405
+ update_p(p, u, lr, adjusted_lr, weight_decay)
406
+
407
+ qk_clip_state = get_qk_clip_info(self.clip_config, n,
408
+ qk_logits)
409
+ scales_full = compute_scales(
410
+ p, qk_clip_state) if qk_clip_state is not None else None
411
+ if scales_full is not None:
412
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
413
+
414
+ if not dtensor_params:
415
+ return
416
+
417
+ # Group DTensors by (placements, mesh) for batched all-gather.
418
+ placement_groups: dict[tuple,
419
+ tuple[list,
420
+ list]] = defaultdict(lambda: ([], []))
421
+ for n, p in zip(dtensor_names, dtensor_params):
422
+ key = (p.placements, p.device_mesh)
423
+ placement_groups[key][0].append(n)
424
+ placement_groups[key][1].append(p)
425
+
426
+ logger.info(
427
+ "distributed_muon: %d placement groups, %d total dtensors",
428
+ len(placement_groups), len(dtensor_params))
429
+
430
+ for (placements, mesh), (grp_names,
431
+ grp_params) in placement_groups.items():
432
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
433
+ placements, mesh)
434
+ rank = dist.get_rank(shard_pg)
435
+ world_size = dist.get_world_size(shard_pg)
436
+
437
+ logger.info(" group: %d params, placements=%s, world_size=%d",
438
+ len(grp_params), placements, world_size)
439
+
440
+ # Separate params that can be batched (all shard dims evenly
441
+ # divisible) from those needing per-param full_tensor
442
+ # (e.g. MoE gate weights with fewer rows than shard ranks).
443
+ # all_gather_into_tensor requires equal buffer sizes across
444
+ # ranks, so uneven splits must use DTensor full_tensor().
445
+ batch_names, batch_params = [], []
446
+ single_names, single_params = [], []
447
+ for n, p in zip(grp_names, grp_params):
448
+ even = all(p.shape[pl.dim] %
449
+ shard_mesh.mesh.shape[dim_idx] == 0
450
+ for dim_idx, pl in enumerate(shard_placements))
451
+ if even:
452
+ batch_names.append(n)
453
+ batch_params.append(p)
454
+ else:
455
+ single_names.append(n)
456
+ single_params.append(p)
457
+
458
+ # Process uneven-split params per-param via full_tensor().
459
+ for n, p in zip(single_names, single_params):
460
+ with record_function("distributed_muon::newton_schulz"):
461
+ g_full = p.grad.full_tensor().to(COMM_DTYPE)
462
+ u_full = _zeropower_via_newtonschulz5(g_full,
463
+ steps=ns_steps)
464
+ del g_full
465
+ with record_function("distributed_muon::update"):
466
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
467
+ p._local_tensor.mul_(1 - lr * weight_decay)
468
+ local_indices = get_slices_of_dtensor(
469
+ p, rank, shard_mesh, shard_placements)
470
+ u_local = u_full[local_indices]
471
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
472
+ del u_full
473
+
474
+ qk_clip_state = get_qk_clip_info(
475
+ self.clip_config, n, qk_logits)
476
+ scales_full = compute_scales(
477
+ p, qk_clip_state
478
+ ) if qk_clip_state is not None else None
479
+ if scales_full is not None:
480
+ ratio = p.shape[0] // scales_full.shape[0]
481
+ idx0 = local_indices[0]
482
+ if isinstance(idx0, slice):
483
+ start = idx0.start or 0
484
+ idx0 = torch.arange(start,
485
+ idx0.stop,
486
+ device=scales_full.device)
487
+ row_scales = scales_full[idx0 // ratio]
488
+ p._local_tensor.mul_(row_scales.view(-1, 1))
489
+
490
+ if not batch_params:
491
+ continue
492
 
493
+ logger.info(" batched=%d, single=%d", len(batch_params),
494
+ len(single_params))
495
+
496
+ # Concat all local grad shards into a single flat buffer.
497
+ with record_function("distributed_muon::gather"):
498
+ grad_locals = [
499
+ p.grad.to_local().to(COMM_DTYPE).flatten()
500
+ for p in batch_params
501
+ ]
502
+ numels = [g.numel() for g in grad_locals]
503
+ grad_concat = torch.cat(grad_locals)
504
+ del grad_locals
505
+
506
+ # Single all-gather (replaces N separate full_tensor).
507
+ grad_gathered = torch.empty(
508
+ grad_concat.numel() * world_size,
509
+ dtype=COMM_DTYPE,
510
+ device="cuda",
511
+ )
512
+ dist.all_gather_into_tensor(grad_gathered,
513
+ grad_concat,
514
+ group=shard_pg)
515
+
516
+ total_numel = grad_concat.numel()
517
+ del grad_concat
518
+
519
+ # Precompute per-param offsets within the concat buffer.
520
+ offsets = []
521
+ off = 0
522
+ for ne in numels:
523
+ offsets.append(off)
524
+ off += ne
525
+
526
+ # Per-param: reconstruct full grad → NS → local update.
527
+ for i, (n, p) in enumerate(zip(batch_names, batch_params)):
528
+ with record_function("distributed_muon::newton_schulz"):
529
+ g_full = torch.empty(p.shape,
530
+ dtype=COMM_DTYPE,
531
+ device="cuda")
532
+ for r in range(world_size):
533
+ r_start = r * total_numel + offsets[i]
534
+ shard = grad_gathered[r_start:r_start + numels[i]]
535
+ indices = get_slices_of_dtensor(
536
+ p, r, shard_mesh, shard_placements)
537
+ g_full[indices] = shard.reshape(
538
+ g_full[indices].shape)
539
+
540
+ u_full = _zeropower_via_newtonschulz5(g_full,
541
+ steps=ns_steps)
542
+ del g_full
543
+
544
+ with record_function("distributed_muon::update"):
545
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
546
+ p._local_tensor.mul_(1 - lr * weight_decay)
547
+ local_indices = get_slices_of_dtensor(
548
+ p, rank, shard_mesh, shard_placements)
549
+ u_local = u_full[local_indices]
550
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
551
+ del u_full
552
+
553
+ qk_clip_state = get_qk_clip_info(
554
+ self.clip_config, n, qk_logits)
555
+ scales_full = compute_scales(
556
+ p, qk_clip_state
557
+ ) if qk_clip_state is not None else None
558
+ if scales_full is not None:
559
+ ratio = p.shape[0] // scales_full.shape[0]
560
+ idx0 = local_indices[0]
561
+ if isinstance(idx0, slice):
562
+ start = idx0.start or 0
563
+ idx0 = torch.arange(start,
564
+ idx0.stop,
565
+ device=scales_full.device)
566
+ row_scales = scales_full[idx0 // ratio]
567
+ p._local_tensor.mul_(row_scales.view(-1, 1))
568
+
569
+ def _setup_parallel(self, names, params, group, qk_logits):
570
+ """Compute (or retrieve cached) parallel pipeline metadata.
571
+
572
+ Returns:
573
+ (ordered_params, param_to_state, rank, chunk_size)
574
+ """
575
+ cache_key = tuple(names)
576
 
577
+ if cache_key not in self._parallel_cache:
578
+ # First call: compute metadata and populate cache.
579
+ param_to_state, ordered_params = self.init_state_and_assign_params(
580
+ names, params, group, qk_logits)
581
 
582
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
583
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
584
 
585
+ if self.chunk_size == -1:
586
+ shard_ranks = dist.get_world_size(shard_pg)
587
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
588
+ elif self.chunk_size > 0:
589
+ chunk_size = self.chunk_size
590
+ else:
591
+ raise ValueError(
592
+ "chunk_size must be -1 or a positive integer.")
593
+
594
+ ordered_names = [
595
+ param_to_state[id(p)].name for p in ordered_params
596
+ ]
597
+ name_to_state = {
598
+ param_to_state[id(p)].name: param_to_state[id(p)]
599
+ for p in ordered_params
600
+ }
601
+ self._parallel_cache[cache_key] = {
602
+ 'ordered_names': ordered_names,
603
+ 'name_to_state': name_to_state,
604
+ 'rank': rank,
605
+ 'chunk_size': chunk_size,
606
+ }
607
+ else:
608
+ # Cached path: rebuild param_to_state with current id(p) keys.
609
+ cache = self._parallel_cache[cache_key]
610
+ rank = cache['rank']
611
+ chunk_size = cache['chunk_size']
612
+
613
+ name_to_param = dict(zip(names, params))
614
+ ordered_params = [name_to_param[n] for n in cache['ordered_names']]
615
+
616
+ param_to_state = {}
617
+ for p, n in zip(ordered_params, cache['ordered_names']):
618
+ cached_state = cache['name_to_state'][n]
619
+ param_to_state[id(p)] = _muon_state(
620
+ worker_rank=cached_state.worker_rank,
621
+ process_group=cached_state.process_group,
622
+ rank_indices=cached_state.rank_indices,
623
+ rank_numels=cached_state.rank_numels,
624
+ name=n,
625
+ qk_clip_state=get_qk_clip_info(self.clip_config, n,
626
+ qk_logits),
627
  )
628
 
629
+ return ordered_params, param_to_state, rank, chunk_size
630
 
631
+ def parallel(self,
632
+ names,
633
+ params,
634
+ group,
635
+ lr,
636
+ weight_decay,
637
+ qk_logits,
638
+ prelaunch_gather=None):
639
  """
640
  Perform a parallel optimization step using Muon.
641
 
 
644
  interleaves multiple chunks so that communication and computation
645
  overlap across chunks (the same overlap previously achieved by the
646
  warmup + main-loop index scheduling).
647
+
648
+ If ``prelaunch_gather`` is provided, it is passed to the first
649
+ chunk's generator to skip re-launching the already in-flight
650
+ A2A gather.
651
  """
652
 
653
  # Momentum is already applied by _step_muon before this method.
654
 
655
+ ordered_params, param_to_state, rank, chunk_size = (
656
+ self._setup_parallel(names, params, group, qk_logits))
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
  def pipelines():
659
+ first = True
660
  for start in range(0, len(ordered_params), chunk_size):
661
  chunk = ordered_params[start:start + chunk_size]
662
  if chunk:
663
+ kwargs = dict(
664
  params=chunk,
665
  param_to_state=param_to_state,
666
  rank=rank,
 
669
  weight_decay=weight_decay,
670
  none_grad=group["none_grad"],
671
  )
672
+ if first and prelaunch_gather is not None:
673
+ kwargs['prelaunch_gather'] = prelaunch_gather
674
+ first = False
675
+ yield muon_chunk_pipeline(**kwargs)
676
 
 
 
677
  with record_function("muon::pipeline"):
678
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
679
 
 
685
  names = group["names"]
686
 
687
  # Apply momentum to all params before routing/expansion.
688
+ # Batched using _foreach_* ops (compiled, fullgraph=True).
689
  with record_function("muon::momentum"):
690
+ active_params = [p for p in params if p.grad is not None]
691
+ if active_params:
692
+ # Ensure momentum buffers exist (avoid zeros_like when already present).
693
+ for p in active_params:
694
+ if "momentum_buffer" not in self.state[p]:
695
+ self.state[p]["momentum_buffer"] = torch.zeros_like(
696
+ p.grad)
697
+
698
+ # Extract local tensors for compiled batch function.
699
+ local_grads = [
700
+ p.grad._local_tensor
701
+ if isinstance(p.grad, DTensor) else p.grad
702
+ for p in active_params
703
+ ]
704
+ local_bufs = [
705
+ self.state[p]["momentum_buffer"]._local_tensor
706
+ if isinstance(self.state[p]["momentum_buffer"], DTensor)
707
+ else self.state[p]["momentum_buffer"]
708
+ for p in active_params
709
+ ]
710
+
711
+ # Wrap momentum as tensor for torch.compile.
712
+ batch_pre_ortho(local_grads, local_bufs,
713
+ torch.tensor(momentum), group["nesterov"])
714
+
715
+ # For non-nesterov, the result is the momentum buffer.
716
+ if not group["nesterov"]:
717
+ for p in active_params:
718
+ p.grad = self.state[p]["momentum_buffer"]
719
+
720
+ # Identify batched experts for deferred NS.
721
+ # Detection is cheap (condition checks only); actual NS compute is
722
+ # deferred so it can overlap with the first chunk's A2A gather.
723
+ deferred_expert_work = []
724
+ if self.expert_keys:
725
+ batched_expert_indices = []
726
+ for i, (n, p) in enumerate(zip(names, params)):
727
+ if not (is_expert_param(n, self.expert_keys)
728
+ and p.grad is not None):
729
  continue
730
+ # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
731
+ if isinstance(p.data, DTensor):
732
+ has_tp = any(
733
+ _is_shard(pl) and pl.dim != 0 for pl in p.placements)
734
+ if has_tp:
735
+ continue
736
+ batched_expert_indices.append(i)
737
+
738
+ if batched_expert_indices:
739
+ # Save refs for deferred NS; free grads from param list.
740
+ for i in batched_expert_indices:
741
+ p = params[i]
742
+ g = p.grad
743
+ local_g = (g._local_tensor
744
+ if isinstance(g, DTensor) else g)
745
+ local_data = (p.data._local_tensor if isinstance(
746
+ p.data, DTensor) else p.data)
747
+ deferred_expert_work.append((local_data, local_g))
748
+ p.grad = None
749
+
750
+ # Remove batched experts from lists before expansion.
751
+ keep = sorted(
752
+ set(range(len(params))) - set(batched_expert_indices))
753
+ names = [names[i] for i in keep]
754
+ params = [params[i] for i in keep]
755
+
756
+ def _run_deferred_expert_ns():
757
+ """Execute deferred batched expert NS."""
758
+ if not deferred_expert_work:
759
+ return
760
+ with record_function("muon::batched_expert_ns"):
761
+ ns_steps = group["ns_steps"]
762
+ for local_data, local_g in deferred_expert_work:
763
+ u = zeropower_via_newtonschulz5_batched(
764
+ local_g.to(COMM_DTYPE), steps=ns_steps)
765
+ adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
766
+ local_data.mul_(1 - lr * weight_decay)
767
+ local_data.add_(u, alpha=-adjusted_lr)
768
 
769
  # Expand expert params by splitting on dim 0.
770
+ logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
771
+ len(params), self.expert_keys)
772
+ if self.expert_keys:
773
+ cache_key = tuple(id(p) for p in params)
774
+ cache = self._expert_expand_cache.get(cache_key)
775
+
776
+ if cache is None:
777
+ # Cold path: full expansion + build cache metadata.
778
+ exp_names, exp_params = _expand_expert_params(
779
+ names, params, self.expert_keys)
780
+
781
+ # Build per-expert-group info for hot-path grad updates.
782
+ grad_info = []
783
+ exp_idx = 0
784
+ for orig_idx, (n, p) in enumerate(zip(names, params)):
785
+ if not is_expert_param(n, self.expert_keys):
786
+ exp_idx += 1
787
+ continue
788
+
789
+ is_dt = isinstance(p.data, DTensor)
790
+ num_experts = (p.to_local() if is_dt else p.data).shape[0]
791
+
792
+ # Detect TP mesh from the first expanded expert param.
793
+ tp_mesh = None
794
+ tp_pls = None
795
+ sample = exp_params[exp_idx]
796
+ if isinstance(sample.data, DTensor):
797
+ tp_mesh = sample.data.device_mesh
798
+ tp_pls = list(sample.data.placements)
799
+
800
+ grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
801
+ tp_mesh, tp_pls))
802
+ exp_idx += num_experts
803
+
804
+ self._expert_expand_cache[cache_key] = {
805
+ 'names': exp_names,
806
+ 'params': exp_params,
807
+ 'grad_info': grad_info,
808
+ }
809
+ names, params = exp_names, exp_params
810
+ else:
811
+ # Hot path: reuse cached params, only update expert grads.
812
+ for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
813
+ tp_pls) in cache['grad_info']:
814
+ p = params[orig_idx]
815
+ g = p.grad
816
+ local_grad = (g.to_local()
817
+ if is_dt and isinstance(g, DTensor) else g)
818
+ for i in range(num_experts):
819
+ expert_p = cache['params'][exp_start + i]
820
+ sg = local_grad[i]
821
+ if tp_mesh is not None:
822
+ expert_p.grad = DTensor.from_local(
823
+ sg, device_mesh=tp_mesh, placements=tp_pls)
824
+ else:
825
+ expert_p.grad = sg
826
+ p.grad = None
827
+
828
+ names = cache['names']
829
+ params = cache['params']
830
+ else:
831
+ names, params = _expand_expert_params(names, params,
832
+ self.expert_keys)
833
+ logger.debug("[_step_muon] after expand: %d params", len(params))
834
 
835
  param_dtensors = []
836
  name_dtensors = []
 
838
  param_tensors = []
839
  name_tensors = []
840
 
841
+ # distributed_muon is a reference implementation for testing only.
842
+ # The parallel pipeline (all2all) path below is the production path.
 
843
  if self.use_distributed_muon:
844
+ _run_deferred_expert_ns()
845
  self.distributed_muon(names=names,
846
  params=params,
847
  group=group,
 
850
  qk_logits=qk_logits)
851
  return
852
 
 
 
853
  for n, p in zip(names, params):
854
  if p is None or p.grad is None:
855
  continue
 
857
  if all(
858
  isinstance(placement, Replicate)
859
  for placement in p.placements):
860
+ logger.debug(
861
+ "[route] %s → base (DTensor all-Replicate), "
862
+ "shape=%s, placements=%s", n, p.shape, p.placements)
863
  param_tensors.append(p)
864
  name_tensors.append(n)
 
 
 
865
  else:
866
+ logger.debug(
867
+ "[route] %s → parallel (DTensor), shape=%s, "
868
+ "placements=%s, mesh=%s", n, p.shape, p.placements,
869
+ p.device_mesh.mesh_dim_names)
870
  param_dtensors.append(p)
871
  name_dtensors.append(n)
872
  elif isinstance(p.data, torch.Tensor):
873
+ logger.debug("[route] %s → base (plain tensor), shape=%s", n,
874
+ p.data.shape)
875
  param_tensors.append(p)
876
  name_tensors.append(n)
877
  else:
878
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
879
 
880
+ logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
881
+ f"{len(param_tensors)} Tensors → base")
 
882
 
883
  def group_dtensors(dtensors, names):
884
  # To support different placements, we group parameters by placements
 
894
  p.device_mesh])][1].append(p)
895
  return placement_to_params
896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  if len(param_dtensors) > 0:
898
  if not dist.is_initialized():
899
  raise RuntimeError(
 
901
  )
902
 
903
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
904
+
905
+ # Pre-launch the first chunk's A2A gather so that the NCCL
906
+ # communication overlaps with the (deferred) batched expert NS
907
+ # compute on the default CUDA stream.
908
+ prelaunch = None
909
+ if deferred_expert_work:
910
+ first_names, first_params = next(iter(dtensor_group.values()))
911
+ ordered, pts, rnk, csz = self._setup_parallel(
912
+ first_names, first_params, group, qk_logits)
913
+ first_chunk = ordered[:csz]
914
+ if first_chunk:
915
+ prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
916
+ group["none_grad"])
917
+
918
+ _run_deferred_expert_ns()
919
+
920
+ first_group = True
921
  for _, (names, params) in dtensor_group.items():
922
+ pg = prelaunch if first_group else None
923
+ first_group = False
924
  self.parallel(
925
  names,
926
  params,
 
928
  lr=lr,
929
  weight_decay=weight_decay,
930
  qk_logits=qk_logits,
931
+ prelaunch_gather=pg,
932
  )
933
+ else:
934
+ _run_deferred_expert_ns()
935
 
936
  if len(param_tensors) > 0:
937
  self.base(
 
943
  qk_logits=qk_logits,
944
  )
945
 
946
+ def _register_states_for_offload(self):
947
+ """Register all optimizer state tensors with the CPU offload pool.
948
+
949
+ Called once after the first step when states have been lazily created.
950
+ Offloads all param states (momentum buffers for Muon, moment1/moment2
951
+ for AdamW) to free GPU memory between steps.
952
+ """
953
+ pool = self._cpu_offload_pool
954
+ tracked = 0
955
+ for group in self.param_groups:
956
+ for p in group["params"]:
957
+ if p not in self.state:
958
+ continue
959
+ state = self.state[p]
960
+ if group.get("use_muon", False):
961
+ if "momentum_buffer" in state:
962
+ pool.track(state["momentum_buffer"])
963
+ tracked += 1
964
+ else:
965
+ if "moment1" in state:
966
+ pool.track(state["moment1"])
967
+ if "moment2" in state:
968
+ pool.track(state["moment2"])
969
+ tracked += 1
970
+ logger.info("[CPUOffload] Registered %d param states for offload",
971
+ tracked)
972
+
973
  @torch.no_grad
974
  def step(self, closure=None, qk_logits=None):
975
  """Perform a single optimization step.
 
987
  with torch.enable_grad():
988
  loss = closure()
989
 
990
+ # H2D: reload optimizer states from CPU before computation.
991
+ if self.cpu_offload and self._offload_initialized:
992
+ self._cpu_offload_pool.reload()
993
+
994
+ logger.debug("[Muon.step] expert_keys=%s, %d param groups",
995
+ self.expert_keys, len(self.param_groups))
996
+
997
+ for i, group in enumerate(self.param_groups):
998
  if group["use_muon"]:
999
+ logger.debug("[Muon.step] group %d: use_muon=True, %d params",
1000
+ i, len(group["params"]))
1001
  self._step_muon(group, qk_logits=qk_logits)
1002
  else:
1003
+ logger.debug(
1004
+ "[Muon.step] group %d: use_muon=False (AdamW), %d params",
1005
+ i, len(group["params"]))
1006
  step_adamw(self.state, group)
1007
 
1008
+ # D2H: offload optimizer states to CPU after computation.
1009
+ if self.cpu_offload:
1010
+ if not self._offload_initialized:
1011
+ self._register_states_for_offload()
1012
+ self._offload_initialized = True
1013
+ self._cpu_offload_pool.offload()
1014
+
1015
  return loss
1016
+
1017
+ # ------------------------------------------------------------------
1018
+ # Checkpoint support for cpu_offload
1019
+ # ------------------------------------------------------------------
1020
+
1021
+ def state_dict(self) -> dict:
1022
+ """Return optimizer state dict, reloading offloaded states first.
1023
+
1024
+ When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
+ storage freed (``resize_(0)``) between steps. We reload them,
1026
+ snapshot the state dict, then re-offload so the optimizer stays
1027
+ in the expected post-step state. The returned dict holds cloned
1028
+ tensors so they remain valid after the re-offload frees the
1029
+ originals' GPU storage.
1030
+ """
1031
+ if self.cpu_offload and self._offload_initialized:
1032
+ self._cpu_offload_pool.reload()
1033
+ torch.cuda.current_stream().synchronize()
1034
+ sd = super().state_dict()
1035
+ if self.cpu_offload and self._offload_initialized:
1036
+ # Clone state tensors so the returned dict survives re-offload
1037
+ # (which frees GPU storage on the originals via resize_(0)).
1038
+ for k in sd["state"]:
1039
+ sd["state"][k] = {
1040
+ sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
+ for sk, sv in sd["state"][k].items()
1042
+ }
1043
+ self._cpu_offload_pool.offload()
1044
+ return sd
1045
+
1046
+ def load_state_dict(self, state_dict: dict) -> None:
1047
+ """Load optimizer state dict, then offload states if needed.
1048
+
1049
+ After ``super().load_state_dict()`` populates GPU tensors, we
1050
+ re-register them with the offload pool and offload to CPU so the
1051
+ optimizer is in the same post-step state (GPU storage freed).
1052
+ """
1053
+ # If states were offloaded, reload first so storage sizes are
1054
+ # correct for super().load_state_dict() to overwrite.
1055
+ if self.cpu_offload and self._offload_initialized:
1056
+ self._cpu_offload_pool.reload()
1057
+ torch.cuda.current_stream().synchronize()
1058
+
1059
+ super().load_state_dict(state_dict)
1060
+
1061
+ if self.cpu_offload:
1062
+ # Re-create the offload pool since state tensors may be new
1063
+ # objects after load_state_dict.
1064
+ self._cpu_offload_pool = CPUOffloadPool()
1065
+ self._offload_initialized = False
1066
+ self._register_states_for_offload()
1067
+ self._offload_initialized = True
1068
+ self._cpu_offload_pool.offload()
build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import torch
2
 
3
  from .matmul_transpose_triton import matmul_transpose_assign
@@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16
6
  DEFAULT_CHUNK_SIZE_RATIO = 4
7
 
8
 
9
- # This code snippet is a modified version adapted from the following GitHub repositories:
10
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
- # Muon's Newton–Schulz iteration causes high variance in singular values
12
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @torch.no_grad()
14
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
  def _zeropower_via_newtonschulz5(G, steps):
16
  """
17
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
- performance at all relative to UV^T, where USV^T = G is the SVD.
 
 
 
 
 
 
 
24
  """
25
  assert len(G.shape) == 2
26
  assert G.dtype == COMM_DTYPE
@@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps):
28
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
- # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
 
 
33
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
  # Perform the NS iterations
36
- for a, b, c in [
37
- (4.0848, -6.8946, 2.9270),
38
- (3.9505, -6.3029, 2.6377),
39
- (3.7418, -5.5913, 2.3037),
40
- (2.8769, -3.1427, 1.2046),
41
- (2.8366, -3.0525, 1.2012),
42
- ]:
43
  matmul_transpose_assign(X, buf1)
44
  matmul_transpose_assign(buf1, buf2)
45
  buf1.mul_(b).add_(buf2, alpha=c)
@@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps):
47
 
48
  if G.size(0) > G.size(1):
49
  X = X.T
 
50
  return X
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ from math import inf, sqrt
3
+
4
+ import numpy as np
5
  import torch
6
 
7
  from .matmul_transpose_triton import matmul_transpose_assign
 
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
 
13
+ def _optimal_quintic(l, u, max_iter=1000):
14
+ """
15
+ Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
+ to the constant function x -> 1 over the interval [l, u].
17
+
18
+ Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
+ approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
+ two interior equioscillation nodes q, r until convergence. Returns the
21
+ closed-form equioscillating solution when l ≈ u.
22
+
23
+ Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
+ (NaN or inf). Raises RuntimeError if convergence is not reached within
25
+ max_iter iterations.
26
+ """
27
+ assert 0 <= l <= u
28
+ if 1 - 5e-6 <= l / u:
29
+ return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
+ q = (3 * l + u) / 4
31
+ r = (l + 3 * u) / 4
32
+ E = inf
33
+ for _ in range(max_iter):
34
+ old_E = E
35
+ LHS = np.array([
36
+ [l, l**3, l**5, 1],
37
+ [q, q**3, q**5, -1],
38
+ [r, r**3, r**5, 1],
39
+ [u, u**3, u**5, -1],
40
+ ])
41
+ a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
+ if not np.all(np.isfinite([a, b, c, E])):
43
+ raise ValueError(f"_optimal_quintic: non-finite solve result "
44
+ f"a={a}, b={b}, c={c}, E={E}")
45
+ q, r = np.sqrt(
46
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
+ (10 * c))
48
+ if not np.all(np.isfinite([q, r])):
49
+ raise ValueError(
50
+ f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
+ if abs(old_E - E) <= 1e-15:
52
+ break
53
+ else:
54
+ raise RuntimeError(
55
+ f"_optimal_quintic: did not converge after {max_iter} iterations")
56
+ return float(a), float(b), float(c)
57
+
58
+
59
+ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
60
+ """
61
+ Compute the Polar Express coefficient series for `num_iters` quintic iterations.
62
+
63
+ Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
64
+ compose to map singular values from [l, 1] toward 1. At each step:
65
+ 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
66
+ prevents near-zero singular values from stalling by raising the effective
67
+ lower bound; if it is active (cushion*u > l), the coefficients are
68
+ rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
69
+ 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
70
+ last iteration, providing numerical headroom at the cost of a slightly slower
71
+ final convergence step.
72
+ 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
73
+
74
+ Returns a list of (a, b, c) tuples, one per iteration.
75
+
76
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
77
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
78
+ """
79
+ u = 1
80
+ assert 0 <= l <= u
81
+ safety_factor = 1 + safety_factor_eps
82
+ coefficients = []
83
+ for iter in range(num_iters):
84
+ a, b, c = _optimal_quintic(max(l, cushion * u), u)
85
+ if cushion * u > l:
86
+ pl = a * l + b * l**3 + c * l**5
87
+ pu = a * u + b * u**3 + c * u**5
88
+ rescaler = 2 / (pl + pu)
89
+ a *= rescaler
90
+ b *= rescaler
91
+ c *= rescaler
92
+ if iter < num_iters - 1:
93
+ a /= safety_factor
94
+ b /= safety_factor**3
95
+ c /= safety_factor**5
96
+ coefficients.append((a, b, c))
97
+ l = a * l + b * l**3 + c * l**5
98
+ u = 2 - l
99
+ return coefficients
100
+
101
+
102
+ # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
103
+ # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
104
+ # approximant to x->1 over the current singular-value interval, computed once at
105
+ # import time and reused across all optimizer steps.
106
+ #
107
+ # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
108
+ # - Former: empirically tuned to maximize slope at zero; did not converge
109
+ # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
110
+ # of the true polar factor UV^T.
111
+ # - Polar Express: analytically optimal per step, adapting to the shrinking
112
+ # singular-value interval [l, u] as iterations progress; converges all
113
+ # singular values to 1, producing the exact polar factor UV^T.
114
+ _coeffs_list = _optimal_composition(l=1e-3,
115
+ num_iters=10,
116
+ safety_factor_eps=1e-2,
117
+ cushion=0.02)
118
+
119
+
120
+ # This code is adapted from:
121
+ # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
122
+ # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
123
+ # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
124
  @torch.no_grad()
 
125
  def _zeropower_via_newtonschulz5(G, steps):
126
  """
127
+ Compute the polar factor of G via the Polar Express method.
128
+
129
+ Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
130
+ are the Polar Express coefficients from `_coeffs_list`. Each step is the
131
+ optimal odd quintic approximant to x -> 1 over the current singular-value
132
+ interval, minimizing the maximum approximation error (Remez / minimax criterion).
133
+ The composition maps singular values from [l, 1] to near 1, producing the
134
+ polar factor (orthogonal factor in the polar decomposition G = UP).
135
+
136
+ `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
137
+ cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
138
+
139
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
140
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
141
  """
142
  assert len(G.shape) == 2
143
  assert G.dtype == COMM_DTYPE
 
145
 
146
  if G.size(0) > G.size(1):
147
  X = X.T
148
+
149
  X = X / (X.norm() + 1e-7)
150
+ hs = _coeffs_list[:steps] + list(
151
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
155
+ for a, b, c in hs:
 
 
 
 
 
 
156
  matmul_transpose_assign(X, buf1)
157
  matmul_transpose_assign(buf1, buf2)
158
  buf1.mul_(b).add_(buf2, alpha=c)
 
160
 
161
  if G.size(0) > G.size(1):
162
  X = X.T
163
+
164
  return X
165
+
166
+
167
+ @torch.no_grad()
168
+ def _zeropower_via_newtonschulz5_batched(G, steps):
169
+ """Batched polar factor computation for 3D (E, out, in) tensors.
170
+
171
+ Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
172
+ ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
173
+ processing all E expert matrices in a single batched call.
174
+ """
175
+ assert len(G.shape) == 3
176
+ assert G.dtype == COMM_DTYPE
177
+ X = G
178
+
179
+ if G.size(1) > G.size(2):
180
+ X = X.transpose(-2, -1)
181
+
182
+ # Per-expert Frobenius norm.
183
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
+
185
+ hs = _coeffs_list[:steps] + list(
186
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
187
+ for a, b, c in hs:
188
+ buf1 = torch.bmm(X, X.transpose(-2, -1))
189
+ buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
190
+ buf1.mul_(b).add_(buf2, alpha=c)
191
+ X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
192
+
193
+ if G.size(1) > G.size(2):
194
+ X = X.transpose(-2, -1)
195
+
196
+ return X
197
+
198
+
199
+ _ns_per_shape: dict[tuple[int, ...], callable] = {}
200
+ _use_compile = True
201
+
202
+
203
+ def set_ns_compile(enabled: bool):
204
+ """Toggle torch.compile for Newton-Schulz iteration."""
205
+ global _use_compile
206
+ _use_compile = enabled
207
+
208
+
209
+ def zeropower_via_newtonschulz5(G, steps=5):
210
+ if not _use_compile:
211
+ return _zeropower_via_newtonschulz5(G, steps)
212
+ key = G.shape
213
+ if key not in _ns_per_shape:
214
+ _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
215
+ options={
216
+ "triton.cudagraphs": True,
217
+ "shape_padding": False
218
+ })
219
+ torch.compiler.cudagraph_mark_step_begin()
220
+ return _ns_per_shape[key](G, steps).clone()
221
+
222
+
223
+ def zeropower_via_newtonschulz5_batched(G, steps=5):
224
+ """Compile-cached batched Newton-Schulz for 3D expert tensors."""
225
+ if not _use_compile:
226
+ return _zeropower_via_newtonschulz5_batched(G, steps)
227
+ key = G.shape
228
+ if key not in _ns_per_shape:
229
+ _ns_per_shape[key] = torch.compile(
230
+ _zeropower_via_newtonschulz5_batched,
231
+ options={
232
+ "triton.cudagraphs": True,
233
+ "shape_padding": False
234
+ })
235
+ torch.compiler.cudagraph_mark_step_begin()
236
+ return _ns_per_shape[key](G, steps).clone()
build/torch210-cxx11-cu128-x86_64-linux/pipeline.py CHANGED
@@ -6,8 +6,8 @@ import torch.distributed as dist
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
- from .core import _muon_state, adjust_lr_for_muon, update_p
10
- from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
@@ -45,26 +45,33 @@ def _launch_gather(
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
- # Build send buffer
49
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
  send_counts = [0] * num_ranks
51
-
52
  for p in params:
53
  state = param_to_state[id(p)]
54
- dst = state.worker_rank
55
- assert dst < num_ranks
56
- shard_elems = state.rank_numels[rank]
57
- g = p.grad
58
- g = g.to_local().to(COMM_DTYPE).contiguous()
59
- assert g.numel() == shard_elems
60
- per_dst[dst].append(g.view(-1))
61
- send_counts[dst] += shard_elems
62
-
63
- assert any(
64
- len(v) > 0 for v in
65
- per_dst), "At least one destination rank must receive a sharded tensor"
66
- per_dst_flat = [t for dst in per_dst for t in dst]
67
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
 
 
 
68
 
69
  # Build recv buffer
70
  recv_counts = [0] * num_ranks
@@ -120,7 +127,8 @@ def _complete_gather(
120
 
121
  shard_view = gathered_grads[id(p)][indices]
122
  n = shard_view.numel()
123
- assert n > 0
 
124
 
125
  sg = recv_buf.narrow(0, off + inner_off, n)
126
  sg = sg.reshape(shard_view.shape)
@@ -143,7 +151,7 @@ def _compute_ns(
143
  """
144
  computed_us: dict[int, torch.Tensor | None] = {}
145
  for p in owned_params:
146
- u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
  gathered_grads[id(p)] = None # free gathered grad
148
  computed_us[id(p)] = u
149
  return computed_us
@@ -163,46 +171,47 @@ def _launch_scatter(
163
  Returns:
164
  work: Async operation handle.
165
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
- scattered_us: ``{id(p): empty_local_tensor}`` for all params.
 
167
  recv_counts: Per-source-rank element counts.
168
  """
169
- # Allocate scattered-u buffers
 
 
 
170
  scattered_us: dict[int, torch.Tensor] = {}
171
  for p in params:
172
- scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
 
 
173
 
174
- # Build send buffer (from computed_us on owner ranks)
175
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
  send_counts = [0] * num_ranks
177
-
178
  if owned_params:
179
  for p in owned_params:
180
  state = param_to_state[id(p)]
181
-
182
- assert computed_us[id(p)] is not None
183
- u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
-
185
- total_sent = 0
186
  for dst_rank in range(num_ranks):
187
- indices = state.rank_indices[dst_rank]
188
- su = u_full[indices].flatten()
189
-
190
- n = su.numel()
191
- assert n > 0
192
 
193
- per_dst[dst_rank].append(su)
194
- send_counts[dst_rank] += n
195
- total_sent += n
196
-
197
- assert total_sent == u_full.numel()
198
-
199
- lengths = [len(v) for v in per_dst]
200
- if all(l > 0 for l in lengths):
201
- assert all(
202
- l == lengths[0] for l in lengths
203
- ), "All destination ranks must have the same number of sharded tensor"
204
- per_dst_flat = [t for dst in per_dst for t in dst]
205
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
206
  else:
207
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
 
@@ -218,7 +227,6 @@ def _launch_scatter(
218
  recv_counts[src] = total
219
 
220
  recv_total = sum(recv_counts)
221
- assert recv_total > 0
222
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
 
224
  # Launch async all-to-all
@@ -242,7 +250,13 @@ def _complete_scatter(
242
  rank: int,
243
  scattered_us: dict[int, torch.Tensor],
244
  ) -> None:
245
- """Copy recv buffer into scattered_us (in-place)."""
 
 
 
 
 
 
246
  off = 0
247
  for src in range(len(recv_counts)):
248
  block = recv_counts[src]
@@ -255,11 +269,11 @@ def _complete_scatter(
255
  if state.worker_rank != src:
256
  continue
257
  n = state.rank_numels[rank]
258
- assert n > 0
 
259
 
260
- flat_local = recv_buf.narrow(0, off + inner_off,
261
- n).view_as(p.to_local())
262
- scattered_us[id(p)].copy_(flat_local)
263
 
264
  inner_off += n
265
 
@@ -275,23 +289,40 @@ def _update_params(
275
  lr: float,
276
  weight_decay: float,
277
  ) -> None:
278
- """Apply weight decay, Muon update, and optional QK clipping."""
279
- for p in params:
280
- state = param_to_state[id(p)]
281
- u_dtensor = DTensor.from_local(
282
- scattered_us[id(p)],
283
- placements=p.placements,
284
- device_mesh=p.device_mesh,
285
- )
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
- update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
 
 
 
289
 
290
- # QK clipping applied directly on the local tensor to
291
- # avoid DTensor sharding-propagation issues with _StridedShard.
292
- scales_full = compute_scales(
293
- p,
294
- state.qk_clip_state) if state.qk_clip_state is not None else None
 
 
 
 
 
295
  if scales_full is not None:
296
  ratio = p.shape[0] // scales_full.shape[0]
297
  idx0 = state.rank_indices[rank][0]
@@ -304,6 +335,45 @@ def _update_params(
304
  p._local_tensor.mul_(row_scales.view(-1, 1))
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # ======================================================================
308
  # Main generator – thin orchestrator that wires stages together.
309
  # ======================================================================
@@ -318,6 +388,7 @@ def muon_chunk_pipeline(
318
  lr: float,
319
  weight_decay: float,
320
  none_grad: bool,
 
321
  ) -> Generator[None, None, None]:
322
  """Process one chunk of parameters through the full Muon pipeline.
323
 
@@ -334,9 +405,12 @@ def muon_chunk_pipeline(
334
  runs concurrently on the NCCL stream — no separate ``comm_stream``
335
  is required.
336
 
 
 
 
337
  Yields exactly **2** times:
338
 
339
- 1. After launching async all-to-all gather.
340
  2. After launching async all-to-all scatter.
341
  """
342
  process_group = param_to_state[id(params[0])].process_group
@@ -345,15 +419,19 @@ def muon_chunk_pipeline(
345
  p for p in params if param_to_state[id(p)].worker_rank == rank
346
  ]
347
 
348
- # Stages 1-2: launch async gather.
349
- with record_function("muon::launch_gather"):
350
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
- params, owned_params, param_to_state, rank, num_ranks,
352
- process_group)
353
-
354
- if none_grad:
355
- for p in params:
356
- p.grad = None
 
 
 
 
357
 
358
  yield # --- YIELD 1: other chunks can launch their gather ---
359
 
 
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
+ from .core import _muon_state, adjust_lr_for_muon
10
+ from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
 
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
+ # Build send buffer – batch grad copies via torch.cat
49
+ # (1-2 fused kernels vs N individual narrow().copy_() calls).
50
  send_counts = [0] * num_ranks
 
51
  for p in params:
52
  state = param_to_state[id(p)]
53
+ send_counts[state.worker_rank] += state.rank_numels[rank]
54
+
55
+ total_send = sum(send_counts)
56
+ if total_send > 0:
57
+ # Group grad slices by destination rank in a single pass.
58
+ dst_to_grads = [[] for _ in range(num_ranks)]
59
+ for p in params:
60
+ state = param_to_state[id(p)]
61
+ n = state.rank_numels[rank]
62
+ if n > 0:
63
+ g = p.grad.to_local()
64
+ dst_to_grads[state.worker_rank].append(g.reshape(-1))
65
+
66
+ # Flatten in dst order and cat once.
67
+ all_slices = []
68
+ for dst in range(num_ranks):
69
+ all_slices.extend(dst_to_grads[dst])
70
+ send_buf = torch.cat(all_slices)
71
+ if send_buf.dtype != COMM_DTYPE:
72
+ send_buf = send_buf.to(COMM_DTYPE)
73
+ else:
74
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
75
 
76
  # Build recv buffer
77
  recv_counts = [0] * num_ranks
 
127
 
128
  shard_view = gathered_grads[id(p)][indices]
129
  n = shard_view.numel()
130
+ if n == 0:
131
+ continue
132
 
133
  sg = recv_buf.narrow(0, off + inner_off, n)
134
  sg = sg.reshape(shard_view.shape)
 
151
  """
152
  computed_us: dict[int, torch.Tensor | None] = {}
153
  for p in owned_params:
154
+ u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
155
  gathered_grads[id(p)] = None # free gathered grad
156
  computed_us[id(p)] = u
157
  return computed_us
 
171
  Returns:
172
  work: Async operation handle.
173
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
174
+ scattered_us: Empty dict, populated by ``_complete_scatter`` with
175
+ zero-copy views into ``recv_buf``.
176
  recv_counts: Per-source-rank element counts.
177
  """
178
+ # scattered_us is populated by _complete_scatter with zero-copy views
179
+ # into recv_buf, avoiding N empty_like allocations + N copy_ calls.
180
+ # Pre-seed entries for params whose local shard is empty (rank_numels == 0)
181
+ # so _update_params can iterate all params without KeyError.
182
  scattered_us: dict[int, torch.Tensor] = {}
183
  for p in params:
184
+ if param_to_state[id(p)].rank_numels[rank] == 0:
185
+ scattered_us[id(p)] = torch.empty_like(p.to_local(),
186
+ dtype=COMM_DTYPE)
187
 
188
+ # Build send buffer batch via torch.cat
189
+ # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
190
  send_counts = [0] * num_ranks
 
191
  if owned_params:
192
  for p in owned_params:
193
  state = param_to_state[id(p)]
 
 
 
 
 
194
  for dst_rank in range(num_ranks):
195
+ send_counts[dst_rank] += state.rank_numels[dst_rank]
 
 
 
 
196
 
197
+ total_send = sum(send_counts)
198
+ if total_send > 0:
199
+ # Cache u_full conversions to avoid redundant .to() per dst_rank.
200
+ u_fulls = {}
201
+ for p in owned_params:
202
+ u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
203
+
204
+ # Collect slices in dst order (matches all-to-all send layout).
205
+ all_slices = []
206
+ for dst_rank in range(num_ranks):
207
+ for p in owned_params:
208
+ state = param_to_state[id(p)]
209
+ su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
210
+ if su.numel() > 0:
211
+ all_slices.append(su)
212
+
213
+ send_buf = torch.cat(all_slices) if all_slices else torch.empty(
214
+ 0, dtype=COMM_DTYPE, device="cuda")
215
  else:
216
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
217
 
 
227
  recv_counts[src] = total
228
 
229
  recv_total = sum(recv_counts)
 
230
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
231
 
232
  # Launch async all-to-all
 
250
  rank: int,
251
  scattered_us: dict[int, torch.Tensor],
252
  ) -> None:
253
+ """Populate scattered_us with zero-copy views into recv_buf.
254
+
255
+ Instead of pre-allocating tensors and copying, we assign views directly
256
+ from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
257
+ The underlying storage of ``recv_buf`` is kept alive through the views
258
+ until ``scattered_us`` is cleared after ``_update_params``.
259
+ """
260
  off = 0
261
  for src in range(len(recv_counts)):
262
  block = recv_counts[src]
 
269
  if state.worker_rank != src:
270
  continue
271
  n = state.rank_numels[rank]
272
+ if n == 0:
273
+ continue
274
 
275
+ scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
276
+ n).view_as(p.to_local())
 
277
 
278
  inner_off += n
279
 
 
289
  lr: float,
290
  weight_decay: float,
291
  ) -> None:
292
+ """Apply weight decay, Muon update, and optional QK clipping.
 
 
 
 
 
 
 
293
 
294
+ Uses batched ``_foreach_mul_`` for weight decay and batched
295
+ ``_foreach_add_`` for the Muon update, grouping parameters by
296
+ adjusted_lr to minimize kernel launches while preserving float32
297
+ precision for the alpha scaling.
298
+ """
299
+ if not params:
300
+ return
301
+
302
+ # Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
303
+ p_locals = [p._local_tensor for p in params]
304
+ torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
305
+
306
+ # Group params by adjusted_lr so _foreach_add_ can use a single
307
+ # alpha per group (preserves float32 precision for alpha scaling).
308
+ lr_groups: dict[float, tuple[list, list]] = {}
309
+ for p in params:
310
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
311
+ if adjusted_lr not in lr_groups:
312
+ lr_groups[adjusted_lr] = ([], [])
313
+ lr_groups[adjusted_lr][0].append(p._local_tensor)
314
+ lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
315
 
316
+ for adjusted_lr, (p_group, u_group) in lr_groups.items():
317
+ torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
318
+
319
+ # QK clipping – applied directly on the local tensor to
320
+ # avoid DTensor sharding-propagation issues with _StridedShard.
321
+ for p in params:
322
+ state = param_to_state[id(p)]
323
+ if state.qk_clip_state is None:
324
+ continue
325
+ scales_full = compute_scales(p, state.qk_clip_state)
326
  if scales_full is not None:
327
  ratio = p.shape[0] // scales_full.shape[0]
328
  idx0 = state.rank_indices[rank][0]
 
335
  p._local_tensor.mul_(row_scales.view(-1, 1))
336
 
337
 
338
+ # ======================================================================
339
+ # Pre-launch helper for overlapping first chunk's gather with other work.
340
+ # ======================================================================
341
+
342
+
343
+ @torch.no_grad()
344
+ def prelaunch_first_gather(
345
+ params: list[DTensor],
346
+ param_to_state: dict[int, _muon_state],
347
+ rank: int,
348
+ none_grad: bool,
349
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
350
+ """Launch the first chunk's A2A gather early for overlap with other compute.
351
+
352
+ Call this *before* expensive GPU work (e.g. batched expert NS) so that
353
+ the NCCL all-to-all runs concurrently on the NCCL stream while the
354
+ default stream executes compute.
355
+
356
+ Returns the same 4-tuple that ``_launch_gather`` produces, which should
357
+ be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
358
+ """
359
+ process_group = param_to_state[id(params[0])].process_group
360
+ num_ranks = dist.get_world_size(group=process_group)
361
+ owned_params = [
362
+ p for p in params if param_to_state[id(p)].worker_rank == rank
363
+ ]
364
+
365
+ with record_function("muon::prelaunch_gather"):
366
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
367
+ params, owned_params, param_to_state, rank, num_ranks,
368
+ process_group)
369
+
370
+ if none_grad:
371
+ for p in params:
372
+ p.grad = None
373
+
374
+ return work, recv_buf, gathered_grads, recv_counts
375
+
376
+
377
  # ======================================================================
378
  # Main generator – thin orchestrator that wires stages together.
379
  # ======================================================================
 
388
  lr: float,
389
  weight_decay: float,
390
  none_grad: bool,
391
+ prelaunch_gather: tuple | None = None,
392
  ) -> Generator[None, None, None]:
393
  """Process one chunk of parameters through the full Muon pipeline.
394
 
 
405
  runs concurrently on the NCCL stream — no separate ``comm_stream``
406
  is required.
407
 
408
+ If ``prelaunch_gather`` is provided, the gather was already launched
409
+ by :func:`prelaunch_first_gather` and we skip launching it again.
410
+
411
  Yields exactly **2** times:
412
 
413
+ 1. After launching async all-to-all gather (or immediately if pre-launched).
414
  2. After launching async all-to-all scatter.
415
  """
416
  process_group = param_to_state[id(params[0])].process_group
 
419
  p for p in params if param_to_state[id(p)].worker_rank == rank
420
  ]
421
 
422
+ if prelaunch_gather is not None:
423
+ # Gather was pre-launched; none_grad already handled by caller.
424
+ work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
425
+ else:
426
+ # Normal path: launch async gather.
427
+ with record_function("muon::launch_gather"):
428
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
429
+ params, owned_params, param_to_state, rank, num_ranks,
430
+ process_group)
431
+
432
+ if none_grad:
433
+ for p in params:
434
+ p.grad = None
435
 
436
  yield # --- YIELD 1: other chunks can launch their gather ---
437
 
build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import dataclass
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
@@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
  'model.4.attn.v_proj.weight' -> (None, -1)
25
  """
26
- parts = name.split('.')
27
  if len(parts) < 3:
28
  return None, -1
29
 
@@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state):
100
  threshold = qk_clip_state.threshold
101
  logit = qk_clip_state.logit
102
 
103
- H_global = p.shape[0] // head_dim
104
- scales_full = torch.ones(H_global, device=p.data.device)
105
- scaling = 0
106
-
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
- if new_scale < scales_full[head_idx]:
112
- scales_full[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
117
- scaling += 1
118
 
119
- return scales_full if scaling > 0 else None
 
 
 
 
 
 
 
120
 
121
 
122
  def qk_clip(p, scales, head_dim):
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
8
+ from .core import normalize_fqn
9
+
10
  logger = logging.getLogger(__name__)
11
 
12
 
 
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
+ parts = normalize_fqn(name).split('.')
29
  if len(parts) < 3:
30
  return None, -1
31
 
 
102
  threshold = qk_clip_state.threshold
103
  logit = qk_clip_state.logit
104
 
105
+ # Check if any head exceeds threshold before allocating.
106
+ head_scales = {}
 
 
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
+ if head_idx not in head_scales or new_scale < head_scales[head_idx]:
112
+ head_scales[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
 
117
 
118
+ if not head_scales:
119
+ return None
120
+
121
+ H_global = p.shape[0] // head_dim
122
+ scales_full = torch.ones(H_global, device=p.data.device)
123
+ for head_idx, scale in head_scales.items():
124
+ scales_full[head_idx] = scale
125
+ return scales_full
126
 
127
 
128
  def qk_clip(p, scales, head_dim):
build/torch210-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_7aef62f_dirty
3
- ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_7aef62f_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_5b58933_dirty
3
+ ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_5b58933_dirty::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9c7bb12bc030d4959e880a959b39ea07eb03e16175d7cf03829f9860f52525d
3
  size 2004728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6869cfabdf45c7092d251846b3099287f8bccd5c5ebe7edf1a5fd21436324349
3
  size 2004728
build/torch210-cxx11-cu130-x86_64-linux/adamw.py CHANGED
@@ -1,8 +1,12 @@
 
1
  from collections import defaultdict
2
  from typing import cast
3
 
4
  import torch
5
  from torch.distributed.tensor import DTensor
 
 
 
6
 
7
 
8
  def fused_adamw(
@@ -72,54 +76,72 @@ def fused_adamw(
72
  )
73
 
74
 
75
- def step_adamw_params(optimizer_state, params, group):
76
- """Run fused AdamW on a list of parameters sharing the same placement.
 
77
 
78
- Args:
79
- optimizer_state: The optimizer's state dict (self.state in Muon).
80
- params: List of parameters to update.
81
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  params_with_grads = []
84
  grads = []
85
  moment1 = []
86
  moment2 = []
87
- max_exp_avg_sqs = []
88
  state_steps = []
89
- lr = group["lr"]
90
- beta1, beta2 = group["adamw_betas"]
91
- eps = group["adamw_eps"]
92
- weight_decay = group["weight_decay"]
93
 
94
  for p in params:
95
  g = p.grad
96
  if g is None:
97
  continue
98
  state = optimizer_state[p]
99
- params_with_grads.append(p)
100
- grads.append(g)
101
  if "step" not in state:
102
- state["step"] = (torch.zeros((),
103
- dtype=torch.float32,
104
- device=p.device))
105
  state["moment1"] = torch.zeros_like(g)
106
  state["moment2"] = torch.zeros_like(g)
107
- moment1.append(state["moment1"])
108
- moment2.append(state["moment2"])
109
  if not isinstance(state["step"], torch.Tensor):
110
- step_tensor = torch.tensor(state["step"],
111
- dtype=torch.float32,
112
- device=p.device)
113
- else:
114
- step_tensor = state["step"]
115
- state_steps.append(step_tensor)
 
 
 
 
 
 
116
 
117
  fused_adamw(
118
  params_with_grads,
119
  grads,
120
  moment1,
121
  moment2,
122
- max_exp_avg_sqs,
123
  state_steps,
124
  amsgrad=False,
125
  beta1=beta1,
@@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group):
131
  )
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def step_adamw(optimizer_state, group):
135
  """Dispatch AdamW step, grouping parameters by type and placement.
136
 
 
 
 
137
  Args:
138
  optimizer_state: The optimizer's state dict (self.state in Muon).
139
  group: Parameter group dict.
140
  """
141
  params = group["params"]
 
142
 
143
- # group params with its type and placement
144
- placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
- for p in params:
146
- match p:
147
- case DTensor():
148
- placement_to_params[tuple([p.placements,
149
- p.device_mesh])].append(p)
150
- case torch.Tensor():
151
- placement_to_params[tuple([torch.Tensor, None])].append(p)
152
-
153
- for group_params in placement_to_params.values():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  step_adamw_params(optimizer_state, group_params, group)
 
1
+ import logging
2
  from collections import defaultdict
3
  from typing import cast
4
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ logger = logging.getLogger(__name__)
10
 
11
 
12
  def fused_adamw(
 
76
  )
77
 
78
 
79
+ def _to_local(t):
80
+ """Unwrap DTensor to local tensor for fused ops."""
81
+ return t._local_tensor if isinstance(t, DTensor) else t
82
 
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Caches for eliminating per-step Python overhead.
86
+ #
87
+ # Placement grouping and tensor list assembly are identical every step
88
+ # (params don't change placement, moment/step tensors are the same objects
89
+ # after initialisation). We cache them keyed by id() of the param list
90
+ # stored in param_groups (stable across steps).
91
+ #
92
+ # Only gradients change each step and must be collected fresh.
93
+ # ---------------------------------------------------------------------------
94
+
95
+ # id(group["params"]) → dict[placement_key, list[param]]
96
+ _placement_cache: dict[int, dict[tuple, list]] = {}
97
+
98
+ # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
+ _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
+
101
+
102
+ def _step_adamw_params_slow(optimizer_state, params, group):
103
+ """Uncached fallback for the rare case where some params lack grads."""
104
  params_with_grads = []
105
  grads = []
106
  moment1 = []
107
  moment2 = []
 
108
  state_steps = []
 
 
 
 
109
 
110
  for p in params:
111
  g = p.grad
112
  if g is None:
113
  continue
114
  state = optimizer_state[p]
115
+ params_with_grads.append(_to_local(p))
116
+ grads.append(_to_local(g))
117
  if "step" not in state:
118
+ state["step"] = torch.zeros((),
119
+ dtype=torch.float32,
120
+ device=p.device)
121
  state["moment1"] = torch.zeros_like(g)
122
  state["moment2"] = torch.zeros_like(g)
123
+ moment1.append(_to_local(state["moment1"]))
124
+ moment2.append(_to_local(state["moment2"]))
125
  if not isinstance(state["step"], torch.Tensor):
126
+ state["step"] = torch.tensor(state["step"],
127
+ dtype=torch.float32,
128
+ device=p.device)
129
+ state_steps.append(state["step"])
130
+
131
+ if not params_with_grads:
132
+ return
133
+
134
+ lr = group["lr"]
135
+ beta1, beta2 = group["adamw_betas"]
136
+ eps = group["adamw_eps"]
137
+ weight_decay = group["weight_decay"]
138
 
139
  fused_adamw(
140
  params_with_grads,
141
  grads,
142
  moment1,
143
  moment2,
144
+ [],
145
  state_steps,
146
  amsgrad=False,
147
  beta1=beta1,
 
153
  )
154
 
155
 
156
+ def step_adamw_params(optimizer_state, params, group):
157
+ """Run fused AdamW on a list of parameters sharing the same placement.
158
+
159
+ After the first call, cached tensor lists (params_local, moment1,
160
+ moment2, state_steps) are reused — only gradients are collected fresh.
161
+
162
+ Args:
163
+ optimizer_state: The optimizer's state dict (self.state in Muon).
164
+ params: List of parameters to update.
165
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
+ """
167
+ # Collect grads — the only thing that changes each step.
168
+ with record_function("adamw::collect_grads"):
169
+ grads = []
170
+ for p in params:
171
+ g = p.grad
172
+ if g is None:
173
+ # Rare: fall back to slow path that filters per-param.
174
+ _step_adamw_params_slow(optimizer_state, params, group)
175
+ return
176
+ grads.append(_to_local(g))
177
+
178
+ tensor_key = id(params)
179
+ if tensor_key not in _tensor_cache:
180
+ with record_function("adamw::init_tensor_cache"):
181
+ params_local = []
182
+ moment1 = []
183
+ moment2 = []
184
+ state_steps = []
185
+
186
+ for p in params:
187
+ state = optimizer_state[p]
188
+ params_local.append(_to_local(p))
189
+ if "step" not in state:
190
+ state["step"] = torch.zeros((),
191
+ dtype=torch.float32,
192
+ device=p.device)
193
+ state["moment1"] = torch.zeros_like(p.grad)
194
+ state["moment2"] = torch.zeros_like(p.grad)
195
+ moment1.append(_to_local(state["moment1"]))
196
+ moment2.append(_to_local(state["moment2"]))
197
+ if not isinstance(state["step"], torch.Tensor):
198
+ state["step"] = torch.tensor(state["step"],
199
+ dtype=torch.float32,
200
+ device=p.device)
201
+ state_steps.append(state["step"])
202
+
203
+ _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
+ state_steps)
205
+
206
+ params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
+
208
+ lr = group["lr"]
209
+ beta1, beta2 = group["adamw_betas"]
210
+ eps = group["adamw_eps"]
211
+ weight_decay = group["weight_decay"]
212
+
213
+ with record_function("adamw::fused_adamw"):
214
+ fused_adamw(
215
+ params_local,
216
+ grads,
217
+ moment1,
218
+ moment2,
219
+ [],
220
+ state_steps,
221
+ amsgrad=False,
222
+ beta1=beta1,
223
+ beta2=beta2,
224
+ lr=lr,
225
+ weight_decay=weight_decay,
226
+ eps=eps,
227
+ maximize=False,
228
+ )
229
+
230
+
231
  def step_adamw(optimizer_state, group):
232
  """Dispatch AdamW step, grouping parameters by type and placement.
233
 
234
+ Placement grouping is cached after the first call since params never
235
+ change their placement between steps.
236
+
237
  Args:
238
  optimizer_state: The optimizer's state dict (self.state in Muon).
239
  group: Parameter group dict.
240
  """
241
  params = group["params"]
242
+ placement_key = id(params)
243
 
244
+ if placement_key not in _placement_cache:
245
+ with record_function("adamw::group_by_placement"):
246
+ placement_to_params: dict[tuple,
247
+ list[torch.Tensor]] = defaultdict(list)
248
+ for p in params:
249
+ match p:
250
+ case DTensor():
251
+ logger.debug(
252
+ "[AdamW] DTensor param: shape=%s, placements=%s, "
253
+ "mesh=%s, grad=%s", p.shape, p.placements,
254
+ p.device_mesh.mesh_dim_names,
255
+ p.grad.shape if p.grad is not None else None)
256
+ placement_to_params[tuple(
257
+ [p.placements, p.device_mesh])].append(p)
258
+ case torch.Tensor():
259
+ logger.debug(
260
+ "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
+ p.grad.shape if p.grad is not None else None)
262
+ placement_to_params[tuple([torch.Tensor,
263
+ None])].append(p)
264
+
265
+ logger.debug("[AdamW] %d placement groups, %d total params",
266
+ len(placement_to_params), len(params))
267
+
268
+ _placement_cache[placement_key] = dict(placement_to_params)
269
+
270
+ for group_params in _placement_cache[placement_key].values():
271
  step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-cu130-x86_64-linux/core.py CHANGED
@@ -1,11 +1,25 @@
 
1
  import math
2
  from dataclasses import dataclass
 
3
 
4
  import torch
5
- import torch.distributed as dist
6
  from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @dataclass
11
  class _muon_state:
@@ -17,26 +31,71 @@ class _muon_state:
17
  qk_clip_state: torch.Tensor | None = None
18
 
19
 
20
- def update_g(optimizer_state, p, g, group, momentum):
21
- """Apply momentum update to gradient.
 
 
 
 
 
 
22
 
23
- Args:
24
- optimizer_state: The optimizer's state dict (self.state in Muon).
25
- p: Parameter tensor.
26
- g: Gradient tensor.
27
- group: Parameter group dict.
28
- momentum: Momentum coefficient.
29
 
30
- Returns:
31
- Momentum-updated gradient tensor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
- state = optimizer_state[p]
34
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
- torch.add(g, buf, alpha=momentum, out=buf)
36
- if group["nesterov"]:
37
- g.add_(buf, alpha=momentum)
38
- return g
39
- return buf
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay):
49
  adjusted_lr: Size-adjusted learning rate.
50
  weight_decay: Weight decay coefficient.
51
  """
52
- if isinstance(p, torch.nn.Parameter):
53
- # apply weight decay
54
- p.data.mul_(1 - lr * weight_decay)
55
- # apply update
56
- p.data.add_(u, alpha=-adjusted_lr)
57
- else:
58
- p.mul_(1 - lr * weight_decay)
59
- p.add_(u, alpha=-adjusted_lr)
60
 
61
 
62
  def adjust_lr_for_muon(lr, param_shape):
@@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape):
77
  return adjusted_lr
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def default_is_muon(name, x, expert_keys=None):
81
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
- if any(key in name for key in skip_keys):
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return False
84
  effective_ndim = x.ndim
85
- if expert_keys and any(key in name for key in expert_keys):
 
86
  effective_ndim -= 1
87
- return effective_ndim >= 2
 
 
 
 
 
88
 
89
 
90
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
@@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
92
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
 
94
  muon_params, muon_names = [], []
95
- non_muon_params = []
96
 
97
  for n, p in model.named_parameters():
98
  if not p.requires_grad:
@@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
102
  muon_names.append(n)
103
  else:
104
  non_muon_params.append(p)
 
 
 
 
105
 
106
  return [
107
  {
 
1
+ import logging
2
  import math
3
  from dataclasses import dataclass
4
+ from typing import List
5
 
6
  import torch
 
7
  from torch.distributed import ProcessGroup
8
  from torch.distributed.tensor import DTensor
9
 
10
+ # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
+ # parameter FQNs. Activation checkpointing similarly inserts
12
+ # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
+ # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
+ _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def normalize_fqn(name: str) -> str:
20
+ """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
+ return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
+
23
 
24
  @dataclass
25
  class _muon_state:
 
31
  qk_clip_state: torch.Tensor | None = None
32
 
33
 
34
+ def _batch_momentum(
35
+ grads: List[torch.Tensor],
36
+ momentum_bufs: List[torch.Tensor],
37
+ momentum: torch.Tensor,
38
+ ) -> None:
39
+ """Batched momentum update (no nesterov)."""
40
+ torch._foreach_mul_(momentum_bufs, momentum)
41
+ torch._foreach_add_(momentum_bufs, grads)
42
 
 
 
 
 
 
 
43
 
44
+ def _batch_momentum_nesterov(
45
+ grads: List[torch.Tensor],
46
+ momentum_bufs: List[torch.Tensor],
47
+ momentum: torch.Tensor,
48
+ ) -> None:
49
+ """Batched momentum update with nesterov correction."""
50
+ torch._foreach_mul_(momentum_bufs, momentum)
51
+ torch._foreach_add_(momentum_bufs, grads)
52
+ nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
+ torch._foreach_add_(grads, nesterov_terms)
54
+
55
+
56
+ _compiled_momentum: dict[bool, callable] = {}
57
+ _use_momentum_compile = True
58
+
59
+
60
+ def set_momentum_compile(enabled: bool):
61
+ """Toggle torch.compile for batched momentum."""
62
+ global _use_momentum_compile
63
+ _use_momentum_compile = enabled
64
+
65
+
66
+ def batch_pre_ortho(
67
+ grads: List[torch.Tensor],
68
+ momentum_bufs: List[torch.Tensor],
69
+ momentum: torch.Tensor,
70
+ nesterov: bool,
71
+ ) -> None:
72
+ """Batched momentum update on lists of plain tensors.
73
+
74
+ Mirrors dion's ``muon_update_pre_orthogonalize``.
75
+ Inputs must be plain CUDA tensors (not DTensor).
76
+ Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
+
78
+ When compile is enabled, uses separately compiled functions for
79
+ nesterov=True/False to avoid graph breaks from the branch.
80
  """
81
+ fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
+ if _use_momentum_compile:
83
+ if nesterov not in _compiled_momentum:
84
+ _compiled_momentum[nesterov] = torch.compile(fn)
85
+ fn = _compiled_momentum[nesterov]
86
+ fn(grads, momentum_bufs, momentum)
87
+
88
+
89
+ def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
+ """Weight-decay + update on plain tensors.
91
+
92
+ Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
+ lookup per call × 256+ params = massive overhead. The pipeline path uses
94
+ batched _foreach_* ops instead; this function remains for base() and
95
+ distributed_muon().
96
+ """
97
+ p_data.mul_(1 - lr * weight_decay)
98
+ p_data.add_(u_data, alpha=-adjusted_lr)
99
 
100
 
101
  def update_p(p, u, lr, adjusted_lr, weight_decay):
 
108
  adjusted_lr: Size-adjusted learning rate.
109
  weight_decay: Weight decay coefficient.
110
  """
111
+ # Unwrap Parameter -> underlying data tensor.
112
+ p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
+ # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
+ if isinstance(p_data, DTensor):
115
+ p_data = p_data._local_tensor
116
+ u_data = u._local_tensor if isinstance(u, DTensor) else u
117
+ _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
 
118
 
119
 
120
  def adjust_lr_for_muon(lr, param_shape):
 
135
  return adjusted_lr
136
 
137
 
138
+ def _match_key(parts, key):
139
+ """Check if key matches as contiguous components in parts.
140
+
141
+ Single-component keys (e.g. "experts") match any single component.
142
+ Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
+ """
144
+ key_parts = key.split(".")
145
+ key_len = len(key_parts)
146
+ if key_len == 1:
147
+ return key in parts
148
+ return any(parts[i:i + key_len] == key_parts
149
+ for i in range(len(parts) - key_len + 1))
150
+
151
+
152
+ def is_expert_param(name, expert_keys):
153
+ """Check if a parameter name matches any expert key (component-level)."""
154
+ if not expert_keys:
155
+ return False
156
+ parts = normalize_fqn(name).split(".")
157
+ return any(_match_key(parts, key) for key in expert_keys)
158
+
159
+
160
  def default_is_muon(name, x, expert_keys=None):
161
+ normalized = normalize_fqn(name)
162
+ parts = normalized.split(".")
163
+ skip_keys = [
164
+ "embed_tokens",
165
+ "lm_head",
166
+ "tok_embeddings",
167
+ "output",
168
+ "mhc_attn",
169
+ "mhc_ffn",
170
+ "lambda_proj",
171
+ ]
172
+ if any(key in parts for key in skip_keys):
173
+ logger.info(
174
+ "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
+ normalized, name, x.ndim)
176
  return False
177
  effective_ndim = x.ndim
178
+ is_expert = is_expert_param(name, expert_keys)
179
+ if is_expert:
180
  effective_ndim -= 1
181
+ result = effective_ndim >= 2
182
+ logger.info(
183
+ "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
+ normalized, name, x.ndim, is_expert, effective_ndim,
185
+ "Muon" if result else "AdamW")
186
+ return result
187
 
188
 
189
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
 
191
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
 
193
  muon_params, muon_names = [], []
194
+ non_muon_params, non_muon_names = [], []
195
 
196
  for n, p in model.named_parameters():
197
  if not p.requires_grad:
 
201
  muon_names.append(n)
202
  else:
203
  non_muon_params.append(p)
204
+ non_muon_names.append(n)
205
+
206
+ logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
+ expert_keys, len(muon_names), len(non_muon_names))
208
 
209
  return [
210
  {
build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CPU offloading for optimizer states.
2
+
3
+ Manages a pinned CPU memory pool and async CUDA streams to offload
4
+ optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
+ optimizer steps, freeing GPU memory.
6
+
7
+ All tracked tensors are packed into a single flat pinned CPU buffer
8
+ (per dtype). D2H and H2D copies are performed per-tensor directly
9
+ between individual GPU tensors and their slice of the CPU flat buffer
10
+ — no GPU staging buffer is allocated, so there is **no temporary GPU
11
+ memory spike** during offload or reload.
12
+
13
+ Individual tensor storages are freed after offload via
14
+ ``untyped_storage().resize_(0)``, preserving tensor identity so
15
+ downstream caches remain valid.
16
+ """
17
+
18
+ import logging
19
+ from collections import defaultdict
20
+
21
+ import torch
22
+ from torch.distributed.tensor import DTensor
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class CPUOffloadPool:
28
+ """Pinned CPU memory pool for async optimizer state offloading.
29
+
30
+ Tracked tensors are grouped by dtype. Each group gets a single flat
31
+ pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
+ the flat buffer) to avoid allocating a GPU staging buffer.
33
+ """
34
+
35
+ def __init__(self):
36
+ self._managed: list[torch.Tensor] = []
37
+ self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
+
39
+ # Per-dtype group: populated on first offload.
40
+ # dtype → dict with keys:
41
+ # "indices" : list[int] managed-list indices
42
+ # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
+ # "total" : int total numel
44
+ # "cpu_flat" : Tensor pinned CPU buffer
45
+ self._groups: dict[torch.dtype, dict] = {}
46
+
47
+ self._offload_stream: torch.cuda.Stream | None = None
48
+ self._device: torch.device | None = None
49
+ self._initialized: bool = False
50
+ self._logged: bool = False
51
+
52
+ # ------------------------------------------------------------------
53
+ @staticmethod
54
+ def _local(t: torch.Tensor) -> torch.Tensor:
55
+ """Unwrap DTensor to its local CUDA tensor."""
56
+ return t._local_tensor if isinstance(t, DTensor) else t
57
+
58
+ def _ensure_stream(self):
59
+ if self._offload_stream is None:
60
+ self._offload_stream = torch.cuda.Stream(device=self._device)
61
+
62
+ # ------------------------------------------------------------------
63
+ def track(self, tensor: torch.Tensor):
64
+ """Register a GPU tensor for CPU offloading. Idempotent."""
65
+ tid = id(tensor)
66
+ if tid in self._storage_nbytes:
67
+ return
68
+ local = self._local(tensor)
69
+ if self._device is None:
70
+ self._device = local.device
71
+ self._storage_nbytes[tid] = local.untyped_storage().size()
72
+ self._managed.append(tensor)
73
+
74
+ # ------------------------------------------------------------------
75
+ def _init_buffers(self):
76
+ """Build per-dtype flat buffers on first offload."""
77
+ # Group managed tensors by dtype.
78
+ dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
79
+ for idx, t in enumerate(self._managed):
80
+ local = self._local(t)
81
+ dtype_map[local.dtype].append((idx, local.numel()))
82
+
83
+ total_cpu_bytes = 0
84
+ for dtype, entries in dtype_map.items():
85
+ offsets: list[tuple[int, int]] = []
86
+ indices: list[int] = []
87
+ off = 0
88
+ for idx, n in entries:
89
+ indices.append(idx)
90
+ offsets.append((off, n))
91
+ off += n
92
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
93
+ self._groups[dtype] = {
94
+ "indices": indices,
95
+ "offsets": offsets,
96
+ "total": off,
97
+ "cpu_flat": cpu_flat,
98
+ }
99
+ total_cpu_bytes += off * cpu_flat.element_size()
100
+
101
+ self._initialized = True
102
+ logger.info(
103
+ "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
104
+ "%.2f MB pinned CPU memory",
105
+ len(self._managed),
106
+ len(self._groups),
107
+ total_cpu_bytes / (1024**2),
108
+ )
109
+
110
+ # ------------------------------------------------------------------
111
+ def offload(self):
112
+ """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
113
+ if not self._managed:
114
+ return
115
+ if not self._initialized:
116
+ self._init_buffers()
117
+ self._ensure_stream()
118
+
119
+ # Offload stream waits for compute to finish.
120
+ compute_event = torch.cuda.current_stream(
121
+ self._device).record_event()
122
+ self._offload_stream.wait_event(compute_event)
123
+
124
+ offloaded_bytes = 0
125
+
126
+ # Per-tensor D2H copies directly into CPU flat buffer slices.
127
+ # No GPU staging buffer → no temporary GPU memory spike.
128
+ with torch.cuda.stream(self._offload_stream):
129
+ for dtype, grp in self._groups.items():
130
+ indices = grp["indices"]
131
+ offsets = grp["offsets"]
132
+ cpu_flat = grp["cpu_flat"]
133
+
134
+ for i, mgd_idx in enumerate(indices):
135
+ local = self._local(self._managed[mgd_idx])
136
+ off, n = offsets[i]
137
+ cpu_flat[off:off + n].copy_(
138
+ local.reshape(-1), non_blocking=True)
139
+
140
+ offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
+
142
+ # Wait for all D2H copies to land, then free GPU storage.
143
+ self._offload_stream.synchronize()
144
+ for t in self._managed:
145
+ self._local(t).untyped_storage().resize_(0)
146
+
147
+ if not self._logged:
148
+ logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
149
+ offloaded_bytes / (1024**2))
150
+
151
+ # ------------------------------------------------------------------
152
+ def reload(self):
153
+ """Per-tensor H2D from CPU flat buffer on the default stream.
154
+
155
+ Runs on the current (default) CUDA stream to avoid stream
156
+ interaction issues with the parallel Muon pipeline. Since
157
+ pinned CPU memory is the source, the copies overlap with
158
+ GPU idle time between steps.
159
+ """
160
+ if not self._managed or not self._initialized:
161
+ return
162
+
163
+ reloaded_bytes = 0
164
+
165
+ # Re-allocate all GPU storages first.
166
+ for t in self._managed:
167
+ local = self._local(t)
168
+ local.untyped_storage().resize_(self._storage_nbytes[id(t)])
169
+
170
+ # Per-tensor H2D copies from CPU flat buffer slices.
171
+ # non_blocking=True with pinned source allows DMA overlap.
172
+ for dtype, grp in self._groups.items():
173
+ indices = grp["indices"]
174
+ offsets = grp["offsets"]
175
+ cpu_flat = grp["cpu_flat"]
176
+
177
+ for i, mgd_idx in enumerate(indices):
178
+ local = self._local(self._managed[mgd_idx])
179
+ off, n = offsets[i]
180
+ local.reshape(-1).copy_(
181
+ cpu_flat[off:off + n], non_blocking=True)
182
+
183
+ reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
+
185
+ if not self._logged:
186
+ logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)",
187
+ reloaded_bytes / (1024**2))
188
+ self._logged = True
build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py CHANGED
@@ -72,12 +72,6 @@ def get_slices_of_dtensor(
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
75
- if curr_size % num_chunks != 0:
76
- raise NotImplementedError(
77
- f"Dimension size {curr_size} is not divisible "
78
- f"by number of ranks {num_chunks} for shard "
79
- f"placement on dim {shard_dim}. (shape: {target.shape})")
80
-
81
  # Compute indices for this level of sharding
82
  if isinstance(placement, _StridedShard):
83
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
 
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
 
 
 
 
 
 
75
  # Compute indices for this level of sharding
76
  if isinstance(placement, _StridedShard):
77
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -43,6 +43,7 @@ def get_autotune_config():
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
 
46
  )
47
  @triton.jit
48
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
@@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
102
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
 
104
 
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
  d_in = d_in.contiguous()
116
  M, K = d_in.shape
117
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
@@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
46
+ restore_value=['y'],
47
  )
48
  @triton.jit
49
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
 
103
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
 
105
 
106
+ @torch.library.custom_op("muon::matmul_transpose_assign",
107
+ mutates_args=("d_out", ))
108
+ def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
+ """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
 
 
 
 
 
 
110
  d_in = d_in.contiguous()
111
  M, K = d_in.shape
112
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
 
114
  with torch.cuda.device(d_in.device.index):
115
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
  d_out.stride(0), d_out.stride(1))
117
+
118
+
119
+ @matmul_transpose_assign.register_fake
120
+ def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
+ """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
+ pass
build/torch210-cxx11-cu130-x86_64-linux/muon.py CHANGED
@@ -10,13 +10,16 @@ from torch.profiler import record_function
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon,
14
- get_default_muon_param_groups, update_g, update_p)
 
15
  from .distributed.utils import (_is_shard, construct_shard_mesh,
16
  get_slices_of_dtensor)
17
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
- _zeropower_via_newtonschulz5)
19
- from .pipeline import muon_chunk_pipeline
 
 
20
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
@@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys):
45
  expanded_params = []
46
 
47
  for n, p in zip(names, params):
48
- is_expert = expert_keys and any(key in n for key in expert_keys)
49
  is_dtensor = isinstance(p.data, DTensor)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if not is_expert:
52
  assert p.data.ndim <= 2, (
53
  f"Param {n} has ndim={p.data.ndim} but does not match "
@@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer):
168
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
  expert_keys: List of strings to identify expert-parallel parameters.
173
  If any key appears in a parameter's name, its outermost
174
  dimension is treated as the expert dimension and expanded
@@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer):
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
- small_param_numel_threshold=65536,
197
- expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
@@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer):
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
- self.small_param_numel_threshold = small_param_numel_threshold
232
  self.expert_keys = expert_keys
 
 
 
 
 
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
@@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer):
333
  if g is None:
334
  continue
335
 
336
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
- steps=group["ns_steps"])
338
 
339
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
  update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer):
355
  weight_decay: float,
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
- """ Implementation of Distributed Muon by Liu et al. """
359
 
360
- # Momentum is already applied by _step_muon before this method.
361
- for n, p in zip(names, params):
362
- g = p.grad
363
- if g is None:
364
- continue
365
-
366
- # Gather G
367
- if isinstance(p.data, DTensor):
368
- g_full = g.full_tensor()
369
- p_full = p.data.full_tensor()
370
- else:
371
- g_full = g
372
- p_full = p
373
-
374
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
- steps=group["ns_steps"])
376
-
377
- adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
- update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- scales_full = compute_scales(
383
- p_full, qk_clip_state) if qk_clip_state is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
- if scales_full is not None:
386
- qk_clip(p_full, scales_full, qk_clip_state.head_dim)
 
 
387
 
388
- if isinstance(p.data, DTensor):
389
- ndims = len(p.device_mesh.mesh.shape)
390
- p_replicate = DTensor.from_local(
391
- p_full,
392
- device_mesh=p.device_mesh,
393
- placements=[Replicate() for _ in range(ndims)],
394
- )
395
 
396
- p_sharded = p_replicate.redistribute(
397
- device_mesh=p.device_mesh,
398
- placements=p.placements,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  )
400
 
401
- p.copy_(p_sharded)
402
 
403
- def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
406
 
@@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer):
409
  interleaves multiple chunks so that communication and computation
410
  overlap across chunks (the same overlap previously achieved by the
411
  warmup + main-loop index scheduling).
 
 
 
 
412
  """
413
 
414
  # Momentum is already applied by _step_muon before this method.
415
 
416
- param_to_state, ordered_params = self.init_state_and_assign_params(
417
- names, params, group, qk_logits)
418
-
419
- # Compute local rank for this group's shard process group.
420
- shard_pg = param_to_state[id(ordered_params[0])].process_group
421
- rank = dist.get_rank(group=shard_pg)
422
-
423
- if self.chunk_size == -1:
424
- shard_ranks = dist.get_world_size(param_to_state[id(
425
- ordered_params[0])].process_group)
426
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
- elif self.chunk_size > 0:
428
- chunk_size = self.chunk_size
429
- else:
430
- raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
  def pipelines():
 
433
  for start in range(0, len(ordered_params), chunk_size):
434
  chunk = ordered_params[start:start + chunk_size]
435
  if chunk:
436
- yield muon_chunk_pipeline(
437
  params=chunk,
438
  param_to_state=param_to_state,
439
  rank=rank,
@@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer):
442
  weight_decay=weight_decay,
443
  none_grad=group["none_grad"],
444
  )
 
 
 
 
445
 
446
- with record_function("muon::barrier"):
447
- dist.barrier()
448
  with record_function("muon::pipeline"):
449
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
450
 
@@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer):
456
  names = group["names"]
457
 
458
  # Apply momentum to all params before routing/expansion.
 
459
  with record_function("muon::momentum"):
460
- for n, p in zip(names, params):
461
- g = p.grad
462
- if g is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  continue
464
- g = update_g(self.state, p, g, group, momentum)
465
- p.grad = g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Expand expert params by splitting on dim 0.
468
- names, params = _expand_expert_params(names, params, self.expert_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  param_dtensors = []
471
  name_dtensors = []
@@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer):
473
  param_tensors = []
474
  name_tensors = []
475
 
476
- param_dtensors_small = []
477
- name_dtensors_small = []
478
-
479
  if self.use_distributed_muon:
 
480
  self.distributed_muon(names=names,
481
  params=params,
482
  group=group,
@@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer):
485
  qk_logits=qk_logits)
486
  return
487
 
488
- # For simplicity, we use distributed Muon for small parameters
489
- # whose number of elements is below a threshold.
490
  for n, p in zip(names, params):
491
  if p is None or p.grad is None:
492
  continue
@@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer):
494
  if all(
495
  isinstance(placement, Replicate)
496
  for placement in p.placements):
 
 
 
497
  param_tensors.append(p)
498
  name_tensors.append(n)
499
- elif p.data.numel() <= self.small_param_numel_threshold:
500
- param_dtensors_small.append(p)
501
- name_dtensors_small.append(n)
502
  else:
 
 
 
 
503
  param_dtensors.append(p)
504
  name_dtensors.append(n)
505
  elif isinstance(p.data, torch.Tensor):
 
 
506
  param_tensors.append(p)
507
  name_tensors.append(n)
508
  else:
509
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
510
 
511
- logger.debug(
512
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
513
- f"{len(param_dtensors_small)} Small DTensors")
514
 
515
  def group_dtensors(dtensors, names):
516
  # To support different placements, we group parameters by placements
@@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer):
526
  p.device_mesh])][1].append(p)
527
  return placement_to_params
528
 
529
- if len(param_dtensors_small) > 0:
530
- if not dist.is_initialized():
531
- raise RuntimeError(
532
- "Parallel Muon requires torch.distributed to be initialized."
533
- )
534
-
535
- self.distributed_muon(
536
- params=param_dtensors_small,
537
- names=name_dtensors_small,
538
- group=group,
539
- lr=lr,
540
- weight_decay=weight_decay,
541
- qk_logits=qk_logits,
542
- )
543
-
544
  if len(param_dtensors) > 0:
545
  if not dist.is_initialized():
546
  raise RuntimeError(
@@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer):
548
  )
549
 
550
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  for _, (names, params) in dtensor_group.items():
 
 
552
  self.parallel(
553
  names,
554
  params,
@@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer):
556
  lr=lr,
557
  weight_decay=weight_decay,
558
  qk_logits=qk_logits,
 
559
  )
 
 
560
 
561
  if len(param_tensors) > 0:
562
  self.base(
@@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer):
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
@@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer):
585
  with torch.enable_grad():
586
  loss = closure()
587
 
588
- for group in self.param_groups:
 
 
 
 
 
 
 
589
  if group["use_muon"]:
 
 
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
 
 
 
592
  step_adamw(self.state, group)
593
 
 
 
 
 
 
 
 
594
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
+ get_default_muon_param_groups, is_expert_param, update_p)
15
+ from .cpu_offload import CPUOffloadPool
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
18
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
+ _zeropower_via_newtonschulz5,
20
+ zeropower_via_newtonschulz5,
21
+ zeropower_via_newtonschulz5_batched)
22
+ from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
23
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
24
 
25
  logger = logging.getLogger(__name__)
 
48
  expanded_params = []
49
 
50
  for n, p in zip(names, params):
51
+ is_expert = is_expert_param(n, expert_keys)
52
  is_dtensor = isinstance(p.data, DTensor)
53
 
54
+ if is_expert:
55
+ if is_dtensor:
56
+ logger.debug(
57
+ "[expand_expert] %s: expert DTensor, shape=%s, "
58
+ "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
59
+ p.placements, p.device_mesh.mesh_dim_names,
60
+ p.to_local().shape)
61
+ else:
62
+ logger.debug(
63
+ "[expand_expert] %s: expert plain tensor, shape=%s", n,
64
+ p.data.shape)
65
+
66
  if not is_expert:
67
  assert p.data.ndim <= 2, (
68
  f"Param {n} has ndim={p.data.ndim} but does not match "
 
183
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
184
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
185
  For testing purpose only.
 
186
  expert_keys: List of strings to identify expert-parallel parameters.
187
  If any key appears in a parameter's name, its outermost
188
  dimension is treated as the expert dimension and expanded
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None,
211
+ cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
 
242
  self.warmup_step = warmup_step
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
 
245
  self.expert_keys = expert_keys
246
+ self.cpu_offload = cpu_offload
247
+ self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
+ self._offload_initialized = False
249
+ self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
+ self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
251
 
252
  def _calc_flops(self, G, steps):
253
  assert len(G.shape) == 2
 
351
  if g is None:
352
  continue
353
 
354
+ u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
355
+ steps=group["ns_steps"])
356
 
357
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
358
  update_p(p, u, lr, adjusted_lr, weight_decay)
 
373
  weight_decay: float,
374
  qk_logits: list[torch.Tensor | DTensor] | None,
375
  ):
376
+ """Batched Distributed Muon for testing/correctness verification only.
377
 
378
+ Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
379
+ the full grad, then slices back to local shards. This is simpler but
380
+ slower than the parallel pipeline (all2all) path, so it serves as a
381
+ reference implementation for verifying correctness.
382
+ """
383
+ with record_function("distributed_muon"):
384
+ # Momentum is already applied by _step_muon before this method.
385
+ ns_steps = group["ns_steps"]
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ # Separate plain tensors (no communication) from DTensors.
388
+ plain_names, plain_params = [], []
389
+ dtensor_names, dtensor_params = [], []
390
+ for n, p in zip(names, params):
391
+ if p.grad is None:
392
+ continue
393
+ if isinstance(p.data, DTensor):
394
+ dtensor_names.append(n)
395
+ dtensor_params.append(p)
396
+ else:
397
+ plain_names.append(n)
398
+ plain_params.append(p)
399
+
400
+ # Process plain tensors per-param (no communication).
401
+ for n, p in zip(plain_names, plain_params):
402
+ u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
403
+ steps=ns_steps)
404
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
405
+ update_p(p, u, lr, adjusted_lr, weight_decay)
406
+
407
+ qk_clip_state = get_qk_clip_info(self.clip_config, n,
408
+ qk_logits)
409
+ scales_full = compute_scales(
410
+ p, qk_clip_state) if qk_clip_state is not None else None
411
+ if scales_full is not None:
412
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
413
+
414
+ if not dtensor_params:
415
+ return
416
+
417
+ # Group DTensors by (placements, mesh) for batched all-gather.
418
+ placement_groups: dict[tuple,
419
+ tuple[list,
420
+ list]] = defaultdict(lambda: ([], []))
421
+ for n, p in zip(dtensor_names, dtensor_params):
422
+ key = (p.placements, p.device_mesh)
423
+ placement_groups[key][0].append(n)
424
+ placement_groups[key][1].append(p)
425
+
426
+ logger.info(
427
+ "distributed_muon: %d placement groups, %d total dtensors",
428
+ len(placement_groups), len(dtensor_params))
429
+
430
+ for (placements, mesh), (grp_names,
431
+ grp_params) in placement_groups.items():
432
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
433
+ placements, mesh)
434
+ rank = dist.get_rank(shard_pg)
435
+ world_size = dist.get_world_size(shard_pg)
436
+
437
+ logger.info(" group: %d params, placements=%s, world_size=%d",
438
+ len(grp_params), placements, world_size)
439
+
440
+ # Separate params that can be batched (all shard dims evenly
441
+ # divisible) from those needing per-param full_tensor
442
+ # (e.g. MoE gate weights with fewer rows than shard ranks).
443
+ # all_gather_into_tensor requires equal buffer sizes across
444
+ # ranks, so uneven splits must use DTensor full_tensor().
445
+ batch_names, batch_params = [], []
446
+ single_names, single_params = [], []
447
+ for n, p in zip(grp_names, grp_params):
448
+ even = all(p.shape[pl.dim] %
449
+ shard_mesh.mesh.shape[dim_idx] == 0
450
+ for dim_idx, pl in enumerate(shard_placements))
451
+ if even:
452
+ batch_names.append(n)
453
+ batch_params.append(p)
454
+ else:
455
+ single_names.append(n)
456
+ single_params.append(p)
457
+
458
+ # Process uneven-split params per-param via full_tensor().
459
+ for n, p in zip(single_names, single_params):
460
+ with record_function("distributed_muon::newton_schulz"):
461
+ g_full = p.grad.full_tensor().to(COMM_DTYPE)
462
+ u_full = _zeropower_via_newtonschulz5(g_full,
463
+ steps=ns_steps)
464
+ del g_full
465
+ with record_function("distributed_muon::update"):
466
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
467
+ p._local_tensor.mul_(1 - lr * weight_decay)
468
+ local_indices = get_slices_of_dtensor(
469
+ p, rank, shard_mesh, shard_placements)
470
+ u_local = u_full[local_indices]
471
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
472
+ del u_full
473
+
474
+ qk_clip_state = get_qk_clip_info(
475
+ self.clip_config, n, qk_logits)
476
+ scales_full = compute_scales(
477
+ p, qk_clip_state
478
+ ) if qk_clip_state is not None else None
479
+ if scales_full is not None:
480
+ ratio = p.shape[0] // scales_full.shape[0]
481
+ idx0 = local_indices[0]
482
+ if isinstance(idx0, slice):
483
+ start = idx0.start or 0
484
+ idx0 = torch.arange(start,
485
+ idx0.stop,
486
+ device=scales_full.device)
487
+ row_scales = scales_full[idx0 // ratio]
488
+ p._local_tensor.mul_(row_scales.view(-1, 1))
489
+
490
+ if not batch_params:
491
+ continue
492
 
493
+ logger.info(" batched=%d, single=%d", len(batch_params),
494
+ len(single_params))
495
+
496
+ # Concat all local grad shards into a single flat buffer.
497
+ with record_function("distributed_muon::gather"):
498
+ grad_locals = [
499
+ p.grad.to_local().to(COMM_DTYPE).flatten()
500
+ for p in batch_params
501
+ ]
502
+ numels = [g.numel() for g in grad_locals]
503
+ grad_concat = torch.cat(grad_locals)
504
+ del grad_locals
505
+
506
+ # Single all-gather (replaces N separate full_tensor).
507
+ grad_gathered = torch.empty(
508
+ grad_concat.numel() * world_size,
509
+ dtype=COMM_DTYPE,
510
+ device="cuda",
511
+ )
512
+ dist.all_gather_into_tensor(grad_gathered,
513
+ grad_concat,
514
+ group=shard_pg)
515
+
516
+ total_numel = grad_concat.numel()
517
+ del grad_concat
518
+
519
+ # Precompute per-param offsets within the concat buffer.
520
+ offsets = []
521
+ off = 0
522
+ for ne in numels:
523
+ offsets.append(off)
524
+ off += ne
525
+
526
+ # Per-param: reconstruct full grad → NS → local update.
527
+ for i, (n, p) in enumerate(zip(batch_names, batch_params)):
528
+ with record_function("distributed_muon::newton_schulz"):
529
+ g_full = torch.empty(p.shape,
530
+ dtype=COMM_DTYPE,
531
+ device="cuda")
532
+ for r in range(world_size):
533
+ r_start = r * total_numel + offsets[i]
534
+ shard = grad_gathered[r_start:r_start + numels[i]]
535
+ indices = get_slices_of_dtensor(
536
+ p, r, shard_mesh, shard_placements)
537
+ g_full[indices] = shard.reshape(
538
+ g_full[indices].shape)
539
+
540
+ u_full = _zeropower_via_newtonschulz5(g_full,
541
+ steps=ns_steps)
542
+ del g_full
543
+
544
+ with record_function("distributed_muon::update"):
545
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
546
+ p._local_tensor.mul_(1 - lr * weight_decay)
547
+ local_indices = get_slices_of_dtensor(
548
+ p, rank, shard_mesh, shard_placements)
549
+ u_local = u_full[local_indices]
550
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
551
+ del u_full
552
+
553
+ qk_clip_state = get_qk_clip_info(
554
+ self.clip_config, n, qk_logits)
555
+ scales_full = compute_scales(
556
+ p, qk_clip_state
557
+ ) if qk_clip_state is not None else None
558
+ if scales_full is not None:
559
+ ratio = p.shape[0] // scales_full.shape[0]
560
+ idx0 = local_indices[0]
561
+ if isinstance(idx0, slice):
562
+ start = idx0.start or 0
563
+ idx0 = torch.arange(start,
564
+ idx0.stop,
565
+ device=scales_full.device)
566
+ row_scales = scales_full[idx0 // ratio]
567
+ p._local_tensor.mul_(row_scales.view(-1, 1))
568
+
569
+ def _setup_parallel(self, names, params, group, qk_logits):
570
+ """Compute (or retrieve cached) parallel pipeline metadata.
571
+
572
+ Returns:
573
+ (ordered_params, param_to_state, rank, chunk_size)
574
+ """
575
+ cache_key = tuple(names)
576
 
577
+ if cache_key not in self._parallel_cache:
578
+ # First call: compute metadata and populate cache.
579
+ param_to_state, ordered_params = self.init_state_and_assign_params(
580
+ names, params, group, qk_logits)
581
 
582
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
583
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
584
 
585
+ if self.chunk_size == -1:
586
+ shard_ranks = dist.get_world_size(shard_pg)
587
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
588
+ elif self.chunk_size > 0:
589
+ chunk_size = self.chunk_size
590
+ else:
591
+ raise ValueError(
592
+ "chunk_size must be -1 or a positive integer.")
593
+
594
+ ordered_names = [
595
+ param_to_state[id(p)].name for p in ordered_params
596
+ ]
597
+ name_to_state = {
598
+ param_to_state[id(p)].name: param_to_state[id(p)]
599
+ for p in ordered_params
600
+ }
601
+ self._parallel_cache[cache_key] = {
602
+ 'ordered_names': ordered_names,
603
+ 'name_to_state': name_to_state,
604
+ 'rank': rank,
605
+ 'chunk_size': chunk_size,
606
+ }
607
+ else:
608
+ # Cached path: rebuild param_to_state with current id(p) keys.
609
+ cache = self._parallel_cache[cache_key]
610
+ rank = cache['rank']
611
+ chunk_size = cache['chunk_size']
612
+
613
+ name_to_param = dict(zip(names, params))
614
+ ordered_params = [name_to_param[n] for n in cache['ordered_names']]
615
+
616
+ param_to_state = {}
617
+ for p, n in zip(ordered_params, cache['ordered_names']):
618
+ cached_state = cache['name_to_state'][n]
619
+ param_to_state[id(p)] = _muon_state(
620
+ worker_rank=cached_state.worker_rank,
621
+ process_group=cached_state.process_group,
622
+ rank_indices=cached_state.rank_indices,
623
+ rank_numels=cached_state.rank_numels,
624
+ name=n,
625
+ qk_clip_state=get_qk_clip_info(self.clip_config, n,
626
+ qk_logits),
627
  )
628
 
629
+ return ordered_params, param_to_state, rank, chunk_size
630
 
631
+ def parallel(self,
632
+ names,
633
+ params,
634
+ group,
635
+ lr,
636
+ weight_decay,
637
+ qk_logits,
638
+ prelaunch_gather=None):
639
  """
640
  Perform a parallel optimization step using Muon.
641
 
 
644
  interleaves multiple chunks so that communication and computation
645
  overlap across chunks (the same overlap previously achieved by the
646
  warmup + main-loop index scheduling).
647
+
648
+ If ``prelaunch_gather`` is provided, it is passed to the first
649
+ chunk's generator to skip re-launching the already in-flight
650
+ A2A gather.
651
  """
652
 
653
  # Momentum is already applied by _step_muon before this method.
654
 
655
+ ordered_params, param_to_state, rank, chunk_size = (
656
+ self._setup_parallel(names, params, group, qk_logits))
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
  def pipelines():
659
+ first = True
660
  for start in range(0, len(ordered_params), chunk_size):
661
  chunk = ordered_params[start:start + chunk_size]
662
  if chunk:
663
+ kwargs = dict(
664
  params=chunk,
665
  param_to_state=param_to_state,
666
  rank=rank,
 
669
  weight_decay=weight_decay,
670
  none_grad=group["none_grad"],
671
  )
672
+ if first and prelaunch_gather is not None:
673
+ kwargs['prelaunch_gather'] = prelaunch_gather
674
+ first = False
675
+ yield muon_chunk_pipeline(**kwargs)
676
 
 
 
677
  with record_function("muon::pipeline"):
678
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
679
 
 
685
  names = group["names"]
686
 
687
  # Apply momentum to all params before routing/expansion.
688
+ # Batched using _foreach_* ops (compiled, fullgraph=True).
689
  with record_function("muon::momentum"):
690
+ active_params = [p for p in params if p.grad is not None]
691
+ if active_params:
692
+ # Ensure momentum buffers exist (avoid zeros_like when already present).
693
+ for p in active_params:
694
+ if "momentum_buffer" not in self.state[p]:
695
+ self.state[p]["momentum_buffer"] = torch.zeros_like(
696
+ p.grad)
697
+
698
+ # Extract local tensors for compiled batch function.
699
+ local_grads = [
700
+ p.grad._local_tensor
701
+ if isinstance(p.grad, DTensor) else p.grad
702
+ for p in active_params
703
+ ]
704
+ local_bufs = [
705
+ self.state[p]["momentum_buffer"]._local_tensor
706
+ if isinstance(self.state[p]["momentum_buffer"], DTensor)
707
+ else self.state[p]["momentum_buffer"]
708
+ for p in active_params
709
+ ]
710
+
711
+ # Wrap momentum as tensor for torch.compile.
712
+ batch_pre_ortho(local_grads, local_bufs,
713
+ torch.tensor(momentum), group["nesterov"])
714
+
715
+ # For non-nesterov, the result is the momentum buffer.
716
+ if not group["nesterov"]:
717
+ for p in active_params:
718
+ p.grad = self.state[p]["momentum_buffer"]
719
+
720
+ # Identify batched experts for deferred NS.
721
+ # Detection is cheap (condition checks only); actual NS compute is
722
+ # deferred so it can overlap with the first chunk's A2A gather.
723
+ deferred_expert_work = []
724
+ if self.expert_keys:
725
+ batched_expert_indices = []
726
+ for i, (n, p) in enumerate(zip(names, params)):
727
+ if not (is_expert_param(n, self.expert_keys)
728
+ and p.grad is not None):
729
  continue
730
+ # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
731
+ if isinstance(p.data, DTensor):
732
+ has_tp = any(
733
+ _is_shard(pl) and pl.dim != 0 for pl in p.placements)
734
+ if has_tp:
735
+ continue
736
+ batched_expert_indices.append(i)
737
+
738
+ if batched_expert_indices:
739
+ # Save refs for deferred NS; free grads from param list.
740
+ for i in batched_expert_indices:
741
+ p = params[i]
742
+ g = p.grad
743
+ local_g = (g._local_tensor
744
+ if isinstance(g, DTensor) else g)
745
+ local_data = (p.data._local_tensor if isinstance(
746
+ p.data, DTensor) else p.data)
747
+ deferred_expert_work.append((local_data, local_g))
748
+ p.grad = None
749
+
750
+ # Remove batched experts from lists before expansion.
751
+ keep = sorted(
752
+ set(range(len(params))) - set(batched_expert_indices))
753
+ names = [names[i] for i in keep]
754
+ params = [params[i] for i in keep]
755
+
756
+ def _run_deferred_expert_ns():
757
+ """Execute deferred batched expert NS."""
758
+ if not deferred_expert_work:
759
+ return
760
+ with record_function("muon::batched_expert_ns"):
761
+ ns_steps = group["ns_steps"]
762
+ for local_data, local_g in deferred_expert_work:
763
+ u = zeropower_via_newtonschulz5_batched(
764
+ local_g.to(COMM_DTYPE), steps=ns_steps)
765
+ adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
766
+ local_data.mul_(1 - lr * weight_decay)
767
+ local_data.add_(u, alpha=-adjusted_lr)
768
 
769
  # Expand expert params by splitting on dim 0.
770
+ logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
771
+ len(params), self.expert_keys)
772
+ if self.expert_keys:
773
+ cache_key = tuple(id(p) for p in params)
774
+ cache = self._expert_expand_cache.get(cache_key)
775
+
776
+ if cache is None:
777
+ # Cold path: full expansion + build cache metadata.
778
+ exp_names, exp_params = _expand_expert_params(
779
+ names, params, self.expert_keys)
780
+
781
+ # Build per-expert-group info for hot-path grad updates.
782
+ grad_info = []
783
+ exp_idx = 0
784
+ for orig_idx, (n, p) in enumerate(zip(names, params)):
785
+ if not is_expert_param(n, self.expert_keys):
786
+ exp_idx += 1
787
+ continue
788
+
789
+ is_dt = isinstance(p.data, DTensor)
790
+ num_experts = (p.to_local() if is_dt else p.data).shape[0]
791
+
792
+ # Detect TP mesh from the first expanded expert param.
793
+ tp_mesh = None
794
+ tp_pls = None
795
+ sample = exp_params[exp_idx]
796
+ if isinstance(sample.data, DTensor):
797
+ tp_mesh = sample.data.device_mesh
798
+ tp_pls = list(sample.data.placements)
799
+
800
+ grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
801
+ tp_mesh, tp_pls))
802
+ exp_idx += num_experts
803
+
804
+ self._expert_expand_cache[cache_key] = {
805
+ 'names': exp_names,
806
+ 'params': exp_params,
807
+ 'grad_info': grad_info,
808
+ }
809
+ names, params = exp_names, exp_params
810
+ else:
811
+ # Hot path: reuse cached params, only update expert grads.
812
+ for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
813
+ tp_pls) in cache['grad_info']:
814
+ p = params[orig_idx]
815
+ g = p.grad
816
+ local_grad = (g.to_local()
817
+ if is_dt and isinstance(g, DTensor) else g)
818
+ for i in range(num_experts):
819
+ expert_p = cache['params'][exp_start + i]
820
+ sg = local_grad[i]
821
+ if tp_mesh is not None:
822
+ expert_p.grad = DTensor.from_local(
823
+ sg, device_mesh=tp_mesh, placements=tp_pls)
824
+ else:
825
+ expert_p.grad = sg
826
+ p.grad = None
827
+
828
+ names = cache['names']
829
+ params = cache['params']
830
+ else:
831
+ names, params = _expand_expert_params(names, params,
832
+ self.expert_keys)
833
+ logger.debug("[_step_muon] after expand: %d params", len(params))
834
 
835
  param_dtensors = []
836
  name_dtensors = []
 
838
  param_tensors = []
839
  name_tensors = []
840
 
841
+ # distributed_muon is a reference implementation for testing only.
842
+ # The parallel pipeline (all2all) path below is the production path.
 
843
  if self.use_distributed_muon:
844
+ _run_deferred_expert_ns()
845
  self.distributed_muon(names=names,
846
  params=params,
847
  group=group,
 
850
  qk_logits=qk_logits)
851
  return
852
 
 
 
853
  for n, p in zip(names, params):
854
  if p is None or p.grad is None:
855
  continue
 
857
  if all(
858
  isinstance(placement, Replicate)
859
  for placement in p.placements):
860
+ logger.debug(
861
+ "[route] %s → base (DTensor all-Replicate), "
862
+ "shape=%s, placements=%s", n, p.shape, p.placements)
863
  param_tensors.append(p)
864
  name_tensors.append(n)
 
 
 
865
  else:
866
+ logger.debug(
867
+ "[route] %s → parallel (DTensor), shape=%s, "
868
+ "placements=%s, mesh=%s", n, p.shape, p.placements,
869
+ p.device_mesh.mesh_dim_names)
870
  param_dtensors.append(p)
871
  name_dtensors.append(n)
872
  elif isinstance(p.data, torch.Tensor):
873
+ logger.debug("[route] %s → base (plain tensor), shape=%s", n,
874
+ p.data.shape)
875
  param_tensors.append(p)
876
  name_tensors.append(n)
877
  else:
878
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
879
 
880
+ logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
881
+ f"{len(param_tensors)} Tensors → base")
 
882
 
883
  def group_dtensors(dtensors, names):
884
  # To support different placements, we group parameters by placements
 
894
  p.device_mesh])][1].append(p)
895
  return placement_to_params
896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  if len(param_dtensors) > 0:
898
  if not dist.is_initialized():
899
  raise RuntimeError(
 
901
  )
902
 
903
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
904
+
905
+ # Pre-launch the first chunk's A2A gather so that the NCCL
906
+ # communication overlaps with the (deferred) batched expert NS
907
+ # compute on the default CUDA stream.
908
+ prelaunch = None
909
+ if deferred_expert_work:
910
+ first_names, first_params = next(iter(dtensor_group.values()))
911
+ ordered, pts, rnk, csz = self._setup_parallel(
912
+ first_names, first_params, group, qk_logits)
913
+ first_chunk = ordered[:csz]
914
+ if first_chunk:
915
+ prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
916
+ group["none_grad"])
917
+
918
+ _run_deferred_expert_ns()
919
+
920
+ first_group = True
921
  for _, (names, params) in dtensor_group.items():
922
+ pg = prelaunch if first_group else None
923
+ first_group = False
924
  self.parallel(
925
  names,
926
  params,
 
928
  lr=lr,
929
  weight_decay=weight_decay,
930
  qk_logits=qk_logits,
931
+ prelaunch_gather=pg,
932
  )
933
+ else:
934
+ _run_deferred_expert_ns()
935
 
936
  if len(param_tensors) > 0:
937
  self.base(
 
943
  qk_logits=qk_logits,
944
  )
945
 
946
+ def _register_states_for_offload(self):
947
+ """Register all optimizer state tensors with the CPU offload pool.
948
+
949
+ Called once after the first step when states have been lazily created.
950
+ Offloads all param states (momentum buffers for Muon, moment1/moment2
951
+ for AdamW) to free GPU memory between steps.
952
+ """
953
+ pool = self._cpu_offload_pool
954
+ tracked = 0
955
+ for group in self.param_groups:
956
+ for p in group["params"]:
957
+ if p not in self.state:
958
+ continue
959
+ state = self.state[p]
960
+ if group.get("use_muon", False):
961
+ if "momentum_buffer" in state:
962
+ pool.track(state["momentum_buffer"])
963
+ tracked += 1
964
+ else:
965
+ if "moment1" in state:
966
+ pool.track(state["moment1"])
967
+ if "moment2" in state:
968
+ pool.track(state["moment2"])
969
+ tracked += 1
970
+ logger.info("[CPUOffload] Registered %d param states for offload",
971
+ tracked)
972
+
973
  @torch.no_grad
974
  def step(self, closure=None, qk_logits=None):
975
  """Perform a single optimization step.
 
987
  with torch.enable_grad():
988
  loss = closure()
989
 
990
+ # H2D: reload optimizer states from CPU before computation.
991
+ if self.cpu_offload and self._offload_initialized:
992
+ self._cpu_offload_pool.reload()
993
+
994
+ logger.debug("[Muon.step] expert_keys=%s, %d param groups",
995
+ self.expert_keys, len(self.param_groups))
996
+
997
+ for i, group in enumerate(self.param_groups):
998
  if group["use_muon"]:
999
+ logger.debug("[Muon.step] group %d: use_muon=True, %d params",
1000
+ i, len(group["params"]))
1001
  self._step_muon(group, qk_logits=qk_logits)
1002
  else:
1003
+ logger.debug(
1004
+ "[Muon.step] group %d: use_muon=False (AdamW), %d params",
1005
+ i, len(group["params"]))
1006
  step_adamw(self.state, group)
1007
 
1008
+ # D2H: offload optimizer states to CPU after computation.
1009
+ if self.cpu_offload:
1010
+ if not self._offload_initialized:
1011
+ self._register_states_for_offload()
1012
+ self._offload_initialized = True
1013
+ self._cpu_offload_pool.offload()
1014
+
1015
  return loss
1016
+
1017
+ # ------------------------------------------------------------------
1018
+ # Checkpoint support for cpu_offload
1019
+ # ------------------------------------------------------------------
1020
+
1021
+ def state_dict(self) -> dict:
1022
+ """Return optimizer state dict, reloading offloaded states first.
1023
+
1024
+ When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
+ storage freed (``resize_(0)``) between steps. We reload them,
1026
+ snapshot the state dict, then re-offload so the optimizer stays
1027
+ in the expected post-step state. The returned dict holds cloned
1028
+ tensors so they remain valid after the re-offload frees the
1029
+ originals' GPU storage.
1030
+ """
1031
+ if self.cpu_offload and self._offload_initialized:
1032
+ self._cpu_offload_pool.reload()
1033
+ torch.cuda.current_stream().synchronize()
1034
+ sd = super().state_dict()
1035
+ if self.cpu_offload and self._offload_initialized:
1036
+ # Clone state tensors so the returned dict survives re-offload
1037
+ # (which frees GPU storage on the originals via resize_(0)).
1038
+ for k in sd["state"]:
1039
+ sd["state"][k] = {
1040
+ sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
+ for sk, sv in sd["state"][k].items()
1042
+ }
1043
+ self._cpu_offload_pool.offload()
1044
+ return sd
1045
+
1046
+ def load_state_dict(self, state_dict: dict) -> None:
1047
+ """Load optimizer state dict, then offload states if needed.
1048
+
1049
+ After ``super().load_state_dict()`` populates GPU tensors, we
1050
+ re-register them with the offload pool and offload to CPU so the
1051
+ optimizer is in the same post-step state (GPU storage freed).
1052
+ """
1053
+ # If states were offloaded, reload first so storage sizes are
1054
+ # correct for super().load_state_dict() to overwrite.
1055
+ if self.cpu_offload and self._offload_initialized:
1056
+ self._cpu_offload_pool.reload()
1057
+ torch.cuda.current_stream().synchronize()
1058
+
1059
+ super().load_state_dict(state_dict)
1060
+
1061
+ if self.cpu_offload:
1062
+ # Re-create the offload pool since state tensors may be new
1063
+ # objects after load_state_dict.
1064
+ self._cpu_offload_pool = CPUOffloadPool()
1065
+ self._offload_initialized = False
1066
+ self._register_states_for_offload()
1067
+ self._offload_initialized = True
1068
+ self._cpu_offload_pool.offload()
build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import torch
2
 
3
  from .matmul_transpose_triton import matmul_transpose_assign
@@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16
6
  DEFAULT_CHUNK_SIZE_RATIO = 4
7
 
8
 
9
- # This code snippet is a modified version adapted from the following GitHub repositories:
10
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
- # Muon's Newton–Schulz iteration causes high variance in singular values
12
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @torch.no_grad()
14
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
  def _zeropower_via_newtonschulz5(G, steps):
16
  """
17
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
- performance at all relative to UV^T, where USV^T = G is the SVD.
 
 
 
 
 
 
 
24
  """
25
  assert len(G.shape) == 2
26
  assert G.dtype == COMM_DTYPE
@@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps):
28
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
- # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
 
 
33
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
  # Perform the NS iterations
36
- for a, b, c in [
37
- (4.0848, -6.8946, 2.9270),
38
- (3.9505, -6.3029, 2.6377),
39
- (3.7418, -5.5913, 2.3037),
40
- (2.8769, -3.1427, 1.2046),
41
- (2.8366, -3.0525, 1.2012),
42
- ]:
43
  matmul_transpose_assign(X, buf1)
44
  matmul_transpose_assign(buf1, buf2)
45
  buf1.mul_(b).add_(buf2, alpha=c)
@@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps):
47
 
48
  if G.size(0) > G.size(1):
49
  X = X.T
 
50
  return X
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ from math import inf, sqrt
3
+
4
+ import numpy as np
5
  import torch
6
 
7
  from .matmul_transpose_triton import matmul_transpose_assign
 
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
 
13
+ def _optimal_quintic(l, u, max_iter=1000):
14
+ """
15
+ Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
+ to the constant function x -> 1 over the interval [l, u].
17
+
18
+ Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
+ approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
+ two interior equioscillation nodes q, r until convergence. Returns the
21
+ closed-form equioscillating solution when l ≈ u.
22
+
23
+ Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
+ (NaN or inf). Raises RuntimeError if convergence is not reached within
25
+ max_iter iterations.
26
+ """
27
+ assert 0 <= l <= u
28
+ if 1 - 5e-6 <= l / u:
29
+ return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
+ q = (3 * l + u) / 4
31
+ r = (l + 3 * u) / 4
32
+ E = inf
33
+ for _ in range(max_iter):
34
+ old_E = E
35
+ LHS = np.array([
36
+ [l, l**3, l**5, 1],
37
+ [q, q**3, q**5, -1],
38
+ [r, r**3, r**5, 1],
39
+ [u, u**3, u**5, -1],
40
+ ])
41
+ a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
+ if not np.all(np.isfinite([a, b, c, E])):
43
+ raise ValueError(f"_optimal_quintic: non-finite solve result "
44
+ f"a={a}, b={b}, c={c}, E={E}")
45
+ q, r = np.sqrt(
46
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
+ (10 * c))
48
+ if not np.all(np.isfinite([q, r])):
49
+ raise ValueError(
50
+ f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
+ if abs(old_E - E) <= 1e-15:
52
+ break
53
+ else:
54
+ raise RuntimeError(
55
+ f"_optimal_quintic: did not converge after {max_iter} iterations")
56
+ return float(a), float(b), float(c)
57
+
58
+
59
+ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
60
+ """
61
+ Compute the Polar Express coefficient series for `num_iters` quintic iterations.
62
+
63
+ Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
64
+ compose to map singular values from [l, 1] toward 1. At each step:
65
+ 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
66
+ prevents near-zero singular values from stalling by raising the effective
67
+ lower bound; if it is active (cushion*u > l), the coefficients are
68
+ rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
69
+ 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
70
+ last iteration, providing numerical headroom at the cost of a slightly slower
71
+ final convergence step.
72
+ 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
73
+
74
+ Returns a list of (a, b, c) tuples, one per iteration.
75
+
76
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
77
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
78
+ """
79
+ u = 1
80
+ assert 0 <= l <= u
81
+ safety_factor = 1 + safety_factor_eps
82
+ coefficients = []
83
+ for iter in range(num_iters):
84
+ a, b, c = _optimal_quintic(max(l, cushion * u), u)
85
+ if cushion * u > l:
86
+ pl = a * l + b * l**3 + c * l**5
87
+ pu = a * u + b * u**3 + c * u**5
88
+ rescaler = 2 / (pl + pu)
89
+ a *= rescaler
90
+ b *= rescaler
91
+ c *= rescaler
92
+ if iter < num_iters - 1:
93
+ a /= safety_factor
94
+ b /= safety_factor**3
95
+ c /= safety_factor**5
96
+ coefficients.append((a, b, c))
97
+ l = a * l + b * l**3 + c * l**5
98
+ u = 2 - l
99
+ return coefficients
100
+
101
+
102
+ # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
103
+ # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
104
+ # approximant to x->1 over the current singular-value interval, computed once at
105
+ # import time and reused across all optimizer steps.
106
+ #
107
+ # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
108
+ # - Former: empirically tuned to maximize slope at zero; did not converge
109
+ # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
110
+ # of the true polar factor UV^T.
111
+ # - Polar Express: analytically optimal per step, adapting to the shrinking
112
+ # singular-value interval [l, u] as iterations progress; converges all
113
+ # singular values to 1, producing the exact polar factor UV^T.
114
+ _coeffs_list = _optimal_composition(l=1e-3,
115
+ num_iters=10,
116
+ safety_factor_eps=1e-2,
117
+ cushion=0.02)
118
+
119
+
120
+ # This code is adapted from:
121
+ # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
122
+ # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
123
+ # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
124
  @torch.no_grad()
 
125
  def _zeropower_via_newtonschulz5(G, steps):
126
  """
127
+ Compute the polar factor of G via the Polar Express method.
128
+
129
+ Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
130
+ are the Polar Express coefficients from `_coeffs_list`. Each step is the
131
+ optimal odd quintic approximant to x -> 1 over the current singular-value
132
+ interval, minimizing the maximum approximation error (Remez / minimax criterion).
133
+ The composition maps singular values from [l, 1] to near 1, producing the
134
+ polar factor (orthogonal factor in the polar decomposition G = UP).
135
+
136
+ `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
137
+ cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
138
+
139
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
140
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
141
  """
142
  assert len(G.shape) == 2
143
  assert G.dtype == COMM_DTYPE
 
145
 
146
  if G.size(0) > G.size(1):
147
  X = X.T
148
+
149
  X = X / (X.norm() + 1e-7)
150
+ hs = _coeffs_list[:steps] + list(
151
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
155
+ for a, b, c in hs:
 
 
 
 
 
 
156
  matmul_transpose_assign(X, buf1)
157
  matmul_transpose_assign(buf1, buf2)
158
  buf1.mul_(b).add_(buf2, alpha=c)
 
160
 
161
  if G.size(0) > G.size(1):
162
  X = X.T
163
+
164
  return X
165
+
166
+
167
+ @torch.no_grad()
168
+ def _zeropower_via_newtonschulz5_batched(G, steps):
169
+ """Batched polar factor computation for 3D (E, out, in) tensors.
170
+
171
+ Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
172
+ ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
173
+ processing all E expert matrices in a single batched call.
174
+ """
175
+ assert len(G.shape) == 3
176
+ assert G.dtype == COMM_DTYPE
177
+ X = G
178
+
179
+ if G.size(1) > G.size(2):
180
+ X = X.transpose(-2, -1)
181
+
182
+ # Per-expert Frobenius norm.
183
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
+
185
+ hs = _coeffs_list[:steps] + list(
186
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
187
+ for a, b, c in hs:
188
+ buf1 = torch.bmm(X, X.transpose(-2, -1))
189
+ buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
190
+ buf1.mul_(b).add_(buf2, alpha=c)
191
+ X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
192
+
193
+ if G.size(1) > G.size(2):
194
+ X = X.transpose(-2, -1)
195
+
196
+ return X
197
+
198
+
199
+ _ns_per_shape: dict[tuple[int, ...], callable] = {}
200
+ _use_compile = True
201
+
202
+
203
+ def set_ns_compile(enabled: bool):
204
+ """Toggle torch.compile for Newton-Schulz iteration."""
205
+ global _use_compile
206
+ _use_compile = enabled
207
+
208
+
209
+ def zeropower_via_newtonschulz5(G, steps=5):
210
+ if not _use_compile:
211
+ return _zeropower_via_newtonschulz5(G, steps)
212
+ key = G.shape
213
+ if key not in _ns_per_shape:
214
+ _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
215
+ options={
216
+ "triton.cudagraphs": True,
217
+ "shape_padding": False
218
+ })
219
+ torch.compiler.cudagraph_mark_step_begin()
220
+ return _ns_per_shape[key](G, steps).clone()
221
+
222
+
223
+ def zeropower_via_newtonschulz5_batched(G, steps=5):
224
+ """Compile-cached batched Newton-Schulz for 3D expert tensors."""
225
+ if not _use_compile:
226
+ return _zeropower_via_newtonschulz5_batched(G, steps)
227
+ key = G.shape
228
+ if key not in _ns_per_shape:
229
+ _ns_per_shape[key] = torch.compile(
230
+ _zeropower_via_newtonschulz5_batched,
231
+ options={
232
+ "triton.cudagraphs": True,
233
+ "shape_padding": False
234
+ })
235
+ torch.compiler.cudagraph_mark_step_begin()
236
+ return _ns_per_shape[key](G, steps).clone()
build/torch210-cxx11-cu130-x86_64-linux/pipeline.py CHANGED
@@ -6,8 +6,8 @@ import torch.distributed as dist
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
- from .core import _muon_state, adjust_lr_for_muon, update_p
10
- from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
@@ -45,26 +45,33 @@ def _launch_gather(
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
- # Build send buffer
49
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
  send_counts = [0] * num_ranks
51
-
52
  for p in params:
53
  state = param_to_state[id(p)]
54
- dst = state.worker_rank
55
- assert dst < num_ranks
56
- shard_elems = state.rank_numels[rank]
57
- g = p.grad
58
- g = g.to_local().to(COMM_DTYPE).contiguous()
59
- assert g.numel() == shard_elems
60
- per_dst[dst].append(g.view(-1))
61
- send_counts[dst] += shard_elems
62
-
63
- assert any(
64
- len(v) > 0 for v in
65
- per_dst), "At least one destination rank must receive a sharded tensor"
66
- per_dst_flat = [t for dst in per_dst for t in dst]
67
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
 
 
 
68
 
69
  # Build recv buffer
70
  recv_counts = [0] * num_ranks
@@ -120,7 +127,8 @@ def _complete_gather(
120
 
121
  shard_view = gathered_grads[id(p)][indices]
122
  n = shard_view.numel()
123
- assert n > 0
 
124
 
125
  sg = recv_buf.narrow(0, off + inner_off, n)
126
  sg = sg.reshape(shard_view.shape)
@@ -143,7 +151,7 @@ def _compute_ns(
143
  """
144
  computed_us: dict[int, torch.Tensor | None] = {}
145
  for p in owned_params:
146
- u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
  gathered_grads[id(p)] = None # free gathered grad
148
  computed_us[id(p)] = u
149
  return computed_us
@@ -163,46 +171,47 @@ def _launch_scatter(
163
  Returns:
164
  work: Async operation handle.
165
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
- scattered_us: ``{id(p): empty_local_tensor}`` for all params.
 
167
  recv_counts: Per-source-rank element counts.
168
  """
169
- # Allocate scattered-u buffers
 
 
 
170
  scattered_us: dict[int, torch.Tensor] = {}
171
  for p in params:
172
- scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
 
 
173
 
174
- # Build send buffer (from computed_us on owner ranks)
175
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
  send_counts = [0] * num_ranks
177
-
178
  if owned_params:
179
  for p in owned_params:
180
  state = param_to_state[id(p)]
181
-
182
- assert computed_us[id(p)] is not None
183
- u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
-
185
- total_sent = 0
186
  for dst_rank in range(num_ranks):
187
- indices = state.rank_indices[dst_rank]
188
- su = u_full[indices].flatten()
189
-
190
- n = su.numel()
191
- assert n > 0
192
 
193
- per_dst[dst_rank].append(su)
194
- send_counts[dst_rank] += n
195
- total_sent += n
196
-
197
- assert total_sent == u_full.numel()
198
-
199
- lengths = [len(v) for v in per_dst]
200
- if all(l > 0 for l in lengths):
201
- assert all(
202
- l == lengths[0] for l in lengths
203
- ), "All destination ranks must have the same number of sharded tensor"
204
- per_dst_flat = [t for dst in per_dst for t in dst]
205
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
206
  else:
207
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
 
@@ -218,7 +227,6 @@ def _launch_scatter(
218
  recv_counts[src] = total
219
 
220
  recv_total = sum(recv_counts)
221
- assert recv_total > 0
222
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
 
224
  # Launch async all-to-all
@@ -242,7 +250,13 @@ def _complete_scatter(
242
  rank: int,
243
  scattered_us: dict[int, torch.Tensor],
244
  ) -> None:
245
- """Copy recv buffer into scattered_us (in-place)."""
 
 
 
 
 
 
246
  off = 0
247
  for src in range(len(recv_counts)):
248
  block = recv_counts[src]
@@ -255,11 +269,11 @@ def _complete_scatter(
255
  if state.worker_rank != src:
256
  continue
257
  n = state.rank_numels[rank]
258
- assert n > 0
 
259
 
260
- flat_local = recv_buf.narrow(0, off + inner_off,
261
- n).view_as(p.to_local())
262
- scattered_us[id(p)].copy_(flat_local)
263
 
264
  inner_off += n
265
 
@@ -275,23 +289,40 @@ def _update_params(
275
  lr: float,
276
  weight_decay: float,
277
  ) -> None:
278
- """Apply weight decay, Muon update, and optional QK clipping."""
279
- for p in params:
280
- state = param_to_state[id(p)]
281
- u_dtensor = DTensor.from_local(
282
- scattered_us[id(p)],
283
- placements=p.placements,
284
- device_mesh=p.device_mesh,
285
- )
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
- update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
 
 
 
289
 
290
- # QK clipping applied directly on the local tensor to
291
- # avoid DTensor sharding-propagation issues with _StridedShard.
292
- scales_full = compute_scales(
293
- p,
294
- state.qk_clip_state) if state.qk_clip_state is not None else None
 
 
 
 
 
295
  if scales_full is not None:
296
  ratio = p.shape[0] // scales_full.shape[0]
297
  idx0 = state.rank_indices[rank][0]
@@ -304,6 +335,45 @@ def _update_params(
304
  p._local_tensor.mul_(row_scales.view(-1, 1))
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # ======================================================================
308
  # Main generator – thin orchestrator that wires stages together.
309
  # ======================================================================
@@ -318,6 +388,7 @@ def muon_chunk_pipeline(
318
  lr: float,
319
  weight_decay: float,
320
  none_grad: bool,
 
321
  ) -> Generator[None, None, None]:
322
  """Process one chunk of parameters through the full Muon pipeline.
323
 
@@ -334,9 +405,12 @@ def muon_chunk_pipeline(
334
  runs concurrently on the NCCL stream — no separate ``comm_stream``
335
  is required.
336
 
 
 
 
337
  Yields exactly **2** times:
338
 
339
- 1. After launching async all-to-all gather.
340
  2. After launching async all-to-all scatter.
341
  """
342
  process_group = param_to_state[id(params[0])].process_group
@@ -345,15 +419,19 @@ def muon_chunk_pipeline(
345
  p for p in params if param_to_state[id(p)].worker_rank == rank
346
  ]
347
 
348
- # Stages 1-2: launch async gather.
349
- with record_function("muon::launch_gather"):
350
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
- params, owned_params, param_to_state, rank, num_ranks,
352
- process_group)
353
-
354
- if none_grad:
355
- for p in params:
356
- p.grad = None
 
 
 
 
357
 
358
  yield # --- YIELD 1: other chunks can launch their gather ---
359
 
 
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
+ from .core import _muon_state, adjust_lr_for_muon
10
+ from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
 
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
+ # Build send buffer – batch grad copies via torch.cat
49
+ # (1-2 fused kernels vs N individual narrow().copy_() calls).
50
  send_counts = [0] * num_ranks
 
51
  for p in params:
52
  state = param_to_state[id(p)]
53
+ send_counts[state.worker_rank] += state.rank_numels[rank]
54
+
55
+ total_send = sum(send_counts)
56
+ if total_send > 0:
57
+ # Group grad slices by destination rank in a single pass.
58
+ dst_to_grads = [[] for _ in range(num_ranks)]
59
+ for p in params:
60
+ state = param_to_state[id(p)]
61
+ n = state.rank_numels[rank]
62
+ if n > 0:
63
+ g = p.grad.to_local()
64
+ dst_to_grads[state.worker_rank].append(g.reshape(-1))
65
+
66
+ # Flatten in dst order and cat once.
67
+ all_slices = []
68
+ for dst in range(num_ranks):
69
+ all_slices.extend(dst_to_grads[dst])
70
+ send_buf = torch.cat(all_slices)
71
+ if send_buf.dtype != COMM_DTYPE:
72
+ send_buf = send_buf.to(COMM_DTYPE)
73
+ else:
74
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
75
 
76
  # Build recv buffer
77
  recv_counts = [0] * num_ranks
 
127
 
128
  shard_view = gathered_grads[id(p)][indices]
129
  n = shard_view.numel()
130
+ if n == 0:
131
+ continue
132
 
133
  sg = recv_buf.narrow(0, off + inner_off, n)
134
  sg = sg.reshape(shard_view.shape)
 
151
  """
152
  computed_us: dict[int, torch.Tensor | None] = {}
153
  for p in owned_params:
154
+ u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
155
  gathered_grads[id(p)] = None # free gathered grad
156
  computed_us[id(p)] = u
157
  return computed_us
 
171
  Returns:
172
  work: Async operation handle.
173
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
174
+ scattered_us: Empty dict, populated by ``_complete_scatter`` with
175
+ zero-copy views into ``recv_buf``.
176
  recv_counts: Per-source-rank element counts.
177
  """
178
+ # scattered_us is populated by _complete_scatter with zero-copy views
179
+ # into recv_buf, avoiding N empty_like allocations + N copy_ calls.
180
+ # Pre-seed entries for params whose local shard is empty (rank_numels == 0)
181
+ # so _update_params can iterate all params without KeyError.
182
  scattered_us: dict[int, torch.Tensor] = {}
183
  for p in params:
184
+ if param_to_state[id(p)].rank_numels[rank] == 0:
185
+ scattered_us[id(p)] = torch.empty_like(p.to_local(),
186
+ dtype=COMM_DTYPE)
187
 
188
+ # Build send buffer batch via torch.cat
189
+ # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
190
  send_counts = [0] * num_ranks
 
191
  if owned_params:
192
  for p in owned_params:
193
  state = param_to_state[id(p)]
 
 
 
 
 
194
  for dst_rank in range(num_ranks):
195
+ send_counts[dst_rank] += state.rank_numels[dst_rank]
 
 
 
 
196
 
197
+ total_send = sum(send_counts)
198
+ if total_send > 0:
199
+ # Cache u_full conversions to avoid redundant .to() per dst_rank.
200
+ u_fulls = {}
201
+ for p in owned_params:
202
+ u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
203
+
204
+ # Collect slices in dst order (matches all-to-all send layout).
205
+ all_slices = []
206
+ for dst_rank in range(num_ranks):
207
+ for p in owned_params:
208
+ state = param_to_state[id(p)]
209
+ su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
210
+ if su.numel() > 0:
211
+ all_slices.append(su)
212
+
213
+ send_buf = torch.cat(all_slices) if all_slices else torch.empty(
214
+ 0, dtype=COMM_DTYPE, device="cuda")
215
  else:
216
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
217
 
 
227
  recv_counts[src] = total
228
 
229
  recv_total = sum(recv_counts)
 
230
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
231
 
232
  # Launch async all-to-all
 
250
  rank: int,
251
  scattered_us: dict[int, torch.Tensor],
252
  ) -> None:
253
+ """Populate scattered_us with zero-copy views into recv_buf.
254
+
255
+ Instead of pre-allocating tensors and copying, we assign views directly
256
+ from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
257
+ The underlying storage of ``recv_buf`` is kept alive through the views
258
+ until ``scattered_us`` is cleared after ``_update_params``.
259
+ """
260
  off = 0
261
  for src in range(len(recv_counts)):
262
  block = recv_counts[src]
 
269
  if state.worker_rank != src:
270
  continue
271
  n = state.rank_numels[rank]
272
+ if n == 0:
273
+ continue
274
 
275
+ scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
276
+ n).view_as(p.to_local())
 
277
 
278
  inner_off += n
279
 
 
289
  lr: float,
290
  weight_decay: float,
291
  ) -> None:
292
+ """Apply weight decay, Muon update, and optional QK clipping.
 
 
 
 
 
 
 
293
 
294
+ Uses batched ``_foreach_mul_`` for weight decay and batched
295
+ ``_foreach_add_`` for the Muon update, grouping parameters by
296
+ adjusted_lr to minimize kernel launches while preserving float32
297
+ precision for the alpha scaling.
298
+ """
299
+ if not params:
300
+ return
301
+
302
+ # Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
303
+ p_locals = [p._local_tensor for p in params]
304
+ torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
305
+
306
+ # Group params by adjusted_lr so _foreach_add_ can use a single
307
+ # alpha per group (preserves float32 precision for alpha scaling).
308
+ lr_groups: dict[float, tuple[list, list]] = {}
309
+ for p in params:
310
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
311
+ if adjusted_lr not in lr_groups:
312
+ lr_groups[adjusted_lr] = ([], [])
313
+ lr_groups[adjusted_lr][0].append(p._local_tensor)
314
+ lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
315
 
316
+ for adjusted_lr, (p_group, u_group) in lr_groups.items():
317
+ torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
318
+
319
+ # QK clipping – applied directly on the local tensor to
320
+ # avoid DTensor sharding-propagation issues with _StridedShard.
321
+ for p in params:
322
+ state = param_to_state[id(p)]
323
+ if state.qk_clip_state is None:
324
+ continue
325
+ scales_full = compute_scales(p, state.qk_clip_state)
326
  if scales_full is not None:
327
  ratio = p.shape[0] // scales_full.shape[0]
328
  idx0 = state.rank_indices[rank][0]
 
335
  p._local_tensor.mul_(row_scales.view(-1, 1))
336
 
337
 
338
+ # ======================================================================
339
+ # Pre-launch helper for overlapping first chunk's gather with other work.
340
+ # ======================================================================
341
+
342
+
343
+ @torch.no_grad()
344
+ def prelaunch_first_gather(
345
+ params: list[DTensor],
346
+ param_to_state: dict[int, _muon_state],
347
+ rank: int,
348
+ none_grad: bool,
349
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
350
+ """Launch the first chunk's A2A gather early for overlap with other compute.
351
+
352
+ Call this *before* expensive GPU work (e.g. batched expert NS) so that
353
+ the NCCL all-to-all runs concurrently on the NCCL stream while the
354
+ default stream executes compute.
355
+
356
+ Returns the same 4-tuple that ``_launch_gather`` produces, which should
357
+ be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
358
+ """
359
+ process_group = param_to_state[id(params[0])].process_group
360
+ num_ranks = dist.get_world_size(group=process_group)
361
+ owned_params = [
362
+ p for p in params if param_to_state[id(p)].worker_rank == rank
363
+ ]
364
+
365
+ with record_function("muon::prelaunch_gather"):
366
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
367
+ params, owned_params, param_to_state, rank, num_ranks,
368
+ process_group)
369
+
370
+ if none_grad:
371
+ for p in params:
372
+ p.grad = None
373
+
374
+ return work, recv_buf, gathered_grads, recv_counts
375
+
376
+
377
  # ======================================================================
378
  # Main generator – thin orchestrator that wires stages together.
379
  # ======================================================================
 
388
  lr: float,
389
  weight_decay: float,
390
  none_grad: bool,
391
+ prelaunch_gather: tuple | None = None,
392
  ) -> Generator[None, None, None]:
393
  """Process one chunk of parameters through the full Muon pipeline.
394
 
 
405
  runs concurrently on the NCCL stream — no separate ``comm_stream``
406
  is required.
407
 
408
+ If ``prelaunch_gather`` is provided, the gather was already launched
409
+ by :func:`prelaunch_first_gather` and we skip launching it again.
410
+
411
  Yields exactly **2** times:
412
 
413
+ 1. After launching async all-to-all gather (or immediately if pre-launched).
414
  2. After launching async all-to-all scatter.
415
  """
416
  process_group = param_to_state[id(params[0])].process_group
 
419
  p for p in params if param_to_state[id(p)].worker_rank == rank
420
  ]
421
 
422
+ if prelaunch_gather is not None:
423
+ # Gather was pre-launched; none_grad already handled by caller.
424
+ work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
425
+ else:
426
+ # Normal path: launch async gather.
427
+ with record_function("muon::launch_gather"):
428
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
429
+ params, owned_params, param_to_state, rank, num_ranks,
430
+ process_group)
431
+
432
+ if none_grad:
433
+ for p in params:
434
+ p.grad = None
435
 
436
  yield # --- YIELD 1: other chunks can launch their gather ---
437
 
build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import dataclass
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
@@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
  'model.4.attn.v_proj.weight' -> (None, -1)
25
  """
26
- parts = name.split('.')
27
  if len(parts) < 3:
28
  return None, -1
29
 
@@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state):
100
  threshold = qk_clip_state.threshold
101
  logit = qk_clip_state.logit
102
 
103
- H_global = p.shape[0] // head_dim
104
- scales_full = torch.ones(H_global, device=p.data.device)
105
- scaling = 0
106
-
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
- if new_scale < scales_full[head_idx]:
112
- scales_full[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
117
- scaling += 1
118
 
119
- return scales_full if scaling > 0 else None
 
 
 
 
 
 
 
120
 
121
 
122
  def qk_clip(p, scales, head_dim):
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
8
+ from .core import normalize_fqn
9
+
10
  logger = logging.getLogger(__name__)
11
 
12
 
 
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
+ parts = normalize_fqn(name).split('.')
29
  if len(parts) < 3:
30
  return None, -1
31
 
 
102
  threshold = qk_clip_state.threshold
103
  logit = qk_clip_state.logit
104
 
105
+ # Check if any head exceeds threshold before allocating.
106
+ head_scales = {}
 
 
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
+ if head_idx not in head_scales or new_scale < head_scales[head_idx]:
112
+ head_scales[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
 
117
 
118
+ if not head_scales:
119
+ return None
120
+
121
+ H_global = p.shape[0] // head_dim
122
+ scales_full = torch.ones(H_global, device=p.data.device)
123
+ for head_idx, scale in head_scales.items():
124
+ scales_full[head_idx] = scale
125
+ return scales_full
126
 
127
 
128
  def qk_clip(p, scales, head_dim):
build/torch210-cxx11-rocm70-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_7aef62f_dirty
3
- ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_7aef62f_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_5b58933_dirty
3
+ ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_5b58933_dirty::{op_name}"
build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:00e9d9e1c2306badb97c3b8f2454a47d6335a302101a38c804ad3c7b075168cc
3
  size 1866400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0102e10121a43f6d5d59a23f2c0e21d88469cc4597d84f7d48b64b0fabfeacdb
3
  size 1866400
build/torch210-cxx11-rocm70-x86_64-linux/adamw.py CHANGED
@@ -1,8 +1,12 @@
 
1
  from collections import defaultdict
2
  from typing import cast
3
 
4
  import torch
5
  from torch.distributed.tensor import DTensor
 
 
 
6
 
7
 
8
  def fused_adamw(
@@ -72,54 +76,72 @@ def fused_adamw(
72
  )
73
 
74
 
75
- def step_adamw_params(optimizer_state, params, group):
76
- """Run fused AdamW on a list of parameters sharing the same placement.
 
77
 
78
- Args:
79
- optimizer_state: The optimizer's state dict (self.state in Muon).
80
- params: List of parameters to update.
81
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  params_with_grads = []
84
  grads = []
85
  moment1 = []
86
  moment2 = []
87
- max_exp_avg_sqs = []
88
  state_steps = []
89
- lr = group["lr"]
90
- beta1, beta2 = group["adamw_betas"]
91
- eps = group["adamw_eps"]
92
- weight_decay = group["weight_decay"]
93
 
94
  for p in params:
95
  g = p.grad
96
  if g is None:
97
  continue
98
  state = optimizer_state[p]
99
- params_with_grads.append(p)
100
- grads.append(g)
101
  if "step" not in state:
102
- state["step"] = (torch.zeros((),
103
- dtype=torch.float32,
104
- device=p.device))
105
  state["moment1"] = torch.zeros_like(g)
106
  state["moment2"] = torch.zeros_like(g)
107
- moment1.append(state["moment1"])
108
- moment2.append(state["moment2"])
109
  if not isinstance(state["step"], torch.Tensor):
110
- step_tensor = torch.tensor(state["step"],
111
- dtype=torch.float32,
112
- device=p.device)
113
- else:
114
- step_tensor = state["step"]
115
- state_steps.append(step_tensor)
 
 
 
 
 
 
116
 
117
  fused_adamw(
118
  params_with_grads,
119
  grads,
120
  moment1,
121
  moment2,
122
- max_exp_avg_sqs,
123
  state_steps,
124
  amsgrad=False,
125
  beta1=beta1,
@@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group):
131
  )
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def step_adamw(optimizer_state, group):
135
  """Dispatch AdamW step, grouping parameters by type and placement.
136
 
 
 
 
137
  Args:
138
  optimizer_state: The optimizer's state dict (self.state in Muon).
139
  group: Parameter group dict.
140
  """
141
  params = group["params"]
 
142
 
143
- # group params with its type and placement
144
- placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
- for p in params:
146
- match p:
147
- case DTensor():
148
- placement_to_params[tuple([p.placements,
149
- p.device_mesh])].append(p)
150
- case torch.Tensor():
151
- placement_to_params[tuple([torch.Tensor, None])].append(p)
152
-
153
- for group_params in placement_to_params.values():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  step_adamw_params(optimizer_state, group_params, group)
 
1
+ import logging
2
  from collections import defaultdict
3
  from typing import cast
4
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ logger = logging.getLogger(__name__)
10
 
11
 
12
  def fused_adamw(
 
76
  )
77
 
78
 
79
+ def _to_local(t):
80
+ """Unwrap DTensor to local tensor for fused ops."""
81
+ return t._local_tensor if isinstance(t, DTensor) else t
82
 
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Caches for eliminating per-step Python overhead.
86
+ #
87
+ # Placement grouping and tensor list assembly are identical every step
88
+ # (params don't change placement, moment/step tensors are the same objects
89
+ # after initialisation). We cache them keyed by id() of the param list
90
+ # stored in param_groups (stable across steps).
91
+ #
92
+ # Only gradients change each step and must be collected fresh.
93
+ # ---------------------------------------------------------------------------
94
+
95
+ # id(group["params"]) → dict[placement_key, list[param]]
96
+ _placement_cache: dict[int, dict[tuple, list]] = {}
97
+
98
+ # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
+ _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
+
101
+
102
+ def _step_adamw_params_slow(optimizer_state, params, group):
103
+ """Uncached fallback for the rare case where some params lack grads."""
104
  params_with_grads = []
105
  grads = []
106
  moment1 = []
107
  moment2 = []
 
108
  state_steps = []
 
 
 
 
109
 
110
  for p in params:
111
  g = p.grad
112
  if g is None:
113
  continue
114
  state = optimizer_state[p]
115
+ params_with_grads.append(_to_local(p))
116
+ grads.append(_to_local(g))
117
  if "step" not in state:
118
+ state["step"] = torch.zeros((),
119
+ dtype=torch.float32,
120
+ device=p.device)
121
  state["moment1"] = torch.zeros_like(g)
122
  state["moment2"] = torch.zeros_like(g)
123
+ moment1.append(_to_local(state["moment1"]))
124
+ moment2.append(_to_local(state["moment2"]))
125
  if not isinstance(state["step"], torch.Tensor):
126
+ state["step"] = torch.tensor(state["step"],
127
+ dtype=torch.float32,
128
+ device=p.device)
129
+ state_steps.append(state["step"])
130
+
131
+ if not params_with_grads:
132
+ return
133
+
134
+ lr = group["lr"]
135
+ beta1, beta2 = group["adamw_betas"]
136
+ eps = group["adamw_eps"]
137
+ weight_decay = group["weight_decay"]
138
 
139
  fused_adamw(
140
  params_with_grads,
141
  grads,
142
  moment1,
143
  moment2,
144
+ [],
145
  state_steps,
146
  amsgrad=False,
147
  beta1=beta1,
 
153
  )
154
 
155
 
156
+ def step_adamw_params(optimizer_state, params, group):
157
+ """Run fused AdamW on a list of parameters sharing the same placement.
158
+
159
+ After the first call, cached tensor lists (params_local, moment1,
160
+ moment2, state_steps) are reused — only gradients are collected fresh.
161
+
162
+ Args:
163
+ optimizer_state: The optimizer's state dict (self.state in Muon).
164
+ params: List of parameters to update.
165
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
+ """
167
+ # Collect grads — the only thing that changes each step.
168
+ with record_function("adamw::collect_grads"):
169
+ grads = []
170
+ for p in params:
171
+ g = p.grad
172
+ if g is None:
173
+ # Rare: fall back to slow path that filters per-param.
174
+ _step_adamw_params_slow(optimizer_state, params, group)
175
+ return
176
+ grads.append(_to_local(g))
177
+
178
+ tensor_key = id(params)
179
+ if tensor_key not in _tensor_cache:
180
+ with record_function("adamw::init_tensor_cache"):
181
+ params_local = []
182
+ moment1 = []
183
+ moment2 = []
184
+ state_steps = []
185
+
186
+ for p in params:
187
+ state = optimizer_state[p]
188
+ params_local.append(_to_local(p))
189
+ if "step" not in state:
190
+ state["step"] = torch.zeros((),
191
+ dtype=torch.float32,
192
+ device=p.device)
193
+ state["moment1"] = torch.zeros_like(p.grad)
194
+ state["moment2"] = torch.zeros_like(p.grad)
195
+ moment1.append(_to_local(state["moment1"]))
196
+ moment2.append(_to_local(state["moment2"]))
197
+ if not isinstance(state["step"], torch.Tensor):
198
+ state["step"] = torch.tensor(state["step"],
199
+ dtype=torch.float32,
200
+ device=p.device)
201
+ state_steps.append(state["step"])
202
+
203
+ _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
+ state_steps)
205
+
206
+ params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
+
208
+ lr = group["lr"]
209
+ beta1, beta2 = group["adamw_betas"]
210
+ eps = group["adamw_eps"]
211
+ weight_decay = group["weight_decay"]
212
+
213
+ with record_function("adamw::fused_adamw"):
214
+ fused_adamw(
215
+ params_local,
216
+ grads,
217
+ moment1,
218
+ moment2,
219
+ [],
220
+ state_steps,
221
+ amsgrad=False,
222
+ beta1=beta1,
223
+ beta2=beta2,
224
+ lr=lr,
225
+ weight_decay=weight_decay,
226
+ eps=eps,
227
+ maximize=False,
228
+ )
229
+
230
+
231
  def step_adamw(optimizer_state, group):
232
  """Dispatch AdamW step, grouping parameters by type and placement.
233
 
234
+ Placement grouping is cached after the first call since params never
235
+ change their placement between steps.
236
+
237
  Args:
238
  optimizer_state: The optimizer's state dict (self.state in Muon).
239
  group: Parameter group dict.
240
  """
241
  params = group["params"]
242
+ placement_key = id(params)
243
 
244
+ if placement_key not in _placement_cache:
245
+ with record_function("adamw::group_by_placement"):
246
+ placement_to_params: dict[tuple,
247
+ list[torch.Tensor]] = defaultdict(list)
248
+ for p in params:
249
+ match p:
250
+ case DTensor():
251
+ logger.debug(
252
+ "[AdamW] DTensor param: shape=%s, placements=%s, "
253
+ "mesh=%s, grad=%s", p.shape, p.placements,
254
+ p.device_mesh.mesh_dim_names,
255
+ p.grad.shape if p.grad is not None else None)
256
+ placement_to_params[tuple(
257
+ [p.placements, p.device_mesh])].append(p)
258
+ case torch.Tensor():
259
+ logger.debug(
260
+ "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
+ p.grad.shape if p.grad is not None else None)
262
+ placement_to_params[tuple([torch.Tensor,
263
+ None])].append(p)
264
+
265
+ logger.debug("[AdamW] %d placement groups, %d total params",
266
+ len(placement_to_params), len(params))
267
+
268
+ _placement_cache[placement_key] = dict(placement_to_params)
269
+
270
+ for group_params in _placement_cache[placement_key].values():
271
  step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-rocm70-x86_64-linux/core.py CHANGED
@@ -1,11 +1,25 @@
 
1
  import math
2
  from dataclasses import dataclass
 
3
 
4
  import torch
5
- import torch.distributed as dist
6
  from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @dataclass
11
  class _muon_state:
@@ -17,26 +31,71 @@ class _muon_state:
17
  qk_clip_state: torch.Tensor | None = None
18
 
19
 
20
- def update_g(optimizer_state, p, g, group, momentum):
21
- """Apply momentum update to gradient.
 
 
 
 
 
 
22
 
23
- Args:
24
- optimizer_state: The optimizer's state dict (self.state in Muon).
25
- p: Parameter tensor.
26
- g: Gradient tensor.
27
- group: Parameter group dict.
28
- momentum: Momentum coefficient.
29
 
30
- Returns:
31
- Momentum-updated gradient tensor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
- state = optimizer_state[p]
34
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
- torch.add(g, buf, alpha=momentum, out=buf)
36
- if group["nesterov"]:
37
- g.add_(buf, alpha=momentum)
38
- return g
39
- return buf
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay):
49
  adjusted_lr: Size-adjusted learning rate.
50
  weight_decay: Weight decay coefficient.
51
  """
52
- if isinstance(p, torch.nn.Parameter):
53
- # apply weight decay
54
- p.data.mul_(1 - lr * weight_decay)
55
- # apply update
56
- p.data.add_(u, alpha=-adjusted_lr)
57
- else:
58
- p.mul_(1 - lr * weight_decay)
59
- p.add_(u, alpha=-adjusted_lr)
60
 
61
 
62
  def adjust_lr_for_muon(lr, param_shape):
@@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape):
77
  return adjusted_lr
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def default_is_muon(name, x, expert_keys=None):
81
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
- if any(key in name for key in skip_keys):
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return False
84
  effective_ndim = x.ndim
85
- if expert_keys and any(key in name for key in expert_keys):
 
86
  effective_ndim -= 1
87
- return effective_ndim >= 2
 
 
 
 
 
88
 
89
 
90
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
@@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
92
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
 
94
  muon_params, muon_names = [], []
95
- non_muon_params = []
96
 
97
  for n, p in model.named_parameters():
98
  if not p.requires_grad:
@@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
102
  muon_names.append(n)
103
  else:
104
  non_muon_params.append(p)
 
 
 
 
105
 
106
  return [
107
  {
 
1
+ import logging
2
  import math
3
  from dataclasses import dataclass
4
+ from typing import List
5
 
6
  import torch
 
7
  from torch.distributed import ProcessGroup
8
  from torch.distributed.tensor import DTensor
9
 
10
+ # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
+ # parameter FQNs. Activation checkpointing similarly inserts
12
+ # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
+ # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
+ _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def normalize_fqn(name: str) -> str:
20
+ """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
+ return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
+
23
 
24
  @dataclass
25
  class _muon_state:
 
31
  qk_clip_state: torch.Tensor | None = None
32
 
33
 
34
+ def _batch_momentum(
35
+ grads: List[torch.Tensor],
36
+ momentum_bufs: List[torch.Tensor],
37
+ momentum: torch.Tensor,
38
+ ) -> None:
39
+ """Batched momentum update (no nesterov)."""
40
+ torch._foreach_mul_(momentum_bufs, momentum)
41
+ torch._foreach_add_(momentum_bufs, grads)
42
 
 
 
 
 
 
 
43
 
44
+ def _batch_momentum_nesterov(
45
+ grads: List[torch.Tensor],
46
+ momentum_bufs: List[torch.Tensor],
47
+ momentum: torch.Tensor,
48
+ ) -> None:
49
+ """Batched momentum update with nesterov correction."""
50
+ torch._foreach_mul_(momentum_bufs, momentum)
51
+ torch._foreach_add_(momentum_bufs, grads)
52
+ nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
+ torch._foreach_add_(grads, nesterov_terms)
54
+
55
+
56
+ _compiled_momentum: dict[bool, callable] = {}
57
+ _use_momentum_compile = True
58
+
59
+
60
+ def set_momentum_compile(enabled: bool):
61
+ """Toggle torch.compile for batched momentum."""
62
+ global _use_momentum_compile
63
+ _use_momentum_compile = enabled
64
+
65
+
66
+ def batch_pre_ortho(
67
+ grads: List[torch.Tensor],
68
+ momentum_bufs: List[torch.Tensor],
69
+ momentum: torch.Tensor,
70
+ nesterov: bool,
71
+ ) -> None:
72
+ """Batched momentum update on lists of plain tensors.
73
+
74
+ Mirrors dion's ``muon_update_pre_orthogonalize``.
75
+ Inputs must be plain CUDA tensors (not DTensor).
76
+ Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
+
78
+ When compile is enabled, uses separately compiled functions for
79
+ nesterov=True/False to avoid graph breaks from the branch.
80
  """
81
+ fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
+ if _use_momentum_compile:
83
+ if nesterov not in _compiled_momentum:
84
+ _compiled_momentum[nesterov] = torch.compile(fn)
85
+ fn = _compiled_momentum[nesterov]
86
+ fn(grads, momentum_bufs, momentum)
87
+
88
+
89
+ def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
+ """Weight-decay + update on plain tensors.
91
+
92
+ Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
+ lookup per call × 256+ params = massive overhead. The pipeline path uses
94
+ batched _foreach_* ops instead; this function remains for base() and
95
+ distributed_muon().
96
+ """
97
+ p_data.mul_(1 - lr * weight_decay)
98
+ p_data.add_(u_data, alpha=-adjusted_lr)
99
 
100
 
101
  def update_p(p, u, lr, adjusted_lr, weight_decay):
 
108
  adjusted_lr: Size-adjusted learning rate.
109
  weight_decay: Weight decay coefficient.
110
  """
111
+ # Unwrap Parameter -> underlying data tensor.
112
+ p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
+ # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
+ if isinstance(p_data, DTensor):
115
+ p_data = p_data._local_tensor
116
+ u_data = u._local_tensor if isinstance(u, DTensor) else u
117
+ _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
 
118
 
119
 
120
  def adjust_lr_for_muon(lr, param_shape):
 
135
  return adjusted_lr
136
 
137
 
138
+ def _match_key(parts, key):
139
+ """Check if key matches as contiguous components in parts.
140
+
141
+ Single-component keys (e.g. "experts") match any single component.
142
+ Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
+ """
144
+ key_parts = key.split(".")
145
+ key_len = len(key_parts)
146
+ if key_len == 1:
147
+ return key in parts
148
+ return any(parts[i:i + key_len] == key_parts
149
+ for i in range(len(parts) - key_len + 1))
150
+
151
+
152
+ def is_expert_param(name, expert_keys):
153
+ """Check if a parameter name matches any expert key (component-level)."""
154
+ if not expert_keys:
155
+ return False
156
+ parts = normalize_fqn(name).split(".")
157
+ return any(_match_key(parts, key) for key in expert_keys)
158
+
159
+
160
  def default_is_muon(name, x, expert_keys=None):
161
+ normalized = normalize_fqn(name)
162
+ parts = normalized.split(".")
163
+ skip_keys = [
164
+ "embed_tokens",
165
+ "lm_head",
166
+ "tok_embeddings",
167
+ "output",
168
+ "mhc_attn",
169
+ "mhc_ffn",
170
+ "lambda_proj",
171
+ ]
172
+ if any(key in parts for key in skip_keys):
173
+ logger.info(
174
+ "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
+ normalized, name, x.ndim)
176
  return False
177
  effective_ndim = x.ndim
178
+ is_expert = is_expert_param(name, expert_keys)
179
+ if is_expert:
180
  effective_ndim -= 1
181
+ result = effective_ndim >= 2
182
+ logger.info(
183
+ "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
+ normalized, name, x.ndim, is_expert, effective_ndim,
185
+ "Muon" if result else "AdamW")
186
+ return result
187
 
188
 
189
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
 
191
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
 
193
  muon_params, muon_names = [], []
194
+ non_muon_params, non_muon_names = [], []
195
 
196
  for n, p in model.named_parameters():
197
  if not p.requires_grad:
 
201
  muon_names.append(n)
202
  else:
203
  non_muon_params.append(p)
204
+ non_muon_names.append(n)
205
+
206
+ logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
+ expert_keys, len(muon_names), len(non_muon_names))
208
 
209
  return [
210
  {
build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CPU offloading for optimizer states.
2
+
3
+ Manages a pinned CPU memory pool and async CUDA streams to offload
4
+ optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
+ optimizer steps, freeing GPU memory.
6
+
7
+ All tracked tensors are packed into a single flat pinned CPU buffer
8
+ (per dtype). D2H and H2D copies are performed per-tensor directly
9
+ between individual GPU tensors and their slice of the CPU flat buffer
10
+ — no GPU staging buffer is allocated, so there is **no temporary GPU
11
+ memory spike** during offload or reload.
12
+
13
+ Individual tensor storages are freed after offload via
14
+ ``untyped_storage().resize_(0)``, preserving tensor identity so
15
+ downstream caches remain valid.
16
+ """
17
+
18
+ import logging
19
+ from collections import defaultdict
20
+
21
+ import torch
22
+ from torch.distributed.tensor import DTensor
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class CPUOffloadPool:
28
+ """Pinned CPU memory pool for async optimizer state offloading.
29
+
30
+ Tracked tensors are grouped by dtype. Each group gets a single flat
31
+ pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
+ the flat buffer) to avoid allocating a GPU staging buffer.
33
+ """
34
+
35
+ def __init__(self):
36
+ self._managed: list[torch.Tensor] = []
37
+ self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
+
39
+ # Per-dtype group: populated on first offload.
40
+ # dtype → dict with keys:
41
+ # "indices" : list[int] managed-list indices
42
+ # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
+ # "total" : int total numel
44
+ # "cpu_flat" : Tensor pinned CPU buffer
45
+ self._groups: dict[torch.dtype, dict] = {}
46
+
47
+ self._offload_stream: torch.cuda.Stream | None = None
48
+ self._device: torch.device | None = None
49
+ self._initialized: bool = False
50
+ self._logged: bool = False
51
+
52
+ # ------------------------------------------------------------------
53
+ @staticmethod
54
+ def _local(t: torch.Tensor) -> torch.Tensor:
55
+ """Unwrap DTensor to its local CUDA tensor."""
56
+ return t._local_tensor if isinstance(t, DTensor) else t
57
+
58
+ def _ensure_stream(self):
59
+ if self._offload_stream is None:
60
+ self._offload_stream = torch.cuda.Stream(device=self._device)
61
+
62
+ # ------------------------------------------------------------------
63
+ def track(self, tensor: torch.Tensor):
64
+ """Register a GPU tensor for CPU offloading. Idempotent."""
65
+ tid = id(tensor)
66
+ if tid in self._storage_nbytes:
67
+ return
68
+ local = self._local(tensor)
69
+ if self._device is None:
70
+ self._device = local.device
71
+ self._storage_nbytes[tid] = local.untyped_storage().size()
72
+ self._managed.append(tensor)
73
+
74
+ # ------------------------------------------------------------------
75
+ def _init_buffers(self):
76
+ """Build per-dtype flat buffers on first offload."""
77
+ # Group managed tensors by dtype.
78
+ dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
79
+ for idx, t in enumerate(self._managed):
80
+ local = self._local(t)
81
+ dtype_map[local.dtype].append((idx, local.numel()))
82
+
83
+ total_cpu_bytes = 0
84
+ for dtype, entries in dtype_map.items():
85
+ offsets: list[tuple[int, int]] = []
86
+ indices: list[int] = []
87
+ off = 0
88
+ for idx, n in entries:
89
+ indices.append(idx)
90
+ offsets.append((off, n))
91
+ off += n
92
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
93
+ self._groups[dtype] = {
94
+ "indices": indices,
95
+ "offsets": offsets,
96
+ "total": off,
97
+ "cpu_flat": cpu_flat,
98
+ }
99
+ total_cpu_bytes += off * cpu_flat.element_size()
100
+
101
+ self._initialized = True
102
+ logger.info(
103
+ "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
104
+ "%.2f MB pinned CPU memory",
105
+ len(self._managed),
106
+ len(self._groups),
107
+ total_cpu_bytes / (1024**2),
108
+ )
109
+
110
+ # ------------------------------------------------------------------
111
+ def offload(self):
112
+ """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
113
+ if not self._managed:
114
+ return
115
+ if not self._initialized:
116
+ self._init_buffers()
117
+ self._ensure_stream()
118
+
119
+ # Offload stream waits for compute to finish.
120
+ compute_event = torch.cuda.current_stream(
121
+ self._device).record_event()
122
+ self._offload_stream.wait_event(compute_event)
123
+
124
+ offloaded_bytes = 0
125
+
126
+ # Per-tensor D2H copies directly into CPU flat buffer slices.
127
+ # No GPU staging buffer → no temporary GPU memory spike.
128
+ with torch.cuda.stream(self._offload_stream):
129
+ for dtype, grp in self._groups.items():
130
+ indices = grp["indices"]
131
+ offsets = grp["offsets"]
132
+ cpu_flat = grp["cpu_flat"]
133
+
134
+ for i, mgd_idx in enumerate(indices):
135
+ local = self._local(self._managed[mgd_idx])
136
+ off, n = offsets[i]
137
+ cpu_flat[off:off + n].copy_(
138
+ local.reshape(-1), non_blocking=True)
139
+
140
+ offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
+
142
+ # Wait for all D2H copies to land, then free GPU storage.
143
+ self._offload_stream.synchronize()
144
+ for t in self._managed:
145
+ self._local(t).untyped_storage().resize_(0)
146
+
147
+ if not self._logged:
148
+ logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
149
+ offloaded_bytes / (1024**2))
150
+
151
+ # ------------------------------------------------------------------
152
+ def reload(self):
153
+ """Per-tensor H2D from CPU flat buffer on the default stream.
154
+
155
+ Runs on the current (default) CUDA stream to avoid stream
156
+ interaction issues with the parallel Muon pipeline. Since
157
+ pinned CPU memory is the source, the copies overlap with
158
+ GPU idle time between steps.
159
+ """
160
+ if not self._managed or not self._initialized:
161
+ return
162
+
163
+ reloaded_bytes = 0
164
+
165
+ # Re-allocate all GPU storages first.
166
+ for t in self._managed:
167
+ local = self._local(t)
168
+ local.untyped_storage().resize_(self._storage_nbytes[id(t)])
169
+
170
+ # Per-tensor H2D copies from CPU flat buffer slices.
171
+ # non_blocking=True with pinned source allows DMA overlap.
172
+ for dtype, grp in self._groups.items():
173
+ indices = grp["indices"]
174
+ offsets = grp["offsets"]
175
+ cpu_flat = grp["cpu_flat"]
176
+
177
+ for i, mgd_idx in enumerate(indices):
178
+ local = self._local(self._managed[mgd_idx])
179
+ off, n = offsets[i]
180
+ local.reshape(-1).copy_(
181
+ cpu_flat[off:off + n], non_blocking=True)
182
+
183
+ reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
+
185
+ if not self._logged:
186
+ logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)",
187
+ reloaded_bytes / (1024**2))
188
+ self._logged = True
build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py CHANGED
@@ -72,12 +72,6 @@ def get_slices_of_dtensor(
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
75
- if curr_size % num_chunks != 0:
76
- raise NotImplementedError(
77
- f"Dimension size {curr_size} is not divisible "
78
- f"by number of ranks {num_chunks} for shard "
79
- f"placement on dim {shard_dim}. (shape: {target.shape})")
80
-
81
  # Compute indices for this level of sharding
82
  if isinstance(placement, _StridedShard):
83
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
 
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
 
 
 
 
 
 
75
  # Compute indices for this level of sharding
76
  if isinstance(placement, _StridedShard):
77
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py CHANGED
@@ -43,6 +43,7 @@ def get_autotune_config():
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
 
46
  )
47
  @triton.jit
48
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
@@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
102
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
 
104
 
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
  d_in = d_in.contiguous()
116
  M, K = d_in.shape
117
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
@@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
46
+ restore_value=['y'],
47
  )
48
  @triton.jit
49
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
 
103
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
 
105
 
106
+ @torch.library.custom_op("muon::matmul_transpose_assign",
107
+ mutates_args=("d_out", ))
108
+ def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
+ """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
 
 
 
 
 
 
110
  d_in = d_in.contiguous()
111
  M, K = d_in.shape
112
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
 
114
  with torch.cuda.device(d_in.device.index):
115
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
  d_out.stride(0), d_out.stride(1))
117
+
118
+
119
+ @matmul_transpose_assign.register_fake
120
+ def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
+ """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
+ pass
build/torch210-cxx11-rocm70-x86_64-linux/muon.py CHANGED
@@ -10,13 +10,16 @@ from torch.profiler import record_function
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
- from .core import (_muon_state, adjust_lr_for_muon,
14
- get_default_muon_param_groups, update_g, update_p)
 
15
  from .distributed.utils import (_is_shard, construct_shard_mesh,
16
  get_slices_of_dtensor)
17
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
18
- _zeropower_via_newtonschulz5)
19
- from .pipeline import muon_chunk_pipeline
 
 
20
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
21
 
22
  logger = logging.getLogger(__name__)
@@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys):
45
  expanded_params = []
46
 
47
  for n, p in zip(names, params):
48
- is_expert = expert_keys and any(key in n for key in expert_keys)
49
  is_dtensor = isinstance(p.data, DTensor)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if not is_expert:
52
  assert p.data.ndim <= 2, (
53
  f"Param {n} has ndim={p.data.ndim} but does not match "
@@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer):
168
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
169
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
170
  For testing purpose only.
171
- small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
172
  expert_keys: List of strings to identify expert-parallel parameters.
173
  If any key appears in a parameter's name, its outermost
174
  dimension is treated as the expert dimension and expanded
@@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer):
193
  warmup_step=5,
194
  chunk_size=-1,
195
  use_distributed_muon=False,
196
- small_param_numel_threshold=65536,
197
- expert_keys=None):
198
  defaults = dict(
199
  lr=lr,
200
  weight_decay=weight_decay,
@@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer):
228
  self.warmup_step = warmup_step
229
  self.chunk_size = chunk_size
230
  self.use_distributed_muon = use_distributed_muon
231
- self.small_param_numel_threshold = small_param_numel_threshold
232
  self.expert_keys = expert_keys
 
 
 
 
 
233
 
234
  def _calc_flops(self, G, steps):
235
  assert len(G.shape) == 2
@@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer):
333
  if g is None:
334
  continue
335
 
336
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
337
- steps=group["ns_steps"])
338
 
339
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
340
  update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer):
355
  weight_decay: float,
356
  qk_logits: list[torch.Tensor | DTensor] | None,
357
  ):
358
- """ Implementation of Distributed Muon by Liu et al. """
359
 
360
- # Momentum is already applied by _step_muon before this method.
361
- for n, p in zip(names, params):
362
- g = p.grad
363
- if g is None:
364
- continue
365
-
366
- # Gather G
367
- if isinstance(p.data, DTensor):
368
- g_full = g.full_tensor()
369
- p_full = p.data.full_tensor()
370
- else:
371
- g_full = g
372
- p_full = p
373
-
374
- u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
375
- steps=group["ns_steps"])
376
-
377
- adjusted_lr = adjust_lr_for_muon(lr, p_full.shape)
378
- update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
379
 
380
- qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- scales_full = compute_scales(
383
- p_full, qk_clip_state) if qk_clip_state is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
- if scales_full is not None:
386
- qk_clip(p_full, scales_full, qk_clip_state.head_dim)
 
 
387
 
388
- if isinstance(p.data, DTensor):
389
- ndims = len(p.device_mesh.mesh.shape)
390
- p_replicate = DTensor.from_local(
391
- p_full,
392
- device_mesh=p.device_mesh,
393
- placements=[Replicate() for _ in range(ndims)],
394
- )
395
 
396
- p_sharded = p_replicate.redistribute(
397
- device_mesh=p.device_mesh,
398
- placements=p.placements,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  )
400
 
401
- p.copy_(p_sharded)
402
 
403
- def parallel(self, names, params, group, lr, weight_decay, qk_logits):
 
 
 
 
 
 
 
404
  """
405
  Perform a parallel optimization step using Muon.
406
 
@@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer):
409
  interleaves multiple chunks so that communication and computation
410
  overlap across chunks (the same overlap previously achieved by the
411
  warmup + main-loop index scheduling).
 
 
 
 
412
  """
413
 
414
  # Momentum is already applied by _step_muon before this method.
415
 
416
- param_to_state, ordered_params = self.init_state_and_assign_params(
417
- names, params, group, qk_logits)
418
-
419
- # Compute local rank for this group's shard process group.
420
- shard_pg = param_to_state[id(ordered_params[0])].process_group
421
- rank = dist.get_rank(group=shard_pg)
422
-
423
- if self.chunk_size == -1:
424
- shard_ranks = dist.get_world_size(param_to_state[id(
425
- ordered_params[0])].process_group)
426
- chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
427
- elif self.chunk_size > 0:
428
- chunk_size = self.chunk_size
429
- else:
430
- raise ValueError("chunk_size must be -1 or a positive integer.")
431
 
432
  def pipelines():
 
433
  for start in range(0, len(ordered_params), chunk_size):
434
  chunk = ordered_params[start:start + chunk_size]
435
  if chunk:
436
- yield muon_chunk_pipeline(
437
  params=chunk,
438
  param_to_state=param_to_state,
439
  rank=rank,
@@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer):
442
  weight_decay=weight_decay,
443
  none_grad=group["none_grad"],
444
  )
 
 
 
 
445
 
446
- with record_function("muon::barrier"):
447
- dist.barrier()
448
  with record_function("muon::pipeline"):
449
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
450
 
@@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer):
456
  names = group["names"]
457
 
458
  # Apply momentum to all params before routing/expansion.
 
459
  with record_function("muon::momentum"):
460
- for n, p in zip(names, params):
461
- g = p.grad
462
- if g is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  continue
464
- g = update_g(self.state, p, g, group, momentum)
465
- p.grad = g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Expand expert params by splitting on dim 0.
468
- names, params = _expand_expert_params(names, params, self.expert_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  param_dtensors = []
471
  name_dtensors = []
@@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer):
473
  param_tensors = []
474
  name_tensors = []
475
 
476
- param_dtensors_small = []
477
- name_dtensors_small = []
478
-
479
  if self.use_distributed_muon:
 
480
  self.distributed_muon(names=names,
481
  params=params,
482
  group=group,
@@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer):
485
  qk_logits=qk_logits)
486
  return
487
 
488
- # For simplicity, we use distributed Muon for small parameters
489
- # whose number of elements is below a threshold.
490
  for n, p in zip(names, params):
491
  if p is None or p.grad is None:
492
  continue
@@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer):
494
  if all(
495
  isinstance(placement, Replicate)
496
  for placement in p.placements):
 
 
 
497
  param_tensors.append(p)
498
  name_tensors.append(n)
499
- elif p.data.numel() <= self.small_param_numel_threshold:
500
- param_dtensors_small.append(p)
501
- name_dtensors_small.append(n)
502
  else:
 
 
 
 
503
  param_dtensors.append(p)
504
  name_dtensors.append(n)
505
  elif isinstance(p.data, torch.Tensor):
 
 
506
  param_tensors.append(p)
507
  name_tensors.append(n)
508
  else:
509
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
510
 
511
- logger.debug(
512
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
513
- f"{len(param_dtensors_small)} Small DTensors")
514
 
515
  def group_dtensors(dtensors, names):
516
  # To support different placements, we group parameters by placements
@@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer):
526
  p.device_mesh])][1].append(p)
527
  return placement_to_params
528
 
529
- if len(param_dtensors_small) > 0:
530
- if not dist.is_initialized():
531
- raise RuntimeError(
532
- "Parallel Muon requires torch.distributed to be initialized."
533
- )
534
-
535
- self.distributed_muon(
536
- params=param_dtensors_small,
537
- names=name_dtensors_small,
538
- group=group,
539
- lr=lr,
540
- weight_decay=weight_decay,
541
- qk_logits=qk_logits,
542
- )
543
-
544
  if len(param_dtensors) > 0:
545
  if not dist.is_initialized():
546
  raise RuntimeError(
@@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer):
548
  )
549
 
550
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  for _, (names, params) in dtensor_group.items():
 
 
552
  self.parallel(
553
  names,
554
  params,
@@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer):
556
  lr=lr,
557
  weight_decay=weight_decay,
558
  qk_logits=qk_logits,
 
559
  )
 
 
560
 
561
  if len(param_tensors) > 0:
562
  self.base(
@@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer):
568
  qk_logits=qk_logits,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  @torch.no_grad
572
  def step(self, closure=None, qk_logits=None):
573
  """Perform a single optimization step.
@@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer):
585
  with torch.enable_grad():
586
  loss = closure()
587
 
588
- for group in self.param_groups:
 
 
 
 
 
 
 
589
  if group["use_muon"]:
 
 
590
  self._step_muon(group, qk_logits=qk_logits)
591
  else:
 
 
 
592
  step_adamw(self.state, group)
593
 
 
 
 
 
 
 
 
594
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
+ from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
+ get_default_muon_param_groups, is_expert_param, update_p)
15
+ from .cpu_offload import CPUOffloadPool
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
18
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
19
+ _zeropower_via_newtonschulz5,
20
+ zeropower_via_newtonschulz5,
21
+ zeropower_via_newtonschulz5_batched)
22
+ from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
23
  from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
24
 
25
  logger = logging.getLogger(__name__)
 
48
  expanded_params = []
49
 
50
  for n, p in zip(names, params):
51
+ is_expert = is_expert_param(n, expert_keys)
52
  is_dtensor = isinstance(p.data, DTensor)
53
 
54
+ if is_expert:
55
+ if is_dtensor:
56
+ logger.debug(
57
+ "[expand_expert] %s: expert DTensor, shape=%s, "
58
+ "placements=%s, mesh=%s, local_shape=%s", n, p.shape,
59
+ p.placements, p.device_mesh.mesh_dim_names,
60
+ p.to_local().shape)
61
+ else:
62
+ logger.debug(
63
+ "[expand_expert] %s: expert plain tensor, shape=%s", n,
64
+ p.data.shape)
65
+
66
  if not is_expert:
67
  assert p.data.ndim <= 2, (
68
  f"Param {n} has ndim={p.data.ndim} but does not match "
 
183
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
184
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
185
  For testing purpose only.
 
186
  expert_keys: List of strings to identify expert-parallel parameters.
187
  If any key appears in a parameter's name, its outermost
188
  dimension is treated as the expert dimension and expanded
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None,
211
+ cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
 
242
  self.warmup_step = warmup_step
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
 
245
  self.expert_keys = expert_keys
246
+ self.cpu_offload = cpu_offload
247
+ self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
+ self._offload_initialized = False
249
+ self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
+ self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
251
 
252
  def _calc_flops(self, G, steps):
253
  assert len(G.shape) == 2
 
351
  if g is None:
352
  continue
353
 
354
+ u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
355
+ steps=group["ns_steps"])
356
 
357
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
358
  update_p(p, u, lr, adjusted_lr, weight_decay)
 
373
  weight_decay: float,
374
  qk_logits: list[torch.Tensor | DTensor] | None,
375
  ):
376
+ """Batched Distributed Muon for testing/correctness verification only.
377
 
378
+ Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
379
+ the full grad, then slices back to local shards. This is simpler but
380
+ slower than the parallel pipeline (all2all) path, so it serves as a
381
+ reference implementation for verifying correctness.
382
+ """
383
+ with record_function("distributed_muon"):
384
+ # Momentum is already applied by _step_muon before this method.
385
+ ns_steps = group["ns_steps"]
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ # Separate plain tensors (no communication) from DTensors.
388
+ plain_names, plain_params = [], []
389
+ dtensor_names, dtensor_params = [], []
390
+ for n, p in zip(names, params):
391
+ if p.grad is None:
392
+ continue
393
+ if isinstance(p.data, DTensor):
394
+ dtensor_names.append(n)
395
+ dtensor_params.append(p)
396
+ else:
397
+ plain_names.append(n)
398
+ plain_params.append(p)
399
+
400
+ # Process plain tensors per-param (no communication).
401
+ for n, p in zip(plain_names, plain_params):
402
+ u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
403
+ steps=ns_steps)
404
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
405
+ update_p(p, u, lr, adjusted_lr, weight_decay)
406
+
407
+ qk_clip_state = get_qk_clip_info(self.clip_config, n,
408
+ qk_logits)
409
+ scales_full = compute_scales(
410
+ p, qk_clip_state) if qk_clip_state is not None else None
411
+ if scales_full is not None:
412
+ qk_clip(p, scales_full, qk_clip_state.head_dim)
413
+
414
+ if not dtensor_params:
415
+ return
416
+
417
+ # Group DTensors by (placements, mesh) for batched all-gather.
418
+ placement_groups: dict[tuple,
419
+ tuple[list,
420
+ list]] = defaultdict(lambda: ([], []))
421
+ for n, p in zip(dtensor_names, dtensor_params):
422
+ key = (p.placements, p.device_mesh)
423
+ placement_groups[key][0].append(n)
424
+ placement_groups[key][1].append(p)
425
+
426
+ logger.info(
427
+ "distributed_muon: %d placement groups, %d total dtensors",
428
+ len(placement_groups), len(dtensor_params))
429
+
430
+ for (placements, mesh), (grp_names,
431
+ grp_params) in placement_groups.items():
432
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
433
+ placements, mesh)
434
+ rank = dist.get_rank(shard_pg)
435
+ world_size = dist.get_world_size(shard_pg)
436
+
437
+ logger.info(" group: %d params, placements=%s, world_size=%d",
438
+ len(grp_params), placements, world_size)
439
+
440
+ # Separate params that can be batched (all shard dims evenly
441
+ # divisible) from those needing per-param full_tensor
442
+ # (e.g. MoE gate weights with fewer rows than shard ranks).
443
+ # all_gather_into_tensor requires equal buffer sizes across
444
+ # ranks, so uneven splits must use DTensor full_tensor().
445
+ batch_names, batch_params = [], []
446
+ single_names, single_params = [], []
447
+ for n, p in zip(grp_names, grp_params):
448
+ even = all(p.shape[pl.dim] %
449
+ shard_mesh.mesh.shape[dim_idx] == 0
450
+ for dim_idx, pl in enumerate(shard_placements))
451
+ if even:
452
+ batch_names.append(n)
453
+ batch_params.append(p)
454
+ else:
455
+ single_names.append(n)
456
+ single_params.append(p)
457
+
458
+ # Process uneven-split params per-param via full_tensor().
459
+ for n, p in zip(single_names, single_params):
460
+ with record_function("distributed_muon::newton_schulz"):
461
+ g_full = p.grad.full_tensor().to(COMM_DTYPE)
462
+ u_full = _zeropower_via_newtonschulz5(g_full,
463
+ steps=ns_steps)
464
+ del g_full
465
+ with record_function("distributed_muon::update"):
466
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
467
+ p._local_tensor.mul_(1 - lr * weight_decay)
468
+ local_indices = get_slices_of_dtensor(
469
+ p, rank, shard_mesh, shard_placements)
470
+ u_local = u_full[local_indices]
471
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
472
+ del u_full
473
+
474
+ qk_clip_state = get_qk_clip_info(
475
+ self.clip_config, n, qk_logits)
476
+ scales_full = compute_scales(
477
+ p, qk_clip_state
478
+ ) if qk_clip_state is not None else None
479
+ if scales_full is not None:
480
+ ratio = p.shape[0] // scales_full.shape[0]
481
+ idx0 = local_indices[0]
482
+ if isinstance(idx0, slice):
483
+ start = idx0.start or 0
484
+ idx0 = torch.arange(start,
485
+ idx0.stop,
486
+ device=scales_full.device)
487
+ row_scales = scales_full[idx0 // ratio]
488
+ p._local_tensor.mul_(row_scales.view(-1, 1))
489
+
490
+ if not batch_params:
491
+ continue
492
 
493
+ logger.info(" batched=%d, single=%d", len(batch_params),
494
+ len(single_params))
495
+
496
+ # Concat all local grad shards into a single flat buffer.
497
+ with record_function("distributed_muon::gather"):
498
+ grad_locals = [
499
+ p.grad.to_local().to(COMM_DTYPE).flatten()
500
+ for p in batch_params
501
+ ]
502
+ numels = [g.numel() for g in grad_locals]
503
+ grad_concat = torch.cat(grad_locals)
504
+ del grad_locals
505
+
506
+ # Single all-gather (replaces N separate full_tensor).
507
+ grad_gathered = torch.empty(
508
+ grad_concat.numel() * world_size,
509
+ dtype=COMM_DTYPE,
510
+ device="cuda",
511
+ )
512
+ dist.all_gather_into_tensor(grad_gathered,
513
+ grad_concat,
514
+ group=shard_pg)
515
+
516
+ total_numel = grad_concat.numel()
517
+ del grad_concat
518
+
519
+ # Precompute per-param offsets within the concat buffer.
520
+ offsets = []
521
+ off = 0
522
+ for ne in numels:
523
+ offsets.append(off)
524
+ off += ne
525
+
526
+ # Per-param: reconstruct full grad → NS → local update.
527
+ for i, (n, p) in enumerate(zip(batch_names, batch_params)):
528
+ with record_function("distributed_muon::newton_schulz"):
529
+ g_full = torch.empty(p.shape,
530
+ dtype=COMM_DTYPE,
531
+ device="cuda")
532
+ for r in range(world_size):
533
+ r_start = r * total_numel + offsets[i]
534
+ shard = grad_gathered[r_start:r_start + numels[i]]
535
+ indices = get_slices_of_dtensor(
536
+ p, r, shard_mesh, shard_placements)
537
+ g_full[indices] = shard.reshape(
538
+ g_full[indices].shape)
539
+
540
+ u_full = _zeropower_via_newtonschulz5(g_full,
541
+ steps=ns_steps)
542
+ del g_full
543
+
544
+ with record_function("distributed_muon::update"):
545
+ adjusted_lr = adjust_lr_for_muon(lr, p.shape)
546
+ p._local_tensor.mul_(1 - lr * weight_decay)
547
+ local_indices = get_slices_of_dtensor(
548
+ p, rank, shard_mesh, shard_placements)
549
+ u_local = u_full[local_indices]
550
+ p._local_tensor.add_(u_local, alpha=-adjusted_lr)
551
+ del u_full
552
+
553
+ qk_clip_state = get_qk_clip_info(
554
+ self.clip_config, n, qk_logits)
555
+ scales_full = compute_scales(
556
+ p, qk_clip_state
557
+ ) if qk_clip_state is not None else None
558
+ if scales_full is not None:
559
+ ratio = p.shape[0] // scales_full.shape[0]
560
+ idx0 = local_indices[0]
561
+ if isinstance(idx0, slice):
562
+ start = idx0.start or 0
563
+ idx0 = torch.arange(start,
564
+ idx0.stop,
565
+ device=scales_full.device)
566
+ row_scales = scales_full[idx0 // ratio]
567
+ p._local_tensor.mul_(row_scales.view(-1, 1))
568
+
569
+ def _setup_parallel(self, names, params, group, qk_logits):
570
+ """Compute (or retrieve cached) parallel pipeline metadata.
571
+
572
+ Returns:
573
+ (ordered_params, param_to_state, rank, chunk_size)
574
+ """
575
+ cache_key = tuple(names)
576
 
577
+ if cache_key not in self._parallel_cache:
578
+ # First call: compute metadata and populate cache.
579
+ param_to_state, ordered_params = self.init_state_and_assign_params(
580
+ names, params, group, qk_logits)
581
 
582
+ shard_pg = param_to_state[id(ordered_params[0])].process_group
583
+ rank = dist.get_rank(group=shard_pg)
 
 
 
 
 
584
 
585
+ if self.chunk_size == -1:
586
+ shard_ranks = dist.get_world_size(shard_pg)
587
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
588
+ elif self.chunk_size > 0:
589
+ chunk_size = self.chunk_size
590
+ else:
591
+ raise ValueError(
592
+ "chunk_size must be -1 or a positive integer.")
593
+
594
+ ordered_names = [
595
+ param_to_state[id(p)].name for p in ordered_params
596
+ ]
597
+ name_to_state = {
598
+ param_to_state[id(p)].name: param_to_state[id(p)]
599
+ for p in ordered_params
600
+ }
601
+ self._parallel_cache[cache_key] = {
602
+ 'ordered_names': ordered_names,
603
+ 'name_to_state': name_to_state,
604
+ 'rank': rank,
605
+ 'chunk_size': chunk_size,
606
+ }
607
+ else:
608
+ # Cached path: rebuild param_to_state with current id(p) keys.
609
+ cache = self._parallel_cache[cache_key]
610
+ rank = cache['rank']
611
+ chunk_size = cache['chunk_size']
612
+
613
+ name_to_param = dict(zip(names, params))
614
+ ordered_params = [name_to_param[n] for n in cache['ordered_names']]
615
+
616
+ param_to_state = {}
617
+ for p, n in zip(ordered_params, cache['ordered_names']):
618
+ cached_state = cache['name_to_state'][n]
619
+ param_to_state[id(p)] = _muon_state(
620
+ worker_rank=cached_state.worker_rank,
621
+ process_group=cached_state.process_group,
622
+ rank_indices=cached_state.rank_indices,
623
+ rank_numels=cached_state.rank_numels,
624
+ name=n,
625
+ qk_clip_state=get_qk_clip_info(self.clip_config, n,
626
+ qk_logits),
627
  )
628
 
629
+ return ordered_params, param_to_state, rank, chunk_size
630
 
631
+ def parallel(self,
632
+ names,
633
+ params,
634
+ group,
635
+ lr,
636
+ weight_decay,
637
+ qk_logits,
638
+ prelaunch_gather=None):
639
  """
640
  Perform a parallel optimization step using Muon.
641
 
 
644
  interleaves multiple chunks so that communication and computation
645
  overlap across chunks (the same overlap previously achieved by the
646
  warmup + main-loop index scheduling).
647
+
648
+ If ``prelaunch_gather`` is provided, it is passed to the first
649
+ chunk's generator to skip re-launching the already in-flight
650
+ A2A gather.
651
  """
652
 
653
  # Momentum is already applied by _step_muon before this method.
654
 
655
+ ordered_params, param_to_state, rank, chunk_size = (
656
+ self._setup_parallel(names, params, group, qk_logits))
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
  def pipelines():
659
+ first = True
660
  for start in range(0, len(ordered_params), chunk_size):
661
  chunk = ordered_params[start:start + chunk_size]
662
  if chunk:
663
+ kwargs = dict(
664
  params=chunk,
665
  param_to_state=param_to_state,
666
  rank=rank,
 
669
  weight_decay=weight_decay,
670
  none_grad=group["none_grad"],
671
  )
672
+ if first and prelaunch_gather is not None:
673
+ kwargs['prelaunch_gather'] = prelaunch_gather
674
+ first = False
675
+ yield muon_chunk_pipeline(**kwargs)
676
 
 
 
677
  with record_function("muon::pipeline"):
678
  run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
679
 
 
685
  names = group["names"]
686
 
687
  # Apply momentum to all params before routing/expansion.
688
+ # Batched using _foreach_* ops (compiled, fullgraph=True).
689
  with record_function("muon::momentum"):
690
+ active_params = [p for p in params if p.grad is not None]
691
+ if active_params:
692
+ # Ensure momentum buffers exist (avoid zeros_like when already present).
693
+ for p in active_params:
694
+ if "momentum_buffer" not in self.state[p]:
695
+ self.state[p]["momentum_buffer"] = torch.zeros_like(
696
+ p.grad)
697
+
698
+ # Extract local tensors for compiled batch function.
699
+ local_grads = [
700
+ p.grad._local_tensor
701
+ if isinstance(p.grad, DTensor) else p.grad
702
+ for p in active_params
703
+ ]
704
+ local_bufs = [
705
+ self.state[p]["momentum_buffer"]._local_tensor
706
+ if isinstance(self.state[p]["momentum_buffer"], DTensor)
707
+ else self.state[p]["momentum_buffer"]
708
+ for p in active_params
709
+ ]
710
+
711
+ # Wrap momentum as tensor for torch.compile.
712
+ batch_pre_ortho(local_grads, local_bufs,
713
+ torch.tensor(momentum), group["nesterov"])
714
+
715
+ # For non-nesterov, the result is the momentum buffer.
716
+ if not group["nesterov"]:
717
+ for p in active_params:
718
+ p.grad = self.state[p]["momentum_buffer"]
719
+
720
+ # Identify batched experts for deferred NS.
721
+ # Detection is cheap (condition checks only); actual NS compute is
722
+ # deferred so it can overlap with the first chunk's A2A gather.
723
+ deferred_expert_work = []
724
+ if self.expert_keys:
725
+ batched_expert_indices = []
726
+ for i, (n, p) in enumerate(zip(names, params)):
727
+ if not (is_expert_param(n, self.expert_keys)
728
+ and p.grad is not None):
729
  continue
730
+ # Eligible: plain tensor, or DTensor with no non-dim-0 shards.
731
+ if isinstance(p.data, DTensor):
732
+ has_tp = any(
733
+ _is_shard(pl) and pl.dim != 0 for pl in p.placements)
734
+ if has_tp:
735
+ continue
736
+ batched_expert_indices.append(i)
737
+
738
+ if batched_expert_indices:
739
+ # Save refs for deferred NS; free grads from param list.
740
+ for i in batched_expert_indices:
741
+ p = params[i]
742
+ g = p.grad
743
+ local_g = (g._local_tensor
744
+ if isinstance(g, DTensor) else g)
745
+ local_data = (p.data._local_tensor if isinstance(
746
+ p.data, DTensor) else p.data)
747
+ deferred_expert_work.append((local_data, local_g))
748
+ p.grad = None
749
+
750
+ # Remove batched experts from lists before expansion.
751
+ keep = sorted(
752
+ set(range(len(params))) - set(batched_expert_indices))
753
+ names = [names[i] for i in keep]
754
+ params = [params[i] for i in keep]
755
+
756
+ def _run_deferred_expert_ns():
757
+ """Execute deferred batched expert NS."""
758
+ if not deferred_expert_work:
759
+ return
760
+ with record_function("muon::batched_expert_ns"):
761
+ ns_steps = group["ns_steps"]
762
+ for local_data, local_g in deferred_expert_work:
763
+ u = zeropower_via_newtonschulz5_batched(
764
+ local_g.to(COMM_DTYPE), steps=ns_steps)
765
+ adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
766
+ local_data.mul_(1 - lr * weight_decay)
767
+ local_data.add_(u, alpha=-adjusted_lr)
768
 
769
  # Expand expert params by splitting on dim 0.
770
+ logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
771
+ len(params), self.expert_keys)
772
+ if self.expert_keys:
773
+ cache_key = tuple(id(p) for p in params)
774
+ cache = self._expert_expand_cache.get(cache_key)
775
+
776
+ if cache is None:
777
+ # Cold path: full expansion + build cache metadata.
778
+ exp_names, exp_params = _expand_expert_params(
779
+ names, params, self.expert_keys)
780
+
781
+ # Build per-expert-group info for hot-path grad updates.
782
+ grad_info = []
783
+ exp_idx = 0
784
+ for orig_idx, (n, p) in enumerate(zip(names, params)):
785
+ if not is_expert_param(n, self.expert_keys):
786
+ exp_idx += 1
787
+ continue
788
+
789
+ is_dt = isinstance(p.data, DTensor)
790
+ num_experts = (p.to_local() if is_dt else p.data).shape[0]
791
+
792
+ # Detect TP mesh from the first expanded expert param.
793
+ tp_mesh = None
794
+ tp_pls = None
795
+ sample = exp_params[exp_idx]
796
+ if isinstance(sample.data, DTensor):
797
+ tp_mesh = sample.data.device_mesh
798
+ tp_pls = list(sample.data.placements)
799
+
800
+ grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
801
+ tp_mesh, tp_pls))
802
+ exp_idx += num_experts
803
+
804
+ self._expert_expand_cache[cache_key] = {
805
+ 'names': exp_names,
806
+ 'params': exp_params,
807
+ 'grad_info': grad_info,
808
+ }
809
+ names, params = exp_names, exp_params
810
+ else:
811
+ # Hot path: reuse cached params, only update expert grads.
812
+ for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
813
+ tp_pls) in cache['grad_info']:
814
+ p = params[orig_idx]
815
+ g = p.grad
816
+ local_grad = (g.to_local()
817
+ if is_dt and isinstance(g, DTensor) else g)
818
+ for i in range(num_experts):
819
+ expert_p = cache['params'][exp_start + i]
820
+ sg = local_grad[i]
821
+ if tp_mesh is not None:
822
+ expert_p.grad = DTensor.from_local(
823
+ sg, device_mesh=tp_mesh, placements=tp_pls)
824
+ else:
825
+ expert_p.grad = sg
826
+ p.grad = None
827
+
828
+ names = cache['names']
829
+ params = cache['params']
830
+ else:
831
+ names, params = _expand_expert_params(names, params,
832
+ self.expert_keys)
833
+ logger.debug("[_step_muon] after expand: %d params", len(params))
834
 
835
  param_dtensors = []
836
  name_dtensors = []
 
838
  param_tensors = []
839
  name_tensors = []
840
 
841
+ # distributed_muon is a reference implementation for testing only.
842
+ # The parallel pipeline (all2all) path below is the production path.
 
843
  if self.use_distributed_muon:
844
+ _run_deferred_expert_ns()
845
  self.distributed_muon(names=names,
846
  params=params,
847
  group=group,
 
850
  qk_logits=qk_logits)
851
  return
852
 
 
 
853
  for n, p in zip(names, params):
854
  if p is None or p.grad is None:
855
  continue
 
857
  if all(
858
  isinstance(placement, Replicate)
859
  for placement in p.placements):
860
+ logger.debug(
861
+ "[route] %s → base (DTensor all-Replicate), "
862
+ "shape=%s, placements=%s", n, p.shape, p.placements)
863
  param_tensors.append(p)
864
  name_tensors.append(n)
 
 
 
865
  else:
866
+ logger.debug(
867
+ "[route] %s → parallel (DTensor), shape=%s, "
868
+ "placements=%s, mesh=%s", n, p.shape, p.placements,
869
+ p.device_mesh.mesh_dim_names)
870
  param_dtensors.append(p)
871
  name_dtensors.append(n)
872
  elif isinstance(p.data, torch.Tensor):
873
+ logger.debug("[route] %s → base (plain tensor), shape=%s", n,
874
+ p.data.shape)
875
  param_tensors.append(p)
876
  name_tensors.append(n)
877
  else:
878
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
879
 
880
+ logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
881
+ f"{len(param_tensors)} Tensors → base")
 
882
 
883
  def group_dtensors(dtensors, names):
884
  # To support different placements, we group parameters by placements
 
894
  p.device_mesh])][1].append(p)
895
  return placement_to_params
896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  if len(param_dtensors) > 0:
898
  if not dist.is_initialized():
899
  raise RuntimeError(
 
901
  )
902
 
903
  dtensor_group = group_dtensors(param_dtensors, name_dtensors)
904
+
905
+ # Pre-launch the first chunk's A2A gather so that the NCCL
906
+ # communication overlaps with the (deferred) batched expert NS
907
+ # compute on the default CUDA stream.
908
+ prelaunch = None
909
+ if deferred_expert_work:
910
+ first_names, first_params = next(iter(dtensor_group.values()))
911
+ ordered, pts, rnk, csz = self._setup_parallel(
912
+ first_names, first_params, group, qk_logits)
913
+ first_chunk = ordered[:csz]
914
+ if first_chunk:
915
+ prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
916
+ group["none_grad"])
917
+
918
+ _run_deferred_expert_ns()
919
+
920
+ first_group = True
921
  for _, (names, params) in dtensor_group.items():
922
+ pg = prelaunch if first_group else None
923
+ first_group = False
924
  self.parallel(
925
  names,
926
  params,
 
928
  lr=lr,
929
  weight_decay=weight_decay,
930
  qk_logits=qk_logits,
931
+ prelaunch_gather=pg,
932
  )
933
+ else:
934
+ _run_deferred_expert_ns()
935
 
936
  if len(param_tensors) > 0:
937
  self.base(
 
943
  qk_logits=qk_logits,
944
  )
945
 
946
+ def _register_states_for_offload(self):
947
+ """Register all optimizer state tensors with the CPU offload pool.
948
+
949
+ Called once after the first step when states have been lazily created.
950
+ Offloads all param states (momentum buffers for Muon, moment1/moment2
951
+ for AdamW) to free GPU memory between steps.
952
+ """
953
+ pool = self._cpu_offload_pool
954
+ tracked = 0
955
+ for group in self.param_groups:
956
+ for p in group["params"]:
957
+ if p not in self.state:
958
+ continue
959
+ state = self.state[p]
960
+ if group.get("use_muon", False):
961
+ if "momentum_buffer" in state:
962
+ pool.track(state["momentum_buffer"])
963
+ tracked += 1
964
+ else:
965
+ if "moment1" in state:
966
+ pool.track(state["moment1"])
967
+ if "moment2" in state:
968
+ pool.track(state["moment2"])
969
+ tracked += 1
970
+ logger.info("[CPUOffload] Registered %d param states for offload",
971
+ tracked)
972
+
973
  @torch.no_grad
974
  def step(self, closure=None, qk_logits=None):
975
  """Perform a single optimization step.
 
987
  with torch.enable_grad():
988
  loss = closure()
989
 
990
+ # H2D: reload optimizer states from CPU before computation.
991
+ if self.cpu_offload and self._offload_initialized:
992
+ self._cpu_offload_pool.reload()
993
+
994
+ logger.debug("[Muon.step] expert_keys=%s, %d param groups",
995
+ self.expert_keys, len(self.param_groups))
996
+
997
+ for i, group in enumerate(self.param_groups):
998
  if group["use_muon"]:
999
+ logger.debug("[Muon.step] group %d: use_muon=True, %d params",
1000
+ i, len(group["params"]))
1001
  self._step_muon(group, qk_logits=qk_logits)
1002
  else:
1003
+ logger.debug(
1004
+ "[Muon.step] group %d: use_muon=False (AdamW), %d params",
1005
+ i, len(group["params"]))
1006
  step_adamw(self.state, group)
1007
 
1008
+ # D2H: offload optimizer states to CPU after computation.
1009
+ if self.cpu_offload:
1010
+ if not self._offload_initialized:
1011
+ self._register_states_for_offload()
1012
+ self._offload_initialized = True
1013
+ self._cpu_offload_pool.offload()
1014
+
1015
  return loss
1016
+
1017
+ # ------------------------------------------------------------------
1018
+ # Checkpoint support for cpu_offload
1019
+ # ------------------------------------------------------------------
1020
+
1021
+ def state_dict(self) -> dict:
1022
+ """Return optimizer state dict, reloading offloaded states first.
1023
+
1024
+ When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
+ storage freed (``resize_(0)``) between steps. We reload them,
1026
+ snapshot the state dict, then re-offload so the optimizer stays
1027
+ in the expected post-step state. The returned dict holds cloned
1028
+ tensors so they remain valid after the re-offload frees the
1029
+ originals' GPU storage.
1030
+ """
1031
+ if self.cpu_offload and self._offload_initialized:
1032
+ self._cpu_offload_pool.reload()
1033
+ torch.cuda.current_stream().synchronize()
1034
+ sd = super().state_dict()
1035
+ if self.cpu_offload and self._offload_initialized:
1036
+ # Clone state tensors so the returned dict survives re-offload
1037
+ # (which frees GPU storage on the originals via resize_(0)).
1038
+ for k in sd["state"]:
1039
+ sd["state"][k] = {
1040
+ sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
+ for sk, sv in sd["state"][k].items()
1042
+ }
1043
+ self._cpu_offload_pool.offload()
1044
+ return sd
1045
+
1046
+ def load_state_dict(self, state_dict: dict) -> None:
1047
+ """Load optimizer state dict, then offload states if needed.
1048
+
1049
+ After ``super().load_state_dict()`` populates GPU tensors, we
1050
+ re-register them with the offload pool and offload to CPU so the
1051
+ optimizer is in the same post-step state (GPU storage freed).
1052
+ """
1053
+ # If states were offloaded, reload first so storage sizes are
1054
+ # correct for super().load_state_dict() to overwrite.
1055
+ if self.cpu_offload and self._offload_initialized:
1056
+ self._cpu_offload_pool.reload()
1057
+ torch.cuda.current_stream().synchronize()
1058
+
1059
+ super().load_state_dict(state_dict)
1060
+
1061
+ if self.cpu_offload:
1062
+ # Re-create the offload pool since state tensors may be new
1063
+ # objects after load_state_dict.
1064
+ self._cpu_offload_pool = CPUOffloadPool()
1065
+ self._offload_initialized = False
1066
+ self._register_states_for_offload()
1067
+ self._offload_initialized = True
1068
+ self._cpu_offload_pool.offload()
build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import torch
2
 
3
  from .matmul_transpose_triton import matmul_transpose_assign
@@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16
6
  DEFAULT_CHUNK_SIZE_RATIO = 4
7
 
8
 
9
- # This code snippet is a modified version adapted from the following GitHub repositories:
10
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
11
- # Muon's Newton–Schulz iteration causes high variance in singular values
12
- # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @torch.no_grad()
14
- # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
15
  def _zeropower_via_newtonschulz5(G, steps):
16
  """
17
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
18
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
19
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
20
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
21
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
22
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
23
- performance at all relative to UV^T, where USV^T = G is the SVD.
 
 
 
 
 
 
 
24
  """
25
  assert len(G.shape) == 2
26
  assert G.dtype == COMM_DTYPE
@@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps):
28
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
- # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
 
 
33
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
34
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
35
  # Perform the NS iterations
36
- for a, b, c in [
37
- (4.0848, -6.8946, 2.9270),
38
- (3.9505, -6.3029, 2.6377),
39
- (3.7418, -5.5913, 2.3037),
40
- (2.8769, -3.1427, 1.2046),
41
- (2.8366, -3.0525, 1.2012),
42
- ]:
43
  matmul_transpose_assign(X, buf1)
44
  matmul_transpose_assign(buf1, buf2)
45
  buf1.mul_(b).add_(buf2, alpha=c)
@@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps):
47
 
48
  if G.size(0) > G.size(1):
49
  X = X.T
 
50
  return X
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ from math import inf, sqrt
3
+
4
+ import numpy as np
5
  import torch
6
 
7
  from .matmul_transpose_triton import matmul_transpose_assign
 
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
 
13
+ def _optimal_quintic(l, u, max_iter=1000):
14
+ """
15
+ Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
+ to the constant function x -> 1 over the interval [l, u].
17
+
18
+ Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
19
+ approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
+ two interior equioscillation nodes q, r until convergence. Returns the
21
+ closed-form equioscillating solution when l ≈ u.
22
+
23
+ Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
+ (NaN or inf). Raises RuntimeError if convergence is not reached within
25
+ max_iter iterations.
26
+ """
27
+ assert 0 <= l <= u
28
+ if 1 - 5e-6 <= l / u:
29
+ return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
+ q = (3 * l + u) / 4
31
+ r = (l + 3 * u) / 4
32
+ E = inf
33
+ for _ in range(max_iter):
34
+ old_E = E
35
+ LHS = np.array([
36
+ [l, l**3, l**5, 1],
37
+ [q, q**3, q**5, -1],
38
+ [r, r**3, r**5, 1],
39
+ [u, u**3, u**5, -1],
40
+ ])
41
+ a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
+ if not np.all(np.isfinite([a, b, c, E])):
43
+ raise ValueError(f"_optimal_quintic: non-finite solve result "
44
+ f"a={a}, b={b}, c={c}, E={E}")
45
+ q, r = np.sqrt(
46
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
+ (10 * c))
48
+ if not np.all(np.isfinite([q, r])):
49
+ raise ValueError(
50
+ f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
+ if abs(old_E - E) <= 1e-15:
52
+ break
53
+ else:
54
+ raise RuntimeError(
55
+ f"_optimal_quintic: did not converge after {max_iter} iterations")
56
+ return float(a), float(b), float(c)
57
+
58
+
59
+ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
60
+ """
61
+ Compute the Polar Express coefficient series for `num_iters` quintic iterations.
62
+
63
+ Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
64
+ compose to map singular values from [l, 1] toward 1. At each step:
65
+ 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
66
+ prevents near-zero singular values from stalling by raising the effective
67
+ lower bound; if it is active (cushion*u > l), the coefficients are
68
+ rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
69
+ 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
70
+ last iteration, providing numerical headroom at the cost of a slightly slower
71
+ final convergence step.
72
+ 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
73
+
74
+ Returns a list of (a, b, c) tuples, one per iteration.
75
+
76
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
77
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
78
+ """
79
+ u = 1
80
+ assert 0 <= l <= u
81
+ safety_factor = 1 + safety_factor_eps
82
+ coefficients = []
83
+ for iter in range(num_iters):
84
+ a, b, c = _optimal_quintic(max(l, cushion * u), u)
85
+ if cushion * u > l:
86
+ pl = a * l + b * l**3 + c * l**5
87
+ pu = a * u + b * u**3 + c * u**5
88
+ rescaler = 2 / (pl + pu)
89
+ a *= rescaler
90
+ b *= rescaler
91
+ c *= rescaler
92
+ if iter < num_iters - 1:
93
+ a /= safety_factor
94
+ b /= safety_factor**3
95
+ c /= safety_factor**5
96
+ coefficients.append((a, b, c))
97
+ l = a * l + b * l**3 + c * l**5
98
+ u = 2 - l
99
+ return coefficients
100
+
101
+
102
+ # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
103
+ # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
104
+ # approximant to x->1 over the current singular-value interval, computed once at
105
+ # import time and reused across all optimizer steps.
106
+ #
107
+ # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
108
+ # - Former: empirically tuned to maximize slope at zero; did not converge
109
+ # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
110
+ # of the true polar factor UV^T.
111
+ # - Polar Express: analytically optimal per step, adapting to the shrinking
112
+ # singular-value interval [l, u] as iterations progress; converges all
113
+ # singular values to 1, producing the exact polar factor UV^T.
114
+ _coeffs_list = _optimal_composition(l=1e-3,
115
+ num_iters=10,
116
+ safety_factor_eps=1e-2,
117
+ cushion=0.02)
118
+
119
+
120
+ # This code is adapted from:
121
+ # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
122
+ # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
123
+ # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
124
  @torch.no_grad()
 
125
  def _zeropower_via_newtonschulz5(G, steps):
126
  """
127
+ Compute the polar factor of G via the Polar Express method.
128
+
129
+ Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
130
+ are the Polar Express coefficients from `_coeffs_list`. Each step is the
131
+ optimal odd quintic approximant to x -> 1 over the current singular-value
132
+ interval, minimizing the maximum approximation error (Remez / minimax criterion).
133
+ The composition maps singular values from [l, 1] to near 1, producing the
134
+ polar factor (orthogonal factor in the polar decomposition G = UP).
135
+
136
+ `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
137
+ cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
138
+
139
+ Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
140
+ Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
141
  """
142
  assert len(G.shape) == 2
143
  assert G.dtype == COMM_DTYPE
 
145
 
146
  if G.size(0) > G.size(1):
147
  X = X.T
148
+
149
  X = X / (X.norm() + 1e-7)
150
+ hs = _coeffs_list[:steps] + list(
151
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
155
+ for a, b, c in hs:
 
 
 
 
 
 
156
  matmul_transpose_assign(X, buf1)
157
  matmul_transpose_assign(buf1, buf2)
158
  buf1.mul_(b).add_(buf2, alpha=c)
 
160
 
161
  if G.size(0) > G.size(1):
162
  X = X.T
163
+
164
  return X
165
+
166
+
167
+ @torch.no_grad()
168
+ def _zeropower_via_newtonschulz5_batched(G, steps):
169
+ """Batched polar factor computation for 3D (E, out, in) tensors.
170
+
171
+ Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
172
+ ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
173
+ processing all E expert matrices in a single batched call.
174
+ """
175
+ assert len(G.shape) == 3
176
+ assert G.dtype == COMM_DTYPE
177
+ X = G
178
+
179
+ if G.size(1) > G.size(2):
180
+ X = X.transpose(-2, -1)
181
+
182
+ # Per-expert Frobenius norm.
183
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
+
185
+ hs = _coeffs_list[:steps] + list(
186
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
187
+ for a, b, c in hs:
188
+ buf1 = torch.bmm(X, X.transpose(-2, -1))
189
+ buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
190
+ buf1.mul_(b).add_(buf2, alpha=c)
191
+ X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
192
+
193
+ if G.size(1) > G.size(2):
194
+ X = X.transpose(-2, -1)
195
+
196
+ return X
197
+
198
+
199
+ _ns_per_shape: dict[tuple[int, ...], callable] = {}
200
+ _use_compile = True
201
+
202
+
203
+ def set_ns_compile(enabled: bool):
204
+ """Toggle torch.compile for Newton-Schulz iteration."""
205
+ global _use_compile
206
+ _use_compile = enabled
207
+
208
+
209
+ def zeropower_via_newtonschulz5(G, steps=5):
210
+ if not _use_compile:
211
+ return _zeropower_via_newtonschulz5(G, steps)
212
+ key = G.shape
213
+ if key not in _ns_per_shape:
214
+ _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
215
+ options={
216
+ "triton.cudagraphs": True,
217
+ "shape_padding": False
218
+ })
219
+ torch.compiler.cudagraph_mark_step_begin()
220
+ return _ns_per_shape[key](G, steps).clone()
221
+
222
+
223
+ def zeropower_via_newtonschulz5_batched(G, steps=5):
224
+ """Compile-cached batched Newton-Schulz for 3D expert tensors."""
225
+ if not _use_compile:
226
+ return _zeropower_via_newtonschulz5_batched(G, steps)
227
+ key = G.shape
228
+ if key not in _ns_per_shape:
229
+ _ns_per_shape[key] = torch.compile(
230
+ _zeropower_via_newtonschulz5_batched,
231
+ options={
232
+ "triton.cudagraphs": True,
233
+ "shape_padding": False
234
+ })
235
+ torch.compiler.cudagraph_mark_step_begin()
236
+ return _ns_per_shape[key](G, steps).clone()
build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py CHANGED
@@ -6,8 +6,8 @@ import torch.distributed as dist
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
- from .core import _muon_state, adjust_lr_for_muon, update_p
10
- from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
@@ -45,26 +45,33 @@ def _launch_gather(
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
- # Build send buffer
49
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
50
  send_counts = [0] * num_ranks
51
-
52
  for p in params:
53
  state = param_to_state[id(p)]
54
- dst = state.worker_rank
55
- assert dst < num_ranks
56
- shard_elems = state.rank_numels[rank]
57
- g = p.grad
58
- g = g.to_local().to(COMM_DTYPE).contiguous()
59
- assert g.numel() == shard_elems
60
- per_dst[dst].append(g.view(-1))
61
- send_counts[dst] += shard_elems
62
-
63
- assert any(
64
- len(v) > 0 for v in
65
- per_dst), "At least one destination rank must receive a sharded tensor"
66
- per_dst_flat = [t for dst in per_dst for t in dst]
67
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
 
 
 
68
 
69
  # Build recv buffer
70
  recv_counts = [0] * num_ranks
@@ -120,7 +127,8 @@ def _complete_gather(
120
 
121
  shard_view = gathered_grads[id(p)][indices]
122
  n = shard_view.numel()
123
- assert n > 0
 
124
 
125
  sg = recv_buf.narrow(0, off + inner_off, n)
126
  sg = sg.reshape(shard_view.shape)
@@ -143,7 +151,7 @@ def _compute_ns(
143
  """
144
  computed_us: dict[int, torch.Tensor | None] = {}
145
  for p in owned_params:
146
- u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
147
  gathered_grads[id(p)] = None # free gathered grad
148
  computed_us[id(p)] = u
149
  return computed_us
@@ -163,46 +171,47 @@ def _launch_scatter(
163
  Returns:
164
  work: Async operation handle.
165
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
166
- scattered_us: ``{id(p): empty_local_tensor}`` for all params.
 
167
  recv_counts: Per-source-rank element counts.
168
  """
169
- # Allocate scattered-u buffers
 
 
 
170
  scattered_us: dict[int, torch.Tensor] = {}
171
  for p in params:
172
- scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE)
 
 
173
 
174
- # Build send buffer (from computed_us on owner ranks)
175
- per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)]
176
  send_counts = [0] * num_ranks
177
-
178
  if owned_params:
179
  for p in owned_params:
180
  state = param_to_state[id(p)]
181
-
182
- assert computed_us[id(p)] is not None
183
- u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous()
184
-
185
- total_sent = 0
186
  for dst_rank in range(num_ranks):
187
- indices = state.rank_indices[dst_rank]
188
- su = u_full[indices].flatten()
189
-
190
- n = su.numel()
191
- assert n > 0
192
 
193
- per_dst[dst_rank].append(su)
194
- send_counts[dst_rank] += n
195
- total_sent += n
196
-
197
- assert total_sent == u_full.numel()
198
-
199
- lengths = [len(v) for v in per_dst]
200
- if all(l > 0 for l in lengths):
201
- assert all(
202
- l == lengths[0] for l in lengths
203
- ), "All destination ranks must have the same number of sharded tensor"
204
- per_dst_flat = [t for dst in per_dst for t in dst]
205
- send_buf = torch.cat(per_dst_flat, dim=0)
 
 
 
 
 
206
  else:
207
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
208
 
@@ -218,7 +227,6 @@ def _launch_scatter(
218
  recv_counts[src] = total
219
 
220
  recv_total = sum(recv_counts)
221
- assert recv_total > 0
222
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
223
 
224
  # Launch async all-to-all
@@ -242,7 +250,13 @@ def _complete_scatter(
242
  rank: int,
243
  scattered_us: dict[int, torch.Tensor],
244
  ) -> None:
245
- """Copy recv buffer into scattered_us (in-place)."""
 
 
 
 
 
 
246
  off = 0
247
  for src in range(len(recv_counts)):
248
  block = recv_counts[src]
@@ -255,11 +269,11 @@ def _complete_scatter(
255
  if state.worker_rank != src:
256
  continue
257
  n = state.rank_numels[rank]
258
- assert n > 0
 
259
 
260
- flat_local = recv_buf.narrow(0, off + inner_off,
261
- n).view_as(p.to_local())
262
- scattered_us[id(p)].copy_(flat_local)
263
 
264
  inner_off += n
265
 
@@ -275,23 +289,40 @@ def _update_params(
275
  lr: float,
276
  weight_decay: float,
277
  ) -> None:
278
- """Apply weight decay, Muon update, and optional QK clipping."""
279
- for p in params:
280
- state = param_to_state[id(p)]
281
- u_dtensor = DTensor.from_local(
282
- scattered_us[id(p)],
283
- placements=p.placements,
284
- device_mesh=p.device_mesh,
285
- )
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
288
- update_p(p, u_dtensor, lr, adjusted_lr, weight_decay)
 
 
 
289
 
290
- # QK clipping applied directly on the local tensor to
291
- # avoid DTensor sharding-propagation issues with _StridedShard.
292
- scales_full = compute_scales(
293
- p,
294
- state.qk_clip_state) if state.qk_clip_state is not None else None
 
 
 
 
 
295
  if scales_full is not None:
296
  ratio = p.shape[0] // scales_full.shape[0]
297
  idx0 = state.rank_indices[rank][0]
@@ -304,6 +335,45 @@ def _update_params(
304
  p._local_tensor.mul_(row_scales.view(-1, 1))
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # ======================================================================
308
  # Main generator – thin orchestrator that wires stages together.
309
  # ======================================================================
@@ -318,6 +388,7 @@ def muon_chunk_pipeline(
318
  lr: float,
319
  weight_decay: float,
320
  none_grad: bool,
 
321
  ) -> Generator[None, None, None]:
322
  """Process one chunk of parameters through the full Muon pipeline.
323
 
@@ -334,9 +405,12 @@ def muon_chunk_pipeline(
334
  runs concurrently on the NCCL stream — no separate ``comm_stream``
335
  is required.
336
 
 
 
 
337
  Yields exactly **2** times:
338
 
339
- 1. After launching async all-to-all gather.
340
  2. After launching async all-to-all scatter.
341
  """
342
  process_group = param_to_state[id(params[0])].process_group
@@ -345,15 +419,19 @@ def muon_chunk_pipeline(
345
  p for p in params if param_to_state[id(p)].worker_rank == rank
346
  ]
347
 
348
- # Stages 1-2: launch async gather.
349
- with record_function("muon::launch_gather"):
350
- work, recv_buf, gathered_grads, recv_counts = _launch_gather(
351
- params, owned_params, param_to_state, rank, num_ranks,
352
- process_group)
353
-
354
- if none_grad:
355
- for p in params:
356
- p.grad = None
 
 
 
 
357
 
358
  yield # --- YIELD 1: other chunks can launch their gather ---
359
 
 
6
  from torch.distributed.tensor import DTensor
7
  from torch.profiler import record_function
8
 
9
+ from .core import _muon_state, adjust_lr_for_muon
10
+ from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
11
  from .qk_clip import compute_scales
12
 
13
  logger = logging.getLogger(__name__)
 
45
  else:
46
  gathered_grads[id(p)] = None
47
 
48
+ # Build send buffer – batch grad copies via torch.cat
49
+ # (1-2 fused kernels vs N individual narrow().copy_() calls).
50
  send_counts = [0] * num_ranks
 
51
  for p in params:
52
  state = param_to_state[id(p)]
53
+ send_counts[state.worker_rank] += state.rank_numels[rank]
54
+
55
+ total_send = sum(send_counts)
56
+ if total_send > 0:
57
+ # Group grad slices by destination rank in a single pass.
58
+ dst_to_grads = [[] for _ in range(num_ranks)]
59
+ for p in params:
60
+ state = param_to_state[id(p)]
61
+ n = state.rank_numels[rank]
62
+ if n > 0:
63
+ g = p.grad.to_local()
64
+ dst_to_grads[state.worker_rank].append(g.reshape(-1))
65
+
66
+ # Flatten in dst order and cat once.
67
+ all_slices = []
68
+ for dst in range(num_ranks):
69
+ all_slices.extend(dst_to_grads[dst])
70
+ send_buf = torch.cat(all_slices)
71
+ if send_buf.dtype != COMM_DTYPE:
72
+ send_buf = send_buf.to(COMM_DTYPE)
73
+ else:
74
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
75
 
76
  # Build recv buffer
77
  recv_counts = [0] * num_ranks
 
127
 
128
  shard_view = gathered_grads[id(p)][indices]
129
  n = shard_view.numel()
130
+ if n == 0:
131
+ continue
132
 
133
  sg = recv_buf.narrow(0, off + inner_off, n)
134
  sg = sg.reshape(shard_view.shape)
 
151
  """
152
  computed_us: dict[int, torch.Tensor | None] = {}
153
  for p in owned_params:
154
+ u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
155
  gathered_grads[id(p)] = None # free gathered grad
156
  computed_us[id(p)] = u
157
  return computed_us
 
171
  Returns:
172
  work: Async operation handle.
173
  recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
174
+ scattered_us: Empty dict, populated by ``_complete_scatter`` with
175
+ zero-copy views into ``recv_buf``.
176
  recv_counts: Per-source-rank element counts.
177
  """
178
+ # scattered_us is populated by _complete_scatter with zero-copy views
179
+ # into recv_buf, avoiding N empty_like allocations + N copy_ calls.
180
+ # Pre-seed entries for params whose local shard is empty (rank_numels == 0)
181
+ # so _update_params can iterate all params without KeyError.
182
  scattered_us: dict[int, torch.Tensor] = {}
183
  for p in params:
184
+ if param_to_state[id(p)].rank_numels[rank] == 0:
185
+ scattered_us[id(p)] = torch.empty_like(p.to_local(),
186
+ dtype=COMM_DTYPE)
187
 
188
+ # Build send buffer batch via torch.cat
189
+ # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
190
  send_counts = [0] * num_ranks
 
191
  if owned_params:
192
  for p in owned_params:
193
  state = param_to_state[id(p)]
 
 
 
 
 
194
  for dst_rank in range(num_ranks):
195
+ send_counts[dst_rank] += state.rank_numels[dst_rank]
 
 
 
 
196
 
197
+ total_send = sum(send_counts)
198
+ if total_send > 0:
199
+ # Cache u_full conversions to avoid redundant .to() per dst_rank.
200
+ u_fulls = {}
201
+ for p in owned_params:
202
+ u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
203
+
204
+ # Collect slices in dst order (matches all-to-all send layout).
205
+ all_slices = []
206
+ for dst_rank in range(num_ranks):
207
+ for p in owned_params:
208
+ state = param_to_state[id(p)]
209
+ su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
210
+ if su.numel() > 0:
211
+ all_slices.append(su)
212
+
213
+ send_buf = torch.cat(all_slices) if all_slices else torch.empty(
214
+ 0, dtype=COMM_DTYPE, device="cuda")
215
  else:
216
  send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
217
 
 
227
  recv_counts[src] = total
228
 
229
  recv_total = sum(recv_counts)
 
230
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
231
 
232
  # Launch async all-to-all
 
250
  rank: int,
251
  scattered_us: dict[int, torch.Tensor],
252
  ) -> None:
253
+ """Populate scattered_us with zero-copy views into recv_buf.
254
+
255
+ Instead of pre-allocating tensors and copying, we assign views directly
256
+ from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
257
+ The underlying storage of ``recv_buf`` is kept alive through the views
258
+ until ``scattered_us`` is cleared after ``_update_params``.
259
+ """
260
  off = 0
261
  for src in range(len(recv_counts)):
262
  block = recv_counts[src]
 
269
  if state.worker_rank != src:
270
  continue
271
  n = state.rank_numels[rank]
272
+ if n == 0:
273
+ continue
274
 
275
+ scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
276
+ n).view_as(p.to_local())
 
277
 
278
  inner_off += n
279
 
 
289
  lr: float,
290
  weight_decay: float,
291
  ) -> None:
292
+ """Apply weight decay, Muon update, and optional QK clipping.
 
 
 
 
 
 
 
293
 
294
+ Uses batched ``_foreach_mul_`` for weight decay and batched
295
+ ``_foreach_add_`` for the Muon update, grouping parameters by
296
+ adjusted_lr to minimize kernel launches while preserving float32
297
+ precision for the alpha scaling.
298
+ """
299
+ if not params:
300
+ return
301
+
302
+ # Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
303
+ p_locals = [p._local_tensor for p in params]
304
+ torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
305
+
306
+ # Group params by adjusted_lr so _foreach_add_ can use a single
307
+ # alpha per group (preserves float32 precision for alpha scaling).
308
+ lr_groups: dict[float, tuple[list, list]] = {}
309
+ for p in params:
310
  adjusted_lr = adjust_lr_for_muon(lr, p.shape)
311
+ if adjusted_lr not in lr_groups:
312
+ lr_groups[adjusted_lr] = ([], [])
313
+ lr_groups[adjusted_lr][0].append(p._local_tensor)
314
+ lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
315
 
316
+ for adjusted_lr, (p_group, u_group) in lr_groups.items():
317
+ torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
318
+
319
+ # QK clipping – applied directly on the local tensor to
320
+ # avoid DTensor sharding-propagation issues with _StridedShard.
321
+ for p in params:
322
+ state = param_to_state[id(p)]
323
+ if state.qk_clip_state is None:
324
+ continue
325
+ scales_full = compute_scales(p, state.qk_clip_state)
326
  if scales_full is not None:
327
  ratio = p.shape[0] // scales_full.shape[0]
328
  idx0 = state.rank_indices[rank][0]
 
335
  p._local_tensor.mul_(row_scales.view(-1, 1))
336
 
337
 
338
+ # ======================================================================
339
+ # Pre-launch helper for overlapping first chunk's gather with other work.
340
+ # ======================================================================
341
+
342
+
343
+ @torch.no_grad()
344
+ def prelaunch_first_gather(
345
+ params: list[DTensor],
346
+ param_to_state: dict[int, _muon_state],
347
+ rank: int,
348
+ none_grad: bool,
349
+ ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
350
+ """Launch the first chunk's A2A gather early for overlap with other compute.
351
+
352
+ Call this *before* expensive GPU work (e.g. batched expert NS) so that
353
+ the NCCL all-to-all runs concurrently on the NCCL stream while the
354
+ default stream executes compute.
355
+
356
+ Returns the same 4-tuple that ``_launch_gather`` produces, which should
357
+ be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
358
+ """
359
+ process_group = param_to_state[id(params[0])].process_group
360
+ num_ranks = dist.get_world_size(group=process_group)
361
+ owned_params = [
362
+ p for p in params if param_to_state[id(p)].worker_rank == rank
363
+ ]
364
+
365
+ with record_function("muon::prelaunch_gather"):
366
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
367
+ params, owned_params, param_to_state, rank, num_ranks,
368
+ process_group)
369
+
370
+ if none_grad:
371
+ for p in params:
372
+ p.grad = None
373
+
374
+ return work, recv_buf, gathered_grads, recv_counts
375
+
376
+
377
  # ======================================================================
378
  # Main generator – thin orchestrator that wires stages together.
379
  # ======================================================================
 
388
  lr: float,
389
  weight_decay: float,
390
  none_grad: bool,
391
+ prelaunch_gather: tuple | None = None,
392
  ) -> Generator[None, None, None]:
393
  """Process one chunk of parameters through the full Muon pipeline.
394
 
 
405
  runs concurrently on the NCCL stream — no separate ``comm_stream``
406
  is required.
407
 
408
+ If ``prelaunch_gather`` is provided, the gather was already launched
409
+ by :func:`prelaunch_first_gather` and we skip launching it again.
410
+
411
  Yields exactly **2** times:
412
 
413
+ 1. After launching async all-to-all gather (or immediately if pre-launched).
414
  2. After launching async all-to-all scatter.
415
  """
416
  process_group = param_to_state[id(params[0])].process_group
 
419
  p for p in params if param_to_state[id(p)].worker_rank == rank
420
  ]
421
 
422
+ if prelaunch_gather is not None:
423
+ # Gather was pre-launched; none_grad already handled by caller.
424
+ work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
425
+ else:
426
+ # Normal path: launch async gather.
427
+ with record_function("muon::launch_gather"):
428
+ work, recv_buf, gathered_grads, recv_counts = _launch_gather(
429
+ params, owned_params, param_to_state, rank, num_ranks,
430
+ process_group)
431
+
432
+ if none_grad:
433
+ for p in params:
434
+ p.grad = None
435
 
436
  yield # --- YIELD 1: other chunks can launch their gather ---
437
 
build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import dataclass
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
@@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
  'model.4.attn.v_proj.weight' -> (None, -1)
25
  """
26
- parts = name.split('.')
27
  if len(parts) < 3:
28
  return None, -1
29
 
@@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state):
100
  threshold = qk_clip_state.threshold
101
  logit = qk_clip_state.logit
102
 
103
- H_global = p.shape[0] // head_dim
104
- scales_full = torch.ones(H_global, device=p.data.device)
105
- scaling = 0
106
-
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
- if new_scale < scales_full[head_idx]:
112
- scales_full[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
117
- scaling += 1
118
 
119
- return scales_full if scaling > 0 else None
 
 
 
 
 
 
 
120
 
121
 
122
  def qk_clip(p, scales, head_dim):
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
8
+ from .core import normalize_fqn
9
+
10
  logger = logging.getLogger(__name__)
11
 
12
 
 
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
+ parts = normalize_fqn(name).split('.')
29
  if len(parts) < 3:
30
  return None, -1
31
 
 
102
  threshold = qk_clip_state.threshold
103
  logit = qk_clip_state.logit
104
 
105
+ # Check if any head exceeds threshold before allocating.
106
+ head_scales = {}
 
 
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
+ if head_idx not in head_scales or new_scale < head_scales[head_idx]:
112
+ head_scales[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
 
117
 
118
+ if not head_scales:
119
+ return None
120
+
121
+ H_global = p.shape[0] // head_dim
122
+ scales_full = torch.ones(H_global, device=p.data.device)
123
+ for head_idx, scale in head_scales.items():
124
+ scales_full[head_idx] = scale
125
+ return scales_full
126
 
127
 
128
  def qk_clip(p, scales, head_dim):
build/torch210-cxx11-rocm71-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_7aef62f_dirty
3
- ops = torch.ops._optimizer_7aef62f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_7aef62f_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_5b58933_dirty
3
+ ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_5b58933_dirty::{op_name}"
build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_7aef62f_dirty.abi3.so → _optimizer_5b58933_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e67022789ddd9296552fc5ab4075ce96b8b00b75bce057c707e5b5076bbde734
3
  size 1866112
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f41709878a4def27b12f4f9a4f5b767027fb33141e775f64ad04d434fcbe33d9
3
  size 1866112
build/torch210-cxx11-rocm71-x86_64-linux/adamw.py CHANGED
@@ -1,8 +1,12 @@
 
1
  from collections import defaultdict
2
  from typing import cast
3
 
4
  import torch
5
  from torch.distributed.tensor import DTensor
 
 
 
6
 
7
 
8
  def fused_adamw(
@@ -72,54 +76,72 @@ def fused_adamw(
72
  )
73
 
74
 
75
- def step_adamw_params(optimizer_state, params, group):
76
- """Run fused AdamW on a list of parameters sharing the same placement.
 
77
 
78
- Args:
79
- optimizer_state: The optimizer's state dict (self.state in Muon).
80
- params: List of parameters to update.
81
- group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
82
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  params_with_grads = []
84
  grads = []
85
  moment1 = []
86
  moment2 = []
87
- max_exp_avg_sqs = []
88
  state_steps = []
89
- lr = group["lr"]
90
- beta1, beta2 = group["adamw_betas"]
91
- eps = group["adamw_eps"]
92
- weight_decay = group["weight_decay"]
93
 
94
  for p in params:
95
  g = p.grad
96
  if g is None:
97
  continue
98
  state = optimizer_state[p]
99
- params_with_grads.append(p)
100
- grads.append(g)
101
  if "step" not in state:
102
- state["step"] = (torch.zeros((),
103
- dtype=torch.float32,
104
- device=p.device))
105
  state["moment1"] = torch.zeros_like(g)
106
  state["moment2"] = torch.zeros_like(g)
107
- moment1.append(state["moment1"])
108
- moment2.append(state["moment2"])
109
  if not isinstance(state["step"], torch.Tensor):
110
- step_tensor = torch.tensor(state["step"],
111
- dtype=torch.float32,
112
- device=p.device)
113
- else:
114
- step_tensor = state["step"]
115
- state_steps.append(step_tensor)
 
 
 
 
 
 
116
 
117
  fused_adamw(
118
  params_with_grads,
119
  grads,
120
  moment1,
121
  moment2,
122
- max_exp_avg_sqs,
123
  state_steps,
124
  amsgrad=False,
125
  beta1=beta1,
@@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group):
131
  )
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def step_adamw(optimizer_state, group):
135
  """Dispatch AdamW step, grouping parameters by type and placement.
136
 
 
 
 
137
  Args:
138
  optimizer_state: The optimizer's state dict (self.state in Muon).
139
  group: Parameter group dict.
140
  """
141
  params = group["params"]
 
142
 
143
- # group params with its type and placement
144
- placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list)
145
- for p in params:
146
- match p:
147
- case DTensor():
148
- placement_to_params[tuple([p.placements,
149
- p.device_mesh])].append(p)
150
- case torch.Tensor():
151
- placement_to_params[tuple([torch.Tensor, None])].append(p)
152
-
153
- for group_params in placement_to_params.values():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  step_adamw_params(optimizer_state, group_params, group)
 
1
+ import logging
2
  from collections import defaultdict
3
  from typing import cast
4
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
+ from torch.profiler import record_function
8
+
9
+ logger = logging.getLogger(__name__)
10
 
11
 
12
  def fused_adamw(
 
76
  )
77
 
78
 
79
+ def _to_local(t):
80
+ """Unwrap DTensor to local tensor for fused ops."""
81
+ return t._local_tensor if isinstance(t, DTensor) else t
82
 
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Caches for eliminating per-step Python overhead.
86
+ #
87
+ # Placement grouping and tensor list assembly are identical every step
88
+ # (params don't change placement, moment/step tensors are the same objects
89
+ # after initialisation). We cache them keyed by id() of the param list
90
+ # stored in param_groups (stable across steps).
91
+ #
92
+ # Only gradients change each step and must be collected fresh.
93
+ # ---------------------------------------------------------------------------
94
+
95
+ # id(group["params"]) → dict[placement_key, list[param]]
96
+ _placement_cache: dict[int, dict[tuple, list]] = {}
97
+
98
+ # id(placement_group_list) → (params_local, moment1, moment2, state_steps)
99
+ _tensor_cache: dict[int, tuple[list, list, list, list]] = {}
100
+
101
+
102
+ def _step_adamw_params_slow(optimizer_state, params, group):
103
+ """Uncached fallback for the rare case where some params lack grads."""
104
  params_with_grads = []
105
  grads = []
106
  moment1 = []
107
  moment2 = []
 
108
  state_steps = []
 
 
 
 
109
 
110
  for p in params:
111
  g = p.grad
112
  if g is None:
113
  continue
114
  state = optimizer_state[p]
115
+ params_with_grads.append(_to_local(p))
116
+ grads.append(_to_local(g))
117
  if "step" not in state:
118
+ state["step"] = torch.zeros((),
119
+ dtype=torch.float32,
120
+ device=p.device)
121
  state["moment1"] = torch.zeros_like(g)
122
  state["moment2"] = torch.zeros_like(g)
123
+ moment1.append(_to_local(state["moment1"]))
124
+ moment2.append(_to_local(state["moment2"]))
125
  if not isinstance(state["step"], torch.Tensor):
126
+ state["step"] = torch.tensor(state["step"],
127
+ dtype=torch.float32,
128
+ device=p.device)
129
+ state_steps.append(state["step"])
130
+
131
+ if not params_with_grads:
132
+ return
133
+
134
+ lr = group["lr"]
135
+ beta1, beta2 = group["adamw_betas"]
136
+ eps = group["adamw_eps"]
137
+ weight_decay = group["weight_decay"]
138
 
139
  fused_adamw(
140
  params_with_grads,
141
  grads,
142
  moment1,
143
  moment2,
144
+ [],
145
  state_steps,
146
  amsgrad=False,
147
  beta1=beta1,
 
153
  )
154
 
155
 
156
+ def step_adamw_params(optimizer_state, params, group):
157
+ """Run fused AdamW on a list of parameters sharing the same placement.
158
+
159
+ After the first call, cached tensor lists (params_local, moment1,
160
+ moment2, state_steps) are reused — only gradients are collected fresh.
161
+
162
+ Args:
163
+ optimizer_state: The optimizer's state dict (self.state in Muon).
164
+ params: List of parameters to update.
165
+ group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
166
+ """
167
+ # Collect grads — the only thing that changes each step.
168
+ with record_function("adamw::collect_grads"):
169
+ grads = []
170
+ for p in params:
171
+ g = p.grad
172
+ if g is None:
173
+ # Rare: fall back to slow path that filters per-param.
174
+ _step_adamw_params_slow(optimizer_state, params, group)
175
+ return
176
+ grads.append(_to_local(g))
177
+
178
+ tensor_key = id(params)
179
+ if tensor_key not in _tensor_cache:
180
+ with record_function("adamw::init_tensor_cache"):
181
+ params_local = []
182
+ moment1 = []
183
+ moment2 = []
184
+ state_steps = []
185
+
186
+ for p in params:
187
+ state = optimizer_state[p]
188
+ params_local.append(_to_local(p))
189
+ if "step" not in state:
190
+ state["step"] = torch.zeros((),
191
+ dtype=torch.float32,
192
+ device=p.device)
193
+ state["moment1"] = torch.zeros_like(p.grad)
194
+ state["moment2"] = torch.zeros_like(p.grad)
195
+ moment1.append(_to_local(state["moment1"]))
196
+ moment2.append(_to_local(state["moment2"]))
197
+ if not isinstance(state["step"], torch.Tensor):
198
+ state["step"] = torch.tensor(state["step"],
199
+ dtype=torch.float32,
200
+ device=p.device)
201
+ state_steps.append(state["step"])
202
+
203
+ _tensor_cache[tensor_key] = (params_local, moment1, moment2,
204
+ state_steps)
205
+
206
+ params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
207
+
208
+ lr = group["lr"]
209
+ beta1, beta2 = group["adamw_betas"]
210
+ eps = group["adamw_eps"]
211
+ weight_decay = group["weight_decay"]
212
+
213
+ with record_function("adamw::fused_adamw"):
214
+ fused_adamw(
215
+ params_local,
216
+ grads,
217
+ moment1,
218
+ moment2,
219
+ [],
220
+ state_steps,
221
+ amsgrad=False,
222
+ beta1=beta1,
223
+ beta2=beta2,
224
+ lr=lr,
225
+ weight_decay=weight_decay,
226
+ eps=eps,
227
+ maximize=False,
228
+ )
229
+
230
+
231
  def step_adamw(optimizer_state, group):
232
  """Dispatch AdamW step, grouping parameters by type and placement.
233
 
234
+ Placement grouping is cached after the first call since params never
235
+ change their placement between steps.
236
+
237
  Args:
238
  optimizer_state: The optimizer's state dict (self.state in Muon).
239
  group: Parameter group dict.
240
  """
241
  params = group["params"]
242
+ placement_key = id(params)
243
 
244
+ if placement_key not in _placement_cache:
245
+ with record_function("adamw::group_by_placement"):
246
+ placement_to_params: dict[tuple,
247
+ list[torch.Tensor]] = defaultdict(list)
248
+ for p in params:
249
+ match p:
250
+ case DTensor():
251
+ logger.debug(
252
+ "[AdamW] DTensor param: shape=%s, placements=%s, "
253
+ "mesh=%s, grad=%s", p.shape, p.placements,
254
+ p.device_mesh.mesh_dim_names,
255
+ p.grad.shape if p.grad is not None else None)
256
+ placement_to_params[tuple(
257
+ [p.placements, p.device_mesh])].append(p)
258
+ case torch.Tensor():
259
+ logger.debug(
260
+ "[AdamW] plain param: shape=%s, grad=%s", p.shape,
261
+ p.grad.shape if p.grad is not None else None)
262
+ placement_to_params[tuple([torch.Tensor,
263
+ None])].append(p)
264
+
265
+ logger.debug("[AdamW] %d placement groups, %d total params",
266
+ len(placement_to_params), len(params))
267
+
268
+ _placement_cache[placement_key] = dict(placement_to_params)
269
+
270
+ for group_params in _placement_cache[placement_key].values():
271
  step_adamw_params(optimizer_state, group_params, group)
build/torch210-cxx11-rocm71-x86_64-linux/core.py CHANGED
@@ -1,11 +1,25 @@
 
1
  import math
2
  from dataclasses import dataclass
 
3
 
4
  import torch
5
- import torch.distributed as dist
6
  from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @dataclass
11
  class _muon_state:
@@ -17,26 +31,71 @@ class _muon_state:
17
  qk_clip_state: torch.Tensor | None = None
18
 
19
 
20
- def update_g(optimizer_state, p, g, group, momentum):
21
- """Apply momentum update to gradient.
 
 
 
 
 
 
22
 
23
- Args:
24
- optimizer_state: The optimizer's state dict (self.state in Muon).
25
- p: Parameter tensor.
26
- g: Gradient tensor.
27
- group: Parameter group dict.
28
- momentum: Momentum coefficient.
29
 
30
- Returns:
31
- Momentum-updated gradient tensor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
- state = optimizer_state[p]
34
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
35
- torch.add(g, buf, alpha=momentum, out=buf)
36
- if group["nesterov"]:
37
- g.add_(buf, alpha=momentum)
38
- return g
39
- return buf
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay):
49
  adjusted_lr: Size-adjusted learning rate.
50
  weight_decay: Weight decay coefficient.
51
  """
52
- if isinstance(p, torch.nn.Parameter):
53
- # apply weight decay
54
- p.data.mul_(1 - lr * weight_decay)
55
- # apply update
56
- p.data.add_(u, alpha=-adjusted_lr)
57
- else:
58
- p.mul_(1 - lr * weight_decay)
59
- p.add_(u, alpha=-adjusted_lr)
60
 
61
 
62
  def adjust_lr_for_muon(lr, param_shape):
@@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape):
77
  return adjusted_lr
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def default_is_muon(name, x, expert_keys=None):
81
- skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
- if any(key in name for key in skip_keys):
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return False
84
  effective_ndim = x.ndim
85
- if expert_keys and any(key in name for key in expert_keys):
 
86
  effective_ndim -= 1
87
- return effective_ndim >= 2
 
 
 
 
 
88
 
89
 
90
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
@@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
92
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
93
 
94
  muon_params, muon_names = [], []
95
- non_muon_params = []
96
 
97
  for n, p in model.named_parameters():
98
  if not p.requires_grad:
@@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
102
  muon_names.append(n)
103
  else:
104
  non_muon_params.append(p)
 
 
 
 
105
 
106
  return [
107
  {
 
1
+ import logging
2
  import math
3
  from dataclasses import dataclass
4
+ from typing import List
5
 
6
  import torch
 
7
  from torch.distributed import ProcessGroup
8
  from torch.distributed.tensor import DTensor
9
 
10
+ # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
+ # parameter FQNs. Activation checkpointing similarly inserts
12
+ # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
+ # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
+ _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def normalize_fqn(name: str) -> str:
20
+ """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
21
+ return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
22
+
23
 
24
  @dataclass
25
  class _muon_state:
 
31
  qk_clip_state: torch.Tensor | None = None
32
 
33
 
34
+ def _batch_momentum(
35
+ grads: List[torch.Tensor],
36
+ momentum_bufs: List[torch.Tensor],
37
+ momentum: torch.Tensor,
38
+ ) -> None:
39
+ """Batched momentum update (no nesterov)."""
40
+ torch._foreach_mul_(momentum_bufs, momentum)
41
+ torch._foreach_add_(momentum_bufs, grads)
42
 
 
 
 
 
 
 
43
 
44
+ def _batch_momentum_nesterov(
45
+ grads: List[torch.Tensor],
46
+ momentum_bufs: List[torch.Tensor],
47
+ momentum: torch.Tensor,
48
+ ) -> None:
49
+ """Batched momentum update with nesterov correction."""
50
+ torch._foreach_mul_(momentum_bufs, momentum)
51
+ torch._foreach_add_(momentum_bufs, grads)
52
+ nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
+ torch._foreach_add_(grads, nesterov_terms)
54
+
55
+
56
+ _compiled_momentum: dict[bool, callable] = {}
57
+ _use_momentum_compile = True
58
+
59
+
60
+ def set_momentum_compile(enabled: bool):
61
+ """Toggle torch.compile for batched momentum."""
62
+ global _use_momentum_compile
63
+ _use_momentum_compile = enabled
64
+
65
+
66
+ def batch_pre_ortho(
67
+ grads: List[torch.Tensor],
68
+ momentum_bufs: List[torch.Tensor],
69
+ momentum: torch.Tensor,
70
+ nesterov: bool,
71
+ ) -> None:
72
+ """Batched momentum update on lists of plain tensors.
73
+
74
+ Mirrors dion's ``muon_update_pre_orthogonalize``.
75
+ Inputs must be plain CUDA tensors (not DTensor).
76
+ Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
+
78
+ When compile is enabled, uses separately compiled functions for
79
+ nesterov=True/False to avoid graph breaks from the branch.
80
  """
81
+ fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
+ if _use_momentum_compile:
83
+ if nesterov not in _compiled_momentum:
84
+ _compiled_momentum[nesterov] = torch.compile(fn)
85
+ fn = _compiled_momentum[nesterov]
86
+ fn(grads, momentum_bufs, momentum)
87
+
88
+
89
+ def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
+ """Weight-decay + update on plain tensors.
91
+
92
+ Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
+ lookup per call × 256+ params = massive overhead. The pipeline path uses
94
+ batched _foreach_* ops instead; this function remains for base() and
95
+ distributed_muon().
96
+ """
97
+ p_data.mul_(1 - lr * weight_decay)
98
+ p_data.add_(u_data, alpha=-adjusted_lr)
99
 
100
 
101
  def update_p(p, u, lr, adjusted_lr, weight_decay):
 
108
  adjusted_lr: Size-adjusted learning rate.
109
  weight_decay: Weight decay coefficient.
110
  """
111
+ # Unwrap Parameter -> underlying data tensor.
112
+ p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
+ # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
+ if isinstance(p_data, DTensor):
115
+ p_data = p_data._local_tensor
116
+ u_data = u._local_tensor if isinstance(u, DTensor) else u
117
+ _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
 
118
 
119
 
120
  def adjust_lr_for_muon(lr, param_shape):
 
135
  return adjusted_lr
136
 
137
 
138
+ def _match_key(parts, key):
139
+ """Check if key matches as contiguous components in parts.
140
+
141
+ Single-component keys (e.g. "experts") match any single component.
142
+ Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
143
+ """
144
+ key_parts = key.split(".")
145
+ key_len = len(key_parts)
146
+ if key_len == 1:
147
+ return key in parts
148
+ return any(parts[i:i + key_len] == key_parts
149
+ for i in range(len(parts) - key_len + 1))
150
+
151
+
152
+ def is_expert_param(name, expert_keys):
153
+ """Check if a parameter name matches any expert key (component-level)."""
154
+ if not expert_keys:
155
+ return False
156
+ parts = normalize_fqn(name).split(".")
157
+ return any(_match_key(parts, key) for key in expert_keys)
158
+
159
+
160
  def default_is_muon(name, x, expert_keys=None):
161
+ normalized = normalize_fqn(name)
162
+ parts = normalized.split(".")
163
+ skip_keys = [
164
+ "embed_tokens",
165
+ "lm_head",
166
+ "tok_embeddings",
167
+ "output",
168
+ "mhc_attn",
169
+ "mhc_ffn",
170
+ "lambda_proj",
171
+ ]
172
+ if any(key in parts for key in skip_keys):
173
+ logger.info(
174
+ "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
175
+ normalized, name, x.ndim)
176
  return False
177
  effective_ndim = x.ndim
178
+ is_expert = is_expert_param(name, expert_keys)
179
+ if is_expert:
180
  effective_ndim -= 1
181
+ result = effective_ndim >= 2
182
+ logger.info(
183
+ "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
184
+ normalized, name, x.ndim, is_expert, effective_ndim,
185
+ "Muon" if result else "AdamW")
186
+ return result
187
 
188
 
189
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
 
191
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
 
193
  muon_params, muon_names = [], []
194
+ non_muon_params, non_muon_names = [], []
195
 
196
  for n, p in model.named_parameters():
197
  if not p.requires_grad:
 
201
  muon_names.append(n)
202
  else:
203
  non_muon_params.append(p)
204
+ non_muon_names.append(n)
205
+
206
+ logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
+ expert_keys, len(muon_names), len(non_muon_names))
208
 
209
  return [
210
  {
build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CPU offloading for optimizer states.
2
+
3
+ Manages a pinned CPU memory pool and async CUDA streams to offload
4
+ optimizer state tensors (momentum buffers, Adam moments) to CPU between
5
+ optimizer steps, freeing GPU memory.
6
+
7
+ All tracked tensors are packed into a single flat pinned CPU buffer
8
+ (per dtype). D2H and H2D copies are performed per-tensor directly
9
+ between individual GPU tensors and their slice of the CPU flat buffer
10
+ — no GPU staging buffer is allocated, so there is **no temporary GPU
11
+ memory spike** during offload or reload.
12
+
13
+ Individual tensor storages are freed after offload via
14
+ ``untyped_storage().resize_(0)``, preserving tensor identity so
15
+ downstream caches remain valid.
16
+ """
17
+
18
+ import logging
19
+ from collections import defaultdict
20
+
21
+ import torch
22
+ from torch.distributed.tensor import DTensor
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class CPUOffloadPool:
28
+ """Pinned CPU memory pool for async optimizer state offloading.
29
+
30
+ Tracked tensors are grouped by dtype. Each group gets a single flat
31
+ pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of
32
+ the flat buffer) to avoid allocating a GPU staging buffer.
33
+ """
34
+
35
+ def __init__(self):
36
+ self._managed: list[torch.Tensor] = []
37
+ self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
38
+
39
+ # Per-dtype group: populated on first offload.
40
+ # dtype → dict with keys:
41
+ # "indices" : list[int] managed-list indices
42
+ # "offsets" : list[tuple[int,int]] (start, numel) in flat buf
43
+ # "total" : int total numel
44
+ # "cpu_flat" : Tensor pinned CPU buffer
45
+ self._groups: dict[torch.dtype, dict] = {}
46
+
47
+ self._offload_stream: torch.cuda.Stream | None = None
48
+ self._device: torch.device | None = None
49
+ self._initialized: bool = False
50
+ self._logged: bool = False
51
+
52
+ # ------------------------------------------------------------------
53
+ @staticmethod
54
+ def _local(t: torch.Tensor) -> torch.Tensor:
55
+ """Unwrap DTensor to its local CUDA tensor."""
56
+ return t._local_tensor if isinstance(t, DTensor) else t
57
+
58
+ def _ensure_stream(self):
59
+ if self._offload_stream is None:
60
+ self._offload_stream = torch.cuda.Stream(device=self._device)
61
+
62
+ # ------------------------------------------------------------------
63
+ def track(self, tensor: torch.Tensor):
64
+ """Register a GPU tensor for CPU offloading. Idempotent."""
65
+ tid = id(tensor)
66
+ if tid in self._storage_nbytes:
67
+ return
68
+ local = self._local(tensor)
69
+ if self._device is None:
70
+ self._device = local.device
71
+ self._storage_nbytes[tid] = local.untyped_storage().size()
72
+ self._managed.append(tensor)
73
+
74
+ # ------------------------------------------------------------------
75
+ def _init_buffers(self):
76
+ """Build per-dtype flat buffers on first offload."""
77
+ # Group managed tensors by dtype.
78
+ dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list)
79
+ for idx, t in enumerate(self._managed):
80
+ local = self._local(t)
81
+ dtype_map[local.dtype].append((idx, local.numel()))
82
+
83
+ total_cpu_bytes = 0
84
+ for dtype, entries in dtype_map.items():
85
+ offsets: list[tuple[int, int]] = []
86
+ indices: list[int] = []
87
+ off = 0
88
+ for idx, n in entries:
89
+ indices.append(idx)
90
+ offsets.append((off, n))
91
+ off += n
92
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
93
+ self._groups[dtype] = {
94
+ "indices": indices,
95
+ "offsets": offsets,
96
+ "total": off,
97
+ "cpu_flat": cpu_flat,
98
+ }
99
+ total_cpu_bytes += off * cpu_flat.element_size()
100
+
101
+ self._initialized = True
102
+ logger.info(
103
+ "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), "
104
+ "%.2f MB pinned CPU memory",
105
+ len(self._managed),
106
+ len(self._groups),
107
+ total_cpu_bytes / (1024**2),
108
+ )
109
+
110
+ # ------------------------------------------------------------------
111
+ def offload(self):
112
+ """Per-tensor async D2H into CPU flat buffer, then free GPU storage."""
113
+ if not self._managed:
114
+ return
115
+ if not self._initialized:
116
+ self._init_buffers()
117
+ self._ensure_stream()
118
+
119
+ # Offload stream waits for compute to finish.
120
+ compute_event = torch.cuda.current_stream(
121
+ self._device).record_event()
122
+ self._offload_stream.wait_event(compute_event)
123
+
124
+ offloaded_bytes = 0
125
+
126
+ # Per-tensor D2H copies directly into CPU flat buffer slices.
127
+ # No GPU staging buffer → no temporary GPU memory spike.
128
+ with torch.cuda.stream(self._offload_stream):
129
+ for dtype, grp in self._groups.items():
130
+ indices = grp["indices"]
131
+ offsets = grp["offsets"]
132
+ cpu_flat = grp["cpu_flat"]
133
+
134
+ for i, mgd_idx in enumerate(indices):
135
+ local = self._local(self._managed[mgd_idx])
136
+ off, n = offsets[i]
137
+ cpu_flat[off:off + n].copy_(
138
+ local.reshape(-1), non_blocking=True)
139
+
140
+ offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
+
142
+ # Wait for all D2H copies to land, then free GPU storage.
143
+ self._offload_stream.synchronize()
144
+ for t in self._managed:
145
+ self._local(t).untyped_storage().resize_(0)
146
+
147
+ if not self._logged:
148
+ logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
149
+ offloaded_bytes / (1024**2))
150
+
151
+ # ------------------------------------------------------------------
152
+ def reload(self):
153
+ """Per-tensor H2D from CPU flat buffer on the default stream.
154
+
155
+ Runs on the current (default) CUDA stream to avoid stream
156
+ interaction issues with the parallel Muon pipeline. Since
157
+ pinned CPU memory is the source, the copies overlap with
158
+ GPU idle time between steps.
159
+ """
160
+ if not self._managed or not self._initialized:
161
+ return
162
+
163
+ reloaded_bytes = 0
164
+
165
+ # Re-allocate all GPU storages first.
166
+ for t in self._managed:
167
+ local = self._local(t)
168
+ local.untyped_storage().resize_(self._storage_nbytes[id(t)])
169
+
170
+ # Per-tensor H2D copies from CPU flat buffer slices.
171
+ # non_blocking=True with pinned source allows DMA overlap.
172
+ for dtype, grp in self._groups.items():
173
+ indices = grp["indices"]
174
+ offsets = grp["offsets"]
175
+ cpu_flat = grp["cpu_flat"]
176
+
177
+ for i, mgd_idx in enumerate(indices):
178
+ local = self._local(self._managed[mgd_idx])
179
+ off, n = offsets[i]
180
+ local.reshape(-1).copy_(
181
+ cpu_flat[off:off + n], non_blocking=True)
182
+
183
+ reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
+
185
+ if not self._logged:
186
+ logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)",
187
+ reloaded_bytes / (1024**2))
188
+ self._logged = True
build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py CHANGED
@@ -72,12 +72,6 @@ def get_slices_of_dtensor(
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
75
- if curr_size % num_chunks != 0:
76
- raise NotImplementedError(
77
- f"Dimension size {curr_size} is not divisible "
78
- f"by number of ranks {num_chunks} for shard "
79
- f"placement on dim {shard_dim}. (shape: {target.shape})")
80
-
81
  # Compute indices for this level of sharding
82
  if isinstance(placement, _StridedShard):
83
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
 
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
 
 
 
 
 
 
75
  # Compute indices for this level of sharding
76
  if isinstance(placement, _StridedShard):
77
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(