Skip to content

Commit 40a0ab3

Browse files
authored
ReduceL2:ReduceL2_forward (#55)
* fixes to work with verigram models * ReduceL2 op is added to __init__ file * revert Clip operation
1 parent 4aeab14 commit 40a0ab3

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

onnx2pytorch/convert/operations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
214214
op = partial(torch.prod, **kwargs)
215215
elif node.op_type == "ReduceSum":
216216
op = ReduceSum(opset_version=opset_version, **extract_attributes(node))
217+
elif node.op_type == "ReduceL2":
218+
op = ReduceL2(opset_version=opset_version, **extract_attributes(node))
217219
elif node.op_type == "Relu":
218220
op = nn.ReLU(inplace=True)
219221
elif node.op_type == "Reshape":

onnx2pytorch/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .range import Range
2323
from .reducemax import ReduceMax
2424
from .reducesum import ReduceSum
25+
from .reducel2 import ReduceL2
2526
from .reshape import Reshape
2627
from .resize import Resize, Upsample
2728
from .scatter import Scatter
@@ -63,6 +64,7 @@
6364
"Range",
6465
"ReduceMax",
6566
"ReduceSum",
67+
"ReduceL2",
6668
"Reshape",
6769
"Resize",
6870
"Scatter",
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
from torch import nn
3+
4+
class ReduceL2(nn.Module):
5+
def __init__(
6+
self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False
7+
):
8+
self.opset_version = opset_version
9+
self.dim = dim
10+
self.keepdim = keepdim
11+
self.noop_with_empty_axes = noop_with_empty_axes
12+
super().__init__()
13+
14+
def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
15+
if self.opset_version < 13:
16+
dims = self.dim
17+
else:
18+
dims = axes
19+
if dims is None:
20+
if self.noop_with_empty_axes:
21+
return data
22+
else:
23+
dims = tuple(range(data.ndim))
24+
25+
if isinstance(dims, int):
26+
dim = dims
27+
else:
28+
dim=tuple(list(dims))
29+
30+
ret = torch.sqrt(torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim))
31+
return ret

0 commit comments

Comments
 (0)