@@ -262,12 +262,59 @@ def __call__(
262262 hidden_states = self .post_attention_layernorm (hidden_states )
263263
264264 if self .is_moe_layer :
265+ # Debug: MLP input using callback
266+ def print_mlp_input (hs ):
267+ print (f"[Layer { self .layer_id } ] MLP Input shape: { hs .shape } , dtype: { hs .dtype } " )
268+ print (f"[Layer { self .layer_id } ] MLP Input full:\n { hs } " )
269+
270+ # jax.debug.callback(print_mlp_input, hidden_states)
271+
265272 router_logits = self .moe_gate (hidden_states )
273+
274+ def print_router_logits_fn (logits ):
275+ print (
276+ f"[Layer { self .layer_id } ] Router Logits shape: { logits .shape } , dtype: { logits .dtype } "
277+ )
278+ print (f"[Layer { self .layer_id } ] Router Logits full:\n { logits } " )
279+
280+ # jax.debug.callback(print_router_logits_fn, router_logits)
281+
266282 if self .use_fused :
283+ print (f"[Layer { self .layer_id } ] Using FUSED MoE backend" )
267284 hidden_states = self .mlp (hidden_states , router_logits )
285+
286+ def print_fused_output (out ):
287+ print (
288+ f"[Layer { self .layer_id } ] Fused MoE Output shape: { out .shape } , dtype: { out .dtype } "
289+ )
290+ print (f"[Layer { self .layer_id } ] Fused MoE Output full:\n { out } " )
291+
292+ # jax.debug.callback(print_fused_output, hidden_states)
268293 else :
294+ print (f"[Layer { self .layer_id } ] Using EPMoE backend" )
269295 topk_weights , topk_ids = self .topk (router_logits )
296+
297+ def print_topk (weights , ids ):
298+ print (
299+ f"[Layer { self .layer_id } ] TopK weights shape: { weights .shape } , dtype: { weights .dtype } "
300+ )
301+ print (f"[Layer { self .layer_id } ] TopK weights full:\n { weights } " )
302+ print (
303+ f"[Layer { self .layer_id } ] TopK ids shape: { ids .shape } , dtype: { ids .dtype } "
304+ )
305+ print (f"[Layer { self .layer_id } ] TopK ids full:\n { ids } " )
306+
307+ # jax.debug.callback(print_topk, topk_weights, topk_ids)
308+
270309 hidden_states = self .mlp (hidden_states , topk_weights , topk_ids )
310+
311+ def print_epmoe_output (out ):
312+ print (
313+ f"[Layer { self .layer_id } ] EPMoE Output shape: { out .shape } , dtype: { out .dtype } "
314+ )
315+ print (f"[Layer { self .layer_id } ] EPMoE Output full:\n { out } " )
316+
317+ # jax.debug.callback(print_epmoe_output, hidden_states)
271318 else :
272319 hidden_states = self .mlp (hidden_states )
273320
@@ -510,37 +557,28 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
510557 # Fused MoE Mapping
511558 # w1: fused gate_proj(w1) + up_proj(w3) -> (num_experts, 2, hidden, intermediate)
512559 # w2: down_proj(w2) -> (num_experts, intermediate, hidden)
513-
514- # 1. Fused w1 (gate + up)
515- target_path_w1 = [f"{ target_prefix } .mlp.w1" ]
516- # Add source keys for gate_proj and up_proj
517- for name in ["gate_proj" , "up_proj" ]:
518- target_path_w1 .extend (
519- [f"{ prefix } .mlp.experts.{ i } .{ name } .weight" for i in range (num_experts )]
520- )
521-
560+ w1_expert_keys = []
561+ for expert_type in ["gate_proj" , "up_proj" ]:
562+ w1_expert_keys = w1_expert_keys + [
563+ f"{ prefix } .mlp.experts.{ i } .{ expert_type } .weight" for i in range (num_experts )
564+ ]
522565 mappings [f"__MOE_EXPERTS__{ prefix } .mlp.w1" ] = WeightMapping (
523- target_path = target_path_w1 ,
566+ target_path = [ f" { target_prefix } .mlp.w1" ] + w1_expert_keys ,
524567 sharding = ("tensor" , None , None , None ), # (E, 2, H, I)
525568 transpose = True ,
526- concat_axis = 0 ,
527569 fuse_moe_weights = True ,
528570 fuse_gate_up = ("gate_proj" , "up_proj" ),
529571 )
530-
531- # 2. w2 (down)
532- target_path_w2 = [f"{ target_prefix } .mlp.w2" ]
533- target_path_w2 .extend (
534- [f"{ prefix } .mlp.experts.{ i } .down_proj.weight" for i in range (num_experts )]
535- )
536-
572+ w2_expert_keys = [
573+ f"{ prefix } .mlp.experts.{ i } .down_proj.weight" for i in range (num_experts )
574+ ]
537575 mappings [f"__MOE_EXPERTS__{ prefix } .mlp.w2" ] = WeightMapping (
538- target_path = target_path_w2 ,
576+ target_path = [ f" { target_prefix } .mlp.w2" ] + w2_expert_keys ,
539577 sharding = ("tensor" , None , None ), # (E, I, H)
540578 transpose = True ,
541- concat_axis = - 1 ,
542579 )
543580 else :
581+ # EPMoE mapping - always use expert sharding
544582 for expert_type in ["gate_proj" , "up_proj" , "down_proj" ]:
545583 target_name = {
546584 "gate_proj" : "wi_0" ,
@@ -553,9 +591,9 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict
553591 ]
554592
555593 if expert_type == "down_proj" :
556- sharding = ("tensor " , None , None )
594+ sharding = ("expert " , "tensor" , None )
557595 else :
558- sharding = ("tensor " , None , None )
596+ sharding = ("expert " , None , "tensor" )
559597
560598 mappings [f"__MOE_EXPERTS__{ prefix } .mlp.{ target_name } " ] = WeightMapping (
561599 target_path = [f"{ target_prefix } .mlp.{ target_name } " ] + expert_keys ,
@@ -598,8 +636,6 @@ def __call__(
598636 logits_metadata : LogitsMetadata ,
599637 ):
600638 hidden_states , layers_kv_fused = self .model (forward_batch , token_to_kv_pool )
601- hidden_states = jax .sharding .reshard (hidden_states , jax .sharding .PartitionSpec (None , None ))
602-
603639 if not getattr (self .config , "tie_word_embeddings" , False ):
604640 output = self .logits_processor (hidden_states , self .lm_head , logits_metadata )
605641 else :
0 commit comments