Skip to content

only add parameter with grads to parameter group#7869

Open
delock wants to merge 4 commits intomasterfrom
gma/fix_muon_partial_training
Open

only add parameter with grads to parameter group#7869
delock wants to merge 4 commits intomasterfrom
gma/fix_muon_partial_training

Conversation

@delock
Copy link
Collaborator

@delock delock commented Feb 22, 2026

This PR fix a bug when Muon optimizer is used on training part of the model parameters.

When train part of the model parameters (and freeze all others). In certain case, all trainable paramters will use Muon optimizer and non of them use AdamW optimizer, or vice versa. It will cause one of muon_params and non_muon_params to contain only non-trainable parameters, which would eventurally cause the following failure.

A reasonable fix is only add parameter with grads to muon_params and non_muon_params, so the case above would cause one of the parameter groups to be empty and get filterd out immediately.

 [rank3]: Traceback (most recent call last):                                                                               
 [rank3]:   File "/home/gma/transfer_qwen/finetune_moe.py", line 904, in <module>                                          
 [rank3]:     main(args)                                                                                                   
 [rank3]:   File "/home/gma/transfer_qwen/finetune_moe.py", line 709, in main
 [rank3]:     model_engine, optimizer, train_dataloader, lr_scheduler = deepspeed.initialize(                              
 [rank3]:                                                               ^^^^^^^^^^^^^^^^^^^^^                              
 [rank3]:   File "/home/gma/DeepSpeed/deepspeed/__init__.py", line 214, in initialize
 [rank3]:     engine = DeepSpeedEngine(args=args,
 [rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
 [rank3]:   File "/home/gma/DeepSpeed/deepspeed/runtime/engine.py", line 363, in __init__
 [rank3]:     self._configure_optimizer(optimizer, model_parameters)
 [rank3]:   File "/home/gma/DeepSpeed/deepspeed/runtime/engine.py", line 1585, in _configure_optimizer
 [rank3]:     self.optimizer = self._configure_zero_optimizer(basic_optimizer)
 [rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [rank3]:   File "/home/gma/DeepSpeed/deepspeed/runtime/engine.py", line 1893, in _configure_zero_optimizer
 [rank3]:     optimizer = Stage1And2ZeroOptimizer(
 [rank3]:                 ^^^^^^^^^^^^^^^^^^^^^^^^
 [rank3]:   File "/home/gma/DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py", line 403, in __init__
 [rank3]:     flattened_buffer = self.flatten_dense_tensors_aligned(
 [rank3]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [rank3]:   File "/home/gma/DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py", line 1044, in flatten_dense_tensors_aligned
 [rank3]:     return self.flatten(align_dense_tensors(tensor_list, alignment))
 [rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [rank3]:   File "/raid/gma/miniforge3/envs/ds/lib/python3.12/site-packages/torch/_utils.py", line 571, in
 _flatten_dense_tensors
 [rank3]:     return torch._C._nn.flatten_dense_tensors(tensors)
 [rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [rank3]: ValueError: torch.cat(): expected a non-empty list of Tensors

Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f0265ef538

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@PKUWZP PKUWZP self-requested a review February 22, 2026 17:43
Copy link
Collaborator

@PKUWZP PKUWZP left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@delock Thanks so much for the fix. In addition to the changes, can we add a test case that freezes some parameters and verifies deepspeed.initialize completes without error under the Muon optimizer path? Thanks!

@delock
Copy link
Collaborator Author

delock commented Feb 23, 2026

@delock Thanks so much for the fix. In addition to the changes, can we add a test case that freezes some parameters and verifies deepspeed.initialize completes without error under the Muon optimizer path? Thanks!

@PKUWZP Good suggestion! Let me add the test

@delock delock requested a review from loadams as a code owner February 23, 2026 03:01
@deepspeedai deepspeedai deleted a comment from PawnOfDelock Feb 23, 2026
@delock delock requested a review from PKUWZP February 23, 2026 03:02
@delock
Copy link
Collaborator Author

delock commented Feb 23, 2026

@delock Thanks so much for the fix. In addition to the changes, can we add a test case that freezes some parameters and verifies deepspeed.initialize completes without error under the Muon optimizer path? Thanks!

A new test is added, this test will fail on master with the exact error described, and pass with this fix.

delock and others added 3 commits February 22, 2026 19:03
Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
@delock
Copy link
Collaborator Author

delock commented Feb 25, 2026

Hi @PKUWZP This PR is ready for review again. A test is added to expose the issue it intend to fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants