|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +import pytest |
| 4 | + |
| 5 | +from onnx2pytorch.operations import InstanceNormWrapper |
| 6 | + |
| 7 | + |
| 8 | +def instancenorm_reference(x, s, bias, eps): |
| 9 | + dims_x = len(x.shape) |
| 10 | + axis = tuple(range(2, dims_x)) |
| 11 | + mean = np.mean(x, axis=axis, keepdims=True) |
| 12 | + var = np.var(x, axis=axis, keepdims=True) |
| 13 | + dim_ones = (1,) * (dims_x - 2) |
| 14 | + s = s.reshape(-1, *dim_ones) |
| 15 | + bias = bias.reshape(-1, *dim_ones) |
| 16 | + return s * (x - mean) / np.sqrt(var + eps) + bias |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture |
| 20 | +def x_np(): |
| 21 | + # input size: (1, 2, 1, 3) |
| 22 | + return np.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(np.float32) |
| 23 | + |
| 24 | + |
| 25 | +@pytest.fixture |
| 26 | +def s_np(): |
| 27 | + return np.array([1.0, 1.5]).astype(np.float32) |
| 28 | + |
| 29 | + |
| 30 | +@pytest.fixture |
| 31 | +def b_np(): |
| 32 | + return np.array([0, 1]).astype(np.float32) |
| 33 | + |
| 34 | + |
| 35 | +def test_instancenorm(x_np, s_np, b_np): |
| 36 | + eps = 1e-5 |
| 37 | + x = torch.from_numpy(x_np) |
| 38 | + s = torch.from_numpy(s_np) |
| 39 | + b = torch.from_numpy(b_np) |
| 40 | + |
| 41 | + exp_y = instancenorm_reference(x_np, s_np, b_np, eps).astype(np.float32) |
| 42 | + exp_y_shape = (1, 2, 1, 3) |
| 43 | + op = InstanceNormWrapper([s, b], eps=eps) |
| 44 | + y = op(x) |
| 45 | + |
| 46 | + assert y.shape == exp_y_shape |
| 47 | + assert np.allclose(y.detach().numpy(), exp_y, rtol=1e-5, atol=1e-5) |
| 48 | + |
| 49 | + |
| 50 | +def test_instancenorm_lazy(x_np, s_np, b_np): |
| 51 | + eps = 1e-5 |
| 52 | + x = torch.from_numpy(x_np) |
| 53 | + s = torch.from_numpy(s_np) |
| 54 | + b = torch.from_numpy(b_np) |
| 55 | + |
| 56 | + exp_y = instancenorm_reference(x_np, s_np, b_np, eps).astype(np.float32) |
| 57 | + exp_y_shape = (1, 2, 1, 3) |
| 58 | + op = InstanceNormWrapper([], eps=eps) |
| 59 | + y = op(x, s, b) |
| 60 | + |
| 61 | + assert y.shape == exp_y_shape |
| 62 | + assert np.allclose(y.detach().numpy(), exp_y, rtol=1e-5, atol=1e-5) |
0 commit comments