@@ -292,6 +292,9 @@ def load_weights_from_safetensors(
292292
293293 nnx .update (self .model , params )
294294
295+ # Final verification: check all fused MoE layers
296+ self ._verify_fused_moe_weights (params , moe_mappings )
297+
295298 def _process_single_moe_group (
296299 self ,
297300 params : nnx .State ,
@@ -358,42 +361,48 @@ def _process_fused_moe_group(
358361 }
359362 """
360363 target_path = mapping .target_path [0 ]
364+ expected_hf_keys = mapping .target_path [1 :]
361365
362366 # Step 1: Process gate and up weights separately
367+ # Use the predefined order from expected_hf_keys, not sorting
363368 gate_weights = []
364369 up_weights = []
365370
366- # Process gate weights (w1)
367- for hf_key in sorted (grouped_weights ["gate" ].keys ()):
368- weights = grouped_weights ["gate" ][hf_key ]
371+ gate_id , up_id = mapping .fuse_gate_up
369372
370- # Concatenate TP shards
371- if mapping . concat_axis is not None and len ( weights ) > 1 :
372- weight = jnp . concatenate ( weights , axis = mapping . concat_axis )
373- else :
374- weight = weights [ 0 ]
373+ # Separate expected keys into gate and up based on fuse_gate_up config
374+ for hf_key in expected_hf_keys :
375+ if gate_id in hf_key :
376+ # This is a gate weight
377+ weights = grouped_weights [ "gate" ][ hf_key ]
375378
376- # Transpose
377- if mapping .transpose :
378- weight = jnp .transpose (weight , (1 , 0 ))
379+ # Concatenate TP shards
380+ if mapping .concat_axis is not None and len (weights ) > 1 :
381+ weight = jnp .concatenate (weights , axis = mapping .concat_axis )
382+ else :
383+ weight = weights [0 ]
379384
380- gate_weights .append (weight )
385+ # Transpose
386+ if mapping .transpose :
387+ weight = jnp .transpose (weight , (1 , 0 ))
381388
382- # Process up weights (w3)
383- for hf_key in sorted (grouped_weights ["up" ].keys ()):
384- weights = grouped_weights ["up" ][hf_key ]
389+ gate_weights .append (weight )
385390
386- # Concatenate TP shards
387- if mapping .concat_axis is not None and len (weights ) > 1 :
388- weight = jnp .concatenate (weights , axis = mapping .concat_axis )
389- else :
390- weight = weights [0 ]
391+ elif up_id in hf_key :
392+ # This is an up weight
393+ weights = grouped_weights ["up" ][hf_key ]
391394
392- # Transpose
393- if mapping .transpose :
394- weight = jnp .transpose (weight , (1 , 0 ))
395+ # Concatenate TP shards
396+ if mapping .concat_axis is not None and len (weights ) > 1 :
397+ weight = jnp .concatenate (weights , axis = mapping .concat_axis )
398+ else :
399+ weight = weights [0 ]
395400
396- up_weights .append (weight )
401+ # Transpose
402+ if mapping .transpose :
403+ weight = jnp .transpose (weight , (1 , 0 ))
404+
405+ up_weights .append (weight )
397406
398407 # Step 2: Stack to 3D tensors
399408 # gate_stacked: (num_experts, hidden_size, intermediate_size)
@@ -422,9 +431,24 @@ def _process_fused_moe_group(
422431
423432 # Step 5: Assign to model parameter
424433 model_param = self ._get_param (params , target_path )
425- model_param .value = sharded_weight .astype (model_param .value .dtype )
434+ original_dtype = model_param .value .dtype
435+ expected_shape = model_param .value .shape
436+
437+ # Validate shape before assignment
438+ if fused_weight .shape != expected_shape :
439+ raise ValueError (
440+ f"Fused MoE weight shape mismatch for { target_path } : "
441+ f"expected { expected_shape } , got { fused_weight .shape } "
442+ )
443+
444+ model_param .value = sharded_weight .astype (original_dtype )
426445
427- logger .debug ("Assigned fused MoE group %s, final shape: %s" , moe_key , fused_weight .shape )
446+ # Verify assignment was successful
447+ actual_shape = model_param .value .shape
448+ if actual_shape != expected_shape :
449+ raise RuntimeError (
450+ f"Failed to assign fused MoE weight to { target_path } : shape mismatch"
451+ )
428452
429453 def _load_dummy_weights (
430454 self ,
@@ -1000,3 +1024,72 @@ def _is_excluded_layer_weight(self, hf_key: str) -> bool:
10001024
10011025 layer_num = int (parts [2 ])
10021026 return layer_num >= self .model_config .num_hidden_layers
1027+
1028+ def _verify_fused_moe_weights (
1029+ self , params : nnx .State , moe_mappings : dict [str , WeightMapping ]
1030+ ) -> None :
1031+ """Verify that all fused MoE weights were loaded correctly."""
1032+ # Get all fused w1 mappings
1033+ fused_w1_mappings = {
1034+ k : v for k , v in moe_mappings .items () if getattr (v , "fuse_moe_weights" , False )
1035+ }
1036+
1037+ # Get corresponding w2 mappings (same layer, but w2 instead of w1)
1038+ w2_mappings = {}
1039+ for k in fused_w1_mappings :
1040+ w2_key = k .replace (".w1" , ".w2" )
1041+ if w2_key in moe_mappings :
1042+ w2_mappings [w2_key ] = moe_mappings [w2_key ]
1043+
1044+ if not fused_w1_mappings :
1045+ return
1046+
1047+ all_verified = True
1048+ verified_count = 0
1049+
1050+ # Verify w1 and w2 weights
1051+ for _ , mapping in fused_w1_mappings .items ():
1052+ target_path = mapping .target_path [0 ]
1053+ try :
1054+ model_param = self ._get_param (params , target_path )
1055+ weight_shape = model_param .value .shape
1056+ weight_values = model_param .value
1057+
1058+ if (
1059+ len (weight_shape ) != 4
1060+ or weight_shape [1 ] != 2
1061+ or jnp .all (weight_values == 0 )
1062+ or jnp .any (jnp .isnan (weight_values ))
1063+ ):
1064+ logger .error ("✗ %s: Invalid or corrupted weights" , target_path )
1065+ all_verified = False
1066+ else :
1067+ verified_count += 1
1068+ except (KeyError , AttributeError , ValueError ) as e :
1069+ logger .error ("✗ %s: Failed to access - %s" , target_path , str (e ))
1070+ all_verified = False
1071+
1072+ for _ , mapping in w2_mappings .items ():
1073+ target_path = mapping .target_path [0 ]
1074+ try :
1075+ model_param = self ._get_param (params , target_path )
1076+ weight_shape = model_param .value .shape
1077+ weight_values = model_param .value
1078+
1079+ if (
1080+ len (weight_shape ) != 3
1081+ or jnp .all (weight_values == 0 )
1082+ or jnp .any (jnp .isnan (weight_values ))
1083+ ):
1084+ logger .error ("✗ %s (w2): Invalid or corrupted weights" , target_path )
1085+ all_verified = False
1086+ else :
1087+ verified_count += 1
1088+ except (KeyError , AttributeError , ValueError ) as e :
1089+ logger .error ("✗ %s (w2): Failed to access - %s" , target_path , str (e ))
1090+ all_verified = False
1091+
1092+ if all_verified :
1093+ logger .info ("✓ Fused MoE weights verified: %d layers" , verified_count // 2 )
1094+ else :
1095+ raise RuntimeError ("Fused MoE weight verification failed" )
0 commit comments