Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Prev Previous commit
Next Next commit
Removed the following from transformer_vae :
PiperOrigin-RevId: 195723846
  • Loading branch information
Ashish Vaswani authored and lukaszkaiser committed May 8, 2018
commit edf4e05502f15f4e4baafe16d5129beb4b0dcc04
27 changes: 3 additions & 24 deletions tensor2tensor/models/research/transformer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import division
from __future__ import print_function

import copy
import functools
import math

Expand Down Expand Up @@ -150,10 +149,6 @@ def decode_transformer(encoder_output,
causal=True):
"""Original Transformer decoder."""
orig_hparams = hparams
if name == "extra":
hparams = hparams.ex
targets = tf.layers.dense(
targets, hparams.hidden_size, name="extra_tgt_embed")
with tf.variable_scope(name):
if task is None:
task = hparams.task
Expand Down Expand Up @@ -331,17 +326,6 @@ def ae_transformer_internal(inputs,
if hparams.do_refine:
_DO_SUMMARIES = False

# Change hyperparameters for the latent prediction model.
hparams_ex = copy.copy(hparams)
hparams_ex.filter_size *= 2
hparams_ex.hidden_size *= 2
hparams_ex.dropout = 0.0
hparams_ex.relu_dropout = 0.0
hparams_ex.z_dropout = 0.0
hparams_ex.layer_prepostprocess_dropout = 0.0
hparams_ex.symbol_dropout = 0.0
hparams.ex = hparams_ex

# Prepare.
if inputs is not None:
batch_size = common_layers.shape_list(inputs)[0]
Expand All @@ -352,10 +336,8 @@ def ae_transformer_internal(inputs,
# Encoder.
if inputs is not None:
inputs = common_layers.flatten4d3d(inputs)
inputs_ex = tf.layers.dense(
tf.stop_gradient(inputs), hparams_ex.hidden_size, name="extra_embed")
inputs, ed = encode(inputs, target_space, hparams, "input_enc")
inputs_ex, ed_ex = encode(inputs_ex, target_space, hparams_ex, "extra_ienc")
inputs_ex, ed_ex = inputs, ed
else:
ed, inputs_ex, ed_ex = None, None, None

Expand Down Expand Up @@ -394,7 +376,7 @@ def ae_transformer_internal(inputs,
if hparams.bottleneck_kind not in ["dense", "vae"]:
latents_pred = decode_transformer(
inputs_ex, ed_ex,
tf.stop_gradient(embed(latents_discrete)), hparams, "extra",
embed(latents_discrete), hparams, "extra",
task="translate")
_, latent_pred_loss = ae_latent_softmax(
latents_pred, tf.stop_gradient(latents_discrete), hparams)
Expand Down Expand Up @@ -487,10 +469,7 @@ def refine_res():
nonlatent_steps = hparams.mask_startup_steps
latent_time = tf.less(nonlatent_steps,
tf.to_int32(tf.train.get_global_step()))
# Learning rate warmup for the latent model for 20K steps.
latent_warmup = tf.to_float(tf.train.get_global_step()) - nonlatent_steps
latent_warmup = tf.maximum(0.0, tf.minimum(1.0, latent_warmup / 20000.0))
losses["latent_pred"] *= tf.to_float(latent_time) * latent_warmup
losses["latent_pred"] *= tf.to_float(latent_time)
return res, losses, cache


Expand Down