Skip to content

Commit d5bf79c

Browse files
add a test for Reshape operator with enable_pruning=False
1 parent 0d92be4 commit d5bf79c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/onnx2pytorch/operations/test_reshape.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def pruned_inp():
1414
return torch.rand(35, 1, 160)
1515

1616

17-
def test_reshape(inp, pruned_inp):
17+
@pytest.mark.parametrize("enable_pruning", [True, False])
18+
def test_reshape(inp, pruned_inp, enable_pruning):
1819
"""Pass shape in forward."""
1920
op = Reshape(enable_pruning=True)
2021
shape = torch.Size((35, 2, 100))
@@ -32,7 +33,8 @@ def test_reshape(inp, pruned_inp):
3233
assert out.shape == expected_shape
3334

3435

35-
def test_reshape_2(inp, pruned_inp):
36+
@pytest.mark.parametrize("enable_pruning", [True, False])
37+
def test_reshape_2(inp, pruned_inp, enable_pruning):
3638
"""Pass shape in init."""
3739
shape = torch.Size((35, 2, 100))
3840
op = Reshape(enable_pruning=True, shape=shape)

0 commit comments

Comments
 (0)