Skip to content

Commit 03e81a3

Browse files
authored
Merge pull request #66 from Talmaj/talmaj/updates
Fix instance norm tests Update pre-commit config Update tox.ini Fix test_clip Add lint_and_test GitHub action Remove circleci
2 parents 89a3fe7 + 6d2812a commit 03e81a3

File tree

9 files changed

+84
-61
lines changed

9 files changed

+84
-61
lines changed

.circleci/config.yml

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: Lint and Test
2+
3+
on: [pull_request]
4+
5+
jobs:
6+
test:
7+
runs-on: ubuntu-latest
8+
strategy:
9+
matrix:
10+
python-version: [ '3.9', '3.10', '3.11', '3.12' ]
11+
12+
steps:
13+
- name: Checkout code
14+
uses: actions/checkout@v4
15+
16+
- name: Set up Python ${{ matrix.python-version }}
17+
uses: actions/setup-python@v5
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
21+
- name: Cache dependencies
22+
uses: actions/cache@v4
23+
with:
24+
path: ~/.cache/pip
25+
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements.txt') }}
26+
restore-keys: |
27+
${{ runner.os }}-pip-${{ matrix.python-version }}-
28+
${{ runner.os }}-pip-
29+
30+
- name: Install dependencies
31+
run: |
32+
pip install tox tox-gh-actions
33+
34+
- name: Run tests
35+
run: |
36+
bash download_fixtures.sh
37+
tox
38+
39+
- name: Upload coverage to GitHub Artifacts
40+
uses: actions/upload-artifact@v4
41+
with:
42+
name: coverage-${{ matrix.python-version }}
43+
path: htmlcov/

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: stable
3+
rev: 24.8.0
44
hooks:
55
- id: black
6-
language_version: python3.8
6+
language_version: python3.10

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# ONNX to PyTorch
22
![PyPI - License](https://img.shields.io/pypi/l/onnx2pytorch?color)
3-
[![CircleCI](https://circleci.com/gh/ToriML/onnx2pytorch.svg?style=shield)](https://app.circleci.com/pipelines/github/ToriML/onnx2pytorch)
3+
[![Lint and Test](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml/badge.svg)](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml)
44
[![Downloads](https://pepy.tech/badge/onnx2pytorch)](https://pepy.tech/project/onnx2pytorch)
55
![PyPI](https://img.shields.io/pypi/v/onnx2pytorch)
66

download_fixtures.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ fi
99

1010
if [[ ! -f shufflenet_v2.onnx ]]; then
1111
echo Downloading shufflenet_v2
12-
curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/blob/master/vision/classification/shufflenet/model/shufflenet-v2-10.onnx\?raw\=true
12+
curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/shufflenet/model/shufflenet-v2-10.onnx
1313
fi
1414

1515
if [[ $1 == "--all" ]]; then
@@ -20,32 +20,32 @@ if [[ $1 == "--all" ]]; then
2020

2121
if [[ ! -f bertsquad-10.onnx ]]; then
2222
echo Downloading bertsquad-10
23-
curl -LJo bertsquad-10.onnx https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx\?raw\=true
23+
curl -LJo bertsquad-10.onnx https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
2424
fi
2525

2626
if [[ ! -f yolo_v4.onnx ]]; then
2727
echo Downloading yolo_v4
28-
curl -LJo yolo_v4.onnx https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/yolov4/model/yolov4.onnx\?raw\=true
28+
curl -LJo yolo_v4.onnx https://github.com/onnx/models/raw/main/validated/vision/object_detection_segmentation/yolov4/model/yolov4.onnx
2929
fi
3030

3131
if [[ ! -f super_res.onnx ]]; then
3232
echo Downloading super_res
33-
curl -LJo super_res.onnx https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx\?raw\=true
33+
curl -LJo super_res.onnx https://github.com/onnx/models/raw/main/validated/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx
3434
fi
3535

3636
if [[ ! -f fast_neural_style.onnx ]]; then
3737
echo Downloading fast_neural_style
38-
curl -LJo fast_neural_style.onnx https://github.com/onnx/models/blob/master/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx\?raw\=true
38+
curl -LJo fast_neural_style.onnx https://github.com/onnx/models/raw/main/validated/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx
3939
fi
4040

4141
if [[ ! -f efficientnet-lite4.onnx ]]; then
4242
echo Downloading efficientnet-lite4
43-
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx\?raw\=true
43+
curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx
4444
fi
4545

4646
if [[ ! -f mobilenetv2-7.onnx ]]; then
4747
echo Downloading mobilenetv2-7
48-
curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/master/vision/classification/mobilenet/model/mobilenetv2-7.onnx\?raw\=true
48+
curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx
4949
fi
5050

5151
fi

onnx2pytorch/operations/instancenorm.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77
from torch.nn.modules.batchnorm import _LazyNormBase
88

99
class _LazyInstanceNorm(_LazyNormBase, _InstanceNorm):
10-
1110
cls_to_become = _InstanceNorm
1211

13-
1412
except ImportError:
1513
from torch.nn.modules.lazy import LazyModuleMixin
1614
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
1715

1816
class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm):
19-
2017
weight: UninitializedParameter # type: ignore[assignment]
2118
bias: UninitializedParameter # type: ignore[assignment]
2219

@@ -78,24 +75,29 @@ def initialize_parameters(self, input) -> None: # type: ignore[override]
7875
self.reset_parameters()
7976

8077

81-
class LazyInstanceNormUnsafe(_LazyInstanceNorm):
78+
class InstanceNormMixin:
8279
"""Skips dimension check."""
8380

8481
def __init__(self, *args, affine=True, **kwargs):
82+
self.no_batch_dim = None # no_batch_dim has to be set at runtime
8583
super().__init__(*args, affine=affine, **kwargs)
8684

85+
def set_no_dim_batch_dim(self, no_batch_dim):
86+
self.no_batch_dim = no_batch_dim
87+
8788
def _check_input_dim(self, input):
8889
return
8990

91+
def _get_no_batch_dim(self):
92+
return self.no_batch_dim
9093

91-
class InstanceNormUnsafe(_InstanceNorm):
92-
"""Skips dimension check."""
9394

94-
def __init__(self, *args, affine=True, **kwargs):
95-
super().__init__(*args, affine=affine, **kwargs)
95+
class LazyInstanceNormUnsafe(InstanceNormMixin, _LazyInstanceNorm):
96+
pass
9697

97-
def _check_input_dim(self, input):
98-
return
98+
99+
class InstanceNormUnsafe(InstanceNormMixin, _InstanceNorm):
100+
pass
99101

100102

101103
class InstanceNormWrapper(torch.nn.Module):
@@ -120,4 +122,7 @@ def forward(self, input, scale=None, B=None):
120122
if B is not None:
121123
getattr(self.inu, "bias").data = B
122124

125+
if self.inu.no_batch_dim is None:
126+
self.inu.set_no_dim_batch_dim(input.dim() - 1)
127+
123128
return self.inu.forward(input)

tests/onnx2pytorch/convert/test_lstm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ def test_single_layer_lstm(
5252
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
5353
with torch.no_grad():
5454
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(input, h_0, c_0)
55-
assert torch.equal(o2p_output, output)
56-
assert torch.equal(o2p_h_n, h_n)
57-
assert torch.equal(o2p_c_n, c_n)
55+
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
56+
torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
57+
torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6)
5858

5959
onnx_lstm = onnx.ModelProto.FromString(bitstream_data)
6060
o2p_lstm = ConvertModel(onnx_lstm, experimental=True)
6161
with torch.no_grad():
6262
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input, c_0=c_0)
63-
assert torch.equal(o2p_output, output)
64-
assert torch.equal(o2p_h_n, h_n)
65-
assert torch.equal(o2p_c_n, c_n)
63+
torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6)
64+
torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6)
65+
torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6)
6666
with pytest.raises(KeyError):
6767
o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input)
6868
with pytest.raises(Exception):

tests/onnx2pytorch/operations/test_clip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ def test_clip():
2020
assert torch.equal(op(x), exp_y)
2121

2222
op = Clip(max=0)
23-
exp_y_np = np.clip(x_np, np.NINF, 0)
23+
exp_y_np = np.clip(x_np, -np.inf, 0)
2424
exp_y = torch.from_numpy(exp_y_np)
2525
assert torch.equal(op(x), exp_y)
2626

2727
op = Clip()
28-
exp_y_np = np.clip(x_np, np.NINF, np.inf)
28+
exp_y_np = np.clip(x_np, -np.inf, np.inf)
2929
exp_y = torch.from_numpy(exp_y_np)
3030
assert torch.equal(op(x), exp_y)

tox.ini

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
# and then run "tox" from this directory.
55

66
[tox]
7-
envlist = clean,py36,py37,py38,py38-torch19,py39
7+
envlist = clean,py39,py310,py311,py312
8+
9+
[gh-actions]
10+
python =
11+
3.9: py39
12+
3.10: py310
13+
3.11: py311
14+
3.12: py312
815

916
[testenv]
1017
passenv =

0 commit comments

Comments
 (0)