@@ -72,11 +72,12 @@ def build_transformer(cfg, return_intermediate=False):
72
72
# ----------------- Transformer Encoder modules -----------------
73
73
class TransformerEncoderLayer (nn .Module ):
74
74
def __init__ (self ,
75
- d_model :int = 256 ,
76
- num_heads :int = 8 ,
77
- ffn_dim :int = 1024 ,
78
- dropout :float = 0.1 ,
79
- act_type :str = "relu" ,
75
+ d_model :int = 256 ,
76
+ num_heads :int = 8 ,
77
+ ffn_dim :int = 1024 ,
78
+ dropout :float = 0.1 ,
79
+ act_type :str = "relu" ,
80
+ pre_norm :bool = False ,
80
81
):
81
82
super ().__init__ ()
82
83
# ----------- Basic parameters -----------
@@ -85,6 +86,7 @@ def __init__(self,
85
86
self .ffn_dim = ffn_dim
86
87
self .dropout = dropout
87
88
self .act_type = act_type
89
+ self .pre_norm = pre_norm
88
90
# ----------- Basic parameters -----------
89
91
# Multi-head Self-Attn
90
92
self .self_attn = nn .MultiheadAttention (d_model , num_heads , dropout = dropout , batch_first = True )
@@ -97,7 +99,27 @@ def __init__(self,
97
99
def with_pos_embed (self , tensor , pos ):
98
100
return tensor if pos is None else tensor + pos
99
101
100
- def forward (self , src , pos_embed ):
102
+ def forward_pre_norm (self , src , pos_embed ):
103
+ """
104
+ Input:
105
+ src: [torch.Tensor] -> [B, N, C]
106
+ pos_embed: [torch.Tensor] -> [B, N, C]
107
+ Output:
108
+ src: [torch.Tensor] -> [B, N, C]
109
+ """
110
+ src = self .norm (src )
111
+ q = k = self .with_pos_embed (src , pos_embed )
112
+
113
+ # -------------- MHSA --------------
114
+ src2 = self .self_attn (q , k , value = src )[0 ]
115
+ src = src + self .dropout (src2 )
116
+
117
+ # -------------- FFN --------------
118
+ src = self .ffn (src )
119
+
120
+ return src
121
+
122
+ def forward_post_norm (self , src , pos_embed ):
101
123
"""
102
124
Input:
103
125
src: [torch.Tensor] -> [B, N, C]
@@ -117,15 +139,22 @@ def forward(self, src, pos_embed):
117
139
118
140
return src
119
141
142
+ def forward (self , src , pos_embed ):
143
+ if self .pre_norm :
144
+ return self .forward_pre_norm (src , pos_embed )
145
+ else :
146
+ return self .forward_post_norm (src , pos_embed )
147
+
120
148
class TransformerEncoder (nn .Module ):
121
149
def __init__ (self ,
122
150
d_model :int = 256 ,
123
151
num_heads :int = 8 ,
124
152
num_layers :int = 1 ,
125
153
ffn_dim :int = 1024 ,
126
- pe_temperature : float = 10000. ,
154
+ pe_temperature :float = 10000. ,
127
155
dropout :float = 0.1 ,
128
156
act_type :str = "relu" ,
157
+ pre_norm :bool = False ,
129
158
):
130
159
super ().__init__ ()
131
160
# ----------- Basic parameters -----------
@@ -135,11 +164,12 @@ def __init__(self,
135
164
self .ffn_dim = ffn_dim
136
165
self .dropout = dropout
137
166
self .act_type = act_type
167
+ self .pre_norm = pre_norm
138
168
self .pe_temperature = pe_temperature
139
169
self .pos_embed = None
140
170
# ----------- Basic parameters -----------
141
171
self .encoder_layers = get_clones (
142
- TransformerEncoderLayer (d_model , num_heads , ffn_dim , dropout , act_type ), num_layers )
172
+ TransformerEncoderLayer (d_model , num_heads , ffn_dim , dropout , act_type , pre_norm ), num_layers )
143
173
144
174
def build_2d_sincos_position_embedding (self , device , w , h , embed_dim = 256 , temperature = 10000. ):
145
175
assert embed_dim % 4 == 0 , \
0 commit comments