-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Add test gym utils play. Fix #2729 #2743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
87f34da
7982347
2a73571
02b81e5
7c65be7
1b3485f
48d5f07
3b9e08d
9a1d4b6
a6a7cd9
42b5ead
7dabd13
97fa4e9
808d0a1
bf4d2d8
52b37cc
5a8f6a4
b5aae1b
e9d333b
b0081c2
a02323c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,12 +1,15 @@ | ||||||||
| import argparse | ||||||||
| from typing import Tuple | ||||||||
|
|
||||||||
| import matplotlib | ||||||||
| import pygame | ||||||||
| from pygame.event import Event | ||||||||
|
|
||||||||
| import gym | ||||||||
| from gym import logger | ||||||||
| from gym import Env, logger | ||||||||
|
|
||||||||
| try: | ||||||||
| import matplotlib | ||||||||
|
|
||||||||
| matplotlib.use("TkAgg") | ||||||||
| import matplotlib.pyplot as plt | ||||||||
| except ImportError as e: | ||||||||
|
|
@@ -18,6 +21,55 @@ | |||||||
| from pygame.locals import VIDEORESIZE | ||||||||
|
|
||||||||
|
|
||||||||
| class PlayableGame: | ||||||||
| def __init__(self, env: Env, keys_to_action: dict = None, zoom: float = None): | ||||||||
| self.env = env | ||||||||
| self.relevant_keys = self._get_relevant_keys(keys_to_action) | ||||||||
| self.video_size = self._get_video_size(zoom) | ||||||||
| self.screen = pygame.display.set_mode(self.video_size) | ||||||||
| self.pressed_keys = [] | ||||||||
| self.running = True | ||||||||
|
|
||||||||
| def _get_relevant_keys(self, keys_to_action: dict) -> set: | ||||||||
|
||||||||
| if keys_to_action is None: | ||||||||
| if hasattr(self.env, "get_keys_to_action"): | ||||||||
| keys_to_action = self.env.get_keys_to_action() | ||||||||
| elif hasattr(self.env.unwrapped, "get_keys_to_action"): | ||||||||
| keys_to_action = self.env.unwrapped.get_keys_to_action() | ||||||||
| else: | ||||||||
| assert False, ( | ||||||||
|
||||||||
| self.env.spec.id | ||||||||
| + " does not have explicit key to action mapping, " | ||||||||
| + "please specify one manually" | ||||||||
| ) | ||||||||
| relevant_keys = set(sum(map(list, keys_to_action.keys()), [])) | ||||||||
|
||||||||
| return relevant_keys | ||||||||
|
|
||||||||
| def _get_video_size(self, zoom: float = None) -> Tuple[int, int]: | ||||||||
|
||||||||
| rendered = self.env.render(mode="rgb_array") | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you put like a TODO here so that we remember to update this when the render API change goes through? |
||||||||
| video_size = [rendered.shape[1], rendered.shape[0]] | ||||||||
|
|
||||||||
| if zoom is not None: | ||||||||
| video_size = int(video_size[0] * zoom), int(video_size[1] * zoom) | ||||||||
|
|
||||||||
| return video_size | ||||||||
|
|
||||||||
| def process_event(self, event: Event) -> None: | ||||||||
| if event.type == pygame.KEYDOWN: | ||||||||
| if event.key in self.relevant_keys: | ||||||||
| self.pressed_keys.append(event.key) | ||||||||
| elif event.key == 27: | ||||||||
|
||||||||
| self.running = False | ||||||||
| elif event.type == pygame.KEYUP: | ||||||||
| if event.key in self.relevant_keys: | ||||||||
| self.pressed_keys.remove(event.key) | ||||||||
| elif event.type == pygame.QUIT: | ||||||||
| self.running = False | ||||||||
| elif event.type == VIDEORESIZE: | ||||||||
| self.video_size = event.size | ||||||||
| self.screen = pygame.display.set_mode(self.video_size) | ||||||||
|
|
||||||||
|
|
||||||||
| def display_arr(screen, arr, video_size, transpose): | ||||||||
|
||||||||
| arr_min, arr_max = arr.min(), arr.max() | ||||||||
| arr = 255.0 * (arr - arr_min) / (arr_max - arr_min) | ||||||||
|
|
@@ -83,63 +135,30 @@ def callback(obs_t, obs_tp1, action, rew, done, info): | |||||||
| If None, default key_to_action mapping for that env is used, if provided. | ||||||||
| """ | ||||||||
| env.reset() | ||||||||
| rendered = env.render(mode="rgb_array") | ||||||||
| game = PlayableGame(env, keys_to_action, zoom) | ||||||||
|
|
||||||||
| if keys_to_action is None: | ||||||||
| if hasattr(env, "get_keys_to_action"): | ||||||||
| keys_to_action = env.get_keys_to_action() | ||||||||
| elif hasattr(env.unwrapped, "get_keys_to_action"): | ||||||||
| keys_to_action = env.unwrapped.get_keys_to_action() | ||||||||
| else: | ||||||||
| assert False, ( | ||||||||
| env.spec.id | ||||||||
| + " does not have explicit key to action mapping, " | ||||||||
| + "please specify one manually" | ||||||||
| ) | ||||||||
| relevant_keys = set(sum(map(list, keys_to_action.keys()), [])) | ||||||||
|
|
||||||||
| video_size = [rendered.shape[1], rendered.shape[0]] | ||||||||
| if zoom is not None: | ||||||||
| video_size = int(video_size[0] * zoom), int(video_size[1] * zoom) | ||||||||
|
|
||||||||
| pressed_keys = [] | ||||||||
| running = True | ||||||||
| env_done = True | ||||||||
|
|
||||||||
| screen = pygame.display.set_mode(video_size) | ||||||||
| done = True | ||||||||
| clock = pygame.time.Clock() | ||||||||
|
|
||||||||
| while running: | ||||||||
| if env_done: | ||||||||
| env_done = False | ||||||||
| while game.running: | ||||||||
| if done: | ||||||||
| done = False | ||||||||
| obs = env.reset() | ||||||||
| else: | ||||||||
| action = keys_to_action.get(tuple(sorted(pressed_keys)), 0) | ||||||||
| action = keys_to_action.get(tuple(sorted(game.pressed_keys)), 0) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we always safely assume that In particular this seems like it'd crash with any continuous-action environments?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it is actually the default action if no relevant keys are pressed; I have the same doubt regarding
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @RedTachyon About this I think we can safely say that, since we can not map a continuos-action environment to discrete keyboard presses it is ok to leave gym/gym/envs/classic_control/mountain_car.py Lines 240 to 242 in 5ae6bf9
|
||||||||
| prev_obs = obs | ||||||||
| obs, rew, env_done, info = env.step(action) | ||||||||
| obs, rew, done, info = env.step(action) | ||||||||
| if callback is not None: | ||||||||
| callback(prev_obs, obs, action, rew, env_done, info) | ||||||||
| callback(prev_obs, obs, action, rew, done, info) | ||||||||
| if obs is not None: | ||||||||
| rendered = env.render(mode="rgb_array") | ||||||||
| display_arr(screen, rendered, transpose=transpose, video_size=video_size) | ||||||||
| display_arr( | ||||||||
| game.screen, rendered, transpose=transpose, video_size=game.video_size | ||||||||
| ) | ||||||||
|
|
||||||||
| # process pygame events | ||||||||
| for event in pygame.event.get(): | ||||||||
| # test events, set key states | ||||||||
| if event.type == pygame.KEYDOWN: | ||||||||
| if event.key in relevant_keys: | ||||||||
| pressed_keys.append(event.key) | ||||||||
| elif event.key == 27: | ||||||||
| running = False | ||||||||
| elif event.type == pygame.KEYUP: | ||||||||
| if event.key in relevant_keys: | ||||||||
| pressed_keys.remove(event.key) | ||||||||
| elif event.type == pygame.QUIT: | ||||||||
| running = False | ||||||||
| elif event.type == VIDEORESIZE: | ||||||||
| video_size = event.size | ||||||||
| screen = pygame.display.set_mode(video_size) | ||||||||
| print(video_size) | ||||||||
| game.process_event(event) | ||||||||
|
|
||||||||
| pygame.display.flip() | ||||||||
| clock.tick(fps) | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| from dataclasses import dataclass, field | ||
| from typing import Callable, Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import pygame | ||
| import pytest | ||
| from pygame import KEYDOWN, QUIT, event | ||
| from pygame.event import Event | ||
|
|
||
| import gym | ||
| from gym.utils.play import PlayableGame, play | ||
|
|
||
| RELEVANT_KEY = 100 | ||
| IRRELEVANT_KEY = 1 | ||
|
|
||
|
|
||
| @dataclass | ||
| class DummyEnvSpec: | ||
| id: str | ||
|
|
||
|
|
||
| class DummyPlayEnv(gym.Env): | ||
| def step(self, action): | ||
| obs = np.zeros((1, 1)) | ||
| rew, done, info = 1, False, {} | ||
| return obs, rew, done, info | ||
|
|
||
| def reset(self): | ||
| ... | ||
|
|
||
| def render(self, mode="rgb_array"): | ||
| return np.zeros((1, 1)) | ||
|
|
||
|
|
||
| class PlayStatus: | ||
| def __init__(self, callback: Callable): | ||
| self.data_callback = callback | ||
| self.cumulative_reward = 0 | ||
|
|
||
| def callback(self, obs_t, obs_tp1, action, rew, done, info): | ||
| self.cumulative_reward += self.data_callback( | ||
| obs_t, obs_tp1, action, rew, done, info | ||
| ) | ||
|
|
||
|
|
||
| # set of key events to inject into the play loop as callback | ||
| callback_events = [ | ||
| Event(KEYDOWN, {"key": RELEVANT_KEY}), | ||
| Event(KEYDOWN, {"key": RELEVANT_KEY}), | ||
| Event(QUIT), | ||
| ] | ||
|
|
||
|
|
||
| def callback(obs_t, obs_tp1, action, rew, done, info): | ||
| event.post(callback_events.pop(0)) | ||
| return rew | ||
|
|
||
|
|
||
| def dummy_keys_to_action(): | ||
| return {(ord("a"),): 0, (ord("d"),): 1} | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def close_pygame(): | ||
| yield | ||
| pygame.quit() | ||
|
|
||
|
|
||
| def test_play_relevant_keys(): | ||
| env = DummyPlayEnv() | ||
| game = PlayableGame(env, dummy_keys_to_action()) | ||
| assert game.relevant_keys == {97, 100} | ||
|
||
|
|
||
|
|
||
| def test_play_relevant_keys_no_mapping(): | ||
| env = DummyPlayEnv() | ||
| env.spec = DummyEnvSpec("DummyPlayEnv") | ||
|
|
||
| with pytest.raises(AssertionError) as info: | ||
| PlayableGame(env) | ||
|
|
||
|
|
||
| def test_play_relevant_keys_with_env_attribute(): | ||
| """Env has a keys_to_action attribute""" | ||
| env = DummyPlayEnv() | ||
| env.get_keys_to_action = dummy_keys_to_action | ||
| game = PlayableGame(env) | ||
| assert game.relevant_keys == {97, 100} | ||
|
|
||
|
|
||
| def test_video_size_no_zoom(): | ||
| env = DummyPlayEnv() | ||
| game = PlayableGame(env, dummy_keys_to_action()) | ||
| assert game.video_size == list(env.render().shape) | ||
|
|
||
|
|
||
| def test_video_size_zoom(): | ||
| env = DummyPlayEnv() | ||
| zoom = 2.2 | ||
| game = PlayableGame(env, dummy_keys_to_action(), zoom) | ||
| assert game.video_size == tuple(int(shape * zoom) for shape in env.render().shape) | ||
|
|
||
|
|
||
| def test_keyboard_quit_event(): | ||
| env = DummyPlayEnv() | ||
| game = PlayableGame(env, dummy_keys_to_action()) | ||
| event = Event(pygame.KEYDOWN, {"key": 27}) | ||
| assert game.running == True | ||
| game.process_event(event) | ||
| assert game.running == False | ||
|
|
||
|
|
||
| def test_pygame_quit_event(): | ||
| env = DummyPlayEnv() | ||
| game = PlayableGame(env, dummy_keys_to_action()) | ||
| event = Event(pygame.QUIT) | ||
| assert game.running == True | ||
| game.process_event(event) | ||
| assert game.running == False | ||
|
|
||
|
|
||
| def test_keyboard_relevant_keydown_event(): | ||
| env = DummyPlayEnv() | ||
| game = PlayableGame(env, dummy_keys_to_action()) | ||
| event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY}) | ||
| game.process_event(event) | ||
| assert game.pressed_keys == [RELEVANT_KEY] | ||
|
|
||
|
|
||
| def test_keyboard_irrelevant_keydown_event(): | ||
| env = DummyPlayEnv() | ||
| game = PlayableGame(env, dummy_keys_to_action()) | ||
| event = Event(pygame.KEYDOWN, {"key": IRRELEVANT_KEY}) | ||
| game.process_event(event) | ||
| assert game.pressed_keys == [] | ||
|
|
||
|
|
||
| def test_keyboard_keyup_event(): | ||
| env = DummyPlayEnv() | ||
| game = PlayableGame(env, dummy_keys_to_action()) | ||
| event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY}) | ||
| game.process_event(event) | ||
| event = Event(pygame.KEYUP, {"key": RELEVANT_KEY}) | ||
| game.process_event(event) | ||
| assert game.pressed_keys == [] | ||
|
|
||
|
|
||
| def test_play_loop(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to do the same test with some actual environment? Say, CartPole, define the game, input a bunch of random actions in a loop just to see that it doesn't crash. For a more "advanced" version, if possible, you could instantiate one PlayableGame of an env with a fixed seed, and one normal env with the same fixed seed. You input the same sequence of actions through the game interface, and through the regular
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, following what suggested by @pseudo-rnd-thoughts I'm adding a test with an actual environment. It's a little bit tricky to manage the Keydown and Keyup event with pygame but I'm on the way |
||
| env = DummyPlayEnv() | ||
| cumulative_env_reward = 0 | ||
| for s in range( | ||
| len(callback_events) | ||
| ): # we run the same number of steps executed with play() | ||
| _, rew, _, _ = env.step(None) | ||
| cumulative_env_reward += rew | ||
|
|
||
| env_play = DummyPlayEnv() | ||
| status = PlayStatus(callback) | ||
| play(env_play, callback=status.callback, keys_to_action=dummy_keys_to_action()) | ||
|
|
||
| assert status.cumulative_reward == cumulative_env_reward | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keys_to_action: Optional[dict]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even better if we can specify the key/value types on the dict (I guess values are arbitrary actions, but keys should be keypresses? how are they represented?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's
Dict[Tuple[int], int]where the key is(ord(key),)gym/gym/envs/classic_control/mountain_car.py
Lines 240 to 242 in 5ae6bf9
reference of the actions:
gym/gym/envs/classic_control/mountain_car.py
Lines 50 to 54 in 5ae6bf9
gym/gym/utils/play.py
Lines 125 to 135 in 9a1d4b6