Fix DDP checkpoint loading by using model.module.load_state_dict #437
+134
−9
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Summary
This PR fixes a common DistributedDataParallel (DDP) checkpoint loading error in multi-GPU setups by modifying the state_dict loading logic to use
model.module.load_state_dict()instead ofmodel.load_state_dict(). This ensures compatibility with checkpoints saved without the"module."prefix (e.g., from single-GPU or non-DDP runs). Additionally, it updates checkpoint saving to always strip the DDP prefix viamodel.module.state_dict(), making saved files portable across single- and multi-GPU environments. It also addstime.sleep(5)before checkpoint loading to ensure synchronization across distributed processes, preventing race conditions where non-rank-0 processes attempt to load before the file is fully written.Fixed Issue
RuntimeError: Error(s) in loading state_dict for DistributedDataParallel: Missing key(s) in state_dictduring evaluation or resume in distributed mode.Motivation and Context
PyTorch's DDP wraps models with a
"module."prefix on parameter keys for multi-GPU synchronization. However, if checkpoints are saved without this prefix (common in RF-DETR's default trainer), loading fails in DDP-wrapped models. This is a frequent pain point in distributed DETR variants (e.g., see PyTorch docs on Saving and Loading Models and community discussions like this Stack Overflow thread). The changes make RF-DETR's checkpoint handling DDP-aware without breaking single-GPU usage.Dependencies
Type of change
Please delete options that are not relevant.
How has this change been tested, please provide a testcase or example of how you tested the change?
Tested on a multi-GPU setup (2x Tesla V100s via
torchrun --nproc_per_node=2) with RF-DETR segmentation fine-tuning:Reproduce Error (Pre-Fix):
checkpoint_best_total.pthwithout prefix).torchrun --nproc_per_node=2 main.py --run_test --resume checkpoint_best_total.pth.RuntimeErroron key mismatch (missing"module."prefixed keys).Verify Fix (Post-Merge):
main.py(load/save hooks around lines 502 and checkpoint callbacks).nproc_per_node=1)—no prefix errors.Full test script snippet:
Ran on PyTorch 2.1.0, CUDA 12.1; no regressions in non-DDP mode.
Any specific deployment considerations
--master_portflag in docs for cluster runs to avoid port conflicts.model.module); new saves are prefix-free for broader compatibility.Docs