Skip to content

[BUG] Fix: Fix gradient norm calculation and dynamic shape blocking in PP+ZeRO1 collective communication#7847

Open
Thinksky5124 wants to merge 1 commit intodeepspeedai:masterfrom
Thinksky5124:master
Open

[BUG] Fix: Fix gradient norm calculation and dynamic shape blocking in PP+ZeRO1 collective communication#7847
Thinksky5124 wants to merge 1 commit intodeepspeedai:masterfrom
Thinksky5124:master

Conversation

@Thinksky5124
Copy link

Describe the bug

This commit fixes gradient normalization bugs when using DeepSpeed Pipeline Parallel (pp) together with ZeRO Stage 1 (zero1), including the following aspects:

  1. PipelineEngine Buffer Type Consistency in Dyanmic Shape
    In deepspeed/runtime/pipe/engine.py, the activation buffer previously did not enforce dtype conversion, which could lead to inconsistent types and subsequent calculation errors. Now, the return value is explicitly cast to the target dtype, ensuring type consistency.

  2. ZeRO Stage 1/2 Gradient Normalization Logic Correction
    In deepspeed/runtime/zero/stage_1_and_2.py, for both CPU-offload and regular scenarios, the previous gradient normalization involved redundant communication and incorrect normalization:

complete_grad_norm_calculation_for_cpu_offload now only computes the local squared L2 norm without cross-rank communication, avoiding redundant normalization and double counting.

get_grad_norm_direct only supports L2 norm, directly accumulates the local gradient squared sum, and avoids double counting for pipeline parallel parameters.

scaled_global_norm unifies the normalization process: first accumulates the squared sum of gradients for all groups locally, then uses all_reduce to aggregate, and finally takes the square root to obtain the global L2 norm, ensuring consistency with ZeRO design.

  1. Code Redundancy and Exception Handling Optimization
    Removed the inf constant and unnecessary norm_type branches, now only supporting L2 norm to simplify the logic. Stricter skipping of None gradients and pipeline-replicated parameters improves robustness.
    After these fixes, gradient normalization in pp+zero1 scenarios is more accurate, avoiding double counting, type inconsistency, and redundant communication, thus improving training stability and performance. Reviewers are advised to focus on the correctness and compatibility of gradient normalization code to ensure consistent behavior across different parallel/offload scenarios.

DeepSpeed Config

{
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 4,
    "steps_per_print": 1,
    "zero_optimization": {
        "stage": 1
    }
}

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 99697687fc

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 1997 to +1998
if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)
self._average_expert_grad_norms(local_total_sq_norm)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve per-group norms when averaging MoE expert grads

When has_moe_layers is true, _average_expert_grad_norms expects a per-parameter-group collection (it iterates by group index and checks self.is_moe_param_group[i]), but this call now passes a single accumulated tensor. In runs with multiple optimizer groups, that means MoE scaling is applied to at most index 0 (or to the already-mixed total), so expert and non-expert contributions are mis-scaled before clipping/overflow logic, producing incorrect global grad norms.

Useful? React with 👍 / 👎.

self._average_expert_grad_norms(norm_groups)
self._average_expert_grad_norms(local_total_sq_norm)

local_total_sq_norm = local_total_sq_norm.to(torch.cuda.current_device())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid hard-coding CUDA device for norm reduction

This forces scaled_global_norm onto torch.cuda.current_device() even though the optimizer code is written against DeepSpeed's accelerator abstraction. On non-CUDA backends (or CPU-only execution), this line raises before all_reduce, so gradient norm computation and optimizer step fail outright; the tensor should stay on self.device or use get_accelerator().current_device_name().

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant