From 343ac8f78291b526cb646fcfd3c9f956a6d40474 Mon Sep 17 00:00:00 2001 From: Ryan Young Date: Wed, 19 Feb 2025 16:31:16 -0800 Subject: [PATCH] kwargs reduceprod --- onnx2pytorch/convert/operations.py | 37 ++++++++++++++++++------------ 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/onnx2pytorch/convert/operations.py b/onnx2pytorch/convert/operations.py index c8c56af..f472dcc 100644 --- a/onnx2pytorch/convert/operations.py +++ b/onnx2pytorch/convert/operations.py @@ -4,27 +4,21 @@ import numpy as np import onnx import torch +from onnx import numpy_helper from torch import nn from torch.nn import functional as F -from onnx import numpy_helper from torch.nn.modules.linear import Identity from onnx2pytorch.convert.attribute import extract_attributes -from onnx2pytorch.convert.layer import ( - convert_layer, - convert_linear_layer, - convert_batch_norm_layer, - convert_instance_norm_layer, - convert_lstm_layer, -) +from onnx2pytorch.convert.layer import (convert_batch_norm_layer, + convert_instance_norm_layer, + convert_layer, convert_linear_layer, + convert_lstm_layer) from onnx2pytorch.operations import * +from onnx2pytorch.operations import Hardsigmoid, Resize, Upsample from onnx2pytorch.operations.base import OperatorWrapper -from onnx2pytorch.operations import Resize, Upsample, Hardsigmoid -from onnx2pytorch.utils import ( - get_inputs_names, - get_outputs_names, - value_wrapper, -) +from onnx2pytorch.utils import (get_inputs_names, get_outputs_names, + value_wrapper) def get_buffer_name(param_name): @@ -211,7 +205,20 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr elif node.op_type == "ReduceProd": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) - op = partial(torch.prod, **kwargs) + def reduceprod_wrapper(x, **kw): + # When no reduction axis is specified, + # we must simulate "keepdim=True" by outputting a tensor of the same rank. + if 'axes' not in kw and 'dim' not in kw: + original_dim = x.dim() + # Compute the product over all elements + out = torch.prod(x) + # Reshape to have as many dimensions as the original input (all ones) + out = out.view([1] * original_dim) + else: + out = torch.prod(x, **kw) + return out + # Use the kwargs when binding reduceprod_wrapper. + op = lambda x: reduceprod_wrapper(x, **kwargs) elif node.op_type == "ReduceSum": op = ReduceSum(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "ReduceL2":