Skip to content

Commit 3cfc8f8

Browse files
committed
Remove experimental flag, since batchnorm is fixed.
1 parent 741b2ea commit 3cfc8f8

File tree

2 files changed

+4
-19
lines changed

2 files changed

+4
-19
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ Currently supported and tested models from [onnx_zoo](https://github.com/onnx/mo
3434

3535
## Limitations
3636
Known current version limitations are:
37-
- `batch_size > 1` requires `experimental=True` in `ConvertModel`.
38-
BatchNorm layers use inference mode (running statistics), which is correct for most ONNX models
39-
exported for inference. If your model was exported in training mode, results may differ.
37+
- `batch_size > 1` is now supported by default.
38+
BatchNorm layers use inference mode (running statistics), which is correct for ONNX models
39+
exported for inference.
4040
- Fine tuning and training of converted models was not tested yet, only inference.
4141

4242
## Development

onnx2pytorch/convert/model.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ def __init__(
9696
batch_dim: int
9797
Dimension of the batch.
9898
experimental: bool
99-
Experimental implementation allows batch_size > 1.
100-
BatchNorm layers use inference mode (running statistics), which is
101-
correct for most ONNX models exported for inference.
99+
At the moment it does not do anything anymore. Default: False
102100
enable_pruning: bool
103101
Track kept/pruned indices between different calls to forward pass.
104102
@@ -144,15 +142,6 @@ def __init__(
144142
self.onnx_model.graph, self, self.mapping
145143
)
146144

147-
if experimental:
148-
warnings.warn(
149-
"Using experimental implementation that allows 'batch_size > 1'. "
150-
"BatchNorm layers use inference mode (running statistics). "
151-
"This is correct for ONNX models exported for inference, but may be "
152-
"incorrect if the model was exported in training mode.",
153-
UserWarning,
154-
)
155-
156145
def forward(self, *input_list, **input_dict):
157146
if len(input_list) > 0 and len(input_dict) > 0:
158147
raise ValueError(
@@ -164,10 +153,6 @@ def forward(self, *input_list, **input_dict):
164153
if len(input_dict) > 0:
165154
inputs = [input_dict[key] for key in self.input_names]
166155

167-
if not self.experimental and inputs[0].shape[self.batch_dim] > 1:
168-
raise NotImplementedError(
169-
"Input with larger batch size than 1 not supported yet."
170-
)
171156
activations = dict(zip(self.input_names, inputs))
172157
still_needed_by = deepcopy(self.needed_by)
173158

0 commit comments

Comments
 (0)