|
| 1 | +import io |
| 2 | +import numpy as np |
| 3 | +import onnx |
| 4 | +import onnxruntime as ort |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | +from onnx import helper, TensorProto |
| 8 | + |
| 9 | +from onnx2pytorch.convert import ConvertModel |
| 10 | + |
| 11 | + |
| 12 | +@pytest.mark.parametrize( |
| 13 | + "bidirectional, input_size, hidden_size, seq_len, batch, test_seq_len, test_batch", |
| 14 | + [ |
| 15 | + (False, 3, 5, 23, 4, 23, 4), |
| 16 | + (False, 3, 5, 23, 4, 37, 4), |
| 17 | + (False, 3, 5, 23, 4, 23, 7), |
| 18 | + (True, 3, 5, 23, 4, 23, 4), |
| 19 | + (True, 3, 5, 23, 4, 37, 4), |
| 20 | + (True, 3, 5, 23, 4, 23, 7), |
| 21 | + ], |
| 22 | +) |
| 23 | +def test_single_layer_gru( |
| 24 | + bidirectional, input_size, hidden_size, seq_len, batch, test_seq_len, test_batch |
| 25 | +): |
| 26 | + torch.manual_seed(42) |
| 27 | + num_layers = 1 |
| 28 | + num_directions = bidirectional + 1 |
| 29 | + gru = torch.nn.GRU( |
| 30 | + input_size=input_size, |
| 31 | + hidden_size=hidden_size, |
| 32 | + num_layers=num_layers, |
| 33 | + bidirectional=bidirectional, |
| 34 | + ) |
| 35 | + input = torch.randn(seq_len, batch, input_size) |
| 36 | + h_0 = torch.randn(num_layers * num_directions, batch, hidden_size) |
| 37 | + output, h_n = gru(input, h_0) |
| 38 | + bitstream = io.BytesIO() |
| 39 | + torch.onnx.export( |
| 40 | + model=gru, |
| 41 | + args=(input, h_0), |
| 42 | + f=bitstream, |
| 43 | + input_names=["input", "h_0"], |
| 44 | + opset_version=11, |
| 45 | + dynamo=False, # Use legacy exporter for GRU compatibility |
| 46 | + dynamic_axes={ |
| 47 | + "input": {0: "seq_len", 1: "batch"}, |
| 48 | + "h_0": {1: "batch"}, |
| 49 | + }, |
| 50 | + ) |
| 51 | + bitstream_data = bitstream.getvalue() |
| 52 | + |
| 53 | + onnx_gru = onnx.ModelProto.FromString(bitstream_data) |
| 54 | + o2p_gru = ConvertModel(onnx_gru, experimental=True) |
| 55 | + with torch.no_grad(): |
| 56 | + o2p_output, o2p_h_n = o2p_gru(input, h_0) |
| 57 | + torch.testing.assert_close(o2p_output, output, rtol=1e-6, atol=1e-6) |
| 58 | + torch.testing.assert_close(o2p_h_n, h_n, rtol=1e-6, atol=1e-6) |
| 59 | + |
| 60 | + onnx_gru = onnx.ModelProto.FromString(bitstream_data) |
| 61 | + o2p_gru = ConvertModel(onnx_gru, experimental=True) |
| 62 | + with torch.no_grad(): |
| 63 | + o2p_output, o2p_h_n = o2p_gru(h_0=h_0, input=input) |
| 64 | + torch.testing.assert_close(o2p_output, output, rtol=1e-6, atol=1e-6) |
| 65 | + torch.testing.assert_close(o2p_h_n, h_n, rtol=1e-6, atol=1e-6) |
| 66 | + with pytest.raises(KeyError): |
| 67 | + o2p_output, o2p_h_n = o2p_gru(input=input) |
| 68 | + with pytest.raises(Exception): |
| 69 | + # Even though initial states are optional for nn.GRU(), |
| 70 | + # we adhere to onnxruntime convention that inputs are provided |
| 71 | + # as either all positional or all keyword arguments. |
| 72 | + o2p_output, o2p_h_n = o2p_gru(input, h_0=h_0) |
| 73 | + |
| 74 | + |
| 75 | +@pytest.mark.parametrize("linear_before_reset", [0, 1]) |
| 76 | +@pytest.mark.parametrize("bidirectional", [False, True]) |
| 77 | +def test_gru_linear_before_reset(linear_before_reset, bidirectional): |
| 78 | + """Test GRU with both linear_before_reset=0 (ONNX/TensorFlow default) and =1 (PyTorch default).""" |
| 79 | + torch.manual_seed(42) |
| 80 | + np.random.seed(42) |
| 81 | + |
| 82 | + input_size = 3 |
| 83 | + hidden_size = 4 |
| 84 | + seq_len = 5 |
| 85 | + batch = 2 |
| 86 | + num_directions = 2 if bidirectional else 1 |
| 87 | + |
| 88 | + # Create input and initial hidden state |
| 89 | + X = np.random.randn(seq_len, batch, input_size).astype(np.float32) |
| 90 | + initial_h = np.random.randn(num_directions, batch, hidden_size).astype(np.float32) |
| 91 | + |
| 92 | + # Create random weights for GRU |
| 93 | + # W shape: [num_directions, 3*hidden_size, input_size] |
| 94 | + W = np.random.randn(num_directions, 3 * hidden_size, input_size).astype(np.float32) |
| 95 | + # R shape: [num_directions, 3*hidden_size, hidden_size] |
| 96 | + R = np.random.randn(num_directions, 3 * hidden_size, hidden_size).astype(np.float32) |
| 97 | + # B shape: [num_directions, 6*hidden_size] (Wb and Rb concatenated) |
| 98 | + B = np.random.randn(num_directions, 6 * hidden_size).astype(np.float32) |
| 99 | + |
| 100 | + # Create ONNX graph with GRU node |
| 101 | + input_tensor = helper.make_tensor_value_info( |
| 102 | + "X", TensorProto.FLOAT, [seq_len, batch, input_size] |
| 103 | + ) |
| 104 | + initial_h_tensor = helper.make_tensor_value_info( |
| 105 | + "initial_h", TensorProto.FLOAT, [num_directions, batch, hidden_size] |
| 106 | + ) |
| 107 | + output_tensor = helper.make_tensor_value_info( |
| 108 | + "Y", TensorProto.FLOAT, [seq_len, num_directions, batch, hidden_size] |
| 109 | + ) |
| 110 | + output_h_tensor = helper.make_tensor_value_info( |
| 111 | + "Y_h", TensorProto.FLOAT, [num_directions, batch, hidden_size] |
| 112 | + ) |
| 113 | + |
| 114 | + W_initializer = helper.make_tensor( |
| 115 | + "W", TensorProto.FLOAT, W.shape, W.flatten().tolist() |
| 116 | + ) |
| 117 | + R_initializer = helper.make_tensor( |
| 118 | + "R", TensorProto.FLOAT, R.shape, R.flatten().tolist() |
| 119 | + ) |
| 120 | + B_initializer = helper.make_tensor( |
| 121 | + "B", TensorProto.FLOAT, B.shape, B.flatten().tolist() |
| 122 | + ) |
| 123 | + |
| 124 | + direction = "bidirectional" if bidirectional else "forward" |
| 125 | + gru_node = helper.make_node( |
| 126 | + "GRU", |
| 127 | + inputs=["X", "W", "R", "B", "", "initial_h"], |
| 128 | + outputs=["Y", "Y_h"], |
| 129 | + hidden_size=hidden_size, |
| 130 | + linear_before_reset=linear_before_reset, |
| 131 | + direction=direction, |
| 132 | + ) |
| 133 | + |
| 134 | + graph = helper.make_graph( |
| 135 | + [gru_node], |
| 136 | + "gru_test", |
| 137 | + [input_tensor, initial_h_tensor], |
| 138 | + [output_tensor, output_h_tensor], |
| 139 | + [W_initializer, R_initializer, B_initializer], |
| 140 | + ) |
| 141 | + |
| 142 | + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 14)]) |
| 143 | + onnx.checker.check_model(model) |
| 144 | + |
| 145 | + # Run with onnxruntime to get expected output |
| 146 | + ort_session = ort.InferenceSession(model.SerializeToString()) |
| 147 | + ort_inputs = {"X": X, "initial_h": initial_h} |
| 148 | + ort_outputs = ort_session.run(None, ort_inputs) |
| 149 | + expected_Y, expected_Y_h = ort_outputs |
| 150 | + |
| 151 | + # Convert to PyTorch and run |
| 152 | + o2p_gru = ConvertModel(model, experimental=True) |
| 153 | + X_torch = torch.from_numpy(X) |
| 154 | + initial_h_torch = torch.from_numpy(initial_h) |
| 155 | + |
| 156 | + with torch.no_grad(): |
| 157 | + o2p_output, o2p_h_n = o2p_gru(X_torch, initial_h_torch) |
| 158 | + |
| 159 | + # Compare with onnxruntime outputs |
| 160 | + torch.testing.assert_close( |
| 161 | + o2p_output, |
| 162 | + torch.from_numpy(expected_Y), |
| 163 | + rtol=1e-5, |
| 164 | + atol=1e-5, |
| 165 | + ) |
| 166 | + torch.testing.assert_close( |
| 167 | + o2p_h_n, |
| 168 | + torch.from_numpy(expected_Y_h), |
| 169 | + rtol=1e-5, |
| 170 | + atol=1e-5, |
| 171 | + ) |
0 commit comments