Skip to content

Commit 741b2ea

Browse files
committed
Add tests for batchnorm layer.
1 parent 89266fb commit 741b2ea

File tree

1 file changed

+307
-0
lines changed

1 file changed

+307
-0
lines changed
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
import numpy as np
2+
import onnxruntime as ort
3+
import pytest
4+
import torch
5+
from onnx import helper, TensorProto
6+
7+
from onnx2pytorch.convert import ConvertModel
8+
from onnx2pytorch.operations.batchnorm import BatchNormWrapper
9+
10+
11+
@pytest.mark.parametrize(
12+
"batch_size,channels,height,width,epsilon,momentum",
13+
[
14+
# Test with batch_size=1
15+
(1, 3, 5, 5, 1e-5, 0.9),
16+
# Test with batch_size>1 (the critical case)
17+
(2, 3, 5, 5, 1e-5, 0.9),
18+
(4, 3, 5, 5, 1e-5, 0.9),
19+
(8, 16, 7, 7, 1e-5, 0.9),
20+
# Test with different epsilons
21+
(2, 3, 5, 5, 1e-3, 0.9),
22+
(2, 3, 5, 5, 1e-7, 0.9),
23+
# Test with different momentums
24+
(2, 3, 5, 5, 1e-5, 0.1),
25+
(2, 3, 5, 5, 1e-5, 0.99),
26+
# Test with different spatial dimensions
27+
(2, 8, 10, 10, 1e-5, 0.9),
28+
(2, 16, 3, 3, 1e-5, 0.9),
29+
],
30+
)
31+
def test_batchnorm_onnxruntime(batch_size, channels, height, width, epsilon, momentum):
32+
"""Test BatchNorm against onnxruntime with various batch sizes."""
33+
np.random.seed(42)
34+
torch.manual_seed(42)
35+
36+
# Create input
37+
X = np.random.randn(batch_size, channels, height, width).astype(np.float32)
38+
39+
# Create BatchNorm parameters
40+
scale = np.random.randn(channels).astype(np.float32)
41+
bias = np.random.randn(channels).astype(np.float32)
42+
mean = np.random.randn(channels).astype(np.float32)
43+
var = np.abs(np.random.randn(channels).astype(np.float32)) + 0.1 # Ensure positive
44+
45+
# Create ONNX graph with BatchNormalization node
46+
input_tensor = helper.make_tensor_value_info(
47+
"X", TensorProto.FLOAT, [batch_size, channels, height, width]
48+
)
49+
output_tensor = helper.make_tensor_value_info(
50+
"Y", TensorProto.FLOAT, [batch_size, channels, height, width]
51+
)
52+
53+
scale_init = helper.make_tensor(
54+
"scale", TensorProto.FLOAT, [channels], scale.tolist()
55+
)
56+
bias_init = helper.make_tensor("B", TensorProto.FLOAT, [channels], bias.tolist())
57+
mean_init = helper.make_tensor("mean", TensorProto.FLOAT, [channels], mean.tolist())
58+
var_init = helper.make_tensor("var", TensorProto.FLOAT, [channels], var.tolist())
59+
60+
bn_node = helper.make_node(
61+
"BatchNormalization",
62+
inputs=["X", "scale", "B", "mean", "var"],
63+
outputs=["Y"],
64+
epsilon=epsilon,
65+
momentum=momentum,
66+
)
67+
68+
graph = helper.make_graph(
69+
[bn_node],
70+
"batchnorm_test",
71+
[input_tensor],
72+
[output_tensor],
73+
[scale_init, bias_init, mean_init, var_init],
74+
)
75+
76+
model = helper.make_model(
77+
graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8
78+
)
79+
80+
# Run with onnxruntime
81+
ort_session = ort.InferenceSession(model.SerializeToString())
82+
ort_outputs = ort_session.run(None, {"X": X})
83+
expected_Y = ort_outputs[0]
84+
85+
# Convert to PyTorch and run
86+
o2p_model = ConvertModel(model, experimental=True)
87+
X_torch = torch.from_numpy(X)
88+
89+
with torch.no_grad():
90+
o2p_output = o2p_model(X_torch)
91+
92+
# Compare outputs
93+
torch.testing.assert_close(
94+
o2p_output,
95+
torch.from_numpy(expected_Y),
96+
rtol=1e-5,
97+
atol=1e-5,
98+
msg=f"BatchNorm mismatch for batch_size={batch_size}, channels={channels}",
99+
)
100+
101+
102+
def test_batchnorm_bias_fix():
103+
"""Test that the bias parameter is correctly applied (not overwritten by scale)."""
104+
np.random.seed(42)
105+
106+
batch_size = 2
107+
channels = 4
108+
height, width = 5, 5
109+
110+
X = np.random.randn(batch_size, channels, height, width).astype(np.float32)
111+
112+
# Create BatchNorm parameters with distinct scale and bias
113+
scale = np.ones(channels, dtype=np.float32) * 2.0 # Scale = 2
114+
bias = np.ones(channels, dtype=np.float32) * 5.0 # Bias = 5 (should NOT be 2!)
115+
mean = np.zeros(channels, dtype=np.float32)
116+
var = np.ones(channels, dtype=np.float32)
117+
118+
# Create ONNX model
119+
input_tensor = helper.make_tensor_value_info(
120+
"X", TensorProto.FLOAT, [batch_size, channels, height, width]
121+
)
122+
output_tensor = helper.make_tensor_value_info(
123+
"Y", TensorProto.FLOAT, [batch_size, channels, height, width]
124+
)
125+
126+
scale_init = helper.make_tensor(
127+
"scale", TensorProto.FLOAT, [channels], scale.tolist()
128+
)
129+
bias_init = helper.make_tensor("B", TensorProto.FLOAT, [channels], bias.tolist())
130+
mean_init = helper.make_tensor("mean", TensorProto.FLOAT, [channels], mean.tolist())
131+
var_init = helper.make_tensor("var", TensorProto.FLOAT, [channels], var.tolist())
132+
133+
bn_node = helper.make_node(
134+
"BatchNormalization",
135+
inputs=["X", "scale", "B", "mean", "var"],
136+
outputs=["Y"],
137+
epsilon=1e-5,
138+
)
139+
140+
graph = helper.make_graph(
141+
[bn_node],
142+
"batchnorm_bias_test",
143+
[input_tensor],
144+
[output_tensor],
145+
[scale_init, bias_init, mean_init, var_init],
146+
)
147+
148+
model = helper.make_model(
149+
graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8
150+
)
151+
152+
# Run with onnxruntime (ground truth)
153+
ort_session = ort.InferenceSession(model.SerializeToString())
154+
ort_outputs = ort_session.run(None, {"X": X})
155+
expected_Y = ort_outputs[0]
156+
157+
# Convert to PyTorch
158+
o2p_model = ConvertModel(model, experimental=True)
159+
X_torch = torch.from_numpy(X)
160+
161+
with torch.no_grad():
162+
o2p_output = o2p_model(X_torch)
163+
164+
# If bias was incorrectly set to scale (the bug), outputs would differ
165+
torch.testing.assert_close(
166+
o2p_output,
167+
torch.from_numpy(expected_Y),
168+
rtol=1e-5,
169+
atol=1e-5,
170+
msg="Bias parameter was not correctly applied",
171+
)
172+
173+
# Verify that the output includes the bias (should be around 5, not 2)
174+
# After normalization: (X - 0) / sqrt(1 + eps) * 2 + 5 ≈ X * 2 + 5
175+
# The mean should be around 5 (from bias), not 2 (from scale)
176+
output_mean_per_channel = o2p_output.mean(dim=(0, 2, 3))
177+
# The mean should be close to bias (5), not scale (2)
178+
# Note: This is approximate since X is random
179+
assert torch.allclose(
180+
output_mean_per_channel, torch.tensor([5.0] * channels), rtol=1, atol=1
181+
)
182+
183+
184+
def test_batchnorm_eval_mode():
185+
"""Test that BatchNorm uses eval mode (running statistics)."""
186+
187+
channels = 4
188+
scale = torch.ones(channels)
189+
bias = torch.zeros(channels)
190+
running_mean = torch.randn(channels)
191+
running_var = torch.abs(torch.randn(channels)) + 0.1
192+
193+
# Create BatchNormWrapper
194+
bn_wrapper = BatchNormWrapper([scale, bias, running_mean, running_var])
195+
196+
# Verify it's in eval mode
197+
assert not bn_wrapper.bnu.training, "BatchNorm should be in eval mode"
198+
199+
# Test with batch_size > 1
200+
X = torch.randn(4, channels, 5, 5)
201+
202+
output = bn_wrapper(X)
203+
204+
# In eval mode, it should use running_mean and running_var,
205+
# not compute statistics from the current batch
206+
# Verify output shape
207+
assert output.shape == X.shape
208+
209+
210+
def test_batchnorm_formula():
211+
"""Test that BatchNorm implements the correct formula."""
212+
batch_size = 2
213+
channels = 3
214+
height, width = 4, 4
215+
216+
X = torch.randn(batch_size, channels, height, width)
217+
218+
scale = torch.ones(channels) * 2.0
219+
bias = torch.ones(channels) * 3.0
220+
mean = torch.zeros(channels)
221+
var = torch.ones(channels)
222+
epsilon = 1e-5
223+
224+
# Manual computation: Y = scale * (X - mean) / sqrt(var + epsilon) + bias
225+
expected = scale.view(1, -1, 1, 1) * (X - mean.view(1, -1, 1, 1)) / torch.sqrt(
226+
var.view(1, -1, 1, 1) + epsilon
227+
) + bias.view(1, -1, 1, 1)
228+
229+
# Using BatchNormWrapper
230+
231+
bn_wrapper = BatchNormWrapper([scale, bias, mean, var], eps=epsilon)
232+
output = bn_wrapper(X)
233+
234+
torch.testing.assert_close(output, expected, rtol=1e-5, atol=1e-5)
235+
236+
237+
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
238+
def test_batchnorm_consistency_across_batch_sizes(batch_size):
239+
"""Test that BatchNorm produces consistent results across different batch sizes."""
240+
np.random.seed(42)
241+
torch.manual_seed(42)
242+
243+
channels = 8
244+
height, width = 6, 6
245+
246+
# Create a deterministic input pattern
247+
X = np.random.randn(batch_size, channels, height, width).astype(np.float32)
248+
249+
scale = np.random.randn(channels).astype(np.float32)
250+
bias = np.random.randn(channels).astype(np.float32)
251+
mean = np.random.randn(channels).astype(np.float32)
252+
var = np.abs(np.random.randn(channels).astype(np.float32)) + 0.1
253+
254+
# Create ONNX model
255+
input_tensor = helper.make_tensor_value_info(
256+
"X", TensorProto.FLOAT, [batch_size, channels, height, width]
257+
)
258+
output_tensor = helper.make_tensor_value_info(
259+
"Y", TensorProto.FLOAT, [batch_size, channels, height, width]
260+
)
261+
262+
scale_init = helper.make_tensor(
263+
"scale", TensorProto.FLOAT, [channels], scale.tolist()
264+
)
265+
bias_init = helper.make_tensor("B", TensorProto.FLOAT, [channels], bias.tolist())
266+
mean_init = helper.make_tensor("mean", TensorProto.FLOAT, [channels], mean.tolist())
267+
var_init = helper.make_tensor("var", TensorProto.FLOAT, [channels], var.tolist())
268+
269+
bn_node = helper.make_node(
270+
"BatchNormalization",
271+
inputs=["X", "scale", "B", "mean", "var"],
272+
outputs=["Y"],
273+
epsilon=1e-5,
274+
)
275+
276+
graph = helper.make_graph(
277+
[bn_node],
278+
"batchnorm_consistency_test",
279+
[input_tensor],
280+
[output_tensor],
281+
[scale_init, bias_init, mean_init, var_init],
282+
)
283+
284+
model = helper.make_model(
285+
graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8
286+
)
287+
288+
# Run with onnxruntime
289+
ort_session = ort.InferenceSession(model.SerializeToString())
290+
ort_outputs = ort_session.run(None, {"X": X})
291+
expected_Y = ort_outputs[0]
292+
293+
# Convert to PyTorch
294+
o2p_model = ConvertModel(model, experimental=True)
295+
X_torch = torch.from_numpy(X)
296+
297+
with torch.no_grad():
298+
o2p_output = o2p_model(X_torch)
299+
300+
# Should match onnxruntime regardless of batch size
301+
torch.testing.assert_close(
302+
o2p_output,
303+
torch.from_numpy(expected_Y),
304+
rtol=1e-5,
305+
atol=1e-5,
306+
msg=f"BatchNorm failed for batch_size={batch_size}",
307+
)

0 commit comments

Comments
 (0)