Skip to content

Commit 4aeab14

Browse files
reducemax
1 parent da17e48 commit 4aeab14

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

onnx2pytorch/convert/operations.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,7 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
199199
elif node.op_type == "Reciprocal":
200200
op = OperatorWrapper(torch.reciprocal)
201201
elif node.op_type == "ReduceMax":
202-
kwargs = dict(keepdim=True)
203-
kwargs.update(extract_attributes(node))
204-
op = partial(torch.max, **kwargs)
202+
op = ReduceMax(**extract_attributes(node))
205203
elif node.op_type == "ReduceMean":
206204
kwargs = dict(keepdim=True)
207205
kwargs.update(extract_attributes(node))

onnx2pytorch/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .pad import Pad
2121
from .prelu import PRelu
2222
from .range import Range
23+
from .reducemax import ReduceMax
2324
from .reducesum import ReduceSum
2425
from .reshape import Reshape
2526
from .resize import Resize, Upsample
@@ -60,6 +61,7 @@
6061
"Pad",
6162
"PRelu",
6263
"Range",
64+
"ReduceMax",
6365
"ReduceSum",
6466
"Reshape",
6567
"Resize",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class ReduceMax(nn.Module):
6+
def __init__(self, dim=None, keepdim=True):
7+
self.dim = dim
8+
self.keepdim = keepdim
9+
super().__init__()
10+
11+
def forward(self, data: torch.Tensor):
12+
dim = self.dim
13+
if dim is None:
14+
dim = tuple(range(data.ndim))
15+
return torch.amax(data, dim=dim, keepdim=self.keepdim)

0 commit comments

Comments
 (0)