Skip to content

fix: keep fp32-pinned parameters out of the bf16 cast path in ZeRO-3 (#7747)#7867

Draft
harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
harshang03:fix/7747-moe-bf16-fp32-pinned-params
Draft

fix: keep fp32-pinned parameters out of the bf16 cast path in ZeRO-3 (#7747)#7867
harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
harshang03:fix/7747-moe-bf16-fp32-pinned-params

Conversation

@harshang03
Copy link

What this PR does / why we need it

When training with BF16 and ZeRO stage 3, every model parameter is cast to bfloat16 inside the deepspeed.zero.Init context via the global tensor-creation wrappers. This affects MoE router weights (and any other parameters that need full fp32 precision), which silently lose precision and can destabilise training or cause incorrect routing.

This PR introduces an opt-in mechanism: fp32_pinned_parameters, a list of sub-string patterns in the ZeRO config. Any parameter whose name contains one of these patterns is re-cast to float32 right after the bf16 wrappers would have downcast it, and is tracked through a dedicated fp32 path in the BF16 optimizer.

How has this been tested?

Manually traced the dtype of a small synthetic MoE model under deepspeed.zero.Init with bf16: {enabled: true} and verified that matched parameters retain torch.float32 dtype while unmatched parameters are correctly torch.bfloat16.

Configuration changes

"zero_optimization": {
  "stage": 3,
  "fp32_pinned_parameters": ["router.weight", "gate."]
}

Files changed

  • deepspeed/runtime/zero/config.py — new fp32_pinned_parameters field in ZeroConfig
  • deepspeed/runtime/zero/partition_parameters.py — mark and re-cast matched params to fp32 in _post_init_method and _convert_to_zero_parameters
  • deepspeed/runtime/bf16_optimizer.py — separate fp32-pinned params from bf16 groups; route them through a dedicated fp32 group in the base optimizer; include their gradients in norm/clip computation
  • deepspeed/runtime/engine.py — log active fp32-pinned patterns at optimizer creation for observability

Fixes #7747

When bf16 is enabled with ZeRO stage 3, all model parameters were cast
to bfloat16 inside the `Init` context via the global tensor-creation
wrappers.  This caused MoE router weights (and any other parameters that
require full fp32 precision) to silently lose precision, leading to
incorrect routing decisions and training instability.

A new ZeRO config field `fp32_pinned_parameters` (list of name-pattern
strings) lets users designate parameters that must remain in fp32:

  "zero_optimization": {
    "stage": 3,
    "fp32_pinned_parameters": ["router.weight", "gate."]
  }

Changes:
- `config.py`: add `fp32_pinned_parameters` field to `ZeroConfig`
- `partition_parameters.py`: in `_post_init_method` and
  `_convert_to_zero_parameters`, mark matching params with
  `ds_fp32_pinned = True` and re-cast their data to float32 after
  the bf16 tensor-creation wrappers would have downcast them.
- `bf16_optimizer.py`: in `_setup_for_real_optimizer`, separate
  fp32-pinned params from the normal bf16 groups; add them as a
  dedicated fp32 group in the base optimizer so their states are
  kept in fp32.  Include their gradients in norm/clip computation
  (`get_grads_for_norm`) and clear them alongside bf16 params
  (`clear_lp_grads`).
- `engine.py`: log the active fp32-pinned patterns at BF16 optimizer
  creation time for easier debugging.

Fixes deepspeedai#7747
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG]MoE router parameters are forced to bf16 under DeepSpeed bf16, causing dtype mismatch in fp32 routing logic

1 participant