fix: keep fp32-pinned parameters out of the bf16 cast path in ZeRO-3 (#7747)#7867
Draft
harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
Draft
fix: keep fp32-pinned parameters out of the bf16 cast path in ZeRO-3 (#7747)#7867harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
Conversation
When bf16 is enabled with ZeRO stage 3, all model parameters were cast
to bfloat16 inside the `Init` context via the global tensor-creation
wrappers. This caused MoE router weights (and any other parameters that
require full fp32 precision) to silently lose precision, leading to
incorrect routing decisions and training instability.
A new ZeRO config field `fp32_pinned_parameters` (list of name-pattern
strings) lets users designate parameters that must remain in fp32:
"zero_optimization": {
"stage": 3,
"fp32_pinned_parameters": ["router.weight", "gate."]
}
Changes:
- `config.py`: add `fp32_pinned_parameters` field to `ZeroConfig`
- `partition_parameters.py`: in `_post_init_method` and
`_convert_to_zero_parameters`, mark matching params with
`ds_fp32_pinned = True` and re-cast their data to float32 after
the bf16 tensor-creation wrappers would have downcast them.
- `bf16_optimizer.py`: in `_setup_for_real_optimizer`, separate
fp32-pinned params from the normal bf16 groups; add them as a
dedicated fp32 group in the base optimizer so their states are
kept in fp32. Include their gradients in norm/clip computation
(`get_grads_for_norm`) and clear them alongside bf16 params
(`clear_lp_grads`).
- `engine.py`: log the active fp32-pinned patterns at BF16 optimizer
creation time for easier debugging.
Fixes deepspeedai#7747
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this PR does / why we need it
When training with BF16 and ZeRO stage 3, every model parameter is cast to
bfloat16inside thedeepspeed.zero.Initcontext via the global tensor-creation wrappers. This affects MoE router weights (and any other parameters that need full fp32 precision), which silently lose precision and can destabilise training or cause incorrect routing.This PR introduces an opt-in mechanism:
fp32_pinned_parameters, a list of sub-string patterns in the ZeRO config. Any parameter whose name contains one of these patterns is re-cast tofloat32right after the bf16 wrappers would have downcast it, and is tracked through a dedicated fp32 path in the BF16 optimizer.How has this been tested?
Manually traced the dtype of a small synthetic MoE model under
deepspeed.zero.Initwithbf16: {enabled: true}and verified that matched parameters retaintorch.float32dtype while unmatched parameters are correctlytorch.bfloat16.Configuration changes
Files changed
deepspeed/runtime/zero/config.py— newfp32_pinned_parametersfield inZeroConfigdeepspeed/runtime/zero/partition_parameters.py— mark and re-cast matched params to fp32 in_post_init_methodand_convert_to_zero_parametersdeepspeed/runtime/bf16_optimizer.py— separate fp32-pinned params from bf16 groups; route them through a dedicated fp32 group in the base optimizer; include their gradients in norm/clip computationdeepspeed/runtime/engine.py— log active fp32-pinned patterns at optimizer creation for observabilityFixes #7747