77 from torch .nn .modules .batchnorm import _LazyNormBase
88
99 class _LazyInstanceNorm (_LazyNormBase , _InstanceNorm ):
10-
1110 cls_to_become = _InstanceNorm
1211
13-
1412except ImportError :
1513 from torch .nn .modules .lazy import LazyModuleMixin
1614 from torch .nn .parameter import UninitializedBuffer , UninitializedParameter
1715
1816 class _LazyInstanceNorm (LazyModuleMixin , _InstanceNorm ):
19-
2017 weight : UninitializedParameter # type: ignore[assignment]
2118 bias : UninitializedParameter # type: ignore[assignment]
2219
@@ -78,24 +75,29 @@ def initialize_parameters(self, input) -> None: # type: ignore[override]
7875 self .reset_parameters ()
7976
8077
81- class LazyInstanceNormUnsafe ( _LazyInstanceNorm ) :
78+ class InstanceNormMixin :
8279 """Skips dimension check."""
8380
8481 def __init__ (self , * args , affine = True , ** kwargs ):
82+ self .no_batch_dim = None # no_batch_dim has to be set at runtime
8583 super ().__init__ (* args , affine = affine , ** kwargs )
8684
85+ def set_no_dim_batch_dim (self , no_batch_dim ):
86+ self .no_batch_dim = no_batch_dim
87+
8788 def _check_input_dim (self , input ):
8889 return
8990
91+ def _get_no_batch_dim (self ):
92+ return self .no_batch_dim
9093
91- class InstanceNormUnsafe (_InstanceNorm ):
92- """Skips dimension check."""
9394
94- def __init__ ( self , * args , affine = True , ** kwargs ):
95- super (). __init__ ( * args , affine = affine , ** kwargs )
95+ class LazyInstanceNormUnsafe ( InstanceNormMixin , _LazyInstanceNorm ):
96+ pass
9697
97- def _check_input_dim (self , input ):
98- return
98+
99+ class InstanceNormUnsafe (InstanceNormMixin , _InstanceNorm ):
100+ pass
99101
100102
101103class InstanceNormWrapper (torch .nn .Module ):
@@ -120,4 +122,7 @@ def forward(self, input, scale=None, B=None):
120122 if B is not None :
121123 getattr (self .inu , "bias" ).data = B
122124
125+ if self .inu .no_batch_dim is None :
126+ self .inu .set_no_dim_batch_dim (input .dim () - 1 )
127+
123128 return self .inu .forward (input )
0 commit comments