Skip to content

Commit 2725457

Browse files
authored
[Fix] fix SiLU activation (#447)
* fix SiLU activation * revise registry condition of SiLU
1 parent f6551e1 commit 2725457

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ repos:
3434
rev: 0.7.9
3535
hooks:
3636
- id: mdformat
37-
language_version: python3.7
3837
args: ["--number", "--table-width", "200"]
3938
additional_dependencies:
4039
- mdformat-openmmlab

mmgen/models/architectures/ddpm/modules.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,39 @@ def forward(self, x, y):
3232
return x
3333

3434

35-
@ACTIVATION_LAYERS.register_module()
36-
class SiLU(nn.Module):
37-
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
38-
The SiLU function is also known as the swish function.
39-
Args:
40-
input (bool, optional): Use inplace operation or not.
41-
Defaults to `False`.
42-
"""
35+
if 'SiLU' not in ACTIVATION_LAYERS:
4336

44-
def __init__(self, inplace=False):
45-
super().__init__()
46-
if digit_version(
47-
torch.__version__) < digit_version('1.7.0') and inplace:
48-
mmcv.print_log('Inplace version of \'SiLU\' is not supported for '
49-
f'torch < 1.7.0, found \'{torch.version}\'.')
50-
self.inplace = inplace
51-
52-
def forward(self, x):
53-
"""Forward function for SiLU.
37+
@ACTIVATION_LAYERS.register_module()
38+
class SiLU(nn.Module):
39+
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
40+
The SiLU function is also known as the swish function.
5441
Args:
55-
x (torch.Tensor): Input tensor.
56-
57-
Returns:
58-
torch.Tensor: Tensor after activation.
42+
input (bool, optional): Use inplace operation or not.
43+
Defaults to `False`.
5944
"""
6045

61-
if digit_version(torch.__version__) < digit_version('1.7.0'):
62-
return x * torch.sigmoid(x)
46+
def __init__(self, inplace=False):
47+
super().__init__()
48+
if digit_version(
49+
torch.__version__) < digit_version('1.7.0') and inplace:
50+
mmcv.print_log('Inplace version of \'SiLU\' is not supported '
51+
'for torch < 1.7.0, found '
52+
f'\'{torch.version}\'.')
53+
self.inplace = inplace
54+
55+
def forward(self, x):
56+
"""Forward function for SiLU.
57+
Args:
58+
x (torch.Tensor): Input tensor.
59+
60+
Returns:
61+
torch.Tensor: Tensor after activation.
62+
"""
63+
64+
if digit_version(torch.__version__) < digit_version('1.7.0'):
65+
return x * torch.sigmoid(x)
6366

64-
return F.silu(x, inplace=self.inplace)
67+
return F.silu(x, inplace=self.inplace)
6568

6669

6770
@MODULES.register_module()

0 commit comments

Comments
 (0)