@@ -122,6 +122,13 @@ def add_args(parser):
122122 'Must be used with adaptive_loss criterion' ),
123123 parser .add_argument ('--adaptive-softmax-dropout' , type = float , metavar = 'D' ,
124124 help = 'sets adaptive softmax dropout for the tail projections' )
125+ # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
126+ parser .add_argument ('--no-cross-attention' , default = False , action = 'store_true' ,
127+ help = 'do not perform cross-attention' )
128+ parser .add_argument ('--cross-self-attention' , default = False , action = 'store_true' ,
129+ help = 'perform cross+self-attention' )
130+ parser .add_argument ('--layer-wise-attention' , default = False , action = 'store_true' ,
131+ help = 'perform layer-wise attention (cross-attention or cross+self-attention)' )
125132 # fmt: on
126133
127134 @classmethod
@@ -180,7 +187,12 @@ def build_encoder(cls, args, src_dict, embed_tokens):
180187
181188 @classmethod
182189 def build_decoder (cls , args , tgt_dict , embed_tokens ):
183- return TransformerDecoder (args , tgt_dict , embed_tokens )
190+ return TransformerDecoder (
191+ args ,
192+ tgt_dict ,
193+ embed_tokens ,
194+ no_encoder_attn = getattr (args , 'no_cross_attention' , False ),
195+ )
184196
185197
186198class TransformerEncoder (FairseqEncoder ):
@@ -211,6 +223,8 @@ def __init__(self, args, dictionary, embed_tokens):
211223 learned = args .encoder_learned_pos ,
212224 ) if not args .no_token_positional_embeddings else None
213225
226+ self .layer_wise_attention = getattr (args , 'layer_wise_attention' , False )
227+
214228 self .layers = nn .ModuleList ([])
215229 self .layers .extend ([
216230 TransformerEncoderLayer (args )
@@ -230,21 +244,29 @@ def forward_embedding(self, src_tokens):
230244 x = F .dropout (x , p = self .dropout , training = self .training )
231245 return x , embed
232246
233- def forward (self , src_tokens , src_lengths , cls_input = None ):
247+ def forward (self , src_tokens , src_lengths , cls_input = None , return_all_hiddens = False ):
234248 """
235249 Args:
236250 src_tokens (LongTensor): tokens in the source language of shape
237251 `(batch, src_len)`
238252 src_lengths (torch.LongTensor): lengths of each source sentence of
239253 shape `(batch)`
254+ return_all_hiddens (bool, optional): also return all of the
255+ intermediate hidden states (default: False).
240256
241257 Returns:
242258 dict:
243259 - **encoder_out** (Tensor): the last encoder layer's output of
244260 shape `(src_len, batch, embed_dim)`
245261 - **encoder_padding_mask** (ByteTensor): the positions of
246262 padding elements of shape `(batch, src_len)`
263+ - **encoder_states** (List[Tensor]): all intermediate
264+ hidden states of shape `(src_len, batch, embed_dim)`.
265+ Only populated if *return_all_hiddens* is True.
247266 """
267+ if self .layer_wise_attention :
268+ return_all_hiddens = True
269+
248270 x , encoder_embedding = self .forward_embedding (src_tokens )
249271
250272 # B x T x C -> T x B x C
@@ -255,17 +277,24 @@ def forward(self, src_tokens, src_lengths, cls_input=None):
255277 if not encoder_padding_mask .any ():
256278 encoder_padding_mask = None
257279
280+ encoder_states = [] if return_all_hiddens else None
281+
258282 # encoder layers
259283 for layer in self .layers :
260284 x = layer (x , encoder_padding_mask )
285+ if return_all_hiddens :
286+ encoder_states .append (x )
261287
262288 if self .layer_norm :
263289 x = self .layer_norm (x )
290+ if return_all_hiddens :
291+ encoder_states [- 1 ] = x
264292
265293 return {
266294 'encoder_out' : x , # T x B x C
267295 'encoder_padding_mask' : encoder_padding_mask , # B x T
268296 'encoder_embedding' : encoder_embedding , # B x T x C
297+ 'encoder_states' : encoder_states , # List[T x B x C]
269298 }
270299
271300 def reorder_encoder_out (self , encoder_out , new_order ):
@@ -285,6 +314,9 @@ def reorder_encoder_out(self, encoder_out, new_order):
285314 if encoder_out ['encoder_padding_mask' ] is not None :
286315 encoder_out ['encoder_padding_mask' ] = \
287316 encoder_out ['encoder_padding_mask' ].index_select (0 , new_order )
317+ if encoder_out .get ('encoder_states' , None ) is not None :
318+ for idx , state in enumerate (encoder_out ['encoder_states' ]):
319+ encoder_out ['encoder_states' ][idx ] = state .index_select (1 , new_order )
288320 return encoder_out
289321
290322 def max_positions (self ):
@@ -293,6 +325,14 @@ def max_positions(self):
293325 return self .max_source_positions
294326 return min (self .max_source_positions , self .embed_positions .max_positions ())
295327
328+ def buffered_future_mask (self , tensor ):
329+ dim = tensor .size (0 )
330+ if not hasattr (self , '_future_mask' ) or self ._future_mask is None or self ._future_mask .device != tensor .device :
331+ self ._future_mask = torch .triu (utils .fill_with_neg_inf (tensor .new (dim , dim )), 1 )
332+ if self ._future_mask .size (0 ) < dim :
333+ self ._future_mask = torch .triu (utils .fill_with_neg_inf (self ._future_mask .resize_ (dim , dim )), 1 )
334+ return self ._future_mask [:dim , :dim ]
335+
296336 def upgrade_state_dict_named (self , state_dict , name ):
297337 """Upgrade a (possibly old) state dict for new versions of fairseq."""
298338 if isinstance (self .embed_positions , SinusoidalPositionalEmbedding ):
@@ -350,6 +390,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
350390 learned = args .decoder_learned_pos ,
351391 ) if not args .no_token_positional_embeddings else None
352392
393+ self .cross_self_attention = getattr (args , 'cross_self_attention' , False )
394+ self .layer_wise_attention = getattr (args , 'layer_wise_attention' , False )
395+
353396 self .layers = nn .ModuleList ([])
354397 self .layers .extend ([
355398 TransformerDecoderLayer (args , no_encoder_attn )
@@ -435,14 +478,26 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta
435478
436479 inner_states = [x ]
437480
481+ self_attn_padding_mask = prev_output_tokens .eq (self .padding_idx )
482+ if not self_attn_padding_mask .any () and not self .cross_self_attention :
483+ self_attn_padding_mask = None
484+
438485 # decoder layers
439- for layer in self .layers :
486+ for idx , layer in enumerate (self .layers ):
487+ encoder_state = None
488+ if encoder_out is not None :
489+ if self .layer_wise_attention :
490+ encoder_state = encoder_out ['encoder_states' ][idx ]
491+ else :
492+ encoder_state = encoder_out ['encoder_out' ]
493+
440494 x , attn = layer (
441495 x ,
442- encoder_out [ 'encoder_out' ] if encoder_out is not None else None ,
496+ encoder_state ,
443497 encoder_out ['encoder_padding_mask' ] if encoder_out is not None else None ,
444498 incremental_state ,
445499 self_attn_mask = self .buffered_future_mask (x ) if incremental_state is None else None ,
500+ self_attn_padding_mask = self_attn_padding_mask ,
446501 )
447502 inner_states .append (x )
448503
@@ -553,6 +608,9 @@ def base_architecture(args):
553608 args .share_all_embeddings = getattr (args , 'share_all_embeddings' , False )
554609 args .no_token_positional_embeddings = getattr (args , 'no_token_positional_embeddings' , False )
555610 args .adaptive_input = getattr (args , 'adaptive_input' , False )
611+ args .no_cross_attention = getattr (args , 'no_cross_attention' , False )
612+ args .cross_self_attention = getattr (args , 'cross_self_attention' , False )
613+ args .layer_wise_attention = getattr (args , 'layer_wise_attention' , False )
556614
557615 args .decoder_output_dim = getattr (args , 'decoder_output_dim' , args .decoder_embed_dim )
558616 args .decoder_input_dim = getattr (args , 'decoder_input_dim' , args .decoder_embed_dim )
0 commit comments