Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ca967c1
Internal change
Mar 2, 2018
908981e
Set evaluation_master to master in RunConfig
Mar 2, 2018
e2e61ac
Add option to do soft EM instead of hard EM
a-googler Mar 5, 2018
f92d901
Add an inv_temp hparam for controlling softness of EM
a-googler Mar 6, 2018
f36f82c
More Librispeech subsets to help with mixed clean and noisy data trai…
a-googler Mar 6, 2018
40d1f15
proper em - P(c_i) is computed using ema_count instead of actual counts
a-googler Mar 6, 2018
8320faf
increase pseudo-count to 1.0 and now there's no NaN in training
a-googler Mar 7, 2018
d83d992
Use logits instead of probs to compute supervised attention loss.
a-googler Mar 8, 2018
7056827
Why do we need stop gradient here?
a-googler Mar 8, 2018
e1e8fbb
Add expected_attention_loss_type hparam to supervised_attention to al…
a-googler Mar 8, 2018
5ee776d
ema_count trainable should be False; this was causing the weird dp be…
a-googler Mar 8, 2018
7293efc
Fix multi-logit loss computation error.
aidangomez Mar 8, 2018
75d2aef
Basic autoencoder and improvements in image modality.
Mar 9, 2018
c4e6fab
Change batch size for hparam config
Mar 9, 2018
9ae5bc2
Make Vanilla GAN work, based on Compare GAN code.
Mar 9, 2018
1568e9b
internal
Mar 9, 2018
95053b4
Bump release number.
Mar 9, 2018
9a638df
Documentation for cloud TPU for Image Transformer. Additional default…
Mar 9, 2018
3de51ab
Add smoothed L0 prior and trainable logits for cluster probabilities.
Mar 9, 2018
6a6d9fe
Added the ema count smoothing update inside the else.
Mar 9, 2018
6e846f2
Make text_encoder unicode conversion a pass-through
Mar 9, 2018
d8080a1
Pass in decode_hp to _interactive_input_fn and remove summaries when
a-googler Mar 10, 2018
c7495b5
six.iteritems for Py3
Mar 10, 2018
329123f
Update Travis tests for Py3 to run TF 1.6
Mar 10, 2018
688f4d5
Update ISSUE_TEMPLATE
Mar 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Internal change
PiperOrigin-RevId: 187670348
  • Loading branch information
Ryan Sepassi committed Mar 9, 2018
commit ca967c155d6e1976841bc286ef42066d15f3641c
22 changes: 0 additions & 22 deletions .github/ISSUE_TEMPLATE.md

This file was deleted.

1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ script:
--ignore=tensor2tensor/problems_test.py
--ignore=tensor2tensor/bin/t2t_trainer_test.py
--ignore=tensor2tensor/data_generators/algorithmic_math_test.py
--ignore=tensor2tensor/rl/rl_trainer_lib_test.py
- pytest tensor2tensor/utils/registry_test.py
- pytest tensor2tensor/utils/trainer_lib_test.py
- pytest tensor2tensor/visualization/visualization_test.py
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ For all translation problems, we suggest to try the Transformer model:
this should reach a BLEU score of about 28 on the English-German data-set,
which is close to state-of-the art. If training on a single GPU, try the
`--hparams_set=transformer_base_single_gpu` setting. For very good results
or larger data-sets (e.g., for English-French), try the big model
or larger data-sets (e.g., for English-French)m, try the big model
with `--hparams_set=transformer_big`.

## Basics
Expand Down
2 changes: 1 addition & 1 deletion docs/distributed_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ training.

T2T uses TensorFlow Estimators and so distributed training is configured with
the `TF_CONFIG` environment variable that is read by the
[RunConfig](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/estimator/run_config.py)
[RunConfig](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/estimators/run_config.py)
along with a set of flags.

## `TF_CONFIG`
Expand Down
6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ For language modeling, we have these data-sets in T2T:
* LM1B (a billion-word corpus): `--problems=languagemodel_lm1b32k` for
subword-level modeling and `--problems=languagemodel_lm1b_characters`
for character-level modeling.

We suggest to start with `--model=transformer` on this task and use
`--hparams_set=transformer_small` for PTB and
`--hparams_set=transformer_base` for LM1B.
Expand All @@ -95,7 +95,7 @@ For speech-to-text, we have these data-sets in T2T:
For summarizing longer text into shorter one we have these data-sets:
* CNN/DailyMail articles summarized into a few sentences:
`--problems=summarize_cnn_dailymail32k`

We suggest to use `--model=transformer` and
`--hparams_set=transformer_prepend` for this task.
This yields good ROUGE scores.
Expand All @@ -118,5 +118,5 @@ For all translation problems, we suggest to try the Transformer model:
this should reach a BLEU score of about 28 on the English-German data-set,
which is close to state-of-the art. If training on a single GPU, try the
`--hparams_set=transformer_base_single_gpu` setting. For very good results
or larger data-sets (e.g., for English-French), try the big model
or larger data-sets (e.g., for English-French)m, try the big model
with `--hparams_set=transformer_big`.
4 changes: 2 additions & 2 deletions tensor2tensor/bin/t2t_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,12 @@ def main(argv):
if argv:
set_hparams_from_args(argv[1:])
hparams = create_hparams()
if is_chief():
save_metadata(hparams)

with maybe_cloud_tpu():
exp_fn = create_experiment_fn()
exp = exp_fn(create_run_config(hparams), hparams)
if is_chief():
save_metadata(hparams)
execute_schedule(exp)


Expand Down
47 changes: 15 additions & 32 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import tensorflow as tf




flags = tf.flags
FLAGS = flags.FLAGS

Expand All @@ -48,17 +50,6 @@ def __init__(self, *args, **kwargs):
super(GymDiscreteProblem, self).__init__(*args, **kwargs)
self._env = None

def example_reading_spec(self, label_repr=None):

data_fields = {
"inputs": tf.FixedLenFeature([210, 160, 3], tf.int64),
"inputs_prev": tf.FixedLenFeature([210, 160, 3], tf.int64),
"targets": tf.FixedLenFeature([210, 160, 3], tf.int64),
"action": tf.FixedLenFeature([1], tf.int64)
}

return data_fields, None

@property
def env_name(self):
# This is the name of the Gym environment for this problem.
Expand Down Expand Up @@ -142,7 +133,7 @@ class GymPongRandom5k(GymDiscreteProblem):

@property
def env_name(self):
return "PongNoFrameskip-v4"
return "Pong-v0"

@property
def num_actions(self):
Expand All @@ -157,30 +148,21 @@ def num_steps(self):
return 5000



@registry.register_problem
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
"""Pong game, loaded actions."""

def __init__(self, *args, **kwargs):
def __init__(self, event_dir, *args, **kwargs):
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
self._env = None
self._last_policy_op = None
self._max_frame_pl = None
self._last_action = self.env.action_space.sample()
self._skip = 4
self._skip_step = 0
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
dtype=np.uint8)

def generator(self, data_dir, tmp_dir):
self._event_dir = event_dir
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
gym.make("PongNoFrameskip-v4"),
warp=False,
frame_skip=4,
frame_stack=False)
hparams = rl.atari_base()
with tf.variable_scope("train", reuse=tf.AUTO_REUSE):
with tf.variable_scope("train"):
policy_lambda = hparams.network
policy_factory = tf.make_template(
"network",
Expand All @@ -191,13 +173,14 @@ def generator(self, data_dir, tmp_dir):
self._max_frame_pl, 0), 0))
policy = actor_critic.policy
self._last_policy_op = policy.mode()
with tf.Session() as sess:
model_saver = tf.train.Saver(
tf.global_variables(".*network_parameters.*"))
model_saver.restore(sess, FLAGS.model_path)
for item in super(GymPongTrajectoriesFromPolicy,
self).generator(data_dir, tmp_dir):
yield item
self._last_action = self.env.action_space.sample()
self._skip = 4
self._skip_step = 0
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
dtype=np.uint8)
self._sess = tf.Session()
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
model_saver.restore(self._sess, FLAGS.model_path)

# TODO(blazej0): For training of atari agents wrappers are usually used.
# Below we have a hacky solution which is a workaround to be used together
Expand All @@ -208,7 +191,7 @@ def get_action(self, observation=None):
self._skip_step = (self._skip_step + 1) % self._skip
if self._skip_step == 0:
max_frame = self._obs_buffer.max(axis=0)
self._last_action = int(tf.get_default_session().run(
self._last_action = int(self._sess.run(
self._last_policy_op,
feed_dict={self._max_frame_pl: max_frame})[0, 0])
return self._last_action
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def distorted_bounding_box_crop(image,
Returns:
(cropped image `Tensor`, distorted bbox `Tensor`).
"""
with tf.name_scope(scope, default_name="distorted_bounding_box_crop", values=[image, bbox]):
with tf.name_scope(scope, "distorted_bounding_box_crop", [image, bbox]):
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].

Expand Down
69 changes: 26 additions & 43 deletions tensor2tensor/data_generators/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,37 +77,6 @@ def _get_token_encoder(vocab_dir, vocab_name, filename):
return text_encoder.TokenTextEncoder(vocab_path)


def _maybe_download_corpus(tmp_dir, vocab_type):
"""Download and unpack the corpus.

Args:
tmp_dir: directory containing dataset.
"""
filename = os.path.basename(PTB_URL)
compressed_filepath = generator_utils.maybe_download(
tmp_dir, filename, PTB_URL)
ptb_files = []
ptb_char_files = []

with tarfile.open(compressed_filepath, "r:gz") as tgz:
files = []
# Selecting only relevant files.
for m in tgz.getmembers():
if "ptb" in m.name and ".txt" in m.name:
if "char" in m.name:
ptb_char_files += [m.name]
else:
ptb_files += [m.name]
files += [m]

tgz.extractall(tmp_dir, members=files)

if vocab_type == text_problems.VocabType.CHARACTER:
return ptb_char_files
else:
return ptb_files


@registry.register_problem
class LanguagemodelPtb10k(text_problems.Text2SelfProblem):
"""PTB, 10k vocab."""
Expand All @@ -122,10 +91,6 @@ def dataset_splits(self):
"shards": 1,
}]

@property
def is_generate_per_split(self):
return True

@property
def vocab_filename(self):
return "vocab.lmptb.10000"
Expand All @@ -135,7 +100,28 @@ def vocab_type(self):
return text_problems.VocabType.TOKEN

def generate_samples(self, data_dir, tmp_dir, dataset_split):
files = _maybe_download_corpus(tmp_dir, self.vocab_type)
filename = os.path.basename(PTB_URL)
compressed_filepath = generator_utils.maybe_download(
tmp_dir, filename, PTB_URL)
ptb_files = []
ptb_char_files = []
with tarfile.open(compressed_filepath, "r:gz") as tgz:
files = []
# Selecting only relevant files.
for m in tgz.getmembers():
if "ptb" in m.name and ".txt" in m.name:
if "char" in m.name:
ptb_char_files += [m.name]
else:
ptb_files += [m.name]
files += [m]

tgz.extractall(tmp_dir, members=files)

if self.vocab_type == text_problems.VocabType.CHARACTER:
files = ptb_char_files
else:
files = ptb_files

train_file, valid_file = None, None
for filename in files:
Expand All @@ -152,13 +138,10 @@ def generate_samples(self, data_dir, tmp_dir, dataset_split):
train = dataset_split == problem.DatasetSplit.TRAIN
filepath = train_file if train else valid_file

def _generate_samples():
with tf.gfile.GFile(filepath, "r") as f:
for line in f:
line = " ".join(line.replace("\n", " %s " % EOS).split())
yield {"targets": line}

return _generate_samples()
with tf.gfile.GFile(filepath, "r") as f:
for line in f:
line = " ".join(line.replace("\n", " %s " % EOS).split())
yield {"targets": line}


@registry.register_problem
Expand Down
22 changes: 11 additions & 11 deletions tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def comma_separated_string_to_integer_list(s):

def saturating_sigmoid(x):
"""Saturating sigmoid: 1.2 * sigmoid(x) - 0.1 cut to [0, 1]."""
with tf.name_scope("saturating_sigmoid", values=[x]):
with tf.name_scope("saturating_sigmoid", [x]):
y = tf.sigmoid(x)
return tf.minimum(1.0, tf.maximum(0.0, 1.2 * y - 0.1))

Expand Down Expand Up @@ -173,7 +173,7 @@ def shakeshake(xs, equal_grad=False):

def convert_rgb_to_real(x):
"""Conversion of pixel values to real numbers."""
with tf.name_scope("rgb_to_real", values=[x]):
with tf.name_scope("rgb_to_real", [x]):
x = tf.to_float(x)
# Use the formula (value/128) - 1 to convert each channel value into a
# real number in the range -1 to 1.
Expand Down Expand Up @@ -795,7 +795,7 @@ def subseparable_conv_block(inputs, filters, dilation_rates_and_kernel_sizes,

def pool(inputs, window_size, pooling_type, padding, strides=(1, 1)):
"""Pooling (supports "LEFT")."""
with tf.name_scope("pool", values=[inputs]):
with tf.name_scope("pool", [inputs]):
static_shape = inputs.get_shape()
if not static_shape or len(static_shape) != 4:
raise ValueError("Inputs to conv must have statically known rank 4.")
Expand Down Expand Up @@ -950,7 +950,7 @@ def simple_attention(target, source, bias=None):
Returns:
a `Tensor` with same shape as `target`
"""
with tf.name_scope("simple_attention", values=[target, source]):
with tf.name_scope("simple_attention", [target, source]):
target_shape = shape_list(target)
source_shape = shape_list(source)
target = tf.reshape(
Expand Down Expand Up @@ -1516,7 +1516,7 @@ def pad_to_same_length(x, y, final_length_divisible_by=1, axis=1):
"""Pad tensors x and y on axis 1 so that they have the same length."""
if axis not in [1, 2]:
raise ValueError("Only axis=1 and axis=2 supported for now.")
with tf.name_scope("pad_to_same_length", values=[x, y]):
with tf.name_scope("pad_to_same_length", [x, y]):
x_length = shape_list(x)[axis]
y_length = shape_list(y)[axis]
max_length = tf.maximum(x_length, y_length)
Expand Down Expand Up @@ -1551,7 +1551,7 @@ def padding_list(length_diff, arg):

def pad_with_zeros(logits, labels):
"""Pad labels on the length dimension to match logits length."""
with tf.name_scope("pad_with_zeros", values=[logits, labels]):
with tf.name_scope("pad_with_zeros", [logits, labels]):
logits, labels = pad_to_same_length(logits, labels)
if len(labels.shape.as_list()) == 3: # 2-d labels.
logits, labels = pad_to_same_length(logits, labels, axis=2)
Expand Down Expand Up @@ -1645,7 +1645,7 @@ def padded_cross_entropy(logits,
reduce_sum=reduce_sum)
confidence = 1.0 - label_smoothing
vocab_size = shape_list(logits)[-1]
with tf.name_scope("padded_cross_entropy", values=[logits, labels]):
with tf.name_scope("padded_cross_entropy", [logits, labels]):
if len(logits.get_shape().as_list()) == 2:
# Deal with the case where we did not insert extra dimensions due to
# TPU issues. No pad-to-same-length happens in this case.
Expand Down Expand Up @@ -1679,7 +1679,7 @@ def smoothing_cross_entropy(logits,
Returns:

"""
with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
with tf.name_scope("smoothing_cross_entropy", [logits, labels]):
# Low confidence is given to all non-true labels, uniformly.
low_confidence = (1.0 - confidence) / tf.to_float(vocab_size - 1)
# Normalizing constant is the best cross-entropy value with soft targets.
Expand Down Expand Up @@ -1726,7 +1726,7 @@ def global_pool_1d(inputs, pooling_type="MAX", mask=None):
output: A tensor of dimensions batch_size x input_dims
dimension containing the sequences of transformed vectors.
"""
with tf.name_scope("global_pool", values=[inputs]):
with tf.name_scope("global_pool", [inputs]):
if mask is not None:
mask = tf.expand_dims(mask, axis=2)
inputs = tf.multiply(inputs, mask)
Expand Down Expand Up @@ -1763,7 +1763,7 @@ def running_global_pool_1d(inputs, pooling_type="MAX"):
dimension containing the running 'totals'.
"""
del pooling_type
with tf.name_scope("running_global_pool", values=[inputs]):
with tf.name_scope("running_global_pool", [inputs]):
scan_fct = tf.maximum
# Permute inputs so seq_length is first.
elems = tf.transpose(inputs, [1, 0, 2])
Expand Down Expand Up @@ -2119,7 +2119,7 @@ def padded_cross_entropy_factored(factored_logits,
a = factored_logits.a
b = factored_logits.b
confidence = 1.0 - label_smoothing
with tf.name_scope("padded_cross_entropy_factored", values=[a, b, labels]):
with tf.name_scope("padded_cross_entropy_factored", [a, b, labels]):
labels_flat = tf.reshape(labels, [-1])
a_flat = tf.reshape(a, [-1, shape_list(b)[1]])
xent = smoothing_cross_entropy_factored(a_flat, b, labels_flat,
Expand Down
Loading