Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ca967c1
Internal change
Mar 2, 2018
908981e
Set evaluation_master to master in RunConfig
Mar 2, 2018
e2e61ac
Add option to do soft EM instead of hard EM
a-googler Mar 5, 2018
f92d901
Add an inv_temp hparam for controlling softness of EM
a-googler Mar 6, 2018
f36f82c
More Librispeech subsets to help with mixed clean and noisy data trai…
a-googler Mar 6, 2018
40d1f15
proper em - P(c_i) is computed using ema_count instead of actual counts
a-googler Mar 6, 2018
8320faf
increase pseudo-count to 1.0 and now there's no NaN in training
a-googler Mar 7, 2018
d83d992
Use logits instead of probs to compute supervised attention loss.
a-googler Mar 8, 2018
7056827
Why do we need stop gradient here?
a-googler Mar 8, 2018
e1e8fbb
Add expected_attention_loss_type hparam to supervised_attention to al…
a-googler Mar 8, 2018
5ee776d
ema_count trainable should be False; this was causing the weird dp be…
a-googler Mar 8, 2018
7293efc
Fix multi-logit loss computation error.
aidangomez Mar 8, 2018
75d2aef
Basic autoencoder and improvements in image modality.
Mar 9, 2018
c4e6fab
Change batch size for hparam config
Mar 9, 2018
9ae5bc2
Make Vanilla GAN work, based on Compare GAN code.
Mar 9, 2018
1568e9b
internal
Mar 9, 2018
95053b4
Bump release number.
Mar 9, 2018
9a638df
Documentation for cloud TPU for Image Transformer. Additional default…
Mar 9, 2018
3de51ab
Add smoothed L0 prior and trainable logits for cluster probabilities.
Mar 9, 2018
6a6d9fe
Added the ema count smoothing update inside the else.
Mar 9, 2018
6e846f2
Make text_encoder unicode conversion a pass-through
Mar 9, 2018
d8080a1
Pass in decode_hp to _interactive_input_fn and remove summaries when
a-googler Mar 10, 2018
c7495b5
six.iteritems for Py3
Mar 10, 2018
329123f
Update Travis tests for Py3 to run TF 1.6
Mar 10, 2018
688f4d5
Update ISSUE_TEMPLATE
Mar 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use logits instead of probs to compute supervised attention loss.
PiperOrigin-RevId: 188254122
  • Loading branch information
T2T Team authored and Ryan Sepassi committed Mar 9, 2018
commit d83d992c9bdbeef94b9004c4a4142a1cf32f8ebf
22 changes: 12 additions & 10 deletions tensor2tensor/layers/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,28 +334,29 @@ def encoder_decoder_attention_loss(expected_attention,

Args:
expected_attention: Tensor storing the expected encoder-decoder attention
weights with shape [batch_size, target_length, input_length].
actual_attentions: Dictionary with actual attention weights for different
logits with shape [batch_size, target_length, input_length].
actual_attentions: Dictionary with actual attention logits for different
attention types and hidden layers.
loss_multiplier: multiplier for the attention loss.

Returns:
MSE loss between the actual and expected attention weights.
KL_divergence loss between the actual and expected attention logits.
"""
# For each hidden layer, we have an attention weight tensor with shape
# For each hidden layer, we have an attention logits tensor with shape
# [batch_size, num_heads, target_length, input_length].
actual_encdec_attention_weights = [
actual_encdec_attention_logits = [
t for layer_key, t in actual_attentions.items()
if "encdec_attention" in layer_key
if "encdec_attention" in layer_key and layer_key.endswith("/logits")
]
# Stack all hidden layer attention weight tensors to get a tensor with shape
# [num_hidden_layers, batch_size, num_heads, target_length, input_length].
actual_attention_weights = tf.stack(actual_encdec_attention_weights)
actual_attention_logits = tf.stack(actual_encdec_attention_logits)
# Reduce mean across all layers (axis=0) and all heads (axis=2) to get a
# tensor with shape [batch_size, target_length, input_length].
actual_attention_weights = tf.reduce_mean(actual_attention_weights, [0, 2])
return tf.losses.mean_squared_error(
expected_attention, actual_attention_weights) * loss_multiplier
actual_attention_logits = tf.reduce_mean(actual_attention_logits, [0, 2])
p = tf.contrib.distributions.Categorical(logits=expected_attention)
q = tf.contrib.distributions.Categorical(logits=actual_attention_logits)
return tf.contrib.distributions.kl_divergence(p, q) * loss_multiplier


@expert_utils.add_name_scope()
Expand Down Expand Up @@ -1305,6 +1306,7 @@ def dot_product_attention(q,
weights = tf.nn.softmax(logits, name="attention_weights")
if save_weights_to is not None:
save_weights_to[scope.name] = weights
save_weights_to[scope.name + "/logits"] = logits
# dropping out the attention links for each of the heads
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
Expand Down