Skip to content

Commit 5f95c61

Browse files
committed
Add tests for reducel2
1 parent 8b1a83e commit 5f95c61

File tree

2 files changed

+253
-3
lines changed

2 files changed

+253
-3
lines changed

onnx2pytorch/operations/reducel2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import torch
22
from torch import nn
33

4+
45
class ReduceL2(nn.Module):
56
def __init__(
67
self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False
78
):
89
self.opset_version = opset_version
910
self.dim = dim
10-
self.keepdim = keepdim
11+
self.keepdim = bool(keepdim)
1112
self.noop_with_empty_axes = noop_with_empty_axes
1213
super().__init__()
1314

@@ -21,11 +22,11 @@ def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
2122
return data
2223
else:
2324
dims = tuple(range(data.ndim))
24-
25+
2526
if isinstance(dims, int):
2627
dim = dims
2728
else:
28-
dim=tuple(list(dims))
29+
dim = tuple(list(dims))
2930

3031
ret = torch.sqrt(torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim))
3132
return ret
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import numpy as np
2+
import onnx
3+
import pytest
4+
import torch
5+
6+
from onnx2pytorch.convert.operations import convert_operations
7+
from onnx2pytorch.operations import ReduceL2
8+
9+
10+
@pytest.fixture
11+
def tensor():
12+
return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
13+
14+
15+
def test_reduce_l2_older_opset_version(tensor):
16+
shape = [3, 2, 2]
17+
axes = np.array([2], dtype=np.int64)
18+
keepdims = 0
19+
20+
data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
21+
op = ReduceL2(opset_version=10, keepdim=keepdims, dim=axes)
22+
23+
reduced = np.sqrt(
24+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
25+
)
26+
27+
out = op(torch.from_numpy(data), axes=axes)
28+
np.testing.assert_array_equal(out, reduced)
29+
30+
31+
def test_do_not_keepdims_older_opset_version() -> None:
32+
opset_version = 10
33+
shape = [3, 2, 2]
34+
axes = np.array([2], dtype=np.int64)
35+
keepdims = 0
36+
37+
node = onnx.helper.make_node(
38+
"ReduceL2",
39+
inputs=["data"],
40+
outputs=["reduced"],
41+
keepdims=keepdims,
42+
axes=axes,
43+
)
44+
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
45+
46+
ops = list(convert_operations(graph, opset_version))
47+
op = ops[0][2]
48+
49+
assert isinstance(op, ReduceL2)
50+
51+
data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
52+
# print(data)
53+
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]
54+
55+
reduced = np.sqrt(
56+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
57+
)
58+
# print(reduced)
59+
# [[2.23606798, 5.],
60+
# [7.81024968, 10.63014581],
61+
# [13.45362405, 16.2788206]]
62+
63+
out = op(torch.from_numpy(data))
64+
np.testing.assert_array_equal(out, reduced)
65+
66+
np.random.seed(0)
67+
data = np.random.uniform(-10, 10, shape).astype(np.float32)
68+
reduced = np.sqrt(
69+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
70+
)
71+
72+
out = op(torch.from_numpy(data))
73+
np.testing.assert_array_equal(out, reduced)
74+
75+
76+
def test_do_not_keepdims() -> None:
77+
shape = [3, 2, 2]
78+
axes = np.array([2], dtype=np.int64)
79+
keepdims = 0
80+
81+
node = onnx.helper.make_node(
82+
"ReduceL2",
83+
inputs=["data", "axes"],
84+
outputs=["reduced"],
85+
keepdims=keepdims,
86+
)
87+
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
88+
ops = list(convert_operations(graph, 18))
89+
op = ops[0][2]
90+
91+
assert isinstance(op, ReduceL2)
92+
93+
data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
94+
# print(data)
95+
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]
96+
97+
reduced = np.sqrt(
98+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
99+
)
100+
# print(reduced)
101+
# [[2.23606798, 5.],
102+
# [7.81024968, 10.63014581],
103+
# [13.45362405, 16.2788206]]
104+
105+
out = op(torch.from_numpy(data), axes=axes)
106+
np.testing.assert_array_equal(out, reduced)
107+
108+
np.random.seed(0)
109+
data = np.random.uniform(-10, 10, shape).astype(np.float32)
110+
reduced = np.sqrt(
111+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
112+
)
113+
114+
out = op(torch.from_numpy(data), axes=axes)
115+
np.testing.assert_array_equal(out, reduced)
116+
117+
118+
def test_export_keepdims() -> None:
119+
shape = [3, 2, 2]
120+
axes = np.array([2], dtype=np.int64)
121+
keepdims = 1
122+
123+
node = onnx.helper.make_node(
124+
"ReduceL2",
125+
inputs=["data", "axes"],
126+
outputs=["reduced"],
127+
keepdims=keepdims,
128+
)
129+
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
130+
ops = list(convert_operations(graph, 18))
131+
op = ops[0][2]
132+
133+
data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
134+
# print(data)
135+
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]
136+
137+
reduced = np.sqrt(
138+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
139+
)
140+
# print(reduced)
141+
# [[[2.23606798], [5.]]
142+
# [[7.81024968], [10.63014581]]
143+
# [[13.45362405], [16.2788206 ]]]
144+
145+
out = op(torch.from_numpy(data), axes=axes)
146+
np.testing.assert_array_equal(out, reduced)
147+
148+
np.random.seed(0)
149+
data = np.random.uniform(-10, 10, shape).astype(np.float32)
150+
reduced = np.sqrt(
151+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
152+
)
153+
154+
out = op(torch.from_numpy(data), axes=axes)
155+
np.testing.assert_array_equal(out, reduced)
156+
157+
158+
def test_export_default_axes_keepdims() -> None:
159+
shape = [3, 2, 2]
160+
axes = np.array([], dtype=np.int64)
161+
keepdims = 1
162+
163+
node = onnx.helper.make_node(
164+
"ReduceL2", inputs=["data", "axes"], outputs=["reduced"], keepdims=keepdims
165+
)
166+
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
167+
ops = list(convert_operations(graph, 18))
168+
op = ops[0][2]
169+
170+
data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
171+
# print(data)
172+
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]
173+
174+
reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1))
175+
# print(reduced)
176+
# [[[25.49509757]]]
177+
178+
out = op(torch.from_numpy(data), axes=axes)
179+
np.testing.assert_array_equal(out, reduced)
180+
181+
np.random.seed(0)
182+
data = np.random.uniform(-10, 10, shape).astype(np.float32)
183+
reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1))
184+
185+
out = op(torch.from_numpy(data), axes=axes)
186+
np.testing.assert_array_equal(out, reduced)
187+
188+
189+
def test_export_negative_axes_keepdims() -> None:
190+
shape = [3, 2, 2]
191+
axes = np.array([-1], dtype=np.int64)
192+
keepdims = 1
193+
194+
node = onnx.helper.make_node(
195+
"ReduceL2",
196+
inputs=["data", "axes"],
197+
outputs=["reduced"],
198+
keepdims=keepdims,
199+
)
200+
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
201+
ops = list(convert_operations(graph, 18))
202+
op = ops[0][2]
203+
204+
data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
205+
# print(data)
206+
# [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]
207+
208+
reduced = np.sqrt(
209+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
210+
)
211+
# print(reduced)
212+
# [[[2.23606798], [5.]]
213+
# [[7.81024968], [10.63014581]]
214+
# [[13.45362405], [16.2788206 ]]]
215+
216+
out = op(torch.from_numpy(data), axes=axes)
217+
np.testing.assert_array_equal(out, reduced)
218+
219+
np.random.seed(0)
220+
data = np.random.uniform(-10, 10, shape).astype(np.float32)
221+
reduced = np.sqrt(
222+
np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1)
223+
)
224+
225+
out = op(torch.from_numpy(data), axes=axes)
226+
np.testing.assert_array_equal(out, reduced)
227+
228+
229+
def test_export_empty_set() -> None:
230+
shape = [2, 0, 4]
231+
keepdims = 1
232+
reduced_shape = [2, 1, 4]
233+
234+
node = onnx.helper.make_node(
235+
"ReduceL2",
236+
inputs=["data", "axes"],
237+
outputs=["reduced"],
238+
keepdims=keepdims,
239+
)
240+
graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], [])
241+
ops = list(convert_operations(graph, 18))
242+
op = ops[0][2]
243+
244+
data = np.array([], dtype=np.float32).reshape(shape)
245+
axes = np.array([1], dtype=np.int64)
246+
reduced = np.array(np.zeros(reduced_shape, dtype=np.float32))
247+
248+
out = op(torch.from_numpy(data), axes=axes)
249+
np.testing.assert_array_equal(out, reduced)

0 commit comments

Comments
 (0)