Skip to content

Commit f233a73

Browse files
committed
Add support for gru_linear_before_reset=0
1 parent dad2054 commit f233a73

File tree

3 files changed

+327
-10
lines changed

3 files changed

+327
-10
lines changed

onnx2pytorch/convert/layer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def convert_gru_layer(node, weights):
412412
direction="forward",
413413
hidden_size=None,
414414
layout=0,
415-
linear_before_reset=0,
415+
linear_before_reset=0, # ONNX spec default
416416
)
417417
dc.update(extract_attributes(node))
418418
if dc["activation_alpha"] is not None:
@@ -436,10 +436,7 @@ def convert_gru_layer(node, weights):
436436
raise NotImplementedError(
437437
"GRU not implemented for layout={}".format(dc["layout"])
438438
)
439-
if dc["linear_before_reset"] != 0:
440-
raise NotImplementedError(
441-
"GRU linear_before_reset={}".format(dc["linear_before_reset"])
442-
)
439+
# linear_before_reset is now supported for both 0 and 1
443440

444441
kwargs = {
445442
"input_size": W.shape[2],
@@ -570,5 +567,5 @@ def convert_gru_layer(node, weights):
570567
)
571568
getattr(gru_layer, "bias_hh_l0").data = Rb_rzn
572569

573-
layer = GRUWrapper(gru_layer)
570+
layer = GRUWrapper(gru_layer, linear_before_reset=dc["linear_before_reset"])
574571
return layer

onnx2pytorch/operations/gru.py

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,180 @@
1+
import torch
12
from torch import nn
23

34

45
class GRUWrapper(nn.Module):
5-
"""Wraps a 1-layer nn.GRU to match the API of an ONNX GRU.
6+
"""Wraps a 1-layer nn.GRU or custom GRU to match the API of an ONNX GRU.
67
78
It expects h_0 as a separate input rather than as a tuple,
89
and returns h_n as a separate output rather than as a tuple.
10+
11+
Supports both linear_before_reset=0 and linear_before_reset=1.
912
"""
1013

11-
def __init__(self, gru_module: nn.GRU):
14+
def __init__(self, gru_module, linear_before_reset=1):
1215
super().__init__()
1316
self.gru = gru_module
17+
self.linear_before_reset = linear_before_reset
18+
19+
# For linear_before_reset=0, we need custom forward pass
20+
if linear_before_reset == 0 and isinstance(gru_module, nn.GRU):
21+
# Extract parameters from PyTorch GRU for custom implementation
22+
self.input_size = gru_module.input_size
23+
self.hidden_size = gru_module.hidden_size
24+
self.bidirectional = gru_module.bidirectional
1425

1526
def forward(self, input, h_0=None):
1627
(seq_len, batch, input_size) = input.shape
1728
num_layers = 1
18-
num_directions = self.gru.bidirectional + 1
29+
num_directions = (
30+
self.gru.bidirectional + 1 if hasattr(self.gru, "bidirectional") else 1
31+
)
1932
hidden_size = self.gru.hidden_size
2033

2134
if h_0 is None or h_0.numel() == 0:
2235
h_0 = None
2336

24-
output, h_n = self.gru(input, h_0)
37+
if self.linear_before_reset == 1:
38+
# Use standard PyTorch GRU (linear_before_reset=1 is PyTorch's default)
39+
output, h_n = self.gru(input, h_0)
40+
else:
41+
# Custom implementation for linear_before_reset=0
42+
output, h_n = self._forward_linear_before_reset_0(input, h_0)
2543

2644
# Y has shape (seq_length, num_directions, batch_size, hidden_size)
2745
Y = output.view(seq_len, batch, num_directions, hidden_size).transpose(1, 2)
2846
# Y_h has shape (num_directions, batch_size, hidden_size)
2947
Y_h = h_n.view(num_layers, num_directions, batch, hidden_size).squeeze(0)
3048

3149
return Y, Y_h
50+
51+
def _forward_linear_before_reset_0(self, input, h_0):
52+
"""Custom GRU forward with linear_before_reset=0 (ONNX/TensorFlow default).
53+
54+
Key difference from linear_before_reset=1 (PyTorch default):
55+
- linear_before_reset=0: ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
56+
Reset gate is applied to hidden state BEFORE matrix multiplication.
57+
- linear_before_reset=1: ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
58+
Reset gate is applied AFTER matrix multiplication and bias addition.
59+
60+
Equations for linear_before_reset=0:
61+
r_t = sigmoid(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr)
62+
z_t = sigmoid(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz)
63+
n_t = tanh(W_in @ x_t + b_in + (r_t * h_{t-1}) @ W_hn + b_hn)
64+
h_t = (1 - z_t) * n_t + z_t * h_{t-1}
65+
"""
66+
seq_len, batch, input_size = input.shape
67+
hidden_size = self.hidden_size
68+
num_directions = 2 if self.bidirectional else 1
69+
70+
if h_0 is None:
71+
h_0 = torch.zeros(
72+
num_directions,
73+
batch,
74+
hidden_size,
75+
device=input.device,
76+
dtype=input.dtype,
77+
)
78+
79+
# Extract weights from PyTorch GRU
80+
# PyTorch stores weights as: weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0
81+
# For bidirectional: also weight_ih_l0_reverse, weight_hh_l0_reverse, etc.
82+
83+
def gru_cell_linear_before_reset_0(
84+
x_t, h_prev, weight_ih, weight_hh, bias_ih, bias_hh
85+
):
86+
"""Single GRU cell with linear_before_reset=0."""
87+
# Split weights for gates: reset, update, new
88+
# PyTorch order: [reset, update, new]
89+
hidden_size = h_prev.size(1)
90+
91+
# Input-to-hidden weights
92+
W_ir, W_iz, W_in = weight_ih.chunk(3, 0)
93+
# Hidden-to-hidden weights
94+
W_hr, W_hz, W_hn = weight_hh.chunk(3, 0)
95+
# Input biases
96+
b_ir, b_iz, b_in = (
97+
bias_ih.chunk(3, 0) if bias_ih is not None else (None, None, None)
98+
)
99+
# Hidden biases
100+
b_hr, b_hz, b_hn = (
101+
bias_hh.chunk(3, 0) if bias_hh is not None else (None, None, None)
102+
)
103+
104+
# Reset gate
105+
r_t = torch.sigmoid(
106+
x_t @ W_ir.t()
107+
+ (b_ir if b_ir is not None else 0)
108+
+ h_prev @ W_hr.t()
109+
+ (b_hr if b_hr is not None else 0)
110+
)
111+
112+
# Update gate
113+
z_t = torch.sigmoid(
114+
x_t @ W_iz.t()
115+
+ (b_iz if b_iz is not None else 0)
116+
+ h_prev @ W_hz.t()
117+
+ (b_hz if b_hz is not None else 0)
118+
)
119+
120+
# New gate (linear_before_reset=0 version)
121+
# Note: Reset gate is applied to h_prev BEFORE matrix multiplication
122+
# ONNX spec: ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
123+
n_t = torch.tanh(
124+
x_t @ W_in.t()
125+
+ (b_in if b_in is not None else 0)
126+
+ (r_t * h_prev) @ W_hn.t()
127+
+ (b_hn if b_hn is not None else 0)
128+
)
129+
130+
# Hidden state update
131+
h_t = (1 - z_t) * n_t + z_t * h_prev
132+
133+
return h_t
134+
135+
# Process sequence
136+
outputs_forward = []
137+
h_forward = h_0[0]
138+
139+
for t in range(seq_len):
140+
h_forward = gru_cell_linear_before_reset_0(
141+
input[t],
142+
h_forward,
143+
self.gru.weight_ih_l0,
144+
self.gru.weight_hh_l0,
145+
self.gru.bias_ih_l0 if self.gru.bias else None,
146+
self.gru.bias_hh_l0 if self.gru.bias else None,
147+
)
148+
outputs_forward.append(h_forward)
149+
150+
if self.bidirectional:
151+
# Process backward direction
152+
outputs_backward = []
153+
h_backward = h_0[1]
154+
155+
for t in range(seq_len - 1, -1, -1):
156+
h_backward = gru_cell_linear_before_reset_0(
157+
input[t],
158+
h_backward,
159+
self.gru.weight_ih_l0_reverse,
160+
self.gru.weight_hh_l0_reverse,
161+
self.gru.bias_ih_l0_reverse if self.gru.bias else None,
162+
self.gru.bias_hh_l0_reverse if self.gru.bias else None,
163+
)
164+
outputs_backward.append(h_backward)
165+
166+
outputs_backward.reverse()
167+
168+
# Concatenate forward and backward outputs
169+
output = torch.stack(
170+
[
171+
torch.cat([outputs_forward[t], outputs_backward[t]], dim=1)
172+
for t in range(seq_len)
173+
]
174+
)
175+
h_n = torch.stack([h_forward, h_backward])
176+
else:
177+
output = torch.stack(outputs_forward)
178+
h_n = h_forward.unsqueeze(0)
179+
180+
return output, h_n
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import io
2+
import numpy as np
3+
import onnx
4+
import onnxruntime as ort
5+
import pytest
6+
import torch
7+
from onnx import helper, TensorProto
8+
9+
from onnx2pytorch.convert import ConvertModel
10+
11+
12+
@pytest.mark.parametrize(
13+
"bidirectional, input_size, hidden_size, seq_len, batch, test_seq_len, test_batch",
14+
[
15+
(False, 3, 5, 23, 4, 23, 4),
16+
(False, 3, 5, 23, 4, 37, 4),
17+
(False, 3, 5, 23, 4, 23, 7),
18+
(True, 3, 5, 23, 4, 23, 4),
19+
(True, 3, 5, 23, 4, 37, 4),
20+
(True, 3, 5, 23, 4, 23, 7),
21+
],
22+
)
23+
def test_single_layer_gru(
24+
bidirectional, input_size, hidden_size, seq_len, batch, test_seq_len, test_batch
25+
):
26+
torch.manual_seed(42)
27+
num_layers = 1
28+
num_directions = bidirectional + 1
29+
gru = torch.nn.GRU(
30+
input_size=input_size,
31+
hidden_size=hidden_size,
32+
num_layers=num_layers,
33+
bidirectional=bidirectional,
34+
)
35+
input = torch.randn(seq_len, batch, input_size)
36+
h_0 = torch.randn(num_layers * num_directions, batch, hidden_size)
37+
output, h_n = gru(input, h_0)
38+
bitstream = io.BytesIO()
39+
torch.onnx.export(
40+
model=gru,
41+
args=(input, h_0),
42+
f=bitstream,
43+
input_names=["input", "h_0"],
44+
opset_version=11,
45+
dynamo=False, # Use legacy exporter for GRU compatibility
46+
dynamic_axes={
47+
"input": {0: "seq_len", 1: "batch"},
48+
"h_0": {1: "batch"},
49+
},
50+
)
51+
bitstream_data = bitstream.getvalue()
52+
53+
onnx_gru = onnx.ModelProto.FromString(bitstream_data)
54+
o2p_gru = ConvertModel(onnx_gru, experimental=True)
55+
with torch.no_grad():
56+
o2p_output, o2p_h_n = o2p_gru(input, h_0)
57+
torch.testing.assert_close(o2p_output, output, rtol=1e-6, atol=1e-6)
58+
torch.testing.assert_close(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
59+
60+
onnx_gru = onnx.ModelProto.FromString(bitstream_data)
61+
o2p_gru = ConvertModel(onnx_gru, experimental=True)
62+
with torch.no_grad():
63+
o2p_output, o2p_h_n = o2p_gru(h_0=h_0, input=input)
64+
torch.testing.assert_close(o2p_output, output, rtol=1e-6, atol=1e-6)
65+
torch.testing.assert_close(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
66+
with pytest.raises(KeyError):
67+
o2p_output, o2p_h_n = o2p_gru(input=input)
68+
with pytest.raises(Exception):
69+
# Even though initial states are optional for nn.GRU(),
70+
# we adhere to onnxruntime convention that inputs are provided
71+
# as either all positional or all keyword arguments.
72+
o2p_output, o2p_h_n = o2p_gru(input, h_0=h_0)
73+
74+
75+
@pytest.mark.parametrize("linear_before_reset", [0, 1])
76+
@pytest.mark.parametrize("bidirectional", [False, True])
77+
def test_gru_linear_before_reset(linear_before_reset, bidirectional):
78+
"""Test GRU with both linear_before_reset=0 (ONNX/TensorFlow default) and =1 (PyTorch default)."""
79+
torch.manual_seed(42)
80+
np.random.seed(42)
81+
82+
input_size = 3
83+
hidden_size = 4
84+
seq_len = 5
85+
batch = 2
86+
num_directions = 2 if bidirectional else 1
87+
88+
# Create input and initial hidden state
89+
X = np.random.randn(seq_len, batch, input_size).astype(np.float32)
90+
initial_h = np.random.randn(num_directions, batch, hidden_size).astype(np.float32)
91+
92+
# Create random weights for GRU
93+
# W shape: [num_directions, 3*hidden_size, input_size]
94+
W = np.random.randn(num_directions, 3 * hidden_size, input_size).astype(np.float32)
95+
# R shape: [num_directions, 3*hidden_size, hidden_size]
96+
R = np.random.randn(num_directions, 3 * hidden_size, hidden_size).astype(np.float32)
97+
# B shape: [num_directions, 6*hidden_size] (Wb and Rb concatenated)
98+
B = np.random.randn(num_directions, 6 * hidden_size).astype(np.float32)
99+
100+
# Create ONNX graph with GRU node
101+
input_tensor = helper.make_tensor_value_info(
102+
"X", TensorProto.FLOAT, [seq_len, batch, input_size]
103+
)
104+
initial_h_tensor = helper.make_tensor_value_info(
105+
"initial_h", TensorProto.FLOAT, [num_directions, batch, hidden_size]
106+
)
107+
output_tensor = helper.make_tensor_value_info(
108+
"Y", TensorProto.FLOAT, [seq_len, num_directions, batch, hidden_size]
109+
)
110+
output_h_tensor = helper.make_tensor_value_info(
111+
"Y_h", TensorProto.FLOAT, [num_directions, batch, hidden_size]
112+
)
113+
114+
W_initializer = helper.make_tensor(
115+
"W", TensorProto.FLOAT, W.shape, W.flatten().tolist()
116+
)
117+
R_initializer = helper.make_tensor(
118+
"R", TensorProto.FLOAT, R.shape, R.flatten().tolist()
119+
)
120+
B_initializer = helper.make_tensor(
121+
"B", TensorProto.FLOAT, B.shape, B.flatten().tolist()
122+
)
123+
124+
direction = "bidirectional" if bidirectional else "forward"
125+
gru_node = helper.make_node(
126+
"GRU",
127+
inputs=["X", "W", "R", "B", "", "initial_h"],
128+
outputs=["Y", "Y_h"],
129+
hidden_size=hidden_size,
130+
linear_before_reset=linear_before_reset,
131+
direction=direction,
132+
)
133+
134+
graph = helper.make_graph(
135+
[gru_node],
136+
"gru_test",
137+
[input_tensor, initial_h_tensor],
138+
[output_tensor, output_h_tensor],
139+
[W_initializer, R_initializer, B_initializer],
140+
)
141+
142+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 14)])
143+
onnx.checker.check_model(model)
144+
145+
# Run with onnxruntime to get expected output
146+
ort_session = ort.InferenceSession(model.SerializeToString())
147+
ort_inputs = {"X": X, "initial_h": initial_h}
148+
ort_outputs = ort_session.run(None, ort_inputs)
149+
expected_Y, expected_Y_h = ort_outputs
150+
151+
# Convert to PyTorch and run
152+
o2p_gru = ConvertModel(model, experimental=True)
153+
X_torch = torch.from_numpy(X)
154+
initial_h_torch = torch.from_numpy(initial_h)
155+
156+
with torch.no_grad():
157+
o2p_output, o2p_h_n = o2p_gru(X_torch, initial_h_torch)
158+
159+
# Compare with onnxruntime outputs
160+
torch.testing.assert_close(
161+
o2p_output,
162+
torch.from_numpy(expected_Y),
163+
rtol=1e-5,
164+
atol=1e-5,
165+
)
166+
torch.testing.assert_close(
167+
o2p_h_n,
168+
torch.from_numpy(expected_Y_h),
169+
rtol=1e-5,
170+
atol=1e-5,
171+
)

0 commit comments

Comments
 (0)