Skip to content

Commit 57a995b

Browse files
committed
change fused moe kernel mesh
1 parent 2cf0141 commit 57a995b

File tree

9 files changed

+133
-59
lines changed

9 files changed

+133
-59
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,4 @@ CLAUDE.md
241241

242242
#gemini code
243243
.gemini-clipboard
244+
GEMINI.md

python/sgl_jax/srt/kernels/fused_moe/v1/kernel.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def _fused_ep_moe_kernel(
205205
top_k: int,
206206
renormalize_topk_logits: bool,
207207
ep_axis_name: str,
208+
tp_axis_name: str,
208209
act_fn: str,
209210
subc_quant_wsz: int | None = None,
210211
# Kernel tuning params.
@@ -214,8 +215,8 @@ def _fused_ep_moe_kernel(
214215
bd2: int, # Block size of hidden_size in w2.
215216
btc: int, # Compute size of block tokens for active expert.
216217
bfc: int, # Compute size of block intermediate_size.
217-
bd1c: int, # Compute size of block hidden_size.
218-
bd2c: int, # Compute size of block hidden_size.
218+
bd1c: int,
219+
bd2c: int,
219220
):
220221
my_id = lax.axis_index(ep_axis_name)
221222
num_devices = lax.axis_size(ep_axis_name)
@@ -260,8 +261,8 @@ def _fused_ep_moe_kernel(
260261
num_bd2 = cdiv(hidden_size, bd2)
261262

262263
def get_mesh_device_id(ep_rank):
263-
dp_rank = jax.lax.axis_index("data")
264-
return (dp_rank, ep_rank)
264+
tp_rank = jax.lax.axis_index(tp_axis_name)
265+
return (ep_rank, tp_rank)
265266

266267
def sync_barrier():
267268
barrier_sem = pltpu.get_barrier_semaphore()
@@ -1104,6 +1105,7 @@ def _():
11041105
"bd1c",
11051106
"bd2c",
11061107
"ep_axis_name",
1108+
"tp_axis_name",
11071109
],
11081110
)
11091111
def fused_ep_moe(
@@ -1134,12 +1136,12 @@ def fused_ep_moe(
11341136
bfc: int,
11351137
bd1c: int,
11361138
bd2c: int,
1137-
ep_axis_name: str = "tensor",
1139+
ep_axis_name: str = "expert",
1140+
tp_axis_name: str = "tensor",
11381141
):
11391142
# TODO(jevinjiang): move all these assertions to validation function.
11401143
# Assert all other axes have length of 1
11411144
assert len(mesh.shape) == 2, "Expect 2D mesh"
1142-
assert "data" in mesh.shape and mesh.shape["data"] == 1, "Expect data axis size of 1"
11431145

11441146
ep_size = mesh.shape[ep_axis_name]
11451147
num_devices = ep_size
@@ -1294,6 +1296,7 @@ def fused_ep_moe(
12941296
top_k=top_k,
12951297
renormalize_topk_logits=renormalize_topk_logits,
12961298
ep_axis_name=ep_axis_name,
1299+
tp_axis_name=tp_axis_name,
12971300
act_fn=act_fn,
12981301
subc_quant_wsz=subc_quant_wsz,
12991302
bt=bt,
@@ -1479,16 +1482,18 @@ def fused_ep_moe(
14791482
mesh=mesh,
14801483
in_specs=(
14811484
P(ep_axis_name), # tokens_hbm
1482-
P(ep_axis_name), # w1_hbm
1483-
P(ep_axis_name), # w2_hbm
1484-
None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
1485-
None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
1486-
None if b1 is None else P(ep_axis_name), # b1_hbm
1487-
None if b2 is None else P(ep_axis_name), # b2_hbm
1485+
P(ep_axis_name, None, None, tp_axis_name), # w1_hbm
1486+
P(ep_axis_name, tp_axis_name, None), # w2_hbm
1487+
(
1488+
None if w1_scale is None else P(ep_axis_name, None, None, None, tp_axis_name)
1489+
), # w1_scale_hbm
1490+
None if w2_scale is None else P(ep_axis_name, None, None, tp_axis_name), # w2_scale_hbm
1491+
None if b1 is None else P(ep_axis_name, None, tp_axis_name), # b1_hbm
1492+
None if b2 is None else P(ep_axis_name, tp_axis_name), # b2_hbm
14881493
P(ep_axis_name), # gating_output_hbm
14891494
P(), # a2a_g_hbm
14901495
),
1491-
out_specs=P(ep_axis_name),
1496+
out_specs=P(ep_axis_name, None),
14921497
check_vma=False,
14931498
)
14941499
def kernel(
@@ -1502,7 +1507,7 @@ def kernel(
15021507
gating_output,
15031508
a2a_g_hbm_scratch,
15041509
):
1505-
return fused_moe(
1510+
results = fused_moe(
15061511
pltpu.with_memory_space_constraint(tokens, pltpu.HBM), # tokens_hbm
15071512
pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
15081513
pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
@@ -1522,6 +1527,11 @@ def kernel(
15221527
pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM), # a2a_g_hbm
15231528
)
15241529

1530+
if tp_axis_name in mesh.axis_names:
1531+
results = jax.lax.psum(results, tp_axis_name)
1532+
1533+
return results
1534+
15251535
a2a_g_hbm_scratch = pl.empty((num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
15261536
results = kernel(
15271537
tokens,

python/sgl_jax/srt/layers/fused_moe.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
"""Fused MoE layer using optimized TPU kernel."""
22

3+
import logging
4+
35
import jax
46
import jax.numpy as jnp
57
from flax import nnx
6-
from jax.sharding import Mesh
8+
from jax.sharding import Mesh, NamedSharding
79
from jax.sharding import PartitionSpec as P
810

911
from sgl_jax.srt.kernels.fused_moe.v1.kernel import fused_ep_moe
1012

13+
logger = logging.getLogger(__name__)
14+
1115

1216
def _get_default_tile_sizes(hidden_size: int, intermediate_size: int) -> dict[str, int]:
1317
"""
@@ -168,6 +172,33 @@ def __init__(
168172
self.bd1c = bd1c
169173
self.bd2c = bd2c
170174

175+
logger.info(
176+
"Initializing FusedEPMoE layer %d: num_experts=%d, "
177+
"num_experts_per_tok=%d, ep_size=%d, "
178+
"intermediate_dim=%d, activation=%s, "
179+
"renormalize_topk_logits=%s",
180+
layer_id,
181+
num_experts,
182+
num_experts_per_tok,
183+
ep_size,
184+
intermediate_dim,
185+
activation,
186+
renormalize_topk_logits,
187+
)
188+
logger.info(
189+
"FusedEPMoE layer %d tile sizes: bt=%d, bf=%d, bd1=%d, bd2=%d, "
190+
"btc=%d, bfc=%d, bd1c=%d, bd2c=%d",
191+
layer_id,
192+
bt,
193+
bf,
194+
bd1,
195+
bd2,
196+
btc,
197+
bfc,
198+
bd1c,
199+
bd2c,
200+
)
201+
171202
# Initialize weights in fused format
172203
with jax.sharding.use_abstract_mesh(self.updated_mesh):
173204
self.w1 = nnx.Param(
@@ -203,32 +234,46 @@ def __call__(self, hidden_states: jax.Array, router_logits: jax.Array) -> jax.Ar
203234
"""
204235
assert hidden_states.ndim == 2
205236

206-
# Call the fused kernel
207-
# Note: ep_size > 1 is handled internally by the kernel via mesh
208-
output = fused_ep_moe(
209-
mesh=self.moe_mesh,
210-
tokens=hidden_states,
211-
w1=self.w1.value,
212-
w2=self.w2.value,
213-
gating_output=router_logits,
214-
top_k=self.num_experts_per_tok,
215-
renormalize_topk_logits=self.renormalize_topk_logits,
216-
act_fn=self.activation,
217-
# Tile sizes
218-
bt=self.bt,
219-
bf=self.bf,
220-
bd1=self.bd1,
221-
bd2=self.bd2,
222-
btc=self.btc,
223-
bfc=self.bfc,
224-
bd1c=self.bd1c,
225-
bd2c=self.bd2c,
226-
# Optional parameters (not used in basic case)
227-
subc_quant_wsz=None,
228-
w1_scale=None,
229-
w2_scale=None,
230-
b1=None,
231-
b2=None,
237+
logger.debug(
238+
"FusedEPMoE layer %d: Processing %d tokens with %d experts (top-%d)",
239+
self.layer_id,
240+
hidden_states.shape[0],
241+
self.num_experts,
242+
self.num_experts_per_tok,
232243
)
233244

234-
return output
245+
with jax.sharding.use_abstract_mesh(self.updated_mesh):
246+
hidden_states = jax.lax.with_sharding_constraint(
247+
hidden_states, NamedSharding(self.updated_mesh, P("expert", None))
248+
)
249+
250+
output = fused_ep_moe(
251+
mesh=self.moe_mesh,
252+
tokens=hidden_states,
253+
w1=self.w1.value,
254+
w2=self.w2.value,
255+
gating_output=router_logits,
256+
top_k=self.num_experts_per_tok,
257+
renormalize_topk_logits=self.renormalize_topk_logits,
258+
act_fn=self.activation,
259+
# Tile sizes
260+
bt=self.bt,
261+
bf=self.bf,
262+
bd1=self.bd1,
263+
bd2=self.bd2,
264+
btc=self.btc,
265+
bfc=self.bfc,
266+
bd1c=self.bd1c,
267+
bd2c=self.bd2c,
268+
# Optional parameters (not used in basic case)
269+
subc_quant_wsz=None,
270+
w1_scale=None,
271+
w2_scale=None,
272+
b1=None,
273+
b2=None,
274+
)
275+
276+
output_pspec = P(*([None] * (output.ndim)))
277+
return jax.sharding.reshard(
278+
output, jax.sharding.NamedSharding(self.original_mesh, output_pspec)
279+
)

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def load_model(self):
240240
self.model_config.configure_for_tensor_parallel(self.tp_size)
241241
self.model_config.log_kv_heads_info(self.tp_size)
242242
self.model_config.hf_config.ep_size = self.ep_size
243+
self.model_config.hf_config.moe_backend = self.model_config.moe_backend.value
243244

244245
self.model = self.model_loader.load_model(
245246
model_config=self.model_config,

python/sgl_jax/srt/models/bailing_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
612612
[f"{prefix}.mlp.experts.{i}.{name}.weight" for i in range(num_experts)]
613613
)
614614

615-
mappings[f"__MOE_EXPERTS__{prefix}.mlp.experts.w1"] = WeightMapping(
615+
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w1"] = WeightMapping(
616616
target_path=target_path_w1,
617617
sharding=("expert", None, None, "tensor"), # (E, 2, H, I/TP)
618618
transpose=True,
@@ -627,7 +627,7 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
627627
[f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)]
628628
)
629629

630-
mappings[f"__MOE_EXPERTS__{prefix}.mlp.experts.w2"] = WeightMapping(
630+
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w2"] = WeightMapping(
631631
target_path=target_path_w2,
632632
sharding=("expert", "tensor", None), # (E, I/TP, H)
633633
transpose=True,

python/sgl_jax/srt/models/grok.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_yarn_find_correction_range,
1818
_yarn_get_mscale,
1919
)
20+
from sgl_jax.srt.layers.fused_moe import FusedEPMoE
2021
from sgl_jax.srt.layers.layernorm import RMSNorm, dual_rmsnorm_forward
2122
from sgl_jax.srt.layers.linear import LinearBase
2223
from sgl_jax.srt.layers.logits_processor import (
@@ -206,6 +207,8 @@ class Grok1MoE(nnx.Module):
206207
kernel is used for the forward pass, with outputs reduced across ranks.
207208
"""
208209

210+
experts: FusedEPMoE | EPMoE
211+
209212
def __init__(
210213
self,
211214
config: PretrainedConfig,
@@ -242,8 +245,6 @@ def __init__(
242245
self.use_fused = self.moe_backend == "fused"
243246

244247
if self.use_fused:
245-
from sgl_jax.srt.layers.fused_moe import FusedEPMoE
246-
247248
self.experts = FusedEPMoE(
248249
config=config,
249250
num_experts=num_experts,
@@ -283,12 +284,14 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
283284
if self.use_fused:
284285
# Fused kernel: pass router_logits directly
285286
# Top-K selection is handled internally by the kernel
287+
assert isinstance(self.experts, FusedEPMoE)
286288
return self.experts(hidden_states, router_logits)
287289
else:
288290
# EPMoE: compute top-k routing weights using sglang-style approach:
289291
# 1. Compute global softmax over ALL experts (not just top-k)
290292
# 2. Select top-k experts based on logits
291293
# 3. Extract corresponding weights (no renormalization)
294+
assert isinstance(self.experts, EPMoE)
292295
top_k_weights, top_k_indices = self._custom_topk(
293296
router_logits, self.top_k, renormalize=False
294297
)
@@ -939,7 +942,7 @@ def _create_layer_mappings(self, layer_idx: int) -> dict[str, WeightMapping]:
939942
# w2: down(w2) -> (num_experts, intermediate, hidden)
940943

941944
# 1. Fused w1 (gate + up)
942-
target_path_w1 = [f"{target_prefix}.block_sparse_moe.experts.w1"]
945+
target_path_w1 = [f"{target_prefix}.block_sparse_moe.w1"]
943946
# Add source keys for w1 (gate) and w3 (up)
944947
# Note: Grok experts are 0..N-1
945948
for name in ["w1", "w3"]:
@@ -950,7 +953,7 @@ def _create_layer_mappings(self, layer_idx: int) -> dict[str, WeightMapping]:
950953
]
951954
)
952955

953-
mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.experts.w1"] = WeightMapping(
956+
mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.w1"] = WeightMapping(
954957
target_path=target_path_w1,
955958
sharding=("expert", None, None, "tensor"), # (E, 2, H, I/TP)
956959
transpose=True,
@@ -960,16 +963,15 @@ def _create_layer_mappings(self, layer_idx: int) -> dict[str, WeightMapping]:
960963
)
961964

962965
# 2. w2 (down)
963-
target_path_w2 = [f"{target_prefix}.block_sparse_moe.experts.w2"]
966+
target_path_w2 = [f"{target_prefix}.block_sparse_moe.w2"]
964967
target_path_w2.extend(
965968
[
966969
f"{prefix}.block_sparse_moe.experts.{i}.w2.weight"
967970
for i in range(self.config.num_local_experts)
968971
]
969972
)
970973

971-
972-
mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.experts.w2"] = WeightMapping(
974+
mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.w2"] = WeightMapping(
973975
target_path=target_path_w2,
974976
sharding=("expert", "tensor", None), # (E, I/TP, H)
975977
transpose=True,
@@ -987,7 +989,11 @@ def _create_layer_mappings(self, layer_idx: int) -> dict[str, WeightMapping]:
987989
]
988990
)
989991

990-
sharding = ("expert", "tensor", None) if target_name == "wo" else ("expert", None, "tensor")
992+
sharding = (
993+
("expert", "tensor", None)
994+
if target_name == "wo"
995+
else ("expert", None, "tensor")
996+
)
991997

992998
if name == "w2":
993999
# w2 (down_proj) -> wo

python/sgl_jax/srt/models/qwen2_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def _create_moe_layer_mappings(self, layer_idx: int) -> dict:
593593
[f"{prefix}.mlp.experts.{i}.{name}.weight" for i in range(num_experts)]
594594
)
595595

596-
mappings[f"__MOE_EXPERTS__{prefix}.mlp.experts.w1"] = WeightMapping(
596+
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w1"] = WeightMapping(
597597
target_path=target_path_w1,
598598
sharding=("expert", None, None, "tensor"), # (E, 2, H, I/TP)
599599
transpose=True,
@@ -608,7 +608,7 @@ def _create_moe_layer_mappings(self, layer_idx: int) -> dict:
608608
[f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)]
609609
)
610610

611-
mappings[f"__MOE_EXPERTS__{prefix}.mlp.experts.w2"] = WeightMapping(
611+
mappings[f"__MOE_EXPERTS__{prefix}.mlp.w2"] = WeightMapping(
612612
target_path=target_path_w2,
613613
sharding=("expert", "tensor", None), # (E, I/TP, H)
614614
transpose=True,

0 commit comments

Comments
 (0)