Skip to content

Commit 788ebc4

Browse files
committed
fix: load fused moe
1 parent 4c57054 commit 788ebc4

File tree

4 files changed

+263
-51
lines changed

4 files changed

+263
-51
lines changed

python/sgl_jax/srt/layers/fused_moe.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,64 @@ def __call__(self, hidden_states: jax.Array, router_logits: jax.Array) -> jax.Ar
182182
Returns:
183183
MoE layer output, same shape as hidden_states
184184
"""
185+
186+
# Debug: Print weights using callback to ensure execution
187+
def print_w1_stats(w1):
188+
print(f"[FusedEPMoE Layer {self.layer_id}] w1 shape: {w1.shape}, dtype: {w1.dtype}")
189+
print(f"[FusedEPMoE Layer {self.layer_id}] w1 full:\n{w1}")
190+
191+
def print_w2_stats(w2):
192+
print(f"[FusedEPMoE Layer {self.layer_id}] w2 shape: {w2.shape}, dtype: {w2.dtype}")
193+
print(f"[FusedEPMoE Layer {self.layer_id}] w2 full:\n{w2}")
194+
195+
def print_w1_gate_stats(w1_gate):
196+
print(
197+
f"[FusedEPMoE Layer {self.layer_id}] w1_gate shape: {w1_gate.shape}, dtype: {w1_gate.dtype}"
198+
)
199+
print(f"[FusedEPMoE Layer {self.layer_id}] w1_gate full:\n{w1_gate}")
200+
201+
def print_w1_up_stats(w1_up):
202+
print(
203+
f"[FusedEPMoE Layer {self.layer_id}] w1_up shape: {w1_up.shape}, dtype: {w1_up.dtype}"
204+
)
205+
print(f"[FusedEPMoE Layer {self.layer_id}] w1_up full:\n{w1_up}")
206+
207+
# jax.debug.callback(print_w1_stats, self.w1.value)
208+
# jax.debug.callback(print_w2_stats, self.w2.value)
209+
210+
# Debug: Print w1 gate_proj and up_proj separately (E, 2, H, I)
211+
# w1_gate = self.w1.value[:, 0, :, :] # gate_proj weights
212+
# w1_up = self.w1.value[:, 1, :, :] # up_proj weights
213+
214+
# jax.debug.callback(print_w1_gate_stats, w1_gate)
215+
# jax.debug.callback(print_w1_up_stats, w1_up)
216+
185217
assert hidden_states.ndim == 2
186218

219+
# Debug: Input before resharding
220+
def print_input_tokens(tokens):
221+
print(
222+
f"[FusedEPMoE Layer {self.layer_id}] Input tokens shape: {tokens.shape}, dtype: {tokens.dtype}"
223+
)
224+
print(f"[FusedEPMoE Layer {self.layer_id}] Input tokens full:\n{tokens}")
225+
226+
# jax.debug.callback(print_input_tokens, hidden_states)
227+
228+
hidden_states = jax.sharding.reshard(hidden_states, P("tensor", None))
229+
router_logits = jax.sharding.reshard(router_logits, P("tensor", None))
230+
231+
# Debug: Input after resharding
232+
def print_router_logits(logits):
233+
print(
234+
f"[FusedEPMoE Layer {self.layer_id}] Calling fused_ep_moe kernel with top_k={self.num_experts_per_tok}, renormalize={self.renormalize_topk_logits}"
235+
)
236+
print(
237+
f"[FusedEPMoE Layer {self.layer_id}] Router logits shape: {logits.shape}, dtype: {logits.dtype}"
238+
)
239+
print(f"[FusedEPMoE Layer {self.layer_id}] Router logits full:\n{logits}")
240+
241+
# jax.debug.callback(print_router_logits, router_logits)
242+
187243
output = fused_ep_moe(
188244
mesh=self.mesh,
189245
tokens=hidden_states,
@@ -212,4 +268,14 @@ def __call__(self, hidden_states: jax.Array, router_logits: jax.Array) -> jax.Ar
212268
# tp_axis_name="data",
213269
)
214270

215-
return output
271+
final_output = jax.sharding.reshard(output, P(None))
272+
273+
def print_final_output(out):
274+
print(
275+
f"[FusedEPMoE Layer {self.layer_id}] Final output shape: {out.shape}, dtype: {out.dtype}"
276+
)
277+
print(f"[FusedEPMoE Layer {self.layer_id}] Final output full:\n{out}")
278+
279+
# jax.debug.callback(print_final_output, final_output)
280+
281+
return final_output

python/sgl_jax/srt/layers/moe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,23 @@ def _detect_device_capabilities(self):
258258
return False, "cpu"
259259

260260
def __call__(self, hidden_states, topk_weights, topk_ids) -> jax.Array:
261+
# Debug: Print weights for EPMoE using callback
262+
def print_wi_0_stats(wi_0):
263+
print(f"[EPMoE Layer {self.layer_id}] wi_0 shape: {wi_0.shape}, dtype: {wi_0.dtype}")
264+
print(f"[EPMoE Layer {self.layer_id}] wi_0 full:\n{wi_0}")
265+
266+
def print_wi_1_stats(wi_1):
267+
print(f"[EPMoE Layer {self.layer_id}] wi_1 shape: {wi_1.shape}, dtype: {wi_1.dtype}")
268+
print(f"[EPMoE Layer {self.layer_id}] wi_1 full:\n{wi_1}")
269+
270+
def print_wo_stats(wo):
271+
print(f"[EPMoE Layer {self.layer_id}] wo shape: {wo.shape}, dtype: {wo.dtype}")
272+
print(f"[EPMoE Layer {self.layer_id}] wo full:\n{wo}")
273+
274+
jax.debug.callback(print_wi_0_stats, self.wi_0.value)
275+
jax.debug.callback(print_wi_1_stats, self.wi_1.value)
276+
jax.debug.callback(print_wo_stats, self.wo.value)
277+
261278
with jax.sharding.use_abstract_mesh(self.updated_mesh):
262279
hidden_states_reshard = jax.sharding.reshard(hidden_states, P(None))
263280
topk_weights_reshard = jax.sharding.reshard(topk_weights, P(None))

python/sgl_jax/srt/models/qwen3_moe.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,59 @@ def __call__(
262262
hidden_states = self.post_attention_layernorm(hidden_states)
263263

264264
if self.is_moe_layer:
265+
# Debug: MLP input using callback
266+
def print_mlp_input(hs):
267+
print(f"[Layer {self.layer_id}] MLP Input shape: {hs.shape}, dtype: {hs.dtype}")
268+
print(f"[Layer {self.layer_id}] MLP Input full:\n{hs}")
269+
270+
# jax.debug.callback(print_mlp_input, hidden_states)
271+
265272
router_logits = self.moe_gate(hidden_states)
273+
274+
def print_router_logits_fn(logits):
275+
print(
276+
f"[Layer {self.layer_id}] Router Logits shape: {logits.shape}, dtype: {logits.dtype}"
277+
)
278+
print(f"[Layer {self.layer_id}] Router Logits full:\n{logits}")
279+
280+
# jax.debug.callback(print_router_logits_fn, router_logits)
281+
266282
if self.use_fused:
283+
print(f"[Layer {self.layer_id}] Using FUSED MoE backend")
267284
hidden_states = self.mlp(hidden_states, router_logits)
285+
286+
def print_fused_output(out):
287+
print(
288+
f"[Layer {self.layer_id}] Fused MoE Output shape: {out.shape}, dtype: {out.dtype}"
289+
)
290+
print(f"[Layer {self.layer_id}] Fused MoE Output full:\n{out}")
291+
292+
# jax.debug.callback(print_fused_output, hidden_states)
268293
else:
294+
print(f"[Layer {self.layer_id}] Using EPMoE backend")
269295
topk_weights, topk_ids = self.topk(router_logits)
296+
297+
def print_topk(weights, ids):
298+
print(
299+
f"[Layer {self.layer_id}] TopK weights shape: {weights.shape}, dtype: {weights.dtype}"
300+
)
301+
print(f"[Layer {self.layer_id}] TopK weights full:\n{weights}")
302+
print(
303+
f"[Layer {self.layer_id}] TopK ids shape: {ids.shape}, dtype: {ids.dtype}"
304+
)
305+
print(f"[Layer {self.layer_id}] TopK ids full:\n{ids}")
306+
307+
# jax.debug.callback(print_topk, topk_weights, topk_ids)
308+
270309
hidden_states = self.mlp(hidden_states, topk_weights, topk_ids)
310+
311+
def print_epmoe_output(out):
312+
print(
313+
f"[Layer {self.layer_id}] EPMoE Output shape: {out.shape}, dtype: {out.dtype}"
314+
)
315+
print(f"[Layer {self.layer_id}] EPMoE Output full:\n{out}")
316+
317+
# jax.debug.callback(print_epmoe_output, hidden_states)
271318
else:
272319
hidden_states = self.mlp(hidden_states)
273320

@@ -510,37 +557,28 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
510557
# Fused MoE Mapping
511558
# w1: fused gate_proj(w1) + up_proj(w3) -> (num_experts, 2, hidden, intermediate)
512559
# w2: down_proj(w2) -> (num_experts, intermediate, hidden)
513-
514-
# 1. Fused w1 (gate + up)
515-
target_path_w1 = [f"{target_prefix}.mlp.w1"]
516-
# Add source keys for gate_proj and up_proj
517-
for name in ["gate_proj", "up_proj"]:
518-
target_path_w1.extend(
519-
[f"{prefix}.mlp.experts.{i}.{name}.weight" for i in range(num_experts)]
520-
)
521-
560+
w1_expert_keys = []
561+
for expert_type in ["gate_proj", "up_proj"]:
562+
w1_expert_keys = w1_expert_keys + [
563+
f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts)
564+
]
522565
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w1"] = WeightMapping(
523-
target_path=target_path_w1,
566+
target_path=[f"{target_prefix}.mlp.w1"] + w1_expert_keys,
524567
sharding=("tensor", None, None, None), # (E, 2, H, I)
525568
transpose=True,
526-
concat_axis=0,
527569
fuse_moe_weights=True,
528570
fuse_gate_up=("gate_proj", "up_proj"),
529571
)
530-
531-
# 2. w2 (down)
532-
target_path_w2 = [f"{target_prefix}.mlp.w2"]
533-
target_path_w2.extend(
534-
[f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)]
535-
)
536-
572+
w2_expert_keys = [
573+
f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)
574+
]
537575
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w2"] = WeightMapping(
538-
target_path=target_path_w2,
576+
target_path=[f"{target_prefix}.mlp.w2"] + w2_expert_keys,
539577
sharding=("tensor", None, None), # (E, I, H)
540578
transpose=True,
541-
concat_axis=-1,
542579
)
543580
else:
581+
# EPMoE mapping - always use expert sharding
544582
for expert_type in ["gate_proj", "up_proj", "down_proj"]:
545583
target_name = {
546584
"gate_proj": "wi_0",
@@ -553,9 +591,9 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
553591
]
554592

555593
if expert_type == "down_proj":
556-
sharding = ("tensor", None, None)
594+
sharding = ("expert", "tensor", None)
557595
else:
558-
sharding = ("tensor", None, None)
596+
sharding = ("expert", None, "tensor")
559597

560598
mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping(
561599
target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys,
@@ -598,8 +636,6 @@ def __call__(
598636
logits_metadata: LogitsMetadata,
599637
):
600638
hidden_states, layers_kv_fused = self.model(forward_batch, token_to_kv_pool)
601-
hidden_states = jax.sharding.reshard(hidden_states, jax.sharding.PartitionSpec(None, None))
602-
603639
if not getattr(self.config, "tie_word_embeddings", False):
604640
output = self.logits_processor(hidden_states, self.lm_head, logits_metadata)
605641
else:

0 commit comments

Comments
 (0)