88from torch .jit import TracerWarning
99from torch .nn .modules .conv import _ConvNd
1010from torch .nn .modules .batchnorm import _BatchNorm
11+ from torch .nn .modules .instancenorm import _InstanceNorm
1112from torch .nn .modules .linear import Identity
1213
1314from onnx2pytorch .operations import Split
15+ from onnx2pytorch .convert .debug import debug_model_conversion
1416from onnx2pytorch .convert .operations import convert_operations
17+ from onnx2pytorch .utils import get_inputs_names
1518
1619
1720class InitParameters (dict ):
@@ -30,7 +33,9 @@ def get(self, item, default):
3033
3134
3235class ConvertModel (nn .Module ):
33- def __init__ (self , onnx_model : onnx .ModelProto , batch_dim = 0 , experimental = False ):
36+ def __init__ (
37+ self , onnx_model : onnx .ModelProto , batch_dim = 0 , experimental = False , debug = False
38+ ):
3439 """
3540 Convert onnx model to pytorch.
3641
@@ -53,6 +58,7 @@ def __init__(self, onnx_model: onnx.ModelProto, batch_dim=0, experimental=False)
5358 self .onnx_model = onnx_model
5459 self .batch_dim = batch_dim
5560 self .experimental = experimental
61+ self .debug = debug
5662 self .mapping = {}
5763 for op_id , op_name , op in convert_operations (onnx_model , batch_dim ):
5864 setattr (self , op_name , op )
@@ -62,6 +68,8 @@ def __init__(self, onnx_model: onnx.ModelProto, batch_dim=0, experimental=False)
6268 {tensor .name : tensor for tensor in self .onnx_model .graph .initializer }
6369 )
6470
71+ self .input_names = get_inputs_names (onnx_model )
72+
6573 if experimental :
6674 warnings .warn (
6775 "Using experimental implementation that allows 'batch_size > 1'."
@@ -74,8 +82,7 @@ def forward(self, *input):
7482 "Input with larger batch size than 1 not supported yet."
7583 )
7684 # TODO figure out how to store only necessary activations.
77- input_names = [x .name for x in self .onnx_model .graph .input ]
78- activations = dict (zip (input_names , input ))
85+ activations = dict (zip (self .input_names , input ))
7986
8087 for node in self .onnx_model .graph .node :
8188 # Identifying the layer ids and names
@@ -93,7 +100,11 @@ def forward(self, *input):
93100 # if first layer choose input as in_activations
94101 # if not in_op_names and len(node.input) == 1:
95102 # in_activations = input
96- if isinstance (op , (nn .Linear , _ConvNd , _BatchNorm )):
103+ layer_types = (nn .Linear , _ConvNd , _BatchNorm , _InstanceNorm )
104+ if isinstance (op , layer_types ) or (
105+ isinstance (op , nn .Sequential )
106+ and any (isinstance (x , layer_types ) for x in op .modules ())
107+ ):
97108 in_activations = [
98109 activations [in_op_id ]
99110 for in_op_id in node .input
@@ -122,6 +133,15 @@ def forward(self, *input):
122133 else :
123134 activations [out_op_id ] = op (* in_activations )
124135
136+ if self .debug :
137+ # compare if the activations of pytorch are the same as from onnxruntime
138+ debug_model_conversion (
139+ self .onnx_model ,
140+ [activations [x ] for x in self .input_names ],
141+ activations [out_op_id ],
142+ node ,
143+ )
144+
125145 # collect all outputs
126146 outputs = [activations [x .name ] for x in self .onnx_model .graph .output ]
127147 if len (outputs ) == 1 :
0 commit comments