Skip to content

Commit 89266fb

Browse files
committed
Fix batchnorm inference, also for batch_size > 1
1 parent 44ea99a commit 89266fb

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ Currently supported and tested models from [onnx_zoo](https://github.com/onnx/mo
3434

3535
## Limitations
3636
Known current version limitations are:
37-
- `batch_size > 1` could deliver unexpected results due to ambiguity of onnx's BatchNorm layer.
38-
That is why in this case for now we raise an assertion error.
39-
Set `experimental=True` in `ConvertModel` to be able to use `batch_size > 1`.
37+
- `batch_size > 1` requires `experimental=True` in `ConvertModel`.
38+
BatchNorm layers use inference mode (running statistics), which is correct for most ONNX models
39+
exported for inference. If your model was exported in training mode, results may differ.
4040
- Fine tuning and training of converted models was not tested yet, only inference.
4141

4242
## Development

onnx2pytorch/convert/model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ def __init__(
9696
batch_dim: int
9797
Dimension of the batch.
9898
experimental: bool
99-
Experimental implementation allows batch_size > 1. However,
100-
batchnorm layers could potentially produce false outputs.
99+
Experimental implementation allows batch_size > 1.
100+
BatchNorm layers use inference mode (running statistics), which is
101+
correct for most ONNX models exported for inference.
101102
enable_pruning: bool
102103
Track kept/pruned indices between different calls to forward pass.
103104
@@ -145,8 +146,11 @@ def __init__(
145146

146147
if experimental:
147148
warnings.warn(
148-
"Using experimental implementation that allows 'batch_size > 1'."
149-
"Batchnorm layers could potentially produce false outputs."
149+
"Using experimental implementation that allows 'batch_size > 1'. "
150+
"BatchNorm layers use inference mode (running statistics). "
151+
"This is correct for ONNX models exported for inference, but may be "
152+
"incorrect if the model was exported in training mode.",
153+
UserWarning,
150154
)
151155

152156
def forward(self, *input_list, **input_dict):

onnx2pytorch/operations/batchnorm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def __init__(self, torch_params, *args, **kwargs):
4848
for key, value in zip(keys, torch_params):
4949
getattr(self.bnu, key).data = value
5050

51+
# Set to eval mode to use running statistics (ONNX inference behavior)
52+
self.bnu.eval()
53+
5154
def forward(self, X, scale=None, B=None, input_mean=None, input_var=None):
5255
if self.has_lazy:
5356
self.bnu.initialize_parameters(X)

0 commit comments

Comments
 (0)