2020
2121def convert_operations (onnx_model , batch_dim = 0 ):
2222 """
23- Convert onnx model operations. Yields onnx's operator_id, opeartor_name and
23+ Convert onnx model operations. Yields onnx's operator_id, operator_name and
2424 converted pytorch operator.
2525
2626 Parameters
@@ -35,6 +35,7 @@ def convert_operations(onnx_model, batch_dim=0):
3535 iterator: (op_id, op_name, op)
3636 """
3737 weights = {tensor .name : tensor for tensor in onnx_model .graph .initializer }
38+ opset_version = onnx_model .opset_import [0 ].version
3839
3940 for i , node in enumerate (onnx_model .graph .node ):
4041 # extract only useful inputs
@@ -46,6 +47,8 @@ def convert_operations(onnx_model, batch_dim=0):
4647 op = nn .ReLU (inplace = True )
4748 elif node .op_type == "LeakyRelu" :
4849 op = nn .LeakyReLU (** extract_attributes (node ), inplace = True )
50+ elif node .op_type == "Elu" :
51+ op = nn .ELU (** extract_attributes (node ), inplace = True )
4952 elif node .op_type == "Sigmoid" :
5053 op = nn .Sigmoid ()
5154 elif node .op_type == "MaxPool" :
@@ -73,14 +76,18 @@ def convert_operations(onnx_model, batch_dim=0):
7376 op = Reshape (shape )
7477 elif node .op_type == "Shape" :
7578 op = Shape ()
79+ elif node .op_type == "Expand" :
80+ op = Expand ()
7681 elif node .op_type == "Gather" :
7782 op = Gather (** extract_attributes (node ))
7883 elif node .op_type == "Squeeze" :
79- op = Squeeze (** extract_attributes (node ))
84+ op = Squeeze (opset_version = opset_version , ** extract_attributes (node ))
8085 elif node .op_type == "Unsqueeze" :
81- op = partial ( torch . unsqueeze , ** extract_attributes (node ))
86+ op = Unsqueeze ( opset_version = opset_version , ** extract_attributes (node ))
8287 elif node .op_type == "ConstantOfShape" :
8388 op = ConstantOfShape (** extract_attributes (node ))
89+ elif node .op_type == "Range" :
90+ op = Range ()
8491 elif node .op_type == "Slice" :
8592 op = Slice (** extract_attributes (node ))
8693 elif node .op_type == "Cast" :
@@ -161,6 +168,14 @@ def convert_operations(onnx_model, batch_dim=0):
161168 op = OperatorWrapper (torch .log )
162169 elif node .op_type == "Exp" :
163170 op = OperatorWrapper (torch .exp )
171+ elif node .op_type == "Reciprocal" :
172+ op = OperatorWrapper (torch .reciprocal )
173+ elif node .op_type == "And" :
174+ op = OperatorWrapper (torch .logical_and )
175+ elif node .op_type == "Or" :
176+ op = OperatorWrapper (torch .logical_or )
177+ elif node .op_type == "Not" :
178+ op = OperatorWrapper (torch .logical_not )
164179 else :
165180 op = getattr (torch , node .op_type .lower (), None )
166181 if op is None :
0 commit comments