Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Prev Previous commit
Next Next commit
Allow for decompress outside do_mask. It seems to be independent of t…
…hat.

PiperOrigin-RevId: 195744845
  • Loading branch information
Ashish Vaswani authored and lukaszkaiser committed May 8, 2018
commit 2ee03caad784d8c7d5b6c22938c0df9403bc5358
15 changes: 9 additions & 6 deletions tensor2tensor/models/research/transformer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,14 @@ def bn_inputs():
latents_dense = tf.pad(latents_dense,
[[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

# decompressing the dense latents
for i in range(hparams.num_compress_steps):
j = hparams.num_compress_steps - i - 1
d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
if hparams.do_attend_decompress:
d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

# Masking.
if hparams.do_mask:
masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
Expand All @@ -437,12 +445,7 @@ def bn_inputs():
mask = tf.less(masking, tf.random_uniform(
common_layers.shape_list(targets)[:-1]))
mask = tf.expand_dims(tf.to_float(mask), 3)
for i in range(hparams.num_compress_steps):
j = hparams.num_compress_steps - i - 1
d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
if hparams.do_attend_decompress:
d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

# targets is always [batch, length, 1, depth]
targets = mask * targets + (1.0 - mask) * d
# reshape back to 4d here
Expand Down