Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged

1.6.2 #771

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Corrections to autoencoders and atari ROM paths.
PiperOrigin-RevId: 195750518
  • Loading branch information
Lukasz Kaiser committed May 8, 2018
commit b748dbefb57340156c357611c1195c1538e5dc55
108 changes: 88 additions & 20 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import video_utils

from tensor2tensor.models.research import autoencoders
from tensor2tensor.models.research import rl
from tensor2tensor.rl import collect
from tensor2tensor.rl.envs import tf_atari_wrappers as atari
Expand All @@ -42,7 +43,9 @@
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("agent_policy_path", "", "File with model for pong")

flags.DEFINE_string("agent_policy_path", "", "File with model for agent.")
flags.DEFINE_string("autoencoder_path", "", "File with model for autoencoder.")


class GymDiscreteProblem(video_utils.VideoProblem):
Expand Down Expand Up @@ -99,6 +102,14 @@ def env(self):
def num_actions(self):
return self.env.action_space.n

@property
def frame_height(self):
return self.env.observation_space.shape[0]

@property
def frame_width(self):
return self.env.observation_space.shape[1]

@property
def num_rewards(self):
raise NotImplementedError()
Expand Down Expand Up @@ -150,14 +161,6 @@ class GymPongRandom5k(GymDiscreteProblem):
def env_name(self):
return "PongDeterministic-v4"

@property
def frame_height(self):
return 210

@property
def frame_width(self):
return 160

@property
def min_reward(self):
return -1
Expand All @@ -181,7 +184,36 @@ def num_steps(self):


@registry.register_problem
class GymDiscreteProblemWithAgent(GymPongRandom5k):
class GymFreewayRandom5k(GymDiscreteProblem):
"""Freeway game, random actions."""

@property
def env_name(self):
return "FreewayDeterministic-v4"

@property
def min_reward(self):
return 0

@property
def num_rewards(self):
return 2

@property
def num_steps(self):
return 5000


@registry.register_problem
class GymFreewayRandom50k(GymFreewayRandom5k):
"""Freeway game, random actions."""

@property
def num_steps(self):
return 50000


class GymDiscreteProblemWithAgent(GymDiscreteProblem):
"""Gym environment with discrete actions and rewards and an agent."""

def __init__(self, *args, **kwargs):
Expand All @@ -190,7 +222,7 @@ def __init__(self, *args, **kwargs):
self.debug_dump_frames_path = "debug_frames_env"

# defaults
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
self.environment_spec = lambda: gym.make(self.env_name)
self.in_graph_wrappers = []
self.collect_hparams = rl.atari_base()
self.settable_num_steps = 20000
Expand All @@ -210,7 +242,7 @@ def _setup(self):
generator_batch_env = batch_env_factory(
self.environment_spec, env_hparams, num_agents=1, xvfb=False)

with tf.variable_scope("", reuse=tf.AUTO_REUSE):
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
if FLAGS.agent_policy_path:
policy_lambda = self.collect_hparams.network
else:
Expand All @@ -223,7 +255,7 @@ def _setup(self):
create_scope_now_=True,
unique_name_="network")

with tf.variable_scope("", reuse=tf.AUTO_REUSE):
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
self.collect_hparams.epoch_length = 10
_, self.collect_trigger_op = collect.define_collect(
policy_factory, generator_batch_env, self.collect_hparams,
Expand All @@ -238,6 +270,22 @@ def restore_networks(self, sess):
tf.global_variables(".*network_parameters.*"))
model_saver.restore(sess, FLAGS.agent_policy_path)

def autoencode(self, image, sess):
with tf.Graph().as_default():
hparams = autoencoders.autoencoder_discrete_pong()
hparams.data_dir = "unused"
hparams.problem_hparams = self.get_hparams(hparams)
hparams.problem = self
model = autoencoders.AutoencoderOrderedDiscrete(
hparams, tf.estimator.ModeKeys.EVAL)
img = tf.constant(image)
img = tf.to_int32(tf.reshape(
img, [1, 1, self.frame_height, self.frame_width, self.num_channels]))
encoded = model.encode(img)
model_saver = tf.train.Saver(tf.global_variables())
model_saver.restore(sess, FLAGS.autoencoder_path)
return sess.run(encoded)

def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
self._setup()
self.debug_dump_frames_path = os.path.join(
Expand All @@ -246,17 +294,14 @@ def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
self.restore_networks(sess)
# Actions are shifted by 1 by MemoryWrapper, compensate here.
avilable_data_size = sess.run(self.avilable_data_size_op)
if avilable_data_size < 1:
sess.run(self.collect_trigger_op)
pieces_generated = 0
observ, reward, _, _ = sess.run(self.data_get_op)
while pieces_generated < self.num_steps + self.warm_up:
avilable_data_size = sess.run(self.avilable_data_size_op)
if avilable_data_size < 1:
sess.run(self.collect_trigger_op)
next_observ, next_reward, action, _ = sess.run(self.data_get_op)
observ, reward, action, _, img = sess.run(self.data_get_op)
if FLAGS.autoencoder_path:
observ = self.autoencode(img, sess)
yield {"image/encoded": [observ],
"image/format": ["png"],
"image/height": [self.frame_height],
Expand All @@ -265,7 +310,6 @@ def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
"done": [int(False)],
"reward": [int(reward) - self.min_reward]}
pieces_generated += 1
observ, reward = next_observ, next_reward


@registry.register_problem
Expand All @@ -286,3 +330,27 @@ def restore_networks(self, sess):
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
ckpt = ckpts.model_checkpoint_path
env_model_loader.restore(sess, ckpt)


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnPong(
GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnPong(
GymDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnFreeway(
GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnFreeway(
GymDiscreteProblemWithAgent, GymFreewayRandom5k):
pass
2 changes: 1 addition & 1 deletion tensor2tensor/layers/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def loss(self, logits, targets):
logits,
targets,
self._model_hparams.label_smoothing,
cutoff=0.001,
cutoff=0.02,
weights_fn=self.targets_weights_fn)


Expand Down
12 changes: 10 additions & 2 deletions tensor2tensor/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

@registry.register_model
class BasicFcRelu(t2t_model.T2TModel):
"""Basic fully-connected + ReLU model."""

def body(self, features):
hparams = self.hparams
Expand All @@ -49,6 +50,7 @@ class BasicAutoencoder(t2t_model.T2TModel):

def __init__(self, *args, **kwargs):
super(BasicAutoencoder, self).__init__(*args, **kwargs)
self.cur_bottleneck_tensor = None
self.is1d = None

def bottleneck(self, x):
Expand Down Expand Up @@ -120,6 +122,7 @@ def body(self, features):
x = self.encoder(x)
# Bottleneck (mix during early training, not too important but stable).
b = self.bottleneck(x)
self.cur_bottleneck_tensor = b
b_loss = self.bottleneck_loss(b)
b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
Expand Down Expand Up @@ -153,8 +156,13 @@ def sample(self):
# Sample in [-1, 1] as the bottleneck is under tanh.
return 2.0 * tf.random_uniform(size) - 1.0

def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
alpha=0.0):
def encode(self, x, *args, **kwargs):
"""Auto-encode x and return the bottleneck."""
features = {"targets": x}
self(features) # pylint: disable=not-callable
return self.cur_bottleneck_tensor

def infer(self, features, *args, **kwargs):
"""Produce predictions from the model by sampling."""
# Inputs and features preparation needed to handle edge cases.
if not features:
Expand Down
24 changes: 20 additions & 4 deletions tensor2tensor/models/research/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# Dependency imports

from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import discretization
from tensor2tensor.models import basic
Expand Down Expand Up @@ -226,6 +227,7 @@ def decoder(self, x):
name="residual_%d" % r)
x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout)
x = common_layers.layer_norm(x)
x = common_attention.add_timing_signal_nd(x)
return x


Expand Down Expand Up @@ -297,6 +299,9 @@ def sample(self):
class AutoencoderOrderedDiscrete(AutoencoderResidualDiscrete):
"""Ordered discrete autoencoder."""

def bottleneck_loss(self, unused_b):
return 0.0

def bottleneck(self, x):
hparams = self.hparams
noise = hparams.bottleneck_noise
Expand Down Expand Up @@ -418,7 +423,7 @@ def autoencoder_autoregressive():
"""Autoregressive autoencoder model."""
hparams = basic.basic_autoencoder()
hparams.add_hparam("autoregressive_forget_base", False)
hparams.add_hparam("autoregressive_mode", "conv3")
hparams.add_hparam("autoregressive_mode", "none")
hparams.add_hparam("autoregressive_dropout", 0.4)
hparams.add_hparam("autoregressive_decode_steps", 0)
hparams.add_hparam("autoregressive_eval_pure_autoencoder", False)
Expand All @@ -429,10 +434,10 @@ def autoencoder_autoregressive():
def autoencoder_residual():
"""Residual autoencoder model."""
hparams = autoencoder_autoregressive()
hparams.optimizer = "Adam"
hparams.learning_rate_constant = 0.0001
hparams.optimizer = "Adafactor"
hparams.learning_rate_constant = 0.2
hparams.learning_rate_warmup_steps = 500
hparams.learning_rate_schedule = "constant * linear_warmup"
hparams.learning_rate_schedule = "constant * linear_warmup * rsqrt_decay"
hparams.dropout = 0.05
hparams.num_hidden_layers = 5
hparams.hidden_size = 64
Expand Down Expand Up @@ -494,6 +499,17 @@ def autoencoder_ordered_discrete():
return hparams


@registry.register_hparams
def autoencoder_discrete_pong():
"""Discrete autoencoder model for compressing pong frames."""
hparams = autoencoder_ordered_discrete()
hparams.bottleneck_size = 24
hparams.dropout = 0.2
hparams.batch_size = 2
hparams.bottleneck_noise = 0.4
return hparams


@registry.register_hparams
def autoencoder_stacked():
"""Stacked autoencoder model."""
Expand Down
6 changes: 4 additions & 2 deletions tensor2tensor/models/research/basic_conv_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import six

from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import registry
Expand Down Expand Up @@ -106,6 +107,7 @@ def body(self, features):
shape = common_layers.shape_list(y)
x = x[:, :shape[1], :shape[2], :]
x = common_layers.layer_norm(x + y)
x = common_attention.add_timing_signal_nd(x)

# Cut down to original size.
x = x[:, :inputs_shape[1], :inputs_shape[2], :]
Expand Down Expand Up @@ -167,14 +169,14 @@ def basic_conv():
hparams.batch_size = 8
hparams.num_hidden_layers = 2
hparams.optimizer = "Adafactor"
hparams.learning_rate_constant = 0.5
hparams.learning_rate_constant = 1.5
hparams.learning_rate_warmup_steps = 1500
hparams.learning_rate_schedule = "linear_warmup * constant * rsqrt_decay"
hparams.label_smoothing = 0.0
hparams.initializer = "uniform_unit_scaling"
hparams.initializer_gain = 1.0
hparams.weight_decay = 0.0
hparams.dropout = 0.2
hparams.dropout = 0.5
hparams.add_hparam("num_compress_steps", 6)
hparams.add_hparam("filter_double_steps", 5)
return hparams
Expand Down
Loading