Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged

1.6.2 #771

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Adding Lambada dataset as a new problem to T2T.
PiperOrigin-RevId: 195414151
  • Loading branch information
T2T Team authored and lukaszkaiser committed May 8, 2018
commit b7e5f484b0a66a5a0d2ee11a2f09a94ec3d6d506
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"tensor2tensor.data_generators.ice_parsing",
"tensor2tensor.data_generators.imagenet",
"tensor2tensor.data_generators.imdb",
"tensor2tensor.data_generators.lambada",
"tensor2tensor.data_generators.librispeech",
"tensor2tensor.data_generators.lm1b",
"tensor2tensor.data_generators.mnist",
Expand Down
71 changes: 11 additions & 60 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("agent_policy_path", "", "File with model for agent")
flags.DEFINE_string("agent_policy_path", "", "File with model for pong")


class GymDiscreteProblem(video_utils.VideoProblem):
Expand Down Expand Up @@ -99,14 +99,6 @@ def env(self):
def num_actions(self):
return self.env.action_space.n

@property
def frame_height(self):
return self.env.observation_space.shape[0]

@property
def frame_width(self):
return self.env.observation_space.shape[1]

@property
def num_rewards(self):
raise NotImplementedError()
Expand Down Expand Up @@ -159,58 +151,37 @@ def env_name(self):
return "PongDeterministic-v4"

@property
def min_reward(self):
return -1

@property
def num_rewards(self):
return 3

@property
def num_steps(self):
return 5000


@registry.register_problem
class GymPongRandom50k(GymPongRandom5k):
"""Pong game, random actions."""

@property
def num_steps(self):
return 50000

@registry.register_problem
class GymFreewayRandom5k(GymDiscreteProblem):
"""Freeway game, random actions."""
def frame_height(self):
return 210

@property
def env_name(self):
return "FreewayDeterministic-v4"
def frame_width(self):
return 160

@property
def min_reward(self):
return 0
return -1

@property
def num_rewards(self):
return 2
return 3

@property
def num_steps(self):
return 5000


@registry.register_problem
class GymFreewayRandom50k(GymFreewayRandom5k):
"""Freeway game, random actions."""
class GymPongRandom50k(GymPongRandom5k):
"""Pong game, random actions."""

@property
def num_steps(self):
return 50000


@registry.register_problem
class GymDiscreteProblemWithAgent(GymDiscreteProblem):
class GymDiscreteProblemWithAgent(GymPongRandom5k):
"""Gym environment with discrete actions and rewards and an agent."""

def __init__(self, *args, **kwargs):
Expand All @@ -219,7 +190,7 @@ def __init__(self, *args, **kwargs):
self.debug_dump_frames_path = "debug_frames_env"

# defaults
self.environment_spec = lambda: gym.make(self.env_name)
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
self.in_graph_wrappers = []
self.collect_hparams = rl.atari_base()
self.settable_num_steps = 20000
Expand Down Expand Up @@ -315,23 +286,3 @@ def restore_networks(self, sess):
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
ckpt = ckpts.model_checkpoint_path
env_model_loader.restore(sess, ckpt)


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnPong(GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnPong(GymDiscreteProblemWithAgent, GymPongRandom5k):
pass


@registry.register_problem
class GymSimulatedDiscreteProblemWithAgentOnFreeway(GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
pass


@registry.register_problem
class GymDiscreteProblemWithAgentOnFreeway(GymDiscreteProblemWithAgent, GymFreewayRandom5k):
pass
Loading