Skip to content

Commit 9247cbc

Browse files
committed
Enforce black formatting.
1 parent e70771d commit 9247cbc

File tree

6 files changed

+8
-5
lines changed

6 files changed

+8
-5
lines changed

onnx2pytorch/operations/expand.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44

55
class Expand(nn.Module):
66
def forward(self, input: torch.Tensor, shape: torch.Tensor):
7-
#return input.expand(torch.Size(shape))
7+
# return input.expand(torch.Size(shape))
88
return input * torch.ones(torch.Size(shape), dtype=input.dtype)
9-

onnx2pytorch/operations/squeeze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self, opset_version, dim=None):
1010
self.dim = dim
1111
super().__init__()
1212

13-
def forward(self, input: torch.Tensor, axes: torch.Tensor=None):
13+
def forward(self, input: torch.Tensor, axes: torch.Tensor = None):
1414
if self.opset_version < 13:
1515
dims = self.dim
1616
else:

onnx2pytorch/operations/unsqueeze.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
from onnx2pytorch.operations.base import Operator
55

6+
67
class Unsqueeze(Operator):
78
def __init__(self, opset_version, dim=None):
89
self.opset_version = opset_version
910
self.dim = dim
1011
super().__init__()
1112

12-
def forward(self, data: torch.Tensor, axes: torch.Tensor=None):
13+
def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
1314
if self.opset_version < 13:
1415
dims = self.dim
1516
else:
@@ -22,4 +23,3 @@ def forward(self, data: torch.Tensor, axes: torch.Tensor=None):
2223
for dim in sorted(dims, reverse=True):
2324
data = torch.unsqueeze(data, dim=dim)
2425
return data
25-

tests/onnx2pytorch/operations/test_expand.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from onnx2pytorch.operations.expand import Expand
55

6+
67
def test_expand_dim_changed():
78
op = Expand()
89
inp = torch.reshape(torch.arange(0, 3, dtype=torch.float32), [3, 1])
@@ -12,6 +13,7 @@ def test_expand_dim_changed():
1213
assert tuple(op(inp, new_shape).shape) == exp_shape
1314
assert torch.equal(op(inp, new_shape), exp)
1415

16+
1517
def test_expand_dim_unchanged():
1618
op = Expand()
1719
inp = torch.reshape(torch.arange(0, 3, dtype=torch.int32), [3, 1])

tests/onnx2pytorch/operations/test_range.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from onnx2pytorch.operations.range import Range
55

6+
67
@pytest.mark.parametrize(
78
"start, limit, delta, expected",
89
[

tests/onnx2pytorch/operations/test_squeeze.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_squeeze_v11(inp, dim, exp_shape):
2323
op = Squeeze(opset_version=11, dim=dim)
2424
assert tuple(op(inp).shape) == exp_shape
2525

26+
2627
@pytest.mark.parametrize(
2728
"dim, exp_shape",
2829
[

0 commit comments

Comments
 (0)