diff --git a/.travis.yml b/.travis.yml index 339a0f606..1f32a4e60 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,7 @@ matrix: - python: "3.6" env: TF_VERSION="1.4.*" - python: "3.6" - env: TF_VERSION="1.6.*" + env: TF_VERSION="1.5.*" before_install: - echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list - curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add - diff --git a/.github/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md similarity index 69% rename from .github/ISSUE_TEMPLATE.md rename to ISSUE_TEMPLATE.md index a2d93e81c..477fa82fc 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -6,7 +6,7 @@ ### *TensorFlow* and *tensor2tensor* versions - + > … @@ -16,7 +16,7 @@ ### In case of bug report: Error log - + > … diff --git a/README.md b/README.md index cd9ab9331..dc6457482 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO of deep learning models and datasets designed to make deep learning more accessible and [accelerate ML research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). - is actively used and maintained by researchers and engineers within the +T2T is actively used and maintained by researchers and engineers within the [Google Brain team](https://research.google.com/teams/brain/) and a community of users. We're eager to collaborate with you too, so feel free to [open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues) diff --git a/docs/cloud_tpu.md b/docs/cloud_tpu.md index 931f75bbf..cfc0c0a96 100644 --- a/docs/cloud_tpu.md +++ b/docs/cloud_tpu.md @@ -18,6 +18,17 @@ See the official tutorial for [running Transfomer on Cloud TPUs](https://cloud.google.com/tpu/docs/tutorials/transformer) for some examples and try out your own problems. +Image Transformer: +* `imagetransformer` with `imagetransformer_base_tpu` (or + `imagetransformer_tiny_tpu`) +* `img2img_transformer` with `img2img_transformer_base_tpu` (or + `img2img_transformer_tiny_tpu`) + +You can run the `ImageTransformer` model on problems like unconditional or +conditional Image generation and `Img2ImgTransformer` model on Super Resolution. +We run on datasets like CelebA, CIFAR and ImageNet but they should work with any +other image dataset. + Residual networks: * `resnet` with `resnet_50` (or `resnet_18` or `resnet_34`) * `revnet` with `revnet_104` (or `revnet_38_cifar`) diff --git a/docs/walkthrough.md b/docs/walkthrough.md index 755d080b6..dc6457482 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -15,7 +15,7 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO of deep learning models and datasets designed to make deep learning more accessible and [accelerate ML research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). - is actively used and maintained by researchers and engineers within the +T2T is actively used and maintained by researchers and engineers within the [Google Brain team](https://research.google.com/teams/brain/) and a community of users. We're eager to collaborate with you too, so feel free to [open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues) @@ -154,7 +154,7 @@ For all translation problems, we suggest to try the Transformer model: this should reach a BLEU score of about 28 on the English-German data-set, which is close to state-of-the art. If training on a single GPU, try the `--hparams_set=transformer_base_single_gpu` setting. For very good results -or larger data-sets (e.g., for English-French)m, try the big model +or larger data-sets (e.g., for English-French), try the big model with `--hparams_set=transformer_big`. ## Basics diff --git a/setup.py b/setup.py index 01a2a6f33..e35412520 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.5.4', + version='1.5.5', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/data_generators/cifar.py b/tensor2tensor/data_generators/cifar.py index ac23a95b5..3dd5c8f39 100644 --- a/tensor2tensor/data_generators/cifar.py +++ b/tensor2tensor/data_generators/cifar.py @@ -124,7 +124,8 @@ def preprocess_example(self, example, mode, unused_hparams): image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3]) if mode == tf.estimator.ModeKeys.TRAIN: image = image_utils.cifar_image_augmentation(image) - image = tf.image.per_image_standardization(image) + if not self._was_reversed: + image = tf.image.per_image_standardization(image) example["inputs"] = image return example @@ -151,7 +152,8 @@ class ImageCifar10Plain(ImageCifar10): def preprocess_example(self, example, mode, unused_hparams): image = example["inputs"] image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3]) - image = tf.image.per_image_standardization(image) + if not self._was_reversed: + image = tf.image.per_image_standardization(image) example["inputs"] = image return example @@ -179,7 +181,8 @@ def dataset_filename(self): def preprocess_example(self, example, mode, unused_hparams): image = example["inputs"] image = image_utils.resize_by_area(image, 8) - image = tf.image.per_image_standardization(image) + if not self._was_reversed: + image = tf.image.per_image_standardization(image) example["inputs"] = image return example @@ -192,7 +195,6 @@ def dataset_filename(self): return "image_cifar10_plain" # Reuse CIFAR-10 plain data. def preprocess_example(self, example, unused_mode, unused_hparams): - inputs = example["inputs"] # For Img2Img resize input and output images as desired. example["inputs"] = image_utils.resize_by_area(inputs, 8) @@ -330,7 +332,8 @@ def preprocess_example(self, example, mode, unused_hparams): image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3]) if mode == tf.estimator.ModeKeys.TRAIN: image = image_utils.cifar_image_augmentation(image) - image = tf.image.per_image_standardization(image) + if not self._was_reversed: + image = tf.image.per_image_standardization(image) example["inputs"] = image return example @@ -357,7 +360,8 @@ class ImageCifar100Plain(ImageCifar100): def preprocess_example(self, example, mode, unused_hparams): image = example["inputs"] image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3]) - image = tf.image.per_image_standardization(image) + if not self._was_reversed: + image = tf.image.per_image_standardization(image) example["inputs"] = image return example @@ -385,7 +389,8 @@ def dataset_filename(self): def preprocess_example(self, example, mode, unused_hparams): image = example["inputs"] image = image_utils.resize_by_area(image, 8) - image = tf.image.per_image_standardization(image) + if not self._was_reversed: + image = tf.image.per_image_standardization(image) example["inputs"] = image return example @@ -398,7 +403,6 @@ def dataset_filename(self): return "image_cifar100_plain" # Reuse CIFAR-100 plain data. def preprocess_example(self, example, unused_mode, unused_hparams): - inputs = example["inputs"] # For Img2Img resize input and output images as desired. example["inputs"] = image_utils.resize_by_area(inputs, 8) diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index c50b0db6b..1030a43b5 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -35,6 +35,8 @@ import tensorflow as tf + + flags = tf.flags FLAGS = flags.FLAGS @@ -157,7 +159,6 @@ def num_steps(self): return 5000 - @registry.register_problem class GymPongTrajectoriesFromPolicy(GymDiscreteProblem): """Pong game, loaded actions.""" @@ -197,7 +198,7 @@ def generator(self, data_dir, tmp_dir): model_saver.restore(sess, FLAGS.model_path) for item in super(GymPongTrajectoriesFromPolicy, self).generator(data_dir, tmp_dir): - yield item + yield item # TODO(blazej0): For training of atari agents wrappers are usually used. # Below we have a hacky solution which is a workaround to be used together diff --git a/tensor2tensor/data_generators/image_utils.py b/tensor2tensor/data_generators/image_utils.py index 7c3d946e1..c77eb11e8 100644 --- a/tensor2tensor/data_generators/image_utils.py +++ b/tensor2tensor/data_generators/image_utils.py @@ -26,6 +26,7 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import metrics from tensor2tensor.utils import registry import tensorflow as tf @@ -64,9 +65,19 @@ def example_reading_spec(self, label_repr=None): return data_fields, data_items_to_decoders def preprocess_example(self, example, mode, hparams): - example["inputs"] = tf.image.per_image_standardization(example["inputs"]) + if not self._was_reversed: + example["inputs"] = tf.image.per_image_standardization(example["inputs"]) return example + def eval_metrics(self): + eval_metrics = [ + metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5, + metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY + ] + if self._was_reversed: + eval_metrics += [metrics.Metrics.IMAGE_SUMMARY] + return eval_metrics + class Image2ClassProblem(ImageProblem): """Base class for image classification problems.""" diff --git a/tensor2tensor/data_generators/imagenet.py b/tensor2tensor/data_generators/imagenet.py index e1de105f2..db555ad9b 100644 --- a/tensor2tensor/data_generators/imagenet.py +++ b/tensor2tensor/data_generators/imagenet.py @@ -334,7 +334,8 @@ def distorted_bounding_box_crop(image, Returns: (cropped image `Tensor`, distorted bbox `Tensor`). """ - with tf.name_scope(scope, default_name="distorted_bounding_box_crop", values=[image, bbox]): + with tf.name_scope(scope, default_name="distorted_bounding_box_crop", + values=[image, bbox]): # Each bounding box has shape [1, num_boxes, box coords] and # the coordinates are ordered [ymin, xmin, ymax, xmax]. diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py index ab8376847..81c532286 100644 --- a/tensor2tensor/data_generators/librispeech.py +++ b/tensor2tensor/data_generators/librispeech.py @@ -39,7 +39,7 @@ "train-other-500" ], ] -_LIBRISPEECH_TEST_DATASETS = [ +_LIBRISPEECH_DEV_DATASETS = [ [ "http://www.openslr.org/resources/12/dev-clean.tar.gz", "dev-clean" @@ -49,6 +49,16 @@ "dev-other" ], ] +_LIBRISPEECH_TEST_DATASETS = [ + [ + "http://www.openslr.org/resources/12/test-clean.tar.gz", + "test-clean" + ], + [ + "http://www.openslr.org/resources/12/test-other.tar.gz", + "test-other" + ], +] def _collect_data(directory, input_ext, transcription_ext): @@ -72,7 +82,7 @@ def _collect_data(directory, input_ext, transcription_ext): assert key not in data_files media_name = "%s.%s"%(media_base, input_ext) media_path = os.path.join(root, media_name) - data_files[key] = (media_path, label) + data_files[key] = (media_base, media_path, label) return data_files @@ -82,7 +92,8 @@ class Librispeech(speech_recognition.SpeechRecognitionProblem): # Select only the clean data TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS - DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS + DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS + TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS @property def num_shards(self): @@ -96,6 +107,10 @@ def use_subword_tokenizer(self): def num_dev_shards(self): return 1 + @property + def num_test_shards(self): + return 1 + @property def use_train_shards_for_dev(self): """If true, we only generate training data and hold out shards for dev.""" @@ -127,13 +142,19 @@ def generator(self, data_dir, tmp_dir, datasets, audio_encoder = encoders["waveforms"] text_encoder = encoders["targets"] - for media_file, text_data in sorted(data_pairs)[start_from:]: + for utt_id, media_file, text_data in sorted(data_pairs)[start_from:]: if how_many > 0 and i == how_many: return i += 1 + wav_data = audio_encoder.encode(media_file) + spk_id, unused_book_id, _ = utt_id.split("-") yield { - "waveforms": audio_encoder.encode(media_file), - "targets": text_encoder.encode(text_data) + "waveforms": wav_data, + "waveform_lens": [len(wav_data)], + "targets": text_encoder.encode(text_data), + "raw_transcript": [text_data], + "utt_id": [utt_id], + "spk_id": [spk_id], } def generate_data(self, data_dir, tmp_dir, task_id=-1): @@ -141,6 +162,11 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): data_dir, self.num_shards, shuffled=False) dev_paths = self.dev_filepaths( data_dir, self.num_dev_shards, shuffled=False) + test_paths = self.test_filepaths( + data_dir, self.num_test_shards, shuffled=True) + + generator_utils.generate_files( + self.generator(data_dir, tmp_dir, self.TEST_DATASETS), test_paths) if self.use_train_shards_for_dev: all_paths = train_paths + dev_paths @@ -153,22 +179,51 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths) +@registry.register_problem() +class LibrispeechTrainFullTestClean(Librispeech): + """Problem to train on full 960h, but evaluate on clean data only.""" + + def training_filepaths(self, data_dir, num_shards, shuffled): + return Librispeech.training_filepaths(data_dir, num_shards, shuffled) + + def dev_filepaths(self, data_dir, num_shards, shuffled): + return LibrispeechClean.dev_filepaths(data_dir, num_shards, shuffled) + + def test_filepaths(self, data_dir, num_shards, shuffled): + return LibrispeechClean.test_filepaths(data_dir, num_shards, shuffled) + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + raise Exception("Generate librispeech and librispeech_clean data.") + + @registry.register_problem() class LibrispeechCleanSmall(Librispeech): - """Problem spec for Librispeech using 100h clean train data.""" + """Problem spec for Librispeech using 100h clean train and clean eval data.""" # Select only the clean data TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:1] - DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1] + DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS[:1] + TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1] @registry.register_problem() class LibrispeechClean(Librispeech): - """Problem spec for Librispeech using 460h clean train data.""" + """Problem spec for Librispeech using 460h clean train and clean eval data.""" # Select only the clean data TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:2] - DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1] + DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS[:1] + TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1] + + +@registry.register_problem() +class LibrispeechNoisy(Librispeech): + """Problem spec for Librispeech using 400h noisy train and noisy eval data.""" + + # Select only the clean data + TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[2:] + DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS[1:] + TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS[1:] # TODO(lukaszkaiser): clean up hparams or remove from here. diff --git a/tensor2tensor/data_generators/mnist.py b/tensor2tensor/data_generators/mnist.py index 5fefbd476..ef40f62e6 100644 --- a/tensor2tensor/data_generators/mnist.py +++ b/tensor2tensor/data_generators/mnist.py @@ -162,7 +162,8 @@ def train_shards(self): def preprocess_example(self, example, mode, unused_hparams): image = example["inputs"] image.set_shape([_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1]) - image = tf.image.per_image_standardization(image) + if not self._was_reversed: + image = tf.image.per_image_standardization(image) example["inputs"] = image return example diff --git a/tensor2tensor/data_generators/ptb.py b/tensor2tensor/data_generators/ptb.py index 5c96c2629..af455749d 100644 --- a/tensor2tensor/data_generators/ptb.py +++ b/tensor2tensor/data_generators/ptb.py @@ -82,6 +82,10 @@ def _maybe_download_corpus(tmp_dir, vocab_type): Args: tmp_dir: directory containing dataset. + vocab_type: which vocabulary are we using. + + Returns: + The list of names of files. """ filename = os.path.basename(PTB_URL) compressed_filepath = generator_utils.maybe_download( diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index fa057ade9..bafdcb615 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -62,10 +62,10 @@ if six.PY2: def native_to_unicode(s): - return s if isinstance(s, unicode) else s.decode("utf8") + return s if isinstance(s, unicode) else s.decode("utf-8") def unicode_to_native(s): - return s.encode("utf-8") + return s.encode("utf-8") if isinstance(s, unicode) else s else: # No conversion required on Python >= 3. def native_to_unicode(s): diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index ddb6c3c89..5b5251955 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -327,35 +327,62 @@ def add_standard_attention_hparams(hparams): return hparams -def encoder_decoder_attention_loss(expected_attention, +def encoder_decoder_attention_loss(expected_attention_logits, actual_attentions, + loss_type="kl_divergence", loss_multiplier=1.0): """Computes encdec attention loss between expected and actual attentions. Args: - expected_attention: Tensor storing the expected encoder-decoder attention - weights with shape [batch_size, target_length, input_length]. - actual_attentions: Dictionary with actual attention weights for different + expected_attention_logits: Tensor storing the expected encoder-decoder + attention logits with shape [batch_size, target_length, input_length]. + actual_attentions: Dictionary with actual attention logits for different attention types and hidden layers. + loss_type: type of the loss function. loss_multiplier: multiplier for the attention loss. Returns: - MSE loss between the actual and expected attention weights. + KL_divergence loss between the actual and expected attention logits. """ - # For each hidden layer, we have an attention weight tensor with shape - # [batch_size, num_heads, target_length, input_length]. - actual_encdec_attention_weights = [ - t for layer_key, t in actual_attentions.items() - if "encdec_attention" in layer_key - ] - # Stack all hidden layer attention weight tensors to get a tensor with shape - # [num_hidden_layers, batch_size, num_heads, target_length, input_length]. - actual_attention_weights = tf.stack(actual_encdec_attention_weights) - # Reduce mean across all layers (axis=0) and all heads (axis=2) to get a - # tensor with shape [batch_size, target_length, input_length]. - actual_attention_weights = tf.reduce_mean(actual_attention_weights, [0, 2]) - return tf.losses.mean_squared_error( - expected_attention, actual_attention_weights) * loss_multiplier + + def combine_attentions(attention_list): + """Combine different layer attentions and then average over layers/heads.""" + # Stack all hidden layer attention tensors to get a tensor with shape + # [num_hidden_layers, batch_size, num_heads, target_length, input_length]. + attentions = tf.stack(attention_list) + # Reduce mean across all layers (axis=0) and all heads (axis=2) to get a + # tensor with shape [batch_size, target_length, input_length]. + return tf.reduce_mean(attentions, [0, 2]) + + def kl_divergence_loss(expected_logits, actual_logits): + p = tf.contrib.distributions.Categorical(logits=expected_logits) + q = tf.contrib.distributions.Categorical(logits=actual_logits) + return tf.contrib.distributions.kl_divergence(p, q) + + def mse_loss(expected_logits, actual_weights): + expected_weights = tf.nn.softmax(expected_logits) + return tf.losses.mean_squared_error(expected_weights, actual_weights) + + # For each hidden layer, we have attention-logit and attention-weight tensors + # with shape [batch_size, num_heads, target_length, input_length]. + loss = 0.0 + if loss_type == "mse": + actual_encdec_attention_weights = [ + t for layer_key, t in actual_attentions.items() + if "encdec_attention" in layer_key and not layer_key.endswith("/logits") + ] + actual_attention_weights = combine_attentions( + actual_encdec_attention_weights) + loss = mse_loss(expected_attention_logits, actual_attention_weights) + else: + actual_encdec_attention_logits = [ + t for layer_key, t in actual_attentions.items() + if "encdec_attention" in layer_key and layer_key.endswith("/logits") + ] + actual_attention_logits = combine_attentions(actual_encdec_attention_logits) + loss = kl_divergence_loss(expected_attention_logits, + actual_attention_logits) + return loss * loss_multiplier @expert_utils.add_name_scope() @@ -1305,6 +1332,7 @@ def dot_product_attention(q, weights = tf.nn.softmax(logits, name="attention_weights") if save_weights_to is not None: save_weights_to[scope.name] = weights + save_weights_to[scope.name + "/logits"] = logits # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index fba76a342..7a999d3b4 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -1618,7 +1618,8 @@ def padded_cross_entropy(logits, labels, label_smoothing, weights_fn=weights_nonzero, - reduce_sum=True): + reduce_sum=True, + gaussian=False): """Compute cross-entropy assuming 0s are padding. Computes a loss numerator (the sum of losses), and loss denominator @@ -1631,12 +1632,19 @@ def padded_cross_entropy(logits, label_smoothing: a floating point `Scalar`. weights_fn: A function from labels to weights. reduce_sum: a Boolean, whether to sum at the end or not. + gaussian: If true, use a gaussian distribution for label smoothing Returns: loss_numerator: a `Scalar`. Sum of losses. loss_denominator: a `Scalar. The number of non-padding target tokens. + + Raises: + ValueError: in case of unsupported argument types. """ if isinstance(logits, FactoredTensor): + if gaussian: + raise ValueError("Factored padded cross entropy with Gaussian smoothing " + "is not implemented yet.") return padded_cross_entropy_factored( logits, labels, @@ -1653,7 +1661,8 @@ def padded_cross_entropy(logits, labels = tf.reshape(labels, [-1]) else: logits, labels = pad_with_zeros(logits, labels) - xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence) + xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence, + gaussian=gaussian) weights = weights_fn(labels) if not reduce_sum: return xent * weights, weights @@ -1688,7 +1697,7 @@ def smoothing_cross_entropy(logits, confidence * tf.log(confidence) + tf.to_float(vocab_size - 1) * low_confidence * tf.log(low_confidence + 1e-20)) - if gaussian: + if gaussian and confidence > 0.0: labels = tf.cast(labels, tf.float32) normal_dist = tf.distributions.Normal(loc=labels, scale=confidence) @@ -2629,3 +2638,30 @@ def dense(x, units, **kwargs): return _recompute_grad(fn, [x]) else: return fn(x) + + +def mix(x1, x2, steps, is_training, + min_prob=0.0, max_prob=1.0, mode="lin", simple=False): + """Mix starting with x2, mixing mixing, going towards x1.""" + if not is_training: + return x1 + + def get_res(): + """Create the result. Separate function to speed it up later (see below).""" + if mode == "lin": + alpha_p = inverse_lin_decay(steps) + else: + alpha_p = inverse_exp_decay(steps) + alpha_p = alpha_p * (max_prob - min_prob) + min_prob + if simple: + return alpha_p * x1 + (1.0 - alpha_p) * x2 + alpha = tf.random_uniform(shape_list(x1)) + alpha = tf.to_float(tf.less(alpha, alpha_p)) + return alpha * x1 + (1.0 - alpha) * x2 + + if max_prob < 1.0: + return get_res() + + # Prevent sampling after steps is passed to speed it up. + return tf.cond(tf.less(tf.train.get_global_step(), steps), + get_res, lambda: x1) diff --git a/tensor2tensor/layers/discretization.py b/tensor2tensor/layers/discretization.py index ccb00ab6b..16f21473a 100644 --- a/tensor2tensor/layers/discretization.py +++ b/tensor2tensor/layers/discretization.py @@ -63,7 +63,14 @@ def slice_hidden(x, hidden_size, num_blocks): return x_sliced -def nearest_neighbor(x, means, block_v_size, random_top_k=1): +def nearest_neighbor(x, + means, + block_v_size, + random_top_k=1, + soft_em=False, + inv_temp=1.0, + ema_count=None, + c_probs=None): """Find the nearest element in means to elements in x. Args: @@ -72,7 +79,12 @@ def nearest_neighbor(x, means, block_v_size, random_top_k=1): means: Embedding table of shpae [num_blocks, block_v_size, block_dim]. block_v_size: Number of table entries per block. random_top_k: Noisy top-k if this is bigger than 1 (Default: 1). - + soft_em: If True then use soft EM rather than hard EM (Default: False). + inv_temp: Inverse temperature for soft EM (Default: 1.) + ema_count: Table of counts for each embedding corresponding to how many + examples in a batch it was the closest to (Default: None). + c_probs: Precomputed probablities of clusters may be given, for example in + the case of smoothed l0 priors. Returns: Tensor with nearest element in mean encoded in one-hot notation. """ @@ -83,20 +95,40 @@ def nearest_neighbor(x, means, block_v_size, random_top_k=1): scalar_prod = tf.transpose(scalar_prod, perm=[1, 0, 2]) dist = x_norm_sq + tf.transpose( means_norm_sq, perm=[2, 0, 1]) - 2 * scalar_prod - if random_top_k > 1: - _, top_k_idx = tf.nn.top_k(-dist, k=random_top_k) - nearest_idx = tf.gather( - top_k_idx, - tf.random_uniform( - [1], minval=0, maxval=random_top_k - 1, dtype=tf.int32), - axis=-1) + # computing cluster probabilities + if soft_em or c_probs is not None: + if c_probs is not None: + # expand dims to match inv temp + c_probs = tf.expand_dims(c_probs, 0) + else: + ema_count = tf.expand_dims(ema_count+1., 0) + c_probs = ema_count / tf.reduce_sum(ema_count, 2, keepdims=True) + if soft_em: + nearest_hot = tf.nn.softmax(-inv_temp * dist, axis=-1) * c_probs + nearest_hot /= tf.reduce_sum(nearest_hot, 2, keepdims=True) else: - nearest_idx = tf.argmax(-dist, axis=-1) - nearest_hot = tf.one_hot(nearest_idx, block_v_size) - return tf.stop_gradient(nearest_hot) - - -def embedding_lookup(x, means, num_blocks, block_v_size, random_top_k=1): + if random_top_k > 1: + _, top_k_idx = tf.nn.top_k(-dist, k=random_top_k) + nearest_idx = tf.gather( + top_k_idx, + tf.random_uniform( + [1], minval=0, maxval=random_top_k - 1, dtype=tf.int32), + axis=-1) + else: + nearest_idx = tf.argmax(-dist, axis=-1) + nearest_hot = tf.one_hot(nearest_idx, block_v_size) + return nearest_hot + + +def embedding_lookup(x, + means, + num_blocks, + block_v_size, + random_top_k=1, + soft_em=False, + inv_temp=1.0, + ema_count=None, + c_probs=None): """Compute nearest neighbors and loss for training the embeddings via DVQ. Args: @@ -106,12 +138,19 @@ def embedding_lookup(x, means, num_blocks, block_v_size, random_top_k=1): num_blocks: Number of blocks in DVQ. block_v_size: Number of table entries per block. random_top_k: Noisy top-k if this is bigger than 1 (Default: 1). + soft_em: If True then use soft EM rather than hard EM (Default: False). + inv_temp: Inverse temperature for soft EM (Default: 1.) + ema_count: Table of counts for each embedding corresponding to how many + examples in a batch it was the closest to (Default: None). + c_probs: precomputed cluster probabilities might be passed, for example in + the case of smoothed L0. Returns: The nearest neighbor in one hot form, the nearest neighbor itself, the commitment loss, embedding training loss. """ - x_means_hot = nearest_neighbor(x, means, block_v_size, random_top_k) + x_means_hot = nearest_neighbor(x, means, block_v_size, random_top_k, soft_em, + inv_temp, ema_count, c_probs) x_means_hot_flat = tf.reshape(x_means_hot, [-1, num_blocks, block_v_size]) x_means = tf.matmul(tf.transpose(x_means_hot_flat, perm=[1, 0, 2]), means) x_means = tf.transpose(x_means, [1, 0, 2]) @@ -366,6 +405,8 @@ def discrete_bottleneck(x, decay=0.999, discrete_mix=0.5, random_top_k=1, + soft_em=False, + inv_temp=1.0, epsilon=1e-5, softmax_k=0, kl_warmup_steps=150000, @@ -375,7 +416,11 @@ def discrete_bottleneck(x, summary=True, dp_strength=1.0, dp_decay=1.0, - dp_alpha=0.5): + dp_alpha=0.5, + slo=False, + slo_alpha=10, + slo_beta=0.5, + c_logits=None): """Discretization bottleneck for latent variables. Args: @@ -402,6 +447,8 @@ def discrete_bottleneck(x, discrete_mix: Factor for mixing discrete and non-discrete input for semhash (Default: 0.5). random_top_k: Noisy top-k for DVQ (Default: 1). + soft_em: If True then use soft EM rather than hard EM (Default: False). + inv_temp: Inverse temperature for soft EM (Default: 1.) epsilon: Epsilon parameter for DVQ (Default: 1e-5). softmax_k: If > 1 then do top-k softmax (Default: 0). kl_warmup_steps: Number of steps for kl warmup (Default: 150000). @@ -415,6 +462,11 @@ def discrete_bottleneck(x, dp_decay: Decay the dp_strength using an exponential decay using this term (Default: 1.0). dp_alpha: Alpha term (pseudo-count) in Dirichlet Process (Default: 0.5). + slo: Smoothed L0 + slo_alpha: alpha for smoothed L0 + slo_beta: beta for smoothed L0 + c_logits: a [num_blocks, block_size] tensor of logits for + computing cluster probabilities. Returns: Embedding to pass to the decoder, discrete latent, loss, and the embedding @@ -500,9 +552,13 @@ def discrete_bottleneck(x, c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hidden_size, name='dae_dense') elif bottleneck_kind == 'dvq': + c_probs = None + if c_logits is not None: + c_probs = tf.nn.softmax(c_logits, axis=-1) x_reshaped = reshape_fn(x) x_means_hot, x_means, q_loss, e_loss = embedding_lookup( - x_reshaped, means, num_blocks, block_v_size, random_top_k) + x_reshaped, means, num_blocks, block_v_size, random_top_k, soft_em, + inv_temp, ema_count, c_probs) # Get the discrete latent represenation x_means_idx = tf.argmax(x_means_hot, axis=-1) @@ -535,6 +591,7 @@ def discrete_bottleneck(x, # Adding a term that puts a Dirichlet prior over cluster probabilities # Hopefully it'll encourage rich get richer behaviors dp_prior_loss = 0. + slo_loss = 0. if dp_strength > 0.0: # Decay dp_strength over time to make it less important dp_strength = tf.train.exponential_decay( @@ -548,6 +605,13 @@ def discrete_bottleneck(x, dp_prior_loss = -1.0 * tf.reduce_sum(dp_prior_loss) dp_prior_loss /= (num_blocks * block_v_size) + # if using smoothed L0 + if slo: + # expected log likelihood + ell = tf.reduce_sum(ema_count * tf.log(c_probs)) + # the prior component in the loss for MAP EM. + slo_prior = slo_alpha * tf.reduce_sum(tf.exp(-1.*c_probs/slo_beta)) + slo_loss = -1. * (ell + slo_prior)/(num_blocks * block_v_size) x_means_hot_flat = tf.reshape( x_means_hot, shape=[-1, num_blocks, block_v_size]) dw = tf.matmul( @@ -563,7 +627,7 @@ def discrete_bottleneck(x, with tf.control_dependencies([e_loss]): update_means = tf.assign(means, updated_ema_means) with tf.control_dependencies([update_means]): - l = beta * e_loss + dp_strength * dp_prior_loss + l = beta * e_loss + dp_strength * dp_prior_loss + slo_loss else: l = q_loss + beta * e_loss diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 8e1cbd5fb..57228ada3 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -194,7 +194,7 @@ def loss(self, logits, targets): @registry.register_image_modality("default") class ImageModality(modality.Modality): """Modality for images.""" - NUM_CHANNELS = 3 + PIXEL_EMBEDDING_SIZE = 64 def bottom(self, inputs): with tf.variable_scope(self.name): @@ -205,35 +205,50 @@ def bottom(self, inputs): def targets_bottom(self, inputs): with tf.variable_scope(self.name): - # Reshape inputs to 2-d tensor and embed the RGB pixel values. - ret = common_layers.embedding( - tf.to_int32(common_layers.flatten4d3d(inputs)), - self.top_dimensionality, - self._body_input_depth, - name="input_rgb_embedding") - if self._model_hparams.multiply_embedding_mode == "sqrt_depth": - ret *= self._body_input_depth**0.5 - - reshape_shape = common_layers.shape_list(inputs)[:3] - reshape_shape.append(self._body_input_depth * 3) - ret = tf.reshape(ret, reshape_shape) - return tf.layers.dense(ret, self._body_input_depth) + if not context.in_eager_mode(): + tf.summary.image("targets_bottom", + tf.cast(inputs, tf.uint8), max_outputs=1) + inputs_shape = common_layers.shape_list(inputs) + if len(inputs_shape) != 4: + raise ValueError("Assuming images given as int tensors in the format " + "[batch, height, width, channels] (256 values).") + # We embed each of 256=self.top_dimensionality possible pixel values. + embedding_var = tf.get_variable( + "pixel_embedding", + [self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE]) + hot_inputs = tf.one_hot(tf.to_int32(inputs), self.top_dimensionality) + hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality]) + embedded = tf.matmul(hot_inputs, embedding_var) + # Let's now merge all channels that were embedded into a single vector. + merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[3] + embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size]) + merged = tf.layers.dense(embedded, self._body_input_depth, + name="merge_pixel_embedded_channels") + return merged def top(self, body_output, _): + # TODO(lukaszkaiser): is this a universal enough way to get channels? + num_channels = self._model_hparams.problem_instances[0].num_channels with tf.variable_scope("rgb_softmax"): - body_output_shape = common_layers.shape_list(body_output) reshape_shape = body_output_shape[:3] - dim = body_output_shape[-1] // 3 - reshape_shape.extend([self.NUM_CHANNELS, dim]) - - out = tf.reshape(body_output, reshape_shape) - res = tf.layers.dense(out, self.top_dimensionality) + reshape_shape.extend([num_channels, self.top_dimensionality]) + res = tf.layers.dense(body_output, self.top_dimensionality * num_channels) + res = tf.reshape(res, reshape_shape) if not tf.get_variable_scope().reuse: res_argmax = tf.cast(tf.argmax(res, axis=-1), tf.uint8) tf.summary.image("result", res_argmax, max_outputs=1) return res + def loss(self, logits, targets): + """Compute loss numerator and denominator for one shard of output.""" + return common_layers.padded_cross_entropy( + logits, + targets, + self._model_hparams.label_smoothing, + weights_fn=self.targets_weights_fn, + gaussian=True) + @registry.register_image_modality("image_channel_compress") class ImageChannelCompressModality(modality.Modality): diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index 075840f2f..301c6d42a 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -41,6 +41,7 @@ from tensor2tensor.models.research import aligned from tensor2tensor.models.research import attention_lm from tensor2tensor.models.research import attention_lm_moe +from tensor2tensor.models.research import autoencoders from tensor2tensor.models.research import basic_conv_gen from tensor2tensor.models.research import cycle_gan from tensor2tensor.models.research import gene_expression diff --git a/tensor2tensor/models/basic.py b/tensor2tensor/models/basic.py index 35ae204b5..42d5f12db 100644 --- a/tensor2tensor/models/basic.py +++ b/tensor2tensor/models/basic.py @@ -44,6 +44,52 @@ def body(self, features): return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T. +@registry.register_model +class BasicAutoencoder(t2t_model.T2TModel): + """A basic autoencoder, try with image_mnist_rev or image_cifar10_rev.""" + + def bottleneck(self, x, res_size): + hparams = self._hparams + x = tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck") + x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout) + x = tf.layers.dense(x, res_size, name="unbottleneck") + return x + + def body(self, features): + hparams = self._hparams + is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN + x = features["targets"] + shape = common_layers.shape_list(x) + kernel = (hparams.kernel_height, hparams.kernel_width) + is1d = shape[2] == 1 + kernel = (hparams.kernel_height, 1) if is1d else kernel + strides = (2, 1) if is1d else (2, 2) + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=1) + if not is1d: + x, _ = common_layers.pad_to_same_length( + x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=2) + # Down-convolutions. + for i in xrange(hparams.num_hidden_layers): + x = tf.layers.conv2d( + x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides, + padding="SAME", activation=tf.nn.relu, name="conv_%d" % i) + x = common_layers.layer_norm(x) + # Bottleneck (mix during early training, not too important but very stable). + b = self.bottleneck(x, hparams.hidden_size * 2**hparams.num_hidden_layers) + x = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) + # Up-convolutions. + for i in xrange(hparams.num_hidden_layers): + j = hparams.num_hidden_layers - i - 1 + x = tf.layers.conv2d_transpose( + x, hparams.hidden_size * 2**j, kernel, strides=strides, + padding="SAME", activation=tf.nn.relu, name="deconv_%d" % j) + x = common_layers.layer_norm(x) + res = x[:, :shape[1], :shape[2], :] + return common_layers.mix(res, features["targets"], + hparams.bottleneck_warmup_steps // 2, is_training) + + @registry.register_hparams def basic_fc_small(): """Small fully connected model.""" @@ -57,3 +103,26 @@ def basic_fc_small(): hparams.weight_decay = 0.0 hparams.dropout = 0.0 return hparams + + +@registry.register_hparams +def basic_autoencoder(): + """Basic autoencoder model.""" + hparams = common_hparams.basic_params1() + hparams.optimizer = "Adam" + hparams.learning_rate_constant = 0.0002 + hparams.learning_rate_warmup_steps = 500 + hparams.learning_rate_schedule = "constant * linear_warmup" + hparams.label_smoothing = 0.05 + hparams.batch_size = 128 + hparams.hidden_size = 64 + hparams.num_hidden_layers = 4 + hparams.initializer = "uniform_unit_scaling" + hparams.initializer_gain = 1.0 + hparams.weight_decay = 0.0 + hparams.kernel_height = 4 + hparams.kernel_width = 4 + hparams.dropout = 0.1 + hparams.add_hparam("bottleneck_size", 128) + hparams.add_hparam("bottleneck_warmup_steps", 3000) + return hparams diff --git a/tensor2tensor/models/image_transformer.py b/tensor2tensor/models/image_transformer.py index dbb58d0b1..0f6244e36 100644 --- a/tensor2tensor/models/image_transformer.py +++ b/tensor2tensor/models/image_transformer.py @@ -654,12 +654,25 @@ def update_hparams_for_tpu(hparams): hparams.batch_size = 4 +@registry.register_hparams +def imagetransformer_base_tpu(): + hparams = imagetransformer_base() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.hidden_size = 256 + hparams.filter_size = 512 + hparams.num_hidden_layers = 8 + hparams.sampling_method = "random" + return hparams + + @registry.register_hparams def imagetransformer_sep_channels_8l_tpu(): """Hparams for training imagetransformer on tpu.""" hparams = imagetransformer_sep_channels_8l() update_hparams_for_tpu(hparams) - hparams.batch_size = 1 + hparams.batch_size = 4 hparams.num_heads = 4 # heads are expensive on tpu hparams.shared_embedding_and_softmax_weights = False return hparams diff --git a/tensor2tensor/models/image_transformer_2d.py b/tensor2tensor/models/image_transformer_2d.py index 83166a937..101126d31 100644 --- a/tensor2tensor/models/image_transformer_2d.py +++ b/tensor2tensor/models/image_transformer_2d.py @@ -453,11 +453,11 @@ def update_hparams_for_tpu(hparams): @registry.register_hparams -def img2mg_transformer_base_tpu(): +def img2img_transformer_base_tpu(): """Hparams for training img2img_transformer on tpu.""" hparams = img2img_transformer_base() update_hparams_for_tpu(hparams) - hparams.batch_size = 4 + hparams.batch_size = 2 hparams.num_heads = 4 # heads are expensive on tpu hparams.num_decoder_layers = 8 hparams.num_encoder_layers = 4 @@ -466,8 +466,8 @@ def img2mg_transformer_base_tpu(): @registry.register_hparams -def img2mg_transformer_tiny_tpu(): - hparams = img2mg_transformer_base_tpu() +def img2img_transformer_tiny_tpu(): + hparams = img2img_transformer_base_tpu() hparams.num_hidden_layers = 2 hparams.hidden_size = 16 hparams.batch_size = 2 diff --git a/tensor2tensor/models/research/autoencoders.py b/tensor2tensor/models/research/autoencoders.py new file mode 100644 index 000000000..67690f551 --- /dev/null +++ b/tensor2tensor/models/research/autoencoders.py @@ -0,0 +1,53 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Autoencoders.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.layers import common_layers +from tensor2tensor.models import basic +from tensor2tensor.utils import registry + +import tensorflow as tf + + +@registry.register_model +class BasicDiscreteAutoencoder(basic.BasicAutoencoder): + + def bottleneck(self, x, res_size): + hparams = self._hparams + x = tf.tanh(tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck")) + d = x + tf.stop_gradient(2 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) + y = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout) + x = common_layers.mix(d, y, hparams.discretize_warmup_steps, + hparams.mode == tf.estimator.ModeKeys.TRAIN) + x = tf.layers.dense(x, res_size, name="unbottleneck") + return x + + +@registry.register_hparams +def basic_discrete_autoencoder(): + """Basic autoencoder model.""" + hparams = basic.basic_autoencoder() + hparams.hidden_size = 128 + hparams.bottleneck_size = 512 + hparams.bottleneck_warmup_steps = 3000 + hparams.add_hparam("discretize_warmup_steps", 5000) + return hparams diff --git a/tensor2tensor/models/research/basic_conv_gen.py b/tensor2tensor/models/research/basic_conv_gen.py index c04e0b891..b0235eb25 100644 --- a/tensor2tensor/models/research/basic_conv_gen.py +++ b/tensor2tensor/models/research/basic_conv_gen.py @@ -1,4 +1,3 @@ - # coding=utf-8 # Copyright 2018 The Tensor2Tensor Authors. # @@ -34,22 +33,30 @@ class BasicConvGen(t2t_model.T2TModel): def body(self, features): - print(features) filters = self.hparams.hidden_size cur_frame = tf.to_float(features["inputs"]) prev_frame = tf.to_float(features["inputs_prev"]) - print(features["inputs"].shape, cur_frame.shape, prev_frame.shape) + action_embedding_size = 32 + action_space_size = 10 + kernel = (3, 3) + # Gather all inputs. action = common_layers.embedding(tf.to_int64(features["action"]), - 10, filters) - action = tf.reshape(action, [-1, 1, 1, filters]) - - frames = tf.concat([cur_frame, prev_frame], axis=3) - h1 = tf.layers.conv2d(frames, filters, kernel_size=(3, 3), padding="SAME") - h2 = tf.layers.conv2d(tf.nn.relu(h1 + action), filters, - kernel_size=(5, 5), padding="SAME") - res = tf.layers.conv2d(tf.nn.relu(h2 + action), 3 * 256, - kernel_size=(3, 3), padding="SAME") - + action_space_size, action_embedding_size) + action = tf.reshape(action, [-1, 1, 1, action_embedding_size]) + frames = tf.concat([cur_frame, prev_frame, action], axis=3) + x = tf.layers.conv2d(frames, filters, kernel, activation=tf.nn.relu, + strides=(2, 2), padding="SAME") + # Run a stack of convolutions. + for _ in xrange(self.num_hidden_layers): + y = tf.layers.conv2d(frames, filters, kernel, activation=tf.nn.relu, + strides=(1, 1), padding="SAME") + x = common_layers.layer_norm(x + y) + # Up-convolve. + x = tf.layers.conv2d_transpose( + frames, filters, kernel, activation=tf.nn.relu, + strides=(2, 2), padding="SAME") + # Output size is 3 * 256 for 3-channel color space. + res = tf.layers.conv2d(x, 3 * 256, kernel, padding="SAME") height = tf.shape(res)[1] width = tf.shape(res)[2] res = tf.reshape(res, [-1, height, width, 3, 256]) @@ -58,7 +65,7 @@ def body(self, features): @registry.register_hparams def basic_conv_small(): - # """Small conv model.""" + """Small conv model.""" hparams = common_hparams.basic_params1() hparams.hidden_size = 32 hparams.batch_size = 2 diff --git a/tensor2tensor/models/research/transformer_vae.py b/tensor2tensor/models/research/transformer_vae.py index a5bb3ff85..ab15b31af 100644 --- a/tensor2tensor/models/research/transformer_vae.py +++ b/tensor2tensor/models/research/transformer_vae.py @@ -477,6 +477,8 @@ def __init__(self, *args, **kwargs): decay=self._hparams.decay, discrete_mix=self._hparams.d_mix, random_top_k=self._hparams.random_top_k, + soft_em=self.hparams.soft_em, + inv_temp=self.hparams.inv_temp, epsilon=self._hparams.epsilon, softmax_k=self._hparams.softmax_k, kl_warmup_steps=self._hparams.kl_warmup_steps, @@ -484,8 +486,10 @@ def __init__(self, *args, **kwargs): summary=_DO_SUMMARIES, dp_strength=self._hparams.dp_strength, dp_decay=self._hparams.dp_decay, - dp_alpha=self._hparams.dp_alpha) - + dp_alpha=self._hparams.dp_alpha, + slo=self._hparams.slo, + slo_alpha=self._hparams.slo_alpha, + slo_beta=self._hparams.slo_beta) # Set the discretization bottleneck specific things here if self._hparams.bottleneck_kind == "dvq": block_dim = int(self._hparams.hidden_size // self._hparams.num_blocks) @@ -511,7 +515,6 @@ def __init__(self, *args, **kwargs): tf.logging.info("Using slices for DVQ") else: raise ValueError("Unknown reshape method") - means = tf.get_variable( name="means", shape=[self._hparams.num_blocks, block_v_size, block_dim], @@ -521,17 +524,27 @@ def __init__(self, *args, **kwargs): if self._hparams.ema: ema_count = tf.get_variable( "ema_count", [self._hparams.num_blocks, block_v_size], - initializer=tf.constant_initializer(0)) + initializer=tf.constant_initializer(0), + trainable=False) with tf.colocate_with(means): ema_means = tf.get_variable( - "ema_means", initializer=means.initialized_value()) - + "ema_means", initializer=means.initialized_value(), + trainable=False) + + # Create the shadow variables if we are using smoothed l0 + c_logits = None + if self._hparams.slo: + # softmax logits for the cluster probabilities + c_logits = tf.get_variable( + "c_logits", [self._hparams.num_blocks, block_v_size], + initializer=tf.uniform_unit_scaling_initializer()) # Update bottleneck self._hparams.bottleneck = partial( self._hparams.bottleneck, means=means, ema_count=ema_count, - ema_means=ema_means) + ema_means=ema_means, + c_logits=c_logits) @property def has_input(self): @@ -639,6 +652,9 @@ def transformer_ae_small(): hparams.add_hparam("dp_alpha", 0.5) hparams.add_hparam("dp_strength", 0.25) hparams.add_hparam("dp_decay", 1.0) + hparams.add_hparam("slo", False) # for smoothed L0. + hparams.add_hparam("slo_alpha", 0.25) + hparams.add_hparam("slo_beta", 0.5) hparams.add_hparam("unmasked_percentage", 0.1) hparams.add_hparam("do_ae", True) hparams.add_hparam("do_mask", True) @@ -664,6 +680,8 @@ def transformer_ae_small(): hparams.add_hparam("decay", 0.999) hparams.add_hparam("ema", True) hparams.add_hparam("random_top_k", 1) + hparams.add_hparam("soft_em", False) + hparams.add_hparam("inv_temp", 1.0) hparams.kl_warmup_steps = 150000 hparams.force_full_predict = True diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 2cddbee8e..9e0142fbc 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -177,6 +177,7 @@ def body(self, features): if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, + hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} @@ -1468,6 +1469,8 @@ def transformer_librispeech_tpu(): def transformer_supervised_attention(): """Hparams for supervised attention problems.""" hparams = transformer_base() + # Attention loss type (KL-divergence or MSE). + hparams.add_hparam("expected_attention_loss_type", "kl_divergence") # Multiplier to the encoder-decoder expected attention loss. hparams.add_hparam("expected_attention_loss_multiplier", 1.0) return hparams diff --git a/tensor2tensor/models/vanilla_gan.py b/tensor2tensor/models/vanilla_gan.py index 1d66bbba2..100d60549 100644 --- a/tensor2tensor/models/vanilla_gan.py +++ b/tensor2tensor/models/vanilla_gan.py @@ -25,104 +25,170 @@ # Dependency imports from tensor2tensor.layers import common_hparams -from tensor2tensor.layers import common_layers from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow as tf -def generator(z, hparams, reuse=False): - """Initalizes generator layers.""" +def lrelu(input_, leak=0.2, name="lrelu"): + return tf.maximum(input_, leak * input_, name=name) - g_h1 = tf.layers.dense( - z, hparams.hidden_dim, activation=tf.nn.relu, name="l1", reuse=reuse) - g_log_prob = tf.layers.dense( - g_h1, hparams.height * hparams.width, name="logp", reuse=reuse) - g_prob = tf.nn.sigmoid(g_log_prob) - return g_prob +def deconv2d( + input_, output_shape, k_h, k_w, d_h, d_w, stddev=0.02, name="deconv2d"): + with tf.variable_scope(name): + w = tf.get_variable( + "w", [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], + initializer=tf.random_normal_initializer(stddev=stddev)) + deconv = tf.nn.conv2d_transpose( + input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) + biases = tf.get_variable( + "biases", [output_shape[-1]], initializer=tf.constant_initializer(0.0)) + return tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) -def discriminator(x, hparams, reuse=False): - """Initalizes discriminator layers.""" - d_h1 = tf.layers.dense( - x, hparams.hidden_dim, activation=tf.nn.relu, name="d_h1", reuse=reuse) - d_logit = tf.layers.dense(d_h1, 1, name="d_logit", reuse=reuse) - d_prob = tf.nn.sigmoid(d_logit) +class AbstractGAN(t2t_model.T2TModel): + """Base class for all GANs.""" - return d_prob, d_logit + def discriminator(self, x, is_training, reuse=False): + """Discriminator architecture based on InfoGAN. + Args: + x: input images, shape [bs, h, w, channels] + is_training: boolean, are we in train or eval model. + reuse: boolean, should params be re-used. -def reverse_grad(x): - return tf.stop_gradient(2 * x) - x + Returns: + out_logit: the output logits (before sigmoid). + """ + hparams = self._hparams + with tf.variable_scope( + "discriminator", reuse=reuse, + initializer=tf.random_normal_initializer(stddev=0.02)): + batch_size = hparams.batch_size + # Mapping x from [bs, h, w, c] to [bs, 1] + net = tf.layers.conv2d(x, 64, (4, 4), strides=(2, 2), + padding="SAME", name="d_conv1") + # [bs, h/2, w/2, 64] + net = lrelu(net) + net = tf.layers.conv2d(net, 128, (4, 4), strides=(2, 2), + padding="SAME", name="d_conv2") + # [bs, h/4, w/4, 128] + if hparams.discriminator_batchnorm: + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="d_bn2") + net = lrelu(net) + size = hparams.height * hparams.width + net = tf.reshape(net, [batch_size, size * 8]) # [bs, h * w * 8] + net = tf.layers.dense(net, 1024, name="d_fc3") # [bs, 1024] + if hparams.discriminator_batchnorm: + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="d_bn3") + net = lrelu(net) + out_logit = tf.layers.dense(net, 1, name="d_fc4") # [bs, 1] + return out_logit + + def generator(self, z, is_training, reuse=False): + """Generator outputting image in [0, 1].""" + hparams = self._hparams + height = hparams.height + width = hparams.width + batch_size = hparams.batch_size + with tf.variable_scope( + "generator", reuse=reuse, + initializer=tf.random_normal_initializer(stddev=0.02)): + net = tf.layers.dense(z, 1024, name="g_fc1") + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="g_bn1") + net = lrelu(net) + net = tf.layers.dense(net, 128 * (height // 4) * (width // 4), + name="g_fc2") + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="g_bn2") + net = lrelu(net) + net = tf.reshape(net, [batch_size, height // 4, width // 4, 128]) + net = deconv2d(net, [batch_size, height // 2, width // 2, 64], + 4, 4, 2, 2, name="g_dc3") + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="g_bn3") + net = lrelu(net) + net = deconv2d(net, [batch_size, height, width, hparams.c_dim], + 4, 4, 2, 2, name="g_dc4") + out = tf.nn.sigmoid(net) + return out + + def body(self, features): + """Body of the model. + Args: + features: a dictionary with the tensors. -def vanilla_gan_internal(inputs, hparams, train): - with tf.variable_scope("vanilla_gan", reuse=tf.AUTO_REUSE): - batch_size, height, width, _ = common_layers.shape_list(inputs) - assert height == hparams.height - assert width == hparams.width + Returns: + A pair (predictions, losses) where preditions is the generated image + and losses is a dictionary of losses (that get added for the final loss). + """ + features["targets"] = features["inputs"] + is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN - # Currently uses only one of RGB - x = inputs - x = x[:, :, :, 0] - x = tf.reshape(x, [batch_size, height * width]) + # Input images. + inputs = features["inputs"] - # Generate a fake image + # Noise vector. z = tf.random_uniform( - shape=[batch_size, hparams.random_sample_size], + shape=[self._hparams.batch_size, self._hparams.z_size], minval=-1, maxval=1, name="z") - g_sample = generator(z, hparams) - - # Discriminate on the real image - d_real, _ = discriminator(x, hparams) - - # Discriminate on the fake image - d_fake, _ = discriminator(reverse_grad(g_sample), hparams, reuse=True) - # GAN losses - d_loss = -tf.reduce_mean( - tf.log(d_real + hparams.epsilon) + tf.log(1. - d_fake)) - g_loss = -tf.reduce_mean(tf.log(d_fake + hparams.epsilon)) - - losses = {} + # Discriminator output for real images. + d_real_logits = self.discriminator( + inputs, is_training=is_training, reuse=False) + + # Discriminator output for fake images. + g = self.generator(z, is_training=is_training, reuse=False) + d_fake_logits_g = self.discriminator( + g, is_training=is_training, reuse=True) + # Discriminator doesn't backprop to generator. + d_fake_logits_d = self.discriminator( + tf.stop_gradient(g), is_training=is_training, reuse=True) + + # Loss on real and fake data. + d_loss_real = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits=d_real_logits, labels=tf.ones_like(d_real_logits))) + d_loss_fake_g = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits=d_fake_logits_g, labels=tf.zeros_like(d_fake_logits_g))) + d_loss_fake_d = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits( + logits=d_fake_logits_d, labels=tf.zeros_like(d_fake_logits_d))) + d_loss = d_loss_real + d_loss_fake_d + + losses = {} # All losses get added at the end. losses["discriminator"] = d_loss - losses["generator"] = g_loss - # Include a dummy training loss to skip self.top and self.loss + losses["generator"] = - d_loss_fake_g + # Include a dummy training loss to skip self.loss. losses["training"] = tf.constant(0., dtype=tf.float32) - summary_g_image = tf.reshape(g_sample[0, :], [1, height, width, 1]) + hparams = self._hparams + summary_g_image = tf.reshape(g[0, :], [1, hparams.height, hparams.width, 1]) tf.summary.image("generated", summary_g_image, max_outputs=1) - if train: + if is_training: # Returns an dummy output and the losses dictionary. - return tf.zeros([batch_size, 1]), losses - else: - return g_sample, losses + return tf.zeros_like(inputs), losses + return tf.reshape(g, tf.shape(inputs)), losses + + def top(self, body_output, features): + """Override the top function to not do anything.""" + return body_output @registry.register_model -class VanillaGan(t2t_model.T2TModel): +class VanillaGan(AbstractGAN): """Simple GAN for demonstration.""" - def body(self, features): - """Computes the generator and discriminator loss. - - Args: - features: A dictionary of key to Tensor. "inputs" should be an image. - - Returns: - output: Tensor containing one zero. GANs do not make use of the modality - loss. - losses: a dictionary of losses containing the generator and discriminator - losses. - """ - train = self.hparams.mode == tf.estimator.ModeKeys.TRAIN - return vanilla_gan_internal(features["inputs"], self.hparams, train) - def infer(self, features=None, decode_length=50, @@ -137,7 +203,7 @@ def infer(self, maxval=1, name="z") - g_sample = generator(z, self._hparams) + g_sample = self.generator(z, self._hparams) return g_sample @@ -145,12 +211,12 @@ def infer(self, def vanilla_gan(): """Basic parameters for a vanilla_gan.""" hparams = common_hparams.basic_params1() - - hparams.batch_size = 32 hparams.label_smoothing = 0.0 - hparams.add_hparam("hidden_dim", 128) - hparams.add_hparam("random_sample_size", 100) + hparams.hidden_size = 128 + hparams.batch_size = 64 + hparams.add_hparam("z_size", 64) + hparams.add_hparam("c_dim", 1) hparams.add_hparam("height", 28) hparams.add_hparam("width", 28) - hparams.add_hparam("epsilon", 1e-4) + hparams.add_hparam("discriminator_batchnorm", int(True)) return hparams diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 3d18b4d10..a81318731 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -43,7 +43,7 @@ def decode_hparams(overrides=""): hp = tf.contrib.training.HParams( save_images=False, problem_idx=0, - extra_length=50, + extra_length=100, batch_size=0, beam_size=4, alpha=0.6, @@ -343,7 +343,8 @@ def decode_interactively(estimator, hparams, decode_hp): """Interactive decoding.""" def input_fn(): - gen_fn = make_input_fn_from_generator(_interactive_input_fn(hparams)) + gen_fn = make_input_fn_from_generator( + _interactive_input_fn(hparams, decode_hp)) example = gen_fn() example = _interactive_input_tensor_to_features_dict(example, hparams) return example @@ -405,7 +406,7 @@ def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs, } -def _interactive_input_fn(hparams): +def _interactive_input_fn(hparams, decode_hp): """Generator that reads from the terminal and yields "interactive inputs". Due to temporary limitations in tf.learn, if we don't want to reload the @@ -417,14 +418,15 @@ def _interactive_input_fn(hparams): Args: hparams: model hparams + decode_hp: decode hparams Yields: numpy arrays Raises: Exception: when `input_type` is invalid. """ - num_samples = 1 - decode_length = 100 + num_samples = decode_hp.num_samples + decode_length = decode_hp.extra_length input_type = "text" problem_id = 0 p_hparams = hparams.problems[problem_id] diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index 1f9e2ed00..bb31a4dec 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -23,6 +23,7 @@ # Dependency imports import numpy as np +import six from tensor2tensor.layers import common_layers from tensor2tensor.utils import bleu_hook @@ -243,22 +244,25 @@ def set_recall(predictions, labels, weights_fn=common_layers.weights_nonzero): return tf.to_float(tf.equal(labels, predictions)), weights -def image_summary(predictions, hparams): +def image_summary(predictions, features, hparams): """Reshapes predictions and passes it to tensorboard. Args: - predictions : A Tensor of scores of shape [batch, nlabels]. - hparams: model_hparams + predictions : The predicted image (logits). + features : The features dictionary with tensors. + hparams: model hparams. Returns: - summary_proto: containing the summary image for predictions - weights: A Tensor of zeros of shape [batch, nlabels]. + summary_proto: containing the summary images. + weights: A Tensor of zeros of the same shape as preditions. """ - predictions_reshaped = tf.reshape( - predictions, [-1, hparams.height, hparams.width, hparams.colors]) - return tf.summary.image( - "image_summary", predictions_reshaped, - max_outputs=1), tf.zeros_like(predictions) + del hparams + results = tf.cast(tf.argmax(predictions, axis=-1), tf.uint8) + gold = tf.cast(features["targets"], tf.uint8) + summary1 = tf.summary.image("prediction", results, max_outputs=2) + summary2 = tf.summary.image("data", gold, max_outputs=2) + summary = tf.summary.merge([summary1, summary2]) + return summary, tf.zeros_like(predictions) def create_evaluation_metrics(problems, model_hparams): @@ -318,23 +322,39 @@ def wrapped_metric_fn(): def image_wrapped_metric_fn(predictions, labels, weights_fn=common_layers.weights_nonzero): - _, _ = labels, weights_fn - return metric_fn(predictions, model_hparams) + del weights_fn + return metric_fn(predictions, labels, model_hparams) tm = problem_instance.get_hparams().target_modality - if isinstance(tm, tuple): - tm = registry.create_modality(tm, model_hparams) - weights_fn = tm.targets_weights_fn - - for metric in metrics: - metric_fn = METRICS_FNS[metric] - metric_name = "metrics-%s/%s" % (problem_name, metric) - if metric == Metrics.IMAGE_SUMMARY: - eval_metrics[metric_name] = image_wrapped_metric_fn - else: - problem_metric_fn = make_problem_specific_metric_fn( - metric_fn, problem_idx, weights_fn) - eval_metrics[metric_name] = problem_metric_fn + if isinstance(tm, dict): + for k, v in six.iteritems(tm): + if isinstance(v, tuple): + v = registry.create_modality(v, model_hparams) + weights_fn = v.targets_weights_fn + + for metric in metrics: + metric_fn = METRICS_FNS[metric] + metric_name = "metrics-%s/%s/%s" % (problem_name, k, metric) + if metric == Metrics.IMAGE_SUMMARY: + eval_metrics[metric_name] = image_wrapped_metric_fn + else: + problem_metric_fn = make_problem_specific_metric_fn( + metric_fn, problem_idx, weights_fn) + eval_metrics[metric_name] = problem_metric_fn + else: + if isinstance(tm, tuple): + tm = registry.create_modality(tm, model_hparams) + weights_fn = tm.targets_weights_fn + + for metric in metrics: + metric_fn = METRICS_FNS[metric] + metric_name = "metrics-%s/%s" % (problem_name, metric) + if metric == Metrics.IMAGE_SUMMARY: + eval_metrics[metric_name] = image_wrapped_metric_fn + else: + problem_metric_fn = make_problem_specific_metric_fn( + metric_fn, problem_idx, weights_fn) + eval_metrics[metric_name] = problem_metric_fn return eval_metrics diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 085cc821f..eef6c5dcb 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -92,6 +92,7 @@ def __init__(self, if not problem_hparams and hasattr(hparams, "problems"): problem_hparams = hparams.problems[0] + print(problem_hparams) self._problem_hparams = problem_hparams # Setup hparams @@ -138,7 +139,7 @@ def call(self, features): sharded_logits, losses = self.model_fn_sharded(sharded_features) if isinstance(sharded_logits, dict): concat_logits = {} - for k, v in sharded_logits.iteritems(): + for k, v in six.iteritems(sharded_logits): concat_logits[k] = tf.concat(v, 0) return concat_logits, losses else: @@ -171,7 +172,7 @@ def model_fn_sharded(self, sharded_features): if isinstance(body_out, dict): sharded_logits = {} sharded_losses = {} - for k, v in body_out.iteritems(): + for k, v in six.iteritems(body_out): sharded_logits[k] = dp(self.top, v, datashard_to_features) sharded_losses[k] = dp(self.loss, sharded_logits[k], datashard_to_features) @@ -189,8 +190,8 @@ def model_fn_sharded(self, sharded_features): else: sharded_logits, sharded_losses = dp(self.model_fn, datashard_to_features) if isinstance(sharded_logits[0], dict): - temp_dict = {k: [] for k, _ in sharded_logits[0].iteritems()} - for k, _ in sharded_logits[0].iteritems(): + temp_dict = {k: [] for k, _ in six.iteritems(sharded_logits[0])} + for k, _ in six.iteritems(sharded_logits[0]): for l in sharded_logits: temp_dict[k].append(l[k]) sharded_logits = temp_dict @@ -250,12 +251,22 @@ def bottom(self, features): all_previous_modalities.append(input_modality.name) # Transform the targets (for autoregressive models) + print(self._problem_hparams) target_modality = self._problem_hparams.target_modality - with tf.variable_scope(target_modality.name): - log_info("Transforming 'targets' with %s.targets_bottom", - target_modality.name) - transformed_features["targets"] = target_modality.targets_bottom( - features["targets"]) + if isinstance(target_modality, dict): + for k, v in six.iteritems(target_modality): + with tf.variable_scope( + "%s/%s" % + (v.name, + k)): # TODO(aidangomez): share variables across modalities? + log_info("Transforming 'targets' with %s.targets_bottom", v.name) + transformed_features[k] = v.targets_bottom(features[k]) + else: + with tf.variable_scope(target_modality.name): + log_info("Transforming 'targets' with %s.targets_bottom", + target_modality.name) + transformed_features["targets"] = target_modality.targets_bottom( + features["targets"]) for key in features: if key not in transformed_features: @@ -284,12 +295,11 @@ def body(self, features): """ raise NotImplementedError("Abstract Method") - def _top_single(self, body_output, features): - if not self._problem_hparams: + def _top_single(self, body_output, target_modality, features): + if not target_modality: log_warn("Without a Problem, T2TModel.top is a passthrough.") return body_output - target_modality = self._problem_hparams.target_modality with tf.variable_scope(target_modality.name): log_info("Transforming body output with %s.top", target_modality.name) last_only = ( @@ -310,32 +320,60 @@ def _top_single(self, body_output, features): def top(self, body_output, features): if isinstance(body_output, dict): + if self._problem_hparams: + target_modality = self._problem_hparams.target_modality + else: + target_modality = {k: None for k in body_output.keys()} + assert set(body_output.keys()) == set(target_modality.keys()), ( + "The keys of model_body's returned logits dict must match the keys " + "of problem_hparams.target_modality's dict.") logits = {} - for k, v in body_output.iteritems(): - logits[k] = self._top_single(v, features) + for k, v in six.iteritems(body_output): + with tf.variable_scope(k): # TODO(aidangomez): share variables here? + logits[k] = self._top_single(v, target_modality[k], features) return logits else: - return self._top_single(body_output, features) - - def _loss_single(self, logits, features): - if not self._problem_hparams: + if self._problem_hparams: + target_modality = self._problem_hparams.target_modality + else: + target_modality = None + assert not isinstance(target_modality, dict), ( + "model_body must return a dictionary of logits when " + "problem_hparams.target_modality is a dict.") + return self._top_single(body_output, target_modality, features) + + def _loss_single(self, logits, target_modality, features): + if not target_modality: log_warn(_no_problem_err("loss")) return (tf.constant(0., dtype=tf.float32), tf.constant(1., dtype=tf.float32)) - target_modality = self._problem_hparams.target_modality loss_num, loss_den = target_modality.loss(logits, features["targets"]) loss_num *= self._problem_hparams.loss_multiplier return loss_num, loss_den def loss(self, logits, features): if isinstance(logits, dict): + if self._problem_hparams: + target_modality = self._problem_hparams.target_modality + else: + target_modality = {k: None for k in logits.keys()} + assert set(logits.keys()) == set(target_modality.keys()), ( + "The keys of model_body's returned logits dict must match the keys " + "of problem_hparams.target_modality's dict.") losses = {} - for k, v in logits.iteritems(): - losses[k] = self._loss_single(v, features) - return tf.add_n([n / d for n, d in logits.values()]) + for k, v in six.iteritems(logits): + losses[k] = self._loss_single(v, target_modality[k], features) + return tf.add_n([n / d for n, d in losses.values()]) else: - return self._loss_single(logits, features) + if self._problem_hparams: + target_modality = self._problem_hparams.target_modality + else: + target_modality = None + assert not isinstance(target_modality, dict), ( + "model_body must return a dictionary of logits when " + "problem_hparams.target_modality is a dict.") + return self._loss_single(logits, target_modality, features) def optimize(self, loss, num_async_replicas=1): """Return a training op minimizing loss.""" @@ -386,12 +424,21 @@ def _create_modalities(self, problem_hparams, hparams): input_modality[f] = registry.create_modality(modality_spec, hparams) problem_hparams.input_modality = input_modality - target_modality_spec = problem_hparams.target_modality - if target_modality_name: - _warn_changed_modality_type(target_modality_name, target_modality_spec[0], - "target") - target_modality_spec = (target_modality_name, target_modality_spec[1]) - target_modality = registry.create_modality(target_modality_spec, hparams) + if isinstance(problem_hparams.target_modality, dict): + target_modality = {} + for f, modality_spec in six.iteritems(problem_hparams.target_modality): + if target_modality_name: + _warn_changed_modality_type(target_modality_name, modality_spec[0], + "target_modality/%s" % f) + modality_spec = (target_modality_name, modality_spec[1]) + target_modality[f] = registry.create_modality(modality_spec, hparams) + else: + target_modality_spec = problem_hparams.target_modality + if target_modality_name: + _warn_changed_modality_type(target_modality_name, + target_modality_spec[0], "target") + target_modality_spec = (target_modality_name, target_modality_spec[1]) + target_modality = registry.create_modality(target_modality_spec, hparams) problem_hparams.target_modality = target_modality def prepare_features_for_infer(self, features): @@ -880,7 +927,7 @@ def estimator_model_fn(cls, # Set known shapes if use_tpu: if isinstance(logits, dict): - for k, v in logits.iteritems(): + for k, v in six.iteritems(logits): if "scalar/" in k: continue @@ -941,9 +988,9 @@ def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): problem = hparams.problem_instances[0] if common_layers.is_on_tpu(): - eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) _remove_summaries() if isinstance(logits, dict): + eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) # For TPU, logits dict will be passed as keyword arguments to # eval_metrics_fn. Here we add the labels to those arguments. logits.update({"labels": labels}) @@ -952,6 +999,7 @@ def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): eval_metrics=(eval_metrics_fn, logits), loss=loss) else: + eval_metrics_fn = _create_tpu_eval_metrics_fn(problem, hparams) return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, [logits, labels]), @@ -960,13 +1008,22 @@ def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): - eval_metrics[metric_name] = metric_fn(logits, features) - - return tf.estimator.EstimatorSpec( - tf.estimator.ModeKeys.EVAL, - predictions={"predictions": logits}, - eval_metric_ops=eval_metrics, - loss=loss) + if isinstance(logits, dict): + # the key is located in the center of metric_name: "metrics-%s/%s/%s" + k = metric_name.split("/")[1] + eval_metrics[metric_name] = metric_fn(logits[k], features) + return tf.estimator.EstimatorSpec( + tf.estimator.ModeKeys.EVAL, + predictions=logits, + eval_metric_ops=eval_metrics, + loss=loss) + else: + eval_metrics[metric_name] = metric_fn(logits, features) + return tf.estimator.EstimatorSpec( + tf.estimator.ModeKeys.EVAL, + predictions={"predictions": logits}, + eval_metric_ops=eval_metrics, + loss=loss) def estimator_spec_predict(self, features): """Construct EstimatorSpec for PREDICT mode.""" @@ -1001,6 +1058,8 @@ def estimator_spec_predict(self, features): if "scores" in predictions: export_out["scores"] = predictions["scores"] + _remove_summaries() + return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, predictions=predictions, @@ -1066,28 +1125,49 @@ def _create_dummy_vars(): def _create_tpu_eval_metrics_fn(problem, hparams): """Create the metrics_fn that TPUEstimatorSpec expects.""" - tm = problem.get_hparams().target_modality - if isinstance(tm, tuple): - tm = registry.create_modality(tm, hparams) - weights_fn = tm.targets_weights_fn + metric_fns = [] + eval_metrics = problem.eval_metrics() - def make_metric_fn(metric_fn): + tm = problem.get_hparams().target_modality + if isinstance(tm, dict): + for k, v in six.iteritems(tm): + if isinstance(v, tuple): + v = registry.create_modality(v, hparams) + weights_fn = v.targets_weights_fn + + def make_metric_fn(metric_fn): + + def wrapped_metric_fn(logits, labels, weights_fn=weights_fn): + num, den = metric_fn(logits, labels, weights_fn=weights_fn) + return tf.metrics.mean(num, den) + + return wrapped_metric_fn + + for metric in eval_metrics: + if metric in TPU_METRIC_BLACKLIST: + log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) + continue + name = "%s/metrics-%s/%s" % (k, problem.name, metric) + metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) + else: + if isinstance(tm, tuple): + tm = registry.create_modality(tm, hparams) + weights_fn = tm.targets_weights_fn - def wrapped_metric_fn(logits, labels): - num, den = metric_fn(logits, labels, weights_fn=weights_fn) - return tf.metrics.mean(num, den) + def make_metric_fn(metric_fn): - return wrapped_metric_fn + def wrapped_metric_fn(logits, labels): + num, den = metric_fn(logits, labels, weights_fn=weights_fn) + return tf.metrics.mean(num, den) - metric_fns = [] - eval_metrics = problem.eval_metrics() + return wrapped_metric_fn - for metric in eval_metrics: - if metric in TPU_METRIC_BLACKLIST: - log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) - continue - name = "metrics-%s/%s" % (problem.name, metric) - metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) + for metric in eval_metrics: + if metric in TPU_METRIC_BLACKLIST: + log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) + continue + name = "metrics-%s/%s" % (problem.name, metric) + metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) def all_metrics_fn(logits=None, labels=None, **kwargs): """Construct metrics dictionary.""" @@ -1098,8 +1178,11 @@ def all_metrics_fn(logits=None, labels=None, **kwargs): for name, fn in metric_fns: if isinstance(logits, dict): - for k, v in logits.iteritems(): - metrics_dict["%s/%s" % (name, k)] = fn(v, labels) + for k, v in six.iteritems(logits): + if isinstance(labels, dict): + metrics_dict["%s/%s" % (name, k)] = fn(v, labels[k]) + else: + metrics_dict["%s/%s" % (name, k)] = fn(v, labels) else: metrics_dict[name] = fn(logits, labels) diff --git a/tensor2tensor/utils/trainer_lib.py b/tensor2tensor/utils/trainer_lib.py index dd1442517..1eb2442b4 100644 --- a/tensor2tensor/utils/trainer_lib.py +++ b/tensor2tensor/utils/trainer_lib.py @@ -117,6 +117,7 @@ def create_run_config(master="", use_tpu=use_tpu) run_config_args = { "master": master, + "evaluation_master": master, "model_dir": model_dir, "session_config": session_config, "save_summary_steps": 100, diff --git a/tensor2tensor/utils/trainer_lib_test.py b/tensor2tensor/utils/trainer_lib_test.py index ef117274e..a9a6e692c 100644 --- a/tensor2tensor/utils/trainer_lib_test.py +++ b/tensor2tensor/utils/trainer_lib_test.py @@ -76,9 +76,8 @@ def testExperiment(self): def testModel(self): # HParams - hparams = trainer_lib.create_hparams("transformer_tiny", - data_dir=self.data_dir, - problem_name="tiny_algo") + hparams = trainer_lib.create_hparams( + "transformer_tiny", data_dir=self.data_dir, problem_name="tiny_algo") # Dataset problem = hparams.problem_instances[0] @@ -102,6 +101,43 @@ def testModel(self): self.assertAllEqual(logits_shape, [10, None, 1, 1, 4]) self.assertEqual(loss_val.shape, tuple()) + def testMultipleTargetModalities(self): + # HParams + hparams = trainer_lib.create_hparams( + "transformer_tiny", data_dir=self.data_dir, problem_name="tiny_algo") + tm = hparams.problem_instances[0].get_hparams().target_modality + hparams.problem_instances[0].get_hparams().target_modality = { + "targets": tm, + "A": tm, + "B": tm + } + + # Dataset + problem = hparams.problem_instances[0] + dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, self.data_dir) + dataset = dataset.repeat(None).padded_batch(10, dataset.output_shapes) + features = dataset.make_one_shot_iterator().get_next() + features = problem_lib.standardize_shapes(features) + features["A"] = features["B"] = features["targets"] + + # Model + model = registry.model("transformer")(hparams, tf.estimator.ModeKeys.TRAIN) + + def body(args, mb=model.body): + out = mb(args) + return {"targets": out, "A": out, "B": out} + + model.body = body + + logits, losses = model(features) + + self.assertTrue("training" in losses) + loss = losses["training"] + + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run([logits, loss]) + if __name__ == "__main__": tf.test.main()