@@ -205,6 +205,7 @@ def _fused_ep_moe_kernel(
205205 top_k : int ,
206206 renormalize_topk_logits : bool ,
207207 ep_axis_name : str ,
208+ tp_axis_name : str ,
208209 act_fn : str ,
209210 subc_quant_wsz : int | None = None ,
210211 # Kernel tuning params.
@@ -214,8 +215,8 @@ def _fused_ep_moe_kernel(
214215 bd2 : int , # Block size of hidden_size in w2.
215216 btc : int , # Compute size of block tokens for active expert.
216217 bfc : int , # Compute size of block intermediate_size.
217- bd1c : int , # Compute size of block hidden_size.
218- bd2c : int , # Compute size of block hidden_size.
218+ bd1c : int ,
219+ bd2c : int ,
219220):
220221 my_id = lax .axis_index (ep_axis_name )
221222 num_devices = lax .axis_size (ep_axis_name )
@@ -260,8 +261,8 @@ def _fused_ep_moe_kernel(
260261 num_bd2 = cdiv (hidden_size , bd2 )
261262
262263 def get_mesh_device_id (ep_rank ):
263- dp_rank = jax .lax .axis_index ("data" )
264- return (dp_rank , ep_rank )
264+ tp_rank = jax .lax .axis_index (tp_axis_name )
265+ return (ep_rank , tp_rank )
265266
266267 def sync_barrier ():
267268 barrier_sem = pltpu .get_barrier_semaphore ()
@@ -1104,6 +1105,7 @@ def _():
11041105 "bd1c" ,
11051106 "bd2c" ,
11061107 "ep_axis_name" ,
1108+ "tp_axis_name" ,
11071109 ],
11081110)
11091111def fused_ep_moe (
@@ -1134,12 +1136,12 @@ def fused_ep_moe(
11341136 bfc : int ,
11351137 bd1c : int ,
11361138 bd2c : int ,
1137- ep_axis_name : str = "tensor" ,
1139+ ep_axis_name : str = "expert" ,
1140+ tp_axis_name : str = "tensor" ,
11381141):
11391142 # TODO(jevinjiang): move all these assertions to validation function.
11401143 # Assert all other axes have length of 1
11411144 assert len (mesh .shape ) == 2 , "Expect 2D mesh"
1142- assert "data" in mesh .shape and mesh .shape ["data" ] == 1 , "Expect data axis size of 1"
11431145
11441146 ep_size = mesh .shape [ep_axis_name ]
11451147 num_devices = ep_size
@@ -1294,6 +1296,7 @@ def fused_ep_moe(
12941296 top_k = top_k ,
12951297 renormalize_topk_logits = renormalize_topk_logits ,
12961298 ep_axis_name = ep_axis_name ,
1299+ tp_axis_name = tp_axis_name ,
12971300 act_fn = act_fn ,
12981301 subc_quant_wsz = subc_quant_wsz ,
12991302 bt = bt ,
@@ -1479,16 +1482,18 @@ def fused_ep_moe(
14791482 mesh = mesh ,
14801483 in_specs = (
14811484 P (ep_axis_name ), # tokens_hbm
1482- P (ep_axis_name ), # w1_hbm
1483- P (ep_axis_name ), # w2_hbm
1484- None if w1_scale is None else P (ep_axis_name ), # w1_scale_hbm
1485- None if w2_scale is None else P (ep_axis_name ), # w2_scale_hbm
1486- None if b1 is None else P (ep_axis_name ), # b1_hbm
1487- None if b2 is None else P (ep_axis_name ), # b2_hbm
1485+ P (ep_axis_name , None , None , tp_axis_name ), # w1_hbm
1486+ P (ep_axis_name , tp_axis_name , None ), # w2_hbm
1487+ (
1488+ None if w1_scale is None else P (ep_axis_name , None , None , None , tp_axis_name )
1489+ ), # w1_scale_hbm
1490+ None if w2_scale is None else P (ep_axis_name , None , None , tp_axis_name ), # w2_scale_hbm
1491+ None if b1 is None else P (ep_axis_name , None , tp_axis_name ), # b1_hbm
1492+ None if b2 is None else P (ep_axis_name , tp_axis_name ), # b2_hbm
14881493 P (ep_axis_name ), # gating_output_hbm
14891494 P (), # a2a_g_hbm
14901495 ),
1491- out_specs = P (ep_axis_name ),
1496+ out_specs = P (ep_axis_name , None ),
14921497 check_vma = False ,
14931498 )
14941499 def kernel (
@@ -1502,7 +1507,7 @@ def kernel(
15021507 gating_output ,
15031508 a2a_g_hbm_scratch ,
15041509 ):
1505- return fused_moe (
1510+ results = fused_moe (
15061511 pltpu .with_memory_space_constraint (tokens , pltpu .HBM ), # tokens_hbm
15071512 pltpu .with_memory_space_constraint (w1 , pltpu .HBM ), # w1_hbm
15081513 pltpu .with_memory_space_constraint (w2 , pltpu .HBM ), # w2_hbm
@@ -1522,6 +1527,11 @@ def kernel(
15221527 pltpu .with_memory_space_constraint (a2a_g_hbm_scratch , pltpu .HBM ), # a2a_g_hbm
15231528 )
15241529
1530+ if tp_axis_name in mesh .axis_names :
1531+ results = jax .lax .psum (results , tp_axis_name )
1532+
1533+ return results
1534+
15251535 a2a_g_hbm_scratch = pl .empty ((num_experts , bt , t_packing , hidden_size // t_packing ), t_dtype )
15261536 results = kernel (
15271537 tokens ,
0 commit comments