Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactoring play function. Tests for keys to action mapping.
  • Loading branch information
gianlucadecola committed Apr 10, 2022
commit 87f34da886c6e48b3732a271cfde289b6d8bf3d3
36 changes: 24 additions & 12 deletions gym/utils/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@
from pygame.locals import VIDEORESIZE


class PlayableGame:
def __init__(self, env):
self.env = env

def get_relevant_keys(self, keys_to_action=None):
if keys_to_action is None:
if hasattr(self.env, "get_keys_to_action"):
keys_to_action = self.env.get_keys_to_action()
elif hasattr(self.env.unwrapped, "get_keys_to_action"):
keys_to_action = self.env.unwrapped.get_keys_to_action()
else:
assert False, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly raise some Exception instead of a failing assert, you can either use one of the existing exceptions like AttributeError or something; or make a new one in this file to make it more descriptive.

self.env.spec.id
+ " does not have explicit key to action mapping, "
+ "please specify one manually"
)
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map is nice (functional programming is great), but in Python it's typically more readable to use a comprehension, so it'd be something like
set(sum(list(key) for key in keys_to_action), []))

But is this actually what we want this to do? I'm not sure what each key is meant to be (see previous comment about types), so take another look at the logic here. Intuition tells me that you might have wanted to just do set(keys_to_actions.keys()) or something like that (which might be equivalent to just set(keys_to_actions))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to explore it more in depth, this logic was already there in the old play(), it seems also to me that set(keys_to_actions.keys()) may be fine

return relevant_keys



def display_arr(screen, arr, video_size, transpose):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add type hints here and to play()?

arr_min, arr_max = arr.min(), arr.max()
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
Expand Down Expand Up @@ -83,20 +104,11 @@ def callback(obs_t, obs_tp1, action, rew, done, info):
If None, default key_to_action mapping for that env is used, if provided.
"""
env.reset()
game = PlayableGame(env)

rendered = env.render(mode="rgb_array")

if keys_to_action is None:
if hasattr(env, "get_keys_to_action"):
keys_to_action = env.get_keys_to_action()
elif hasattr(env.unwrapped, "get_keys_to_action"):
keys_to_action = env.unwrapped.get_keys_to_action()
else:
assert False, (
env.spec.id
+ " does not have explicit key to action mapping, "
+ "please specify one manually"
)
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
relevant_keys = game.get_relevant_keys(keys_to_action)

video_size = [rendered.shape[1], rendered.shape[0]]
if zoom is not None:
Expand Down
72 changes: 72 additions & 0 deletions tests/utils/test_play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass
import pytest
import numpy as np
import gym

from gym.utils.play import PlayableGame
from gym.utils.play import play



@dataclass
class DummyEnvSpec():
id: str


class DummyPlayEnv(gym.Env):

def step(self, action):
...

def reset(self):
...

def render(self, mode):
return np.zeros((1,1))


def dummy_keys_to_action():
return {(ord('a'),): 0, (ord('d'),): 1}


def test_play_relvant_keys():
env = DummyPlayEnv()
keys_to_action = {
(ord('a'),): 0,
(ord('d'),): 1
}
game = PlayableGame(env)
relevant_keys = game.get_relevant_keys(keys_to_action)
assert relevant_keys == {97, 100}


def test_play_revant_keys_no_mapping():
env = DummyPlayEnv()
env.spec = DummyEnvSpec("DummyPlayEnv")
game = PlayableGame(env)

with pytest.raises(AssertionError) as info:
game.get_relevant_keys()


def test_play_relevant_keys_with_env_attribute():
"""Env has a keys_to_action attribute
"""
env = DummyPlayEnv()
env.get_keys_to_action = dummy_keys_to_action
game = PlayableGame(env)
relevant_keys = game.get_relevant_keys()
assert relevant_keys == {97, 100}






# def test_play_loop():
# env = DummyPlayEnv()
# keys_to_action = {
# (ord('a'),): 0,
# (ord('d'),): 1
# }
# play(env, keys_to_action=keys_to_action)