Skip to content

Commit cd453ac

Browse files
test_instancenorm.py
1 parent 0a2d3c6 commit cd453ac

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import numpy as np
2+
import torch
3+
import pytest
4+
5+
from onnx2pytorch.operations import InstanceNormWrapper
6+
7+
8+
def instancenorm_reference(x, s, bias, eps):
9+
dims_x = len(x.shape)
10+
axis = tuple(range(2, dims_x))
11+
mean = np.mean(x, axis=axis, keepdims=True)
12+
var = np.var(x, axis=axis, keepdims=True)
13+
dim_ones = (1,) * (dims_x - 2)
14+
s = s.reshape(-1, *dim_ones)
15+
bias = bias.reshape(-1, *dim_ones)
16+
return s * (x - mean) / np.sqrt(var + eps) + bias
17+
18+
19+
@pytest.fixture
20+
def x_np():
21+
# input size: (1, 2, 1, 3)
22+
return np.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(np.float32)
23+
24+
25+
@pytest.fixture
26+
def s_np():
27+
return np.array([1.0, 1.5]).astype(np.float32)
28+
29+
30+
@pytest.fixture
31+
def b_np():
32+
return np.array([0, 1]).astype(np.float32)
33+
34+
35+
def test_instancenorm(x_np, s_np, b_np):
36+
eps = 1e-5
37+
x = torch.from_numpy(x_np)
38+
s = torch.from_numpy(s_np)
39+
b = torch.from_numpy(b_np)
40+
41+
exp_y = instancenorm_reference(x_np, s_np, b_np, eps).astype(np.float32)
42+
exp_y_shape = (1, 2, 1, 3)
43+
op = InstanceNormWrapper([s, b], eps=eps)
44+
y = op(x)
45+
46+
assert y.shape == exp_y_shape
47+
assert np.allclose(y.detach().numpy(), exp_y, rtol=1e-5, atol=1e-5)
48+
49+
50+
def test_instancenorm_lazy(x_np, s_np, b_np):
51+
eps = 1e-5
52+
x = torch.from_numpy(x_np)
53+
s = torch.from_numpy(s_np)
54+
b = torch.from_numpy(b_np)
55+
56+
exp_y = instancenorm_reference(x_np, s_np, b_np, eps).astype(np.float32)
57+
exp_y_shape = (1, 2, 1, 3)
58+
op = InstanceNormWrapper([], eps=eps)
59+
y = op(x, s, b)
60+
61+
assert y.shape == exp_y_shape
62+
assert np.allclose(y.detach().numpy(), exp_y, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)