Skip to content

Commit 780e008

Browse files
committed
Add ReduceSumSquare with tests.
1 parent 8cbd450 commit 780e008

File tree

4 files changed

+273
-0
lines changed

4 files changed

+273
-0
lines changed

onnx2pytorch/convert/operations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
232232
op = partial(torch.prod, **kwargs)
233233
elif node.op_type == "ReduceSum":
234234
op = ReduceSum(opset_version=opset_version, **extract_attributes(node))
235+
elif node.op_type == "ReduceSumSquare":
236+
op = ReduceSumSquare(
237+
opset_version=opset_version, **extract_attributes(node)
238+
)
235239
elif node.op_type == "ReduceL2":
236240
op = ReduceL2(opset_version=opset_version, **extract_attributes(node))
237241
elif node.op_type == "Relu":

onnx2pytorch/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .randomuniformlike import RandomUniformLike
3131
from .reducemax import ReduceMax
3232
from .reducesum import ReduceSum
33+
from .reducesumsquare import ReduceSumSquare
3334
from .reducel2 import ReduceL2
3435
from .reshape import Reshape
3536
from .resize import Resize, Upsample
@@ -80,6 +81,7 @@
8081
"RandomUniformLike",
8182
"ReduceMax",
8283
"ReduceSum",
84+
"ReduceSumSquare",
8385
"ReduceL2",
8486
"Reshape",
8587
"Resize",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class ReduceSumSquare(nn.Module):
6+
"""
7+
Computes the sum of the squared elements of the input tensor's elements along the provided axes.
8+
9+
Equivalent to ReduceSum(Square(data), axes, keepdim).
10+
"""
11+
12+
def __init__(
13+
self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False
14+
):
15+
self.opset_version = opset_version
16+
self.dim = dim
17+
self.keepdim = bool(keepdim)
18+
self.noop_with_empty_axes = noop_with_empty_axes
19+
super().__init__()
20+
21+
def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
22+
# In opset < 13, axes is an attribute (self.dim)
23+
# In opset >= 13, axes is an optional input
24+
if self.opset_version < 13:
25+
dims = self.dim
26+
else:
27+
dims = axes
28+
29+
if dims is None:
30+
if self.noop_with_empty_axes:
31+
return data
32+
else:
33+
# Reduce over all dimensions
34+
dims = tuple(range(data.ndim))
35+
36+
if isinstance(dims, int):
37+
dim = dims
38+
else:
39+
dim = tuple(list(dims))
40+
41+
# Compute sum of squares: sum(x^2)
42+
ret = torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim)
43+
return ret
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import numpy as np
2+
import onnxruntime as ort
3+
import pytest
4+
import torch
5+
from onnx import helper, TensorProto
6+
7+
from onnx2pytorch.convert import ConvertModel
8+
from onnx2pytorch.operations.reducesumsquare import ReduceSumSquare
9+
10+
11+
@pytest.mark.parametrize(
12+
"input_shape,axes,keepdims",
13+
[
14+
# Test with different axes
15+
([3, 4, 5], [0], 1),
16+
([3, 4, 5], [1], 1),
17+
([3, 4, 5], [2], 1),
18+
([3, 4, 5], [-1], 1),
19+
# Test with multiple axes
20+
([3, 4, 5], [0, 1], 1),
21+
([3, 4, 5], [1, 2], 1),
22+
([3, 4, 5], [0, 2], 1),
23+
# Test with keepdims=0
24+
([3, 4, 5], [1], 0),
25+
([3, 4, 5], [0, 2], 0),
26+
# Test with all axes (None means reduce all)
27+
([3, 4, 5], None, 1),
28+
([3, 4, 5], None, 0),
29+
# Test 2D inputs
30+
([5, 10], [0], 1),
31+
([5, 10], [1], 1),
32+
([5, 10], None, 1),
33+
# Test 1D inputs
34+
([10], [0], 1),
35+
([10], None, 1),
36+
],
37+
)
38+
def test_reducesumsquare_onnxruntime(input_shape, axes, keepdims):
39+
"""Test ReduceSumSquare against onnxruntime."""
40+
np.random.seed(42)
41+
42+
# Create input
43+
X = np.random.randn(*input_shape).astype(np.float32)
44+
45+
# Create ONNX graph with ReduceSumSquare node
46+
input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)
47+
output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, None)
48+
49+
# Use axes as attribute (supported in all opset versions)
50+
node_attrs = {"keepdims": keepdims}
51+
if axes is not None:
52+
node_attrs["axes"] = axes
53+
54+
reducesumsquare_node = helper.make_node(
55+
"ReduceSumSquare",
56+
inputs=["X"],
57+
outputs=["Y"],
58+
**node_attrs,
59+
)
60+
61+
graph = helper.make_graph(
62+
[reducesumsquare_node],
63+
"reducesumsquare_test",
64+
[input_tensor],
65+
[output_tensor],
66+
)
67+
68+
model = helper.make_model(
69+
graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8
70+
)
71+
72+
# Run with onnxruntime
73+
ort_session = ort.InferenceSession(model.SerializeToString())
74+
ort_outputs = ort_session.run(None, {"X": X})
75+
expected_Y = ort_outputs[0]
76+
77+
# Convert to PyTorch and run
78+
o2p_model = ConvertModel(model, experimental=True)
79+
X_torch = torch.from_numpy(X)
80+
81+
with torch.no_grad():
82+
o2p_output = o2p_model(X_torch)
83+
84+
# Compare outputs
85+
torch.testing.assert_close(
86+
o2p_output,
87+
torch.from_numpy(expected_Y),
88+
rtol=1e-5,
89+
atol=1e-5,
90+
)
91+
92+
93+
def test_reducesumsquare_formula():
94+
"""Test that ReduceSumSquare implements sum(x^2)."""
95+
X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
96+
97+
# Manual computation
98+
expected_all = torch.sum(X**2)
99+
expected_axis0 = torch.sum(X**2, dim=0, keepdim=True)
100+
expected_axis1 = torch.sum(X**2, dim=1, keepdim=True)
101+
102+
# Test reduce all
103+
104+
op_all = ReduceSumSquare(opset_version=13, dim=None, keepdim=True)
105+
result_all = op_all(X)
106+
torch.testing.assert_close(
107+
result_all, expected_all.view(1, 1), rtol=1e-6, atol=1e-6
108+
)
109+
110+
# Test reduce axis 0
111+
op_axis0 = ReduceSumSquare(opset_version=11, dim=0, keepdim=True)
112+
result_axis0 = op_axis0(X)
113+
torch.testing.assert_close(result_axis0, expected_axis0, rtol=1e-6, atol=1e-6)
114+
115+
# Test reduce axis 1
116+
op_axis1 = ReduceSumSquare(opset_version=11, dim=1, keepdim=True)
117+
result_axis1 = op_axis1(X)
118+
torch.testing.assert_close(result_axis1, expected_axis1, rtol=1e-6, atol=1e-6)
119+
120+
121+
def test_reducesumsquare_keepdims():
122+
"""Test keepdims parameter."""
123+
X = torch.randn(2, 3, 4)
124+
125+
# With keepdims=True
126+
op_keep = ReduceSumSquare(opset_version=11, dim=1, keepdim=True)
127+
result_keep = op_keep(X)
128+
assert result_keep.shape == (2, 1, 4)
129+
130+
# With keepdims=False
131+
op_no_keep = ReduceSumSquare(opset_version=11, dim=1, keepdim=False)
132+
result_no_keep = op_no_keep(X)
133+
assert result_no_keep.shape == (2, 4)
134+
135+
# Values should be the same (just different shapes)
136+
torch.testing.assert_close(
137+
result_keep.squeeze(1), result_no_keep, rtol=1e-6, atol=1e-6
138+
)
139+
140+
141+
def test_reducesumsquare_noop_with_empty_axes():
142+
"""Test noop_with_empty_axes parameter."""
143+
X = torch.randn(2, 3, 4)
144+
145+
# With noop_with_empty_axes=True and no axes, should return input unchanged
146+
op_noop = ReduceSumSquare(
147+
opset_version=13, dim=None, keepdim=True, noop_with_empty_axes=True
148+
)
149+
result_noop = op_noop(X)
150+
torch.testing.assert_close(result_noop, X, rtol=1e-6, atol=1e-6)
151+
152+
# With noop_with_empty_axes=False and no axes, should reduce all
153+
op_reduce = ReduceSumSquare(
154+
opset_version=13, dim=None, keepdim=True, noop_with_empty_axes=False
155+
)
156+
result_reduce = op_reduce(X)
157+
expected = torch.sum(X**2).view(1, 1, 1)
158+
torch.testing.assert_close(result_reduce, expected, rtol=1e-6, atol=1e-6)
159+
160+
161+
def test_reducesumsquare_with_axes_input():
162+
"""Test with axes as an input tensor (for frameworks that support it)."""
163+
X = torch.randn(2, 3, 4)
164+
165+
# Opset 13+ supports axes as input
166+
op = ReduceSumSquare(opset_version=13, dim=None, keepdim=True)
167+
168+
# Provide axes as a tensor
169+
axes = torch.tensor([0, 2], dtype=torch.int64)
170+
result = op(X, axes)
171+
172+
# Expected: reduce along axes 0 and 2
173+
expected = torch.sum(X**2, dim=(0, 2), keepdim=True)
174+
torch.testing.assert_close(result, expected, rtol=1e-6, atol=1e-6)
175+
assert result.shape == (1, 3, 1)
176+
177+
178+
def test_reducesumsquare_vs_reducesum_square():
179+
"""Test that ReduceSumSquare(x) == ReduceSum(Square(x))."""
180+
X = torch.randn(3, 4, 5)
181+
182+
# ReduceSumSquare
183+
op_sumsquare = ReduceSumSquare(opset_version=11, dim=1, keepdim=True)
184+
result_sumsquare = op_sumsquare(X)
185+
186+
# ReduceSum(Square(x))
187+
result_square_sum = torch.sum(X**2, dim=1, keepdim=True)
188+
189+
torch.testing.assert_close(
190+
result_sumsquare, result_square_sum, rtol=1e-6, atol=1e-6
191+
)
192+
193+
194+
def test_reducesumsquare_negative_axis():
195+
"""Test with negative axis values."""
196+
X = torch.randn(2, 3, 4)
197+
198+
# axis=-1 should be equivalent to axis=2
199+
op_neg = ReduceSumSquare(opset_version=11, dim=-1, keepdim=True)
200+
result_neg = op_neg(X)
201+
202+
op_pos = ReduceSumSquare(opset_version=11, dim=2, keepdim=True)
203+
result_pos = op_pos(X)
204+
205+
torch.testing.assert_close(result_neg, result_pos, rtol=1e-6, atol=1e-6)
206+
207+
208+
def test_reducesumsquare_gradient():
209+
"""Test that gradients flow correctly through ReduceSumSquare."""
210+
X = torch.randn(2, 3, 4, requires_grad=True)
211+
212+
op = ReduceSumSquare(opset_version=11, dim=1, keepdim=True)
213+
result = op(X)
214+
215+
# Compute gradient
216+
loss = result.sum()
217+
loss.backward()
218+
219+
# Gradient of sum(x^2) with respect to x is 2x
220+
# After summing along dim=1, gradient should be 2x broadcast along dim=1
221+
expected_grad = 2 * X
222+
223+
assert X.grad is not None
224+
torch.testing.assert_close(X.grad, expected_grad, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)