@@ -32,36 +32,39 @@ def forward(self, x, y):
3232 return x
3333
3434
35- @ACTIVATION_LAYERS .register_module ()
36- class SiLU (nn .Module ):
37- r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
38- The SiLU function is also known as the swish function.
39- Args:
40- input (bool, optional): Use inplace operation or not.
41- Defaults to `False`.
42- """
35+ if 'SiLU' not in ACTIVATION_LAYERS :
4336
44- def __init__ (self , inplace = False ):
45- super ().__init__ ()
46- if digit_version (
47- torch .__version__ ) < digit_version ('1.7.0' ) and inplace :
48- mmcv .print_log ('Inplace version of \' SiLU\' is not supported for '
49- f'torch < 1.7.0, found \' { torch .version } \' .' )
50- self .inplace = inplace
51-
52- def forward (self , x ):
53- """Forward function for SiLU.
37+ @ACTIVATION_LAYERS .register_module ()
38+ class SiLU (nn .Module ):
39+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
40+ The SiLU function is also known as the swish function.
5441 Args:
55- x (torch.Tensor): Input tensor.
56-
57- Returns:
58- torch.Tensor: Tensor after activation.
42+ input (bool, optional): Use inplace operation or not.
43+ Defaults to `False`.
5944 """
6045
61- if digit_version (torch .__version__ ) < digit_version ('1.7.0' ):
62- return x * torch .sigmoid (x )
46+ def __init__ (self , inplace = False ):
47+ super ().__init__ ()
48+ if digit_version (
49+ torch .__version__ ) < digit_version ('1.7.0' ) and inplace :
50+ mmcv .print_log ('Inplace version of \' SiLU\' is not supported '
51+ 'for torch < 1.7.0, found '
52+ f'\' { torch .version } \' .' )
53+ self .inplace = inplace
54+
55+ def forward (self , x ):
56+ """Forward function for SiLU.
57+ Args:
58+ x (torch.Tensor): Input tensor.
59+
60+ Returns:
61+ torch.Tensor: Tensor after activation.
62+ """
63+
64+ if digit_version (torch .__version__ ) < digit_version ('1.7.0' ):
65+ return x * torch .sigmoid (x )
6366
64- return F .silu (x , inplace = self .inplace )
67+ return F .silu (x , inplace = self .inplace )
6568
6669
6770@MODULES .register_module ()
0 commit comments