Skip to content

Commit 4418de8

Browse files
test_utils.py get_ops_names
1 parent 29f49f5 commit 4418de8

File tree

1 file changed

+112
-2
lines changed

1 file changed

+112
-2
lines changed

tests/onnx2pytorch/test_utils.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
import torch
2-
import pytest
31
import numpy as np
2+
import onnx
3+
import pytest
4+
import torch
45
from torch import nn
56
from onnx.backend.test.case.node.pad import pad_impl
67

78
from onnx2pytorch.helpers import to_onnx
89
from onnx2pytorch.utils import (
910
is_constant,
11+
get_ops_names,
1012
get_selection,
1113
assign_values_to_dim,
1214
get_activation_value,
@@ -146,3 +148,111 @@ def weight():
146148
def test_extract_padding_params(weight, onnx_pads, torch_pads):
147149
out_pads = extract_padding_params(onnx_pads)
148150
assert out_pads == torch_pads
151+
152+
153+
def test_get_ops_names():
154+
y_in = onnx.helper.make_tensor_value_info("y_in", onnx.TensorProto.FLOAT, [1])
155+
y_out = onnx.helper.make_tensor_value_info("y_out", onnx.TensorProto.FLOAT, [1])
156+
scan_out = onnx.helper.make_tensor_value_info(
157+
"scan_out", onnx.TensorProto.FLOAT, []
158+
)
159+
cond_in = onnx.helper.make_tensor_value_info("cond_in", onnx.TensorProto.BOOL, [])
160+
cond_out = onnx.helper.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, [])
161+
iter_count = onnx.helper.make_tensor_value_info(
162+
"iter_count", onnx.TensorProto.INT64, []
163+
)
164+
165+
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
166+
167+
x_const_node = onnx.helper.make_node(
168+
"Constant",
169+
inputs=[],
170+
outputs=["x"],
171+
value=onnx.helper.make_tensor(
172+
name="const_tensor_x",
173+
data_type=onnx.TensorProto.FLOAT,
174+
dims=x.shape,
175+
vals=x.flatten().astype(float),
176+
),
177+
)
178+
179+
one_const_node = onnx.helper.make_node(
180+
"Constant",
181+
inputs=[],
182+
outputs=["one"],
183+
value=onnx.helper.make_tensor(
184+
name="const_tensor_one", data_type=onnx.TensorProto.INT64, dims=(), vals=[1]
185+
),
186+
)
187+
188+
i_add_node = onnx.helper.make_node(
189+
"Add", inputs=["iter_count", "one"], outputs=["end"]
190+
)
191+
192+
start_unsqueeze_node = onnx.helper.make_node(
193+
"Unsqueeze", inputs=["iter_count"], outputs=["slice_start"], axes=[0]
194+
)
195+
196+
end_unsqueeze_node = onnx.helper.make_node(
197+
"Unsqueeze", inputs=["end"], outputs=["slice_end"], axes=[0]
198+
)
199+
200+
slice_node = onnx.helper.make_node(
201+
"Slice", inputs=["x", "slice_start", "slice_end"], outputs=["slice_out"]
202+
)
203+
204+
y_add_node = onnx.helper.make_node(
205+
"Add", inputs=["y_in", "slice_out"], outputs=["y_out"]
206+
)
207+
208+
identity_node = onnx.helper.make_node(
209+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
210+
)
211+
212+
scan_identity_node = onnx.helper.make_node(
213+
"Identity", inputs=["y_out"], outputs=["scan_out"]
214+
)
215+
216+
loop_body = onnx.helper.make_graph(
217+
[
218+
identity_node,
219+
x_const_node,
220+
one_const_node,
221+
i_add_node,
222+
start_unsqueeze_node,
223+
end_unsqueeze_node,
224+
slice_node,
225+
y_add_node,
226+
scan_identity_node,
227+
],
228+
"loop_body",
229+
[iter_count, cond_in, y_in],
230+
[cond_out, y_out, scan_out],
231+
)
232+
233+
node = onnx.helper.make_node(
234+
"Loop",
235+
inputs=["trip_count", "cond", "y"],
236+
outputs=["res_y", "res_scan"],
237+
body=loop_body,
238+
)
239+
240+
trip_count = onnx.helper.make_tensor_value_info(
241+
"trip_count", onnx.TensorProto.INT64, []
242+
)
243+
cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, [])
244+
y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1])
245+
res_y = onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1])
246+
res_scan = onnx.helper.make_tensor_value_info(
247+
"res_scan", onnx.TensorProto.FLOAT, []
248+
)
249+
250+
graph_def = onnx.helper.make_graph(
251+
nodes=[node],
252+
name="test-model",
253+
inputs=[trip_count, cond, y],
254+
outputs=[res_y, res_scan],
255+
)
256+
257+
ops_names = set(["Add", "Constant", "Identity", "Loop", "Slice", "Unsqueeze"])
258+
assert get_ops_names(graph_def) == ops_names

0 commit comments

Comments
 (0)