Skip to content

Commit 1301703

Browse files
Avoid calling torch.clamp when neither min and max are given
1 parent 8750c19 commit 1301703

File tree

6 files changed

+62
-1
lines changed

6 files changed

+62
-1
lines changed

download_fixtures.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ if [[ $1 == "--all" ]]; then
4242
echo Downloading efficientnet-lite4
4343
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx\?raw\=true
4444
fi
45+
46+
if [[ ! -f mobilenetv2-7.onnx ]]; then
47+
echo Downloading mobilenetv2-7
48+
curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/master/vision/classification/mobilenet/model/mobilenetv2-7.onnx\?raw\=true
49+
fi
50+
4551
fi
4652

4753
echo "All models downloaded."

onnx2pytorch/convert/attribute.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def extract_attributes(node):
124124
kwargs["largest"] = extract_attr_values(attr)
125125
elif attr.name == "layout":
126126
kwargs["layout"] = extract_attr_values(attr)
127+
elif attr.name == "max":
128+
kwargs["max"] = extract_attr_values(attr)
129+
elif attr.name == "min":
130+
kwargs["min"] = extract_attr_values(attr)
127131
elif attr.name == "mode":
128132
kwargs["mode"] = extract_attr_values(attr)
129133
elif attr.name == "momentum":

onnx2pytorch/convert/operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
9393
elif node.op_type == "Ceil":
9494
op = OperatorWrapper(torch.ceil)
9595
elif node.op_type == "Clip":
96-
op = OperatorWrapper(torch.clamp)
96+
op = Clip(**extract_attributes(node))
9797
elif node.op_type == "Concat":
9898
op = partial(torch.cat, **extract_attributes(node))
9999
elif node.op_type == "Constant":

onnx2pytorch/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .batchnorm import BatchNormWrapper
33
from .bitshift import BitShift
44
from .cast import Cast
5+
from .clip import Clip
56
from .constant import Constant
67
from .constantofshape import ConstantOfShape
78
from .div import Div
@@ -41,6 +42,7 @@
4142
"BatchNormWrapper",
4243
"BitShift",
4344
"Cast",
45+
"Clip",
4446
"Constant",
4547
"ConstantOfShape",
4648
"Div",

onnx2pytorch/operations/clip.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class Clip(nn.Module):
6+
def __init__(self, min=None, max=None):
7+
super().__init__()
8+
self.min = min
9+
self.max = max
10+
11+
def forward(self, input, min=None, max=None):
12+
if min is None:
13+
min = self.min
14+
if max is None:
15+
max = self.max
16+
if min is None and max is None:
17+
return input
18+
else:
19+
return torch.clamp(input, min=min, max=max)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from onnx2pytorch.operations.clip import Clip
6+
7+
8+
def test_clip():
9+
x_np = np.random.randn(3, 4, 5).astype(np.float32)
10+
x = torch.from_numpy(x_np)
11+
12+
op = Clip(min=-1, max=1)
13+
exp_y_np = np.clip(x_np, -1, 1)
14+
exp_y = torch.from_numpy(exp_y_np)
15+
assert torch.equal(op(x), exp_y)
16+
17+
op = Clip(min=0)
18+
exp_y_np = np.clip(x_np, 0, np.inf)
19+
exp_y = torch.from_numpy(exp_y_np)
20+
assert torch.equal(op(x), exp_y)
21+
22+
op = Clip(max=0)
23+
exp_y_np = np.clip(x_np, np.NINF, 0)
24+
exp_y = torch.from_numpy(exp_y_np)
25+
assert torch.equal(op(x), exp_y)
26+
27+
op = Clip()
28+
exp_y_np = np.clip(x_np, np.NINF, np.inf)
29+
exp_y = torch.from_numpy(exp_y_np)
30+
assert torch.equal(op(x), exp_y)

0 commit comments

Comments
 (0)