|
2 | 2 | import torch |
3 | 3 |
|
4 | 4 | from torch.nn.modules.instancenorm import _InstanceNorm |
5 | | -from torch.nn.modules.lazy import LazyModuleMixin |
6 | | -from torch.nn.parameter import UninitializedBuffer, UninitializedParameter |
7 | | - |
8 | | - |
9 | | -class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm): |
10 | | - |
11 | | - weight: UninitializedParameter # type: ignore[assignment] |
12 | | - bias: UninitializedParameter # type: ignore[assignment] |
13 | | - |
14 | | - cls_to_become = _InstanceNorm |
15 | | - |
16 | | - def __init__( |
17 | | - self, |
18 | | - eps=1e-5, |
19 | | - momentum=0.1, |
20 | | - affine=True, |
21 | | - track_running_stats=True, |
22 | | - device=None, |
23 | | - dtype=None, |
24 | | - ) -> None: |
25 | | - factory_kwargs = {"device": device, "dtype": dtype} |
26 | | - super(_LazyInstanceNorm, self).__init__( |
27 | | - # affine and track_running_stats are hardcoded to False to |
28 | | - # avoid creating tensors that will soon be overwritten. |
29 | | - 0, |
30 | | - eps, |
31 | | - momentum, |
32 | | - False, |
33 | | - False, |
34 | | - **factory_kwargs, |
35 | | - ) |
36 | | - self.affine = affine |
37 | | - self.track_running_stats = track_running_stats |
38 | | - if self.affine: |
39 | | - self.weight = UninitializedParameter(**factory_kwargs) |
40 | | - self.bias = UninitializedParameter(**factory_kwargs) |
41 | | - if self.track_running_stats: |
42 | | - self.running_mean = UninitializedBuffer(**factory_kwargs) |
43 | | - self.running_var = UninitializedBuffer(**factory_kwargs) |
44 | | - self.num_batches_tracked = torch.tensor( |
45 | | - 0, |
46 | | - dtype=torch.long, |
47 | | - **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, |
48 | | - ) |
49 | 5 |
|
50 | | - def reset_parameters(self) -> None: |
51 | | - if not self.has_uninitialized_params() and self.num_features != 0: |
52 | | - super().reset_parameters() |
| 6 | +try: |
| 7 | + from torch.nn.modules.batchnorm import _LazyNormBase |
| 8 | + |
| 9 | + class _LazyInstanceNorm(_LazyNormBase, _InstanceNorm): |
| 10 | + |
| 11 | + cls_to_become = _InstanceNorm |
| 12 | + |
| 13 | + |
| 14 | +except ImportError: |
| 15 | + from torch.nn.modules.lazy import LazyModuleMixin |
| 16 | + from torch.nn.parameter import UninitializedBuffer, UninitializedParameter |
53 | 17 |
|
54 | | - def initialize_parameters(self, input) -> None: # type: ignore[override] |
55 | | - if self.has_uninitialized_params(): |
56 | | - self.num_features = input.shape[1] |
| 18 | + class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm): |
| 19 | + |
| 20 | + weight: UninitializedParameter # type: ignore[assignment] |
| 21 | + bias: UninitializedParameter # type: ignore[assignment] |
| 22 | + |
| 23 | + cls_to_become = _InstanceNorm |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + eps=1e-5, |
| 28 | + momentum=0.1, |
| 29 | + affine=True, |
| 30 | + track_running_stats=True, |
| 31 | + device=None, |
| 32 | + dtype=None, |
| 33 | + ) -> None: |
| 34 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 35 | + super(_LazyInstanceNorm, self).__init__( |
| 36 | + # affine and track_running_stats are hardcoded to False to |
| 37 | + # avoid creating tensors that will soon be overwritten. |
| 38 | + 0, |
| 39 | + eps, |
| 40 | + momentum, |
| 41 | + False, |
| 42 | + False, |
| 43 | + **factory_kwargs, |
| 44 | + ) |
| 45 | + self.affine = affine |
| 46 | + self.track_running_stats = track_running_stats |
57 | 47 | if self.affine: |
58 | | - assert isinstance(self.weight, UninitializedParameter) |
59 | | - assert isinstance(self.bias, UninitializedParameter) |
60 | | - self.weight.materialize((self.num_features,)) |
61 | | - self.bias.materialize((self.num_features,)) |
| 48 | + self.weight = UninitializedParameter(**factory_kwargs) |
| 49 | + self.bias = UninitializedParameter(**factory_kwargs) |
62 | 50 | if self.track_running_stats: |
63 | | - self.running_mean.materialize( |
64 | | - (self.num_features,) |
65 | | - ) # type:ignore[union-attr] |
66 | | - self.running_var.materialize( |
67 | | - (self.num_features,) |
68 | | - ) # type:ignore[union-attr] |
69 | | - self.reset_parameters() |
| 51 | + self.running_mean = UninitializedBuffer(**factory_kwargs) |
| 52 | + self.running_var = UninitializedBuffer(**factory_kwargs) |
| 53 | + self.num_batches_tracked = torch.tensor( |
| 54 | + 0, |
| 55 | + dtype=torch.long, |
| 56 | + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, |
| 57 | + ) |
| 58 | + |
| 59 | + def reset_parameters(self) -> None: |
| 60 | + if not self.has_uninitialized_params() and self.num_features != 0: |
| 61 | + super().reset_parameters() |
| 62 | + |
| 63 | + def initialize_parameters(self, input) -> None: # type: ignore[override] |
| 64 | + if self.has_uninitialized_params(): |
| 65 | + self.num_features = input.shape[1] |
| 66 | + if self.affine: |
| 67 | + assert isinstance(self.weight, UninitializedParameter) |
| 68 | + assert isinstance(self.bias, UninitializedParameter) |
| 69 | + self.weight.materialize((self.num_features,)) |
| 70 | + self.bias.materialize((self.num_features,)) |
| 71 | + if self.track_running_stats: |
| 72 | + self.running_mean.materialize( |
| 73 | + (self.num_features,) |
| 74 | + ) # type:ignore[union-attr] |
| 75 | + self.running_var.materialize( |
| 76 | + (self.num_features,) |
| 77 | + ) # type:ignore[union-attr] |
| 78 | + self.reset_parameters() |
70 | 79 |
|
71 | 80 |
|
72 | 81 | class LazyInstanceNormUnsafe(_LazyInstanceNorm): |
|
0 commit comments