Skip to content

Commit 29f49f5

Browse files
test slice _to_positive_step
1 parent d5bf79c commit 29f49f5

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

onnx2pytorch/operations/slice.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
def _to_positive_step(orig_slice, N):
66
"""
7-
Convert a slice object with a negative step to an equivalent one with a
8-
positive step, computed using N, the length of the iterable being sliced.
9-
This is because PyTorch currently does not support slicing a tensor with
10-
a negative step.
7+
Convert a slice object with a negative step to one with a positive step.
8+
Accessing an iterable with the positive-stepped slice, followed by flipping
9+
the result, should be equivalent to accessing the tensor with the original
10+
slice. Computing positive-step slice requires using N, the length of the
11+
iterable being sliced. This is because PyTorch currently does not support
12+
slicing a tensor with a negative step.
1113
"""
1214
# Get rid of backward slices
1315
start, stop, step = orig_slice.indices(N)

tests/onnx2pytorch/operations/test_slice.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from onnx2pytorch.operations import Slice
6+
from onnx2pytorch.operations.slice import _to_positive_step
67

78

89
@pytest.fixture
@@ -120,3 +121,13 @@ def test_slice_neg_steps(x, init):
120121
else:
121122
op = Slice()
122123
assert torch.equal(op(x, starts, ends, axes, steps), y)
124+
125+
126+
def test_to_positive_step():
127+
assert _to_positive_step(slice(-1, None, -1), 8) == slice(0, 8, 1)
128+
assert _to_positive_step(slice(-2, None, -1), 8) == slice(0, 7, 1)
129+
assert _to_positive_step(slice(None, -1, -1), 8) == slice(0, 0, 1)
130+
assert _to_positive_step(slice(None, -2, -1), 8) == slice(7, 8, 1)
131+
assert _to_positive_step(slice(None, None, -1), 8) == slice(0, 8, 1)
132+
assert _to_positive_step(slice(8, 1, -2), 8) == slice(3, 8, 2)
133+
assert _to_positive_step(slice(8, 0, -2), 8) == slice(1, 8, 2)

0 commit comments

Comments
 (0)