@@ -53,17 +53,17 @@ def test_single_layer_lstm(
5353 o2p_lstm = ConvertModel (onnx_lstm , experimental = True )
5454 with torch .no_grad ():
5555 o2p_output , o2p_h_n , o2p_c_n = o2p_lstm (input , h_0 , c_0 )
56- torch .testing .assert_allclose (o2p_output , output , rtol = 1e-6 , atol = 1e-6 )
57- torch .testing .assert_allclose (o2p_h_n , h_n , rtol = 1e-6 , atol = 1e-6 )
58- torch .testing .assert_allclose (o2p_c_n , c_n , rtol = 1e-6 , atol = 1e-6 )
56+ torch .testing .assert_close (o2p_output , output , rtol = 1e-6 , atol = 1e-6 )
57+ torch .testing .assert_close (o2p_h_n , h_n , rtol = 1e-6 , atol = 1e-6 )
58+ torch .testing .assert_close (o2p_c_n , c_n , rtol = 1e-6 , atol = 1e-6 )
5959
6060 onnx_lstm = onnx .ModelProto .FromString (bitstream_data )
6161 o2p_lstm = ConvertModel (onnx_lstm , experimental = True )
6262 with torch .no_grad ():
6363 o2p_output , o2p_h_n , o2p_c_n = o2p_lstm (h_0 = h_0 , input = input , c_0 = c_0 )
64- torch .testing .assert_allclose (o2p_output , output , rtol = 1e-6 , atol = 1e-6 )
65- torch .testing .assert_allclose (o2p_h_n , h_n , rtol = 1e-6 , atol = 1e-6 )
66- torch .testing .assert_allclose (o2p_c_n , c_n , rtol = 1e-6 , atol = 1e-6 )
64+ torch .testing .assert_close (o2p_output , output , rtol = 1e-6 , atol = 1e-6 )
65+ torch .testing .assert_close (o2p_h_n , h_n , rtol = 1e-6 , atol = 1e-6 )
66+ torch .testing .assert_close (o2p_c_n , c_n , rtol = 1e-6 , atol = 1e-6 )
6767 with pytest .raises (KeyError ):
6868 o2p_output , o2p_h_n , o2p_c_n = o2p_lstm (h_0 = h_0 , input = input )
6969 with pytest .raises (Exception ):
0 commit comments