Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ca967c1
Internal change
Mar 2, 2018
908981e
Set evaluation_master to master in RunConfig
Mar 2, 2018
e2e61ac
Add option to do soft EM instead of hard EM
a-googler Mar 5, 2018
f92d901
Add an inv_temp hparam for controlling softness of EM
a-googler Mar 6, 2018
f36f82c
More Librispeech subsets to help with mixed clean and noisy data trai…
a-googler Mar 6, 2018
40d1f15
proper em - P(c_i) is computed using ema_count instead of actual counts
a-googler Mar 6, 2018
8320faf
increase pseudo-count to 1.0 and now there's no NaN in training
a-googler Mar 7, 2018
d83d992
Use logits instead of probs to compute supervised attention loss.
a-googler Mar 8, 2018
7056827
Why do we need stop gradient here?
a-googler Mar 8, 2018
e1e8fbb
Add expected_attention_loss_type hparam to supervised_attention to al…
a-googler Mar 8, 2018
5ee776d
ema_count trainable should be False; this was causing the weird dp be…
a-googler Mar 8, 2018
7293efc
Fix multi-logit loss computation error.
aidangomez Mar 8, 2018
75d2aef
Basic autoencoder and improvements in image modality.
Mar 9, 2018
c4e6fab
Change batch size for hparam config
Mar 9, 2018
9ae5bc2
Make Vanilla GAN work, based on Compare GAN code.
Mar 9, 2018
1568e9b
internal
Mar 9, 2018
95053b4
Bump release number.
Mar 9, 2018
9a638df
Documentation for cloud TPU for Image Transformer. Additional default…
Mar 9, 2018
3de51ab
Add smoothed L0 prior and trainable logits for cluster probabilities.
Mar 9, 2018
6a6d9fe
Added the ema count smoothing update inside the else.
Mar 9, 2018
6e846f2
Make text_encoder unicode conversion a pass-through
Mar 9, 2018
d8080a1
Pass in decode_hp to _interactive_input_fn and remove summaries when
a-googler Mar 10, 2018
c7495b5
six.iteritems for Py3
Mar 10, 2018
329123f
Update Travis tests for Py3 to run TF 1.6
Mar 10, 2018
688f4d5
Update ISSUE_TEMPLATE
Mar 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add smoothed L0 prior and trainable logits for cluster probabilities.
PiperOrigin-RevId: 188533148
  • Loading branch information
Ashish Vaswani authored and Ryan Sepassi committed Mar 10, 2018
commit 3de51abb9c4b755cf18a524e5147b4cf867b90c4
52 changes: 42 additions & 10 deletions tensor2tensor/layers/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def nearest_neighbor(x,
random_top_k=1,
soft_em=False,
inv_temp=1.0,
ema_count=None):
ema_count=None,
c_probs=None):
"""Find the nearest element in means to elements in x.

Args:
Expand All @@ -82,7 +83,8 @@ def nearest_neighbor(x,
inv_temp: Inverse temperature for soft EM (Default: 1.)
ema_count: Table of counts for each embedding corresponding to how many
examples in a batch it was the closest to (Default: None).

c_probs: Precomputed probablities of clusters may be given, for example in
the case of smoothed l0 priors.
Returns:
Tensor with nearest element in mean encoded in one-hot notation.
"""
Expand All @@ -93,10 +95,17 @@ def nearest_neighbor(x,
scalar_prod = tf.transpose(scalar_prod, perm=[1, 0, 2])
dist = x_norm_sq + tf.transpose(
means_norm_sq, perm=[2, 0, 1]) - 2 * scalar_prod
if soft_em:
# computing cluster probabilities
if soft_em or c_probs is not None:
ema_count = tf.expand_dims(ema_count + 1.0, 0)
nearest_hot = tf.nn.softmax(-inv_temp * dist, axis=-1) * (
ema_count / tf.reduce_sum(ema_count, 2, keepdims=True))
if c_probs is not None:
# softmax of z logits and expand dims to match what we're doing
# for the else condition
c_probs = tf.expand_dims(c_probs, 0)
else:
c_probs = ema_count / tf.reduce_sum(ema_count, 2, keepdims=True)
if soft_em:
nearest_hot = tf.nn.softmax(-inv_temp * dist, axis=-1) * c_probs
nearest_hot /= tf.reduce_sum(nearest_hot, 2, keepdims=True)
else:
if random_top_k > 1:
Expand All @@ -119,7 +128,8 @@ def embedding_lookup(x,
random_top_k=1,
soft_em=False,
inv_temp=1.0,
ema_count=None):
ema_count=None,
c_probs=None):
"""Compute nearest neighbors and loss for training the embeddings via DVQ.

Args:
Expand All @@ -133,13 +143,15 @@ def embedding_lookup(x,
inv_temp: Inverse temperature for soft EM (Default: 1.)
ema_count: Table of counts for each embedding corresponding to how many
examples in a batch it was the closest to (Default: None).
c_probs: precomputed cluster probabilities might be passed, for example in
the case of smoothed L0.

Returns:
The nearest neighbor in one hot form, the nearest neighbor itself, the
commitment loss, embedding training loss.
"""
x_means_hot = nearest_neighbor(x, means, block_v_size, random_top_k, soft_em,
inv_temp, ema_count)
inv_temp, ema_count, c_probs)
x_means_hot_flat = tf.reshape(x_means_hot, [-1, num_blocks, block_v_size])
x_means = tf.matmul(tf.transpose(x_means_hot_flat, perm=[1, 0, 2]), means)
x_means = tf.transpose(x_means, [1, 0, 2])
Expand Down Expand Up @@ -405,7 +417,11 @@ def discrete_bottleneck(x,
summary=True,
dp_strength=1.0,
dp_decay=1.0,
dp_alpha=0.5):
dp_alpha=0.5,
slo=False,
slo_alpha=10,
slo_beta=0.5,
c_logits=None):
"""Discretization bottleneck for latent variables.

Args:
Expand Down Expand Up @@ -447,6 +463,11 @@ def discrete_bottleneck(x,
dp_decay: Decay the dp_strength using an exponential decay using this
term (Default: 1.0).
dp_alpha: Alpha term (pseudo-count) in Dirichlet Process (Default: 0.5).
slo: Smoothed L0
slo_alpha: alpha for smoothed L0
slo_beta: beta for smoothed L0
c_logits: a [num_blocks, block_size] tensor of logits for
computing cluster probabilities.

Returns:
Embedding to pass to the decoder, discrete latent, loss, and the embedding
Expand Down Expand Up @@ -532,10 +553,13 @@ def discrete_bottleneck(x,
c = tf.argmax(hot, axis=-1)
h1 = tf.layers.dense(hot, hidden_size, name='dae_dense')
elif bottleneck_kind == 'dvq':
c_probs = None
if c_logits is not None:
c_probs = tf.nn.softmax(c_logits, axis=-1)
x_reshaped = reshape_fn(x)
x_means_hot, x_means, q_loss, e_loss = embedding_lookup(
x_reshaped, means, num_blocks, block_v_size, random_top_k, soft_em,
inv_temp, ema_count)
inv_temp, ema_count, c_probs)

# Get the discrete latent represenation
x_means_idx = tf.argmax(x_means_hot, axis=-1)
Expand Down Expand Up @@ -568,6 +592,7 @@ def discrete_bottleneck(x,
# Adding a term that puts a Dirichlet prior over cluster probabilities
# Hopefully it'll encourage rich get richer behaviors
dp_prior_loss = 0.
slo_loss = 0.
if dp_strength > 0.0:
# Decay dp_strength over time to make it less important
dp_strength = tf.train.exponential_decay(
Expand All @@ -581,6 +606,13 @@ def discrete_bottleneck(x,
dp_prior_loss = -1.0 * tf.reduce_sum(dp_prior_loss)
dp_prior_loss /= (num_blocks * block_v_size)

# if using smoothed L0
if slo:
# expected log likelihood
ell = tf.reduce_sum(ema_count * tf.log(c_probs))
# the prior component in the loss for MAP EM.
slo_prior = slo_alpha * tf.reduce_sum(tf.exp(-1.*c_probs/slo_beta))
slo_loss = -1. * (ell + slo_prior)/(num_blocks * block_v_size)
x_means_hot_flat = tf.reshape(
x_means_hot, shape=[-1, num_blocks, block_v_size])
dw = tf.matmul(
Expand All @@ -596,7 +628,7 @@ def discrete_bottleneck(x,
with tf.control_dependencies([e_loss]):
update_means = tf.assign(means, updated_ema_means)
with tf.control_dependencies([update_means]):
l = beta * e_loss + dp_strength * dp_prior_loss
l = beta * e_loss + dp_strength * dp_prior_loss + slo_loss
else:
l = q_loss + beta * e_loss

Expand Down
20 changes: 16 additions & 4 deletions tensor2tensor/models/research/transformer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,10 @@ def __init__(self, *args, **kwargs):
summary=_DO_SUMMARIES,
dp_strength=self._hparams.dp_strength,
dp_decay=self._hparams.dp_decay,
dp_alpha=self._hparams.dp_alpha)

dp_alpha=self._hparams.dp_alpha,
slo=self._hparams.slo,
slo_alpha=self._hparams.slo_alpha,
slo_beta=self._hparams.slo_beta)
# Set the discretization bottleneck specific things here
if self._hparams.bottleneck_kind == "dvq":
block_dim = int(self._hparams.hidden_size // self._hparams.num_blocks)
Expand All @@ -513,7 +515,6 @@ def __init__(self, *args, **kwargs):
tf.logging.info("Using slices for DVQ")
else:
raise ValueError("Unknown reshape method")

means = tf.get_variable(
name="means",
shape=[self._hparams.num_blocks, block_v_size, block_dim],
Expand All @@ -530,12 +531,20 @@ def __init__(self, *args, **kwargs):
"ema_means", initializer=means.initialized_value(),
trainable=False)

# Create the shadow variables if we are using smoothed l0
c_logits = None
if self._hparams.slo:
# softmax logits for the cluster probabilities
c_logits = tf.get_variable(
"c_logits", [self._hparams.num_blocks, block_v_size],
initializer=tf.uniform_unit_scaling_initializer())
# Update bottleneck
self._hparams.bottleneck = partial(
self._hparams.bottleneck,
means=means,
ema_count=ema_count,
ema_means=ema_means)
ema_means=ema_means,
c_logits=c_logits)

@property
def has_input(self):
Expand Down Expand Up @@ -643,6 +652,9 @@ def transformer_ae_small():
hparams.add_hparam("dp_alpha", 0.5)
hparams.add_hparam("dp_strength", 0.25)
hparams.add_hparam("dp_decay", 1.0)
hparams.add_hparam("slo", False) # for smoothed L0.
hparams.add_hparam("slo_alpha", 0.25)
hparams.add_hparam("slo_beta", 0.5)
hparams.add_hparam("unmasked_percentage", 0.1)
hparams.add_hparam("do_ae", True)
hparams.add_hparam("do_mask", True)
Expand Down