Skip to content

Commit 8eb6ae8

Browse files
authored
Merge pull request #9 from ToriML/develop
Version 0.3.0
2 parents c80e879 + 1ca6b40 commit 8eb6ae8

File tree

16 files changed

+189
-22
lines changed

16 files changed

+189
-22
lines changed

onnx2pytorch/convert/attribute.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def extract_attributes(node):
101101
kwargs["transpose_activation"] = bool(extract_attr_values(attr))
102102
elif attr.name == "alpha" and node.op_type == "LeakyRelu":
103103
kwargs["negative_slope"] = extract_attr_values(attr)
104+
elif attr.name == "alpha" and node.op_type == "Elu":
105+
kwargs["alpha"] = extract_attr_values(attr)
104106
elif attr.name == "alpha":
105107
kwargs["weight_multiplier"] = extract_attr_values(attr)
106108
elif attr.name == "beta":

onnx2pytorch/convert/layer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def convert_layer(node, layer_type, params=None):
4444
"Unexpected length of kernel_size dimension: {}".format(kernel_size_length)
4545
)
4646

47+
pad_layer = None
4748
if params:
48-
pad_layer = None
4949
weight, bias = extract_params(params)
5050
kwargs["bias"] = bias is not None
5151
kwargs["in_channels"] = weight.dims[1] * kwargs.get("groups", 1)
@@ -58,18 +58,23 @@ def convert_layer(node, layer_type, params=None):
5858
)
5959

6060
# if padding is a layer, remove from kwargs and prepend later
61-
if isinstance(kwargs["padding"], nn.Module):
61+
if "padding" in kwargs and isinstance(kwargs["padding"], nn.Module):
6262
pad_layer = kwargs.pop("padding")
6363

6464
# initialize layer and load weights
6565
layer = layer(**kwargs)
6666
load_params(layer, weight, bias)
67-
if pad_layer is not None:
68-
layer = nn.Sequential(pad_layer, layer)
6967
else:
7068
# initialize operations without parameters (MaxPool, AvgPool, etc.)
69+
70+
# if padding is a layer, remove from kwargs and prepend later
71+
if "padding" in kwargs and isinstance(kwargs["padding"], nn.Module):
72+
pad_layer = kwargs.pop("padding")
7173
layer = layer(**kwargs)
7274

75+
if pad_layer is not None:
76+
layer = nn.Sequential(pad_layer, layer)
77+
7378
return layer
7479

7580

@@ -80,8 +85,8 @@ def convert_batch_norm_layer(node, params):
8085
kwargs["num_features"] = params[0].dims[0]
8186
# initialize layer and load weights
8287
layer = layer(**kwargs)
83-
key = ["weight", "bias", "running_mean", "running_var"]
84-
for key, value in zip(key, params):
88+
keys = ["weight", "bias", "running_mean", "running_var"]
89+
for key, value in zip(keys, params):
8590
getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value))
8691

8792
return layer
@@ -95,8 +100,8 @@ def convert_instance_norm_layer(node, params):
95100
kwargs["num_features"] = params[0].dims[0]
96101
# initialize layer and load weights
97102
layer = layer(**kwargs)
98-
key = ["weight", "bias"]
99-
for key, value in zip(key, params):
103+
keys = ["weight", "bias"]
104+
for key, value in zip(keys, params):
100105
getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value))
101106

102107
return layer

onnx2pytorch/convert/operations.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
def convert_operations(onnx_model, batch_dim=0):
2222
"""
23-
Convert onnx model operations. Yields onnx's operator_id, opeartor_name and
23+
Convert onnx model operations. Yields onnx's operator_id, operator_name and
2424
converted pytorch operator.
2525
2626
Parameters
@@ -35,6 +35,7 @@ def convert_operations(onnx_model, batch_dim=0):
3535
iterator: (op_id, op_name, op)
3636
"""
3737
weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer}
38+
opset_version = onnx_model.opset_import[0].version
3839

3940
for i, node in enumerate(onnx_model.graph.node):
4041
# extract only useful inputs
@@ -46,6 +47,8 @@ def convert_operations(onnx_model, batch_dim=0):
4647
op = nn.ReLU(inplace=True)
4748
elif node.op_type == "LeakyRelu":
4849
op = nn.LeakyReLU(**extract_attributes(node), inplace=True)
50+
elif node.op_type == "Elu":
51+
op = nn.ELU(**extract_attributes(node), inplace=True)
4952
elif node.op_type == "Sigmoid":
5053
op = nn.Sigmoid()
5154
elif node.op_type == "MaxPool":
@@ -73,14 +76,18 @@ def convert_operations(onnx_model, batch_dim=0):
7376
op = Reshape(shape)
7477
elif node.op_type == "Shape":
7578
op = Shape()
79+
elif node.op_type == "Expand":
80+
op = Expand()
7681
elif node.op_type == "Gather":
7782
op = Gather(**extract_attributes(node))
7883
elif node.op_type == "Squeeze":
79-
op = Squeeze(**extract_attributes(node))
84+
op = Squeeze(opset_version=opset_version, **extract_attributes(node))
8085
elif node.op_type == "Unsqueeze":
81-
op = partial(torch.unsqueeze, **extract_attributes(node))
86+
op = Unsqueeze(opset_version=opset_version, **extract_attributes(node))
8287
elif node.op_type == "ConstantOfShape":
8388
op = ConstantOfShape(**extract_attributes(node))
89+
elif node.op_type == "Range":
90+
op = Range()
8491
elif node.op_type == "Slice":
8592
op = Slice(**extract_attributes(node))
8693
elif node.op_type == "Cast":
@@ -161,6 +168,14 @@ def convert_operations(onnx_model, batch_dim=0):
161168
op = OperatorWrapper(torch.log)
162169
elif node.op_type == "Exp":
163170
op = OperatorWrapper(torch.exp)
171+
elif node.op_type == "Reciprocal":
172+
op = OperatorWrapper(torch.reciprocal)
173+
elif node.op_type == "And":
174+
op = OperatorWrapper(torch.logical_and)
175+
elif node.op_type == "Or":
176+
op = OperatorWrapper(torch.logical_or)
177+
elif node.op_type == "Not":
178+
op = OperatorWrapper(torch.logical_not)
164179
else:
165180
op = getattr(torch, node.op_type.lower(), None)
166181
if op is None:

onnx2pytorch/operations/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .add import Add
22
from .batchnorm import BatchNormUnsafe
3+
from .expand import Expand
34
from .instancenorm import InstanceNormUnsafe
45
from .cast import Cast
56
from .constant import ConstantOfShape
@@ -8,16 +9,19 @@
89
from .onehot import OneHot
910
from .pad import Pad
1011
from .pooling import GlobalAveragePool
12+
from .range import Range
1113
from .reshape import Reshape
1214
from .shape import Shape
1315
from .slice import Slice
1416
from .split import Split
1517
from .squeeze import Squeeze
1618
from .resize import Resize, Upsample
19+
from .unsqueeze import Unsqueeze
1720

1821
__all__ = [
1922
"Add",
2023
"BatchNormUnsafe",
24+
"Expand",
2125
"InstanceNormUnsafe",
2226
"Cast",
2327
"ConstantOfShape",
@@ -26,11 +30,13 @@
2630
"OneHot",
2731
"Pad",
2832
"GlobalAveragePool",
33+
"Range",
2934
"Reshape",
3035
"Shape",
3136
"Slice",
3237
"Split",
3338
"Squeeze",
3439
"Resize",
40+
"Unsqueeze",
3541
"Upsample",
3642
]

onnx2pytorch/operations/expand.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class Expand(nn.Module):
6+
def forward(self, input: torch.Tensor, shape: torch.Tensor):
7+
# return input.expand(torch.Size(shape))
8+
return input * torch.ones(torch.Size(shape), dtype=input.dtype)

onnx2pytorch/operations/range.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class Range(nn.Module):
6+
def forward(self, start: torch.Tensor, limit: torch.Tensor, delta: torch.Tensor):
7+
return torch.arange(start=start, end=limit, step=delta)

onnx2pytorch/operations/slice.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def forward(
2828
if steps is None:
2929
steps = tuple(1 for _ in axes)
3030

31+
axes = [input.ndim + x if x < 0 else x for x in axes]
32+
3133
selection = [slice(None) for _ in range(max(axes) + 1)]
3234
for i, axis in enumerate(axes):
3335
selection[axis] = slice(starts[i], ends[i], steps[i])

onnx2pytorch/operations/squeeze.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@
55

66

77
class Squeeze(Operator):
8-
def __init__(self, dim=None):
8+
def __init__(self, opset_version, dim=None):
9+
self.opset_version = opset_version
910
self.dim = dim
1011
super().__init__()
1112

12-
def forward(self, input):
13-
if self.dim is None:
13+
def forward(self, input: torch.Tensor, axes: torch.Tensor = None):
14+
if self.opset_version < 13:
15+
dims = self.dim
16+
else:
17+
dims = axes
18+
19+
if dims is None:
1420
return torch.squeeze(input)
15-
elif isinstance(self.dim, int):
16-
return torch.squeeze(input, dim=self.dim)
21+
elif isinstance(dims, int):
22+
return torch.squeeze(input, dim=dims)
1723
else:
18-
for dim in sorted(self.dim, reverse=True):
24+
for dim in sorted(dims, reverse=True):
1925
input = torch.squeeze(input, dim=dim)
2026
return input
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
from torch import nn
3+
4+
from onnx2pytorch.operations.base import Operator
5+
6+
7+
class Unsqueeze(Operator):
8+
def __init__(self, opset_version, dim=None):
9+
self.opset_version = opset_version
10+
self.dim = dim
11+
super().__init__()
12+
13+
def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
14+
if self.opset_version < 13:
15+
dims = self.dim
16+
else:
17+
dims = torch.Size(axes)
18+
if dims is None:
19+
raise ValueError("Unsqueeze expects axes")
20+
elif isinstance(dims, int):
21+
return torch.unsqueeze(data, dim=dims)
22+
else:
23+
for dim in sorted(dims, reverse=True):
24+
data = torch.unsqueeze(data, dim=dim)
25+
return data

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
setup(
1414
name="onnx2pytorch",
15-
version="0.2.0",
15+
version="0.3.0",
1616
description="Library to transform onnx model to pytorch.",
1717
license="apache-2.0",
1818
author="Talmaj Marinc",

0 commit comments

Comments
 (0)