Skip to content

Commit dc7684c

Browse files
authored
Merge pull request #15 from calvinmccarter-at-lightmatter/reducesum
LSTM conversion, training & multi-device support, and more
2 parents 8eb6ae8 + 4418de8 commit dc7684c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+2963
-281
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ The Uncompromising Code Formatter: [Black](https://github.com/psf/black)
5252
```black {source_file_or_directory}```
5353

5454
Install it into pre-commit hook to always commit nicely formatted code:
55-
```pre-commmit install```
55+
```pre-commit install```
5656

5757
### Testing
5858
[Pytest](https://docs.pytest.org/en/latest/) and [tox](https://tox.readthedocs.io/en/latest/).
@@ -66,4 +66,4 @@ Add any custom models to `./fixtures` folder to test their conversion.
6666
### Debugging
6767
Set `ConvertModel(..., debug=True)` to compare each converted
6868
activation from pytorch with the activation from onnxruntime.
69-
This helps identify where in the graph the activations start to differ.
69+
This helps identify where in the graph the activations start to differ.

onnx2pytorch/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from torch import nn
2+
from torch.nn.modules.conv import _ConvNd
3+
from torch.nn.modules.pooling import _MaxPoolNd
4+
from onnx2pytorch.operations import (
5+
BatchNormWrapper,
6+
InstanceNormWrapper,
7+
Loop,
8+
LSTMWrapper,
9+
Split,
10+
TopK,
11+
)
12+
13+
14+
COMPOSITE_LAYERS = (nn.Sequential,)
15+
MULTIOUTPUT_LAYERS = (_MaxPoolNd, Loop, LSTMWrapper, Split, TopK)
16+
STANDARD_LAYERS = (
17+
_ConvNd,
18+
BatchNormWrapper,
19+
InstanceNormWrapper,
20+
LSTMWrapper,
21+
nn.Linear,
22+
)

onnx2pytorch/convert/attribute.py

Lines changed: 102 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def extract_attr_values(attr):
4141
value = numpy_helper.to_array(attr.t)
4242
elif attr.type == AttributeType["STRING"]:
4343
value = attr.s.decode()
44+
elif attr.type == AttributeType["GRAPH"]:
45+
value = attr.g
4446
else:
4547
raise NotImplementedError(
4648
"Extraction of attribute type {} not implemented.".format(attr.type)
@@ -52,21 +54,27 @@ def extract_attributes(node):
5254
"""Extract onnx attributes. Map onnx feature naming to pytorch."""
5355
kwargs = {}
5456
for attr in node.attribute:
55-
if attr.name == "dilations":
56-
kwargs["dilation"] = extract_attr_values(attr)
57-
elif attr.name == "group":
58-
kwargs["groups"] = extract_attr_values(attr)
59-
elif attr.name == "kernel_shape":
60-
kwargs["kernel_size"] = extract_attr_values(attr)
61-
elif attr.name == "pads":
62-
params = extract_attr_values(attr)
63-
if node.op_type == "Pad":
64-
kwargs["padding"] = extract_padding_params(params)
57+
if attr.name == "activation_alpha":
58+
kwargs["activation_alpha"] = extract_attr_values(attr)
59+
elif attr.name == "activation_beta":
60+
kwargs["activation_beta"] = extract_attr_values(attr)
61+
elif attr.name == "activations":
62+
kwargs["activations"] = extract_attr_values(attr)
63+
elif attr.name == "alpha":
64+
if node.op_type == "LeakyRelu":
65+
kwargs["negative_slope"] = extract_attr_values(attr)
66+
elif node.op_type in ("Elu", "ThresholdedRelu"):
67+
kwargs["alpha"] = extract_attr_values(attr)
6568
else:
66-
# Works for Conv, MaxPooling and other layers from convert_layer func
67-
kwargs["padding"] = extract_padding_params_for_conv_layer(params)
68-
elif attr.name == "strides":
69-
kwargs["stride"] = extract_attr_values(attr)
69+
kwargs["weight_multiplier"] = extract_attr_values(attr)
70+
elif attr.name == "auto_pad":
71+
value = extract_attr_values(attr)
72+
if value == "NOTSET":
73+
pass
74+
else:
75+
raise NotImplementedError(
76+
"auto_pad={} functionality not implemented.".format(value)
77+
)
7078
elif attr.name == "axis" and node.op_type == "Flatten":
7179
kwargs["start_dim"] = extract_attr_values(attr)
7280
elif attr.name == "axis" or attr.name == "axes":
@@ -75,62 +83,103 @@ def extract_attributes(node):
7583
kwargs["dim"] = v[0]
7684
else:
7785
kwargs["dim"] = v
78-
elif attr.name == "keepdims":
79-
kwargs["keepdim"] = bool(extract_attr_values(attr))
86+
elif attr.name == "beta":
87+
kwargs["bias_multiplier"] = extract_attr_values(attr)
88+
elif attr.name == "body":
89+
kwargs["body"] = extract_attr_values(attr)
90+
elif attr.name == "ceil_mode":
91+
kwargs["ceil_mode"] = bool(extract_attr_values(attr))
92+
elif attr.name == "center_point_box":
93+
kwargs["center_point_box"] = extract_attr_values(attr)
94+
elif attr.name == "clip":
95+
kwargs["clip"] = extract_attr_values(attr)
96+
elif attr.name == "coordinate_transformation_mode":
97+
arg = extract_attr_values(attr)
98+
if arg == "align_corners":
99+
kwargs["align_corners"] = True
100+
else:
101+
warnings.warn(
102+
"Pytorch's interpolate uses no coordinate_transformation_mode={}. "
103+
"Result might differ.".format(arg)
104+
)
105+
elif attr.name == "dilations":
106+
kwargs["dilation"] = extract_attr_values(attr)
107+
elif attr.name == "direction":
108+
kwargs["direction"] = extract_attr_values(attr)
109+
elif attr.name == "ends":
110+
kwargs["ends"] = extract_attr_values(attr)
80111
elif attr.name == "epsilon":
81112
kwargs["eps"] = extract_attr_values(attr)
113+
elif attr.name == "group":
114+
kwargs["groups"] = extract_attr_values(attr)
115+
elif attr.name == "hidden_size":
116+
kwargs["hidden_size"] = extract_attr_values(attr)
117+
elif attr.name == "input_forget":
118+
kwargs["input_forget"] = extract_attr_values(attr)
119+
elif attr.name == "keepdims":
120+
kwargs["keepdim"] = bool(extract_attr_values(attr))
121+
elif attr.name == "kernel_shape":
122+
kwargs["kernel_size"] = extract_attr_values(attr)
123+
elif attr.name == "largest":
124+
kwargs["largest"] = extract_attr_values(attr)
125+
elif attr.name == "layout":
126+
kwargs["layout"] = extract_attr_values(attr)
127+
elif attr.name == "mode":
128+
kwargs["mode"] = extract_attr_values(attr)
82129
elif attr.name == "momentum":
83130
kwargs["momentum"] = extract_attr_values(attr)
84-
elif attr.name == "ceil_mode":
85-
kwargs["ceil_mode"] = bool(extract_attr_values(attr))
86-
elif attr.name == "value":
87-
kwargs["constant"] = extract_attr_values(attr)
131+
elif attr.name == "noop_with_empty_axes":
132+
kwargs["noop_with_empty_axes"] = extract_attr_values(attr)
133+
elif attr.name == "output_shape" and node.op_type == "ConvTranspose":
134+
raise NotImplementedError(
135+
"ConvTranspose with dynamic padding not implemented."
136+
)
137+
elif attr.name == "pads":
138+
params = extract_attr_values(attr)
139+
if node.op_type == "Pad":
140+
kwargs["padding"] = extract_padding_params(params)
141+
else:
142+
# Works for Conv, MaxPooling and other layers from convert_layer func
143+
kwargs["padding"] = extract_padding_params_for_conv_layer(params)
88144
elif attr.name == "perm":
89145
kwargs["dims"] = extract_attr_values(attr)
90-
elif attr.name == "split":
91-
kwargs["split_size_or_sections"] = extract_attr_values(attr)
146+
elif attr.name == "repeats":
147+
kwargs["repeats"] = extract_attr_values(attr)
148+
elif attr.name == "sorted":
149+
kwargs["sorted"] = extract_attr_values(attr)
150+
elif attr.name == "sparse_value":
151+
kwargs["constant"] = extract_attr_values(attr)
92152
elif attr.name == "spatial":
93153
kwargs["spatial"] = extract_attr_values(attr) # Batch norm parameter
154+
elif attr.name == "split":
155+
kwargs["split_size_or_sections"] = extract_attr_values(attr)
156+
elif attr.name == "strides":
157+
kwargs["stride"] = extract_attr_values(attr)
158+
elif attr.name == "starts":
159+
kwargs["starts"] = extract_attr_values(attr)
94160
elif attr.name == "to":
95161
kwargs["dtype"] = TENSOR_PROTO_MAPPING[extract_attr_values(attr)].lower()
96-
elif attr.name == "mode":
97-
kwargs["mode"] = extract_attr_values(attr)
98162
elif attr.name == "transB":
99163
kwargs["transpose_weight"] = not extract_attr_values(attr)
100164
elif attr.name == "transA":
101165
kwargs["transpose_activation"] = bool(extract_attr_values(attr))
102-
elif attr.name == "alpha" and node.op_type == "LeakyRelu":
103-
kwargs["negative_slope"] = extract_attr_values(attr)
104-
elif attr.name == "alpha" and node.op_type == "Elu":
105-
kwargs["alpha"] = extract_attr_values(attr)
106-
elif attr.name == "alpha":
107-
kwargs["weight_multiplier"] = extract_attr_values(attr)
108-
elif attr.name == "beta":
109-
kwargs["bias_multiplier"] = extract_attr_values(attr)
110-
elif attr.name == "starts":
111-
kwargs["starts"] = extract_attr_values(attr)
112-
elif attr.name == "ends":
113-
kwargs["ends"] = extract_attr_values(attr)
114-
elif attr.name == "coordinate_transformation_mode":
115-
arg = extract_attr_values(attr)
116-
if arg == "align_corners":
117-
kwargs["align_corners"] = True
118-
else:
119-
warnings.warn(
120-
"Pytorch's interpolate uses no coordinate_transformation_mode={}. "
121-
"Result might differ.".format(arg)
122-
)
166+
elif attr.name == "value":
167+
kwargs["constant"] = extract_attr_values(attr)
168+
elif attr.name == "value_float":
169+
kwargs["constant"] = extract_attr_values(attr)
170+
elif attr.name == "value_floats":
171+
kwargs["constant"] = extract_attr_values(attr)
172+
elif attr.name == "value_int":
173+
kwargs["constant"] = extract_attr_values(attr)
174+
elif attr.name == "value_ints":
175+
kwargs["constant"] = extract_attr_values(attr)
176+
elif attr.name == "value_string":
177+
kwargs["constant"] = extract_attr_values(attr)
178+
elif attr.name == "value_strings":
179+
kwargs["constant"] = extract_attr_values(attr)
123180
elif node.op_type == "Resize":
124181
# These parameters are not used, warn in Resize operator
125182
kwargs[attr.name] = extract_attr_values(attr)
126-
elif attr.name == "auto_pad":
127-
value = extract_attr_values(attr)
128-
if value == "NOTSET":
129-
pass
130-
else:
131-
raise NotImplementedError(
132-
"auto_pad={} functionality not implemented.".format(value)
133-
)
134183
else:
135184
raise NotImplementedError(
136185
"Extraction of attribute {} not implemented.".format(attr.name)

onnx2pytorch/convert/debug.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@ def debug_model_conversion(onnx_model, inputs, pred_act, node, rtol=1e-3, atol=1
1010
raise TypeError("inputs should be in a list.")
1111

1212
if not all(isinstance(x, np.ndarray) for x in inputs):
13-
inputs = [x.detach().numpy() for x in inputs]
13+
inputs = [x.detach().cpu().numpy() for x in inputs]
1414

1515
exp_act = get_activation_value(onnx_model, inputs, list(node.output))
1616
if isinstance(pred_act, list):
17+
assert len(exp_act) == len(pred_act)
1718
for a, b in zip(exp_act, pred_act):
18-
assert torch.allclose(torch.from_numpy(a), b, rtol=rtol, atol=atol)
19+
exp = torch.from_numpy(a).cpu()
20+
pred = b.cpu()
21+
assert torch.equal(torch.tensor(exp.shape), torch.tensor(pred.shape))
22+
assert torch.allclose(exp, pred, rtol=rtol, atol=atol)
1923
else:
20-
a = torch.from_numpy(exp_act[0])
21-
b = pred_act
22-
assert torch.allclose(a, b, rtol=rtol, atol=atol)
24+
exp = torch.from_numpy(exp_act[0]).cpu()
25+
pred = pred_act.cpu()
26+
assert torch.equal(torch.tensor(exp.shape), torch.tensor(pred.shape))
27+
assert torch.allclose(exp, pred, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)