Skip to content

Commit c80e879

Browse files
authored
Merge pull request #1 from ToriML/develop
New models support
2 parents 305cf96 + e678e93 commit c80e879

25 files changed

+672
-57
lines changed

README.md

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# ONNX to PyTorch
2-
[![GitHub License](https://img.shields.io/badge/License-Apache-2.svg)](https://opensource.org/licenses/Apache-2.0)
2+
![PyPI - License](https://img.shields.io/pypi/l/onnx2pytorch?color)
33
[![CircleCI](https://circleci.com/gh/ToriML/onnx2pytorch.svg?style=shield)](https://app.circleci.com/pipelines/github/ToriML/onnx2pytorch)
44
[![Downloads](https://pepy.tech/badge/onnx2pytorch)](https://pepy.tech/project/onnx2pytorch)
5+
![PyPI](https://img.shields.io/pypi/v/onnx2pytorch)
56

67
A library to transform ONNX model to PyTorch. This library enables use of PyTorch
78
backend and all of its great features for manipulation of neural networks.
@@ -19,10 +20,17 @@ pytorch_model = ConvertModel(onnx_model)
1920
```
2021

2122
Currently supported and tested models from [onnx_zoo](https://github.com/onnx/models):
22-
- MobileNet
23-
- ResNet
24-
- ShuffleNet
25-
- Bert
23+
- [MobileNet](https://github.com/onnx/models/tree/master/vision/classification/mobilenet)
24+
- [ResNet](https://github.com/onnx/models/tree/master/vision/classification/resnet)
25+
- [ShuffleNet_V2](https://github.com/onnx/models/tree/master/vision/classification/shufflenet)
26+
- [BERT-Squad](https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad)
27+
- [EfficientNet-Lite4](https://github.com/onnx/models/tree/master/vision/classification/efficientnet-lite4)
28+
- [Fast Neural Style Transfer](https://github.com/onnx/models/tree/master/vision/style_transfer/fast_neural_style)
29+
- [Super Resolution](https://github.com/onnx/models/tree/master/vision/super_resolution/sub_pixel_cnn_2016)
30+
- [YOLOv4](https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/yolov4)
31+
(Not exactly the same, nearest neighbour interpolation in pytorch differs)
32+
- [U-net](https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/)
33+
(Converted from pytorch to onnx and then back)
2634

2735
## Limitations
2836
Known current version limitations are:
@@ -48,4 +56,14 @@ Install it into pre-commit hook to always commit nicely formatted code:
4856

4957
### Testing
5058
[Pytest](https://docs.pytest.org/en/latest/) and [tox](https://tox.readthedocs.io/en/latest/).
51-
```tox```
59+
```tox```
60+
#### Test fixtures
61+
To test the complete conversion of an onnx model download pre-trained models:
62+
```./download_fixtures.sh```
63+
Use flag `--all` to download more models.
64+
Add any custom models to `./fixtures` folder to test their conversion.
65+
66+
### Debugging
67+
Set `ConvertModel(..., debug=True)` to compare each converted
68+
activation from pytorch with the activation from onnxruntime.
69+
This helps identify where in the graph the activations start to differ.

download_fixtures.sh

100644100755
Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env bash
12
mkdir -p fixtures
23
cd fixtures
34

@@ -6,22 +7,41 @@ if [[ ! -f mobilenetv2-1.0.onnx ]]; then
67
curl -o mobilenetv2-1.0.onnx https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx
78
fi
89

9-
#if [[ ! -f resnet18v1.onnx ]]; then
10-
# echo Downloading resnet18v1
11-
# curl -o resnet18v1.onnx https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/resnet18v1.onnx
12-
#fi
13-
1410
if [[ ! -f shufflenet_v2.onnx ]]; then
1511
echo Downloading shufflenet_v2
1612
curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/blob/master/vision/classification/shufflenet/model/shufflenet-v2-10.onnx\?raw\=true
1713
fi
1814

19-
#if [[ ! -f bertsquad-10.onnx ]]; then
20-
# echo Downloading bertsquad-10
21-
# curl -LJo bertsquad-10.onnx https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx\?raw\=true
22-
#fi
15+
if [[ $1 == "--all" ]]; then
16+
if [[ ! -f resnet18v1.onnx ]]; then
17+
echo Downloading resnet18v1
18+
curl -o resnet18v1.onnx https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/resnet18v1.onnx
19+
fi
20+
21+
if [[ ! -f bertsquad-10.onnx ]]; then
22+
echo Downloading bertsquad-10
23+
curl -LJo bertsquad-10.onnx https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx\?raw\=true
24+
fi
25+
26+
if [[ ! -f yolo_v4.onnx ]]; then
27+
echo Downloading yolo_v4
28+
curl -LJo yolo_v4.onnx https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/yolov4/model/yolov4.onnx\?raw\=true
29+
fi
30+
31+
if [[ ! -f super_res.onnx ]]; then
32+
echo Downloading super_res
33+
curl -LJo super_res.onnx https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx\?raw\=true
34+
fi
35+
36+
if [[ ! -f fast_neural_style.onnx ]]; then
37+
echo Downloading fast_neural_style
38+
curl -LJo fast_neural_style.onnx https://github.com/onnx/models/blob/master/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx\?raw\=true
39+
fi
40+
41+
if [[ ! -f efficientnet-lite4.onnx ]]; then
42+
echo Downloading efficientnet-lite4
43+
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx\?raw\=true
44+
fi
45+
fi
2346

24-
#if [[ ! -f yolo_v4.onnx ]]; then
25-
# echo Downloading yolo_v4
26-
# curl -LJo yolo_v4.onnx https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/yolov4/model/yolov4.onnx\?raw\=true
27-
#fi
47+
echo "All models downloaded."

onnx2pytorch/convert/attribute.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import warnings
2+
13
import onnx
24
from onnx import numpy_helper
35

4-
from onnx2pytorch.utils import to_pytorch_params
6+
from onnx2pytorch.utils import (
7+
extract_padding_params_for_conv_layer,
8+
extract_padding_params,
9+
)
510

611
TENSOR_PROTO_MAPPING = dict([i[::-1] for i in onnx.TensorProto.DataType.items()])
712

@@ -54,7 +59,12 @@ def extract_attributes(node):
5459
elif attr.name == "kernel_shape":
5560
kwargs["kernel_size"] = extract_attr_values(attr)
5661
elif attr.name == "pads":
57-
kwargs["padding"] = to_pytorch_params(extract_attr_values(attr))
62+
params = extract_attr_values(attr)
63+
if node.op_type == "Pad":
64+
kwargs["padding"] = extract_padding_params(params)
65+
else:
66+
# Works for Conv, MaxPooling and other layers from convert_layer func
67+
kwargs["padding"] = extract_padding_params_for_conv_layer(params)
5868
elif attr.name == "strides":
5969
kwargs["stride"] = extract_attr_values(attr)
6070
elif attr.name == "axis" and node.op_type == "Flatten":
@@ -89,10 +99,28 @@ def extract_attributes(node):
8999
kwargs["transpose_weight"] = not extract_attr_values(attr)
90100
elif attr.name == "transA":
91101
kwargs["transpose_activation"] = bool(extract_attr_values(attr))
102+
elif attr.name == "alpha" and node.op_type == "LeakyRelu":
103+
kwargs["negative_slope"] = extract_attr_values(attr)
92104
elif attr.name == "alpha":
93105
kwargs["weight_multiplier"] = extract_attr_values(attr)
94106
elif attr.name == "beta":
95107
kwargs["bias_multiplier"] = extract_attr_values(attr)
108+
elif attr.name == "starts":
109+
kwargs["starts"] = extract_attr_values(attr)
110+
elif attr.name == "ends":
111+
kwargs["ends"] = extract_attr_values(attr)
112+
elif attr.name == "coordinate_transformation_mode":
113+
arg = extract_attr_values(attr)
114+
if arg == "align_corners":
115+
kwargs["align_corners"] = True
116+
else:
117+
warnings.warn(
118+
"Pytorch's interpolate uses no coordinate_transformation_mode={}. "
119+
"Result might differ.".format(arg)
120+
)
121+
elif node.op_type == "Resize":
122+
# These parameters are not used, warn in Resize operator
123+
kwargs[attr.name] = extract_attr_values(attr)
96124
elif attr.name == "auto_pad":
97125
value = extract_attr_values(attr)
98126
if value == "NOTSET":

onnx2pytorch/convert/debug.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
import numpy as np
3+
4+
from onnx2pytorch.utils import get_activation_value
5+
6+
7+
def debug_model_conversion(onnx_model, inputs, pred_act, node, rtol=1e-3, atol=1e-4):
8+
"""Compare if the activations of pytorch are the same as from onnxruntime."""
9+
if not isinstance(inputs, list):
10+
raise TypeError("inputs should be in a list.")
11+
12+
if not all(isinstance(x, np.ndarray) for x in inputs):
13+
inputs = [x.detach().numpy() for x in inputs]
14+
15+
exp_act = get_activation_value(onnx_model, inputs, list(node.output))
16+
if isinstance(pred_act, list):
17+
for a, b in zip(exp_act, pred_act):
18+
assert torch.allclose(torch.from_numpy(a), b, rtol=rtol, atol=atol)
19+
else:
20+
a = torch.from_numpy(exp_act[0])
21+
b = pred_act
22+
assert torch.allclose(a, b, rtol=rtol, atol=atol)

onnx2pytorch/convert/layer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import nn
33
from onnx import numpy_helper
44

5-
from onnx2pytorch.operations import BatchNormUnsafe
5+
from onnx2pytorch.operations import BatchNormUnsafe, InstanceNormUnsafe
66
from onnx2pytorch.convert.attribute import extract_attributes, extract_attr_values
77

88

@@ -45,6 +45,7 @@ def convert_layer(node, layer_type, params=None):
4545
)
4646

4747
if params:
48+
pad_layer = None
4849
weight, bias = extract_params(params)
4950
kwargs["bias"] = bias is not None
5051
kwargs["in_channels"] = weight.dims[1] * kwargs.get("groups", 1)
@@ -56,9 +57,15 @@ def convert_layer(node, layer_type, params=None):
5657
kwargs["in_channels"],
5758
)
5859

60+
# if padding is a layer, remove from kwargs and prepend later
61+
if isinstance(kwargs["padding"], nn.Module):
62+
pad_layer = kwargs.pop("padding")
63+
5964
# initialize layer and load weights
6065
layer = layer(**kwargs)
6166
load_params(layer, weight, bias)
67+
if pad_layer is not None:
68+
layer = nn.Sequential(pad_layer, layer)
6269
else:
6370
# initialize operations without parameters (MaxPool, AvgPool, etc.)
6471
layer = layer(**kwargs)
@@ -80,6 +87,21 @@ def convert_batch_norm_layer(node, params):
8087
return layer
8188

8289

90+
def convert_instance_norm_layer(node, params):
91+
kwargs = extract_attributes(node)
92+
# Skips input dimension check, not possible before forward pass
93+
layer = InstanceNormUnsafe
94+
95+
kwargs["num_features"] = params[0].dims[0]
96+
# initialize layer and load weights
97+
layer = layer(**kwargs)
98+
key = ["weight", "bias"]
99+
for key, value in zip(key, params):
100+
getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value))
101+
102+
return layer
103+
104+
83105
def convert_linear_layer(node, params):
84106
"""Convert linear layer from onnx node and params."""
85107
# Default Gemm attributes

onnx2pytorch/convert/model.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
from torch.jit import TracerWarning
99
from torch.nn.modules.conv import _ConvNd
1010
from torch.nn.modules.batchnorm import _BatchNorm
11+
from torch.nn.modules.instancenorm import _InstanceNorm
1112
from torch.nn.modules.linear import Identity
1213

1314
from onnx2pytorch.operations import Split
15+
from onnx2pytorch.convert.debug import debug_model_conversion
1416
from onnx2pytorch.convert.operations import convert_operations
17+
from onnx2pytorch.utils import get_inputs_names
1518

1619

1720
class InitParameters(dict):
@@ -30,7 +33,9 @@ def get(self, item, default):
3033

3134

3235
class ConvertModel(nn.Module):
33-
def __init__(self, onnx_model: onnx.ModelProto, batch_dim=0, experimental=False):
36+
def __init__(
37+
self, onnx_model: onnx.ModelProto, batch_dim=0, experimental=False, debug=False
38+
):
3439
"""
3540
Convert onnx model to pytorch.
3641
@@ -53,6 +58,7 @@ def __init__(self, onnx_model: onnx.ModelProto, batch_dim=0, experimental=False)
5358
self.onnx_model = onnx_model
5459
self.batch_dim = batch_dim
5560
self.experimental = experimental
61+
self.debug = debug
5662
self.mapping = {}
5763
for op_id, op_name, op in convert_operations(onnx_model, batch_dim):
5864
setattr(self, op_name, op)
@@ -62,6 +68,8 @@ def __init__(self, onnx_model: onnx.ModelProto, batch_dim=0, experimental=False)
6268
{tensor.name: tensor for tensor in self.onnx_model.graph.initializer}
6369
)
6470

71+
self.input_names = get_inputs_names(onnx_model)
72+
6573
if experimental:
6674
warnings.warn(
6775
"Using experimental implementation that allows 'batch_size > 1'."
@@ -74,8 +82,7 @@ def forward(self, *input):
7482
"Input with larger batch size than 1 not supported yet."
7583
)
7684
# TODO figure out how to store only necessary activations.
77-
input_names = [x.name for x in self.onnx_model.graph.input]
78-
activations = dict(zip(input_names, input))
85+
activations = dict(zip(self.input_names, input))
7986

8087
for node in self.onnx_model.graph.node:
8188
# Identifying the layer ids and names
@@ -93,7 +100,11 @@ def forward(self, *input):
93100
# if first layer choose input as in_activations
94101
# if not in_op_names and len(node.input) == 1:
95102
# in_activations = input
96-
if isinstance(op, (nn.Linear, _ConvNd, _BatchNorm)):
103+
layer_types = (nn.Linear, _ConvNd, _BatchNorm, _InstanceNorm)
104+
if isinstance(op, layer_types) or (
105+
isinstance(op, nn.Sequential)
106+
and any(isinstance(x, layer_types) for x in op.modules())
107+
):
97108
in_activations = [
98109
activations[in_op_id]
99110
for in_op_id in node.input
@@ -122,6 +133,15 @@ def forward(self, *input):
122133
else:
123134
activations[out_op_id] = op(*in_activations)
124135

136+
if self.debug:
137+
# compare if the activations of pytorch are the same as from onnxruntime
138+
debug_model_conversion(
139+
self.onnx_model,
140+
[activations[x] for x in self.input_names],
141+
activations[out_op_id],
142+
node,
143+
)
144+
125145
# collect all outputs
126146
outputs = [activations[x.name] for x in self.onnx_model.graph.output]
127147
if len(outputs) == 1:

onnx2pytorch/convert/operations.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22

33
import torch
44
from torch import nn
5+
from torch.nn import functional as F
56
from onnx import numpy_helper
67

78
from onnx2pytorch.convert.attribute import extract_attributes
89
from onnx2pytorch.convert.layer import (
910
convert_layer,
1011
convert_linear_layer,
1112
convert_batch_norm_layer,
13+
convert_instance_norm_layer,
1214
)
1315
from onnx2pytorch.operations import *
1416
from onnx2pytorch.operations.base import OperatorWrapper
17+
from onnx2pytorch.operations import Resize, Upsample
1518
from onnx2pytorch.utils import value_wrapper
1619

1720

@@ -41,6 +44,8 @@ def convert_operations(onnx_model, batch_dim=0):
4144
op = convert_layer(node, "Conv", params)
4245
elif node.op_type == "Relu":
4346
op = nn.ReLU(inplace=True)
47+
elif node.op_type == "LeakyRelu":
48+
op = nn.LeakyReLU(**extract_attributes(node), inplace=True)
4449
elif node.op_type == "Sigmoid":
4550
op = nn.Sigmoid()
4651
elif node.op_type == "MaxPool":
@@ -54,6 +59,8 @@ def convert_operations(onnx_model, batch_dim=0):
5459
op.feature_dim = batch_dim + 1 # Necessary for transformers
5560
elif node.op_type == "BatchNormalization":
5661
op = convert_batch_norm_layer(node, params=params)
62+
elif node.op_type == "InstanceNormalization":
63+
op = convert_instance_norm_layer(node, params=params)
5764
elif node.op_type == "Concat":
5865
op = partial(torch.cat, **extract_attributes(node))
5966
elif node.op_type == "Constant":
@@ -75,7 +82,7 @@ def convert_operations(onnx_model, batch_dim=0):
7582
elif node.op_type == "ConstantOfShape":
7683
op = ConstantOfShape(**extract_attributes(node))
7784
elif node.op_type == "Slice":
78-
op = Slice()
85+
op = Slice(**extract_attributes(node))
7986
elif node.op_type == "Cast":
8087
op = Cast(**extract_attributes(node))
8188
elif node.op_type == "Where":
@@ -136,6 +143,10 @@ def convert_operations(onnx_model, batch_dim=0):
136143
op = convert_layer(node, "ConvTranspose", params)
137144
elif node.op_type == "Identity":
138145
op = nn.Identity()
146+
elif node.op_type == "Resize":
147+
op = Resize(**extract_attributes(node))
148+
elif node.op_type == "Upsample":
149+
op = Upsample(**extract_attributes(node))
139150
elif node.op_type == "OneHot":
140151
op = OneHot(**extract_attributes(node))
141152
elif node.op_type == "Pad":
@@ -146,6 +157,10 @@ def convert_operations(onnx_model, batch_dim=0):
146157
op = OperatorWrapper(torch.tanh)
147158
elif node.op_type == "Erf":
148159
op = OperatorWrapper(torch.erf)
160+
elif node.op_type == "Log":
161+
op = OperatorWrapper(torch.log)
162+
elif node.op_type == "Exp":
163+
op = OperatorWrapper(torch.exp)
149164
else:
150165
op = getattr(torch, node.op_type.lower(), None)
151166
if op is None:

0 commit comments

Comments
 (0)