@@ -96,9 +96,7 @@ def __init__(
9696 batch_dim: int
9797 Dimension of the batch.
9898 experimental: bool
99- Experimental implementation allows batch_size > 1.
100- BatchNorm layers use inference mode (running statistics), which is
101- correct for most ONNX models exported for inference.
99+ At the moment it does not do anything anymore. Default: False
102100 enable_pruning: bool
103101 Track kept/pruned indices between different calls to forward pass.
104102
@@ -144,15 +142,6 @@ def __init__(
144142 self .onnx_model .graph , self , self .mapping
145143 )
146144
147- if experimental :
148- warnings .warn (
149- "Using experimental implementation that allows 'batch_size > 1'. "
150- "BatchNorm layers use inference mode (running statistics). "
151- "This is correct for ONNX models exported for inference, but may be "
152- "incorrect if the model was exported in training mode." ,
153- UserWarning ,
154- )
155-
156145 def forward (self , * input_list , ** input_dict ):
157146 if len (input_list ) > 0 and len (input_dict ) > 0 :
158147 raise ValueError (
@@ -164,10 +153,6 @@ def forward(self, *input_list, **input_dict):
164153 if len (input_dict ) > 0 :
165154 inputs = [input_dict [key ] for key in self .input_names ]
166155
167- if not self .experimental and inputs [0 ].shape [self .batch_dim ] > 1 :
168- raise NotImplementedError (
169- "Input with larger batch size than 1 not supported yet."
170- )
171156 activations = dict (zip (self .input_names , inputs ))
172157 still_needed_by = deepcopy (self .needed_by )
173158
0 commit comments