Skip to content

Commit 6961176

Browse files
reducemax
1 parent 31825a7 commit 6961176

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

onnx2pytorch/convert/operations.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,7 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
199199
elif node.op_type == "Reciprocal":
200200
op = OperatorWrapper(torch.reciprocal)
201201
elif node.op_type == "ReduceMax":
202-
kwargs = dict(keepdim=True)
203-
kwargs.update(extract_attributes(node))
204-
op = partial(torch.max, **kwargs)
202+
op = ReduceMax(**extract_attributes(node))
205203
elif node.op_type == "ReduceMean":
206204
kwargs = dict(keepdim=True)
207205
kwargs.update(extract_attributes(node))

onnx2pytorch/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .pad import Pad
2222
from .prelu import PRelu
2323
from .range import Range
24+
from .reducemax import ReduceMax
2425
from .reducesum import ReduceSum
2526
from .reshape import Reshape
2627
from .resize import Resize, Upsample
@@ -61,6 +62,7 @@
6162
"Pad",
6263
"PRelu",
6364
"Range",
65+
"ReduceMax",
6466
"ReduceSum",
6567
"Reshape",
6668
"Resize",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class ReduceMax(nn.Module):
6+
def __init__(self, dim=None, keepdim=True):
7+
self.dim = dim
8+
self.keepdim = keepdim
9+
super().__init__()
10+
11+
def forward(self, data: torch.Tensor):
12+
dim = self.dim
13+
if dim is None:
14+
dim = tuple(range(data.ndim))
15+
return torch.amax(data, dim=dim, keepdim=self.keepdim)

0 commit comments

Comments
 (0)