From 1864a657b76e65860a10d0efb734b83f72a03e35 Mon Sep 17 00:00:00 2001 From: CedricHwong <997630814@qq.com> Date: Fri, 26 Dec 2025 07:59:44 +0000 Subject: [PATCH 1/3] Fix MSE calibration amax sync in distributed Signed-off-by: CedricHwong <997630814@qq.com> --- modelopt/torch/quantization/model_calib.py | 92 ++++++++++++++++++- .../quantization/test_mse_calibrate_sync.py | 83 +++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 tests/gpu/torch/quantization/test_mse_calibrate_sync.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d4cf249fe..97a944e91 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -266,7 +266,97 @@ def quant_func(x, amax, quantizer=module): # Step 4: Compute optimal amax and load it finish_stats_collection(model, method="mse") - # TODO: Sync amax across distributed processes + if not distributed_sync: + return + + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" + if isinstance(quantizer, SequentialQuantizer): + for _q in quantizer: + sync_quantizer_amax_across_dp_ep(_q, parallel_state) + return + if getattr(quantizer, "_amax", None) is not None: + quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + # TODO: create sync_bias_across_distributed_group + + for name, module in model.named_modules(): + if isinstance(module, QuantModule): + for child in module.children(): + if isinstance(child, (TensorQuantizer, SequentialQuantizer)): + sync_quantizer_amax_across_dp_ep(child, module.parallel_state) + + def sync_quantizer_amax_across_tp( + quantizer: TensorQuantizer | SequentialQuantizer, + linear_name: str, + quantizer_type: str, + axes_for_sync: list, + parallel_state: ParallelState, + ): + # Syncing amax across TP for sequential quantizer + if isinstance(quantizer, SequentialQuantizer): + for _q in quantizer: + # Syncing amax across TP for sequential quantizer + sync_quantizer_amax_across_tp( + _q, linear_name, quantizer_type, axes_for_sync, parallel_state + ) + return + # sync is not needed for block quantization + if quantizer.block_sizes is not None: + if hasattr(quantizer, "_padding"): + warnings.warn( + f"Found block-quantized padded {quantizer_type} for {linear_name}, amax will" + " not be synced correctly." + ) + # Skip amax sync for INT4 / W4A8 block quantization + # Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale) + if getattr(quantizer.block_sizes, "type", None) == "dynamic": + return + + if quantizer.axis in axes_for_sync and quantizer.amax is not None: + quantizer.sync_amax_across_distributed_group(parallel_state.tensor_parallel_group) + + for name, module in model.named_modules(): + if getattr(module, "_parallel_state", None) is None: + continue + + if is_quantized_column_parallel_linear(module): + sync_quantizer_amax_across_tp( + module.input_quantizer, + name, + "input_quantizer", + axes_for_sync=[None, -1], + parallel_state=module.parallel_state, + ) + + sync_quantizer_amax_across_tp( + module.weight_quantizer, + name, + "weight_quantizer", + axes_for_sync=[None, -1], + parallel_state=module.parallel_state, + ) + + if is_quantized_row_parallel_linear(module): + sync_quantizer_amax_across_tp( + module.input_quantizer, + name, + "input_quantizer", + axes_for_sync=[None], + parallel_state=module.parallel_state, + ) + + sync_quantizer_amax_across_tp( + module.weight_quantizer, + name, + "weight_quantizer", + axes_for_sync=[None, 0], + parallel_state=module.parallel_state, + ) + + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() def enable_stats_collection(model: nn.Module): diff --git a/tests/gpu/torch/quantization/test_mse_calibrate_sync.py b/tests/gpu/torch/quantization/test_mse_calibrate_sync.py new file mode 100644 index 000000000..79798db9b --- /dev/null +++ b/tests/gpu/torch/quantization/test_mse_calibrate_sync.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from _test_utils.torch.distributed.utils import get_device_counts, spawn_multiprocess_job + +import modelopt.torch.quantization as mtq + + +def _test_mse_calibrate_sync(distributed_sync: bool, rank: int, size: int) -> None: + model = nn.Sequential(nn.Linear(16, 16), nn.ReLU(), nn.Linear(16, 16)).cuda() + + config = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + config["algorithm"] = { + "method": "mse", + "num_steps": 16, + "start_multiplier": 0.001, + "stop_multiplier": 4.0, + "distributed_sync": distributed_sync, + } + + def forward_loop(model): + torch.manual_seed(1234 + rank) + scale = 1.0 if rank == 0 else 100.0 + for _ in range(4): + model(torch.randn(64, 16, device="cuda") * scale) + + model = mtq.quantize(model, config, forward_loop) + + target = next(module for module in model.modules() if hasattr(module, "input_quantizer")) + amax_val = target.input_quantizer.amax.detach().float().max() + + gathered = [torch.zeros_like(amax_val) for _ in range(size)] + dist.all_gather(gathered, amax_val) + + if size < 2 or rank != 0: + return + + values = torch.stack(gathered) + if distributed_sync: + assert torch.allclose(values, values[0], rtol=0, atol=0), ( + "Expected amax values to be synchronized across ranks, but got " + f"{values.tolist()}" + ) + else: + assert (values.max() - values.min()) > 10.0, ( + "Expected amax values to differ across ranks when sync is disabled, but got " + f"{values.tolist()}" + ) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_mse_calibrate_with_sync(device_count): + spawn_multiprocess_job( + size=device_count, job=partial(_test_mse_calibrate_sync, True), backend="nccl" + ) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_mse_calibrate_without_sync(device_count): + if device_count < 2: + pytest.skip("need 2 GPUs") + spawn_multiprocess_job( + size=device_count, job=partial(_test_mse_calibrate_sync, False), backend="nccl" + ) From 06d9b0c93f7147eae20cc3548abcc2a626a4b34a Mon Sep 17 00:00:00 2001 From: CedricHwong <997630814@qq.com> Date: Fri, 26 Dec 2025 11:22:05 +0000 Subject: [PATCH 2/3] Update changelog for MSE calibration sync Signed-off-by: CedricHwong <997630814@qq.com> --- CHANGELOG.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 826b51160..cfbf20981 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,10 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend `` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``. - Add ``examples/llm_qad`` for QAD training with Megatron-LM. +**Bug Fixes** + +- Synchronize MSE calibration amax across distributed groups (DP/EP/TP) to keep quantization parameters consistent. + **Deprecations** - Deprecate ``num_query_groups`` parameter in Minitron pruning (``mcore_minitron``). You can use ModelOpt 0.40.0 or earlier instead if you need to prune it. From 41180e9462a0dfcc6f384eb3ef7b479de73e3d05 Mon Sep 17 00:00:00 2001 From: CedricHwong <997630814@qq.com> Date: Fri, 26 Dec 2025 17:14:29 +0000 Subject: [PATCH 3/3] Sync bias across distributed calibration Signed-off-by: CedricHwong <997630814@qq.com> --- modelopt/torch/quantization/model_calib.py | 6 +- .../nn/modules/tensor_quantizer.py | 48 +++++++ .../quantization/test_mse_calibrate_sync.py | 122 ++++++++++++++++++ 3 files changed, 174 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 97a944e91..355289094 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -93,7 +93,8 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) - # TODO: create sync_bias_across_distributed_group + quantizer.sync_bias_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_bias_across_distributed_group(parallel_state.expert_model_parallel_group) for name, module in model.named_modules(): if isinstance(module, QuantModule): @@ -278,7 +279,8 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) - # TODO: create sync_bias_across_distributed_group + quantizer.sync_bias_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_bias_across_distributed_group(parallel_state.expert_model_parallel_group) for name, module in model.named_modules(): if isinstance(module, QuantModule): diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index de7407f70..a7b07adae 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1169,6 +1169,54 @@ def sync_amax_across_distributed_group(self, parallel_group: DistributedProcessG "if happening during modelopt restore." ) + def sync_bias_across_distributed_group(self, parallel_group: DistributedProcessGroup): + """Synchronize the bias across all ranks in the given group.""" + if not parallel_group.is_initialized(): + return + if self.bias_calibrator is None or self.bias_type != "static": + return + + bias = self.bias_calibrator.compute_bias() + if bias is None: + return + + try: + if self.bias_method == "mean": + cnt = float(getattr(self.bias_calibrator, "_cnt", 0)) + bias_sum = bias.float() * cnt + cnt_tensor = torch.tensor(cnt, device=bias_sum.device, dtype=bias_sum.dtype) + dist.all_reduce(bias_sum, op=dist.ReduceOp.SUM, group=parallel_group.group) + dist.all_reduce(cnt_tensor, op=dist.ReduceOp.SUM, group=parallel_group.group) + if cnt_tensor.item() > 0: + bias_avg = (bias_sum / cnt_tensor).to(bias.dtype) + else: + bias_avg = bias + self.bias_value = bias_avg + self.bias_calibrator._calib_bias = bias_avg.detach().clone() + self.bias_calibrator._cnt = int(cnt_tensor.item()) + elif self.bias_method == "max_min": + calib_max = getattr(self.bias_calibrator, "_calib_max", None) + calib_min = getattr(self.bias_calibrator, "_calib_min", None) + if calib_max is None: + calib_max = torch.full_like(bias, -float("inf")) + if calib_min is None: + calib_min = torch.full_like(bias, float("inf")) + dist.all_reduce(calib_max, op=dist.ReduceOp.MAX, group=parallel_group.group) + dist.all_reduce(calib_min, op=dist.ReduceOp.MIN, group=parallel_group.group) + bias_val = ((calib_max + calib_min) / 2).to(bias.dtype) + self.bias_value = bias_val + self.bias_calibrator._calib_max = calib_max.detach().clone() + self.bias_calibrator._calib_min = calib_min.detach().clone() + self.bias_calibrator._calib_bias = bias_val.detach().clone() + else: + warnings.warn(f"Unsupported bias method: {self.bias_method}; skipping bias sync.") + except RuntimeError as e: + warnings.warn( + f"Failed to synchronize bias: {e}, probably because the tensor is on a device which is not" + "supported by the current distributed backend. This warning can be ignored" + "if happening during modelopt restore." + ) + @contextlib.contextmanager def disable_pre_quant_scale(self): """Context manager to turn off pre_quant_scale inside this quantizer.""" diff --git a/tests/gpu/torch/quantization/test_mse_calibrate_sync.py b/tests/gpu/torch/quantization/test_mse_calibrate_sync.py index 79798db9b..9e0b5b2ee 100644 --- a/tests/gpu/torch/quantization/test_mse_calibrate_sync.py +++ b/tests/gpu/torch/quantization/test_mse_calibrate_sync.py @@ -67,6 +67,96 @@ def forward_loop(model): ) +def _test_mse_calibrate_bias_sync(distributed_sync: bool, rank: int, size: int) -> None: + for bias_method in ["mean", "max_min"]: + model = nn.Sequential(nn.Linear(16, 16), nn.ReLU(), nn.Linear(16, 16)).cuda() + + config = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + config["quant_cfg"]["*input_quantizer"]["bias"] = { + 0: None, + "type": "static", + "method": bias_method, + } + config["algorithm"] = { + "method": "mse", + "num_steps": 16, + "start_multiplier": 0.001, + "stop_multiplier": 4.0, + "distributed_sync": distributed_sync, + } + + def forward_loop(model): + torch.manual_seed(4321 + rank) + offset = 0.0 if rank == 0 else 10.0 + for _ in range(4): + model(torch.randn(64, 16, device="cuda") * 0.1 + offset) + + model = mtq.quantize(model, config, forward_loop) + + target = next(module for module in model.modules() if hasattr(module, "input_quantizer")) + bias_val = target.input_quantizer.bias_value.detach().float().mean() + + gathered = [torch.zeros_like(bias_val) for _ in range(size)] + dist.all_gather(gathered, bias_val) + + if size < 2 or rank != 0: + continue + + values = torch.stack(gathered) + if distributed_sync: + assert torch.allclose(values, values[0], rtol=0, atol=0), ( + f"Expected bias values to be synchronized across ranks for {bias_method}, but got " + f"{values.tolist()}" + ) + else: + assert (values.max() - values.min()) > 5.0, ( + f"Expected bias values to differ across ranks for {bias_method} when sync is disabled, " + f"but got {values.tolist()}" + ) + + +def _test_max_calibrate_bias_sync(distributed_sync: bool, rank: int, size: int) -> None: + for bias_method in ["mean", "max_min"]: + model = nn.Sequential(nn.Linear(16, 16), nn.ReLU(), nn.Linear(16, 16)).cuda() + + config = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + config["quant_cfg"]["*input_quantizer"]["bias"] = { + 0: None, + "type": "static", + "method": bias_method, + } + config["algorithm"] = {"method": "max", "distributed_sync": distributed_sync} + + def forward_loop(model): + torch.manual_seed(9876 + rank) + offset = 0.0 if rank == 0 else 10.0 + for _ in range(4): + model(torch.randn(64, 16, device="cuda") * 0.1 + offset) + + model = mtq.quantize(model, config, forward_loop) + + target = next(module for module in model.modules() if hasattr(module, "input_quantizer")) + bias_val = target.input_quantizer.bias_value.detach().float().mean() + + gathered = [torch.zeros_like(bias_val) for _ in range(size)] + dist.all_gather(gathered, bias_val) + + if size < 2 or rank != 0: + continue + + values = torch.stack(gathered) + if distributed_sync: + assert torch.allclose(values, values[0], rtol=0, atol=0), ( + f"Expected bias values to be synchronized across ranks for {bias_method}, but got " + f"{values.tolist()}" + ) + else: + assert (values.max() - values.min()) > 5.0, ( + f"Expected bias values to differ across ranks for {bias_method} when sync is disabled, " + f"but got {values.tolist()}" + ) + + @pytest.mark.parametrize("device_count", get_device_counts()) def test_mse_calibrate_with_sync(device_count): spawn_multiprocess_job( @@ -81,3 +171,35 @@ def test_mse_calibrate_without_sync(device_count): spawn_multiprocess_job( size=device_count, job=partial(_test_mse_calibrate_sync, False), backend="nccl" ) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_mse_calibrate_bias_with_sync(device_count): + spawn_multiprocess_job( + size=device_count, job=partial(_test_mse_calibrate_bias_sync, True), backend="nccl" + ) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_mse_calibrate_bias_without_sync(device_count): + if device_count < 2: + pytest.skip("need 2 GPUs") + spawn_multiprocess_job( + size=device_count, job=partial(_test_mse_calibrate_bias_sync, False), backend="nccl" + ) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_max_calibrate_bias_with_sync(device_count): + spawn_multiprocess_job( + size=device_count, job=partial(_test_max_calibrate_bias_sync, True), backend="nccl" + ) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_max_calibrate_bias_without_sync(device_count): + if device_count < 2: + pytest.skip("need 2 GPUs") + spawn_multiprocess_job( + size=device_count, job=partial(_test_max_calibrate_bias_sync, False), backend="nccl" + )