diff --git a/.travis.yml b/.travis.yml index 4fc315147..694915038 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,13 +10,13 @@ env: matrix: - TF_VERSION="1.4.*" - TF_VERSION="1.5.*" - - TF_VERSION="1.6.0rc1" + - TF_VERSION="1.6.*" matrix: exclude: - python: "3.6" env: TF_VERSION="1.4.*" - python: "3.6" - env: TF_VERSION="1.6.0rc1" + env: TF_VERSION="1.6.*" before_install: - echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list - curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add - diff --git a/README.md b/README.md index a569fb80a..755d080b6 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or [T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library -of deep learning models and datasets designed to [accelerate deep learning -research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible. - -T2T is actively used and maintained by researchers and engineers within the +of deep learning models and datasets designed to make deep learning more +accessible and [accelerate ML +research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). + is actively used and maintained by researchers and engineers within the [Google Brain team](https://research.google.com/teams/brain/) and a community of users. We're eager to collaborate with you too, so feel free to [open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues) @@ -368,6 +368,7 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research * [Discrete Autoencoders for Sequence Models](https://arxiv.org/abs/1801.09797) * [Generating Wikipedia by Summarizing Long Sequences](https://arxiv.org/abs/1801.10198) -* [Image Transformer](https://openreview.net/forum?id=r16Vyf-0-) +* [Image Transformer](https://arxiv.org/abs/1802.05751) +* [Training Tips for the Transformer Model](http://ufallab.ms.mff.cuni.cz/~popel/training-tips-transformer.pdf) *Note: This is not an official Google product.* diff --git a/docs/index.md b/docs/index.md index 060a878a3..8860e03b7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -11,8 +11,9 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or [T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library -of deep learning models and datasets designed to [accelerate deep learning -research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible. +of deep learning models and datasets designed to make deep learning more +accessible and [accelerate ML +research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). ## Basics diff --git a/docs/new_problem.md b/docs/new_problem.md index fab76d90d..371ae3daa 100644 --- a/docs/new_problem.md +++ b/docs/new_problem.md @@ -9,6 +9,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) +Another good overview of this part together with training is given in +[The Cloud ML Poetry Blog +Post](https://cloud.google.com/blog/big-data/2018/02/cloud-poetry-training-and-hyperparameter-tuning-custom-text-models-on-cloud-ml-engine) + Let's add a new dataset together and train the [Transformer](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/transformer.py) model on it. We'll give the model a line of poetry, and it will learn to diff --git a/docs/walkthrough.md b/docs/walkthrough.md index a569fb80a..755d080b6 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -12,10 +12,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or [T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library -of deep learning models and datasets designed to [accelerate deep learning -research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible. - -T2T is actively used and maintained by researchers and engineers within the +of deep learning models and datasets designed to make deep learning more +accessible and [accelerate ML +research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). + is actively used and maintained by researchers and engineers within the [Google Brain team](https://research.google.com/teams/brain/) and a community of users. We're eager to collaborate with you too, so feel free to [open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues) @@ -368,6 +368,7 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research * [Discrete Autoencoders for Sequence Models](https://arxiv.org/abs/1801.09797) * [Generating Wikipedia by Summarizing Long Sequences](https://arxiv.org/abs/1801.10198) -* [Image Transformer](https://openreview.net/forum?id=r16Vyf-0-) +* [Image Transformer](https://arxiv.org/abs/1802.05751) +* [Training Tips for the Transformer Model](http://ufallab.ms.mff.cuni.cz/~popel/training-tips-transformer.pdf) *Note: This is not an official Google product.* diff --git a/setup.py b/setup.py index c30c752dd..01a2a6f33 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.5.3', + version='1.5.4', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 9ccda4c0e..4339a0068 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -147,8 +147,9 @@ def generate_files(generator, output_filenames, max_cases=None): if outputs_exist(output_filenames): tf.logging.info("Skipping generator because outputs files exist") return + tmp_filenames = [fname + ".incomplete" for fname in output_filenames] num_shards = len(output_filenames) - writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames] + writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filenames] counter, shard = 0, 0 for case in generator: if case is None: @@ -165,6 +166,9 @@ def generate_files(generator, output_filenames, max_cases=None): for writer in writers: writer.close() + for tmp_name, final_name in zip(tmp_filenames, output_filenames): + tf.gfile.Rename(tmp_name, final_name) + tf.logging.info("Generated %s Examples", counter) diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index aa00e4189..06b5ad0f3 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -19,39 +19,30 @@ from __future__ import division from __future__ import print_function -import os +import functools # Dependency imports -import numpy as np -import functools import gym +import numpy as np -from tensor2tensor.rl import rl_trainer_lib -from tensor2tensor.rl.envs import atari_wrappers -from tensor2tensor.models.research import rl from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem +from tensor2tensor.models.research import rl +from tensor2tensor.rl.envs import atari_wrappers from tensor2tensor.utils import registry import tensorflow as tf + + flags = tf.flags FLAGS = flags.FLAGS flags.DEFINE_string("model_path", "", "File with model for pong") -def gym_lib(): - """Access to gym to allow for import of this file without a gym install.""" - try: - import gym # pylint: disable=g-import-not-at-top - except ImportError: - raise ImportError("pip install gym to use gym-based Problems") - return gym - - class GymDiscreteProblem(problem.Problem): """Gym environment with discrete actions and rewards.""" @@ -67,7 +58,7 @@ def env_name(self): @property def env(self): if self._env is None: - self._env = gym_lib().make(self.env_name) + self._env = gym.make(self.env_name) return self._env @property @@ -157,8 +148,6 @@ def num_steps(self): return 5000 - - @registry.register_problem class GymPongTrajectoriesFromPolicy(GymDiscreteProblem): """Pong game, loaded actions.""" @@ -167,28 +156,34 @@ def __init__(self, event_dir, *args, **kwargs): super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs) self._env = None self._event_dir = event_dir - env_spec = lambda: atari_wrappers.wrap_atari( - gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False) + env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda + gym.make("PongNoFrameskip-v4"), + warp=False, + frame_skip=4, + frame_stack=False) hparams = rl.atari_base() with tf.variable_scope("train"): policy_lambda = hparams.network policy_factory = tf.make_template( - "network", - functools.partial(policy_lambda, env_spec().action_space, hparams)) - self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape) - actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0)) + "network", + functools.partial(policy_lambda, env_spec().action_space, hparams)) + self._max_frame_pl = tf.placeholder( + tf.float32, self.env.observation_space.shape) + actor_critic = policy_factory(tf.expand_dims(tf.expand_dims( + self._max_frame_pl, 0), 0)) policy = actor_critic.policy self._last_policy_op = policy.mode() self._last_action = self.env.action_space.sample() self._skip = 4 self._skip_step = 0 - self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, dtype=np.uint8) + self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, + dtype=np.uint8) self._sess = tf.Session() model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*")) model_saver.restore(self._sess, FLAGS.model_path) # TODO(blazej0): For training of atari agents wrappers are usually used. - # Below we have a hacky solution which is a temporary workaround to be used together + # Below we have a hacky solution which is a workaround to be used together # with atari_wrappers.MaxAndSkipEnv. def get_action(self, observation=None): if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation @@ -197,7 +192,8 @@ def get_action(self, observation=None): if self._skip_step == 0: max_frame = self._obs_buffer.max(axis=0) self._last_action = int(self._sess.run( - self._last_policy_op, feed_dict={self._max_frame_pl: max_frame})[0, 0]) + self._last_policy_op, + feed_dict={self._max_frame_pl: max_frame})[0, 0]) return self._last_action @property diff --git a/tensor2tensor/data_generators/imagenet.py b/tensor2tensor/data_generators/imagenet.py index a071c820b..a378f0726 100644 --- a/tensor2tensor/data_generators/imagenet.py +++ b/tensor2tensor/data_generators/imagenet.py @@ -19,19 +19,84 @@ from __future__ import division from __future__ import print_function +import os # Dependency imports +from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import image_utils from tensor2tensor.utils import registry import tensorflow as tf +# URLs and filenames for IMAGENET 32x32 data from +# https://arxiv.org/abs/1601.06759. +_IMAGENET_SMALL_ROOT_URL = "http://image-net.org/small/" +_IMAGENET_SMALL_URLS = [ + "train_32x32.tar", "valid_32x32.tar"] +_IMAGENET_SMALL_TRAIN_PREFIX = "train_32x32" +_IMAGENET_SMALL_EVAL_PREFIX = "valid_32x32" +_IMAGENET_SMALL_IMAGE_SIZE = 32 + + +# URLs and filenames for IMAGENET 64x64 data. +_IMAGENET_MEDIUM_ROOT_URL = "http://image-net.org/small/" +_IMAGENET_MEDIUM_URLS = [ + "train_64x64.tar", "valid_64x64.tar"] +_IMAGENET_MEDIUM_TRAIN_PREFIX = "train_64x64" +_IMAGENET_MEDIUM_EVAL_PREFIX = "valid_64x64" +_IMAGENET_MEDIUM_IMAGE_SIZE = 64 + # Derived from ImageNet data MEAN_RGB = [0.485, 0.456, 0.406] STDDEV_RGB = [0.229, 0.224, 0.225] +def imagenet_pixelrnn_generator(tmp_dir, + training, + size=_IMAGENET_SMALL_IMAGE_SIZE): + """Image generator for Imagenet 64x64 downsampled images. + + It assumes that the data has been downloaded from + http://image-net.org/small/*_32x32.tar or + http://image-net.org/small/*_64x64.tar into tmp_dir. + Args: + tmp_dir: path to temporary storage directory. + training: a Boolean; if true, we use the train set, otherwise the test set. + size: image size (assumes height and width are same) + + Yields: + A dictionary representing the images with the following fields: + * image/encoded: the string encoding the image as JPEG, + * image/format: the string "jpeg" representing image format, + * image/height: an integer representing the height, + * image/width: an integer representing the width. + Every field is actually a list of the corresponding type. + """ + if size == _IMAGENET_SMALL_IMAGE_SIZE: + train_prefix = _IMAGENET_SMALL_TRAIN_PREFIX + eval_prefix = _IMAGENET_SMALL_EVAL_PREFIX + else: + train_prefix = _IMAGENET_MEDIUM_TRAIN_PREFIX + eval_prefix = _IMAGENET_MEDIUM_EVAL_PREFIX + prefix = train_prefix if training else eval_prefix + images_filepath = os.path.join(tmp_dir, prefix) + image_files = tf.gfile.Glob(images_filepath + "/*") + height = size + width = size + const_label = 0 + for filename in image_files: + with tf.gfile.Open(filename, "r") as f: + encoded_image = f.read() + yield { + "image/encoded": [encoded_image], + "image/format": ["png"], + "image/class/label": [const_label], + "image/height": [height], + "image/width": [width] + } + + def imagenet_preprocess_example(example, mode, resize_size=None): """Preprocessing used for Imagenet and similar problems.""" resize_size = resize_size or [299, 299] @@ -123,6 +188,40 @@ def preprocess_example(self, example, mode, _): return example +@registry.register_problem +class ImageImagenet64Gen(ImageImagenet): + """Cifar-10 Tune.""" + + @property + def train_shards(self): + return 1024 + + @property + def dev_shards(self): + return 10 + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + generator_utils.generate_dataset_and_shuffle( + self.generator(data_dir, tmp_dir, True), + self.training_filepaths(data_dir, self.train_shards, shuffled=True), + self.generator(data_dir, tmp_dir, False), + self.dev_filepaths(data_dir, self.dev_shards, shuffled=True)) + + def generator(self, data_dir, tmp_dir, is_training): + if is_training: + return imagenet_pixelrnn_generator( + tmp_dir, int(True), size=_IMAGENET_MEDIUM_IMAGE_SIZE) + else: + return imagenet_pixelrnn_generator( + tmp_dir, int(False), size=_IMAGENET_MEDIUM_IMAGE_SIZE) + + def preprocess_example(self, example, mode, unused_hparams): + example["inputs"].set_shape([_IMAGENET_MEDIUM_IMAGE_SIZE, + _IMAGENET_MEDIUM_IMAGE_SIZE, 3]) + example["inputs"] = tf.to_int64(example["inputs"]) + return example + + @registry.register_problem class ImageImagenet64(ImageImagenet32): """Imagenet rescaled to 64x64.""" diff --git a/tensor2tensor/data_generators/inspect.py b/tensor2tensor/data_generators/inspect.py index 0b113bb98..c8fb85deb 100644 --- a/tensor2tensor/data_generators/inspect.py +++ b/tensor2tensor/data_generators/inspect.py @@ -60,6 +60,8 @@ def main(_): total_sequences = 0 total_input_tokens = 0 total_target_tokens = 0 + nonpadding_input_tokens = 0 + nonpadding_target_tokens = 0 max_input_length = 0 max_target_length = 0 for record in reader: @@ -71,6 +73,8 @@ def main(_): print("INPUTS:\n" + encoder.decode(inputs) if encoder else inputs) if FLAGS.print_targets: print("TARGETS:\n" + encoder.decode(targets) if encoder else targets) + nonpadding_input_tokens += len(inputs) - inputs.count(0) + nonpadding_target_tokens += len(targets) - targets.count(0) total_input_tokens += len(inputs) total_target_tokens += len(targets) total_sequences += 1 @@ -83,6 +87,8 @@ def main(_): print("total_sequences: %d" % total_sequences) print("total_input_tokens: %d" % total_input_tokens) print("total_target_tokens: %d" % total_target_tokens) + print("nonpadding_input_tokens: %d" % nonpadding_input_tokens) + print("nonpadding_target_tokens: %d" % nonpadding_target_tokens) print("max_input_length: %d" % max_input_length) print("max_target_length: %d" % max_target_length) diff --git a/tensor2tensor/data_generators/lm1b.py b/tensor2tensor/data_generators/lm1b.py index 4f14c1040..e875a810d 100644 --- a/tensor2tensor/data_generators/lm1b.py +++ b/tensor2tensor/data_generators/lm1b.py @@ -87,10 +87,10 @@ def _train_data_filenames(tmp_dir): def _dev_data_filenames(tmp_dir): - return os.path.join(tmp_dir, - "1-billion-word-language-modeling-benchmark-r13output", - "heldout-monolingual.tokenized.shuffled", - "news.en.heldout-00000-of-00050") + return [os.path.join(tmp_dir, + "1-billion-word-language-modeling-benchmark-r13output", + "heldout-monolingual.tokenized.shuffled", + "news.en.heldout-00000-of-00050")] def _maybe_download_corpus(tmp_dir): @@ -147,7 +147,7 @@ def _get_or_build_subword_text_encoder(tmp_dir, vocab_filepath, target_size): @registry.register_problem -class LanguagemodelLm1b32k(text_problems.Text2TextProblem): +class LanguagemodelLm1b32k(text_problems.Text2SelfProblem): """A language model on the 1B words corpus.""" @property @@ -187,7 +187,7 @@ class LanguagemodelLm1b8kPacked(LanguagemodelLm1b32k): Happy TPU Training. Ratio of dev tokens (including eos) to dev words (including eos) - 207351 / 159658 = 1.29872; multiply ppx by this to compare results. + 207351 / 159658 = 1.29872; multiply log-ppl by this to compare results. """ @property @@ -201,8 +201,25 @@ def packed_length(self): @registry.register_problem class LanguagemodelLm1bCharacters(LanguagemodelLm1b32k): - """A language model on the 1B words corpus, character level.""" + """A language model on the 1B words corpus, character level. + + Ratio of dev chars (including eos) to dev words (including eos) + 826189 / 159658 = 5.174742; multiply log-ppl by this to compare results. + """ @property def vocab_type(self): return text_problems.VocabType.CHARACTER + + +@registry.register_problem +class LanguagemodelLm1bCharactersPacked(LanguagemodelLm1bCharacters): + """Packed version. + + Ratio of dev chars (including eos) to dev words (including eos) + 826189 / 159658 = 5.174742; multiply log-ppl by this to compare results. + """ + + @property + def packed_length(self): + return 1024 diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 68f573b5f..5faf5175b 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -421,25 +421,31 @@ def get_hparams(self, model_hparams=None): return self._hparams def maybe_reverse_features(self, feature_map): + """Reverse features between inputs and targets if the problem is '_rev'.""" if not self._was_reversed: return inputs, targets = feature_map["inputs"], feature_map["targets"] feature_map["inputs"], feature_map["targets"] = targets, inputs if "inputs_segmentation" in feature_map: - inputs, targets = feature_map["inputs_segmentation"], feature_map["targets_segmentation"] - feature_map["inputs_segmentation"], feature_map["targets_segmentation"] = targets, inputs + inputs_seg = feature_map["inputs_segmentation"] + targets_seg = feature_map["targets_segmentation"] + feature_map["inputs_segmentation"] = targets_seg + feature_map["targets_segmentation"] = inputs_seg if "inputs_position" in feature_map: - inputs, targets = feature_map["inputs_position"], feature_map["targets_position"] - feature_map["inputs_position"], feature_map["targets_position"] = targets, inputs - + inputs_pos = feature_map["inputs_position"] + targets_pos = feature_map["targets_position"] + feature_map["inputs_position"] = targets_pos + feature_map["targets_position"] = inputs_pos def maybe_copy_features(self, feature_map): if not self._was_copy: return feature_map["targets"] = feature_map["inputs"] - if "inputs_segmentation" in feature_map: - feature_map["targets_segmentation"] = feature_map["inputs_segmentation"] - if "inputs_position" in feature_map: + if ("inputs_segmentation" in feature_map and + "targets_segmentation" not in feature_map): + feature_map["targets_segmentation"] = feature_map["inputs_segmentation"] + if ("inputs_position" in feature_map and + "targets_position" not in feature_map): feature_map["targets_position"] = feature_map["inputs_position"] def dataset(self, diff --git a/tensor2tensor/data_generators/text_problems.py b/tensor2tensor/data_generators/text_problems.py index 73e6bf4c7..862cd6b0c 100644 --- a/tensor2tensor/data_generators/text_problems.py +++ b/tensor2tensor/data_generators/text_problems.py @@ -186,10 +186,11 @@ def generate_text_for_vocab(self, data_dir, tmp_dir): @property def vocab_filename(self): if self.vocab_type == VocabType.SUBWORD: - return "vocab.%s.%d.%s" % (self.name, self.approx_vocab_size, + return "vocab.%s.%d.%s" % (self.dataset_filename(), + self.approx_vocab_size, VocabType.SUBWORD) else: - return "vocab.%s.%s" % (self.name, VocabType.TOKEN) + return "vocab.%s.%s" % (self.dataset_filename(), VocabType.TOKEN) def get_or_create_vocab(self, data_dir, tmp_dir, force_get=False): if self.vocab_type == VocabType.CHARACTER: diff --git a/tensor2tensor/data_generators/translate_enmk.py b/tensor2tensor/data_generators/translate_enmk.py index 94f4d213a..ce73d634d 100644 --- a/tensor2tensor/data_generators/translate_enmk.py +++ b/tensor2tensor/data_generators/translate_enmk.py @@ -27,10 +27,6 @@ from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry -import tensorflow as tf - -FLAGS = tf.flags.FLAGS - # End-of-sentence marker. EOS = text_encoder.EOS_ID @@ -50,6 +46,10 @@ ]] +# See this PR on github for some results with Transformer on these Problems. +# https://github.com/tensorflow/tensor2tensor/pull/626 + + @registry.register_problem class TranslateEnmkSetimes32k(translate.TranslateProblem): """Problem spec for SETimes En-Mk translation.""" @@ -66,6 +66,7 @@ def source_data_files(self, dataset_split): train = dataset_split == problem.DatasetSplit.TRAIN return _ENMK_TRAIN_DATASETS if train else _ENMK_TEST_DATASETS + @registry.register_problem class TranslateEnmkSetimesCharacters(translate.TranslateProblem): """Problem spec for SETimes En-Mk translation.""" diff --git a/tensor2tensor/data_generators/translate_envi.py b/tensor2tensor/data_generators/translate_envi.py index 8b5d96864..2003316eb 100644 --- a/tensor2tensor/data_generators/translate_envi.py +++ b/tensor2tensor/data_generators/translate_envi.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data generators for translation data-sets.""" +"""Data generators for En-Vi translation.""" from __future__ import absolute_import from __future__ import division @@ -26,10 +26,6 @@ from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry -import tensorflow as tf - -FLAGS = tf.flags.FLAGS - # End-of-sentence marker. EOS = text_encoder.EOS_ID @@ -48,6 +44,10 @@ ]] +# See this PR on github for some results with Transformer on this Problem. +# https://github.com/tensorflow/tensor2tensor/pull/611 + + @registry.register_problem class TranslateEnviIwslt32k(translate.TranslateProblem): """Problem spec for IWSLT'15 En-Vi translation.""" diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 9909a1267..c6a724a70 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -224,7 +224,7 @@ class LanguagemodelWikiNorefV8kL1k(LanguagemodelWikiXmlV8kL1k): @property def vocab_filename(self): - return "vocab.wiki_noref" + return "vocab.wiki_noref.%d" % self.approx_vocab_size def filepath_to_unicode_text(self, filepath): """Overriddes the base class to clean up the xml dump before tokenizing.""" diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index a9346c34d..ddb6c3c89 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -327,7 +327,9 @@ def add_standard_attention_hparams(hparams): return hparams -def encoder_decoder_attention_loss(expected_attention, actual_attentions): +def encoder_decoder_attention_loss(expected_attention, + actual_attentions, + loss_multiplier=1.0): """Computes encdec attention loss between expected and actual attentions. Args: @@ -335,6 +337,7 @@ def encoder_decoder_attention_loss(expected_attention, actual_attentions): weights with shape [batch_size, target_length, input_length]. actual_attentions: Dictionary with actual attention weights for different attention types and hidden layers. + loss_multiplier: multiplier for the attention loss. Returns: MSE loss between the actual and expected attention weights. @@ -351,8 +354,8 @@ def encoder_decoder_attention_loss(expected_attention, actual_attentions): # Reduce mean across all layers (axis=0) and all heads (axis=2) to get a # tensor with shape [batch_size, target_length, input_length]. actual_attention_weights = tf.reduce_mean(actual_attention_weights, [0, 2]) - return tf.losses.mean_squared_error(expected_attention, - actual_attention_weights) + return tf.losses.mean_squared_error( + expected_attention, actual_attention_weights) * loss_multiplier @expert_utils.add_name_scope() diff --git a/tensor2tensor/layers/common_image_attention.py b/tensor2tensor/layers/common_image_attention.py index 0e7ac4a4e..47b96577e 100644 --- a/tensor2tensor/layers/common_image_attention.py +++ b/tensor2tensor/layers/common_image_attention.py @@ -74,7 +74,6 @@ def local_attention_2d(x, hparams, attention_type="local_attention_2d"): def local_attention_1d(x, - self_attention_bias, hparams, attention_type="local_unmasked", q_padding="VALID", @@ -86,7 +85,7 @@ def local_attention_1d(x, y = common_attention.multihead_attention( x, None, - self_attention_bias, + None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, @@ -107,7 +106,6 @@ def local_attention_1d(x, def dilated_attention_1d(x, - self_attention_bias, hparams, attention_type="masked_dilated_1d", q_padding="VALID", @@ -120,7 +118,7 @@ def dilated_attention_1d(x, y = common_attention.multihead_attention( x, None, - self_attention_bias, + None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, @@ -152,6 +150,8 @@ def local_global_attention(x, [x_global, x_local] = tf.split(x, 2, axis=-1) split_hidden_size = int(hparams.hidden_size / 2) split_heads = int(hparams.num_heads / 2) + if self_attention_bias is not None: + self_attention_bias = get_self_attention_bias(x) y_global = common_attention.multihead_attention( x_global, None, @@ -169,7 +169,7 @@ def local_global_attention(x, y_local = common_attention.multihead_attention( x_local, None, - self_attention_bias, + None, hparams.attention_key_channels or split_hidden_size, hparams.attention_value_channels or split_hidden_size, split_hidden_size, @@ -194,6 +194,8 @@ def full_self_attention(x, kv_padding="LEFT"): """Full self-attention layer.""" x, x_shape, is_4d = maybe_reshape_4d_to_3d(x) + if self_attention_bias is not None: + self_attention_bias = get_self_attention_bias(x) with tf.variable_scope("self_att"): y = common_attention.multihead_attention( x, @@ -241,9 +243,9 @@ def encdec_attention_1d(x, def transformer_decoder_layers(inputs, encoder_output, - bias, num_layers, hparams, + self_attention_bias=None, attention_type=AttentionType.LOCAL_2D, name="transformer"): """Multi layer transformer.""" @@ -260,21 +262,21 @@ def transformer_decoder_layers(inputs, attention_type="masked_local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess(x, hparams), - bias, hparams, + hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.GLOCAL: y = local_global_attention(common_layers.layer_preprocess(x, hparams), - bias, hparams, + self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.DILATED: y = dilated_attention_1d(common_layers.layer_preprocess(x, hparams), - bias, hparams, q_padding="LEFT", + hparams, q_padding="LEFT", kv_padding="LEFT", gap_size=hparams.gap_sizes[layer]) elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess(x, hparams), - bias, hparams, + self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT") x = common_layers.layer_postprocess(x, y, hparams) # enc-dec attention + skip connections @@ -309,7 +311,7 @@ def transformer_encoder_layers(inputs, attention_type="local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess(x, hparams), - self_attention_bias, hparams, + hparams, attention_type="local_unmasked", q_padding=q_padding, kv_padding=kv_padding) elif attention_type == AttentionType.GLOBAL: @@ -356,6 +358,22 @@ def ffn_layer(x, hparams): return y +def get_self_attention_bias(x): + """Creates masked self attention bias. + + Args: + x: A tensor of shape [batch, length, depth] + + Returns: + self_attention_bias: A tensor of shape [length, length, 1] + """ + + x_shape = common_layers.shape_list(x) + self_attention_bias = common_attention.attention_bias_lower_triangle( + x_shape[1]) + return self_attention_bias + + def transformer_layers_sharded(dp, ps_devices, inputs, @@ -381,7 +399,7 @@ def transformer_layers_sharded(dp, attention_type="masked_local_attention_2d")) elif attention_type == AttentionType.LOCAL_1D: y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams), - self_attention_bias, hparams, + hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOCAL: @@ -389,6 +407,7 @@ def transformer_layers_sharded(dp, common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOBAL: + self_attention_bias = dp(get_self_attention_bias(x)) y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) @@ -509,8 +528,6 @@ def prepare_decoder(targets, hparams): # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) - # mask out upper triangle to avoid looking into the future. - bias = common_attention.attention_bias_lower_triangle(x_shape[1]*x_shape[2]) if hparams.dec_attention_type == AttentionType.LOCAL_2D: x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") @@ -522,7 +539,7 @@ def prepare_decoder(targets, hparams): x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") - return x, x_shape[1], x_shape[2], bias + return x, x_shape[1], x_shape[2] def prepare_image(inputs, hparams, name=None): diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index c78d1f52a..32ef49901 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -44,6 +44,7 @@ from tensor2tensor.models.research import cycle_gan from tensor2tensor.models.research import gene_expression from tensor2tensor.models.research import multimodel +from tensor2tensor.models.research import rl from tensor2tensor.models.research import super_lm from tensor2tensor.models.research import transformer_moe from tensor2tensor.models.research import transformer_revnet diff --git a/tensor2tensor/models/image_transformer.py b/tensor2tensor/models/image_transformer.py index a7e00245f..dbb58d0b1 100644 --- a/tensor2tensor/models/image_transformer.py +++ b/tensor2tensor/models/image_transformer.py @@ -50,7 +50,7 @@ def body(self, features): tf.summary.image("targets", targets, max_outputs=1) # Prepare decoder inputs and bias. - decoder_input, rows, cols, bias = cia.prepare_decoder(targets, hparams) + decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape( @@ -59,7 +59,6 @@ def body(self, features): decoder_output = cia.transformer_decoder_layers( decoder_input, None, - bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, @@ -90,8 +89,8 @@ def body_sharded(self, sharded_features): kv_padding = "LEFT" # Prepare decoder inputs and bias. - decoder_input, rows, cols, attention_bias = dp(cia.prepare_decoder_inputs, - inputs, targets, hparams) + decoder_input, rows, cols = dp(cia.prepare_decoder_inputs, + inputs, targets, hparams) # Run decoder. decoder_output, extra_loss = cia.transformer_layers_sharded( @@ -100,7 +99,7 @@ def body_sharded(self, sharded_features): decoder_input, hparams.num_hidden_layers, hparams, - self_attention_bias=attention_bias, + self_attention_bias=None, enc_output=None, attention_type=hparams.dec_attention_type, q_padding=q_padding, @@ -243,6 +242,18 @@ def imagetransformer_base_8l_8h_big_cond_dr03_dan(): return hparams +@registry.register_hparams +def imagetransformer_base_10l_8h_big_uncond_dr03_dan_64(): + """big 1d model for unconditional generation on imagenet.""" + hparams = imagetransformer_base_10l_8h_big_cond_dr03_dan() + hparams.unconditional = True + hparams.max_length = 14000 + hparams.batch_size = 1 + hparams.img_len = 64 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + @registry.register_hparams def imagetransformer_base_8l_8h_big_cond_dr03_dan_128(): hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan() @@ -264,7 +275,6 @@ def imagetransformer_base_10l_8h_big_uncond_dr03_dan(): """Best unconditional Cifar10 gen param.""" hparams = imagetransformer_base_10l_8h_big_cond_dr03_dan() hparams.num_decoder_layers = 10 - hparams.unconditional = True return hparams diff --git a/tensor2tensor/models/image_transformer_2d.py b/tensor2tensor/models/image_transformer_2d.py index 046fa06ee..83166a937 100644 --- a/tensor2tensor/models/image_transformer_2d.py +++ b/tensor2tensor/models/image_transformer_2d.py @@ -50,7 +50,7 @@ def body(self, features): hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("targets", targets, max_outputs=1) - decoder_input, rows, cols, bias = cia.prepare_decoder( + decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) # Add class label to decoder input. if not hparams.unconditional: @@ -58,7 +58,7 @@ def body(self, features): [targets_shape[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( - decoder_input, None, bias, + decoder_input, None, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, @@ -88,11 +88,11 @@ def body(self, features): hparams, attention_type=hparams.enc_attention_type, name="encoder") - decoder_input, rows, cols, bias = cia.prepare_decoder( + decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, - encoder_output, bias, + encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, @@ -230,6 +230,28 @@ def imagetransformer2d_base_8l_8_32_big(): return hparams +@registry.register_hparams +def imagetransformer_base_10l_8h_big_uncond_dr03_dan_64_2d(): + """big 1d model for unconditional generation on imagenet.""" + hparams = image_transformer2d_base() + hparams.unconditional = True + hparams.hidden_size = 512 + hparams.batch_size = 1 + hparams.img_len = 64 + hparams.num_heads = 8 + hparams.filter_size = 2048 + hparams.batch_size = 1 + hparams.max_length = 3075 + hparams.max_length = 14000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.1 + hparams.dec_attention_type = cia.AttentionType.LOCAL_2D + hparams.query_shape = (16, 16) + hparams.memory_flange = (8, 8) + return hparams + + @registry.register_hparams def imagetransformer2d_base_8l_8_64_64by64(): """hparams fo 12 layer big 2d model for imagenet 64x64.""" diff --git a/tensor2tensor/models/research/rl.py b/tensor2tensor/models/research/rl.py index f552914c2..858d6964e 100644 --- a/tensor2tensor/models/research/rl.py +++ b/tensor2tensor/models/research/rl.py @@ -69,6 +69,7 @@ def discrete_action_base(): @registry.register_hparams def atari_base(): + """Atari base parameters.""" hparams = discrete_action_base() hparams.learning_rate = 16e-5 hparams.num_agents = 5 diff --git a/tensor2tensor/models/research/transformer_vae.py b/tensor2tensor/models/research/transformer_vae.py index 4b37528ea..a5bb3ff85 100644 --- a/tensor2tensor/models/research/transformer_vae.py +++ b/tensor2tensor/models/research/transformer_vae.py @@ -338,10 +338,7 @@ def ae_transformer_internal(inputs, targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) - if hparams.ae_input: - targets_c = compress(targets, inputs, False, hparams, "compress") - else: - targets_c = compress(targets, None, False, hparams, "compress") + targets_c = compress(targets, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck( @@ -638,8 +635,6 @@ def transformer_ae_small(): # Reshape method for DVQ: slice, project hparams.add_hparam("reshape_method", "slice") hparams.add_hparam("trainable_projections", False) - # Add option to pass the input to the autoencoder - hparams.add_hparam("ae_input", False) # Hparams for Dirichlet process process hparams.add_hparam("dp_alpha", 0.5) hparams.add_hparam("dp_strength", 0.25) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 09b252291..2cddbee8e 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -176,7 +176,8 @@ def body(self, features): expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( - expected_attentions, self.attention_weights) + expected_attentions, self.attention_weights, + hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} return decoder_output @@ -1462,3 +1463,11 @@ def transformer_librispeech_tpu(): librispeech.set_librispeech_length_hparams(hparams) return hparams + +@registry.register_hparams +def transformer_supervised_attention(): + """Hparams for supervised attention problems.""" + hparams = transformer_base() + # Multiplier to the encoder-decoder expected attention loss. + hparams.add_hparam("expected_attention_loss_multiplier", 1.0) + return hparams diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 88581ac87..53e4616b9 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -208,7 +208,8 @@ def testTransformerWithoutProblem(self): [BATCH_SIZE, TARGET_LENGTH, 1, hparams.hidden_size]) def testTransformerWithEncoderDecoderAttentionLoss(self): - model, features = self.getModel(transformer.transformer_small()) + model, features = self.getModel( + transformer.transformer_supervised_attention()) expected_attention_weights = np.random.random_sample( size=(BATCH_SIZE, TARGET_LENGTH, INPUT_LENGTH)) features["expected_attentions"] = tf.constant( diff --git a/tensor2tensor/models/xception_test.py b/tensor2tensor/models/xception_test.py index 8e29b3bd0..90d2ba9fb 100644 --- a/tensor2tensor/models/xception_test.py +++ b/tensor2tensor/models/xception_test.py @@ -32,7 +32,7 @@ class XceptionTest(tf.test.TestCase): - def _testXception(self, img_size, output_size): + def _testXception(self, img_size): vocab_size = 9 batch_size = 3 x = np.random.random_integers( @@ -42,6 +42,7 @@ def _testXception(self, img_size, output_size): hparams = xception.xception_tiny() p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) p_hparams.input_modality["inputs"] = (registry.Modalities.IMAGE, None) + p_hparams.target_modality = (registry.Modalities.CLASS_LABEL, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), @@ -51,13 +52,13 @@ def _testXception(self, img_size, output_size): logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) - self.assertEqual(res.shape, output_size + (1, vocab_size)) + self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size)) - def testXceptionSmall(self): - self._testXception(img_size=9, output_size=(3, 5, 5)) + def testXceptionSmallImage(self): + self._testXception(img_size=9) - def testXceptionLarge(self): - self._testXception(img_size=256, output_size=(3, 8, 8)) + def testXceptionLargeImage(self): + self._testXception(img_size=256) if __name__ == "__main__": diff --git a/tensor2tensor/rl/README.md b/tensor2tensor/rl/README.md index 9ff2a0f71..46e40403f 100644 --- a/tensor2tensor/rl/README.md +++ b/tensor2tensor/rl/README.md @@ -20,4 +20,3 @@ Currently the only supported algorithm is Proximy Policy Optimization - PPO. ```python tensor2tensor/bin/t2t-trainer --generate_data --data_dir=~/t2t_data --problems=gym_pong_trajectories_from_policy --hparams_set=base_atari --model_path [model]``` ```python tensor2tensor/bin/t2t-datagen --data_dir=~/t2t_data --tmp_dir=~/t2t_data/tmp --problem=gym_pong_trajectories_from_policy --model_path [model]``` - diff --git a/tensor2tensor/rl/envs/atari_wrappers.py b/tensor2tensor/rl/envs/atari_wrappers.py index 20c6a81dc..b8dd425ec 100644 --- a/tensor2tensor/rl/envs/atari_wrappers.py +++ b/tensor2tensor/rl/envs/atari_wrappers.py @@ -1,37 +1,65 @@ -# Copied from baselines -# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Various wrappers copied for Gym Baselines.""" -# Various wrappers copied from Baselines -import gym -import numpy as np from collections import deque import gym -from gym import spaces +import numpy as np + + +# Adapted from the link below. +# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py class WarpFrame(gym.ObservationWrapper): + """Wrap a frame.""" + def __init__(self, env): """Warp frames to 84x84 as done in the Nature paper and later work.""" gym.ObservationWrapper.__init__(self, env) self.width = 84 self.height = 84 - self.observation_space = spaces.Box(low=0, high=255, - shape=(self.height, self.width, 1), dtype=np.uint8) + self.observation_space = gym.spaces.Box( + low=0, high=255, + shape=(self.height, self.width, 1), dtype=np.uint8) def observation(self, frame): - import cv2 + import cv2 # pylint: disable=g-import-not-at-top cv2.ocl.setUseOpenCL(False) frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) + frame = cv2.resize(frame, (self.width, self.height), + interpolation=cv2.INTER_AREA) return frame[:, :, None] + class LazyFrames(object): + """Lazy frame storage.""" + def __init__(self, frames): - """This object ensures that common frames between the observations are only stored once. - It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay - buffers. - This object should only be converted to numpy array before being passed to the model. - You'd not believe how complex the previous solution was.""" + """Lazy frame storage. + + This object ensures that common frames between the observations + are only stored once. It exists purely to optimize memory usage + which can be huge for DQN's 1M frames replay buffers. + This object should only be converted to numpy array before being passed + to the model. + + Args: + frames: the frames. + """ self._frames = frames def __array__(self, dtype=None): @@ -40,19 +68,18 @@ def __array__(self, dtype=None): out = out.astype(dtype) return out + class FrameStack(gym.Wrapper): + """Stack frames.""" + def __init__(self, env, k): - """Stack k last frames. - Returns lazy array, which is much more memory efficient. - See Also - -------- - baselines.common.atari_wrappers.LazyFrames - """ + """Stack k last frames. Returns lazy array, memory efficient.""" gym.Wrapper.__init__(self, env) self.k = k self.frames = deque([], maxlen=k) shp = env.observation_space.shape - self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) def reset(self): ob = self.env.reset() @@ -69,16 +96,20 @@ def _get_ob(self): assert len(self.frames) == self.k return LazyFrames(list(self.frames)) + class MaxAndSkipEnv(gym.Wrapper): + """Max and skip env.""" + def __init__(self, env, skip=4): - """Return only every `skip`-th frame""" + """Return only every `skip`-th frame.""" gym.Wrapper.__init__(self, env) - # most recent raw observations (for max pooling across time steps) - self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8) + # Most recent raw observations (for max pooling across time steps). + self._obs_buffer = np.zeros((2,) + env.observation_space.shape, + dtype=np.uint8) self._skip = skip - def reset(self): - return self.env.reset() + def reset(self, **kwargs): + return self.env.reset(**kwargs) def step(self, action): """Repeat action, sum reward, and max over last observations.""" @@ -97,8 +128,6 @@ def step(self, action): return max_frame, total_reward, done, info - def reset(self, **kwargs): - return self.env.reset(**kwargs) def wrap_atari(env, warp=False, frame_skip=False, frame_stack=False): if warp: diff --git a/tensor2tensor/rl/rl_trainer_lib.py b/tensor2tensor/rl/rl_trainer_lib.py index 3d243a97b..3193b7044 100644 --- a/tensor2tensor/rl/rl_trainer_lib.py +++ b/tensor2tensor/rl/rl_trainer_lib.py @@ -77,8 +77,9 @@ def define_train(hparams, environment_spec, event_dir): def train(hparams, environment_spec, event_dir=None): """Train.""" if environment_spec == "stacked_pong": - environment_spec = lambda: atari_wrappers.wrap_atari( - gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False) + environment_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda + gym.make("PongNoFrameskip-v4"), + warp=False, frame_skip=4, frame_stack=False) train_summary_op, eval_summary_op, _ = define_train(hparams, environment_spec, event_dir) if event_dir: @@ -100,5 +101,7 @@ def train(hparams, environment_spec, event_dir=None): summary = sess.run(eval_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) - if model_saver and hparams.save_models_every_epochs and epoch_index % hparams.save_models_every_epochs == 0: - model_saver.save(sess, os.path.join(event_dir, "model{}.ckpt".format(epoch_index))) + if (model_saver and hparams.save_models_every_epochs and + epoch_index % hparams.save_models_every_epochs == 0): + model_saver.save(sess, os.path.join(event_dir, + "model{}.ckpt".format(epoch_index))) diff --git a/tensor2tensor/utils/cloud_mlengine.py b/tensor2tensor/utils/cloud_mlengine.py index 9d0cc0f4a..e3993717a 100755 --- a/tensor2tensor/utils/cloud_mlengine.py +++ b/tensor2tensor/utils/cloud_mlengine.py @@ -89,41 +89,23 @@ def flags_as_args(): return args -def machine_config(num_gpus=1, use_tpu=False, master_type=None): - """Return dict specifying machine config for trainingInput.""" +def get_default_master_type(num_gpus=1, use_tpu=False): + """Returns master_type for trainingInput.""" if use_tpu: - master_type = 'standard_tpu' + return 'standard_tpu' elif num_gpus <= 0: - master_type = master_type or 'standard' - cpu_types = ['standard', 'large_model', 'complex_model_s', - 'complex_model_m', 'complex_model_l'] - if master_type not in cpu_types: - raise ValueError('Expected `cloudml_engine_master_type` to be one of %s ' - 'when `worker_gpu` <= 0, found %s.', str(cpu_types), - master_type) - elif num_gpus >= 1: - if num_gpus == 1: - if master_type != 'standard_gpu': - master_type = 'standard_p100' - elif num_gpus == 4: - if master_type != 'complex_model_m_gpu': - master_type = 'complex_model_m_p100' - elif num_gpus == 8: - master_type = 'complex_model_l_gpu' - else: - raise ValueError('Must use exactly 1, 4, or 8 GPUs.') - assert master_type - return { - 'scaleTier': 'CUSTOM', - 'masterType': master_type - } + return 'standard' + elif num_gpus == 1: + return 'standard_p100' + elif num_gpus == 4: + return 'complex_model_m_p100' + elif num_gpus == 8: + return 'complex_model_l_gpu' + assert False def configure_job(): """Construct jobSpec for ML Engine job.""" - train_dir = FLAGS.output_dir - assert train_dir.startswith('gs://') - # See documentation: # https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#traininginput training_input = { @@ -132,15 +114,13 @@ def configure_job(): 'region': cloud.default_region(), 'runtimeVersion': '1.4', 'pythonVersion': '3.5' if sys.version_info.major == 3 else '2.7', - 'jobDir': train_dir, - } - training_input.update( - machine_config( + 'jobDir': FLAGS.output_dir, + 'scaleTier': 'CUSTOM', + 'masterType': FLAGS.cloud_mlengine_master_type or get_default_master_type( num_gpus=FLAGS.worker_gpu, - use_tpu=FLAGS.use_tpu, - master_type=FLAGS.cloud_mlengine_master_type)) + use_tpu=FLAGS.use_tpu) + } if FLAGS.hparams_range: - assert FLAGS.autotune_objective tf.logging.info('Configuring hyperparameter tuning.') training_input['hyperparameters'] = configure_autotune( FLAGS.hparams_range, @@ -278,15 +258,40 @@ def configure_usr_dir(job_spec, usr_tar): job_spec['trainingInput']['args'].extend(usr_args) -def launch(): - """Launch t2t_trainer on Cloud ML Engine.""" +def validate_flags(): + """Validates flags are set to acceptable values for CloudML Engine runs.""" assert not FLAGS.cloud_tpu assert not job_dir() assert FLAGS.output_dir.startswith('gs://') assert FLAGS.data_dir.startswith('gs://') assert FLAGS.worker_replicas <= 1 assert FLAGS.ps_replicas <= 0 + if FLAGS.hparams_range: + assert FLAGS.autotune_objective + if FLAGS.worker_gpu: + assert FLAGS.worker_gpu in [1, 4, 8] + if FLAGS.cloud_mlengine_master_type: + if FLAGS.use_tpu: + assert FLAGS.cloud_mlengine_master_type == 'standard_tpu' + elif FLAGS.worker_gpu: + if FLAGS.worker_gpu == 1: + assert FLAGS.cloud_ml_engine_master_type in ['standard_gpu', + 'standard_p100'] + elif FLAGS.worker_gpu == 4: + assert FLAGS.cloud_ml_engine_master_type in ['complex_model_m_gpu', + 'complex_model_m_p100'] + else: + assert FLAGS.cloud_ml_engine_master_type == 'complex_model_l_gpu' + else: + assert FLAGS.cloud_mlengine_master_type in ['standard', 'large_model', + 'complex_model_s', + 'complex_model_m', + 'complex_model_l'] + +def launch(): + """Launch t2t_trainer on Cloud ML Engine.""" + validate_flags() job_spec = configure_job() job_name = job_spec['jobId'] tf.logging.info('Launching job %s with ML Engine spec:\n%s', job_name,