Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add expected_attention_loss_multiplier hparam to allow scaling the at…
…tention_loss.

PiperOrigin-RevId: 187549764
  • Loading branch information
T2T Team authored and Ryan Sepassi committed Mar 2, 2018
commit c1dbc3607d497d560aca6a2eae0ca2985c8b43d4
9 changes: 6 additions & 3 deletions tensor2tensor/layers/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,17 @@ def add_standard_attention_hparams(hparams):
return hparams


def encoder_decoder_attention_loss(expected_attention, actual_attentions):
def encoder_decoder_attention_loss(expected_attention,
actual_attentions,
loss_multiplier=1.0):
"""Computes encdec attention loss between expected and actual attentions.

Args:
expected_attention: Tensor storing the expected encoder-decoder attention
weights with shape [batch_size, target_length, input_length].
actual_attentions: Dictionary with actual attention weights for different
attention types and hidden layers.
loss_multiplier: multiplier for the attention loss.

Returns:
MSE loss between the actual and expected attention weights.
Expand All @@ -351,8 +354,8 @@ def encoder_decoder_attention_loss(expected_attention, actual_attentions):
# Reduce mean across all layers (axis=0) and all heads (axis=2) to get a
# tensor with shape [batch_size, target_length, input_length].
actual_attention_weights = tf.reduce_mean(actual_attention_weights, [0, 2])
return tf.losses.mean_squared_error(expected_attention,
actual_attention_weights)
return tf.losses.mean_squared_error(
expected_attention, actual_attention_weights) * loss_multiplier


@expert_utils.add_name_scope()
Expand Down
11 changes: 10 additions & 1 deletion tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def body(self, features):
expected_attentions = features.get("expected_attentions")
if expected_attentions is not None:
attention_loss = common_attention.encoder_decoder_attention_loss(
expected_attentions, self.attention_weights)
expected_attentions, self.attention_weights,
hparams.expected_attention_loss_multiplier)
return decoder_output, {"attention_loss": attention_loss}

return decoder_output
Expand Down Expand Up @@ -1462,3 +1463,11 @@ def transformer_librispeech_tpu():
librispeech.set_librispeech_length_hparams(hparams)
return hparams


@registry.register_hparams
def transformer_supervised_attention():
"""Hparams for supervised attention problems."""
hparams = transformer_base()
# Multiplier to the encoder-decoder expected attention loss.
hparams.add_hparam("expected_attention_loss_multiplier", 1.0)
return hparams
3 changes: 2 additions & 1 deletion tensor2tensor/models/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def testTransformerWithoutProblem(self):
[BATCH_SIZE, TARGET_LENGTH, 1, hparams.hidden_size])

def testTransformerWithEncoderDecoderAttentionLoss(self):
model, features = self.getModel(transformer.transformer_small())
model, features = self.getModel(
transformer.transformer_supervised_attention())
expected_attention_weights = np.random.random_sample(
size=(BATCH_SIZE, TARGET_LENGTH, INPUT_LENGTH))
features["expected_attentions"] = tf.constant(
Expand Down