Introduce all_reduce_hook to support gradient aggregation across replica groups.#7764
Introduce all_reduce_hook to support gradient aggregation across replica groups.#7764zhengchenyu wants to merge 3 commits intodeepspeedai:masterfrom
Conversation
…ica groups. Signed-off-by: zhengchenyu <zhengchenyu16@163.com>
|
@zhengchenyu thanks for the PR. Can you provide some clarification for the motivation?
We already provide a form of this functionality in hpZ component of ZeRO++. Have you explored whether hpZ would meet your needs?
My understanding replica groups is only relevant for zero stage 3 since lower stages don't do parameter partitioning. Can you explain how replica groups exist in your workload? |
|
@sfc-gh-truwase Thanks for your review!
Regarding zero++. It cannot solve problem (2). It can solve problem (1), but there is a cost involved, we must introduce extra Regarding MICS. For zero stage 3, these two problem do not exist. For stage 1/2, there are no problems (1), but if the optimizer parameters are considered when loading the checkpoint, there will be problem for issue (2). |
|
Thanks for sharing more details.
|
|
@sfc-gh-truwase Thanks for your reply. |
Yes, I agree that existing options like |
Using replica groups offers the following advantages:
For stage 3, it ensures that parameter gather during forward and backward occurs only within the replica group.
Checkpointing is performed only on
replica_group_rank=0, guaranteeing constant checkpoint world size and avoiding the universal checkpoint transformations during scaling up or down.We can achieve gradient all reduce within the replica group after backward and before optimizer.step, but we must wait for all buckets to complete, thus can not leverage concurrency advantages.
I know MICS has similar functionality, but currently only supports zero stage 3. Additionally, I want to use this feature for compatibility with architectures like TorchFT.