diff --git a/gym/utils/play.py b/gym/utils/play.py index 74667820581..f04ccae4a62 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -1,12 +1,16 @@ -import argparse +from typing import Callable, Dict, Optional, Tuple -import matplotlib import pygame +from numpy.typing import NDArray +from pygame import Surface +from pygame.event import Event import gym -from gym import logger +from gym import Env, logger try: + import matplotlib + matplotlib.use("TkAgg") import matplotlib.pyplot as plt except ImportError as e: @@ -18,7 +22,71 @@ from pygame.locals import VIDEORESIZE -def display_arr(screen, arr, video_size, transpose): +class MissingKeysToAction(Exception): + """Raised when the environment does not have + a default keys_to_action mapping + """ + + +class PlayableGame: + def __init__( + self, + env: Env, + keys_to_action: Optional[Dict[Tuple[int], int]] = None, + zoom: Optional[float] = None, + ): + self.env = env + self.relevant_keys = self._get_relevant_keys(keys_to_action) + self.video_size = self._get_video_size(zoom) + self.screen = pygame.display.set_mode(self.video_size) + self.pressed_keys = [] + self.running = True + + def _get_relevant_keys( + self, keys_to_action: Optional[Dict[Tuple[int], int]] = None + ) -> set: + if keys_to_action is None: + if hasattr(self.env, "get_keys_to_action"): + keys_to_action = self.env.get_keys_to_action() + elif hasattr(self.env.unwrapped, "get_keys_to_action"): + keys_to_action = self.env.unwrapped.get_keys_to_action() + else: + raise MissingKeysToAction( + "%s does not have explicit key to action mapping, " + "please specify one manually" % self.env.spec.id + ) + relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) + return relevant_keys + + def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]: + # TODO: this needs to be updated when the render API change goes through + rendered = self.env.render(mode="rgb_array") + video_size = [rendered.shape[1], rendered.shape[0]] + + if zoom is not None: + video_size = int(video_size[0] * zoom), int(video_size[1] * zoom) + + return video_size + + def process_event(self, event: Event) -> None: + if event.type == pygame.KEYDOWN: + if event.key in self.relevant_keys: + self.pressed_keys.append(event.key) + elif event.key == pygame.K_ESCAPE: + self.running = False + elif event.type == pygame.KEYUP: + if event.key in self.relevant_keys: + self.pressed_keys.remove(event.key) + elif event.type == pygame.QUIT: + self.running = False + elif event.type == VIDEORESIZE: + self.video_size = event.size + self.screen = pygame.display.set_mode(self.video_size) + + +def display_arr( + screen: Surface, arr: NDArray, video_size: Tuple[int, int], transpose: bool +): arr_min, arr_max = arr.min(), arr.max() arr = 255.0 * (arr - arr_min) / (arr_max - arr_min) pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr) @@ -26,7 +94,15 @@ def display_arr(screen, arr, video_size, transpose): screen.blit(pyg_img, (0, 0)) -def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=None): +def play( + env: Env, + transpose: Optional[bool] = True, + fps: Optional[int] = 30, + zoom: Optional[float] = None, + callback: Optional[Callable] = None, + keys_to_action: Optional[Dict[Tuple[int], int]] = None, + seed: Optional[int] = None, +): """Allows one to play the game using keyboard. To simply play the game use: @@ -81,65 +157,35 @@ def callback(obs_t, obs_tp1, action, rew, done, info): # ... } If None, default key_to_action mapping for that env is used, if provided. + seed: bool or None + Random seed used when resetting the environment. If None, no seed is used. """ - env.reset() - rendered = env.render(mode="rgb_array") - - if keys_to_action is None: - if hasattr(env, "get_keys_to_action"): - keys_to_action = env.get_keys_to_action() - elif hasattr(env.unwrapped, "get_keys_to_action"): - keys_to_action = env.unwrapped.get_keys_to_action() - else: - assert False, ( - env.spec.id - + " does not have explicit key to action mapping, " - + "please specify one manually" - ) - relevant_keys = set(sum(map(list, keys_to_action.keys()), [])) - - video_size = [rendered.shape[1], rendered.shape[0]] - if zoom is not None: - video_size = int(video_size[0] * zoom), int(video_size[1] * zoom) + env.reset(seed=seed) + game = PlayableGame(env, keys_to_action, zoom) - pressed_keys = [] - running = True - env_done = True - - screen = pygame.display.set_mode(video_size) + done = True clock = pygame.time.Clock() - while running: - if env_done: - env_done = False - obs = env.reset() + while game.running: + if done: + done = False + obs = env.reset(seed=seed) else: - action = keys_to_action.get(tuple(sorted(pressed_keys)), 0) + action = keys_to_action.get(tuple(sorted(game.pressed_keys)), 0) prev_obs = obs - obs, rew, env_done, info = env.step(action) + obs, rew, done, info = env.step(action) if callback is not None: - callback(prev_obs, obs, action, rew, env_done, info) + callback(prev_obs, obs, action, rew, done, info) if obs is not None: + # TODO: this needs to be updated when the render API change goes through rendered = env.render(mode="rgb_array") - display_arr(screen, rendered, transpose=transpose, video_size=video_size) + display_arr( + game.screen, rendered, transpose=transpose, video_size=game.video_size + ) # process pygame events for event in pygame.event.get(): - # test events, set key states - if event.type == pygame.KEYDOWN: - if event.key in relevant_keys: - pressed_keys.append(event.key) - elif event.key == 27: - running = False - elif event.type == pygame.KEYUP: - if event.key in relevant_keys: - pressed_keys.remove(event.key) - elif event.type == pygame.QUIT: - running = False - elif event.type == VIDEORESIZE: - video_size = event.size - screen = pygame.display.set_mode(video_size) - print(video_size) + game.process_event(event) pygame.display.flip() clock.tick(fps) @@ -180,20 +226,3 @@ def callback(self, obs_t, obs_tp1, action, rew, done, info): ) self.ax[i].set_xlim(xmin, xmax) plt.pause(0.000001) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--env", - type=str, - default="MontezumaRevengeNoFrameskip-v4", - help="Define Environment", - ) - args = parser.parse_args() - env = gym.make(args.env) - play(env, zoom=4, fps=60) - - -if __name__ == "__main__": - main() diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py new file mode 100644 index 00000000000..0f0ee1e46eb --- /dev/null +++ b/tests/utils/test_play.py @@ -0,0 +1,213 @@ +from dataclasses import dataclass, field +from typing import Callable, Optional, Tuple + +import numpy as np +import pygame +import pytest +from pygame import KEYDOWN, KEYUP, QUIT, event +from pygame.event import Event + +import gym +from gym.utils.play import MissingKeysToAction, PlayableGame, play + +RELEVANT_KEY_1 = ord("a") # 97 +RELEVANT_KEY_2 = ord("d") # 100 +IRRELEVANT_KEY = 1 + + +@dataclass +class DummyEnvSpec: + id: str + + +class DummyPlayEnv(gym.Env): + def step(self, action): + obs = np.zeros((1, 1)) + rew, done, info = 1, False, {} + return obs, rew, done, info + + def reset(self, seed=None): + ... + + def render(self, mode="rgb_array"): + return np.zeros((1, 1)) + + +class PlayStatus: + def __init__(self, callback: Callable): + self.data_callback = callback + self.cumulative_reward = 0 + self.last_observation = None + + def callback(self, obs_t, obs_tp1, action, rew, done, info): + _, obs_tp1, _, rew, _, _ = self.data_callback( + obs_t, obs_tp1, action, rew, done, info + ) + self.cumulative_reward += rew + self.last_observation = obs_tp1 + + +def dummy_keys_to_action(): + return {(RELEVANT_KEY_1,): 0, (RELEVANT_KEY_2,): 1} + + +@pytest.fixture(autouse=True) +def close_pygame(): + yield + pygame.quit() + + +def test_play_relevant_keys(): + env = DummyPlayEnv() + game = PlayableGame(env, dummy_keys_to_action()) + assert game.relevant_keys == {RELEVANT_KEY_1, RELEVANT_KEY_2} + + +def test_play_relevant_keys_no_mapping(): + env = DummyPlayEnv() + env.spec = DummyEnvSpec("DummyPlayEnv") + + with pytest.raises(MissingKeysToAction) as info: + PlayableGame(env) + + +def test_play_relevant_keys_with_env_attribute(): + """Env has a keys_to_action attribute""" + env = DummyPlayEnv() + env.get_keys_to_action = dummy_keys_to_action + game = PlayableGame(env) + assert game.relevant_keys == {97, 100} + + +def test_video_size_no_zoom(): + env = DummyPlayEnv() + game = PlayableGame(env, dummy_keys_to_action()) + assert game.video_size == list(env.render().shape) + + +def test_video_size_zoom(): + env = DummyPlayEnv() + zoom = 2.2 + game = PlayableGame(env, dummy_keys_to_action(), zoom) + assert game.video_size == tuple(int(shape * zoom) for shape in env.render().shape) + + +def test_keyboard_quit_event(): + env = DummyPlayEnv() + game = PlayableGame(env, dummy_keys_to_action()) + event = Event(pygame.KEYDOWN, {"key": pygame.K_ESCAPE}) + assert game.running == True + game.process_event(event) + assert game.running == False + + +def test_pygame_quit_event(): + env = DummyPlayEnv() + game = PlayableGame(env, dummy_keys_to_action()) + event = Event(pygame.QUIT) + assert game.running == True + game.process_event(event) + assert game.running == False + + +def test_keyboard_relevant_keydown_event(): + env = DummyPlayEnv() + game = PlayableGame(env, dummy_keys_to_action()) + event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1}) + game.process_event(event) + assert game.pressed_keys == [RELEVANT_KEY_1] + + +def test_keyboard_irrelevant_keydown_event(): + env = DummyPlayEnv() + game = PlayableGame(env, dummy_keys_to_action()) + event = Event(pygame.KEYDOWN, {"key": IRRELEVANT_KEY}) + game.process_event(event) + assert game.pressed_keys == [] + + +def test_keyboard_keyup_event(): + env = DummyPlayEnv() + game = PlayableGame(env, dummy_keys_to_action()) + event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1}) + game.process_event(event) + event = Event(pygame.KEYUP, {"key": RELEVANT_KEY_1}) + game.process_event(event) + assert game.pressed_keys == [] + + +def test_play_loop(): + # set of key events to inject into the play loop as callback + callback_events = [ + Event(KEYDOWN, {"key": RELEVANT_KEY_1}), + Event(KEYDOWN, {"key": RELEVANT_KEY_1}), + Event(QUIT), + ] + + def callback(obs_t, obs_tp1, action, rew, done, info): + event.post(callback_events.pop(0)) + return obs_t, obs_tp1, action, rew, done, info + + env = DummyPlayEnv() + cumulative_env_reward = 0 + for s in range( + len(callback_events) + ): # we run the same number of steps executed with play() + _, rew, _, _ = env.step(None) + cumulative_env_reward += rew + + env_play = DummyPlayEnv() + status = PlayStatus(callback) + play(env_play, callback=status.callback, keys_to_action=dummy_keys_to_action()) + + assert status.cumulative_reward == cumulative_env_reward + + +def test_play_loop_real_env(): + SEED = 42 + ENV = "CartPole-v1" + + # set of key events to inject into the play loop as callback + callback_events = [ + Event(KEYDOWN, {"key": RELEVANT_KEY_1}), + Event(KEYUP, {"key": RELEVANT_KEY_1}), + Event(KEYDOWN, {"key": RELEVANT_KEY_2}), + Event(KEYUP, {"key": RELEVANT_KEY_2}), + Event(KEYDOWN, {"key": RELEVANT_KEY_1}), + Event(KEYUP, {"key": RELEVANT_KEY_1}), + Event(KEYDOWN, {"key": RELEVANT_KEY_1}), + Event(KEYUP, {"key": RELEVANT_KEY_1}), + Event(KEYDOWN, {"key": RELEVANT_KEY_2}), + Event(KEYUP, {"key": RELEVANT_KEY_2}), + Event(QUIT), + ] + keydown_events = [k for k in callback_events if k.type == KEYDOWN] + + def callback(obs_t, obs_tp1, action, rew, done, info): + pygame_event = callback_events.pop(0) + event.post(pygame_event) + + # after releasing a key, post new events until + # we have one keydown + while pygame_event.type == KEYUP: + pygame_event = callback_events.pop(0) + event.post(pygame_event) + + return obs_t, obs_tp1, action, rew, done, info + + env = gym.make(ENV) + env.reset(seed=SEED) + keys_to_action = dummy_keys_to_action() + + # first action is 0 because at the first iteration + # we can not inject a callback event into play() + env.step(0) + for e in keydown_events: + action = keys_to_action[(e.key,)] + obs, _, _, _ = env.step(action) + + env_play = gym.make(ENV) + status = PlayStatus(callback) + play(env_play, callback=status.callback, keys_to_action=keys_to_action, seed=SEED) + + assert (status.last_observation == obs).all()