Skip to content

Commit b4a2aa6

Browse files
committed
fix: load fused moe
1 parent 4c57054 commit b4a2aa6

File tree

3 files changed

+138
-51
lines changed

3 files changed

+138
-51
lines changed

python/sgl_jax/srt/layers/fused_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def __call__(self, hidden_states: jax.Array, router_logits: jax.Array) -> jax.Ar
184184
"""
185185
assert hidden_states.ndim == 2
186186

187+
hidden_states = jax.sharding.reshard(hidden_states, P("tensor", None))
188+
router_logits = jax.sharding.reshard(router_logits, P("tensor", None))
189+
187190
output = fused_ep_moe(
188191
mesh=self.mesh,
189192
tokens=hidden_states,
@@ -212,4 +215,5 @@ def __call__(self, hidden_states: jax.Array, router_logits: jax.Array) -> jax.Ar
212215
# tp_axis_name="data",
213216
)
214217

215-
return output
218+
final_output = jax.sharding.reshard(output, P(None))
219+
return final_output

python/sgl_jax/srt/models/qwen3_moe.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def __call__(
263263

264264
if self.is_moe_layer:
265265
router_logits = self.moe_gate(hidden_states)
266+
266267
if self.use_fused:
267268
hidden_states = self.mlp(hidden_states, router_logits)
268269
else:
@@ -510,37 +511,28 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
510511
# Fused MoE Mapping
511512
# w1: fused gate_proj(w1) + up_proj(w3) -> (num_experts, 2, hidden, intermediate)
512513
# w2: down_proj(w2) -> (num_experts, intermediate, hidden)
513-
514-
# 1. Fused w1 (gate + up)
515-
target_path_w1 = [f"{target_prefix}.mlp.w1"]
516-
# Add source keys for gate_proj and up_proj
517-
for name in ["gate_proj", "up_proj"]:
518-
target_path_w1.extend(
519-
[f"{prefix}.mlp.experts.{i}.{name}.weight" for i in range(num_experts)]
520-
)
521-
514+
w1_expert_keys = []
515+
for expert_type in ["gate_proj", "up_proj"]:
516+
w1_expert_keys = w1_expert_keys + [
517+
f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts)
518+
]
522519
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w1"] = WeightMapping(
523-
target_path=target_path_w1,
520+
target_path=[f"{target_prefix}.mlp.w1"] + w1_expert_keys,
524521
sharding=("tensor", None, None, None), # (E, 2, H, I)
525522
transpose=True,
526-
concat_axis=0,
527523
fuse_moe_weights=True,
528524
fuse_gate_up=("gate_proj", "up_proj"),
529525
)
530-
531-
# 2. w2 (down)
532-
target_path_w2 = [f"{target_prefix}.mlp.w2"]
533-
target_path_w2.extend(
534-
[f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)]
535-
)
536-
526+
w2_expert_keys = [
527+
f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)
528+
]
537529
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w2"] = WeightMapping(
538-
target_path=target_path_w2,
530+
target_path=[f"{target_prefix}.mlp.w2"] + w2_expert_keys,
539531
sharding=("tensor", None, None), # (E, I, H)
540532
transpose=True,
541-
concat_axis=-1,
542533
)
543534
else:
535+
# EPMoE mapping - always use expert sharding
544536
for expert_type in ["gate_proj", "up_proj", "down_proj"]:
545537
target_name = {
546538
"gate_proj": "wi_0",
@@ -553,9 +545,9 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
553545
]
554546

555547
if expert_type == "down_proj":
556-
sharding = ("tensor", None, None)
548+
sharding = ("expert", "tensor", None)
557549
else:
558-
sharding = ("tensor", None, None)
550+
sharding = ("expert", None, "tensor")
559551

560552
mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping(
561553
target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys,
@@ -598,8 +590,6 @@ def __call__(
598590
logits_metadata: LogitsMetadata,
599591
):
600592
hidden_states, layers_kv_fused = self.model(forward_batch, token_to_kv_pool)
601-
hidden_states = jax.sharding.reshard(hidden_states, jax.sharding.PartitionSpec(None, None))
602-
603593
if not getattr(self.config, "tie_word_embeddings", False):
604594
output = self.logits_processor(hidden_states, self.lm_head, logits_metadata)
605595
else:

python/sgl_jax/srt/utils/weight_utils.py

Lines changed: 119 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def load_weights_from_safetensors(
292292

293293
nnx.update(self.model, params)
294294

295+
# Final verification: check all fused MoE layers
296+
self._verify_fused_moe_weights(params, moe_mappings)
297+
295298
def _process_single_moe_group(
296299
self,
297300
params: nnx.State,
@@ -358,42 +361,48 @@ def _process_fused_moe_group(
358361
}
359362
"""
360363
target_path = mapping.target_path[0]
364+
expected_hf_keys = mapping.target_path[1:]
361365

362366
# Step 1: Process gate and up weights separately
367+
# Use the predefined order from expected_hf_keys, not sorting
363368
gate_weights = []
364369
up_weights = []
365370

366-
# Process gate weights (w1)
367-
for hf_key in sorted(grouped_weights["gate"].keys()):
368-
weights = grouped_weights["gate"][hf_key]
371+
gate_id, up_id = mapping.fuse_gate_up
369372

370-
# Concatenate TP shards
371-
if mapping.concat_axis is not None and len(weights) > 1:
372-
weight = jnp.concatenate(weights, axis=mapping.concat_axis)
373-
else:
374-
weight = weights[0]
373+
# Separate expected keys into gate and up based on fuse_gate_up config
374+
for hf_key in expected_hf_keys:
375+
if gate_id in hf_key:
376+
# This is a gate weight
377+
weights = grouped_weights["gate"][hf_key]
375378

376-
# Transpose
377-
if mapping.transpose:
378-
weight = jnp.transpose(weight, (1, 0))
379+
# Concatenate TP shards
380+
if mapping.concat_axis is not None and len(weights) > 1:
381+
weight = jnp.concatenate(weights, axis=mapping.concat_axis)
382+
else:
383+
weight = weights[0]
379384

380-
gate_weights.append(weight)
385+
# Transpose
386+
if mapping.transpose:
387+
weight = jnp.transpose(weight, (1, 0))
381388

382-
# Process up weights (w3)
383-
for hf_key in sorted(grouped_weights["up"].keys()):
384-
weights = grouped_weights["up"][hf_key]
389+
gate_weights.append(weight)
385390

386-
# Concatenate TP shards
387-
if mapping.concat_axis is not None and len(weights) > 1:
388-
weight = jnp.concatenate(weights, axis=mapping.concat_axis)
389-
else:
390-
weight = weights[0]
391+
elif up_id in hf_key:
392+
# This is an up weight
393+
weights = grouped_weights["up"][hf_key]
391394

392-
# Transpose
393-
if mapping.transpose:
394-
weight = jnp.transpose(weight, (1, 0))
395+
# Concatenate TP shards
396+
if mapping.concat_axis is not None and len(weights) > 1:
397+
weight = jnp.concatenate(weights, axis=mapping.concat_axis)
398+
else:
399+
weight = weights[0]
395400

396-
up_weights.append(weight)
401+
# Transpose
402+
if mapping.transpose:
403+
weight = jnp.transpose(weight, (1, 0))
404+
405+
up_weights.append(weight)
397406

398407
# Step 2: Stack to 3D tensors
399408
# gate_stacked: (num_experts, hidden_size, intermediate_size)
@@ -422,9 +431,24 @@ def _process_fused_moe_group(
422431

423432
# Step 5: Assign to model parameter
424433
model_param = self._get_param(params, target_path)
425-
model_param.value = sharded_weight.astype(model_param.value.dtype)
434+
original_dtype = model_param.value.dtype
435+
expected_shape = model_param.value.shape
436+
437+
# Validate shape before assignment
438+
if fused_weight.shape != expected_shape:
439+
raise ValueError(
440+
f"Fused MoE weight shape mismatch for {target_path}: "
441+
f"expected {expected_shape}, got {fused_weight.shape}"
442+
)
443+
444+
model_param.value = sharded_weight.astype(original_dtype)
426445

427-
logger.debug("Assigned fused MoE group %s, final shape: %s", moe_key, fused_weight.shape)
446+
# Verify assignment was successful
447+
actual_shape = model_param.value.shape
448+
if actual_shape != expected_shape:
449+
raise RuntimeError(
450+
f"Failed to assign fused MoE weight to {target_path}: shape mismatch"
451+
)
428452

429453
def _load_dummy_weights(
430454
self,
@@ -1000,3 +1024,72 @@ def _is_excluded_layer_weight(self, hf_key: str) -> bool:
10001024

10011025
layer_num = int(parts[2])
10021026
return layer_num >= self.model_config.num_hidden_layers
1027+
1028+
def _verify_fused_moe_weights(
1029+
self, params: nnx.State, moe_mappings: dict[str, WeightMapping]
1030+
) -> None:
1031+
"""Verify that all fused MoE weights were loaded correctly."""
1032+
# Get all fused w1 mappings
1033+
fused_w1_mappings = {
1034+
k: v for k, v in moe_mappings.items() if getattr(v, "fuse_moe_weights", False)
1035+
}
1036+
1037+
# Get corresponding w2 mappings (same layer, but w2 instead of w1)
1038+
w2_mappings = {}
1039+
for k in fused_w1_mappings:
1040+
w2_key = k.replace(".w1", ".w2")
1041+
if w2_key in moe_mappings:
1042+
w2_mappings[w2_key] = moe_mappings[w2_key]
1043+
1044+
if not fused_w1_mappings:
1045+
return
1046+
1047+
all_verified = True
1048+
verified_count = 0
1049+
1050+
# Verify w1 and w2 weights
1051+
for _, mapping in fused_w1_mappings.items():
1052+
target_path = mapping.target_path[0]
1053+
try:
1054+
model_param = self._get_param(params, target_path)
1055+
weight_shape = model_param.value.shape
1056+
weight_values = model_param.value
1057+
1058+
if (
1059+
len(weight_shape) != 4
1060+
or weight_shape[1] != 2
1061+
or jnp.all(weight_values == 0)
1062+
or jnp.any(jnp.isnan(weight_values))
1063+
):
1064+
logger.error("✗ %s: Invalid or corrupted weights", target_path)
1065+
all_verified = False
1066+
else:
1067+
verified_count += 1
1068+
except (KeyError, AttributeError, ValueError) as e:
1069+
logger.error("✗ %s: Failed to access - %s", target_path, str(e))
1070+
all_verified = False
1071+
1072+
for _, mapping in w2_mappings.items():
1073+
target_path = mapping.target_path[0]
1074+
try:
1075+
model_param = self._get_param(params, target_path)
1076+
weight_shape = model_param.value.shape
1077+
weight_values = model_param.value
1078+
1079+
if (
1080+
len(weight_shape) != 3
1081+
or jnp.all(weight_values == 0)
1082+
or jnp.any(jnp.isnan(weight_values))
1083+
):
1084+
logger.error("✗ %s (w2): Invalid or corrupted weights", target_path)
1085+
all_verified = False
1086+
else:
1087+
verified_count += 1
1088+
except (KeyError, AttributeError, ValueError) as e:
1089+
logger.error("✗ %s (w2): Failed to access - %s", target_path, str(e))
1090+
all_verified = False
1091+
1092+
if all_verified:
1093+
logger.info("✓ Fused MoE weights verified: %d layers", verified_count // 2)
1094+
else:
1095+
raise RuntimeError("Fused MoE weight verification failed")

0 commit comments

Comments
 (0)