Skip to content

fix: correct DistributedAttention output shape and pad uneven sequence lengths (#7842)#7868

Draft
harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
harshang03:fix/7842-dist-attn-uneven-seq
Draft

fix: correct DistributedAttention output shape and pad uneven sequence lengths (#7842)#7868
harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
harshang03:fix/7842-dist-attn-uneven-seq

Conversation

@harshang03
Copy link

What this PR does / why we need it

DistributedAttention produced numerically incorrect results whenever the global sequence length was not evenly divisible by the sequence-parallel world size. Two bugs were found and fixed:

Bug 1 – Wrong output shape formula (silent corruption even with divisible lengths)

In _generate_layout_params, the post_all2all_res_shape for the batch_dim_idx=1, scatter_idx < 2 path (seq-first layout, output all-to-all) was computed as [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]. After the permute, the tensor is [S/P, B, P, H/P, D]; collapsing the two head dimensions must yield [S/P, B, H, D], not the formula above. This caused the returned tensor to have the right number of elements but the wrong shape, silently misinterpreting data.

Bug 2 – Unequal local sequence lengths crash dist.all_to_all_single

When S % P != 0 different ranks hold different numbers of tokens (ceil(S/P) vs floor(S/P)). dist.all_to_all_single with implicit equal splits then silently reads/writes beyond allocated buffers, producing garbage outputs with no exception raised.

Fix

  • deepspeed/sequence/layer.py: correct post_all2all_res_shape for the batch_dim_idx=1, scatter_idx<2 case; add DistributedAttention._pad_to_seq_world_size(); in forward(), call _get_max_local_seq_len() to align all ranks to the same local sequence length, pad Q/K/V before the all-to-all, and slice the output back to the original length. F.pad is differentiable so gradients flow correctly without extra ctx bookkeeping.
  • deepspeed/utils/groups.py: add _get_max_local_seq_len(local_seq_len, group) — a single allreduce(MAX) so that all ranks can agree on the padded sequence length without multiple round trips.
  • deepspeed/sequence/fpdt_layer.py: replace the hard assert global_seq_len % sp_size == 0 in FPDT_InputConstruct.__init__ with graceful zero-padding of tokens/labels/loss_mask/position_ids to the nearest multiple of sp_size.
  • deepspeed/sequence/__init__.py: export DistributedAttention and _SeqAllToAll for easier downstream imports.

Fixes #7842

…e lengths

Two related bugs caused numerically incorrect results in DistributedAttention
when the global sequence length is not evenly divisible by the sequence-parallel
world size (scatter_idx < 2, batch_dim_idx == 1 path):

1. Wrong post_all2all_res_shape for seq-first layout (batch_dim_idx=1,
   scatter_idx < 2).  After the all-to-all + permute the tensor is
   [S/P, B, P, H/P, D]; collapsing the two head dims must give
   [S/P, B, H, D] not [B, P*S, H/P^2, D] as previously computed.
   This caused silent data corruption even for evenly-divisible sequences.

2. Unequal local sequence lengths across ranks (S % P != 0) mean that
   dist.all_to_all_single with implicit equal splits silently reads/writes
   out of bounds, producing garbage values without raising an exception.

Fix:
- Correct post_all2all_res_shape in _generate_layout_params for the
  batch_dim_idx=1, scatter_idx<2 case (layer.py).
- Add _get_max_local_seq_len() to groups.py: a single allreduce(MAX) that
  lets every rank agree on the largest local sequence length.
- In DistributedAttention.forward, call _get_max_local_seq_len() and pad
  Q/K/V on the sequence dimension to max_local_seq_len before the all-to-all.
  Slice the output back to the original local_seq_len after the output
  all-to-all.  F.pad is differentiable, so backward gradients are handled
  correctly without explicit ctx bookkeeping (layer.py).
- In FPDT_InputConstruct.__init__, pad global_seq_len to the nearest
  multiple of sp_size before computing chunk sizes, replacing the hard
  assert with a graceful pad (fpdt_layer.py).
- Export DistributedAttention and _SeqAllToAll from the sequence package
  __init__.py for easier import by downstream users.

Fixes deepspeedai#7842
@PKUWZP PKUWZP self-requested a review February 22, 2026 18:10
Copy link
Collaborator

@PKUWZP PKUWZP left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harshang03 Thanks so much for fixing the bugs. Here are my comments:

I understand that this PR tries to fix two bugs: 1. Fix the tensor shape; 2. Padding the uneven sequences.

Bug 1 — Shape Fix (Correct)

I think the fix at layer.py:46-49 is correct. Tracing the tensor shapes:

  • Input: [global_seq_len, bs, num_local_head, head_dim]
  • After all-to-all + permute (1,2,0,3,4): [global_seq_len//P, bs, P, num_local_head, head_dim]
  • Collapse last two head dims → [global_seq_len//P, bs, P * num_local_head, head_dim]

The old formula [bs, P * global_seq_len, num_local_head // P, head_dim] was incorrect in every dimension. The fix correctly changes this to
[global_seq_len // P, bs, P * num_local_head, head_dim].

@sfc-gh-truwase @tohtana feel free to chime in and confirm. I think it's from the FPDT paper in MLSys (https://arxiv.org/pdf/2408.16978)?


Bug 2 — Uneven Sequence Padding, I feel for this bug, the fix has some issues

The current fix provides an allreduce MAX which finds the largest local sequence length across ranks; Q/K/V are zero-padded to that length before the all-to-all;
the output is sliced back afterward. @harshang03 feel free to chime in if my understanding is incorrect.

A coupe of potential Issues:

  1. Zero-padding corrupts attention for real tokens. After the first all-to-all, each rank sees the full (padded) sequence. During local attention, padding tokens participate as keys/values:
    - Zero keys → dot(q, 0) = 0 → exp(0) = 1 in softmax → non-trivial attention weight
    - Zero values → contribute zero to the weighted sum, which might inflate softmax normalization denominator. This could cause the outputs for real (non-padded) tokens numerically different from a non-SP baseline.
I suggest a few things:
   - Pass an attention mask through to the local attention to mask out padded positions, or
   - Document this as a known approximation that becomes negligible for long sequences
  1. Allreduce every forward pass (_get_max_local_seq_len). This introduces a global synchronization point on every forward call. In practice, local sequence lengths are typically constant across a training run. Consider caching or allowing users to pass the max length directly to avoid this overhead.
  2. Backward pass correctness. The output slicing output[:, :local_seq_len] (or output[:local_seq_len]) creates a view, so gradients flow back correctly through the slice. However, the zero-padded positions receive gradients from the attention computation, meaning F.pad's backward will produce gradients that reflect the (incorrect) attention over padding tokens. This may cause minor training instabilities. @harshang03 Not sure if I understand it correctly.

FPDT Padding (Has Issues)

  1. attention_mask is not padded (fpdt_layer.py). tokens, labels, loss_mask, and position_ids are all padded, but attention_mask is left unchanged. If attention_mask has a sequence-length dimension (which it often does), this will cause a shape mismatch downstream.
  2. Second assert may still cause problems. After padding, global_seq_len % sp_size == 0 is satisfied, but the next assertion global_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0 could still fail if the padded length isn't divisible by chunk_size. The PR only addresses the first divisibility constraint.
  3. raw_global_seq_len stored but never used. self.raw_global_seq_len is set but never referenced — the generate() method doesn't trim back to the original length, so padded positions will flow through training including into the loss (mitigated only if loss_mask zeros them out).

Other Issues

  1. Unused imports in fpdt_layer.py. Both DistributedAttention and _get_max_local_seq_len are imported but never used in the FPDT changes:
    from .layer import single_all_to_all, apply_rotary_pos_emb, DistributedAttention # DistributedAttention unused from deepspeed.utils.groups import _get_max_local_seq_len # unused
  2. Exporting a private symbol. _SeqAllToAll (prefixed with _) is exported from __init__.py, effectively promoting it to a public API. If downstream users aren't expected to use it directly, it shouldn't be exported.
  3. No tests. This is a draft PR, but the fix is non-trivial and touches a correctness-critical code path. Tests should cover:
    - The shape fix (Bug 1) with batch_dim_idx=1, scatter_idx < 2
    - Uneven sequence lengths (S % P != 0) producing correct outputs
    - Round-trip correctness: verify SP output matches non-SP baseline

I suggest we separate this PR into two PRs. The first one covers the fix on tensor shape fix (which is straightforward); the second one fixes uneven sequence padding bugs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants