Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pre-commit.
  • Loading branch information
gianlucadecola committed Apr 10, 2022
commit 02b81e5b3f34e661d2e1bf7022ab5fc74fc35122
40 changes: 17 additions & 23 deletions gym/utils/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@


class PlayableGame:
def __init__(self, env, keys_to_action=None):
def __init__(self, env, keys_to_action=None, zoom=None):
self.env = env
self.relevant_keys = self.get_relevant_keys(keys_to_action)
self.relevant_keys = self._get_relevant_keys(keys_to_action)
self.video_size = self._get_video_size(zoom)
self.screen = pygame.display.set_mode(self.video_size)
self.pressed_keys = []
self.running = True

def get_relevant_keys(self, keys_to_action):
def _get_relevant_keys(self, keys_to_action):
if keys_to_action is None:
if hasattr(self.env, "get_keys_to_action"):
keys_to_action = self.env.get_keys_to_action()
Expand All @@ -41,7 +43,7 @@ def get_relevant_keys(self, keys_to_action):
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map is nice (functional programming is great), but in Python it's typically more readable to use a comprehension, so it'd be something like
set(sum(list(key) for key in keys_to_action), []))

But is this actually what we want this to do? I'm not sure what each key is meant to be (see previous comment about types), so take another look at the logic here. Intuition tells me that you might have wanted to just do set(keys_to_actions.keys()) or something like that (which might be equivalent to just set(keys_to_actions))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to explore it more in depth, this logic was already there in the old play(), it seems also to me that set(keys_to_actions.keys()) may be fine

return relevant_keys

def get_video_size(self, zoom=None):
def _get_video_size(self, zoom=None):
rendered = self.env.render(mode="rgb_array")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put like a TODO here so that we remember to update this when the render API change goes through?

video_size = [rendered.shape[1], rendered.shape[0]]

Expand All @@ -62,9 +64,8 @@ def process_event(self, event):
elif event.type == pygame.QUIT:
self.running = False
elif event.type == VIDEORESIZE:
video_size = event.size
screen = pygame.display.set_mode(video_size)
print(video_size)
self.video_size = event.size
self.screen = pygame.display.set_mode(self.video_size)


def display_arr(screen, arr, video_size, transpose):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add type hints here and to play()?

Expand Down Expand Up @@ -132,37 +133,30 @@ def callback(obs_t, obs_tp1, action, rew, done, info):
If None, default key_to_action mapping for that env is used, if provided.
"""
env.reset()
game = PlayableGame(env, keys_to_action)
game = PlayableGame(env, keys_to_action, zoom)

video_size = game.get_video_size(zoom)

env_done = True

screen = pygame.display.set_mode(video_size)
done = True
clock = pygame.time.Clock()

while game.running:
if env_done:
env_done = False
if done:
done = False
obs = env.reset()
else:
action = keys_to_action.get(tuple(sorted(game.pressed_keys)), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we always safely assume that 0 is a good default action? (if I'm understanding this correctly)
Maybe this should be part of the specification?

In particular this seems like it'd crash with any continuous-action environments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is actually the default action if no relevant keys are pressed; I have the same doubt regarding 0, I will explore it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RedTachyon About this I think we can safely say that, since we can not map a continuos-action environment to discrete keyboard presses it is ok to leave 0 as default because play() will crash before since there is no keys_to_action. (Maybe we could add an explicit message for continuos-action environment?)
An example of a keys to action mapping it is indeed this:

def get_keys_to_action(self):
# Control with left and right arrow keys.
return {(): 1, (276,): 0, (275,): 2, (275, 276): 1}

prev_obs = obs
obs, rew, env_done, info = env.step(action)
obs, rew, done, info = env.step(action)
if callback is not None:
callback(prev_obs, obs, action, rew, env_done, info)
callback(prev_obs, obs, action, rew, done, info)
if obs is not None:
rendered = env.render(mode="rgb_array")
display_arr(screen, rendered, transpose=transpose, video_size=video_size)
display_arr(
game.screen, rendered, transpose=transpose, video_size=game.video_size
)

# process pygame events
for event in pygame.event.get():
game.process_event(event)
# test events, set key states
if event.type == VIDEORESIZE:
video_size = event.size
screen = pygame.display.set_mode(video_size)
print(video_size)

pygame.display.flip()
clock.tick(fps)
Expand Down
25 changes: 9 additions & 16 deletions tests/utils/test_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ class DummyEnvSpec:

class DummyPlayEnv(gym.Env):
def step(self, action):
...
obs = np.zeros((1, 1))
rew = 0
done = False
info = {}
return obs, rew, done, info

def reset(self):
...
Expand Down Expand Up @@ -64,16 +68,14 @@ def test_play_relevant_keys_with_env_attribute():
def test_video_size_no_zoom():
env = DummyPlayEnv()
game = PlayableGame(env, dummy_keys_to_action())
video_size = game.get_video_size()
assert video_size == list(env.render().shape)
assert game.video_size == list(env.render().shape)


def test_video_size_zoom():
env = DummyPlayEnv()
game = PlayableGame(env, dummy_keys_to_action())
zoom_value = 2.2
video_size = game.get_video_size(zoom=zoom_value)
assert video_size == tuple(int(shape * zoom_value) for shape in env.render().shape)
zoom = 2.2
game = PlayableGame(env, dummy_keys_to_action(), zoom)
assert game.video_size == tuple(int(shape * zoom) for shape in env.render().shape)


def test_keyboard_quit_event():
Expand Down Expand Up @@ -118,12 +120,3 @@ def test_keyboard_keyup_event():
event = MockKeyEvent(pygame.KEYUP, RELEVANT_KEY)
game.process_event(event)
assert game.pressed_keys == []


# def test_play_loop():
# env = DummyPlayEnv()
# keys_to_action = {
# (ord('a'),): 0,
# (ord('d'),): 1
# }
# play(env, keys_to_action=keys_to_action)