Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ca967c1
Internal change
Mar 2, 2018
908981e
Set evaluation_master to master in RunConfig
Mar 2, 2018
e2e61ac
Add option to do soft EM instead of hard EM
a-googler Mar 5, 2018
f92d901
Add an inv_temp hparam for controlling softness of EM
a-googler Mar 6, 2018
f36f82c
More Librispeech subsets to help with mixed clean and noisy data trai…
a-googler Mar 6, 2018
40d1f15
proper em - P(c_i) is computed using ema_count instead of actual counts
a-googler Mar 6, 2018
8320faf
increase pseudo-count to 1.0 and now there's no NaN in training
a-googler Mar 7, 2018
d83d992
Use logits instead of probs to compute supervised attention loss.
a-googler Mar 8, 2018
7056827
Why do we need stop gradient here?
a-googler Mar 8, 2018
e1e8fbb
Add expected_attention_loss_type hparam to supervised_attention to al…
a-googler Mar 8, 2018
5ee776d
ema_count trainable should be False; this was causing the weird dp be…
a-googler Mar 8, 2018
7293efc
Fix multi-logit loss computation error.
aidangomez Mar 8, 2018
75d2aef
Basic autoencoder and improvements in image modality.
Mar 9, 2018
c4e6fab
Change batch size for hparam config
Mar 9, 2018
9ae5bc2
Make Vanilla GAN work, based on Compare GAN code.
Mar 9, 2018
1568e9b
internal
Mar 9, 2018
95053b4
Bump release number.
Mar 9, 2018
9a638df
Documentation for cloud TPU for Image Transformer. Additional default…
Mar 9, 2018
3de51ab
Add smoothed L0 prior and trainable logits for cluster probabilities.
Mar 9, 2018
6a6d9fe
Added the ema count smoothing update inside the else.
Mar 9, 2018
6e846f2
Make text_encoder unicode conversion a pass-through
Mar 9, 2018
d8080a1
Pass in decode_hp to _interactive_input_fn and remove summaries when
a-googler Mar 10, 2018
c7495b5
six.iteritems for Py3
Mar 10, 2018
329123f
Update Travis tests for Py3 to run TF 1.6
Mar 10, 2018
688f4d5
Update ISSUE_TEMPLATE
Mar 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Basic autoencoder and improvements in image modality.
PiperOrigin-RevId: 188428275
  • Loading branch information
Lukasz Kaiser authored and Ryan Sepassi committed Mar 9, 2018
commit 75d2aefe85a8fc453bf6f818607c1012a82c113f
20 changes: 12 additions & 8 deletions tensor2tensor/data_generators/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def preprocess_example(self, example, mode, unused_hparams):
image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
if mode == tf.estimator.ModeKeys.TRAIN:
image = image_utils.cifar_image_augmentation(image)
image = tf.image.per_image_standardization(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

Expand All @@ -151,7 +152,8 @@ class ImageCifar10Plain(ImageCifar10):
def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
image = tf.image.per_image_standardization(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

Expand Down Expand Up @@ -179,7 +181,8 @@ def dataset_filename(self):
def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image = image_utils.resize_by_area(image, 8)
image = tf.image.per_image_standardization(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

Expand All @@ -192,7 +195,6 @@ def dataset_filename(self):
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.

def preprocess_example(self, example, unused_mode, unused_hparams):

inputs = example["inputs"]
# For Img2Img resize input and output images as desired.
example["inputs"] = image_utils.resize_by_area(inputs, 8)
Expand Down Expand Up @@ -330,7 +332,8 @@ def preprocess_example(self, example, mode, unused_hparams):
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
if mode == tf.estimator.ModeKeys.TRAIN:
image = image_utils.cifar_image_augmentation(image)
image = tf.image.per_image_standardization(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

Expand All @@ -357,7 +360,8 @@ class ImageCifar100Plain(ImageCifar100):
def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
image = tf.image.per_image_standardization(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

Expand Down Expand Up @@ -385,7 +389,8 @@ def dataset_filename(self):
def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image = image_utils.resize_by_area(image, 8)
image = tf.image.per_image_standardization(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

Expand All @@ -398,7 +403,6 @@ def dataset_filename(self):
return "image_cifar100_plain" # Reuse CIFAR-100 plain data.

def preprocess_example(self, example, unused_mode, unused_hparams):

inputs = example["inputs"]
# For Img2Img resize input and output images as desired.
example["inputs"] = image_utils.resize_by_area(inputs, 8)
Expand Down
13 changes: 12 additions & 1 deletion tensor2tensor/data_generators/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry

import tensorflow as tf
Expand Down Expand Up @@ -64,9 +65,19 @@ def example_reading_spec(self, label_repr=None):
return data_fields, data_items_to_decoders

def preprocess_example(self, example, mode, hparams):
example["inputs"] = tf.image.per_image_standardization(example["inputs"])
if not self._was_reversed:
example["inputs"] = tf.image.per_image_standardization(example["inputs"])
return example

def eval_metrics(self):
eval_metrics = [
metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5,
metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY
]
if self._was_reversed:
eval_metrics += [metrics.Metrics.IMAGE_SUMMARY]
return eval_metrics


class Image2ClassProblem(ImageProblem):
"""Base class for image classification problems."""
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/data_generators/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def train_shards(self):
def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image.set_shape([_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1])
image = tf.image.per_image_standardization(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

Expand Down
42 changes: 39 additions & 3 deletions tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,7 +1618,8 @@ def padded_cross_entropy(logits,
labels,
label_smoothing,
weights_fn=weights_nonzero,
reduce_sum=True):
reduce_sum=True,
gaussian=False):
"""Compute cross-entropy assuming 0s are padding.

Computes a loss numerator (the sum of losses), and loss denominator
Expand All @@ -1631,12 +1632,19 @@ def padded_cross_entropy(logits,
label_smoothing: a floating point `Scalar`.
weights_fn: A function from labels to weights.
reduce_sum: a Boolean, whether to sum at the end or not.
gaussian: If true, use a gaussian distribution for label smoothing

Returns:
loss_numerator: a `Scalar`. Sum of losses.
loss_denominator: a `Scalar. The number of non-padding target tokens.

Raises:
ValueError: in case of unsupported argument types.
"""
if isinstance(logits, FactoredTensor):
if gaussian:
raise ValueError("Factored padded cross entropy with Gaussian smoothing "
"is not implemented yet.")
return padded_cross_entropy_factored(
logits,
labels,
Expand All @@ -1653,7 +1661,8 @@ def padded_cross_entropy(logits,
labels = tf.reshape(labels, [-1])
else:
logits, labels = pad_with_zeros(logits, labels)
xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence)
xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence,
gaussian=gaussian)
weights = weights_fn(labels)
if not reduce_sum:
return xent * weights, weights
Expand Down Expand Up @@ -1688,7 +1697,7 @@ def smoothing_cross_entropy(logits,
confidence * tf.log(confidence) + tf.to_float(vocab_size - 1) *
low_confidence * tf.log(low_confidence + 1e-20))

if gaussian:
if gaussian and confidence > 0.0:
labels = tf.cast(labels, tf.float32)

normal_dist = tf.distributions.Normal(loc=labels, scale=confidence)
Expand Down Expand Up @@ -2629,3 +2638,30 @@ def dense(x, units, **kwargs):
return _recompute_grad(fn, [x])
else:
return fn(x)


def mix(x1, x2, steps, is_training,
min_prob=0.0, max_prob=1.0, mode="lin", simple=False):
"""Mix starting with x2, mixing mixing, going towards x1."""
if not is_training:
return x1

def get_res():
"""Create the result. Separate function to speed it up later (see below)."""
if mode == "lin":
alpha_p = inverse_lin_decay(steps)
else:
alpha_p = inverse_exp_decay(steps)
alpha_p = alpha_p * (max_prob - min_prob) + min_prob
if simple:
return alpha_p * x1 + (1.0 - alpha_p) * x2
alpha = tf.random_uniform(shape_list(x1))
alpha = tf.to_float(tf.less(alpha, alpha_p))
return alpha * x1 + (1.0 - alpha) * x2

if max_prob < 1.0:
return get_res()

# Prevent sampling after steps is passed to speed it up.
return tf.cond(tf.less(tf.train.get_global_step(), steps),
get_res, lambda: x1)
55 changes: 35 additions & 20 deletions tensor2tensor/layers/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def loss(self, logits, targets):
@registry.register_image_modality("default")
class ImageModality(modality.Modality):
"""Modality for images."""
NUM_CHANNELS = 3
PIXEL_EMBEDDING_SIZE = 64

def bottom(self, inputs):
with tf.variable_scope(self.name):
Expand All @@ -205,35 +205,50 @@ def bottom(self, inputs):

def targets_bottom(self, inputs):
with tf.variable_scope(self.name):
# Reshape inputs to 2-d tensor and embed the RGB pixel values.
ret = common_layers.embedding(
tf.to_int32(common_layers.flatten4d3d(inputs)),
self.top_dimensionality,
self._body_input_depth,
name="input_rgb_embedding")
if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
ret *= self._body_input_depth**0.5

reshape_shape = common_layers.shape_list(inputs)[:3]
reshape_shape.append(self._body_input_depth * 3)
ret = tf.reshape(ret, reshape_shape)
return tf.layers.dense(ret, self._body_input_depth)
if not context.in_eager_mode():
tf.summary.image("targets_bottom",
tf.cast(inputs, tf.uint8), max_outputs=1)
inputs_shape = common_layers.shape_list(inputs)
if len(inputs_shape) != 4:
raise ValueError("Assuming images given as int tensors in the format "
"[batch, height, width, channels] (256 values).")
# We embed each of 256=self.top_dimensionality possible pixel values.
embedding_var = tf.get_variable(
"pixel_embedding",
[self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE])
hot_inputs = tf.one_hot(tf.to_int32(inputs), self.top_dimensionality)
hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality])
embedded = tf.matmul(hot_inputs, embedding_var)
# Let's now merge all channels that were embedded into a single vector.
merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[3]
embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size])
merged = tf.layers.dense(embedded, self._body_input_depth,
name="merge_pixel_embedded_channels")
return merged

def top(self, body_output, _):
# TODO(lukaszkaiser): is this a universal enough way to get channels?
num_channels = self._model_hparams.problem_instances[0].num_channels
with tf.variable_scope("rgb_softmax"):

body_output_shape = common_layers.shape_list(body_output)
reshape_shape = body_output_shape[:3]
dim = body_output_shape[-1] // 3
reshape_shape.extend([self.NUM_CHANNELS, dim])

out = tf.reshape(body_output, reshape_shape)
res = tf.layers.dense(out, self.top_dimensionality)
reshape_shape.extend([num_channels, self.top_dimensionality])
res = tf.layers.dense(body_output, self.top_dimensionality * num_channels)
res = tf.reshape(res, reshape_shape)
if not tf.get_variable_scope().reuse:
res_argmax = tf.cast(tf.argmax(res, axis=-1), tf.uint8)
tf.summary.image("result", res_argmax, max_outputs=1)
return res

def loss(self, logits, targets):
"""Compute loss numerator and denominator for one shard of output."""
return common_layers.padded_cross_entropy(
logits,
targets,
self._model_hparams.label_smoothing,
weights_fn=self.targets_weights_fn,
gaussian=True)


@registry.register_image_modality("image_channel_compress")
class ImageChannelCompressModality(modality.Modality):
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from tensor2tensor.models.research import aligned
from tensor2tensor.models.research import attention_lm
from tensor2tensor.models.research import attention_lm_moe
from tensor2tensor.models.research import autoencoders
from tensor2tensor.models.research import cycle_gan
from tensor2tensor.models.research import gene_expression
from tensor2tensor.models.research import multimodel
Expand Down
69 changes: 69 additions & 0 deletions tensor2tensor/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,52 @@ def body(self, features):
return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T.


@registry.register_model
class BasicAutoencoder(t2t_model.T2TModel):
"""A basic autoencoder, try with image_mnist_rev or image_cifar10_rev."""

def bottleneck(self, x, res_size):
hparams = self._hparams
x = tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck")
x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
x = tf.layers.dense(x, res_size, name="unbottleneck")
return x

def body(self, features):
hparams = self._hparams
is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
x = features["targets"]
shape = common_layers.shape_list(x)
kernel = (hparams.kernel_height, hparams.kernel_width)
is1d = shape[2] == 1
kernel = (hparams.kernel_height, 1) if is1d else kernel
strides = (2, 1) if is1d else (2, 2)
x, _ = common_layers.pad_to_same_length(
x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=1)
if not is1d:
x, _ = common_layers.pad_to_same_length(
x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=2)
# Down-convolutions.
for i in xrange(hparams.num_hidden_layers):
x = tf.layers.conv2d(
x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides,
padding="SAME", activation=tf.nn.relu, name="conv_%d" % i)
x = common_layers.layer_norm(x)
# Bottleneck (mix during early training, not too important but very stable).
b = self.bottleneck(x, hparams.hidden_size * 2**hparams.num_hidden_layers)
x = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
# Up-convolutions.
for i in xrange(hparams.num_hidden_layers):
j = hparams.num_hidden_layers - i - 1
x = tf.layers.conv2d_transpose(
x, hparams.hidden_size * 2**j, kernel, strides=strides,
padding="SAME", activation=tf.nn.relu, name="deconv_%d" % j)
x = common_layers.layer_norm(x)
res = x[:, :shape[1], :shape[2], :]
return common_layers.mix(res, features["targets"],
hparams.bottleneck_warmup_steps // 2, is_training)


@registry.register_hparams
def basic_fc_small():
"""Small fully connected model."""
Expand All @@ -57,3 +103,26 @@ def basic_fc_small():
hparams.weight_decay = 0.0
hparams.dropout = 0.0
return hparams


@registry.register_hparams
def basic_autoencoder():
"""Basic autoencoder model."""
hparams = common_hparams.basic_params1()
hparams.optimizer = "Adam"
hparams.learning_rate_constant = 0.0002
hparams.learning_rate_warmup_steps = 500
hparams.learning_rate_schedule = "constant * linear_warmup"
hparams.label_smoothing = 0.05
hparams.batch_size = 128
hparams.hidden_size = 64
hparams.num_hidden_layers = 4
hparams.initializer = "uniform_unit_scaling"
hparams.initializer_gain = 1.0
hparams.weight_decay = 0.0
hparams.kernel_height = 4
hparams.kernel_width = 4
hparams.dropout = 0.1
hparams.add_hparam("bottleneck_size", 128)
hparams.add_hparam("bottleneck_warmup_steps", 3000)
return hparams
Loading