Skip to content

Commit 8750c19

Browse files
authored
Merge pull request #23 from calvinmccarter-at-lightmatter/master
Fix _LazyBatchNorm import broken by torch 1.10.0
2 parents ef9ae6c + bfb0677 commit 8750c19

File tree

3 files changed

+85
-63
lines changed

3 files changed

+85
-63
lines changed

onnx2pytorch/operations/batchnorm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
import warnings
22

33
from torch import nn
4-
from torch.nn.modules.batchnorm import _BatchNorm, _LazyBatchNorm
4+
from torch.nn.modules.batchnorm import _BatchNorm
5+
6+
try:
7+
from torch.nn.modules.batchnorm import _LazyNormBase
8+
9+
class _LazyBatchNorm(_LazyNormBase, _BatchNorm):
10+
11+
cls_to_become = _BatchNorm
12+
13+
14+
except ImportError:
15+
# for torch < 1.10.0
16+
from torch.nn.modules.batchnorm import _LazyBatchNorm
517

618

719
class LazyBatchNormUnsafe(_LazyBatchNorm):

onnx2pytorch/operations/instancenorm.py

Lines changed: 70 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,71 +2,80 @@
22
import torch
33

44
from torch.nn.modules.instancenorm import _InstanceNorm
5-
from torch.nn.modules.lazy import LazyModuleMixin
6-
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
7-
8-
9-
class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm):
10-
11-
weight: UninitializedParameter # type: ignore[assignment]
12-
bias: UninitializedParameter # type: ignore[assignment]
13-
14-
cls_to_become = _InstanceNorm
15-
16-
def __init__(
17-
self,
18-
eps=1e-5,
19-
momentum=0.1,
20-
affine=True,
21-
track_running_stats=True,
22-
device=None,
23-
dtype=None,
24-
) -> None:
25-
factory_kwargs = {"device": device, "dtype": dtype}
26-
super(_LazyInstanceNorm, self).__init__(
27-
# affine and track_running_stats are hardcoded to False to
28-
# avoid creating tensors that will soon be overwritten.
29-
0,
30-
eps,
31-
momentum,
32-
False,
33-
False,
34-
**factory_kwargs,
35-
)
36-
self.affine = affine
37-
self.track_running_stats = track_running_stats
38-
if self.affine:
39-
self.weight = UninitializedParameter(**factory_kwargs)
40-
self.bias = UninitializedParameter(**factory_kwargs)
41-
if self.track_running_stats:
42-
self.running_mean = UninitializedBuffer(**factory_kwargs)
43-
self.running_var = UninitializedBuffer(**factory_kwargs)
44-
self.num_batches_tracked = torch.tensor(
45-
0,
46-
dtype=torch.long,
47-
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
48-
)
495

50-
def reset_parameters(self) -> None:
51-
if not self.has_uninitialized_params() and self.num_features != 0:
52-
super().reset_parameters()
6+
try:
7+
from torch.nn.modules.batchnorm import _LazyNormBase
8+
9+
class _LazyInstanceNorm(_LazyNormBase, _InstanceNorm):
10+
11+
cls_to_become = _InstanceNorm
12+
13+
14+
except ImportError:
15+
from torch.nn.modules.lazy import LazyModuleMixin
16+
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
5317

54-
def initialize_parameters(self, input) -> None: # type: ignore[override]
55-
if self.has_uninitialized_params():
56-
self.num_features = input.shape[1]
18+
class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm):
19+
20+
weight: UninitializedParameter # type: ignore[assignment]
21+
bias: UninitializedParameter # type: ignore[assignment]
22+
23+
cls_to_become = _InstanceNorm
24+
25+
def __init__(
26+
self,
27+
eps=1e-5,
28+
momentum=0.1,
29+
affine=True,
30+
track_running_stats=True,
31+
device=None,
32+
dtype=None,
33+
) -> None:
34+
factory_kwargs = {"device": device, "dtype": dtype}
35+
super(_LazyInstanceNorm, self).__init__(
36+
# affine and track_running_stats are hardcoded to False to
37+
# avoid creating tensors that will soon be overwritten.
38+
0,
39+
eps,
40+
momentum,
41+
False,
42+
False,
43+
**factory_kwargs,
44+
)
45+
self.affine = affine
46+
self.track_running_stats = track_running_stats
5747
if self.affine:
58-
assert isinstance(self.weight, UninitializedParameter)
59-
assert isinstance(self.bias, UninitializedParameter)
60-
self.weight.materialize((self.num_features,))
61-
self.bias.materialize((self.num_features,))
48+
self.weight = UninitializedParameter(**factory_kwargs)
49+
self.bias = UninitializedParameter(**factory_kwargs)
6250
if self.track_running_stats:
63-
self.running_mean.materialize(
64-
(self.num_features,)
65-
) # type:ignore[union-attr]
66-
self.running_var.materialize(
67-
(self.num_features,)
68-
) # type:ignore[union-attr]
69-
self.reset_parameters()
51+
self.running_mean = UninitializedBuffer(**factory_kwargs)
52+
self.running_var = UninitializedBuffer(**factory_kwargs)
53+
self.num_batches_tracked = torch.tensor(
54+
0,
55+
dtype=torch.long,
56+
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
57+
)
58+
59+
def reset_parameters(self) -> None:
60+
if not self.has_uninitialized_params() and self.num_features != 0:
61+
super().reset_parameters()
62+
63+
def initialize_parameters(self, input) -> None: # type: ignore[override]
64+
if self.has_uninitialized_params():
65+
self.num_features = input.shape[1]
66+
if self.affine:
67+
assert isinstance(self.weight, UninitializedParameter)
68+
assert isinstance(self.bias, UninitializedParameter)
69+
self.weight.materialize((self.num_features,))
70+
self.bias.materialize((self.num_features,))
71+
if self.track_running_stats:
72+
self.running_mean.materialize(
73+
(self.num_features,)
74+
) # type:ignore[union-attr]
75+
self.running_var.materialize(
76+
(self.num_features,)
77+
) # type:ignore[union-attr]
78+
self.reset_parameters()
7079

7180

7281
class LazyInstanceNormUnsafe(_LazyInstanceNorm):

tox.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# and then run "tox" from this directory.
55

66
[tox]
7-
envlist = clean,py36,py37,py38
7+
envlist = clean,py36,py37,py38,py38-torch19,py39
88

99
[testenv]
1010
passenv =
1111
CIRCLE*
1212
KMP_DUPLICATE_LIB_OK
1313
deps =
1414
-rrequirements.txt
15+
torch19: torch <= 1.9.0.
1516
pytest-cov
1617
commands =
1718
pytest --cov --cov-append --cov-report term --cov-report html tests/

0 commit comments

Comments
 (0)