fix: correct DistributedAttention output shape and pad uneven sequence lengths (#7842)#7868
fix: correct DistributedAttention output shape and pad uneven sequence lengths (#7842)#7868harshang03 wants to merge 1 commit intodeepspeedai:masterfrom
Conversation
…e lengths Two related bugs caused numerically incorrect results in DistributedAttention when the global sequence length is not evenly divisible by the sequence-parallel world size (scatter_idx < 2, batch_dim_idx == 1 path): 1. Wrong post_all2all_res_shape for seq-first layout (batch_dim_idx=1, scatter_idx < 2). After the all-to-all + permute the tensor is [S/P, B, P, H/P, D]; collapsing the two head dims must give [S/P, B, H, D] not [B, P*S, H/P^2, D] as previously computed. This caused silent data corruption even for evenly-divisible sequences. 2. Unequal local sequence lengths across ranks (S % P != 0) mean that dist.all_to_all_single with implicit equal splits silently reads/writes out of bounds, producing garbage values without raising an exception. Fix: - Correct post_all2all_res_shape in _generate_layout_params for the batch_dim_idx=1, scatter_idx<2 case (layer.py). - Add _get_max_local_seq_len() to groups.py: a single allreduce(MAX) that lets every rank agree on the largest local sequence length. - In DistributedAttention.forward, call _get_max_local_seq_len() and pad Q/K/V on the sequence dimension to max_local_seq_len before the all-to-all. Slice the output back to the original local_seq_len after the output all-to-all. F.pad is differentiable, so backward gradients are handled correctly without explicit ctx bookkeeping (layer.py). - In FPDT_InputConstruct.__init__, pad global_seq_len to the nearest multiple of sp_size before computing chunk sizes, replacing the hard assert with a graceful pad (fpdt_layer.py). - Export DistributedAttention and _SeqAllToAll from the sequence package __init__.py for easier import by downstream users. Fixes deepspeedai#7842
There was a problem hiding this comment.
@harshang03 Thanks so much for fixing the bugs. Here are my comments:
I understand that this PR tries to fix two bugs: 1. Fix the tensor shape; 2. Padding the uneven sequences.
Bug 1 — Shape Fix (Correct)
I think the fix at layer.py:46-49 is correct. Tracing the tensor shapes:
- Input:
[global_seq_len, bs, num_local_head, head_dim] - After all-to-all + permute (1,2,0,3,4):
[global_seq_len//P, bs, P, num_local_head, head_dim] - Collapse last two head dims →
[global_seq_len//P, bs, P * num_local_head, head_dim]
The old formula [bs, P * global_seq_len, num_local_head // P, head_dim] was incorrect in every dimension. The fix correctly changes this to
[global_seq_len // P, bs, P * num_local_head, head_dim].
@sfc-gh-truwase @tohtana feel free to chime in and confirm. I think it's from the FPDT paper in MLSys (https://arxiv.org/pdf/2408.16978)?
Bug 2 — Uneven Sequence Padding, I feel for this bug, the fix has some issues
The current fix provides an allreduce MAX which finds the largest local sequence length across ranks; Q/K/V are zero-padded to that length before the all-to-all;
the output is sliced back afterward. @harshang03 feel free to chime in if my understanding is incorrect.
A coupe of potential Issues:
- Zero-padding corrupts attention for real tokens. After the first all-to-all, each rank sees the full (padded) sequence. During local attention, padding tokens participate as keys/values:
-Zero keys → dot(q, 0) = 0 → exp(0) = 1in softmax → non-trivial attention weight
-Zero values → contribute zeroto the weighted sum, which might inflate softmax normalization denominator. This could cause the outputs for real (non-padded) tokens numerically different from a non-SP baseline.
I suggest a few things:
- Pass an attention mask through to the local attention to mask out padded positions, or
- Document this as a known approximation that becomes negligible for long sequences
- Allreduce every forward pass
(_get_max_local_seq_len). This introduces a global synchronization point on every forward call. In practice, local sequence lengths are typically constant across a training run. Consider caching or allowing users to pass themax lengthdirectly to avoid this overhead. - Backward pass correctness. The output slicing
output[:, :local_seq_len](oroutput[:local_seq_len]) creates a view, so gradients flow back correctly through the slice. However, the zero-padded positions receive gradients from the attention computation, meaningF.pad's backward will produce gradients that reflect the (incorrect) attention over padding tokens. This may cause minor training instabilities. @harshang03 Not sure if I understand it correctly.
FPDT Padding (Has Issues)
- attention_mask is not padded (
fpdt_layer.py).tokens, labels, loss_mask, andposition_idsare all padded, but attention_mask is left unchanged. If attention_mask has a sequence-length dimension (which it often does), this will cause a shape mismatch downstream. - Second
assertmay still cause problems. After padding,global_seq_len % sp_size == 0is satisfied, but the next assertionglobal_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0could still fail if the padded length isn't divisible by chunk_size. The PR only addresses the first divisibility constraint. raw_global_seq_lenstored but never used.self.raw_global_seq_lenis set but never referenced — thegenerate()method doesn't trim back to the original length, so padded positions will flow through training including into the loss (mitigated only ifloss_maskzeros them out).
Other Issues
- Unused imports in
fpdt_layer.py. BothDistributedAttentionand_get_max_local_seq_lenare imported but never used in the FPDT changes:
from .layer import single_all_to_all, apply_rotary_pos_emb, DistributedAttention#DistributedAttentionunusedfrom deepspeed.utils.groups import _get_max_local_seq_len# unused - Exporting a private
symbol. _SeqAllToAll (prefixed with _)is exported from__init__.py, effectively promoting it to a public API. If downstream users aren't expected to use it directly, it shouldn't be exported. - No tests. This is a draft PR, but the fix is non-trivial and touches a correctness-critical code path. Tests should cover:
- The shape fix (Bug 1) withbatch_dim_idx=1, scatter_idx < 2
- Uneven sequence lengths(S % P != 0)producing correct outputs
- Round-trip correctness: verify SP output matches non-SP baseline
I suggest we separate this PR into two PRs. The first one covers the fix on tensor shape fix (which is straightforward); the second one fixes uneven sequence padding bugs.
What this PR does / why we need it
DistributedAttentionproduced numerically incorrect results whenever the global sequence length was not evenly divisible by the sequence-parallel world size. Two bugs were found and fixed:Bug 1 – Wrong output shape formula (silent corruption even with divisible lengths)
In
_generate_layout_params, thepost_all2all_res_shapefor thebatch_dim_idx=1, scatter_idx < 2path (seq-first layout, output all-to-all) was computed as[bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]. After the permute, the tensor is[S/P, B, P, H/P, D]; collapsing the two head dimensions must yield[S/P, B, H, D], not the formula above. This caused the returned tensor to have the right number of elements but the wrong shape, silently misinterpreting data.Bug 2 – Unequal local sequence lengths crash
dist.all_to_all_singleWhen
S % P != 0different ranks hold different numbers of tokens (ceil(S/P)vsfloor(S/P)).dist.all_to_all_singlewith implicit equal splits then silently reads/writes beyond allocated buffers, producing garbage outputs with no exception raised.Fix
deepspeed/sequence/layer.py: correctpost_all2all_res_shapefor thebatch_dim_idx=1, scatter_idx<2case; addDistributedAttention._pad_to_seq_world_size(); inforward(), call_get_max_local_seq_len()to align all ranks to the same local sequence length, pad Q/K/V before the all-to-all, and slice the output back to the original length.F.padis differentiable so gradients flow correctly without extra ctx bookkeeping.deepspeed/utils/groups.py: add_get_max_local_seq_len(local_seq_len, group)— a singleallreduce(MAX)so that all ranks can agree on the padded sequence length without multiple round trips.deepspeed/sequence/fpdt_layer.py: replace the hardassert global_seq_len % sp_size == 0inFPDT_InputConstruct.__init__with graceful zero-padding of tokens/labels/loss_mask/position_ids to the nearest multiple ofsp_size.deepspeed/sequence/__init__.py: exportDistributedAttentionand_SeqAllToAllfor easier downstream imports.Fixes #7842