@@ -41,6 +41,8 @@ def extract_attr_values(attr):
4141 value = numpy_helper .to_array (attr .t )
4242 elif attr .type == AttributeType ["STRING" ]:
4343 value = attr .s .decode ()
44+ elif attr .type == AttributeType ["GRAPH" ]:
45+ value = attr .g
4446 else :
4547 raise NotImplementedError (
4648 "Extraction of attribute type {} not implemented." .format (attr .type )
@@ -52,21 +54,27 @@ def extract_attributes(node):
5254 """Extract onnx attributes. Map onnx feature naming to pytorch."""
5355 kwargs = {}
5456 for attr in node .attribute :
55- if attr .name == "dilations" :
56- kwargs ["dilation" ] = extract_attr_values (attr )
57- elif attr .name == "group" :
58- kwargs ["groups" ] = extract_attr_values (attr )
59- elif attr .name == "kernel_shape" :
60- kwargs ["kernel_size" ] = extract_attr_values (attr )
61- elif attr .name == "pads" :
62- params = extract_attr_values (attr )
63- if node .op_type == "Pad" :
64- kwargs ["padding" ] = extract_padding_params (params )
57+ if attr .name == "activation_alpha" :
58+ kwargs ["activation_alpha" ] = extract_attr_values (attr )
59+ elif attr .name == "activation_beta" :
60+ kwargs ["activation_beta" ] = extract_attr_values (attr )
61+ elif attr .name == "activations" :
62+ kwargs ["activations" ] = extract_attr_values (attr )
63+ elif attr .name == "alpha" :
64+ if node .op_type == "LeakyRelu" :
65+ kwargs ["negative_slope" ] = extract_attr_values (attr )
66+ elif node .op_type in ("Elu" , "ThresholdedRelu" ):
67+ kwargs ["alpha" ] = extract_attr_values (attr )
6568 else :
66- # Works for Conv, MaxPooling and other layers from convert_layer func
67- kwargs ["padding" ] = extract_padding_params_for_conv_layer (params )
68- elif attr .name == "strides" :
69- kwargs ["stride" ] = extract_attr_values (attr )
69+ kwargs ["weight_multiplier" ] = extract_attr_values (attr )
70+ elif attr .name == "auto_pad" :
71+ value = extract_attr_values (attr )
72+ if value == "NOTSET" :
73+ pass
74+ else :
75+ raise NotImplementedError (
76+ "auto_pad={} functionality not implemented." .format (value )
77+ )
7078 elif attr .name == "axis" and node .op_type == "Flatten" :
7179 kwargs ["start_dim" ] = extract_attr_values (attr )
7280 elif attr .name == "axis" or attr .name == "axes" :
@@ -75,62 +83,103 @@ def extract_attributes(node):
7583 kwargs ["dim" ] = v [0 ]
7684 else :
7785 kwargs ["dim" ] = v
78- elif attr .name == "keepdims" :
79- kwargs ["keepdim" ] = bool (extract_attr_values (attr ))
86+ elif attr .name == "beta" :
87+ kwargs ["bias_multiplier" ] = extract_attr_values (attr )
88+ elif attr .name == "body" :
89+ kwargs ["body" ] = extract_attr_values (attr )
90+ elif attr .name == "ceil_mode" :
91+ kwargs ["ceil_mode" ] = bool (extract_attr_values (attr ))
92+ elif attr .name == "center_point_box" :
93+ kwargs ["center_point_box" ] = extract_attr_values (attr )
94+ elif attr .name == "clip" :
95+ kwargs ["clip" ] = extract_attr_values (attr )
96+ elif attr .name == "coordinate_transformation_mode" :
97+ arg = extract_attr_values (attr )
98+ if arg == "align_corners" :
99+ kwargs ["align_corners" ] = True
100+ else :
101+ warnings .warn (
102+ "Pytorch's interpolate uses no coordinate_transformation_mode={}. "
103+ "Result might differ." .format (arg )
104+ )
105+ elif attr .name == "dilations" :
106+ kwargs ["dilation" ] = extract_attr_values (attr )
107+ elif attr .name == "direction" :
108+ kwargs ["direction" ] = extract_attr_values (attr )
109+ elif attr .name == "ends" :
110+ kwargs ["ends" ] = extract_attr_values (attr )
80111 elif attr .name == "epsilon" :
81112 kwargs ["eps" ] = extract_attr_values (attr )
113+ elif attr .name == "group" :
114+ kwargs ["groups" ] = extract_attr_values (attr )
115+ elif attr .name == "hidden_size" :
116+ kwargs ["hidden_size" ] = extract_attr_values (attr )
117+ elif attr .name == "input_forget" :
118+ kwargs ["input_forget" ] = extract_attr_values (attr )
119+ elif attr .name == "keepdims" :
120+ kwargs ["keepdim" ] = bool (extract_attr_values (attr ))
121+ elif attr .name == "kernel_shape" :
122+ kwargs ["kernel_size" ] = extract_attr_values (attr )
123+ elif attr .name == "largest" :
124+ kwargs ["largest" ] = extract_attr_values (attr )
125+ elif attr .name == "layout" :
126+ kwargs ["layout" ] = extract_attr_values (attr )
127+ elif attr .name == "mode" :
128+ kwargs ["mode" ] = extract_attr_values (attr )
82129 elif attr .name == "momentum" :
83130 kwargs ["momentum" ] = extract_attr_values (attr )
84- elif attr .name == "ceil_mode" :
85- kwargs ["ceil_mode" ] = bool (extract_attr_values (attr ))
86- elif attr .name == "value" :
87- kwargs ["constant" ] = extract_attr_values (attr )
131+ elif attr .name == "noop_with_empty_axes" :
132+ kwargs ["noop_with_empty_axes" ] = extract_attr_values (attr )
133+ elif attr .name == "output_shape" and node .op_type == "ConvTranspose" :
134+ raise NotImplementedError (
135+ "ConvTranspose with dynamic padding not implemented."
136+ )
137+ elif attr .name == "pads" :
138+ params = extract_attr_values (attr )
139+ if node .op_type == "Pad" :
140+ kwargs ["padding" ] = extract_padding_params (params )
141+ else :
142+ # Works for Conv, MaxPooling and other layers from convert_layer func
143+ kwargs ["padding" ] = extract_padding_params_for_conv_layer (params )
88144 elif attr .name == "perm" :
89145 kwargs ["dims" ] = extract_attr_values (attr )
90- elif attr .name == "split" :
91- kwargs ["split_size_or_sections" ] = extract_attr_values (attr )
146+ elif attr .name == "repeats" :
147+ kwargs ["repeats" ] = extract_attr_values (attr )
148+ elif attr .name == "sorted" :
149+ kwargs ["sorted" ] = extract_attr_values (attr )
150+ elif attr .name == "sparse_value" :
151+ kwargs ["constant" ] = extract_attr_values (attr )
92152 elif attr .name == "spatial" :
93153 kwargs ["spatial" ] = extract_attr_values (attr ) # Batch norm parameter
154+ elif attr .name == "split" :
155+ kwargs ["split_size_or_sections" ] = extract_attr_values (attr )
156+ elif attr .name == "strides" :
157+ kwargs ["stride" ] = extract_attr_values (attr )
158+ elif attr .name == "starts" :
159+ kwargs ["starts" ] = extract_attr_values (attr )
94160 elif attr .name == "to" :
95161 kwargs ["dtype" ] = TENSOR_PROTO_MAPPING [extract_attr_values (attr )].lower ()
96- elif attr .name == "mode" :
97- kwargs ["mode" ] = extract_attr_values (attr )
98162 elif attr .name == "transB" :
99163 kwargs ["transpose_weight" ] = not extract_attr_values (attr )
100164 elif attr .name == "transA" :
101165 kwargs ["transpose_activation" ] = bool (extract_attr_values (attr ))
102- elif attr .name == "alpha" and node .op_type == "LeakyRelu" :
103- kwargs ["negative_slope" ] = extract_attr_values (attr )
104- elif attr .name == "alpha" and node .op_type == "Elu" :
105- kwargs ["alpha" ] = extract_attr_values (attr )
106- elif attr .name == "alpha" :
107- kwargs ["weight_multiplier" ] = extract_attr_values (attr )
108- elif attr .name == "beta" :
109- kwargs ["bias_multiplier" ] = extract_attr_values (attr )
110- elif attr .name == "starts" :
111- kwargs ["starts" ] = extract_attr_values (attr )
112- elif attr .name == "ends" :
113- kwargs ["ends" ] = extract_attr_values (attr )
114- elif attr .name == "coordinate_transformation_mode" :
115- arg = extract_attr_values (attr )
116- if arg == "align_corners" :
117- kwargs ["align_corners" ] = True
118- else :
119- warnings .warn (
120- "Pytorch's interpolate uses no coordinate_transformation_mode={}. "
121- "Result might differ." .format (arg )
122- )
166+ elif attr .name == "value" :
167+ kwargs ["constant" ] = extract_attr_values (attr )
168+ elif attr .name == "value_float" :
169+ kwargs ["constant" ] = extract_attr_values (attr )
170+ elif attr .name == "value_floats" :
171+ kwargs ["constant" ] = extract_attr_values (attr )
172+ elif attr .name == "value_int" :
173+ kwargs ["constant" ] = extract_attr_values (attr )
174+ elif attr .name == "value_ints" :
175+ kwargs ["constant" ] = extract_attr_values (attr )
176+ elif attr .name == "value_string" :
177+ kwargs ["constant" ] = extract_attr_values (attr )
178+ elif attr .name == "value_strings" :
179+ kwargs ["constant" ] = extract_attr_values (attr )
123180 elif node .op_type == "Resize" :
124181 # These parameters are not used, warn in Resize operator
125182 kwargs [attr .name ] = extract_attr_values (attr )
126- elif attr .name == "auto_pad" :
127- value = extract_attr_values (attr )
128- if value == "NOTSET" :
129- pass
130- else :
131- raise NotImplementedError (
132- "auto_pad={} functionality not implemented." .format (value )
133- )
134183 else :
135184 raise NotImplementedError (
136185 "Extraction of attribute {} not implemented." .format (attr .name )
0 commit comments