|
1 | | -import torch |
2 | | -import pytest |
3 | 1 | import numpy as np |
| 2 | +import onnx |
| 3 | +import pytest |
| 4 | +import torch |
4 | 5 | from torch import nn |
5 | 6 | from onnx.backend.test.case.node.pad import pad_impl |
6 | 7 |
|
7 | 8 | from onnx2pytorch.helpers import to_onnx |
8 | 9 | from onnx2pytorch.utils import ( |
9 | 10 | is_constant, |
| 11 | + get_ops_names, |
10 | 12 | get_selection, |
11 | 13 | assign_values_to_dim, |
12 | 14 | get_activation_value, |
@@ -146,3 +148,111 @@ def weight(): |
146 | 148 | def test_extract_padding_params(weight, onnx_pads, torch_pads): |
147 | 149 | out_pads = extract_padding_params(onnx_pads) |
148 | 150 | assert out_pads == torch_pads |
| 151 | + |
| 152 | + |
| 153 | +def test_get_ops_names(): |
| 154 | + y_in = onnx.helper.make_tensor_value_info("y_in", onnx.TensorProto.FLOAT, [1]) |
| 155 | + y_out = onnx.helper.make_tensor_value_info("y_out", onnx.TensorProto.FLOAT, [1]) |
| 156 | + scan_out = onnx.helper.make_tensor_value_info( |
| 157 | + "scan_out", onnx.TensorProto.FLOAT, [] |
| 158 | + ) |
| 159 | + cond_in = onnx.helper.make_tensor_value_info("cond_in", onnx.TensorProto.BOOL, []) |
| 160 | + cond_out = onnx.helper.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, []) |
| 161 | + iter_count = onnx.helper.make_tensor_value_info( |
| 162 | + "iter_count", onnx.TensorProto.INT64, [] |
| 163 | + ) |
| 164 | + |
| 165 | + x = np.array([1, 2, 3, 4, 5]).astype(np.float32) |
| 166 | + |
| 167 | + x_const_node = onnx.helper.make_node( |
| 168 | + "Constant", |
| 169 | + inputs=[], |
| 170 | + outputs=["x"], |
| 171 | + value=onnx.helper.make_tensor( |
| 172 | + name="const_tensor_x", |
| 173 | + data_type=onnx.TensorProto.FLOAT, |
| 174 | + dims=x.shape, |
| 175 | + vals=x.flatten().astype(float), |
| 176 | + ), |
| 177 | + ) |
| 178 | + |
| 179 | + one_const_node = onnx.helper.make_node( |
| 180 | + "Constant", |
| 181 | + inputs=[], |
| 182 | + outputs=["one"], |
| 183 | + value=onnx.helper.make_tensor( |
| 184 | + name="const_tensor_one", data_type=onnx.TensorProto.INT64, dims=(), vals=[1] |
| 185 | + ), |
| 186 | + ) |
| 187 | + |
| 188 | + i_add_node = onnx.helper.make_node( |
| 189 | + "Add", inputs=["iter_count", "one"], outputs=["end"] |
| 190 | + ) |
| 191 | + |
| 192 | + start_unsqueeze_node = onnx.helper.make_node( |
| 193 | + "Unsqueeze", inputs=["iter_count"], outputs=["slice_start"], axes=[0] |
| 194 | + ) |
| 195 | + |
| 196 | + end_unsqueeze_node = onnx.helper.make_node( |
| 197 | + "Unsqueeze", inputs=["end"], outputs=["slice_end"], axes=[0] |
| 198 | + ) |
| 199 | + |
| 200 | + slice_node = onnx.helper.make_node( |
| 201 | + "Slice", inputs=["x", "slice_start", "slice_end"], outputs=["slice_out"] |
| 202 | + ) |
| 203 | + |
| 204 | + y_add_node = onnx.helper.make_node( |
| 205 | + "Add", inputs=["y_in", "slice_out"], outputs=["y_out"] |
| 206 | + ) |
| 207 | + |
| 208 | + identity_node = onnx.helper.make_node( |
| 209 | + "Identity", inputs=["cond_in"], outputs=["cond_out"] |
| 210 | + ) |
| 211 | + |
| 212 | + scan_identity_node = onnx.helper.make_node( |
| 213 | + "Identity", inputs=["y_out"], outputs=["scan_out"] |
| 214 | + ) |
| 215 | + |
| 216 | + loop_body = onnx.helper.make_graph( |
| 217 | + [ |
| 218 | + identity_node, |
| 219 | + x_const_node, |
| 220 | + one_const_node, |
| 221 | + i_add_node, |
| 222 | + start_unsqueeze_node, |
| 223 | + end_unsqueeze_node, |
| 224 | + slice_node, |
| 225 | + y_add_node, |
| 226 | + scan_identity_node, |
| 227 | + ], |
| 228 | + "loop_body", |
| 229 | + [iter_count, cond_in, y_in], |
| 230 | + [cond_out, y_out, scan_out], |
| 231 | + ) |
| 232 | + |
| 233 | + node = onnx.helper.make_node( |
| 234 | + "Loop", |
| 235 | + inputs=["trip_count", "cond", "y"], |
| 236 | + outputs=["res_y", "res_scan"], |
| 237 | + body=loop_body, |
| 238 | + ) |
| 239 | + |
| 240 | + trip_count = onnx.helper.make_tensor_value_info( |
| 241 | + "trip_count", onnx.TensorProto.INT64, [] |
| 242 | + ) |
| 243 | + cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []) |
| 244 | + y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]) |
| 245 | + res_y = onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]) |
| 246 | + res_scan = onnx.helper.make_tensor_value_info( |
| 247 | + "res_scan", onnx.TensorProto.FLOAT, [] |
| 248 | + ) |
| 249 | + |
| 250 | + graph_def = onnx.helper.make_graph( |
| 251 | + nodes=[node], |
| 252 | + name="test-model", |
| 253 | + inputs=[trip_count, cond, y], |
| 254 | + outputs=[res_y, res_scan], |
| 255 | + ) |
| 256 | + |
| 257 | + ops_names = set(["Add", "Constant", "Identity", "Loop", "Slice", "Unsqueeze"]) |
| 258 | + assert get_ops_names(graph_def) == ops_names |
0 commit comments