From ee712998d01b6584c2dd7f75e816a525e32dbae0 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Sat, 19 Feb 2022 22:47:38 +0100 Subject: [PATCH 001/153] Remove references to GoalEnv --- stable_baselines3/common/env_checker.py | 4 ++-- stable_baselines3/common/envs/bit_flipping_env.py | 4 ++-- tests/test_vec_normalize.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index c4e566991c..9ffad69daa 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -176,8 +176,8 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action assert isinstance(done, bool), "The `done` signal must be a boolean" assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" - if isinstance(env, gym.GoalEnv): - # For a GoalEnv, the keys are checked at reset + # Goal conditioned env + if hasattr(env, "compute_reward"): assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index c5d713aa27..50ca1511b8 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -2,13 +2,13 @@ from typing import Any, Dict, Optional, Union import numpy as np -from gym import GoalEnv, spaces +from gym import Env, spaces from gym.envs.registration import EnvSpec from stable_baselines3.common.type_aliases import GymStepReturn -class BitFlippingEnv(GoalEnv): +class BitFlippingEnv(Env): """ Simple bit flipping env, useful to test HER. The goal is to flip all the bits to get a vector of ones. diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index c3d1d3065f..6d319ca9c8 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -40,7 +40,7 @@ def reset(self): return np.array([self.returned_rewards[self.return_reward_idx]]) -class DummyDictEnv(gym.GoalEnv): +class DummyDictEnv(gym.Env): """ Dummy gym goal env for testing purposes """ From 65343f536a99173d6e7814fc6ba6f1ad7962bb57 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Sat, 19 Feb 2022 22:56:33 +0100 Subject: [PATCH 002/153] Fix env tests --- setup.cfg | 1 + tests/test_envs.py | 32 ++++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/setup.cfg b/setup.cfg index 73ae3dbcc7..9ad143ead5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,7 @@ filterwarnings = ignore:Parameters to load are deprecated.:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym + ignore::DeprecationWarning:gym [pytype] inputs = stable_baselines3 diff --git a/tests/test_envs.py b/tests/test_envs.py index b859ed703b..400ce5dec1 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -26,6 +26,8 @@ SimpleMultiObsEnv, ] +GYM_MESSAGE = "Function `rng.randint" + @pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"]) def test_env(env_id): @@ -43,8 +45,8 @@ def test_env(env_id): if env_id == "Pendulum-v1": assert len(record) == 1 else: - # The other environments must pass without warning - assert len(record) == 0 + # The other environments are expected to raise warning introduced by Gym. + check_rand_warning(record) @pytest.mark.parametrize("env_class", ENV_CLASSES) @@ -52,8 +54,8 @@ def test_custom_envs(env_class): env = env_class() with pytest.warns(None) as record: check_env(env) - # No warnings for custom envs - assert len(record) == 0 + # Only randint warning coming from gym + check_rand_warning(record) @pytest.mark.parametrize( @@ -71,8 +73,12 @@ def test_bit_flipping(kwargs): with pytest.warns(None) as record: check_env(env) - # No warnings for custom envs - assert len(record) == 0 + # Only randint warning coming from gym + check_rand_warning(record) + + +def check_rand_warning(record): + assert all(GYM_MESSAGE in warning.message.args[0] for warning in record) def test_high_dimension_action_space(): @@ -150,13 +156,19 @@ def test_non_default_action_spaces(new_action_space): with pytest.warns(None) as record: check_env(env) - # No warnings for custom envs - assert len(record) == 0 + # Only randint warning coming from gym + check_rand_warning(record) + # Change the action space env.action_space = new_action_space - with pytest.warns(UserWarning): - check_env(env) + # Gym raises error for Boxed spaces if low > high + if env.action_space.low[0] > env.action_space.high[0]: + with pytest.raises(ValueError): + check_env(env) + else: + with pytest.warns(UserWarning): + check_env(env) def check_reset_assert_error(env, new_reset_return): From 4d794f36061d5f4c03e6ed405cd5879a626f315f Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Sun, 20 Feb 2022 00:15:16 +0100 Subject: [PATCH 003/153] Fix bug in test creating invalid box space --- tests/test_save_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 7d810c70eb..821567a536 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -27,7 +27,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env: if model_class == DQN: return IdentityEnv(10) else: - return IdentityEnvBox(10) + return IdentityEnvBox(10, 10) @pytest.mark.parametrize("model_class", MODEL_LIST) From 513ed083bdd0241775d707c55ac15338f6413b3a Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Mon, 21 Feb 2022 20:27:15 +0100 Subject: [PATCH 004/153] Add classic_control extra packages from gym --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index eabf30c66e..dab2df975f 100644 --- a/setup.py +++ b/setup.py @@ -115,8 +115,8 @@ "extra": [ # For render "opencv-python", - # For atari games, - "gym[atari,accept-rom-license]>=0.21", + # For atari games and classic control envs + "gym[atari,accept-rom-license,classic_control]>=0.21", "pillow", # Tensorboard support "tensorboard>=2.2.0", From e5195a0cde24b8859abd298fed3482cf35c37f0a Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Mon, 28 Feb 2022 22:56:45 +0100 Subject: [PATCH 005/153] Change back to gym 0.22 for testing --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3922ad8afc..32756915b4 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.21", # Fixed version due to breaking changes in 0.22 + "gym>=0.21", "numpy", "torch>=1.8.1", # For saving models From 435f5fbcbee31ac2cf29c27dfa73f9a3900433a8 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Mon, 28 Feb 2022 22:57:13 +0100 Subject: [PATCH 006/153] Fix failing set_env test --- tests/test_save_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 469147ce6e..f9deb72cf5 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -27,7 +27,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env: if model_class == DQN: return IdentityEnv(10) else: - return IdentityEnvBox(10, 10) + return IdentityEnvBox(-10, 10) @pytest.mark.parametrize("model_class", MODEL_LIST) From f64346a66ea2307c7fb86c4f57f58481676ec5e2 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Thu, 3 Mar 2022 20:58:45 +0100 Subject: [PATCH 007/153] Fix test failiing due to deprectation of env.seed --- stable_baselines3/common/vec_env/dummy_vec_env.py | 7 +++++-- tests/test_vec_monitor.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 5eb87cdb82..3be3996983 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -56,9 +56,12 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: seeds.append(env.seed(seed + idx)) return seeds - def reset(self) -> VecEnvObs: + def reset(self, seed: Optional[int] = None) -> VecEnvObs: for env_idx in range(self.num_envs): - obs = self.envs[env_idx].reset() + if seed: + obs = self.envs[env_idx].reset(seed=seed) + else: + obs = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return self._obs_from_buf() diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 974202b318..7daf7a3e75 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -133,7 +133,7 @@ def test_vec_monitor_ppo(recwarn): Test the `VecMonitor` with PPO """ env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) - env.seed(0) + env.reset(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") model.learn(total_timesteps=250) From daaa84c41e3b99968e34f96f241998a91e735013 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Tue, 8 Mar 2022 21:01:14 +0100 Subject: [PATCH 008/153] Adjust mean reward threshold in failing test --- tests/test_identity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_identity.py b/tests/test_identity.py index f5bbc49467..da64d2dcfc 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -28,7 +28,7 @@ def test_discrete(model_class, env): model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps) - evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) + evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=86, warn=False) obs = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) From e62edde2da5fdc8da88e37cb3eee34d0d49c1b99 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Tue, 8 Mar 2022 21:48:29 +0100 Subject: [PATCH 009/153] Fix her test failing due to rng --- tests/test_her.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_her.py b/tests/test_her.py index 0f6d75f6fa..52927a72b4 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -237,7 +237,7 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la train_freq=4, buffer_size=int(2e4), policy_kwargs=dict(net_arch=[64]), - seed=1, + seed=0, ) model.learn(200) if online_sampling: From 9a41c515a269b87112a19e5b39e880839adc8ab3 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Tue, 8 Mar 2022 22:01:50 +0100 Subject: [PATCH 010/153] Change seed and revert reward threshold to 90 --- tests/test_identity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_identity.py b/tests/test_identity.py index da64d2dcfc..ba190ebeb4 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -26,9 +26,9 @@ def test_discrete(model_class, env): # slightly higher budget n_steps = 3500 - model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.4, seed=2, **kwargs).learn(n_steps) - evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=86, warn=False) + evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) obs = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) From 9c73732fa36943c6b1cb955b0ea5b98dced4489d Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Wed, 16 Mar 2022 21:21:50 +0100 Subject: [PATCH 011/153] Pin gym version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 32756915b4..7ec54e6a14 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym>=0.21", + "gym==0.23.1", "numpy", "torch>=1.8.1", # For saving models From 110be7806e34fdd41648b3edcd7e1c74c90a0d26 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Wed, 16 Mar 2022 21:26:33 +0100 Subject: [PATCH 012/153] Make VecEnv compatible with gym seeding change --- stable_baselines3/common/envs/bit_flipping_env.py | 4 +++- stable_baselines3/common/envs/identity_env.py | 8 ++++++-- stable_baselines3/common/envs/multi_input_envs.py | 7 +++++-- stable_baselines3/common/vec_env/base_vec_env.py | 3 ++- stable_baselines3/common/vec_env/dummy_vec_env.py | 2 +- stable_baselines3/common/vec_env/subproc_vec_env.py | 9 ++++++--- stable_baselines3/common/vec_env/vec_frame_stack.py | 4 ++-- stable_baselines3/common/vec_env/vec_normalize.py | 5 +++-- tests/test_dict_env.py | 6 +++++- tests/test_gae.py | 11 +++++++++-- tests/test_identity.py | 2 +- tests/test_spaces.py | 10 ++++++++-- tests/test_vec_envs.py | 3 ++- tests/test_vec_normalize.py | 13 ++++++++++--- 14 files changed, 63 insertions(+), 24 deletions(-) diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 50ca1511b8..1dff4b0176 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -157,7 +157,9 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ] ) - def reset(self) -> Dict[str, Union[int, np.ndarray]]: + def reset(self, seed: Optional[int] = None) -> Dict[str, Union[int, np.ndarray]]: + if seed is not None: + self.obs_space.seed(seed) self.current_step = 0 self.state = self.obs_space.sample() return self._get_obs() diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 8f6ccd2dce..e06ab63ea8 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -32,7 +32,9 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_ self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self) -> GymObs: + def reset(self, seed: Optional[int] = None) -> GymObs: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 self.num_resets += 1 self._choose_next_state() @@ -136,7 +138,9 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self) -> np.ndarray: + def reset(self, seed: Optional[int] = None) -> np.ndarray: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 return self.observation_space.sample() diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 177a641663..6c6c17ffa1 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict, Optional, Union import gym import numpy as np @@ -166,12 +166,15 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self) -> Dict[str, np.ndarray]: + def reset(self, seed: Optional[int] = None) -> Dict[str, np.ndarray]: """ Resets the environment state and step count and returns reset observation. + :param seed: :return: observation dict {'vec': ..., 'img': ...} """ + if seed is not None: + super().reset(seed=seed) self.count = 0 if not self.random_start: self.state = 0 diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index d3e624af9b..1736f63668 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -61,7 +61,7 @@ def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_sp self.action_space = action_space @abstractmethod - def reset(self) -> VecEnvObs: + def reset(self, seed: Optional[int] = None) -> VecEnvObs: """ Reset all the environments and return an array of observations, or a tuple of observation arrays. @@ -70,6 +70,7 @@ def reset(self) -> VecEnvObs: be cancelled and step_wait() should not be called until step_async() is invoked again. + :param seed: The random seed in case we want to set it / change it. :return: observation """ raise NotImplementedError() diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 3be3996983..06a1c768b3 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -53,7 +53,7 @@ def step_wait(self) -> VecEnvStepReturn: def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: seeds = list() for idx, env in enumerate(self.envs): - seeds.append(env.seed(seed + idx)) + seeds.append(env.reset(seed=seed + idx)) return seeds def reset(self, seed: Optional[int] = None) -> VecEnvObs: diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 1050f3e332..45392fbe7e 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -127,9 +127,12 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: remote.send(("seed", seed + idx)) return [remote.recv() for remote in self.remotes] - def reset(self) -> VecEnvObs: - for remote in self.remotes: - remote.send(("reset", None)) + def reset(self, seed: Optional[int] = None) -> VecEnvObs: + for idx, remote in enumerate(self.remotes): + if seed is not None: + remote.send(("reset", seed + idx)) + else: + remote.send(("reset", None)) obs = [remote.recv() for remote in self.remotes] return _flatten_obs(obs, self.observation_space) diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index e06d5125e0..4a8cf5239f 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -51,11 +51,11 @@ def step_wait( return observations, rewards, dones, infos - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self, seed: Optional[int] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments """ - observation = self.venv.reset() # pytype:disable=annotation-type-mismatch + observation = self.venv.reset(seed) # pytype:disable=annotation-type-mismatch observation = self.stackedobs.reset(observation) return observation diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index f3ee588aba..7620b6ab9e 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -250,12 +250,13 @@ def get_original_reward(self) -> np.ndarray: """ return self.old_reward.copy() - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self, seed: Optional[int] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments + :param seed: :return: first observation of the episode """ - obs = self.venv.reset() + obs = self.venv.reset(seed) self.old_obs = obs self.returns = np.zeros(self.num_envs) if self.training and self.norm_obs: diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 93b13b40e2..d7c8003d63 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,3 +1,5 @@ +from typing import Optional + import gym import numpy as np import pytest @@ -72,7 +74,9 @@ def step(self, action): def compute_reward(self, achieved_goal, desired_goal, info): return np.zeros((len(achieved_goal),)) - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + self.observation_space.seed(seed) return self.observation_space.sample() def render(self, mode="human"): diff --git a/tests/test_gae.py b/tests/test_gae.py index 54e03b8b1a..91275db8d5 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,3 +1,5 @@ +from typing import Optional + import gym import numpy as np import pytest @@ -19,7 +21,9 @@ def __init__(self, max_steps=8): def seed(self, seed): self.observation_space.seed(seed) - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + self.observation_space.seed(seed) self.n_steps = 0 return self.observation_space.sample() @@ -43,7 +47,10 @@ def __init__(self, n_states=4): self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) + self.current_state = 0 return self.current_state diff --git a/tests/test_identity.py b/tests/test_identity.py index ba190ebeb4..b4dee2f41e 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -26,7 +26,7 @@ def test_discrete(model_class, env): # slightly higher budget n_steps = 3500 - model = model_class("MlpPolicy", env_, gamma=0.4, seed=2, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) obs = env.reset() diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 54994b2b5d..bbcea2d7c8 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,3 +1,5 @@ +from typing import Optional + import gym import numpy as np import pytest @@ -13,7 +15,9 @@ def __init__(self, nvec): self.observation_space = gym.spaces.MultiDiscrete(nvec) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) return self.observation_space.sample() def step(self, action): @@ -26,7 +30,9 @@ def __init__(self, n): self.observation_space = gym.spaces.MultiBinary(n) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) return self.observation_space.sample() def step(self, action): diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 9a4c1189b3..ebf2acf894 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,6 +2,7 @@ import functools import itertools import multiprocessing +from typing import Optional import gym import numpy as np @@ -25,7 +26,7 @@ def __init__(self, space): self.current_step = 0 self.ep_length = 4 - def reset(self): + def reset(self, seed: Optional[int] = None): self.current_step = 0 self._choose_next_state() return self.state diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 3951ec9a8d..90b6fe7f53 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,5 +1,6 @@ import operator import warnings +from typing import Optional import gym import numpy as np @@ -36,7 +37,9 @@ def step(self, action): returned_value = self.returned_rewards[index] return np.array([returned_value]), returned_value, self.t == len(self.returned_rewards), {} - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) self.t = 0 return np.array([self.returned_rewards[self.return_reward_idx]]) @@ -57,7 +60,9 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) return self.observation_space.sample() def step(self, action): @@ -87,7 +92,9 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) return self.observation_space.sample() def step(self, action): From dc9c645c41b7745ff587082a9074b1ec9290a441 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Thu, 17 Mar 2022 11:55:31 +0100 Subject: [PATCH 013/153] Revert change to VecEnv reset signature --- stable_baselines3/common/vec_env/base_vec_env.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 1736f63668..d3e624af9b 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -61,7 +61,7 @@ def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_sp self.action_space = action_space @abstractmethod - def reset(self, seed: Optional[int] = None) -> VecEnvObs: + def reset(self) -> VecEnvObs: """ Reset all the environments and return an array of observations, or a tuple of observation arrays. @@ -70,7 +70,6 @@ def reset(self, seed: Optional[int] = None) -> VecEnvObs: be cancelled and step_wait() should not be called until step_async() is invoked again. - :param seed: The random seed in case we want to set it / change it. :return: observation """ raise NotImplementedError() From e1c6e1be5cc763d8a3a2a5ceb8f957d9f987400c Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Thu, 17 Mar 2022 12:02:02 +0100 Subject: [PATCH 014/153] Change subprocenv seed cmd to call reset instead --- stable_baselines3/common/vec_env/subproc_vec_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 45392fbe7e..30f551baf7 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -33,7 +33,7 @@ def _worker( observation = env.reset() remote.send((observation, reward, done, info)) elif cmd == "seed": - remote.send(env.seed(data)) + remote.send(env.reset(seed=data)) elif cmd == "reset": observation = env.reset() remote.send(observation) From 29bd22273e340ba370aaa8219c4dd4358d927cf9 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Thu, 17 Mar 2022 12:40:57 +0100 Subject: [PATCH 015/153] Fix type check --- stable_baselines3/common/vec_env/vec_frame_stack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 4a8cf5239f..e06d5125e0 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -51,11 +51,11 @@ def step_wait( return observations, rewards, dones, infos - def reset(self, seed: Optional[int] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments """ - observation = self.venv.reset(seed) # pytype:disable=annotation-type-mismatch + observation = self.venv.reset() # pytype:disable=annotation-type-mismatch observation = self.stackedobs.reset(observation) return observation From b1730f4e6add8be0ee58019b438543c28fe0617c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 17 Mar 2022 14:40:42 +0100 Subject: [PATCH 016/153] Add backward compat --- .../common/vec_env/dummy_vec_env.py | 15 +++++++----- .../common/vec_env/subproc_vec_env.py | 23 +++++++++++-------- .../common/vec_env/vec_frame_stack.py | 3 +-- .../common/vec_env/vec_normalize.py | 5 ++-- tests/test_vec_envs.py | 17 ++++++++++++++ tests/test_vec_monitor.py | 2 +- 6 files changed, 44 insertions(+), 21 deletions(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 06a1c768b3..93d566d3fd 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,5 +1,6 @@ from collections import OrderedDict from copy import deepcopy +from inspect import signature from typing import Any, Callable, List, Optional, Sequence, Type, Union import gym @@ -53,15 +54,17 @@ def step_wait(self) -> VecEnvStepReturn: def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: seeds = list() for idx, env in enumerate(self.envs): - seeds.append(env.reset(seed=seed + idx)) + if "seed" in signature(env.unwrapped.reset).parameters: + # gym >= 0.23.1 + seeds.append(env.reset(seed=seed + idx)) + else: + # Backward compatibility + seeds.append(env.seed(seed=seed + idx)) return seeds - def reset(self, seed: Optional[int] = None) -> VecEnvObs: + def reset(self) -> VecEnvObs: for env_idx in range(self.num_envs): - if seed: - obs = self.envs[env_idx].reset(seed=seed) - else: - obs = self.envs[env_idx].reset() + obs = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return self._obs_from_buf() diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 30f551baf7..012605fd25 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,5 +1,6 @@ import multiprocessing as mp from collections import OrderedDict +from inspect import signature from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import gym @@ -14,8 +15,10 @@ ) -def _worker( - remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper +def _worker( # noqa: C901 + remote: mp.connection.Connection, + parent_remote: mp.connection.Connection, + env_fn_wrapper: CloudpickleWrapper, ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped @@ -33,7 +36,12 @@ def _worker( observation = env.reset() remote.send((observation, reward, done, info)) elif cmd == "seed": - remote.send(env.reset(seed=data)) + if "seed" in signature(env.unwrapped.reset).parameters: + # gym >= 0.23.1 + remote.send(env.reset(seed=data)) + else: + # Backward compatibility + remote.send(env.seed(seed=data)) elif cmd == "reset": observation = env.reset() remote.send(observation) @@ -127,12 +135,9 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: remote.send(("seed", seed + idx)) return [remote.recv() for remote in self.remotes] - def reset(self, seed: Optional[int] = None) -> VecEnvObs: - for idx, remote in enumerate(self.remotes): - if seed is not None: - remote.send(("reset", seed + idx)) - else: - remote.send(("reset", None)) + def reset(self) -> VecEnvObs: + for remote in self.remotes: + remote.send(("reset", None)) obs = [remote.recv() for remote in self.remotes] return _flatten_obs(obs, self.observation_space) diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index e06d5125e0..5fdb866f87 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -55,8 +55,7 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments """ - observation = self.venv.reset() # pytype:disable=annotation-type-mismatch - + observation = self.venv.reset() observation = self.stackedobs.reset(observation) return observation diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 7620b6ab9e..f3ee588aba 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -250,13 +250,12 @@ def get_original_reward(self) -> np.ndarray: """ return self.old_reward.copy() - def reset(self, seed: Optional[int] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments - :param seed: :return: first observation of the episode """ - obs = self.venv.reset(seed) + obs = self.venv.reset() self.old_obs = obs self.returns = np.zeros(self.num_envs) if self.training and self.norm_obs: diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index ebf2acf894..d89adec6ca 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -441,3 +441,20 @@ def make_monitored_env(): vec_env = VecFrameStack(vec_env, n_stack=2) assert vec_env.env_is_wrapped(Monitor) == [False, True] + + +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_backward_compat_seed(vec_env_class): + def make_env(): + env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + # Patch reset function to remove seed param + env.reset = env.observation_space.sample + env.seed = env.observation_space.seed + return env + + vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) + vec_env.seed(3) + obs = vec_env.reset() + vec_env.seed(3) + new_obs = vec_env.reset() + assert np.allclose(new_obs, obs) diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 7daf7a3e75..b2c33f9927 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -133,7 +133,7 @@ def test_vec_monitor_ppo(recwarn): Test the `VecMonitor` with PPO """ env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) - env.reset(seed=0) + env.seed(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") model.learn(total_timesteps=250) From 00e794680ccb92dda5f2bdca0e6d90fe3f1ea592 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 25 Mar 2022 18:02:57 +0100 Subject: [PATCH 017/153] Add `compat_gym_seed` helper --- stable_baselines3/common/base_class.py | 5 ++++- stable_baselines3/common/env_util.py | 3 ++- stable_baselines3/common/utils.py | 16 ++++++++++++++++ .../common/vec_env/dummy_vec_env.py | 11 ++++------- .../common/vec_env/subproc_vec_env.py | 9 ++------- tests/test_monitor.py | 6 +++--- 6 files changed, 31 insertions(+), 19 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 25c26382c7..11422226b3 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -23,6 +23,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import ( check_for_correct_spaces, + compat_gym_seed, get_device, get_schedule_fn, get_system_info, @@ -572,10 +573,12 @@ def set_random_seed(self, seed: Optional[int] = None) -> None: return set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type) self.action_space.seed(seed) + # self.env is always a VecEnv if self.env is not None: self.env.seed(seed) + # Eval env may be a gym.Env, hence the call to compat_gym_seed() if self.eval_env is not None: - self.eval_env.seed(seed) + compat_gym_seed(self.eval_env, seed=seed) def set_parameters( self, diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 520c50a5f4..82cfc3e968 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -5,6 +5,7 @@ from stable_baselines3.common.atari_wrappers import AtariWrapper from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.utils import compat_gym_seed from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv @@ -81,7 +82,7 @@ def _init(): else: env = env_id(**env_kwargs) if seed is not None: - env.seed(seed + rank) + compat_gym_seed(env, seed=seed + rank) env.action_space.seed(seed + rank) # Wrap the env in a Monitor wrapper # to have additional training information diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 8504c8d4bd..b47510789f 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -3,6 +3,7 @@ import platform import random from collections import deque +from inspect import signature from itertools import zip_longest from typing import Dict, Iterable, Optional, Tuple, Union @@ -503,3 +504,18 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: if print_info: print(env_info_str) return env_info, env_info_str + + +def compat_gym_seed(env: GymEnv, seed: int) -> None: + """ + Compatibility helper to seed Gym envs. + + :param env: The Gym environment. + :param seed: The seed for the pseudo random generator + """ + if isinstance(env, gym.Env) and "seed" in signature(env.unwrapped.reset).parameters: + # gym >= 0.23.1 + env.reset(seed=seed) + else: + # VecEnv and backward compatibility + env.seed(seed) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 93d566d3fd..50a8feb090 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,6 +1,5 @@ from collections import OrderedDict from copy import deepcopy -from inspect import signature from typing import Any, Callable, List, Optional, Sequence, Type, Union import gym @@ -52,14 +51,12 @@ def step_wait(self) -> VecEnvStepReturn: return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + # Avoid circular import + from stable_baselines3.common.utils import compat_gym_seed + seeds = list() for idx, env in enumerate(self.envs): - if "seed" in signature(env.unwrapped.reset).parameters: - # gym >= 0.23.1 - seeds.append(env.reset(seed=seed + idx)) - else: - # Backward compatibility - seeds.append(env.seed(seed=seed + idx)) + seeds.append(compat_gym_seed(env, seed=seed + idx)) return seeds def reset(self) -> VecEnvObs: diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 012605fd25..c834b2ea96 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,6 +1,5 @@ import multiprocessing as mp from collections import OrderedDict -from inspect import signature from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import gym @@ -22,6 +21,7 @@ def _worker( # noqa: C901 ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped + from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() env = env_fn_wrapper.var() @@ -36,12 +36,7 @@ def _worker( # noqa: C901 observation = env.reset() remote.send((observation, reward, done, info)) elif cmd == "seed": - if "seed" in signature(env.unwrapped.reset).parameters: - # gym >= 0.23.1 - remote.send(env.reset(seed=data)) - else: - # Backward compatibility - remote.send(env.seed(seed=data)) + remote.send(compat_gym_seed(env, seed=data)) elif cmd == "reset": observation = env.reset() remote.send(observation) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index d3d041b4de..ec9dd0016d 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -13,7 +13,7 @@ def test_monitor(tmp_path): Test the monitor wrapper """ env = gym.make("CartPole-v1") - env.seed(0) + env.reset(seed=0) monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) monitor_env = Monitor(env, monitor_file) monitor_env.reset() @@ -55,7 +55,7 @@ def test_monitor_load_results(tmp_path): """ tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") - env1.seed(0) + env1.reset(seed=0) monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) monitor_env1 = Monitor(env1, monitor_file1) @@ -75,7 +75,7 @@ def test_monitor_load_results(tmp_path): assert results_size1 == episode_count1 env2 = gym.make("CartPole-v1") - env2.seed(0) + env2.reset(seed=0) monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) From a116a1a2a658fea24034c31111773301e4794601 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Mon, 11 Apr 2022 23:19:12 +0200 Subject: [PATCH 018/153] Add goal env checks in env_checker --- stable_baselines3/common/env_checker.py | 40 +++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 9ffad69daa..dca915307c 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -93,6 +93,27 @@ def _check_nan(env: gym.Env) -> None: _, _, _, _ = vec_env.step(action) +def _is_goal_env(env: gym.Env) -> bool: + """ + Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface) + """ + return hasattr(env, "compute_reward") + + +def _check_goal_env_obs(obs: dict, observation_space: spaces.Space, method_name: str) -> None: + """ + Check that an environment implementing the `compute_rewards()` method (previously known as + GoalEnv in gym) contains at least three elements, namely `observation`, `desired_goal`, and + `achieved_goal`. + """ + for key in ["observation", "achieved_goal", "desired_goal"]: + if key not in observation_space.spaces: + raise AssertionError( + f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the {key}" + "key to be part of the observation dictionary." + ) + + def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None: """ Check that the observation returned by the environment @@ -141,7 +162,9 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists obs = env.reset() - if isinstance(observation_space, spaces.Dict): + if _is_goal_env(env): + _check_goal_env_obs(obs, observation_space, "reset") + elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary" for key in observation_space.spaces.keys(): try: @@ -160,14 +183,15 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # Unpack obs, reward, done, info = data - if isinstance(observation_space, spaces.Dict): + if _is_goal_env(env): + _check_goal_env_obs(obs, observation_space, "step") + elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" for key in observation_space.spaces.keys(): try: _check_obs(obs[key], observation_space.spaces[key], "step") except AssertionError as e: raise AssertionError(f"Error while checking key={key}: " + str(e)) - else: _check_obs(obs, observation_space, "step") @@ -183,8 +207,9 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action def _check_spaces(env: gym.Env) -> None: """ - Check that the observation and action spaces are defined - and inherit from gym.spaces.Space. + Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For + envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check + the observation space is gym.spaces.Dict """ # Helper to link to the code, because gym has no proper documentation gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/" @@ -195,6 +220,11 @@ def _check_spaces(env: gym.Env) -> None: assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces + if _is_goal_env(env): + assert isinstance( + env.observation_space, spaces.Dict + ), "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gym.spaces.Dict" + # Check render cannot be covered by CI def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover From 87809194008c03bf3657fa384a1a15456a4912e1 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Mon, 11 Apr 2022 23:19:58 +0200 Subject: [PATCH 019/153] Add docs on HER requirements for envs --- docs/modules/her.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/modules/her.rst b/docs/modules/her.rst index 82bf745d65..21c6d0c8d8 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -22,7 +22,10 @@ It creates "virtual" transitions by relabeling transitions (changing the desired .. warning:: - HER requires the environment to inherits from `gym.GoalEnv `_ + HER requires the environment to follow the legacy `gym.GoalEnv `_ + In short, the `gym.Env` must have: + - a vectorized implementation of ``compute_reward()`` + - a dictionary observation space with at least three keys: ``observation``, ``achieved_goal`` and ``desired_goal`` .. warning:: From c2ab5cdac0dcaf1d0789433f5de76e3004d11fe9 Mon Sep 17 00:00:00 2001 From: Carlos Luis Date: Wed, 13 Apr 2022 23:10:49 +0200 Subject: [PATCH 020/153] Capture user warning in test with inverted box space --- tests/test_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index e6c6c9515a..a42247b99e 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -165,7 +165,7 @@ def test_non_default_action_spaces(new_action_space): # Gym raises error for Boxed spaces if low > high if env.action_space.low[0] > env.action_space.high[0]: - with pytest.raises(ValueError): + with pytest.raises(ValueError), pytest.warns(UserWarning): check_env(env) else: with pytest.warns(UserWarning): From cb50e9e12e9b4f75fff7f2492e99e35e8bb2f492 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 18 Apr 2022 21:34:45 +0200 Subject: [PATCH 021/153] Update ale-py version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 94fc4a6727..cca07cc05c 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ "opencv-python", "pygame", # For atari games, - "ale-py~=0.7.4", + "ale-py~=0.7.5", "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support From de3f086fcd0ec155fc525d95d266847d1deb15ce Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 19 Apr 2022 09:53:30 +0200 Subject: [PATCH 022/153] Fix randint --- stable_baselines3/common/atari_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 832ad9f235..b395348e14 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -33,7 +33,7 @@ def reset(self, **kwargs) -> np.ndarray: if self.override_num_noops is not None: noops = self.override_num_noops else: - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) for _ in range(noops): From 1980db2faab03c54d242c1b7242dd1437e99e5eb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 19 Apr 2022 10:26:00 +0200 Subject: [PATCH 023/153] Allow noop_max to be zero --- stable_baselines3/common/atari_wrappers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index b395348e14..ef4b03e063 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -235,8 +235,10 @@ def __init__( terminal_on_life_loss: bool = True, clip_reward: bool = True, ): - env = NoopResetEnv(env, noop_max=noop_max) - env = MaxAndSkipEnv(env, skip=frame_skip) + if noop_max > 0: + env = NoopResetEnv(env, noop_max=noop_max) + if frame_skip > 0: + env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) if "FIRE" in env.unwrapped.get_action_meanings(): From 90adf8fb3ea10ac67bd8a58082e8287f848de044 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 8 May 2022 15:12:11 +0200 Subject: [PATCH 024/153] Update changelog --- docs/misc/changelog.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 42a1d5a3c8..f63bf3c24b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,9 +12,11 @@ Breaking Changes: - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) - SB3 now requires PyTorch >= 1.11 +- Switched minimum Gym version to 0.23.1 (@carlosluis) New Features: ^^^^^^^^^^^^^ +- ``noop_max`` and ``frame_skip`` are now allowed to be equal to zero when using ``AtariWrapper`` SB3-Contrib ^^^^^^^^^^^ @@ -48,7 +50,7 @@ Release 1.5.0 (2022-03-25) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Switched minimum Gym version to 0.21.0. +- Switched minimum Gym version to 0.21.0 New Features: ^^^^^^^^^^^^^ @@ -968,4 +970,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies @bcollazo +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @carlosluis From 0f240fa0a0d230c04cd30e87b9cd0c41e6d40667 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 10 May 2022 11:21:16 +0200 Subject: [PATCH 025/153] Update docker image --- .gitlab-ci.yml | 2 +- Dockerfile | 3 +++ docs/misc/changelog.rst | 1 + scripts/build_docker.sh | 6 +++--- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 20953d2f6a..d9a3f7120e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:1.4.1a0 +image: stablebaselines/stable-baselines3-cpu:1.5.1a6 type-check: script: diff --git a/Dockerfile b/Dockerfile index 8dfbbbf4cc..e85a3c46ad 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,9 @@ FROM $PARENT_IMAGE ARG PYTORCH_DEPS=cpuonly ARG PYTHON_VERSION=3.7 +# for tzdata +ENV DEBIAN_FRONTEND="noninteractive" TZ="Europe/Paris" + RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ cmake \ diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d57d2b0c85..14a46f15bf 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -34,6 +34,7 @@ Deprecations: Others: ^^^^^^^ - Upgraded to Python 3.7+ syntax using ``pyupgrade`` +- Updated docker base image to Ubuntu 20.04 and cuda 11.3 Documentation: ^^^^^^^^^^^^^^ diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 13ac86b17d..3f0d5ae7c9 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -1,14 +1,14 @@ #!/bin/bash -CPU_PARENT=ubuntu:18.04 -GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 +CPU_PARENT=ubuntu:20.04 +GPU_PARENT=nvidia/cuda:11.3.1-base-ubuntu20.04 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) if [[ ${USE_GPU} == "True" ]]; then PARENT=${GPU_PARENT} - PYTORCH_DEPS="cudatoolkit=10.1" + PYTORCH_DEPS="cudatoolkit=11.3" else PARENT=${CPU_PARENT} PYTORCH_DEPS="cpuonly" From 3087d58b0bb8b513c0b260e68c22d533d562c753 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 13 May 2022 17:33:00 +0200 Subject: [PATCH 026/153] Update doc conda env and dockerfile --- Dockerfile | 2 +- docs/conda_env.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index e85a3c46ad..96588ef91d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,7 +23,7 @@ RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest ~/miniconda.sh -b -p /opt/conda && \ rm ~/miniconda.sh && \ /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \ - /opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \ + /opt/conda/bin/conda install -y pytorch=1.11 $PYTORCH_DEPS -c pytorch && \ /opt/conda/bin/conda clean -ya ENV PATH /opt/conda/bin:$PATH diff --git a/docs/conda_env.yml b/docs/conda_env.yml index a01d37bcec..d3f363df74 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -6,7 +6,7 @@ dependencies: - cpuonly=1.0=0 - pip=21.1 - python=3.7 - - pytorch=1.8.1=py3.7_cpu_0 + - pytorch=1.11.0=py3.7_cpu_0 - pip: - gym>=0.17.2 - cloudpickle From 76290160f5df09b52836ff80f35e81e86bdcc2a6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 13 May 2022 20:32:48 +0200 Subject: [PATCH 027/153] Custom envs should not have any warnings --- docs/modules/her.rst | 6 +++--- tests/test_envs.py | 29 ++++++++++------------------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/docs/modules/her.rst b/docs/modules/her.rst index 21c6d0c8d8..7f3feefbc6 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -22,10 +22,10 @@ It creates "virtual" transitions by relabeling transitions (changing the desired .. warning:: - HER requires the environment to follow the legacy `gym.GoalEnv `_ - In short, the `gym.Env` must have: + HER requires the environment to follow the legacy `gym.GoalEnv interface `_ + In short, the ``gym.Env`` must have: - a vectorized implementation of ``compute_reward()`` - - a dictionary observation space with at least three keys: ``observation``, ``achieved_goal`` and ``desired_goal`` + - a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal`` .. warning:: diff --git a/tests/test_envs.py b/tests/test_envs.py index a42247b99e..b340e965d8 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -46,8 +46,8 @@ def test_env(env_id): if env_id == "Pendulum-v1": assert len(record) == 1 else: - # The other environments are expected to raise warning introduced by Gym. - check_rand_warning(record) + # The other environments must pass without warning + assert len(record) == 0 @pytest.mark.parametrize("env_class", ENV_CLASSES) @@ -55,8 +55,8 @@ def test_custom_envs(env_class): env = env_class() with warnings.catch_warnings(record=True) as record: check_env(env) - # Only randint warning coming from gym - check_rand_warning(record) + # No warnings for custom envs + assert len(record) == 0 @pytest.mark.parametrize( @@ -74,12 +74,8 @@ def test_bit_flipping(kwargs): with warnings.catch_warnings(record=True) as record: check_env(env) - # Only randint warning coming from gym - check_rand_warning(record) - - -def check_rand_warning(record): - assert all(GYM_MESSAGE in warning.message.args[0] for warning in record) + # No warnings for custom envs + assert len(record) == 0 def test_high_dimension_action_space(): @@ -157,19 +153,14 @@ def test_non_default_action_spaces(new_action_space): with warnings.catch_warnings(record=True) as record: check_env(env) - # Only randint warning coming from gym - check_rand_warning(record) + # No warnings for custom envs + assert len(record) == 0 # Change the action space env.action_space = new_action_space - # Gym raises error for Boxed spaces if low > high - if env.action_space.low[0] > env.action_space.high[0]: - with pytest.raises(ValueError), pytest.warns(UserWarning): - check_env(env) - else: - with pytest.warns(UserWarning): - check_env(env) + with pytest.warns(UserWarning): + check_env(env) def check_reset_assert_error(env, new_reset_return): From 7bb643bf363f6d265bfe782d6c33f0515321df37 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 13 May 2022 20:58:55 +0200 Subject: [PATCH 028/153] Fix test for numpy >= 1.21 --- tests/test_envs.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index b340e965d8..69dabe5be2 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -27,8 +27,6 @@ SimpleMultiObsEnv, ] -GYM_MESSAGE = "Function `rng.randint" - @pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"]) def test_env(env_id): @@ -159,8 +157,14 @@ def test_non_default_action_spaces(new_action_space): # Change the action space env.action_space = new_action_space - with pytest.warns(UserWarning): - check_env(env) + low, high = new_action_space.low[0], new_action_space.high[0] + # numpy >= 1.21 raises a ValueError + if int(np.__version__.split(".")[1]) >= 21 and (low > high): + with pytest.raises(ValueError), pytest.warns(UserWarning): + check_env(env) + else: + with pytest.warns(UserWarning): + check_env(env) def check_reset_assert_error(env, new_reset_return): From 706f07271cb7582dd05243c61adf1223dd9d4a51 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 13 May 2022 21:40:39 +0200 Subject: [PATCH 029/153] Add check for vectorized compute reward --- stable_baselines3/common/env_checker.py | 52 +++++++++++++++++++++---- tests/test_envs.py | 11 ++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index dca915307c..20ad629322 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -1,5 +1,5 @@ import warnings -from typing import Union +from typing import Any, Dict, Union import gym import numpy as np @@ -102,19 +102,56 @@ def _is_goal_env(env: gym.Env) -> bool: def _check_goal_env_obs(obs: dict, observation_space: spaces.Space, method_name: str) -> None: """ - Check that an environment implementing the `compute_rewards()` method (previously known as - GoalEnv in gym) contains at least three elements, namely `observation`, `desired_goal`, and - `achieved_goal`. + Check that an environment implementing the `compute_rewards()` method + (previously known as GoalEnv in gym) contains three elements, + namely `observation`, `desired_goal`, and `achieved_goal`. """ + assert len(observation_space.spaces) == 3, ( + "A goal conditioned env must contain 3 observation keys: `observation`, `desired_goal`, and `achieved_goal`." + f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}" + ) + for key in ["observation", "achieved_goal", "desired_goal"]: if key not in observation_space.spaces: raise AssertionError( - f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the {key}" - "key to be part of the observation dictionary." + f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' " + "key to be part of the observation dictionary. " + f"Current keys are {list(observation_space.spaces.keys())}" ) -def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None: +def _check_goal_env_compute_reward( + obs: Dict[str, Union[np.ndarray, int]], + env: gym.Env, + reward: float, + info: Dict[str, Any], +): + """ + Check that reward is computed with `compute_reward` + and that the implementation is vectorized. + """ + achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"] + assert reward == env.compute_reward( + achieved_goal, desired_goal, info + ), "The reward was not computed with `compute_reward()`" + + achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal) + batch_achieved_goals = np.array([achieved_goal, achieved_goal]) + batch_desired_goals = np.array([desired_goal, desired_goal]) + if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0: + batch_achieved_goals = batch_achieved_goals.reshape(2, 1) + batch_desired_goals = batch_desired_goals.reshape(2, 1) + batch_infos = np.array([info, info]) + rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) + assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)" + assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}" + + +def _check_obs( + obs: Union[tuple, dict, np.ndarray, int], + observation_space: spaces.Space, + method_name: str, +) -> None: """ Check that the observation returned by the environment correspond to the declared one. @@ -185,6 +222,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action if _is_goal_env(env): _check_goal_env_obs(obs, observation_space, "step") + _check_goal_env_compute_reward(obs, env, reward, info) elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" for key in observation_space.spaces.keys(): diff --git a/tests/test_envs.py b/tests/test_envs.py index 69dabe5be2..b883559f05 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -75,6 +75,17 @@ def test_bit_flipping(kwargs): # No warnings for custom envs assert len(record) == 0 + # Remove a key, must throw an error + obs_space = env.observation_space.spaces["observation"] + del env.observation_space.spaces["observation"] + with pytest.raises(AssertionError): + check_env(env) + + # Rename a key, must throw an error + env.observation_space.spaces["obs"] = obs_space + with pytest.raises(AssertionError): + check_env(env) + def test_high_dimension_action_space(): """ From 77d188fb6f5b74fd81d9bed1d61dacb6bed453d1 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 17:52:19 +0200 Subject: [PATCH 030/153] Bump to gym 0.24 --- docs/misc/changelog.rst | 2 +- setup.py | 2 +- tests/test_envs.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5b7c45e1de..e54b35c395 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,7 +12,7 @@ Breaking Changes: - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) - SB3 now requires PyTorch >= 1.11 -- Switched minimum Gym version to 0.23.1 (@carlosluis) +- Switched minimum Gym version to 0.24 (@carlosluis) New Features: ^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index f3b6ec5f1d..c4d175a3da 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.23.1", + "gym==0.24", "numpy", "torch>=1.11", # For saving models diff --git a/tests/test_envs.py b/tests/test_envs.py index cc8073ceb2..df8c396f62 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -175,6 +175,7 @@ def test_non_default_action_spaces(new_action_space): # the rest only warning if not np.all(np.isfinite(env.action_space.low)): with pytest.raises(AssertionError), pytest.warns(UserWarning): + check_env(env) # numpy >= 1.21 raises a ValueError elif int(np.__version__.split(".")[1]) >= 21 and (low > high): with pytest.raises(ValueError), pytest.warns(UserWarning): From 68cec40622ab3527252dc3866b1bdc8a9773a45d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 18:03:30 +0200 Subject: [PATCH 031/153] Fix gym default step docstring --- stable_baselines3/common/envs/bit_flipping_env.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index a768f9d182..5089ee7d30 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -165,6 +165,12 @@ def reset(self, seed: Optional[int] = None) -> Dict[str, Union[int, np.ndarray]] return self._get_obs() def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + """ + Step into the env. + + :param action: + :return: + """ if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: From 0072b77156c006ada8a1d6e26ce347ed85a83eeb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 18:14:35 +0200 Subject: [PATCH 032/153] Test downgrading gym --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c4d175a3da..f3b6ec5f1d 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.24", + "gym==0.23.1", "numpy", "torch>=1.11", # For saving models From 07a85b83b9297fd3653109fde286d647b5c5f190 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 18:36:17 +0200 Subject: [PATCH 033/153] Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f3b6ec5f1d..c4d175a3da 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.23.1", + "gym==0.24", "numpy", "torch>=1.11", # For saving models From d755cc6320a12151e7871d321591b8aeecd458d0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 18:40:44 +0200 Subject: [PATCH 034/153] Fix protobuf error --- docs/conda_env.yml | 2 +- setup.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/conda_env.yml b/docs/conda_env.yml index d3f363df74..3f4f25e862 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -8,7 +8,7 @@ dependencies: - python=3.7 - pytorch=1.11.0=py3.7_cpu_0 - pip: - - gym>=0.17.2 + - gym==0.24 - cloudpickle - opencv-python-headless - pandas diff --git a/setup.py b/setup.py index c4d175a3da..030a943e84 100644 --- a/setup.py +++ b/setup.py @@ -111,6 +111,9 @@ "sphinxcontrib.spelling", # Type hints support "sphinx-autodoc-typehints", + # weird error with protobuf > 4.0 + # probably due to tensorboard + "protobuf~=3.19.0", ], "extra": [ # For render From 99b91eb117ffb1b2126738f0aed0c92c8db9ca2a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 18:49:20 +0200 Subject: [PATCH 035/153] Fix in dependencies --- Makefile | 3 ++- docs/conda_env.yml | 6 +++--- setup.py | 7 ++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index 9954c7d7b1..02851cf45e 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,8 @@ check-codestyle: commit-checks: format type lint doc: - cd docs && make html + # Prevent weird error due to protobuf + cd docs && PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp make html spelling: cd docs && make spelling diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 3f4f25e862..243b2540ae 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: - cpuonly=1.0=0 - - pip=21.1 + - pip=22.1.1 - python=3.7 - pytorch=1.11.0=py3.7_cpu_0 - pip: @@ -15,6 +15,6 @@ dependencies: - numpy - matplotlib - sphinx_autodoc_typehints - - sphinx>=4.2 + - sphinx~=4.2 # See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 - - sphinx_rtd_theme>=1.0 + - sphinx_rtd_theme~=1.0 diff --git a/setup.py b/setup.py index 030a943e84..23726f6719 100644 --- a/setup.py +++ b/setup.py @@ -101,19 +101,16 @@ # Reformat "black", # For toy text Gym envs - "scipy>=1.4.1", + "scipy~=1.4.1", ], "docs": [ - "sphinx", + "sphinx~=4.5.0", "sphinx-autobuild", "sphinx-rtd-theme", # For spelling "sphinxcontrib.spelling", # Type hints support "sphinx-autodoc-typehints", - # weird error with protobuf > 4.0 - # probably due to tensorboard - "protobuf~=3.19.0", ], "extra": [ # For render From 1d7da08585577de8c40b18c9b520413ca4cc08f5 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 19:51:27 +0200 Subject: [PATCH 036/153] Fix protobuf dep --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 23726f6719..c277848a7c 100644 --- a/setup.py +++ b/setup.py @@ -122,6 +122,9 @@ "pillow", # Tensorboard support "tensorboard>=2.2.0", + # Protobuf >= 4 has breaking changes + # which does play well with tensorboard + "protobuf~=3.19.0", # Checking memory taken by replay buffer "psutil", ], From cf7e43889ecd0bb27723c98ae61ff9c85020166c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 23:34:37 +0200 Subject: [PATCH 037/153] Use newest version of cartpole --- tests/test_callbacks.py | 2 +- tests/test_envs.py | 2 +- tests/test_predict.py | 2 +- tests/test_utils.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 6576f7dc32..e6e2ec1784 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -102,7 +102,7 @@ def test_callbacks(tmp_path, model_class): def select_env(model_class) -> str: if model_class is DQN: - return "CartPole-v0" + return "CartPole-v1" else: return "Pendulum-v1" diff --git a/tests/test_envs.py b/tests/test_envs.py index df8c396f62..3397b3d950 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -28,7 +28,7 @@ ] -@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"]) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_env(env_id): """ Check that environmnent integrated in Gym pass the test. diff --git a/tests/test_predict.py b/tests/test_predict.py index 853f4d11db..da84cab361 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -41,7 +41,7 @@ def test_auto_wrap(model_class): # Use different environment for DQN if model_class is DQN: - env_name = "CartPole-v0" + env_name = "CartPole-v1" else: env_name = "Pendulum-v1" env = gym.make(env_name) diff --git a/tests/test_utils.py b/tests/test_utils.py index 67f2ad1a32..a4f1c73ccc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -229,7 +229,7 @@ def test_evaluate_policy_monitors(vec_env_class): # Also test VecEnvs n_eval_episodes = 3 n_envs = 2 - env_id = "CartPole-v0" + env_id = "CartPole-v1" model = A2C("MlpPolicy", env_id, seed=0) def make_eval_env(with_monitor, wrapper_class=gym.Wrapper): From 626db1d45e0975a56e535bddfcf32a3f101e1576 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 10 Jun 2022 14:36:27 +0200 Subject: [PATCH 038/153] Update gym --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c63cc4585d..baf09c0895 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.24", + "gym==0.24.1", "numpy", "torch>=1.11", # For saving models From f86450138665a5e8ed4489f685c67cd0f0662e11 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 14 Jun 2022 16:55:15 +0200 Subject: [PATCH 039/153] Fix warning --- tests/test_vec_monitor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index c9907a5492..ea93413acb 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -132,7 +132,7 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ - env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) + env = DummyVecEnv([lambda: gym.make("CartPole-v1", disable_env_checker=True)]) env.seed(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") @@ -140,7 +140,7 @@ def test_vec_monitor_ppo(recwarn): # No warnings because using `VecMonitor` evaluate_policy(model, monitor_env) - assert len(recwarn) == 0 + assert len(recwarn) == 0, f"{[str(warning) for warning in recwarn]}" def test_vec_monitor_warn(): From 5fa3cd96b707f71b1a44439e3748edad9fc648c8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 23 Jun 2022 21:25:28 +0200 Subject: [PATCH 040/153] Loosen required scipy version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index baf09c0895..ebf0b19d21 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ # Reformat "black", # For toy text Gym envs - "scipy~=1.4.1", + "scipy>=1.4.1", ], "docs": [ "sphinx~=4.5.0", From e79148be6e5a79f52981cf5b993fc456a278960f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 23 Jun 2022 21:28:43 +0200 Subject: [PATCH 041/153] Scipy no longer needed --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index ebf0b19d21..fed5305ff4 100644 --- a/setup.py +++ b/setup.py @@ -100,8 +100,6 @@ "isort>=5.0", # Reformat "black", - # For toy text Gym envs - "scipy>=1.4.1", ], "docs": [ "sphinx~=4.5.0", From 1ad5f7841a7ba3cdb8e2fbfcb8fa881af1d25647 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 18 Jul 2022 11:52:43 +0200 Subject: [PATCH 042/153] Try gym 0.25 --- docs/conda_env.yml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 78a7032c53..101c83cdb9 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -8,7 +8,7 @@ dependencies: - python=3.7 - pytorch=1.11.0=py3.7_cpu_0 - pip: - - gym==0.24 + - gym==0.25 - cloudpickle - opencv-python-headless - pandas diff --git a/setup.py b/setup.py index cda7c8eed8..013280f8a7 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.24.1", + "gym==0.25", "numpy", "torch>=1.11", # For saving models From 3f0b531cb38b04b764f5f83609eee32cd2a848af Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 25 Jul 2022 15:06:27 +0200 Subject: [PATCH 043/153] Silence warnings from gym --- setup.cfg | 1 + stable_baselines3/__init__.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index 5bc66c20cb..9f20cfb907 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,6 +15,7 @@ filterwarnings = ignore::UserWarning:gym ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning + ignore:.*step API:DeprecationWarning:gym markers = expensive: marks tests as expensive (deselect with '-m "not expensive"') diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index d73f5f095e..4792f6c152 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,4 +1,5 @@ import os +import warnings from stable_baselines3.a2c import A2C from stable_baselines3.common.utils import get_system_info @@ -14,6 +15,9 @@ with open(version_file) as file_handler: __version__ = file_handler.read().strip() +# Silence Gym warnings due to new API +warnings.filterwarnings("ignore", message=r".*step API", module="gym") + def HER(*args, **kwargs): raise ImportError( From b27e5553aa7a24e70d0d0ad2b882613acd075839 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 25 Jul 2022 23:22:44 +0200 Subject: [PATCH 044/153] Filter warnings during tests --- setup.cfg | 8 ++++---- tests/test_vec_monitor.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 9f20cfb907..bd04ac9e92 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,11 +10,11 @@ filterwarnings = # Tensorboard warnings ignore::DeprecationWarning:tensorboard # Gym warnings - ignore:Parameters to load are deprecated.:DeprecationWarning - ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning + ; ignore:Parameters to load are deprecated.:DeprecationWarning + ; ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym - ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning - ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning + ; ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning + ; ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning ignore:.*step API:DeprecationWarning:gym markers = expensive: marks tests as expensive (deselect with '-m "not expensive"') diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index ea93413acb..bbf5e8d216 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,6 +2,7 @@ import json import os import uuid +import warnings import gym import pandas @@ -132,6 +133,9 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ + # Remove Gym Warnings + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="gym") + env = DummyVecEnv([lambda: gym.make("CartPole-v1", disable_env_checker=True)]) env.seed(seed=0) monitor_env = VecMonitor(env) From 8c65748bc1879c42b1c67b86ee22d321c9830cdb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 1 Oct 2022 12:05:13 +0200 Subject: [PATCH 045/153] Update doc --- .github/ISSUE_TEMPLATE/custom_env.md | 7 ++++--- docs/conda_env.yml | 2 +- docs/guide/examples.rst | 15 ++++++++------- docs/guide/quickstart.rst | 20 +++++++++++++------- docs/misc/changelog.rst | 2 +- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/custom_env.md b/.github/ISSUE_TEMPLATE/custom_env.md index 0a12a68bb6..f28d370b88 100644 --- a/.github/ISSUE_TEMPLATE/custom_env.md +++ b/.github/ISSUE_TEMPLATE/custom_env.md @@ -44,19 +44,20 @@ from stable_baselines3.common.env_checker import check_env class CustomEnv(gym.Env): def __init__(self): - super(CustomEnv, self).__init__() + super().__init__() self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,)) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = 1.0 done = False + truncated = False info = {} - return obs, reward, done, info + return obs, reward, done, truncated, info env = CustomEnv() check_env(env) diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 101c83cdb9..7b89ba92bd 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -8,7 +8,7 @@ dependencies: - python=3.7 - pytorch=1.11.0=py3.7_cpu_0 - pip: - - gym==0.25 + - gym==0.26 - cloudpickle - opencv-python-headless - pandas diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 247c86cd9e..426cd8fffc 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -94,11 +94,12 @@ In the following example, we will train, save and load a DQN model on the Lunar mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10) # Enjoy trained agent - obs = env.reset() + vec_env = model.get_env() + obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, rewards, dones, info = env.step(action) - env.render() + obs, rewards, dones, info = vec_env.step(action) + vec_env.render() Multiprocessing: Unleashing the Power of Vectorized Environments @@ -470,19 +471,19 @@ The parking env is a goal-conditioned continuous control task, in which the vehi # HER must be loaded with the env model = SAC.load("her_sac_highway", env=env) - obs = env.reset() + obs, info = env.reset() # Evaluate the agent episode_reward = 0 for _ in range(100): action, _ = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, done, truncated, info = env.step(action) env.render() episode_reward += reward - if done or info.get("is_success", False): + if done or truncated or info.get("is_success", False): print("Reward:", episode_reward, "Success?", info.get("is_success", False)) episode_reward = 0.0 - obs = env.reset() + obs, info = env.reset() Learning Rate Schedule diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index 064139d253..a1c5473440 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -14,18 +14,24 @@ Here is a quick example of how to train and run A2C on a CartPole environment: from stable_baselines3 import A2C - env = gym.make('CartPole-v1') + env = gym.make("CartPole-v1") - model = A2C('MlpPolicy', env, verbose=1) + model = A2C("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000) - obs = env.reset() + # Note: Gym 0.26+ reset() returns a tuple + # where SB3 VecEnv only return an observation + obs, info = env.reset() for i in range(1000): action, _state = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + # Note: Gym 0.26+ step() returns an additional boolean + # "truncated" where SB3 store truncation information + # in info["TimeLimit.truncated"] + obs, reward, done, truncated, info = env.step(action) env.render() - if done: - obs = env.reset() + # Note: reset is automated in SB3 VecEnv + if done or truncated: + obs, info = env.reset() .. note:: @@ -40,4 +46,4 @@ the policy is registered: from stable_baselines3 import A2C - model = A2C('MlpPolicy', 'CartPole-v1').learn(10000) + model = A2C("MlpPolicy", "CartPole-v1").learn(10000) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e509a3ef7c..8ad6c29529 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -1047,6 +1047,6 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@carlosluis @arjun-kg +@carlosluis @arjun-kg @tlpss @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde From 0874da186b248a640298983483480101124bd3cb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 1 Oct 2022 12:05:49 +0200 Subject: [PATCH 046/153] Update requirements --- README.md | 6 +++--- setup.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 973a180da7..07872a661d 100644 --- a/README.md +++ b/README.md @@ -124,12 +124,12 @@ env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs = env.reset() +obs, info = env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, done, truncated, info = env.step(action) env.render() - if done: + if done or truncated: obs = env.reset() env.close() diff --git a/setup.py b/setup.py index a4d1fd2896..a4365a5705 100644 --- a/setup.py +++ b/setup.py @@ -48,13 +48,13 @@ model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs = env.reset() +obs, info = env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, done, truncated, info = env.step(action) env.render() - if done: - obs = env.reset() + if done or truncated: + obs, info = env.reset() ``` Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.25", + "gym==0.26", "numpy", "torch>=1.11", # For saving models @@ -117,7 +117,7 @@ "opencv-python", "pygame", # For atari games, - "ale-py~=0.7.5", + "ale-py~=0.8.0", "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support From 26ceefc983e75b45ca30963f550d3633e3790f2d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 1 Oct 2022 12:06:22 +0200 Subject: [PATCH 047/153] Add gym 26 compat in vec env --- stable_baselines3/common/type_aliases.py | 2 + stable_baselines3/common/utils.py | 75 ++++++++++++++++++- .../common/vec_env/base_vec_env.py | 1 + .../common/vec_env/dummy_vec_env.py | 18 ++++- .../common/vec_env/subproc_vec_env.py | 21 ++++-- 5 files changed, 104 insertions(+), 13 deletions(-) diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index f4c29ab279..509169dc71 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -11,7 +11,9 @@ GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] +Gym26ResetReturn = Tuple[GymObs, Dict] GymStepReturn = Tuple[GymObs, float, bool, Dict] +Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict] TensorDict = Dict[Union[str, int], th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 8c20f987ce..d91dc94c2f 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -2,6 +2,7 @@ import os import platform import random +import warnings from collections import deque from inspect import signature from itertools import zip_longest @@ -20,7 +21,16 @@ SummaryWriter = None from stable_baselines3.common.logger import Logger, configure -from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.type_aliases import ( + Gym26StepReturn, + GymEnv, + GymObs, + GymStepReturn, + Schedule, + TensorDict, + TrainFreq, + TrainFrequencyUnit, +) def set_random_seed(seed: int, using_cuda: bool = False) -> None: @@ -535,3 +545,66 @@ def compat_gym_seed(env: GymEnv, seed: int) -> None: else: # VecEnv and backward compatibility env.seed(seed) + + +def compat_gym_26_step( + step_returns: Union[GymStepReturn, Gym26StepReturn], + warn=True, +) -> Gym26StepReturn: + """ + Transform step returns to old step API irrespective of input API. + This makes env written with both gym < 0.26 and gym > 0.26 + compatible with SB3. + This is a simplified version of the helper found in Gym. + + :param step_returns: Items returned by `env.step()`. + Can be (obs, reward, done, info) or (obs, reward, terminated, truncated, info) + :param warn: Whether to warn or not the user + :return: (obs, reward, done, info) with info["TimeLimit.truncated"] = truncated + """ + if len(step_returns) == 4: + if warn: + warnings.warn( + "You are using gym API < 0.26, please upgrade to gym > 0.26 " + "step API with 5 values returned (obs, reward, terminated, truncated, info) " + "instead of the current 4 values (obs, reward, done, info). " + "Please read Gym documentation for more information.", + DeprecationWarning, + ) + observations, rewards, done, infos = step_returns + truncated = infos.get("TimeLimit.truncated", False) # pytype: disable=attribute-error] + return observations, rewards, done, truncated, infos # pytype: disable=bad-return-type + else: + assert len(step_returns) == 5, f"The step function returned {len(step_returns)} values instead of 5" + return step_returns # pytype: disable=bad-return-type + + +def compat_gym_26_reset( + reset_return: Union[GymObs, Tuple[GymObs, Dict]], + warn: bool = True, +) -> Tuple[GymObs, Dict]: + """ + Transform gym reset return to new API (gym > 0.26) + + :param reset_return: (obs) or (obs, info) + :param warn: Whether to warn or not the user + :return: (obs, info) + """ + if isinstance(reset_return, tuple): + assert len(reset_return), ( + "The tuple returned by the reset() function " + f"has a length of {len(reset_return)} != 2. " + "It should only return (obs, info)" + ) + return reset_return + else: + if warn: + warnings.warn( + "You are using gym API < 0.26, please upgrade to gym > 0.26 " + "reset API with 2 values returned (obs, info) " + "instead of the current single value (obs). " + "Please read Gym documentation for more information.", + DeprecationWarning, + ) + + return reset_return, {} diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 98706050c7..8c09d3c8c4 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -59,6 +59,7 @@ def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_sp self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space + self.reset_infos = [{} for _ in range(num_envs)] # store info returns by the reset method @abstractmethod def reset(self) -> VecEnvObs: diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 4162b27ac5..b4f0f28554 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -39,14 +39,21 @@ def step_async(self, actions: np.ndarray) -> None: self.actions = actions def step_wait(self) -> VecEnvStepReturn: + # Avoid circular imports + from stable_baselines3.common.utils import compat_gym_26_reset, compat_gym_26_step + for env_idx in range(self.num_envs): - obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( - self.actions[env_idx] + obs, self.buf_rews[env_idx], done, truncated, self.buf_infos[env_idx] = compat_gym_26_step( + self.envs[env_idx].step(self.actions[env_idx]) ) + # convert to SB3 VecEnv api + self.buf_dones[env_idx] = done or truncated + self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated + if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = compat_gym_26_reset(self.envs[env_idx].reset()) self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) @@ -62,8 +69,11 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: return seeds def reset(self) -> VecEnvObs: + # Avoid circular imports + from stable_baselines3.common.utils import compat_gym_26_reset + for env_idx in range(self.num_envs): - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = compat_gym_26_reset(self.envs[env_idx].reset()) self._save_obs(env_idx, obs) return self._obs_from_buf() diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 06ad7c7563..f5f85cf331 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -21,25 +21,29 @@ def _worker( # noqa: C901 ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped - from stable_baselines3.common.utils import compat_gym_seed + from stable_baselines3.common.utils import compat_gym_26_reset, compat_gym_26_step, compat_gym_seed parent_remote.close() env = env_fn_wrapper.var() + reset_info = {} while True: try: cmd, data = remote.recv() if cmd == "step": - observation, reward, done, info = env.step(data) + observation, reward, done, truncated, info = compat_gym_26_step(env.step(data)) + # convert to SB3 VecEnv api + done = done or truncated + info["TimeLimit.truncated"] = truncated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation - observation = env.reset() - remote.send((observation, reward, done, info)) + observation, reset_info = compat_gym_26_reset(env.reset()) + remote.send((observation, reward, done, info, reset_info)) elif cmd == "seed": remote.send(compat_gym_seed(env, seed=data)) elif cmd == "reset": - observation = env.reset() - remote.send(observation) + observation, reset_info = compat_gym_26_reset(env.reset()) + remote.send((observation, reset_info)) elif cmd == "render": remote.send(env.render(data)) elif cmd == "close": @@ -122,7 +126,7 @@ def step_async(self, actions: np.ndarray) -> None: def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False - obs, rews, dones, infos = zip(*results) + obs, rews, dones, infos, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -135,7 +139,8 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: def reset(self) -> VecEnvObs: for remote in self.remotes: remote.send(("reset", None)) - obs = [remote.recv() for remote in self.remotes] + results = [remote.recv() for remote in self.remotes] + obs, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space) def close(self) -> None: From a8c579aa9a6f183bc8dda83cb38fb4db5d739d1e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 1 Oct 2022 12:22:18 +0200 Subject: [PATCH 048/153] Fixes in envs and tests for gym 0.26+ --- stable_baselines3/common/atari_wrappers.py | 54 ++++++++++--------- stable_baselines3/common/env_checker.py | 13 +++-- .../common/envs/bit_flipping_env.py | 17 +++--- stable_baselines3/common/envs/identity_env.py | 30 +++++------ .../common/envs/multi_input_envs.py | 15 +++--- stable_baselines3/common/monitor.py | 12 ++--- .../common/off_policy_algorithm.py | 1 - stable_baselines3/common/utils.py | 6 +-- stable_baselines3/td3/td3.py | 1 - tests/test_buffers.py | 12 ++--- tests/test_cnn.py | 4 +- tests/test_dict_env.py | 8 +-- tests/test_env_checker.py | 5 +- tests/test_envs.py | 42 +++++++++------ tests/test_gae.py | 14 +++-- tests/test_her.py | 2 +- tests/test_identity.py | 11 ++-- tests/test_logger.py | 4 +- tests/test_monitor.py | 12 ++--- tests/test_predict.py | 10 ++-- tests/test_run.py | 1 + tests/test_spaces.py | 12 ++--- tests/test_utils.py | 11 ++-- tests/test_vec_check_nan.py | 4 +- tests/test_vec_envs.py | 14 ++--- tests/test_vec_normalize.py | 13 ++--- 26 files changed, 176 insertions(+), 152 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index b1a200ba0f..ba08da7b5b 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,3 +1,5 @@ +from typing import Dict, Tuple + import gym import numpy as np from gym import spaces @@ -9,7 +11,7 @@ except ImportError: cv2 = None -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn class NoopResetEnv(gym.Wrapper): @@ -28,7 +30,7 @@ def __init__(self, env: gym.Env, noop_max: int = 30): self.noop_action = 0 assert env.unwrapped.get_action_meanings()[0] == "NOOP" - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) if self.override_num_noops is not None: noops = self.override_num_noops @@ -36,11 +38,12 @@ def reset(self, **kwargs) -> np.ndarray: noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) + info = {} for _ in range(noops): - obs, _, done, _ = self.env.step(self.noop_action) - if done: - obs = self.env.reset(**kwargs) - return obs + obs, _, done, truncated, info = self.env.step(self.noop_action) + if done or truncated: + obs, info = self.env.reset(**kwargs) + return obs, info class FireResetEnv(gym.Wrapper): @@ -55,15 +58,15 @@ def __init__(self, env: gym.Env): assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(1) - if done: + obs, _, done, truncated, _ = self.env.step(1) + if done or truncated: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(2) - if done: + obs, _, done, truncated, _ = self.env.step(2) + if done or truncated: self.env.reset(**kwargs) - return obs + return obs, {} class EpisodicLifeEnv(gym.Wrapper): @@ -79,21 +82,21 @@ def __init__(self, env: gym.Env): self.lives = 0 self.was_real_done = True - def step(self, action: int) -> GymStepReturn: - obs, reward, done, info = self.env.step(action) + def step(self, action: int) -> Gym26StepReturn: + obs, reward, done, truncated, info = self.env.step(action) self.was_real_done = done # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() if 0 < lives < self.lives: - # for Qbert sometimes we stay in lives == 0 condtion for a few frames + # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. done = True self.lives = lives - return obs, reward, done, info + return obs, reward, done, truncated, info - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: """ Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, @@ -103,12 +106,12 @@ def reset(self, **kwargs) -> np.ndarray: :return: the first observation of the environment """ if self.was_real_done: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, _, _ = self.env.step(0) + obs, _, _, info = self.env.step(0) self.lives = self.env.unwrapped.ale.lives() - return obs + return obs, info class MaxAndSkipEnv(gym.Wrapper): @@ -125,7 +128,7 @@ def __init__(self, env: gym.Env, skip: int = 4): self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype) self._skip = skip - def step(self, action: int) -> GymStepReturn: + def step(self, action: int) -> Gym26StepReturn: """ Step the environment with the given action Repeat action, sum reward, and max over last observations. @@ -134,9 +137,10 @@ def step(self, action: int) -> GymStepReturn: :return: observation, reward, done, information """ total_reward = 0.0 - done = None + terminated = truncated = False for i in range(self._skip): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) + done = terminated or truncated if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: @@ -148,9 +152,9 @@ def step(self, action: int) -> GymStepReturn: # doesn't matter max_frame = self._obs_buffer.max(axis=0) - return max_frame, total_reward, done, info + return max_frame, total_reward, terminated, truncated, info - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Gym26ResetReturn: return self.env.reset(**kwargs) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 4033bff290..c383c8edab 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -197,7 +197,11 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action Check the returned values by the env when calling `.reset()` or `.step()` methods. """ # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists - obs = env.reset() + reset_returns = env.reset() + assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)" + assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}" + obs, info = reset_returns + assert isinstance(info, dict), "The second element of the tuple return by `reset()` must be a dictionary" if _is_goal_env(env): _check_goal_env_obs(obs, observation_space, "reset") @@ -222,10 +226,10 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action action = action_space.sample() data = env.step(action) - assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info" + assert len(data) == 5, "The `step()` method must return four values: obs, reward, terminated, truncated, info" # Unpack - obs, reward, done, info = data + obs, reward, terminated, truncated, info = data if _is_goal_env(env): _check_goal_env_obs(obs, observation_space, "step") @@ -250,7 +254,8 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # We also allow int because the reward will be cast to float assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" - assert isinstance(done, bool), "The `done` signal must be a boolean" + assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" + assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" # Goal conditioned env diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 5089ee7d30..0fc93a6cf6 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -1,11 +1,11 @@ from collections import OrderedDict -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np from gym import Env, spaces from gym.envs.registration import EnvSpec -from stable_baselines3.common.type_aliases import GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn class BitFlippingEnv(Env): @@ -25,7 +25,7 @@ class BitFlippingEnv(Env): :param channel_first: Whether to use channel-first or last image. """ - spec = EnvSpec("BitFlippingEnv-v0") + spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point") def __init__( self, @@ -157,14 +157,14 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ] ) - def reset(self, seed: Optional[int] = None) -> Dict[str, Union[int, np.ndarray]]: + def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: if seed is not None: self.obs_space.seed(seed) self.current_step = 0 self.state = self.obs_space.sample() - return self._get_obs() + return self._get_obs(), {} - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: """ Step into the env. @@ -181,8 +181,9 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps info = {"is_success": done} - done = done or self.current_step >= self.max_steps - return obs, reward, done, info + truncated = self.current_step >= self.max_steps + done = done or truncated + return obs, reward, done, truncated, info def compute_reward( self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index e06ab63ea8..aa71696111 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -1,10 +1,10 @@ -from typing import Optional, Union +from typing import Dict, Optional, Tuple, Union import numpy as np from gym import Env, Space from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn class IdentityEnv(Env): @@ -32,20 +32,20 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_ self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self, seed: Optional[int] = None) -> GymObs: + def reset(self, seed: Optional[int] = None) -> Gym26ResetReturn: if seed is not None: super().reset(seed=seed) self.current_step = 0 self.num_resets += 1 self._choose_next_state() - return self.state + return self.state, {} - def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: + def step(self, action: Union[int, np.ndarray]) -> Gym26StepReturn: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.state, reward, done, truncated, {} def _choose_next_state(self) -> None: self.state = self.action_space.sample() @@ -71,12 +71,12 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l super().__init__(ep_length=ep_length, space=space) self.eps = eps - def step(self, action: np.ndarray) -> GymStepReturn: + def step(self, action: np.ndarray) -> Gym26StepReturn: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.state, reward, done, truncated, {} def _get_reward(self, action: np.ndarray) -> float: return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 @@ -138,17 +138,17 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self, seed: Optional[int] = None) -> np.ndarray: + def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: reward = 0.0 self.current_step += 1 - done = self.current_step >= self.ep_length - return self.observation_space.sample(), reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.observation_space.sample(), reward, done, truncated, {} def render(self, mode: str = "human") -> None: pass diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 3cb0cc7fa3..c7b5973daa 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,9 +1,9 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union import gym import numpy as np -from stable_baselines3.common.type_aliases import GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn class SimpleMultiObsEnv(gym.Env): @@ -120,7 +120,7 @@ def init_possible_transitions(self) -> None: self.right_possible = [0, 1, 2, 12, 13, 14] self.up_possible = [4, 8, 12, 7, 11, 15] - def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn: + def step(self, action: Union[int, float, np.ndarray]) -> Gym26StepReturn: """ Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` @@ -152,11 +152,12 @@ def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn: got_to_end = self.state == self.max_state reward = 1 if got_to_end else reward - done = self.count > self.max_count or got_to_end + truncated = self.count > self.max_count + done = got_to_end or truncated self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" - return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} + return self.get_state_mapping(), reward, done, truncated, {"got_to_end": got_to_end} def render(self, mode: str = "human") -> None: """ @@ -166,7 +167,7 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self, seed: Optional[int] = None) -> Dict[str, np.ndarray]: + def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, np.ndarray], Dict]: """ Resets the environment state and step count and returns reset observation. @@ -180,4 +181,4 @@ def reset(self, seed: Optional[int] = None) -> Dict[str, np.ndarray]: self.state = 0 else: self.state = np.random.randint(0, self.max_state) - return self.state_mapping[self.state] + return self.state_mapping[self.state], {} diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 1e56fdb78b..499753553f 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -11,7 +11,7 @@ import numpy as np import pandas -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn class Monitor(gym.Wrapper): @@ -61,7 +61,7 @@ def __init__( self.total_steps = 0 self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Gym26ResetReturn: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True @@ -82,7 +82,7 @@ def reset(self, **kwargs) -> GymObs: self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: """ Step the environment with the given action @@ -91,9 +91,9 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") - observation, reward, done, info = self.env.step(action) + observation, reward, done, truncated, info = self.env.step(action) self.rewards.append(reward) - if done: + if done or truncated: self.needs_reset = True ep_rew = sum(self.rewards) ep_len = len(self.rewards) @@ -108,7 +108,7 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: self.results_writer.write_row(ep_info) info["episode"] = ep_info self.total_steps += 1 - return observation, reward, done, info + return observation, reward, done, truncated, info def close(self) -> None: """ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c23223dc12..f53e07e9cd 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -618,7 +618,6 @@ def collect_rollouts( # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self._dump_logs() - callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index d91dc94c2f..ff521a3dc1 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -572,11 +572,11 @@ def compat_gym_26_step( DeprecationWarning, ) observations, rewards, done, infos = step_returns - truncated = infos.get("TimeLimit.truncated", False) # pytype: disable=attribute-error] - return observations, rewards, done, truncated, infos # pytype: disable=bad-return-type + truncated = infos.get("TimeLimit.truncated", False) # pytype: disable=attribute-error] + return observations, rewards, done, truncated, infos # pytype: disable=bad-return-type else: assert len(step_returns) == 5, f"The step function returned {len(step_returns)} values instead of 5" - return step_returns # pytype: disable=bad-return-type + return step_returns # pytype: disable=bad-return-type def compat_gym_26_reset( diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 62e33f56e8..ae7895dc83 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -154,7 +154,6 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: self._update_learning_rate([self.actor.optimizer, self.critic.optimizer]) actor_losses, critic_losses = [], [] - for _ in range(gradient_steps): self._n_updates += 1 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 0e028e670d..4bd2d27939 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -27,15 +27,15 @@ def __init__(self): def reset(self): self._t = 0 obs = self._observations[0] - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = self._observations[index] - done = self._t >= self._ep_length + done = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, done, truncated, {} class DummyDictEnv(gym.Env): @@ -55,15 +55,15 @@ def __init__(self): def reset(self): self._t = 0 obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()} - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()} - done = self._t >= self._ep_length + done = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, done, truncated, {} @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 03f089db99..48dce0821b 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -35,7 +35,7 @@ def test_cnn(tmp_path, model_class): # FakeImageEnv is channel last by default and should be wrapped assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() # Test stochastic predict with channel last input if model_class == DQN: @@ -238,7 +238,7 @@ def test_channel_first_env(tmp_path): assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs, deterministic=True) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index d7c8003d63..1f832bfeb8 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -68,8 +68,8 @@ def seed(self, seed=None): def step(self, action): reward = 0.0 - done = False - return self.observation_space.sample(), reward, done, {} + done = truncated = False + return self.observation_space.sample(), reward, done, truncated, {} def compute_reward(self, achieved_goal, desired_goal, info): return np.zeros((len(achieved_goal),)) @@ -77,7 +77,7 @@ def compute_reward(self, achieved_goal, desired_goal, info): def reset(self, seed: Optional[int] = None): if seed is not None: self.observation_space.seed(seed) - return self.observation_space.sample() + return self.observation_space.sample(), {} def render(self, mode="human"): pass @@ -109,7 +109,7 @@ def test_consistency(model_class): dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) dict_env.seed(10) - obs = dict_env.reset() + obs, _ = dict_env.reset() kwargs = {} n_steps = 256 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 0b0a82d8fe..2313defc9f 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -14,11 +14,12 @@ def step(self, action): observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) reward = 1 done = True + truncated = False info = {} - return observation, reward, done, info + return observation, reward, done, truncated, info def reset(self): - return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) + return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} def render(self, mode="human"): pass diff --git a/tests/test_envs.py b/tests/test_envs.py index a8d603322a..5c00b9473f 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -98,7 +98,7 @@ def test_high_dimension_action_space(): # Patch to avoid error def patched_step(_action): - return env.observation_space.sample(), 0.0, False, {} + return env.observation_space.sample(), 0.0, False, False, {} env.step = patched_step check_env(env) @@ -127,10 +127,10 @@ def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space # Patch methods to avoid errors - env.reset = new_obs_space.sample + env.reset = lambda: (new_obs_space.sample(), {}) def patched_step(_action): - return new_obs_space.sample(), 0.0, False, {} + return new_obs_space.sample(), 0.0, False, False, {} env.step = patched_step with pytest.warns(UserWarning): @@ -193,7 +193,7 @@ def check_reset_assert_error(env, new_reset_return): """ def wrong_reset(): - return new_reset_return + return new_reset_return, {} # Patch the reset method with a wrong one env.reset = wrong_reset @@ -211,6 +211,11 @@ def test_common_failures_reset(): # The observation is not a numpy array check_reset_assert_error(env, 1) + # Return only obs (gym < 0.26) + env.reset = env.observation_space.sample + with pytest.raises(AssertionError): + check_env(env) + # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) @@ -223,10 +228,10 @@ def test_common_failures_reset(): wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_reset_assert_error(env, wrong_obs) - obs = env.reset() + obs, _ = env.reset() def wrong_reset(self): - return {"img": obs["img"], "vec": obs["img"]} + return {"img": obs["img"], "vec": obs["img"]}, {} env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError) as excinfo: @@ -259,33 +264,38 @@ def test_common_failures_step(): env = IdentityEnvBox() # Wrong shape for the observation - check_step_assert_error(env, (np.ones((4,)), 1.0, False, {})) + check_step_assert_error(env, (np.ones((4,)), 1.0, False, False, {})) # Obs is not a numpy array - check_step_assert_error(env, (1, 1.0, False, {})) + check_step_assert_error(env, (1, 1.0, False, False, {})) # Return a wrong reward - check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {})) + check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, False, {})) # Info dict is not returned - check_step_assert_error(env, (env.observation_space.sample(), 0.0, False)) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, False)) + + # Truncated is not returned (gym < 0.26) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, {})) # Done is not a boolean - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, False, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, False, {})) + # Truncated is not a boolean + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, 1.0, {})) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) wrong_obs = {**env.observation_space.sample(), "extra_key": None} - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) - obs = env.reset() + obs, _ = env.reset() def wrong_step(self, action): - return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {} + return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, False, {} env.step = types.MethodType(wrong_step, env) with pytest.raises(AssertionError) as excinfo: diff --git a/tests/test_gae.py b/tests/test_gae.py index 4bb371caa2..35f01ac77b 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -25,18 +25,22 @@ def reset(self, seed: Optional[int] = None): if seed is not None: self.observation_space.seed(seed) self.n_steps = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): self.n_steps += 1 - done = False + done = truncated = False reward = 0.0 if self.n_steps >= self.max_steps: reward = 1.0 done = True + # To simplify GAE computation checks, + # we do not consider truncation here. + # Truncations are checked in InfiniteHorizonEnv + truncated = False - return self.observation_space.sample(), reward, done, {} + return self.observation_space.sample(), reward, done, truncated, {} class InfiniteHorizonEnv(gym.Env): @@ -52,11 +56,11 @@ def reset(self, seed: Optional[int] = None): super().reset(seed=seed) self.current_state = 0 - return self.current_state + return self.current_state, {} def step(self, action): self.current_state = (self.current_state + 1) % self.n_states - return self.current_state, 1.0, False, {} + return self.current_state, 1.0, False, False, {} class CheckGAECallback(BaseCallback): diff --git a/tests/test_her.py b/tests/test_her.py index b83ba1dc0f..c1bc515ed6 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -143,7 +143,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): model.learn(total_timesteps=150) - obs = env.reset() + obs, _ = env.reset() observations = {key: [] for key in obs.keys()} for _ in range(10): diff --git a/tests/test_identity.py b/tests/test_identity.py index b4dee2f41e..66443b1b33 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -15,21 +15,17 @@ def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} - n_steps = 3000 + n_steps = 2500 if model_class == DQN: kwargs = dict(learning_starts=0) - n_steps = 4000 # DQN only support discrete actions if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return - elif model_class == A2C: - # slightly higher budget - n_steps = 3500 model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) - obs = env.reset() + obs, _ = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) @@ -38,9 +34,10 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class] + n_steps = {A2C: 2000, PPO: 2500, SAC: 700, TD3: 500, DDPG: 500}[model_class] kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) + if model_class in [TD3]: n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) diff --git a/tests/test_logger.py b/tests/test_logger.py index 92b65e8a53..9bf05213a9 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -354,12 +354,12 @@ def __init__(self, delay: float = 0.01): self.action_space = gym.spaces.Discrete(2) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): time.sleep(self.delay) obs = self.observation_space.sample() - return obs, 0.0, True, {} + return obs, 0.0, True, False, {} class InMemoryLogger(Logger): diff --git a/tests/test_monitor.py b/tests/test_monitor.py index d7400c1e3f..c580fcf49b 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -22,10 +22,10 @@ def test_monitor(tmp_path): ep_lengths = [] ep_len, ep_reward = 0, 0 for _ in range(total_steps): - _, reward, done, _ = monitor_env.step(monitor_env.action_space.sample()) + _, reward, done, truncated, _ = monitor_env.step(monitor_env.action_space.sample()) ep_len += 1 ep_reward += reward - if done: + if done or truncated: ep_rewards.append(ep_reward) ep_lengths.append(ep_len) monitor_env.reset() @@ -75,8 +75,8 @@ def test_monitor_load_results(tmp_path): monitor_env1.reset() episode_count1 = 0 for _ in range(1000): - _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample()) - if done: + _, _, done, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample()) + if done or truncated: episode_count1 += 1 monitor_env1.reset() @@ -98,8 +98,8 @@ def test_monitor_load_results(tmp_path): monitor_env2 = Monitor(env2, monitor_file2, override_existing=False) monitor_env2.reset() for _ in range(1000): - _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample()) - if done: + _, _, done, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample()) + if done or truncated: episode_count2 += 1 monitor_env2.reset() diff --git a/tests/test_predict.py b/tests/test_predict.py index c6e499abbc..0a6855a02f 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -29,10 +29,10 @@ def __init__(self): self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, {} + return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {} @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -71,7 +71,7 @@ def test_predict(model_class, env_id, device): env = gym.make(env_id) vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs) assert isinstance(action, np.ndarray) assert action.shape == env.action_space.shape @@ -97,7 +97,7 @@ def test_dqn_epsilon_greedy(): env = IdentityEnv(2) model = DQN("MlpPolicy", env) model.exploration_rate = 1.0 - obs = env.reset() + obs, _ = env.reset() # is vectorized should not crash with discrete obs action, _ = model.predict(obs, deterministic=False) assert env.action_space.contains(action) @@ -108,5 +108,5 @@ def test_subclassed_space_env(model_class): env = CustomSubClassedSpaceEnv() model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32])) model.learn(300) - obs = env.reset() + obs, _ = env.reset() env.step(model.predict(obs)) diff --git a/tests/test_run.py b/tests/test_run.py index 655182da1b..9dec724d73 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -26,6 +26,7 @@ def test_deterministic_pg(model_class, action_noise): verbose=1, create_eval_env=True, buffer_size=250, + gradient_steps=1, action_noise=action_noise, ) model.learn(total_timesteps=300, eval_freq=250) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index fb1469722c..2f9c796068 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -18,10 +18,10 @@ def __init__(self, nvec): def reset(self, seed: Optional[int] = None): if seed is not None: super().reset(seed=seed) - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultiBinary(gym.Env): @@ -33,10 +33,10 @@ def __init__(self, n): def reset(self, seed: Optional[int] = None): if seed is not None: super().reset(seed=seed) - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultidimensionalAction(gym.Env): @@ -46,10 +46,10 @@ def __init__(self): self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2b3d595f52..35d5ae1e1e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -191,17 +191,18 @@ def __init__(self, env): self.needs_reset = True def step(self, action): - obs, reward, done, info = self.env.step(action) - self.needs_reset = done + obs, reward, done, truncated, info = self.env.step(action) + self.needs_reset = done or truncated self.last_obs = obs - return obs, reward, True, info + return obs, reward, True, truncated, info def reset(self, **kwargs): + info = {} if self.needs_reset: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) self.last_obs = obs self.needs_reset = False - return self.last_obs + return self.last_obs, info @pytest.mark.parametrize("n_envs", [1, 2, 5, 7]) diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 962355782d..f09a7aae4a 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -24,11 +24,11 @@ def step(action): obs = float("inf") else: obs = 0 - return [obs], 0.0, False, {} + return [obs], 0.0, False, False, {} @staticmethod def reset(): - return [0.0] + return [0.0], {} def render(self, mode="human", close=False): pass diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index c9bdd0853b..be357655e0 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -31,14 +31,14 @@ def reset(self, seed: Optional[int] = None): self.seed(seed) self.current_step = 0 self._choose_next_state() - return self.state + return self.state, {} def step(self, action): reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.state, reward, done, truncated, {} def _choose_next_state(self): self.state = self.observation_space.sample() @@ -147,13 +147,13 @@ def __init__(self, max_steps): def reset(self): self.current_step = 0 - return np.array([self.current_step], dtype="int") + return np.array([self.current_step], dtype="int"), {} def step(self, action): prev_step = self.current_step self.current_step += 1 - done = self.current_step >= self.max_steps - return np.array([prev_step], dtype="int"), 0.0, done, {} + done = truncated = self.current_step >= self.max_steps + return np.array([prev_step], dtype="int"), 0.0, done, truncated, {} @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) @@ -452,7 +452,7 @@ def test_backward_compat_seed(vec_env_class): def make_env(): env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) # Patch reset function to remove seed param - env.reset = env.observation_space.sample + env.reset = lambda: (env.observation_space.sample(), {}) env.seed = env.observation_space.seed return env diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index b9e3cbf7f6..a2c35f49b6 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -35,13 +35,14 @@ def step(self, action): self.t += 1 index = (self.t + self.return_reward_idx) % len(self.returned_rewards) returned_value = self.returned_rewards[index] - return np.array([returned_value]), returned_value, self.t == len(self.returned_rewards), {} + done = truncated = self.t == len(self.returned_rewards) + return np.array([returned_value]), returned_value, done, truncated, {} def reset(self, seed: Optional[int] = None): if seed is not None: super().reset(seed=seed) self.t = 0 - return np.array([self.returned_rewards[self.return_reward_idx]]) + return np.array([self.returned_rewards[self.return_reward_idx]]), {} class DummyDictEnv(gym.Env): @@ -63,13 +64,13 @@ def __init__(self): def reset(self, seed: Optional[int] = None): if seed is not None: super().reset(seed=seed) - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {}) done = np.random.rand() > 0.8 - return obs, reward, done, {} + return obs, reward, done, False, {} def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32: distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) @@ -95,12 +96,12 @@ def __init__(self): def reset(self, seed: Optional[int] = None): if seed is not None: super().reset(seed=seed) - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() done = np.random.rand() > 0.8 - return obs, 0.0, done, {} + return obs, 0.0, done, False, {} def allclose(obs_1, obs_2): From 6ed30791176fef0fc9e55ba07eff0017ea12c108 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 1 Oct 2022 12:25:53 +0200 Subject: [PATCH 049/153] Enforce gym 0.26 api --- stable_baselines3/common/utils.py | 67 ------------------- .../common/vec_env/dummy_vec_env.py | 13 ++-- .../common/vec_env/subproc_vec_env.py | 8 +-- 3 files changed, 8 insertions(+), 80 deletions(-) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index ff521a3dc1..f3c88f2364 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -2,7 +2,6 @@ import os import platform import random -import warnings from collections import deque from inspect import signature from itertools import zip_longest @@ -22,10 +21,7 @@ from stable_baselines3.common.logger import Logger, configure from stable_baselines3.common.type_aliases import ( - Gym26StepReturn, GymEnv, - GymObs, - GymStepReturn, Schedule, TensorDict, TrainFreq, @@ -545,66 +541,3 @@ def compat_gym_seed(env: GymEnv, seed: int) -> None: else: # VecEnv and backward compatibility env.seed(seed) - - -def compat_gym_26_step( - step_returns: Union[GymStepReturn, Gym26StepReturn], - warn=True, -) -> Gym26StepReturn: - """ - Transform step returns to old step API irrespective of input API. - This makes env written with both gym < 0.26 and gym > 0.26 - compatible with SB3. - This is a simplified version of the helper found in Gym. - - :param step_returns: Items returned by `env.step()`. - Can be (obs, reward, done, info) or (obs, reward, terminated, truncated, info) - :param warn: Whether to warn or not the user - :return: (obs, reward, done, info) with info["TimeLimit.truncated"] = truncated - """ - if len(step_returns) == 4: - if warn: - warnings.warn( - "You are using gym API < 0.26, please upgrade to gym > 0.26 " - "step API with 5 values returned (obs, reward, terminated, truncated, info) " - "instead of the current 4 values (obs, reward, done, info). " - "Please read Gym documentation for more information.", - DeprecationWarning, - ) - observations, rewards, done, infos = step_returns - truncated = infos.get("TimeLimit.truncated", False) # pytype: disable=attribute-error] - return observations, rewards, done, truncated, infos # pytype: disable=bad-return-type - else: - assert len(step_returns) == 5, f"The step function returned {len(step_returns)} values instead of 5" - return step_returns # pytype: disable=bad-return-type - - -def compat_gym_26_reset( - reset_return: Union[GymObs, Tuple[GymObs, Dict]], - warn: bool = True, -) -> Tuple[GymObs, Dict]: - """ - Transform gym reset return to new API (gym > 0.26) - - :param reset_return: (obs) or (obs, info) - :param warn: Whether to warn or not the user - :return: (obs, info) - """ - if isinstance(reset_return, tuple): - assert len(reset_return), ( - "The tuple returned by the reset() function " - f"has a length of {len(reset_return)} != 2. " - "It should only return (obs, info)" - ) - return reset_return - else: - if warn: - warnings.warn( - "You are using gym API < 0.26, please upgrade to gym > 0.26 " - "reset API with 2 values returned (obs, info) " - "instead of the current single value (obs). " - "Please read Gym documentation for more information.", - DeprecationWarning, - ) - - return reset_return, {} diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index b4f0f28554..c663558a79 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -40,11 +40,9 @@ def step_async(self, actions: np.ndarray) -> None: def step_wait(self) -> VecEnvStepReturn: # Avoid circular imports - from stable_baselines3.common.utils import compat_gym_26_reset, compat_gym_26_step - for env_idx in range(self.num_envs): - obs, self.buf_rews[env_idx], done, truncated, self.buf_infos[env_idx] = compat_gym_26_step( - self.envs[env_idx].step(self.actions[env_idx]) + obs, self.buf_rews[env_idx], done, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( + self.actions[env_idx] ) # convert to SB3 VecEnv api self.buf_dones[env_idx] = done or truncated @@ -53,7 +51,7 @@ def step_wait(self) -> VecEnvStepReturn: if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs - obs, self.reset_infos[env_idx] = compat_gym_26_reset(self.envs[env_idx].reset()) + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) @@ -69,11 +67,8 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: return seeds def reset(self) -> VecEnvObs: - # Avoid circular imports - from stable_baselines3.common.utils import compat_gym_26_reset - for env_idx in range(self.num_envs): - obs, self.reset_infos[env_idx] = compat_gym_26_reset(self.envs[env_idx].reset()) + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return self._obs_from_buf() diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index f5f85cf331..367a87f136 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -21,7 +21,7 @@ def _worker( # noqa: C901 ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped - from stable_baselines3.common.utils import compat_gym_26_reset, compat_gym_26_step, compat_gym_seed + from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() env = env_fn_wrapper.var() @@ -30,19 +30,19 @@ def _worker( # noqa: C901 try: cmd, data = remote.recv() if cmd == "step": - observation, reward, done, truncated, info = compat_gym_26_step(env.step(data)) + observation, reward, done, truncated, info = env.step(data) # convert to SB3 VecEnv api done = done or truncated info["TimeLimit.truncated"] = truncated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation - observation, reset_info = compat_gym_26_reset(env.reset()) + observation, reset_info = env.reset() remote.send((observation, reward, done, info, reset_info)) elif cmd == "seed": remote.send(compat_gym_seed(env, seed=data)) elif cmd == "reset": - observation, reset_info = compat_gym_26_reset(env.reset()) + observation, reset_info = env.reset() remote.send((observation, reset_info)) elif cmd == "render": remote.send(env.render(data)) From 95bb4d6b880f21e8e3117d29d4391937f12fedbf Mon Sep 17 00:00:00 2001 From: tlips Date: Sun, 2 Oct 2022 12:49:40 +0200 Subject: [PATCH 050/153] format --- stable_baselines3/common/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index f3c88f2364..8c20f987ce 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -20,13 +20,7 @@ SummaryWriter = None from stable_baselines3.common.logger import Logger, configure -from stable_baselines3.common.type_aliases import ( - GymEnv, - Schedule, - TensorDict, - TrainFreq, - TrainFrequencyUnit, -) +from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit def set_random_seed(seed: int, using_cuda: bool = False) -> None: From c4517f29a1071f73b835152b6e6329bf6e882ede Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 2 Oct 2022 15:29:38 +0200 Subject: [PATCH 051/153] Fix formatting --- stable_baselines3/common/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index f3c88f2364..8c20f987ce 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -20,13 +20,7 @@ SummaryWriter = None from stable_baselines3.common.logger import Logger, configure -from stable_baselines3.common.type_aliases import ( - GymEnv, - Schedule, - TensorDict, - TrainFreq, - TrainFrequencyUnit, -) +from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit def set_random_seed(seed: int, using_cuda: bool = False) -> None: From 9ac7592c2f14bee93d083d69e8cdc56333b56644 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 2 Oct 2022 16:14:49 +0200 Subject: [PATCH 052/153] Fix dependencies --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index a4365a5705..f6336f74a4 100644 --- a/setup.py +++ b/setup.py @@ -94,6 +94,8 @@ "pytype", # Lint code "flake8>=3.8", + # flake8 not compatible with importlib-metadata>5.0 + "importlib-metadata~=4.13" # Find likely bugs "flake8-bugbear", # Sort imports From d2e687320f6e49cfd7f6a93cc0e801f0358a3307 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 2 Oct 2022 16:26:48 +0200 Subject: [PATCH 053/153] Fix syntax --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f6336f74a4..d296e7e113 100644 --- a/setup.py +++ b/setup.py @@ -95,7 +95,7 @@ # Lint code "flake8>=3.8", # flake8 not compatible with importlib-metadata>5.0 - "importlib-metadata~=4.13" + "importlib-metadata~=4.13", # Find likely bugs "flake8-bugbear", # Sort imports From 969c1cfdcad1932e8a5e78198eb91fb547c417fa Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 2 Oct 2022 17:10:36 +0200 Subject: [PATCH 054/153] Cleanup doc and warnings --- README.md | 12 +++++++----- docs/guide/quickstart.rst | 18 +++++++----------- setup.cfg | 5 ----- setup.py | 15 +++++++++------ stable_baselines3/__init__.py | 4 ---- .../common/vec_env/base_vec_env.py | 2 +- tests/test_her.py | 2 +- tests/test_vec_monitor.py | 4 ---- 8 files changed, 25 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 07872a661d..58521688ac 100644 --- a/README.md +++ b/README.md @@ -124,13 +124,15 @@ env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs, info = env.reset() +vec_env = model.get_env() +obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, truncated, info = env.step(action) - env.render() - if done or truncated: - obs = env.reset() + obs, reward, done, info = vec_env.step(action) + vec_env.render() + # VecEnv resets automatically + # if done: + # obs = env.reset() env.close() ``` diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index a1c5473440..aa2ddee4ce 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -19,19 +19,15 @@ Here is a quick example of how to train and run A2C on a CartPole environment: model = A2C("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000) - # Note: Gym 0.26+ reset() returns a tuple - # where SB3 VecEnv only return an observation - obs, info = env.reset() + vec_env = model.get_env() + obs = vec_env.reset() for i in range(1000): action, _state = model.predict(obs, deterministic=True) - # Note: Gym 0.26+ step() returns an additional boolean - # "truncated" where SB3 store truncation information - # in info["TimeLimit.truncated"] - obs, reward, done, truncated, info = env.step(action) - env.render() - # Note: reset is automated in SB3 VecEnv - if done or truncated: - obs, info = env.reset() + obs, reward, done, info = vec_env.step(action) + vec_env.render() + # VecEnv resets automatically + # if done: + # obs = vec_env.reset() .. note:: diff --git a/setup.cfg b/setup.cfg index bd04ac9e92..63bd337b29 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,12 +10,7 @@ filterwarnings = # Tensorboard warnings ignore::DeprecationWarning:tensorboard # Gym warnings - ; ignore:Parameters to load are deprecated.:DeprecationWarning - ; ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym - ; ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning - ; ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning - ignore:.*step API:DeprecationWarning:gym markers = expensive: marks tests as expensive (deselect with '-m "not expensive"') diff --git a/setup.py b/setup.py index d296e7e113..40f928f649 100644 --- a/setup.py +++ b/setup.py @@ -48,13 +48,16 @@ model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs, info = env.reset() +vec_env = model.get_env() +obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, truncated, info = env.step(action) - env.render() - if done or truncated: - obs, info = env.reset() + obs, reward, done, info = vec_env.step(action) + vec_env.render() + # VecEnv resets automatically + # if done: + # obs = vec_env.reset() + ``` Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): @@ -104,7 +107,7 @@ "black", ], "docs": [ - "sphinx~=4.5.0", + "sphinx", "sphinx-autobuild", "sphinx-rtd-theme", # For spelling diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 4792f6c152..d73f5f095e 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,5 +1,4 @@ import os -import warnings from stable_baselines3.a2c import A2C from stable_baselines3.common.utils import get_system_info @@ -15,9 +14,6 @@ with open(version_file) as file_handler: __version__ = file_handler.read().strip() -# Silence Gym warnings due to new API -warnings.filterwarnings("ignore", message=r".*step API", module="gym") - def HER(*args, **kwargs): raise ImportError( diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 8c09d3c8c4..28ef076849 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -59,7 +59,7 @@ def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_sp self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space - self.reset_infos = [{} for _ in range(num_envs)] # store info returns by the reset method + self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method @abstractmethod def reset(self) -> VecEnvObs: diff --git a/tests/test_her.py b/tests/test_her.py index c1bc515ed6..c3ce19ed75 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -394,7 +394,7 @@ def test_performance_her(online_sampling, n_bits): buffer_size=int(1e5), ) - model.learn(total_timesteps=5000, log_interval=50) + model.learn(total_timesteps=3500, log_interval=50) # 90% training success assert np.mean(model.ep_success_buffer) > 0.90 diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index bbf5e8d216..ea93413acb 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,7 +2,6 @@ import json import os import uuid -import warnings import gym import pandas @@ -133,9 +132,6 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ - # Remove Gym Warnings - warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="gym") - env = DummyVecEnv([lambda: gym.make("CartPole-v1", disable_env_checker=True)]) env.seed(seed=0) monitor_env = VecMonitor(env) From 2fcd07241c5491dc9e0be8ad72829a1b44caabc8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 2 Oct 2022 17:27:28 +0200 Subject: [PATCH 055/153] Faster tests --- tests/test_identity.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_identity.py b/tests/test_identity.py index 66443b1b33..d4b584016d 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -34,7 +34,7 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = {A2C: 2000, PPO: 2500, SAC: 700, TD3: 500, DDPG: 500}[model_class] + n_steps = {A2C: 2000, PPO: 2000, SAC: 400, TD3: 400, DDPG: 400}[model_class] kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) @@ -42,7 +42,9 @@ def test_continuous(model_class): n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) kwargs["action_noise"] = action_noise + elif model_class == PPO: + kwargs = dict(n_steps=512, n_epochs=5) - model = model_class("MlpPolicy", env, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) From dd67a202e05d8bb9a6361c7f858f3c7d4bed944d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 2 Oct 2022 18:20:03 +0200 Subject: [PATCH 056/153] Higher budget for HER perf test (revert prev change) --- tests/test_her.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_her.py b/tests/test_her.py index c3ce19ed75..c1bc515ed6 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -394,7 +394,7 @@ def test_performance_her(online_sampling, n_bits): buffer_size=int(1e5), ) - model.learn(total_timesteps=3500, log_interval=50) + model.learn(total_timesteps=5000, log_interval=50) # 90% training success assert np.mean(model.ep_success_buffer) > 0.90 From 9ae6fa2ead62aa35ad839578316aacafe53f21b2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 5 Oct 2022 14:57:57 +0200 Subject: [PATCH 057/153] Fixes and update doc --- docs/misc/changelog.rst | 36 +++++++++++++++++-- docs/modules/her.rst | 3 +- setup.py | 2 +- stable_baselines3/common/atari_wrappers.py | 22 ++++++------ .../common/vec_env/dummy_vec_env.py | 4 +-- .../common/vec_env/subproc_vec_env.py | 4 +-- stable_baselines3/version.txt | 2 +- 7 files changed, 52 insertions(+), 21 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index db51e670ff..65ed29b242 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,40 @@ Changelog ========== +Release 1.6.2a0 (WIP) +--------------------------- + +.. warning:: + + This version will be the last one supporting ``gym``, + we recommend switching to `gymnasium `_. + You can find a migration guide here: TODO + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Switched minimum Gym version to 0.26 (@carlosluis, @arjun-kg, @tlpss) + + +New Features: +^^^^^^^^^^^^^ + +SB3-Contrib +^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + Release 1.6.1 (2022-09-29) --------------------------- @@ -67,7 +101,6 @@ Release 1.6.0 (2022-07-11) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Switched minimum Gym version to 0.24 (@carlosluis) - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) - SB3 now requires PyTorch >= 1.11 @@ -112,7 +145,6 @@ Documentation: - Added link to PPO ICLR blog post - Added remark about breaking Markov assumption and timeout handling - Added doc about MLFlow integration via custom logger (@git-thor) -- Updated tutorials to work with Gym 0.23 (@arjun-kg) - Updated Huggingface integration doc - Added copy button for code snippets - Added doc about EnvPool and Isaac Gym support diff --git a/docs/modules/her.rst b/docs/modules/her.rst index c0a6abb738..817a991cfb 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -19,10 +19,9 @@ It creates "virtual" transitions by relabeling transitions (changing the desired but a replay buffer class ``HerReplayBuffer`` that must be passed to an off-policy algorithm when using ``MultiInputPolicy`` (to have Dict observation support). - .. warning:: - HER requires the environment to follow the legacy `gym.GoalEnv interface `_ + HER requires the environment to follow the legacy `gym_robotics.GoalEnv interface `_ In short, the ``gym.Env`` must have: - a vectorized implementation of ``compute_reward()`` - a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal`` diff --git a/setup.py b/setup.py index 40f928f649..e4305a2b93 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.26", + "gym==0.26.2", "numpy", "torch>=1.11", # For saving models diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index ba08da7b5b..56dccb0a10 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -40,8 +40,8 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: obs = np.zeros(0) info = {} for _ in range(noops): - obs, _, done, truncated, info = self.env.step(self.noop_action) - if done or truncated: + obs, _, terminated, truncated, info = self.env.step(self.noop_action) + if terminated or truncated: obs, info = self.env.reset(**kwargs) return obs, info @@ -60,11 +60,11 @@ def __init__(self, env: gym.Env): def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) - obs, _, done, truncated, _ = self.env.step(1) - if done or truncated: + obs, _, terminated, truncated, _ = self.env.step(1) + if terminated or truncated: self.env.reset(**kwargs) - obs, _, done, truncated, _ = self.env.step(2) - if done or truncated: + obs, _, terminated, truncated, _ = self.env.step(2) + if terminated or truncated: self.env.reset(**kwargs) return obs, {} @@ -83,8 +83,8 @@ def __init__(self, env: gym.Env): self.was_real_done = True def step(self, action: int) -> Gym26StepReturn: - obs, reward, done, truncated, info = self.env.step(action) - self.was_real_done = done + obs, reward, terminated, truncated, info = self.env.step(action) + self.was_real_done = terminated or truncated # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() @@ -92,9 +92,9 @@ def step(self, action: int) -> Gym26StepReturn: # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. - done = True + terminated = True self.lives = lives - return obs, reward, done, truncated, info + return obs, reward, terminated, truncated, info def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: """ @@ -109,7 +109,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, _, info = self.env.step(0) + obs, _, _, _, info = self.env.step(0) self.lives = self.env.unwrapped.ale.lives() return obs, info diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index c663558a79..b80efd4ae1 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -41,11 +41,11 @@ def step_async(self, actions: np.ndarray) -> None: def step_wait(self) -> VecEnvStepReturn: # Avoid circular imports for env_idx in range(self.num_envs): - obs, self.buf_rews[env_idx], done, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( + obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( self.actions[env_idx] ) # convert to SB3 VecEnv api - self.buf_dones[env_idx] = done or truncated + self.buf_dones[env_idx] = terminated or truncated self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated if self.buf_dones[env_idx]: diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 367a87f136..d133f1d639 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -30,9 +30,9 @@ def _worker( # noqa: C901 try: cmd, data = remote.recv() if cmd == "step": - observation, reward, done, truncated, info = env.step(data) + observation, reward, terminated, truncated, info = env.step(data) # convert to SB3 VecEnv api - done = done or truncated + done = terminated or truncated info["TimeLimit.truncated"] = truncated if done: # save final observation where user can get it, then reset diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 9c6d6293b1..35a785a76f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.6.1 +2.0.0a0 From 056454bace90cb21e33d1b5b5265688239968629 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 5 Oct 2022 15:04:36 +0200 Subject: [PATCH 058/153] Fix doc build --- docs/misc/changelog.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 65ed29b242..e9c9c0fef4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -8,9 +8,8 @@ Release 1.6.2a0 (WIP) .. warning:: - This version will be the last one supporting ``gym``, - we recommend switching to `gymnasium `_. - You can find a migration guide here: TODO + This version will be the last one supporting ``gym``, we recommend switching to `gymnasium `_. + You can find a migration guide here: TODO Breaking Changes: From 9ad927b631ca0b57bf5bd6c3a28c39cb96b4976d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 31 Oct 2022 13:55:56 +0100 Subject: [PATCH 059/153] Fix breaking change --- stable_baselines3/common/envs/bit_flipping_env.py | 1 - stable_baselines3/common/envs/identity_env.py | 9 ++++++--- stable_baselines3/common/envs/multi_input_envs.py | 2 +- stable_baselines3/common/vec_env/dummy_vec_env.py | 4 +++- stable_baselines3/common/vec_env/subproc_vec_env.py | 2 +- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 0fc93a6cf6..82cc1a0047 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -182,7 +182,6 @@ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: # Episode terminate when we reached the goal or the max number of steps info = {"is_success": done} truncated = self.current_step >= self.max_steps - done = done or truncated return obs, reward, done, truncated, info def compute_reward( diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index aa71696111..fc5b200fe0 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -44,7 +44,8 @@ def step(self, action: Union[int, np.ndarray]) -> Gym26StepReturn: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = truncated = self.current_step >= self.ep_length + done = False + truncated = self.current_step >= self.ep_length return self.state, reward, done, truncated, {} def _choose_next_state(self) -> None: @@ -75,7 +76,8 @@ def step(self, action: np.ndarray) -> Gym26StepReturn: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = truncated = self.current_step >= self.ep_length + done = False + truncated = self.current_step >= self.ep_length return self.state, reward, done, truncated, {} def _get_reward(self, action: np.ndarray) -> float: @@ -147,7 +149,8 @@ def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]: def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: reward = 0.0 self.current_step += 1 - done = truncated = self.current_step >= self.ep_length + done = False + truncated = self.current_step >= self.ep_length return self.observation_space.sample(), reward, done, truncated, {} def render(self, mode: str = "human") -> None: diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index c7b5973daa..36af4e4405 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -153,7 +153,7 @@ def step(self, action: Union[int, float, np.ndarray]) -> Gym26StepReturn: got_to_end = self.state == self.max_state reward = 1 if got_to_end else reward truncated = self.count > self.max_count - done = got_to_end or truncated + done = got_to_end self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index b80efd4ae1..2f34c4cfa2 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -46,7 +46,9 @@ def step_wait(self) -> VecEnvStepReturn: ) # convert to SB3 VecEnv api self.buf_dones[env_idx] = terminated or truncated - self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated + # See https://github.com/openai/gym/issues/3102 + # Gym 0.26 introduces a breaking change + self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated if self.buf_dones[env_idx]: # save final observation where user can get it, then reset diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index d133f1d639..8aa3e59e3d 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -33,7 +33,7 @@ def _worker( # noqa: C901 observation, reward, terminated, truncated, info = env.step(data) # convert to SB3 VecEnv api done = terminated or truncated - info["TimeLimit.truncated"] = truncated + info["TimeLimit.truncated"] = truncated and not terminated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation From 3ca1b734ef956b7e369464d3e0ac394868899169 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 31 Oct 2022 16:05:53 +0100 Subject: [PATCH 060/153] Fixes for rendering --- stable_baselines3/common/vec_env/dummy_vec_env.py | 2 +- stable_baselines3/common/vec_env/vec_video_recorder.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 2f34c4cfa2..3b26592a2c 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -94,7 +94,7 @@ def render(self, mode: str = "human") -> Optional[np.ndarray]: :param mode: The rendering type. """ if self.num_envs == 1: - return self.envs[0].render(mode=mode) + return self.envs[0].render() else: return super().render(mode=mode) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 70d74ebe4c..54ba4964f6 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -48,6 +48,7 @@ def __init__( metadata = temp_env.metadata self.env.metadata = metadata + self.env.render_mode = "rgb_array" self.record_video_trigger = record_video_trigger self.video_recorder = None From 0f5374f2c2665b53cbcb76f5d495dc7b2c438d74 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 31 Oct 2022 16:16:30 +0100 Subject: [PATCH 061/153] Rename variables in monitor --- stable_baselines3/common/monitor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 499753553f..3f26a74dbc 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -87,13 +87,13 @@ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: Step the environment with the given action :param action: the action - :return: observation, reward, done, information + :return: observation, reward, terminated, truncated, information """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") - observation, reward, done, truncated, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) self.rewards.append(reward) - if done or truncated: + if terminated or truncated: self.needs_reset = True ep_rew = sum(self.rewards) ep_len = len(self.rewards) @@ -108,7 +108,7 @@ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: self.results_writer.write_row(ep_info) info["episode"] = ep_info self.total_steps += 1 - return observation, reward, done, truncated, info + return observation, reward, terminated, truncated, info def close(self) -> None: """ From 3320e782b7f7497f140ab69dae588d9134313dc6 Mon Sep 17 00:00:00 2001 From: tlpss Date: Thu, 3 Nov 2022 21:17:46 +0100 Subject: [PATCH 062/153] update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) --- .../common/vec_env/base_vec_env.py | 61 ++++++++++++------- .../common/vec_env/dummy_vec_env.py | 8 +-- .../common/vec_env/subproc_vec_env.py | 20 +++--- tests/test_vec_envs.py | 17 +++++- 4 files changed, 70 insertions(+), 36 deletions(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 28ef076849..98a398bb34 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,6 +1,6 @@ import inspect -import warnings from abc import ABC, abstractmethod +from multiprocessing.sharedctypes import Value from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import cloudpickle @@ -55,10 +55,17 @@ class VecEnv(ABC): metadata = {"render.modes": ["human", "rgb_array"]} - def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space): + def __init__( + self, + num_envs: int, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + render_mode: Optional[str] = None, + ): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space + self.render_mode = render_mode self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method @abstractmethod @@ -162,35 +169,45 @@ def step(self, actions: np.ndarray) -> VecEnvStepReturn: self.step_async(actions) return self.step_wait() - def get_images(self) -> Sequence[np.ndarray]: + def get_render_output(self) -> Sequence[Optional[np.ndarray]]: """ - Return RGB images from each environment + Return Render output from each environment """ raise NotImplementedError - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ Gym environment rendering :param mode: the rendering type """ - try: - imgs = self.get_images() - except NotImplementedError: - warnings.warn(f"Render not defined for {self}") - return - # Create a big image by tiling images from subprocesses - bigimg = tile_images(imgs) - if mode == "human": - import cv2 # pytype:disable=import-error + if mode and self.render_mode != mode: + raise ValueError( + f"""starting from gym v0.26, render modes are determined during the initialization of the environment. + We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) + has to be the same as the environment render mode ({self.render_mode}) whichs is not the case.""" + ) + + mode = self.render_mode - cv2.imshow("vecenv", bigimg[:, :, ::-1]) - cv2.waitKey(1) - elif mode == "rgb_array": + # call the render method of the environments + render_output = self.get_render_output() + + if mode == "rgb_array": + + # Create a big image by tiling images from subprocesses + bigimg = tile_images(render_output) return bigimg + + elif mode == "rgb_array_list": + # TODO: a new 'rgb_array_list' mode has been defined and should be handled. + raise NotImplementedError("This mode has not yet been implemented in Stable Baselines.") + else: - raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs") + # other render methods are simply ignored. + # for 'human' or None, the render output will be a List of None values + return @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -251,6 +268,7 @@ def __init__( venv: VecEnv, observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, + render_mode: Optional[str] = None, ): self.venv = venv VecEnv.__init__( @@ -258,6 +276,7 @@ def __init__( num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, action_space=action_space or venv.action_space, + render_mode=render_mode, ) self.class_attributes = dict(inspect.getmembers(self.__class__)) @@ -278,11 +297,11 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: def close(self) -> None: return self.venv.close() - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: return self.venv.render(mode=mode) - def get_images(self) -> Sequence[np.ndarray]: - return self.venv.get_images() + def get_render_output(self) -> Sequence[np.ndarray]: + return self.venv.get_render_output() def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: return self.venv.get_attr(attr_name, indices) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 3b26592a2c..7990829721 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -24,7 +24,7 @@ class DummyVecEnv(VecEnv): def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.envs = [fn() for fn in env_fns] env = self.envs[0] - VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) + VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode) obs_space = env.observation_space self.keys, shapes, dtypes = obs_space_info(obs_space) @@ -78,10 +78,10 @@ def close(self) -> None: for env in self.envs: env.close() - def get_images(self) -> Sequence[np.ndarray]: - return [env.render(mode="rgb_array") for env in self.envs] + def get_render_output(self) -> Sequence[Optional[np.ndarray]]: + return [env.render() for env in self.envs] - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ Gym environment rendering. If there are multiple environments then they are tiled together in one image via ``BaseVecEnv.render()``. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 8aa3e59e3d..446162d56d 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -45,13 +45,15 @@ def _worker( # noqa: C901 observation, reset_info = env.reset() remote.send((observation, reset_info)) elif cmd == "render": - remote.send(env.render(data)) + remote.send(env.render()) elif cmd == "close": env.close() remote.close() break elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) + elif cmd == "get_render_mode": + remote.send(env.render_mode) elif cmd == "env_method": method = getattr(env, data[0]) remote.send(method(*data[1], **data[2])) @@ -116,7 +118,10 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[ self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() - VecEnv.__init__(self, len(env_fns), observation_space, action_space) + + self.remotes[0].send(("get_render_mode", None)) + render_mode = self.remotes[0].recv() + VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode) def step_async(self, actions: np.ndarray) -> None: for remote, action in zip(self.remotes, actions): @@ -155,13 +160,12 @@ def close(self) -> None: process.join() self.closed = True - def get_images(self) -> Sequence[np.ndarray]: + def get_render_output(self) -> Sequence[Optional[np.ndarray]]: for pipe in self.remotes: - # gather images from subprocesses - # `mode` will be taken into account later - pipe.send(("render", "rgb_array")) - imgs = [pipe.recv() for pipe in self.remotes] - return imgs + # gather render return from subprocesses + pipe.send(("render", None)) + outputs = [pipe.recv() for pipe in self.remotes] + return outputs def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index be357655e0..d7bd817cdd 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,6 +2,7 @@ import functools import itertools import multiprocessing +from multiprocessing.sharedctypes import Value from typing import Optional import gym @@ -25,6 +26,7 @@ def __init__(self, space): self.observation_space = space self.current_step = 0 self.ep_length = 4 + self.render_mode = "rgb_array" def reset(self, seed: Optional[int] = None): if seed is not None: @@ -43,8 +45,8 @@ def step(self, action): def _choose_next_state(self): self.state = self.observation_space.sample() - def render(self, mode="human"): - if mode == "rgb_array": + def render(self): + if self.render_mode == "rgb_array": return np.zeros((4, 4, 3)) def seed(self, seed=None): @@ -83,9 +85,18 @@ def make_env(): # Test seed method vec_env.seed(0) + # Test render method call # vec_env.render() # we need a X server to test the "human" mode - vec_env.render(mode="rgb_array") + array_explicit_mode = vec_env.render(mode="rgb_array") + # test render withouth argument (new gym API style) + array_implicit_mode = vec_env.render() + assert np.array_equal(array_implicit_mode, array_explicit_mode) + + # test error if you try different render mode + with pytest.raises(ValueError): + vec_env.render(mode="human") + env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set current_step to an arbitrary value From 1596ea45605ed7be85a9ac8708960335b777640a Mon Sep 17 00:00:00 2001 From: tlpss Date: Thu, 3 Nov 2022 21:27:11 +0100 Subject: [PATCH 063/153] update tests and docs to new gym render API --- docs/guide/checking_nan.rst | 2 +- docs/guide/custom_env.rst | 2 +- stable_baselines3/common/env_checker.py | 30 ++++++------------- .../common/vec_env/base_vec_env.py | 1 - tests/test_dict_env.py | 2 +- tests/test_env_checker.py | 2 +- tests/test_vec_check_nan.py | 2 +- tests/test_vec_envs.py | 1 - tests/test_vec_extract_dict_obs.py | 2 +- 9 files changed, 15 insertions(+), 29 deletions(-) diff --git a/docs/guide/checking_nan.rst b/docs/guide/checking_nan.rst index 29f9e318c7..4de484b476 100644 --- a/docs/guide/checking_nan.rst +++ b/docs/guide/checking_nan.rst @@ -129,7 +129,7 @@ It will monitor the actions, observations, and rewards, indicating what action o def reset(self): return [0.0] - def render(self, mode="human", close=False): + def render(self, close=False): pass # Create environment diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 9fbb5277ca..dfd491a46d 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -45,7 +45,7 @@ That is to say, your environment must implement the following methods (and inher def reset(self): ... return observation # reward, done, info can't be included - def render(self, mode="human"): + def render(self): ... def close (self): ... diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index c383c8edab..f1840a191c 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -285,9 +285,9 @@ def _check_spaces(env: gym.Env) -> None: # Check render cannot be covered by CI -def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover +def _check_render(env: gym.Env) -> None: # pragma: no cover """ - Check the declared render modes and the `render()`/`close()` + Check the instantiated render mode (if any) by calling the `render()`/`close()` method of the environment. :param env: The environment to check @@ -295,24 +295,12 @@ def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> No :param headless: Whether to disable render modes that require a graphical interface. False by default. """ - render_modes = env.metadata.get("render.modes") - if render_modes is None: - if warn: - warnings.warn( - "No render modes was declared in the environment " - " (env.metadata['render.modes'] is None or not defined), " - "you may have trouble when calling `.render()`" - ) - - else: - # Don't check render mode that require a - # graphical interface (useful for CI) - if headless and "human" in render_modes: - render_modes.remove("human") - # Check all declared render modes - for render_mode in render_modes: - env.render(mode=render_mode) - env.close() + # render_modes = env.metadata.get("render.modes") + # TODO: if we want to check all render modes, + # we need to initialize new environments so the class should be passed as argument. + if env.render_mode: + env.render() + env.close() def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: @@ -377,7 +365,7 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - # ==== Check the render method and the declared render modes ==== if not skip_render_check: - _check_render(env, warn=warn) # pragma: no cover + _check_render(env) # pragma: no cover # The check only works with numpy arrays if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space): diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 98a398bb34..4985e57258 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,6 +1,5 @@ import inspect from abc import ABC, abstractmethod -from multiprocessing.sharedctypes import Value from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import cloudpickle diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 1f832bfeb8..8e1a08a524 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -79,7 +79,7 @@ def reset(self, seed: Optional[int] = None): self.observation_space.seed(seed) return self.observation_space.sample(), {} - def render(self, mode="human"): + def render(self): pass diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 2313defc9f..034ba498d8 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -21,7 +21,7 @@ def step(self, action): def reset(self): return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} - def render(self, mode="human"): + def render(self): pass diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index f09a7aae4a..10ecdac99b 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -30,7 +30,7 @@ def step(action): def reset(): return [0.0], {} - def render(self, mode="human", close=False): + def render(self, close=False): pass diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index d7bd817cdd..99fa33565d 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,7 +2,6 @@ import functools import itertools import multiprocessing -from multiprocessing.sharedctypes import Value from typing import Optional import gym diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py index 15074425ed..6aa4abdbde 100644 --- a/tests/test_vec_extract_dict_obs.py +++ b/tests/test_vec_extract_dict_obs.py @@ -29,7 +29,7 @@ def step_wait(self): def reset(self): return {"rgb": np.zeros((self.num_envs, 86, 86))} - def render(self, mode="human", close=False): + def render(self, close=False): pass From 008fdce185842de6b20d6610fcf1b0e4428fe4cd Mon Sep 17 00:00:00 2001 From: tlpss Date: Mon, 7 Nov 2022 22:17:44 +0100 Subject: [PATCH 064/153] undo removal of render modes metatadata check --- stable_baselines3/common/env_checker.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index f1840a191c..808cdf6f3a 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -285,7 +285,7 @@ def _check_spaces(env: gym.Env) -> None: # Check render cannot be covered by CI -def _check_render(env: gym.Env) -> None: # pragma: no cover +def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover """ Check the instantiated render mode (if any) by calling the `render()`/`close()` method of the environment. @@ -295,8 +295,16 @@ def _check_render(env: gym.Env) -> None: # pragma: no cover :param headless: Whether to disable render modes that require a graphical interface. False by default. """ - # render_modes = env.metadata.get("render.modes") - # TODO: if we want to check all render modes, + render_modes = env.metadata.get("render.modes") + if render_modes is None: + if warn: + warnings.warn( + "No render modes was declared in the environment " + " (env.metadata['render.modes'] is None or not defined), " + "you may have trouble when calling `.render()`" + ) + + # TODO: if we want to check all declared render modes, # we need to initialize new environments so the class should be passed as argument. if env.render_mode: env.render() @@ -365,7 +373,7 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - # ==== Check the render method and the declared render modes ==== if not skip_render_check: - _check_render(env) # pragma: no cover + _check_render(env, warn) # pragma: no cover # The check only works with numpy arrays if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space): From 93bd9888d2e5d08f3f0ff12c9423c0301614fa2a Mon Sep 17 00:00:00 2001 From: tlpss Date: Mon, 7 Nov 2022 22:54:38 +0100 Subject: [PATCH 065/153] set rgb_array as default render mode for gym.make --- stable_baselines3/common/env_util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index bb440f99c5..119eb4439a 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -74,6 +74,7 @@ def make_vec_env( :return: The wrapped environment """ env_kwargs = {} if env_kwargs is None else env_kwargs + vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs @@ -81,7 +82,10 @@ def make_vec_env( def make_env(rank): def _init(): if isinstance(env_id, str): - env = gym.make(env_id, **env_kwargs) + # if the render mode was not specified, we set it to `rgb_array` as default. + kwargs = {"render_mode": "rgb_array"} + kwargs.update(env_kwargs) + env = gym.make(env_id, **kwargs) else: env = env_id(**env_kwargs) if seed is not None: From 53da2d009b7e61f548d1ef4ef8592dda83af2886 Mon Sep 17 00:00:00 2001 From: tlpss Date: Sun, 20 Nov 2022 17:21:17 +0100 Subject: [PATCH 066/153] undo changes & raise warning if not 'rgb_array' --- stable_baselines3/common/vec_env/base_vec_env.py | 11 +++++------ stable_baselines3/common/vec_env/dummy_vec_env.py | 6 +++++- stable_baselines3/common/vec_env/subproc_vec_env.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 4985e57258..4165e4591f 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -168,7 +168,7 @@ def step(self, actions: np.ndarray) -> VecEnvStepReturn: self.step_async(actions) return self.step_wait() - def get_render_output(self) -> Sequence[Optional[np.ndarray]]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: """ Return Render output from each environment """ @@ -191,12 +191,11 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: mode = self.render_mode # call the render method of the environments - render_output = self.get_render_output() + images = self.get_images() if mode == "rgb_array": - # Create a big image by tiling images from subprocesses - bigimg = tile_images(render_output) + bigimg = tile_images(images) return bigimg elif mode == "rgb_array_list": @@ -299,8 +298,8 @@ def close(self) -> None: def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: return self.venv.render(mode=mode) - def get_render_output(self) -> Sequence[np.ndarray]: - return self.venv.get_render_output() + def get_images(self) -> Sequence[np.ndarray]: + return self.venv.get_images() def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: return self.venv.get_attr(attr_name, indices) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 7990829721..c1067740e9 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -78,7 +78,11 @@ def close(self) -> None: for env in self.envs: env.close() - def get_render_output(self) -> Sequence[Optional[np.ndarray]]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: + if self.envs[0].render_mode != "rgb_array": + raise RuntimeWarning( + "The render mode is {self.envs[0].render_mode}, but this method assumes it is `rgb_array` to obtain images." + ) return [env.render() for env in self.envs] def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 446162d56d..2d8c7e3500 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -160,7 +160,7 @@ def close(self) -> None: process.join() self.closed = True - def get_render_output(self) -> Sequence[Optional[np.ndarray]]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: for pipe in self.remotes: # gather render return from subprocesses pipe.send(("render", None)) From 54fb37e09e0c8de82ea6097b449dd362e2e75187 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Nov 2022 16:45:50 +0100 Subject: [PATCH 067/153] Fix type check --- Makefile | 6 ++++ stable_baselines3/common/env_checker.py | 28 +++++++++++-------- .../common/envs/bit_flipping_env.py | 4 ++- stable_baselines3/common/envs/identity_env.py | 4 +-- .../common/envs/multi_input_envs.py | 2 +- stable_baselines3/common/evaluation.py | 9 ++++-- stable_baselines3/common/noise.py | 2 +- tests/test_dict_env.py | 4 +-- tests/test_gae.py | 6 ++-- tests/test_spaces.py | 6 ++-- tests/test_vec_envs.py | 4 +-- tests/test_vec_normalize.py | 8 +++--- 12 files changed, 51 insertions(+), 32 deletions(-) diff --git a/Makefile b/Makefile index a1ef72f04b..d2c5830461 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,12 @@ pytype: mypy: mypy ${LINT_PATHS} +missing-annotations: + mypy --disallow-untyped-calls --disallow-untyped-defs --ignore-missing-imports stable_baselines3 + +# missing docstrings +# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4 + type: pytype mypy lint: diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index c383c8edab..63ad4d3ec7 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -100,7 +100,7 @@ def _is_goal_env(env: gym.Env) -> bool: return hasattr(env, "compute_reward") -def _check_goal_env_obs(obs: dict, observation_space: spaces.Space, method_name: str) -> None: +def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None: """ Check that an environment implementing the `compute_rewards()` method (previously known as GoalEnv in gym) contains three elements, @@ -131,7 +131,7 @@ def _check_goal_env_compute_reward( and that the implementation is vectorized. """ achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"] - assert reward == env.compute_reward( + assert reward == env.compute_reward( # type: ignore[attr-defined] achieved_goal, desired_goal, info ), "The reward was not computed with `compute_reward()`" @@ -142,7 +142,7 @@ def _check_goal_env_compute_reward( batch_achieved_goals = batch_achieved_goals.reshape(2, 1) batch_desired_goals = batch_desired_goals.reshape(2, 1) batch_infos = np.array([info, info]) - rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) + rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) # type: ignore[attr-defined] assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)" assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}" @@ -204,6 +204,8 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action assert isinstance(info, dict), "The second element of the tuple return by `reset()` must be a dictionary" if _is_goal_env(env): + # Make mypy happy, already checked + assert isinstance(observation_space, spaces.Dict) _check_goal_env_obs(obs, observation_space, "reset") elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary" @@ -232,6 +234,8 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action obs, reward, terminated, truncated, info = data if _is_goal_env(env): + # Make mypy happy, already checked + assert isinstance(observation_space, spaces.Dict) _check_goal_env_obs(obs, observation_space, "step") _check_goal_env_compute_reward(obs, env, reward, info) elif isinstance(observation_space, spaces.Dict): @@ -305,14 +309,16 @@ def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> No ) else: - # Don't check render mode that require a - # graphical interface (useful for CI) - if headless and "human" in render_modes: - render_modes.remove("human") - # Check all declared render modes - for render_mode in render_modes: - env.render(mode=render_mode) - env.close() + # FIXME: render check need to be updated + # # Don't check render mode that require a + # # graphical interface (useful for CI) + # if headless and "human" in render_modes: + # render_modes.remove("human") + # # Check all declared render modes + # for render_mode in render_modes: + # env.render(mode=render_mode) + # env.close() + pass def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 82cc1a0047..01e09357fd 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -157,7 +157,9 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ] ) - def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict] = None + ) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: if seed is not None: self.obs_space.seed(seed) self.current_step = 0 diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index fc5b200fe0..10982d4935 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -32,7 +32,7 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_ self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self, seed: Optional[int] = None) -> Gym26ResetReturn: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Gym26ResetReturn: if seed is not None: super().reset(seed=seed) self.current_step = 0 @@ -140,7 +140,7 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 36af4e4405..439626ca60 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -167,7 +167,7 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, np.ndarray], Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]: """ Resets the environment state and step count and returns reset observation. diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index ff18137853..0a4e498e3d 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -59,7 +59,7 @@ def evaluate_policy( from stable_baselines3.common.monitor import Monitor if not isinstance(env, VecEnv): - env = DummyVecEnv([lambda: env]) + env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] @@ -85,7 +85,12 @@ def evaluate_policy( states = None episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): - actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) + actions, states = model.predict( + observations, # type: ignore[arg-type] + state=states, + episode_start=episode_starts, + deterministic=deterministic, + ) observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index baa72e9a7e..5e8632dffb 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -10,7 +10,7 @@ class ActionNoise(ABC): The action noise base class """ - def __init__(self): + def __init__(self) -> None: super().__init__() def reset(self) -> None: diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 1f832bfeb8..650a032fc0 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional import gym import numpy as np @@ -74,7 +74,7 @@ def step(self, action): def compute_reward(self, achieved_goal, desired_goal, info): return np.zeros((len(achieved_goal),)) - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: self.observation_space.seed(seed) return self.observation_space.sample(), {} diff --git a/tests/test_gae.py b/tests/test_gae.py index 35f01ac77b..7aed1d01df 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional import gym import numpy as np @@ -21,7 +21,7 @@ def __init__(self, max_steps=8): def seed(self, seed): self.observation_space.seed(seed) - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: self.observation_space.seed(seed) self.n_steps = 0 @@ -51,7 +51,7 @@ def __init__(self, n_states=4): self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 2f9c796068..81fb2367b0 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional import gym import numpy as np @@ -15,7 +15,7 @@ def __init__(self, nvec): self.observation_space = gym.spaces.MultiDiscrete(nvec) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} @@ -30,7 +30,7 @@ def __init__(self, n): self.observation_space = gym.spaces.MultiBinary(n) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 76fdf61914..b4eecf736d 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,7 +2,7 @@ import functools import itertools import multiprocessing -from typing import Optional +from typing import Dict, Optional import gym import numpy as np @@ -26,7 +26,7 @@ def __init__(self, space): self.current_step = 0 self.ep_length = 4 - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: self.seed(seed) self.current_step = 0 diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index f2876f2e1f..a0239a44b8 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,5 +1,5 @@ import operator -from typing import Optional +from typing import Dict, Optional import gym import numpy as np @@ -37,7 +37,7 @@ def step(self, action): done = truncated = self.t == len(self.returned_rewards) return np.array([returned_value]), returned_value, done, truncated, {} - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) self.t = 0 @@ -60,7 +60,7 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} @@ -92,7 +92,7 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self, seed: Optional[int] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} From 07ca2711f17ae81fbeecf31c554226af792a80fe Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Nov 2022 19:15:45 +0100 Subject: [PATCH 068/153] Remove recursion and fix type checking --- stable_baselines3/common/env_checker.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 9975de6edb..4362e282be 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -97,9 +97,8 @@ def _is_goal_env(env: gym.Env) -> bool: """ Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface) """ - if isinstance(env, gym.Wrapper): # We need to unwrap the env since gym.Wrapper has the compute_reward method - return _is_goal_env(env.unwrapped) - return hasattr(env, "compute_reward") + # We need to unwrap the env since gym.Wrapper has the compute_reward method + return hasattr(env.unwrapped, "compute_reward") def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None: @@ -262,6 +261,8 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # Goal conditioned env if _is_goal_env(env): + # for mypy, env.unwrapped was checked by _is_goal_env() + assert hasattr(env, "compute_reward") assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) From c0a6a181fbde4d6c1b2c77a281fc1df063af06d6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 30 Nov 2022 00:04:34 +0100 Subject: [PATCH 069/153] Remove hacks for protobuf and gym 0.24 --- Makefile | 3 +-- tests/test_vec_monitor.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index d2c5830461..c70f2f0e10 100644 --- a/Makefile +++ b/Makefile @@ -40,8 +40,7 @@ check-codestyle: commit-checks: format type lint doc: - # Prevent weird error due to protobuf - cd docs && PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp make html + cd docs && make html spelling: cd docs && make spelling diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index ea93413acb..ed2ea89bf8 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -132,7 +132,7 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ - env = DummyVecEnv([lambda: gym.make("CartPole-v1", disable_env_checker=True)]) + env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env.seed(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") From b954703c25c3c4876ed29994dc53b65c00cd0aa0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 30 Nov 2022 00:33:54 +0100 Subject: [PATCH 070/153] Fix type annotations --- stable_baselines3/common/env_checker.py | 2 +- stable_baselines3/common/env_util.py | 22 ++++++++++++++-------- stable_baselines3/common/torch_layers.py | 6 +++++- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 4362e282be..6b87be45d9 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -126,7 +126,7 @@ def _check_goal_env_compute_reward( env: gym.Env, reward: float, info: Dict[str, Any], -): +) -> None: """ Check that reward is computed with `compute_reward` and that the implementation is vectorized. diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index fc54867cab..2ae4d213e6 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -25,7 +25,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g return None -def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool: +def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool: """ Check if a given environment has been wrapped with a given wrapper. @@ -73,13 +73,19 @@ def make_vec_env( :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. :return: The wrapped environment """ - env_kwargs = {} if env_kwargs is None else env_kwargs - vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs - monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs - wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs + env_kwargs = env_kwargs or {} + vec_env_kwargs = vec_env_kwargs or {} + monitor_kwargs = monitor_kwargs or {} + wrapper_kwargs = wrapper_kwargs or {} + assert vec_env_kwargs is not None # for mypy + + def make_env(rank: int) -> Callable[[], gym.Env]: + def _init() -> gym.Env: + # For type checker: + assert monitor_kwargs is not None + assert wrapper_kwargs is not None + assert env_kwargs is not None - def make_env(rank): - def _init(): if isinstance(env_id, str): env = gym.make(env_id, **env_kwargs) else: @@ -91,7 +97,7 @@ def _init(): # to have additional training information monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None # Create the monitor folder if needed - if monitor_path is not None: + if monitor_path is not None and monitor_dir is not None: os.makedirs(monitor_dir, exist_ok=True) env = Monitor(env, filename=monitor_path, **monitor_kwargs) # Optionally, wrap the environment with the provided wrapper diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 5de2af1ce4..1bbdf0c72a 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -57,7 +57,11 @@ class NatureCNN(BaseFeaturesExtractor): This corresponds to the number of unit for the last layer. """ - def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): + def __init__(self, observation_space: gym.Space, features_dim: int = 512): + assert isinstance(observation_space, gym.spaces.Box), ( + "NatureCNN must be used with a gym.spaces.Box ", + f"observation space, not {observation_space}", + ) super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper From 870139c8dd7e3ae240cd73b69e9471ccc5dfb27a Mon Sep 17 00:00:00 2001 From: tlpss Date: Sat, 3 Dec 2022 18:36:22 +0100 Subject: [PATCH 071/153] reuse existing render_mode attribute --- stable_baselines3/common/vec_env/dummy_vec_env.py | 2 +- stable_baselines3/common/vec_env/subproc_vec_env.py | 8 +++++--- tests/test_vec_check_nan.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 664624a9a5..8d68a0c6a8 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -90,7 +90,7 @@ def close(self) -> None: env.close() def get_images(self) -> Sequence[Optional[np.ndarray]]: - if self.envs[0].render_mode != "rgb_array": + if self.render_mode != "rgb_array": raise RuntimeWarning( "The render mode is {self.envs[0].render_mode}, but this method assumes it is `rgb_array` to obtain images." ) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 2d8c7e3500..7c0381f1ef 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -52,8 +52,6 @@ def _worker( # noqa: C901 break elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) - elif cmd == "get_render_mode": - remote.send(env.render_mode) elif cmd == "env_method": method = getattr(env, data[0]) remote.send(method(*data[1], **data[2])) @@ -119,7 +117,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[ self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() - self.remotes[0].send(("get_render_mode", None)) + self.remotes[0].send(("get_attr", "render_mode")) render_mode = self.remotes[0].recv() VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode) @@ -161,6 +159,10 @@ def close(self) -> None: self.closed = True def get_images(self) -> Sequence[Optional[np.ndarray]]: + if self.render_mode != "rgb_array": + raise RuntimeWarning( + "The render mode is {self.envs[0].render_mode}, but this method assumes it is `rgb_array` to obtain images." + ) for pipe in self.remotes: # gather render return from subprocesses pipe.send(("render", None)) diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 10ecdac99b..48b203e891 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -30,7 +30,7 @@ def step(action): def reset(): return [0.0], {} - def render(self, close=False): + def render(self): pass From bc5335ff62e4f240a38e8498a70183ec9dc3f5cb Mon Sep 17 00:00:00 2001 From: tlpss Date: Sat, 3 Dec 2022 18:36:38 +0100 Subject: [PATCH 072/153] return tiled images for 'human' render mode --- stable_baselines3/common/vec_env/base_vec_env.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 4165e4591f..49c0bc2c1a 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -170,7 +170,7 @@ def step(self, actions: np.ndarray) -> VecEnvStepReturn: def get_images(self) -> Sequence[Optional[np.ndarray]]: """ - Return Render output from each environment + Return RGB images from each environment when available """ raise NotImplementedError @@ -193,7 +193,7 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: # call the render method of the environments images = self.get_images() - if mode == "rgb_array": + if mode == "rgb_array" or mode == "human": # Create a big image by tiling images from subprocesses bigimg = tile_images(images) return bigimg @@ -202,11 +202,6 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: # TODO: a new 'rgb_array_list' mode has been defined and should be handled. raise NotImplementedError("This mode has not yet been implemented in Stable Baselines.") - else: - # other render methods are simply ignored. - # for 'human' or None, the render output will be a List of None values - return - @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: """ From 11cf07f569bfe2f5e08bcfaa45ddfe5a37abd948 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 6 Dec 2022 17:46:16 +0100 Subject: [PATCH 073/153] Allow to use opencv for human render, fix typos --- .../common/vec_env/base_vec_env.py | 41 ++++++++++++++----- .../common/vec_env/dummy_vec_env.py | 2 +- .../common/vec_env/subproc_vec_env.py | 2 +- tests/test_vec_envs.py | 12 +++--- 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 49c0bc2c1a..7315f4eb59 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -181,26 +181,47 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: :param mode: the rendering type """ - if mode and self.render_mode != mode: + if mode == "human" and self.render_mode != mode: + # Special case, if the render_mode="rgb_array" + # we can still display that image using opencv + assert self.render_mode == "rgb_array", ( + f"You tried to render a VecEnv with mode='{mode}' " + "but the render mode defined when initializing the environment must be " + f"'human' or 'rgb_array', not '{self.render_mode}'." + ) + + elif mode and self.render_mode != mode: raise ValueError( - f"""starting from gym v0.26, render modes are determined during the initialization of the environment. - We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) - has to be the same as the environment render mode ({self.render_mode}) whichs is not the case.""" + f"""Starting from gym v0.26, render modes are determined during the initialization of the environment. + We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) + has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" ) - mode = self.render_mode + mode = mode or self.render_mode - # call the render method of the environments - images = self.get_images() + # TODO: handle the case where mode == self.render_mode == "human" + # In that case, we can try to call `self.env.render()` but it might + # crash for subprocesses (TO BE TESTED) + # if self.render_mode == "human" if mode == "rgb_array" or mode == "human": + # call the render method of the environments + images = self.get_images() # Create a big image by tiling images from subprocesses bigimg = tile_images(images) - return bigimg - elif mode == "rgb_array_list": + if mode == "human": + # Display it using OpenCV + import cv2 # pytype:disable=import-error + + cv2.imshow("vecenv", bigimg[:, :, ::-1]) + cv2.waitKey(1) + else: + return bigimg + + else: # TODO: a new 'rgb_array_list' mode has been defined and should be handled. - raise NotImplementedError("This mode has not yet been implemented in Stable Baselines.") + raise NotImplementedError(f"The render mode {mode} has not yet been implemented in Stable Baselines.") @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 8d68a0c6a8..b1aa68f100 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -92,7 +92,7 @@ def close(self) -> None: def get_images(self) -> Sequence[Optional[np.ndarray]]: if self.render_mode != "rgb_array": raise RuntimeWarning( - "The render mode is {self.envs[0].render_mode}, but this method assumes it is `rgb_array` to obtain images." + f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." ) return [env.render() for env in self.envs] diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 7c0381f1ef..b08dcd7bb3 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -161,7 +161,7 @@ def close(self) -> None: def get_images(self) -> Sequence[Optional[np.ndarray]]: if self.render_mode != "rgb_array": raise RuntimeWarning( - "The render mode is {self.envs[0].render_mode}, but this method assumes it is `rgb_array` to obtain images." + f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." ) for pipe in self.remotes: # gather render return from subprocesses diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 4b93d2448f..4833039122 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -17,7 +17,7 @@ class CustomGymEnv(gym.Env): - def __init__(self, space): + def __init__(self, space, render_mode: str = "rgb_array"): """ Custom gym environment for testing purposes """ @@ -25,7 +25,7 @@ def __init__(self, space): self.observation_space = space self.current_step = 0 self.ep_length = 4 - self.render_mode = "rgb_array" + self.render_mode = render_mode def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: @@ -96,15 +96,17 @@ def make_env(): vec_env.seed(0) # Test render method call - # vec_env.render() # we need a X server to test the "human" mode array_explicit_mode = vec_env.render(mode="rgb_array") - # test render withouth argument (new gym API style) + # test render without argument (new gym API style) array_implicit_mode = vec_env.render() assert np.array_equal(array_implicit_mode, array_explicit_mode) # test error if you try different render mode with pytest.raises(ValueError): - vec_env.render(mode="human") + vec_env.render(mode="something_else") + + # we need a X server to test the "human" mode (uses OpenCV) + # vec_env.render(mode="human") env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] From 3f75a8aedc9ee183238919d7a71734a0ef68872c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Dec 2022 10:42:54 +0100 Subject: [PATCH 074/153] Add warning when using non-zero start with Discrete (fixes #1197) --- stable_baselines3/common/env_checker.py | 14 +++++++++++++- tests/test_envs.py | 4 ++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 6b87be45d9..a229b83258 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -56,9 +56,15 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act if isinstance(observation_space, spaces.Dict): nested_dict = False - for space in observation_space.spaces.values(): + for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True + if isinstance(space, spaces.Discrete) and space.start != 0: + warnings.warn( + f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your observation space." + ) + if nested_dict: warnings.warn( "Nested observation spaces are not supported by Stable Baselines3 " @@ -77,6 +83,12 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "which is supported by SB3." ) + if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0: + warnings.warn( + "Discrete observation space with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your observation space." + ) + if not _is_numpy_array_space(action_space): warnings.warn( "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " diff --git a/tests/test_envs.py b/tests/test_envs.py index 5c00b9473f..30c4593c7c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -121,6 +121,10 @@ def patched_step(_action): spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), # Small image inside a dict spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), + # Non zero start index + spaces.Discrete(3, start=-1), + # Non zero start index inside a Dict + spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], ) def test_non_default_spaces(new_obs_space): From 6251fdc42daf071d29740fe11ef1841794347453 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 19:31:05 +0100 Subject: [PATCH 075/153] Fix type checking --- stable_baselines3/common/atari_wrappers.py | 23 ++++++++++++++-------- stable_baselines3/common/preprocessing.py | 8 ++++---- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 05c9888efd..434933c358 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -28,7 +28,7 @@ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: self.noop_max = noop_max self.override_num_noops = None self.noop_action = 0 - assert env.unwrapped.get_action_meanings()[0] == "NOOP" + assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) @@ -38,7 +38,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) - info = {} + info: Dict = {} for _ in range(noops): obs, _, terminated, truncated, info = self.env.step(self.noop_action) if terminated or truncated: @@ -55,8 +55,8 @@ class FireResetEnv(gym.Wrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) - assert env.unwrapped.get_action_meanings()[1] == "FIRE" - assert len(env.unwrapped.get_action_meanings()) >= 3 + assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] + assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) @@ -87,7 +87,7 @@ def step(self, action: int) -> Gym26StepReturn: self.was_real_done = terminated or truncated # check current lives, make loss of life terminal, # then update lives to handle bonus lives - lives = self.env.unwrapped.ale.lives() + lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] if 0 < lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once @@ -110,7 +110,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: else: # no-op step to advance from terminal/lost life state obs, _, _, _, info = self.env.step(0) - self.lives = self.env.unwrapped.ale.lives() + self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] return obs, info @@ -125,6 +125,8 @@ class MaxAndSkipEnv(gym.Wrapper): def __init__(self, env: gym.Env, skip: int = 4) -> None: super().__init__(env) # most recent raw observations (for max pooling across time steps) + assert env.observation_space.dtype is not None, "No dtype specified for the observation space" + assert env.observation_space.shape is not None, "No shape defined for the observation space" self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype) self._skip = skip @@ -192,8 +194,13 @@ def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None: super().__init__(env) self.width = width self.height = height + assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}" + self.observation_space = spaces.Box( - low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype + low=0, + high=255, + shape=(self.height, self.width, 1), + dtype=env.observation_space.dtype, # type: ignore[arg-type] ) def observation(self, frame: np.ndarray) -> np.ndarray: @@ -245,7 +252,7 @@ def __init__( env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) - if "FIRE" in env.unwrapped.get_action_meanings(): + if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined] env = FireResetEnv(env) env = WarpFrame(env, width=screen_size, height=screen_size) if clip_reward: diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 76dab705d0..33d43b5b87 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -151,10 +151,7 @@ def get_obs_shape( return (int(len(observation_space.nvec)),) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features - if type(observation_space.n) in [tuple, list, np.ndarray]: - return tuple(observation_space.n) - else: - return (int(observation_space.n),) + return observation_space.shape elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] @@ -198,6 +195,9 @@ def get_action_dim(action_space: spaces.Space) -> int: return int(len(action_space.nvec)) elif isinstance(action_space, spaces.MultiBinary): # Number of binary actions + assert isinstance(action_space.n, int), ( + "Multi-dimensional MultiBinary action space is not supported. " "You can flatten it instead." + ) return int(action_space.n) else: raise NotImplementedError(f"{action_space} action space is not supported") From 6b80c9372729ecdaa3f8467b8156535f3d1583df Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 20:39:42 +0100 Subject: [PATCH 076/153] Bug fixes and handle more cases --- stable_baselines3/common/base_class.py | 6 +++- stable_baselines3/common/env_util.py | 5 +++- .../common/vec_env/base_vec_env.py | 30 ++++++++++++------- .../common/vec_env/dummy_vec_env.py | 10 +------ 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 13092296ed..e310e2c6d5 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -51,7 +51,11 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE if isinstance(env, str): if verbose >= 1: print(f"Creating environment from the given name '{env}'") - env = gym.make(env) + # Set render_mode to `rgb_array` as default, so we can record video + try: + env = gym.make(env, render_mode="rgb_array") + except TypeError: + env = gym.make(env) return env diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index dce16f2243..cf10246495 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -90,7 +90,10 @@ def _init() -> gym.Env: # if the render mode was not specified, we set it to `rgb_array` as default. kwargs = {"render_mode": "rgb_array"} kwargs.update(env_kwargs) - env = gym.make(env_id, **kwargs) + try: + env = gym.make(env_id, **kwargs) + except TypeError: + env = gym.make(env_id, **env_kwargs) else: env = env_id(**env_kwargs) if seed is not None: diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 7315f4eb59..db4bc55909 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,4 +1,5 @@ import inspect +import warnings from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union @@ -184,25 +185,34 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: if mode == "human" and self.render_mode != mode: # Special case, if the render_mode="rgb_array" # we can still display that image using opencv - assert self.render_mode == "rgb_array", ( - f"You tried to render a VecEnv with mode='{mode}' " - "but the render mode defined when initializing the environment must be " - f"'human' or 'rgb_array', not '{self.render_mode}'." - ) + if self.render_mode != "rgb_array": + warnings.warn( + f"You tried to render a VecEnv with mode='{mode}' " + "but the render mode defined when initializing the environment must be " + f"'human' or 'rgb_array', not '{self.render_mode}'." + ) + return elif mode and self.render_mode != mode: - raise ValueError( + warnings.warn( f"""Starting from gym v0.26, render modes are determined during the initialization of the environment. We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" ) + return mode = mode or self.render_mode - # TODO: handle the case where mode == self.render_mode == "human" - # In that case, we can try to call `self.env.render()` but it might - # crash for subprocesses (TO BE TESTED) - # if self.render_mode == "human" + if mode is None: + warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.") + return + + # mode == self.render_mode == "human" + # In that case, we try to call `self.env.render()` but it might + # crash for subprocesses + if self.render_mode == "human": + self.env_method("render") + return if mode == "rgb_array" or mode == "human": # call the render method of the environments diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index b1aa68f100..68a794a1fc 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -100,18 +100,10 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ Gym environment rendering. If there are multiple environments then they are tiled together in one image via ``BaseVecEnv.render()``. - Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the - underlying environment. - - Therefore, some arguments such as ``mode`` will have values that are valid - only when ``num_envs == 1``. :param mode: The rendering type. """ - if self.num_envs == 1: - return self.envs[0].render() - else: - return super().render(mode=mode) + return super().render(mode=mode) def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: for key in self.keys: From c09fa740e8d0a39b6c744a7a27f439984d3825c9 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 20:45:10 +0100 Subject: [PATCH 077/153] Throw proper warnings --- stable_baselines3/common/vec_env/dummy_vec_env.py | 4 +++- stable_baselines3/common/vec_env/subproc_vec_env.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 68a794a1fc..8dcf26c010 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,3 +1,4 @@ +import warnings from collections import OrderedDict from copy import deepcopy from typing import Any, Callable, List, Optional, Sequence, Type, Union @@ -91,9 +92,10 @@ def close(self) -> None: def get_images(self) -> Sequence[Optional[np.ndarray]]: if self.render_mode != "rgb_array": - raise RuntimeWarning( + warnings.warn( f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." ) + return [None for _ in self.envs] return [env.render() for env in self.envs] def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index b08dcd7bb3..775e30902e 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import warnings from collections import OrderedDict from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union @@ -160,9 +161,10 @@ def close(self) -> None: def get_images(self) -> Sequence[Optional[np.ndarray]]: if self.render_mode != "rgb_array": - raise RuntimeWarning( + warnings.warn( f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." ) + return [None for _ in self.remotes] for pipe in self.remotes: # gather render return from subprocesses pipe.send(("render", None)) From 480a793ba493281b60acd86fbf8b728550d13853 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 20:48:48 +0100 Subject: [PATCH 078/153] Update test --- tests/test_vec_envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 4833039122..ddd4a37de3 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -101,8 +101,8 @@ def make_env(): array_implicit_mode = vec_env.render() assert np.array_equal(array_implicit_mode, array_explicit_mode) - # test error if you try different render mode - with pytest.raises(ValueError): + # test warning if you try different render mode + with pytest.warns(UserWarning): vec_env.render(mode="something_else") # we need a X server to test the "human" mode (uses OpenCV) From e03b885eddc903c1509e2188b6dedbbab5c74166 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 20:52:38 +0100 Subject: [PATCH 079/153] Fix new metadata name --- stable_baselines3/common/env_checker.py | 4 ++-- stable_baselines3/common/vec_env/base_vec_env.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index d798269ee2..0bd30dcbf1 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -310,12 +310,12 @@ def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover :param headless: Whether to disable render modes that require a graphical interface. False by default. """ - render_modes = env.metadata.get("render.modes") + render_modes = env.metadata.get("render_modes") if render_modes is None: if warn: warnings.warn( "No render modes was declared in the environment " - " (env.metadata['render.modes'] is None or not defined), " + "(env.metadata['render_modes'] is None or not defined), " "you may have trouble when calling `.render()`" ) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index db4bc55909..3cac2cdfb5 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -53,7 +53,7 @@ class VecEnv(ABC): :param action_space: the action space """ - metadata = {"render.modes": ["human", "rgb_array"]} + metadata = {"render_modes": ["human", "rgb_array"]} def __init__( self, From e4248dfb1c2287987e5d6adff8792d7c663df6d0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 21:07:39 +0100 Subject: [PATCH 080/153] Ignore numpy warnings --- tests/test_vec_monitor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index ed2ea89bf8..ab988b6b40 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,6 +2,7 @@ import json import os import uuid +import warnings import gym import pandas @@ -132,6 +133,7 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r".*passive_env_checker") env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env.seed(seed=0) monitor_env = VecMonitor(env) From 1f8ccbe047667302cf05fc7826d00cc337fb455a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 21:24:44 +0100 Subject: [PATCH 081/153] Fixes in vec recorder --- stable_baselines3/common/vec_env/vec_video_recorder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 54ba4964f6..73f7441b93 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -48,7 +48,7 @@ def __init__( metadata = temp_env.metadata self.env.metadata = metadata - self.env.render_mode = "rgb_array" + assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" self.record_video_trigger = record_video_trigger self.video_recorder = None @@ -111,4 +111,4 @@ def close(self) -> None: self.close_video_recorder() def __del__(self): - self.close() + self.close_video_recorder() From f4e978a6de10a8bbfddaf0d9b07ef145824c68a6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 21:35:00 +0100 Subject: [PATCH 082/153] Global ignore --- setup.cfg | 1 + tests/test_vec_monitor.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index f383e95431..1b149558e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ filterwarnings = ignore::DeprecationWarning:tensorboard # Gym warnings ignore::UserWarning:gym + ignore::DeprecationWarning:.*passive_env_checker.* markers = expensive: marks tests as expensive (deselect with '-m "not expensive"') diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index ab988b6b40..ed2ea89bf8 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,7 +2,6 @@ import json import os import uuid -import warnings import gym import pandas @@ -133,7 +132,6 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ - warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r".*passive_env_checker") env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env.seed(seed=0) monitor_env = VecMonitor(env) From f98903a6ecaf6b23437bcd86ae515f263d3fed19 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Dec 2022 21:56:06 +0100 Subject: [PATCH 083/153] Filter local warning too --- tests/test_vec_monitor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index ed2ea89bf8..ab988b6b40 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,6 +2,7 @@ import json import os import uuid +import warnings import gym import pandas @@ -132,6 +133,7 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r".*passive_env_checker") env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) env.seed(seed=0) monitor_env = VecMonitor(env) From 29086a5e18381f343ffb598ad099938fb87a5b1d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 19 Dec 2022 13:57:17 +0100 Subject: [PATCH 084/153] Monkey patch not needed for gym 26 --- stable_baselines3/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 680e254536..0775a8ec5d 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,7 +1,5 @@ import os -import numpy as np - from stable_baselines3.a2c import A2C from stable_baselines3.common.utils import get_system_info from stable_baselines3.ddpg import DDPG @@ -11,10 +9,6 @@ from stable_baselines3.sac import SAC from stable_baselines3.td3 import TD3 -# Small monkey patch so gym 0.21 is compatible with numpy >= 1.24 -# TODO: remove when upgrading to gym 0.26 -np.bool = bool # type: ignore[attr-defined] - # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") with open(version_file) as file_handler: From 6a3f45d77376faad64bb56fd28f7c7101968e020 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Dec 2022 22:51:54 +0100 Subject: [PATCH 085/153] Add doc of VecEnv vs Gym API --- docs/guide/vec_envs.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index b074dad447..ac8b414a38 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -44,6 +44,32 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️ For more information, see Python's `multiprocessing guidelines `_. +VecEnv API vs Gym API +--------------------- + +For consistency across Stable-Baselines3 (SB3) versions and because of its special requirements and features, +SB3 VecEnv API is not the same as Gym API. +SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: + +- the ``reset()`` method only returns the observation (``obs = vec_env.reset()``) and not a tuple, the info at reset are stored in ``vec_env.reset_infos``. +- only the initial call to ``vec_env.reset()`` is required, environments are reset automatically afterward (and ``reset_infos`` is updated automatically). +- the ``vec_env.step(actions)`` method expects an array as input + (with a batch size corresponding to the number of environments) and returns a 4-tuple (and not a 5-tuple): ``obs, rewards, dones, infos`` instead of ``obs, reward, terminated, truncated, info`` + where ``dones = terminated or truncated`` (for each env). + ``obs, rewards, dones`` are numpy arrays with shape ``(n_envs, shape_for_single_env)`` (so with a batch dimension). + Additional information is passed via the ``infos`` value which is a list of dictionaries. +- at the end of an episode, ``infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated`` + tells the user if an episode was truncated or not: + you should bootstrap when ``infos[env_idx]["TimeLimit.truncated"] is True`` or ``dones[env_idx] is False``. + Note: compared to Gym 0.26+ ``infos[env_idx]["TimeLimit.truncated"]`` and ``terminated`` `are mutually exclusive `_. +- at the end of an episode, because the environment resets automatically, + we provide ``infos[env_idx]["terminal_observation"]`` which contains the last observation + of an episode (and can be used when bootstrapping, see note in the previous section) +- if you pass ``render_mode="rgb_array"`` to your Gym env, a corresponding VecEnv can automatically show the rendered image + by calling ``vec_env.render(mode="human")``. This is different from Gym which currently `doesn't allow multiple render modes `_ + and doesn't allow passing a ``mode`` parameter. Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). + + Vectorized Environments Wrappers -------------------------------- From c645d49525bef790dcbd628af2e2308f8982d995 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Dec 2022 23:19:10 +0100 Subject: [PATCH 086/153] Add render test --- tests/test_vec_envs.py | 63 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index ddd4a37de3..c85b189bca 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,12 +2,15 @@ import functools import itertools import multiprocessing +import os +import warnings from typing import Dict, Optional import gym import numpy as np import pytest +from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize @@ -515,3 +518,63 @@ def make_env(): assert not np.allclose(rewards[1], rewards[2]) vec_env.close() + + +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_render(vec_env_class): + # Skip if no X-Server + if not os.environ.get("DISPLAY"): + pytest.skip("No X-Server") + + env_id = "Pendulum-v1" + # DummyVecEnv human render is currently + # buggy because of gym: + # https://github.com/carlosluis/stable-baselines3/pull/3#issuecomment-1356863808 + n_envs = 2 + # Human render + vec_env = make_vec_env( + env_id, + n_envs, + vec_env_cls=vec_env_class, + env_kwargs=dict(render_mode="human"), + ) + + vec_env.reset() + vec_env.render() + + with pytest.warns(UserWarning): + vec_env.render("rgb_array") + + with pytest.warns(UserWarning): + vec_env.render(mode="blah") + + for _ in range(10): + vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) + vec_env.render() + + vec_env.close() + # rgb_array render, which allows human_render + # thanks to OpenCV + vec_env = make_vec_env( + env_id, + n_envs, + vec_env_cls=vec_env_class, + env_kwargs=dict(render_mode="rgb_array"), + ) + + vec_env.reset() + with warnings.catch_warnings(record=True) as record: + vec_env.render() + vec_env.render("rgb_array") + vec_env.render(mode="human") + + # No warnings for using human mode + assert len(record) == 0 + + with pytest.warns(UserWarning): + vec_env.render(mode="blah") + + for _ in range(10): + vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) + vec_env.render() + vec_env.close() From 3f6413d310650740316b0993383045cfd4b97721 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Dec 2022 23:32:17 +0100 Subject: [PATCH 087/153] Fix return type --- stable_baselines3/common/vec_env/base_vec_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 3cac2cdfb5..ca2c755c97 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -324,7 +324,7 @@ def close(self) -> None: def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: return self.venv.render(mode=mode) - def get_images(self) -> Sequence[np.ndarray]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: From ff609adfc0a8b2a7eb7620945440a0b9baa449a0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Dec 2022 23:49:38 +0100 Subject: [PATCH 088/153] Update VecEnv vs Gym API doc --- docs/guide/vec_envs.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index ac8b414a38..3f37efc891 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -68,6 +68,12 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: - if you pass ``render_mode="rgb_array"`` to your Gym env, a corresponding VecEnv can automatically show the rendered image by calling ``vec_env.render(mode="human")``. This is different from Gym which currently `doesn't allow multiple render modes `_ and doesn't allow passing a ``mode`` parameter. Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). +- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator, + you should call ``vec_env.seed(seed=seed)`` and ``obs = vec_env.reset()`` afterward. + If your Gym env implements a ``seed()`` method then it will be called, + otherwise ``env.reset(seed=seed)`` will be called (in that case, you will need two resets to set the seed). +- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``, + ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. Vectorized Environments Wrappers From b428a271ae7a278870483c94e855fb713da38f32 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 5 Jan 2023 16:33:28 +0100 Subject: [PATCH 089/153] Fix for custom render mode --- stable_baselines3/common/vec_env/base_vec_env.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 36bff4e79d..50020d9324 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -231,8 +231,10 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: return bigimg else: - # TODO: a new 'rgb_array_list' mode has been defined and should be handled. - raise NotImplementedError(f"The render mode {mode} has not yet been implemented in Stable Baselines.") + # Other render modes: + # In that case, we try to call `self.env.render()` but it might + # crash for subprocesses + return self.env_method("render") @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: From d98dc9e84c5ae047fd2cbff652cd030cf97bcfe7 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 5 Jan 2023 16:41:05 +0100 Subject: [PATCH 090/153] Fix return type --- stable_baselines3/common/vec_env/base_vec_env.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 50020d9324..572a7a1328 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -234,7 +234,8 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: # Other render modes: # In that case, we try to call `self.env.render()` but it might # crash for subprocesses - return self.env_method("render") + # and we don't return the values + self.env_method("render") @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: From cd286669ca1fdcd0016bab45955e44ac5b9005a0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 16 Jan 2023 22:53:21 +0100 Subject: [PATCH 091/153] Fix type checking --- stable_baselines3/common/monitor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 9d9884957c..41ceef4be5 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -43,9 +43,10 @@ def __init__( self.t_start = time.time() self.results_writer = None if filename is not None: + env_id = env.spec.id if env.spec is not None else None self.results_writer = ResultsWriter( filename, - header={"t_start": self.t_start, "env_id": env.spec and env.spec.id}, + header={"t_start": self.t_start, "env_id": env_id}, extra_keys=reset_keywords + info_keywords, override_existing=override_existing, ) From c9430ecc7036a0578ee8b683641da336a14d37ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Tue, 24 Jan 2023 14:24:22 +0100 Subject: [PATCH 092/153] check test env test_buffer --- tests/test_buffers.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 4bd2d27939..93c3e12823 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -5,6 +5,7 @@ from gym import spaces from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device @@ -16,10 +17,12 @@ class DummyEnv(gym.Env): Custom gym environment for testing purposes """ + render_mode = None + def __init__(self): self.action_space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Box(1, 5, (1,)) - self._observations = [1, 2, 3, 4, 5] + self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 @@ -43,11 +46,13 @@ class DummyDictEnv(gym.Env): Custom gym environment for testing purposes """ + render_mode = None + def __init__(self): self.action_space = spaces.Box(1, 5, (1,)) space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space}) - self._observations = [1, 2, 3, 4, 5] + self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 @@ -66,6 +71,15 @@ def step(self, action): return obs, reward, done, truncated, {} +@pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) +def test_test_env(env_cls): + # Check the env used for testing + check_env(env_cls()) + + +check_env(DummyEnv()) + + @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) def test_replay_buffer_normalization(replay_buffer_cls): env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls] From 546928cea0bb367d845e7ef68052bdffa17424a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Tue, 24 Jan 2023 14:44:55 +0100 Subject: [PATCH 093/153] skip render check --- tests/test_buffers.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 93c3e12823..5cacb443b7 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -17,8 +17,6 @@ class DummyEnv(gym.Env): Custom gym environment for testing purposes """ - render_mode = None - def __init__(self): self.action_space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Box(1, 5, (1,)) @@ -46,8 +44,6 @@ class DummyDictEnv(gym.Env): Custom gym environment for testing purposes """ - render_mode = None - def __init__(self): self.action_space = spaces.Box(1, 5, (1,)) space = spaces.Box(1, 5, (1,)) @@ -74,10 +70,7 @@ def step(self, action): @pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) def test_test_env(env_cls): # Check the env used for testing - check_env(env_cls()) - - -check_env(DummyEnv()) + check_env(env_cls(), skip_render_check=True) @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) From 8462dbb0ebbb31b33f2fb1c7a1a25eaa518f068d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Tue, 24 Jan 2023 14:45:20 +0100 Subject: [PATCH 094/153] check env test_dict_env --- tests/test_dict_env.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 3eb3b51ce3..31b423c83a 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -7,6 +7,7 @@ from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize @@ -71,9 +72,6 @@ def step(self, action): done = truncated = False return self.observation_space.sample(), reward, done, truncated, {} - def compute_reward(self, achieved_goal, desired_goal, info): - return np.zeros((len(achieved_goal),)) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: self.observation_space.seed(seed) @@ -83,6 +81,15 @@ def render(self): pass +@pytest.mark.parametrize("use_discrete_actions", [True, False]) +@pytest.mark.parametrize("channel_last", [True, False]) +@pytest.mark.parametrize("nested_dict_obs", [True, False]) +@pytest.mark.parametrize("vec_only", [True, False]) +def test_test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): + # Check the env used for testing + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + + @pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"]) def test_policy_hint(policy): # Common mistake: using the wrong policy From 205b987a2dd83a629a527d4f3735d65f94af74e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Tue, 24 Jan 2023 14:45:53 +0100 Subject: [PATCH 095/153] test_env test_gae --- tests/test_gae.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_gae.py b/tests/test_gae.py index 564f2e8a28..35c5689596 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -8,6 +8,7 @@ from stable_baselines3 import A2C, PPO, SAC from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.policies import ActorCriticPolicy @@ -121,6 +122,12 @@ def forward(self, obs, deterministic=False): return actions, values, log_prob +@pytest.mark.parametrize("env_cls", [CustomEnv, InfiniteHorizonEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) + + @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("gae_lambda", [1.0, 0.9]) @pytest.mark.parametrize("gamma", [1.0, 0.99]) From 1946082f10d3ebca3aa135a627e6271a6464134b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Tue, 24 Jan 2023 14:51:51 +0100 Subject: [PATCH 096/153] check envs in remaining tests --- tests/test_dict_env.py | 2 +- tests/test_logger.py | 7 +++++++ tests/test_predict.py | 7 +++++++ tests/test_spaces.py | 7 +++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 31b423c83a..945d5e6b7d 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -6,8 +6,8 @@ from gym import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 -from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize diff --git a/tests/test_logger.py b/tests/test_logger.py index 85affe1a75..bd218516bb 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -13,6 +13,7 @@ from pandas.errors import EmptyDataError from stable_baselines3 import A2C, DQN +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import ( DEBUG, INFO, @@ -363,6 +364,12 @@ def step(self, action): return obs, 0.0, True, False, {} +@pytest.mark.parametrize("env_cls", [TimeDelayEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) + + class InMemoryLogger(Logger): """ Logger that keeps key/value pairs in memory without any writers. diff --git a/tests/test_predict.py b/tests/test_predict.py index 553abea4c1..4f8c3d1e9c 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -5,6 +5,7 @@ from gym import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import IdentityEnv from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -36,6 +37,12 @@ def step(self, action): return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {} +@pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) + + @pytest.mark.parametrize("model_class", MODEL_LIST) def test_auto_wrap(model_class): """Test auto wrapping of env into a VecEnv.""" diff --git a/tests/test_spaces.py b/tests/test_spaces.py index eff7662535..82fdc636a3 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -6,6 +6,7 @@ from gym import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy @@ -53,6 +54,12 @@ def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) +def test_env(env): + # Check the env used for testing + check_env(env, skip_render_check=True) + + @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) @pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) def test_identity_spaces(model_class, env): From 7460782ff3af7487db51b812587fa17ff848f25f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 25 Jan 2023 15:59:46 +0100 Subject: [PATCH 097/153] Update tests --- tests/test_buffers.py | 5 +++-- tests/test_dict_env.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 5cacb443b7..f988f91360 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -68,9 +68,10 @@ def step(self, action): @pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) -def test_test_env(env_cls): +def test_env(env_cls): # Check the env used for testing - check_env(env_cls(), skip_render_check=True) + # Do not warn for assymetric space + check_env(env_cls(), warn=False, skip_render_check=True) @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 945d5e6b7d..42aa468a7f 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -85,9 +85,13 @@ def render(self): @pytest.mark.parametrize("channel_last", [True, False]) @pytest.mark.parametrize("nested_dict_obs", [True, False]) @pytest.mark.parametrize("vec_only", [True, False]) -def test_test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): +def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): # Check the env used for testing - check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + if nested_dict_obs: + with pytest.warns(UserWarning, match="Nested observation spaces are not supported"): + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + else: + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) @pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"]) From e5575d8abbf20a1b6b2fa7b527275a8045ee6811 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 25 Jan 2023 16:20:14 +0100 Subject: [PATCH 098/153] Add warning for Discrete action space with non-zero (#1295) --- stable_baselines3/common/env_checker.py | 6 ++++++ tests/test_envs.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 2b9afd5cc3..ab7d404f61 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -93,6 +93,12 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "You can use a wrapper or update your observation space." ) + if isinstance(action_space, spaces.Discrete) and action_space.start != 0: + warnings.warn( + "Discrete action space with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your action space." + ) + if not _is_numpy_array_space(action_space): warnings.warn( "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " diff --git a/tests/test_envs.py b/tests/test_envs.py index 30c4593c7c..82bd6a6c07 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -160,6 +160,8 @@ def patched_step(_action): spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), + # Non zero start index + spaces.Discrete(3, start=-1), ], ) def test_non_default_action_spaces(new_action_space): @@ -174,6 +176,12 @@ def test_non_default_action_spaces(new_action_space): # Change the action space env.action_space = new_action_space + # Discrete action space + if isinstance(new_action_space, spaces.Discrete): + with pytest.warns(UserWarning): + check_env(env) + return + low, high = new_action_space.low[0], new_action_space.high[0] # Unbounded action space throws an error, # the rest only warning From b787d98e1f9cf5f8a700c512bd7a1e9e64822323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 28 Jan 2023 12:52:10 +0100 Subject: [PATCH 099/153] Fix atari annotation --- stable_baselines3/common/atari_wrappers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index ff9c0e4ff3..595fc6d9ab 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -11,7 +11,7 @@ except ImportError: cv2 = None -from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn class StickyActionEnv(gym.Wrapper): @@ -30,11 +30,11 @@ def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: self.action_repeat_probability = action_repeat_probability assert env.unwrapped.get_action_meanings()[0] == "NOOP" - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self._sticky_action = 0 # NOOP return self.env.reset(**kwargs) - def step(self, action: int) -> GymStepReturn: + def step(self, action: int) -> Gym26StepReturn: if self.np_random.random() >= self.action_repeat_probability: self._sticky_action = action return self.env.step(self._sticky_action) From 85bb0d4aed29d434a44ee6c708b36b4e32d88962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 28 Jan 2023 14:32:14 +0100 Subject: [PATCH 100/153] ignore get_action_meanings [attr-defined] --- stable_baselines3/common/atari_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 595fc6d9ab..e1f1ea59bd 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -28,7 +28,7 @@ class StickyActionEnv(gym.Wrapper): def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: super().__init__(env) self.action_repeat_probability = action_repeat_probability - assert env.unwrapped.get_action_meanings()[0] == "NOOP" + assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self._sticky_action = 0 # NOOP From afa1c735fb763d3a8ccf6468a2ae0455020c5bc6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 6 Feb 2023 22:56:17 +0100 Subject: [PATCH 101/153] Fix mypy issues --- .../common/vec_env/stacked_observations.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index a26812cd91..555c6f23fe 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -40,12 +40,12 @@ def __init__( if not isinstance(channels_order, Mapping): channels_order = {key: channels_order for key in observation_space.spaces.keys()} self.sub_stacked_observations = { - key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) + key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) # type: ignore[arg-type] for key, subspace in observation_space.spaces.items() } self.stacked_observation_space = spaces.Dict( {key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()} - ) # type: spaces.Dict # make mypy happy + ) # type: Union[spaces.Dict, spaces.Box] # make mypy happy elif isinstance(observation_space, spaces.Box): if isinstance(channels_order, Mapping): raise TypeError("When the observation space is Box, channels_order can't be a dict.") @@ -55,7 +55,11 @@ def __init__( ) low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis) high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis) - self.stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype) + self.stacked_observation_space = spaces.Box( + low=low, + high=high, + dtype=observation_space.dtype, # type: ignore[arg-type] + ) self.stacked_obs = np.zeros((num_envs,) + self.stacked_shape, dtype=observation_space.dtype) else: raise TypeError( @@ -125,7 +129,7 @@ def stack_observation_space(self, observation_space: Union[spaces.Box, spaces.Di ) low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) - return spaces.Box(low=low, high=high, dtype=observation_space.dtype) + return spaces.Box(low=low, high=high, dtype=observation_space.dtype) # type: ignore[arg-type] def reset(self, observation: TObs) -> TObs: """ From e085ce1b1736c4932054486fb62949e51612e370 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Feb 2023 17:52:20 +0100 Subject: [PATCH 102/153] Add patch for gym/gymnasium transition --- stable_baselines3/common/vec_env/patch_gym.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 stable_baselines3/common/vec_env/patch_gym.py diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py new file mode 100644 index 0000000000..0477d41e8f --- /dev/null +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -0,0 +1,73 @@ +import warnings +from typing import Callable + +import gymnasium + +try: + import gym + + gym_installed = True +except ImportError: + gym_installed = False + +def _patch_env_generator(env_fn: Callable[[], gymnasium.Env]) -> Callable[[], gymnasium.Env]: + """ + Taken from https://github.com/thu-ml/tianshou. + + Takes an environment generator and patches it to return Gymnasium envs. + This function takes the environment generator ``env_fn`` and returns a patched + generator, without invoking ``env_fn``. The original generator may return + Gymnasium or OpenAI Gym environments, but the patched generator wraps + the result of ``env_fn`` in a shimmy wrapper to convert it to Gymnasium, + if necessary. + + :param env_fn: a function that returns an environment + :return: Patched generator + """ + + def patched() -> gymnasium.Env: + env = env_fn() + + # Gymnasium env, no patching to be done + if isinstance(env, gymnasium.Env): + return env + + if not gym_installed or not isinstance(env, gym.Env): + raise ValueError( + f"Environment generator returned a {type(env)}, not a Gymnasium " + f"environment. In this case, we expect OpenAI Gym to be " + f"installed and the environment to be an OpenAI Gym environment." + ) + + try: + import shimmy + except ImportError as e: + raise ImportError( + "Missing shimmy installation. You provided an environment generator " + "that returned an OpenAI Gym environment. " + "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " + "In order to use OpenAI Gym environments with SB3, you need to " + "install shimmy (`pip install shimmy`)." + ) from e + + warnings.warn( + "You provided an environment generator that returned an OpenAI Gym " + "environment. We strongly recommend transitioning to Gymnasium " + "environments. " + "Stable-Baselines3 is automatically wrapping your environments in a compatibility " + "layer, which could potentially cause issues." + ) + + # gym version only goes to 0.26.2 + gym_version = int(gym.__version__.split(".")[1]) + if gym_version >= 26: + return shimmy.GymV26CompatibilityV0(env=env) + elif gym_version >= 21: + # TODO: rename to GymV21CompatibilityV0 + return shimmy.GymV22CompatibilityV0(env=env) + else: + raise Exception( + f"Found OpenAI Gym version {gym.__version__}. " f"SB3 only supports OpenAI Gym environments of version>=0.21.0" + ) + + return patched From 8fb8c891947bfe58a3149113b1c84da53f9a4de2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Feb 2023 17:56:14 +0100 Subject: [PATCH 103/153] Switch to gymnasium --- docs/guide/callbacks.rst | 10 +++++----- docs/guide/checking_nan.rst | 4 ++-- docs/guide/custom_env.rst | 6 +++--- docs/guide/custom_policy.rst | 8 ++++---- docs/guide/examples.rst | 16 ++++++++-------- docs/guide/integrations.rst | 4 ++-- docs/guide/quickstart.rst | 2 +- docs/guide/tensorboard.rst | 2 +- docs/misc/changelog.rst | 2 +- docs/modules/a2c.rst | 2 +- docs/modules/ddpg.rst | 2 +- docs/modules/dqn.rst | 2 +- docs/modules/ppo.rst | 2 +- docs/modules/sac.rst | 2 +- docs/modules/td3.rst | 2 +- setup.py | 8 ++++---- stable_baselines3/a2c/a2c.py | 2 +- stable_baselines3/common/atari_wrappers.py | 4 ++-- stable_baselines3/common/base_class.py | 4 ++-- stable_baselines3/common/buffers.py | 2 +- stable_baselines3/common/callbacks.py | 2 +- stable_baselines3/common/distributions.py | 2 +- stable_baselines3/common/env_checker.py | 12 +++++++----- stable_baselines3/common/env_util.py | 2 +- .../common/envs/bit_flipping_env.py | 4 ++-- stable_baselines3/common/envs/identity_env.py | 4 ++-- .../common/envs/multi_input_envs.py | 4 ++-- stable_baselines3/common/evaluation.py | 2 +- stable_baselines3/common/monitor.py | 2 +- stable_baselines3/common/off_policy_algorithm.py | 2 +- stable_baselines3/common/on_policy_algorithm.py | 2 +- stable_baselines3/common/policies.py | 2 +- stable_baselines3/common/preprocessing.py | 2 +- stable_baselines3/common/torch_layers.py | 4 ++-- stable_baselines3/common/type_aliases.py | 2 +- stable_baselines3/common/utils.py | 4 ++-- stable_baselines3/common/vec_env/base_vec_env.py | 6 +++--- .../common/vec_env/dummy_vec_env.py | 8 ++++---- .../common/vec_env/stacked_observations.py | 2 +- .../common/vec_env/subproc_vec_env.py | 7 ++++--- stable_baselines3/common/vec_env/util.py | 2 +- .../common/vec_env/vec_frame_stack.py | 2 +- .../common/vec_env/vec_normalize.py | 2 +- .../common/vec_env/vec_transpose.py | 2 +- .../common/vec_env/vec_video_recorder.py | 2 +- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/dqn/policies.py | 2 +- stable_baselines3/ppo/ppo.py | 2 +- stable_baselines3/sac/policies.py | 2 +- stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/policies.py | 2 +- stable_baselines3/td3/td3.py | 2 +- tests/test_buffers.py | 4 ++-- tests/test_callbacks.py | 2 +- tests/test_cnn.py | 2 +- tests/test_dict_env.py | 4 ++-- tests/test_distributions.py | 2 +- tests/test_env_checker.py | 4 ++-- tests/test_envs.py | 4 ++-- tests/test_gae.py | 4 ++-- tests/test_her.py | 2 +- tests/test_logger.py | 4 ++-- tests/test_monitor.py | 2 +- tests/test_predict.py | 4 ++-- tests/test_preprocessing.py | 2 +- tests/test_run.py | 2 +- tests/test_save_load.py | 2 +- tests/test_spaces.py | 4 ++-- tests/test_train_eval_mode.py | 2 +- tests/test_utils.py | 4 ++-- tests/test_vec_check_nan.py | 4 ++-- tests/test_vec_envs.py | 4 ++-- tests/test_vec_extract_dict_obs.py | 2 +- tests/test_vec_monitor.py | 2 +- tests/test_vec_normalize.py | 4 ++-- tests/test_vec_stacked_obs.py | 2 +- 76 files changed, 130 insertions(+), 127 deletions(-) diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 632743bbfa..e08c003673 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -210,7 +210,7 @@ It will save the best model if ``best_model_save_path`` folder is specified and .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback @@ -260,7 +260,7 @@ Alternatively, you can pass directly a list of callbacks to the ``learn()`` meth .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback @@ -290,7 +290,7 @@ It must be used with the :ref:`EvalCallback` and use the event triggered by a ne .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold @@ -322,7 +322,7 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps @@ -379,7 +379,7 @@ It must be used with the :ref:`EvalCallback` and use the event triggered after e .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement diff --git a/docs/guide/checking_nan.rst b/docs/guide/checking_nan.rst index 7395fbd8b8..8d0de36e86 100644 --- a/docs/guide/checking_nan.rst +++ b/docs/guide/checking_nan.rst @@ -100,8 +100,8 @@ It will monitor the actions, observations, and rewards, indicating what action o .. code-block:: python - import gym - from gym import spaces + import gymnasium as gym + from gymnasium import spaces import numpy as np from stable_baselines3 import PPO diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 2392bbb31d..822c8215ec 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -26,9 +26,9 @@ That is to say, your environment must implement the following methods (and inher .. code-block:: python - import gym + import gymnasium as gym import numpy as np - from gym import spaces + from gymnasium import spaces class CustomEnv(gym.Env): @@ -91,7 +91,7 @@ Optionally, you can also register the environment with gym, that will allow you .. code-block:: python - from gym.envs.registration import register + from gymnasium.envs.registration import register # Example for the CartPole environment register( # unique identifier for the env `name-version` diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index dae60485b8..ebc69155d3 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -79,7 +79,7 @@ using ``policy_kwargs`` parameter: .. code-block:: python - import gym + import gymnasium as gym import torch as th from stable_baselines3 import PPO @@ -121,7 +121,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t import torch as th import torch.nn as nn - from gym import spaces + from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.torch_layers import BaseFeaturesExtractor @@ -186,7 +186,7 @@ downsampling and "vector" with a single linear layer. .. code-block:: python - import gym + import gymnasium as gym import torch as th from torch import nn @@ -286,7 +286,7 @@ If your task requires even more granular control over the policy/value architect from typing import Callable, Dict, List, Optional, Tuple, Type, Union - from gym import spaces + from gymnasium import spaces import torch as th from torch import nn diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 5e8a47d8f5..03129508f8 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -64,7 +64,7 @@ In the following example, we will train, save and load a DQN model on the Lunar .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import DQN from stable_baselines3.common.evaluation import evaluate_policy @@ -115,7 +115,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import PPO @@ -173,7 +173,7 @@ Multiprocessing with off-policy algorithms .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import SAC from stable_baselines3.common.env_util import make_vec_env @@ -229,7 +229,7 @@ If your callback returns False, training is aborted early. import os - import gym + import gymnasium as gym import numpy as np import matplotlib.pyplot as plt @@ -372,7 +372,7 @@ will compute a running average and standard deviation of input features (it can .. code-block:: python import os - import gym + import gymnasium as gym import pybullet_envs from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize @@ -430,7 +430,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi .. code-block:: python - import gym + import gymnasium as gym import highway_env import numpy as np @@ -625,7 +625,7 @@ A2C policy gradient updates on the model. from typing import Dict - import gym + import gymnasium as gym import numpy as np import torch as th @@ -742,7 +742,7 @@ Record a mp4 video (here using a random agent). .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv env_id = "CartPole-v1" diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 49bbdb248b..14573cdec3 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -13,7 +13,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati .. code-block:: python - import gym + import gymnasium as gym import wandb from wandb.integration.sb3 import WandbCallback @@ -86,7 +86,7 @@ For instance ``sb3/demo-hf-CartPole-v1``: .. code-block:: python - import gym + import gymnasium as gym from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index e809f1ba06..b22ac54dac 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -16,7 +16,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment: .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import A2C diff --git a/docs/guide/tensorboard.rst b/docs/guide/tensorboard.rst index 2699d4a86d..720c3ded2a 100644 --- a/docs/guide/tensorboard.rst +++ b/docs/guide/tensorboard.rst @@ -190,7 +190,7 @@ Here is an example of how to render an episode and log the resulting video to Te from typing import Any, Dict - import gym + import gymnasium as gym import torch as th from stable_baselines3 import A2C diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f5e0e96688..5b78a5604d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -131,7 +131,7 @@ Others: - Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Set tensors construction directly on the device (~8% speed boost on GPU) - Monkey-patched ``np.bool = bool`` so gym 0.21 is compatible with NumPy 1.24+ -- Standardized the use of ``from gym import spaces`` +- Standardized the use of ``from gymnasium import spaces`` - Modified ``get_system_info`` to avoid issue linked to copy-pasting on GitHub issue Documentation: diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 670da617cf..84a94eaab8 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -53,7 +53,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments. .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import A2C from stable_baselines3.common.env_util import make_vec_env diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index c484a1c935..4ac28ccb3b 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -61,7 +61,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import DDPG diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 8648606cc1..0569aa5283 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -56,7 +56,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import DQN diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index d0c425fb5e..a822cb4369 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -65,7 +65,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments. .. code-block:: python - import gym + import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index e7f9057d50..0e9bb3f645 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -68,7 +68,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import SAC diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index d039ae71c3..7c17e644d2 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -61,7 +61,7 @@ This example is only to demonstrate the use of the library and its functions, an .. code-block:: python - import gym + import gymnasium as gym import numpy as np from stable_baselines3 import TD3 diff --git a/setup.py b/setup.py index b72de46732..cc83347291 100644 --- a/setup.py +++ b/setup.py @@ -39,11 +39,11 @@ Here is a quick example of how to train and run PPO on a cartpole environment: ```python -import gym +import gymnasium from stable_baselines3 import PPO -env = gym.make("CartPole-v1") +env = gymnasium.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) @@ -60,7 +60,7 @@ ``` -Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): +Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): ```python from stable_baselines3 import PPO @@ -76,7 +76,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.26.2", + "gymnasium==0.26.2", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 9e8b40cb07..658ff847de 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional, Type, TypeVar, Union import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index e1f1ea59bd..615b72b634 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,8 +1,8 @@ from typing import Dict, Tuple -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces try: import cv2 # pytype:disable=import-error diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index d104789c6d..c7ee58efcc 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -8,10 +8,10 @@ from collections import deque from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union -import gym +import gymnasium as gym import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common import utils from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 273dba9e03..b34766c838 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -4,7 +4,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape from stable_baselines3.common.type_aliases import ( diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 69a21ab70e..46db088240 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union -import gym +import gymnasium as gym import numpy as np from stable_baselines3.common.logger import Logger diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index b1cd439a24..170b0de2ba 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -5,7 +5,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from torch.distributions import Bernoulli, Categorical, Normal diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index ab7d404f61..2a44b2a4c7 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -1,9 +1,9 @@ import warnings from typing import Any, Dict, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -215,7 +215,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action """ Check the returned values by the env when calling `.reset()` or `.step()` methods. """ - # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists + # because env inherits from gymnasium.Env, we assume that `reset()` and `step()` methods exists reset_returns = env.reset() assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)" assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}" @@ -300,8 +300,10 @@ def _check_spaces(env: gym.Env) -> None: assert hasattr(env, "observation_space"), "You must specify an observation space (cf gym.spaces)" + gym_spaces assert hasattr(env, "action_space"), "You must specify an action space (cf gym.spaces)" + gym_spaces - assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces - assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces + assert isinstance(env.observation_space, spaces.Space), ( + "The observation space must inherit from gymnasium.spaces" + gym_spaces + ) + assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" + gym_spaces if _is_goal_env(env): assert isinstance( diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index cf10246495..d9a13c9a3b 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, Optional, Type, Union -import gym +import gymnasium as gym from stable_baselines3.common.atari_wrappers import AtariWrapper from stable_baselines3.common.monitor import Monitor diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 01e09357fd..cdbed8fe06 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -2,8 +2,8 @@ from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from gym import Env, spaces -from gym.envs.registration import EnvSpec +from gymnasium import Env, spaces +from gymnasium.envs.registration import EnvSpec from stable_baselines3.common.type_aliases import Gym26StepReturn diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 9f8234f72f..3d54dffaf3 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -1,8 +1,8 @@ from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.type_aliases import Gym26StepReturn diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 65ec372bf5..a498d0ac94 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,8 +1,8 @@ from typing import Dict, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.type_aliases import Gym26StepReturn diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index 593b407d89..0c9921bb64 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,7 +1,7 @@ import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np from stable_baselines3.common import type_aliases diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 41ceef4be5..875ab5adc3 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -7,7 +7,7 @@ from glob import glob from typing import Any, Dict, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np import pandas diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c1ab215841..a14302eaae 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -8,7 +8,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 44d8b26b0e..cf557df6f0 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -4,7 +4,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 457274a148..636abee9f4 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -9,7 +9,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import ( diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 79f3fbeec6..6b35481b4d 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -3,7 +3,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 9e72774671..ad6c7eef15 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,8 +1,8 @@ from typing import Dict, List, Tuple, Type, Union -import gym +import gymnasium as gym import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 037c0e58c6..e7f2975290 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np import torch as th diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 48be56d7cf..9978df739f 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -8,10 +8,10 @@ from itertools import zip_longest from typing import Dict, Iterable, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces import stable_baselines3 as sb3 diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 572a7a1328..fca0d5e604 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -4,9 +4,9 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import cloudpickle -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces # Define type aliases here to avoid circular import # Used when we want to access one or more VecEnv @@ -196,7 +196,7 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: elif mode and self.render_mode != mode: warnings.warn( - f"""Starting from gym v0.26, render modes are determined during the initialization of the environment. + f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment. We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" ) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 8dcf26c010..da429e8bcc 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -3,17 +3,17 @@ from copy import deepcopy from typing import Any, Callable, List, Optional, Sequence, Type, Union -import gym +import gymnasium as gym import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info - +from stable_baselines3.common.vec_env.patch_gym import _patch_env_generator class DummyVecEnv(VecEnv): """ Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current - Python process. This is useful for computationally simple environment such as ``cartpole-v1``, + Python process. This is useful for computationally simple environment such as ``Cartpole-v1``, as the overhead of multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that require a vectorized environment, but that you want a single environments to train with. @@ -24,7 +24,7 @@ class DummyVecEnv(VecEnv): """ def __init__(self, env_fns: List[Callable[[], gym.Env]]): - self.envs = [fn() for fn in env_fns] + self.envs = [_patch_env_generator(fn)() for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): raise ValueError( "You tried to create multiple environments, but the function to create them returned the same instance " diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 555c6f23fe..883dc84191 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 9fac06f3f6..6aea50abe2 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -3,9 +3,10 @@ from collections import OrderedDict from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union -import gym +import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces +from stable_baselines3.common.vec_env.patch_gym import _patch_env_generator from stable_baselines3.common.vec_env.base_vec_env import ( CloudpickleWrapper, @@ -26,7 +27,7 @@ def _worker( # noqa: C901 from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() - env = env_fn_wrapper.var() + env = _patch_env_generator(env_fn_wrapper.var)() reset_info = {} while True: try: diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 7d318acffd..6d55db8179 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Tuple import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import check_for_nested_spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 75c80e9e85..200201f060 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.stacked_observations import StackedObservations diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 8b984133d2..3514bf1d64 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common import utils from stable_baselines3.common.preprocessing import is_image_space diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index b6b0ad832f..beb6039617 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -2,7 +2,7 @@ from typing import Dict, Union import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index db69994004..6f670054ba 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,7 +1,7 @@ import os from typing import Callable -from gym.wrappers.monitoring import video_recorder +from gymnasium.wrappers.monitoring import video_recorder from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index ea1946ad71..fb83f83812 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -3,7 +3,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 22e6d0a956..5843873e26 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 3ea67567ee..cdf8a9297e 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -3,7 +3,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index e756097b19..418e5cc227 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index d1a6610354..855f300c64 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -2,7 +2,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 6c4a1e9c38..10546083a1 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type, Union import torch as th -from gym import spaces +from gymnasium import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy, ContinuousCritic diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index c844a99e4a..1430724af0 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -2,7 +2,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer diff --git a/tests/test_buffers.py b/tests/test_buffers.py index f988f91360..5d9236173c 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -1,8 +1,8 @@ -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_checker import check_env diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 420a16a448..5365555a44 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,7 +1,7 @@ import os import shutil -import gym +import gymnasium as gym import numpy as np import pytest diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 1c59d69943..e32438c270 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.envs import FakeImageEnv diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 42aa468a7f..bd97bb51df 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,9 +1,9 @@ from typing import Dict, Optional -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env diff --git a/tests/test_distributions.py b/tests/test_distributions.py index e782182f4c..48eae12d0c 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,7 +1,7 @@ from copy import deepcopy from typing import Tuple -import gym +import gymnasium as gym import numpy as np import pytest import torch as th diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index e55a73fa1f..5a3f377f80 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,7 +1,7 @@ -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.env_checker import check_env diff --git a/tests/test_envs.py b/tests/test_envs.py index 82bd6a6c07..aeb248fbb0 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,10 +1,10 @@ import types import warnings -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import ( diff --git a/tests/test_gae.py b/tests/test_gae.py index 35c5689596..3c22c7cd8c 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,10 +1,10 @@ from typing import Dict, Optional -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, PPO, SAC from stable_baselines3.common.callbacks import BaseCallback diff --git a/tests/test_her.py b/tests/test_her.py index c1bc515ed6..2e385d51e3 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -3,7 +3,7 @@ import warnings from copy import deepcopy -import gym +import gymnasium as gym import numpy as np import pytest import torch as th diff --git a/tests/test_logger.py b/tests/test_logger.py index bd218516bb..26b5322999 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -4,11 +4,11 @@ from typing import Sequence from unittest import mock -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from matplotlib import pyplot as plt from pandas.errors import EmptyDataError diff --git a/tests/test_monitor.py b/tests/test_monitor.py index c580fcf49b..d847a926bb 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -2,7 +2,7 @@ import os import uuid -import gym +import gymnasium as gym import pandas from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results diff --git a/tests/test_predict.py b/tests/test_predict.py index 4f8c3d1e9c..247fe91725 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,8 +1,8 @@ -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 89f869b453..b8a5891c7f 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,5 +1,5 @@ import torch -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs diff --git a/tests/test_run.py b/tests/test_run.py index ca7548ff96..31c7b956ed 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import numpy as np import pytest diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9d3d537b76..2f227adf6a 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -8,7 +8,7 @@ from collections import OrderedDict from copy import deepcopy -import gym +import gymnasium as gym import numpy as np import pytest import torch as th diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 82fdc636a3..6d18fcef84 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,9 +1,9 @@ from typing import Dict, Optional -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index f3a012fd4e..dcbda74e1c 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -1,6 +1,6 @@ from typing import Union -import gym +import gymnasium as gym import numpy as np import pytest import torch as th diff --git a/tests/test_utils.py b/tests/test_utils.py index 0992c3572a..2fd83b3cc8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,11 @@ import os import shutil -import gym +import gymnasium as gym import numpy as np import pytest import torch as th -from gym import spaces +from gymnasium import spaces import stable_baselines3 as sb3 from stable_baselines3 import A2C diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 48b203e891..1253be6e52 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -1,7 +1,7 @@ -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 6129438fe6..c3073f21d1 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -6,10 +6,10 @@ import warnings from typing import Dict, Optional -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.monitor import Monitor diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py index 6aa4abdbde..23c234bec4 100644 --- a/tests/test_vec_extract_dict_obs.py +++ b/tests/test_vec_extract_dict_obs.py @@ -1,5 +1,5 @@ import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index ab988b6b40..1a0e94d909 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -4,7 +4,7 @@ import uuid import warnings -import gym +import gymnasium as gym import pandas import pytest diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 0de43d1531..fb4e76152d 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,10 +1,10 @@ import operator from typing import Any, Dict, Optional -import gym +import gymnasium as gym import numpy as np import pytest -from gym import spaces +from gymnasium import spaces from stable_baselines3 import SAC, TD3, HerReplayBuffer from stable_baselines3.common.monitor import Monitor diff --git a/tests/test_vec_stacked_obs.py b/tests/test_vec_stacked_obs.py index 0a7aa39f10..4b2c614447 100644 --- a/tests/test_vec_stacked_obs.py +++ b/tests/test_vec_stacked_obs.py @@ -1,5 +1,5 @@ import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env.stacked_observations import StackedObservations From 45b20424b13882aecca1c14c90ba54d27e1701db Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Feb 2023 18:02:05 +0100 Subject: [PATCH 104/153] Rely on signature instead of version --- setup.py | 2 +- .../common/vec_env/dummy_vec_env.py | 3 ++- stable_baselines3/common/vec_env/patch_gym.py | 16 +++++++--------- .../common/vec_env/subproc_vec_env.py | 2 +- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index cc83347291..2b474efdf4 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium==0.26.2", + "gymnasium==0.27.1", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index da429e8bcc..3d79c93f84 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -7,8 +7,9 @@ import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn -from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info from stable_baselines3.common.vec_env.patch_gym import _patch_env_generator +from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info + class DummyVecEnv(VecEnv): """ diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 0477d41e8f..2f236fbe87 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -1,4 +1,5 @@ import warnings +from inspect import signature from typing import Callable import gymnasium @@ -10,9 +11,10 @@ except ImportError: gym_installed = False + def _patch_env_generator(env_fn: Callable[[], gymnasium.Env]) -> Callable[[], gymnasium.Env]: """ - Taken from https://github.com/thu-ml/tianshou. + Adapted from https://github.com/thu-ml/tianshou. Takes an environment generator and patches it to return Gymnasium envs. This function takes the environment generator ``env_fn`` and returns a patched @@ -58,16 +60,12 @@ def patched() -> gymnasium.Env: "layer, which could potentially cause issues." ) - # gym version only goes to 0.26.2 - gym_version = int(gym.__version__.split(".")[1]) - if gym_version >= 26: + if "seed" in signature(env.unwrapped.reset).parameters: + # Gym 0.26+ env return shimmy.GymV26CompatibilityV0(env=env) - elif gym_version >= 21: + else: + # Gym 0.21 env # TODO: rename to GymV21CompatibilityV0 return shimmy.GymV22CompatibilityV0(env=env) - else: - raise Exception( - f"Found OpenAI Gym version {gym.__version__}. " f"SB3 only supports OpenAI Gym environments of version>=0.21.0" - ) return patched diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 6aea50abe2..5b77fd85c7 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -6,7 +6,6 @@ import gymnasium as gym import numpy as np from gymnasium import spaces -from stable_baselines3.common.vec_env.patch_gym import _patch_env_generator from stable_baselines3.common.vec_env.base_vec_env import ( CloudpickleWrapper, @@ -15,6 +14,7 @@ VecEnvObs, VecEnvStepReturn, ) +from stable_baselines3.common.vec_env.patch_gym import _patch_env_generator def _worker( # noqa: C901 From cf1ec25e0339591adfd9c47d148434c28eca0c82 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Feb 2023 18:22:37 +0100 Subject: [PATCH 105/153] More patches --- stable_baselines3/common/base_class.py | 3 + stable_baselines3/common/env_util.py | 4 + stable_baselines3/common/utils.py | 2 +- .../common/vec_env/dummy_vec_env.py | 4 +- stable_baselines3/common/vec_env/patch_gym.py | 88 +++++++++---------- .../common/vec_env/subproc_vec_env.py | 4 +- 6 files changed, 52 insertions(+), 53 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index c7ee58efcc..3621bb7bf4 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -39,6 +39,7 @@ is_vecenv_wrapped, unwrap_vec_normalize, ) +from stable_baselines3.common.vec_env.patch_gym import _patch_env SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm") @@ -204,6 +205,8 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve :return: The wrapped environment. """ if not isinstance(env, VecEnv): + # Patch to support gym 0.21/0.26 and gymnasium + env = _patch_env(env) if not is_wrapped(env, Monitor) and monitor_wrapper: if verbose >= 1: print("Wrapping the env with a `Monitor` wrapper") diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index d9a13c9a3b..0ed44607fb 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -7,6 +7,7 @@ from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.utils import compat_gym_seed from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv +from stable_baselines3.common.vec_env.patch_gym import _patch_env def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]: @@ -96,6 +97,9 @@ def _init() -> gym.Env: env = gym.make(env_id, **env_kwargs) else: env = env_id(**env_kwargs) + # Patch to support gym 0.21/0.26 and gymnasium + env = _patch_env(env) + if seed is not None: compat_gym_seed(env, seed=seed + rank) env.action_space.seed(seed + rank) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 9978df739f..eaf4cd6bbc 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -533,7 +533,7 @@ def compat_gym_seed(env: GymEnv, seed: int) -> None: :param env: The Gym environment. :param seed: The seed for the pseudo random generator """ - if isinstance(env, gym.Env) and "seed" in signature(env.unwrapped.reset).parameters: + if "seed" in signature(env.unwrapped.reset).parameters: # gym >= 0.23.1 env.reset(seed=seed) else: diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 3d79c93f84..85ecbb479d 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -7,7 +7,7 @@ import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn -from stable_baselines3.common.vec_env.patch_gym import _patch_env_generator +from stable_baselines3.common.vec_env.patch_gym import _patch_env from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info @@ -25,7 +25,7 @@ class DummyVecEnv(VecEnv): """ def __init__(self, env_fns: List[Callable[[], gym.Env]]): - self.envs = [_patch_env_generator(fn)() for fn in env_fns] + self.envs = [_patch_env(fn()) for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): raise ValueError( "You tried to create multiple environments, but the function to create them returned the same instance " diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 2f236fbe87..6fe346d8b7 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -1,6 +1,6 @@ import warnings from inspect import signature -from typing import Callable +from typing import Union import gymnasium @@ -12,60 +12,52 @@ gym_installed = False -def _patch_env_generator(env_fn: Callable[[], gymnasium.Env]) -> Callable[[], gymnasium.Env]: +def _patch_env(env: Union[gym.Env, gymnasium.Env]) -> gymnasium.Env: """ Adapted from https://github.com/thu-ml/tianshou. - Takes an environment generator and patches it to return Gymnasium envs. - This function takes the environment generator ``env_fn`` and returns a patched - generator, without invoking ``env_fn``. The original generator may return - Gymnasium or OpenAI Gym environments, but the patched generator wraps - the result of ``env_fn`` in a shimmy wrapper to convert it to Gymnasium, + Takes an environment and patches it to return Gymnasium env. + This function takes the environment object and returns a patched + env, using shimmy wrapper to convert it to Gymnasium, if necessary. - :param env_fn: a function that returns an environment - :return: Patched generator + :param env: A gym/gymnasium env + :return: Patched env (gymnasium env) """ - def patched() -> gymnasium.Env: - env = env_fn() + # Gymnasium env, no patching to be done + if isinstance(env, gymnasium.Env): + return env - # Gymnasium env, no patching to be done - if isinstance(env, gymnasium.Env): - return env - - if not gym_installed or not isinstance(env, gym.Env): - raise ValueError( - f"Environment generator returned a {type(env)}, not a Gymnasium " - f"environment. In this case, we expect OpenAI Gym to be " - f"installed and the environment to be an OpenAI Gym environment." - ) - - try: - import shimmy - except ImportError as e: - raise ImportError( - "Missing shimmy installation. You provided an environment generator " - "that returned an OpenAI Gym environment. " - "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " - "In order to use OpenAI Gym environments with SB3, you need to " - "install shimmy (`pip install shimmy`)." - ) from e - - warnings.warn( - "You provided an environment generator that returned an OpenAI Gym " - "environment. We strongly recommend transitioning to Gymnasium " - "environments. " - "Stable-Baselines3 is automatically wrapping your environments in a compatibility " - "layer, which could potentially cause issues." + if not gym_installed or not isinstance(env, gym.Env): + raise ValueError( + f"Environment generator returned a {type(env)}, not a Gymnasium " + f"environment. In this case, we expect OpenAI Gym to be " + f"installed and the environment to be an OpenAI Gym environment." ) - if "seed" in signature(env.unwrapped.reset).parameters: - # Gym 0.26+ env - return shimmy.GymV26CompatibilityV0(env=env) - else: - # Gym 0.21 env - # TODO: rename to GymV21CompatibilityV0 - return shimmy.GymV22CompatibilityV0(env=env) - - return patched + try: + import shimmy + except ImportError as e: + raise ImportError( + "Missing shimmy installation. You provided an environment generator " + "that returned an OpenAI Gym environment. " + "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " + "In order to use OpenAI Gym environments with SB3, you need to " + "install shimmy (`pip install shimmy`)." + ) from e + + warnings.warn( + "You provided an environment generator that returned an OpenAI Gym " + "environment. We strongly recommend transitioning to Gymnasium " + "environments. " + "Stable-Baselines3 is automatically wrapping your environments in a compatibility " + "layer, which could potentially cause issues." + ) + + if "seed" in signature(env.unwrapped.reset).parameters: + # Gym 0.26+ env + return shimmy.GymV26CompatibilityV0(env=env) + # Gym 0.21 env + # TODO: rename to GymV21CompatibilityV0 + return shimmy.GymV22CompatibilityV0(env=env) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 5b77fd85c7..15c9d7d87c 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -14,7 +14,7 @@ VecEnvObs, VecEnvStepReturn, ) -from stable_baselines3.common.vec_env.patch_gym import _patch_env_generator +from stable_baselines3.common.vec_env.patch_gym import _patch_env def _worker( # noqa: C901 @@ -27,7 +27,7 @@ def _worker( # noqa: C901 from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() - env = _patch_env_generator(env_fn_wrapper.var)() + env = _patch_env(env_fn_wrapper.var()) reset_info = {} while True: try: From 951baeeeaa430113a85e30a0a21e9a8dc7d57ab8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Feb 2023 18:41:50 +0100 Subject: [PATCH 106/153] Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 --- setup.cfg | 2 ++ stable_baselines3/common/env_checker.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a8aa96f9ab..b2445953cc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,6 +28,7 @@ follow_imports = silent show_error_codes = True exclude = (?x)( stable_baselines3/a2c/a2c.py$ + | stable_baselines3/common/atari_wrappers.py$ | stable_baselines3/common/base_class.py$ | stable_baselines3/common/buffers.py$ | stable_baselines3/common/callbacks.py$ @@ -35,6 +36,7 @@ exclude = (?x)( | stable_baselines3/common/envs/bit_flipping_env.py$ | stable_baselines3/common/envs/identity_env.py$ | stable_baselines3/common/envs/multi_input_envs.py$ + | stable_baselines3/common/monitor.py$ | stable_baselines3/common/logger.py$ | stable_baselines3/common/off_policy_algorithm.py$ | stable_baselines3/common/on_policy_algorithm.py$ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 2a44b2a4c7..75bd07025b 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -256,7 +256,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # Make mypy happy, already checked assert isinstance(observation_space, spaces.Dict) _check_goal_env_obs(obs, observation_space, "step") - _check_goal_env_compute_reward(obs, env, reward, info) + _check_goal_env_compute_reward(obs, env, reward, info) # type: ignore[arg-type] elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" From 25e70b4377737275da398c640e5e76240ec35a14 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Feb 2023 18:58:32 +0100 Subject: [PATCH 107/153] Fix doc build --- stable_baselines3/common/vec_env/patch_gym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 6fe346d8b7..9513dfad6a 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -12,7 +12,7 @@ gym_installed = False -def _patch_env(env: Union[gym.Env, gymnasium.Env]) -> gymnasium.Env: +def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: """ Adapted from https://github.com/thu-ml/tianshou. From f23f43acb6fd311a9eaf2165c482f8672d9a7a92 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Feb 2023 19:30:40 +0100 Subject: [PATCH 108/153] Fix pytype errors --- stable_baselines3/common/vec_env/patch_gym.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 9513dfad6a..c808a2cbdb 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -5,7 +5,7 @@ import gymnasium try: - import gym + import gym # pytype: disable=import-error gym_installed = True except ImportError: @@ -37,7 +37,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: ) try: - import shimmy + import shimmy # pytype: disable=import-error except ImportError as e: raise ImportError( "Missing shimmy installation. You provided an environment generator " From 1cc406e4643abc9940729ab7c6513f977deaa081 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 10:08:32 +0100 Subject: [PATCH 109/153] Fix atari requirement --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2b474efdf4..20c5635a0e 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,7 @@ "opencv-python", "pygame", # For atari games, - "ale-py~=0.8.0", + "shimmy[atari]~=0.2", "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support From bfcb0ee2c4f2c3b055ac1074e685a471e1d4f45d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 13:41:29 +0100 Subject: [PATCH 110/153] Update env checker due to change in dtype for Discrete --- stable_baselines3/common/env_checker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 75bd07025b..9bedb4855c 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -182,7 +182,9 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac # The check for a GoalEnv is done by the base class if isinstance(observation_space, spaces.Discrete): - assert isinstance(obs, int), f"The observation returned by `{method_name}()` method must be an int" + # Since https://github.com/Farama-Foundation/Gymnasium/pull/141, + # `sample()` will return a np.int64 instead of an int + assert np.issubdtype(obs, np.integer), f"The observation returned by `{method_name}()` method must be an int" elif _is_numpy_array_space(observation_space): assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" From 360625a31528eed2a2810fd14d75c6941b07b536 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 14:28:33 +0100 Subject: [PATCH 111/153] Fix type hint --- stable_baselines3/common/env_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 9bedb4855c..dc785e405d 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -184,7 +184,7 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac if isinstance(observation_space, spaces.Discrete): # Since https://github.com/Farama-Foundation/Gymnasium/pull/141, # `sample()` will return a np.int64 instead of an int - assert np.issubdtype(obs, np.integer), f"The observation returned by `{method_name}()` method must be an int" + assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int" elif _is_numpy_array_space(observation_space): assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" From f0147740b07af273e09a3984f439e4f6175b5343 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 16:12:32 +0100 Subject: [PATCH 112/153] Convert spaces for saved models --- stable_baselines3/common/base_class.py | 6 ++- stable_baselines3/common/vec_env/patch_gym.py | 52 ++++++++++++++++--- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 3621bb7bf4..9cfab2a837 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -39,7 +39,7 @@ is_vecenv_wrapped, unwrap_vec_normalize, ) -from stable_baselines3.common.vec_env.patch_gym import _patch_env +from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm") @@ -689,6 +689,10 @@ def load( if "observation_space" not in data or "action_space" not in data: raise KeyError("The observation_space and action_space were not given, can't verify new environments") + # Gym -> Gymnasium space conversion + for key in {"observation_space", "action_space"}: + data[key] = _convert_space(data[key]) + if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index c808a2cbdb..caae26bab2 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -12,7 +12,7 @@ gym_installed = False -def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: +def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma: no cover """ Adapted from https://github.com/thu-ml/tianshou. @@ -31,7 +31,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: if not gym_installed or not isinstance(env, gym.Env): raise ValueError( - f"Environment generator returned a {type(env)}, not a Gymnasium " + f"The environment is of type {type(env)}, not a Gymnasium " f"environment. In this case, we expect OpenAI Gym to be " f"installed and the environment to be an OpenAI Gym environment." ) @@ -40,17 +40,15 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: import shimmy # pytype: disable=import-error except ImportError as e: raise ImportError( - "Missing shimmy installation. You provided an environment generator " - "that returned an OpenAI Gym environment. " + "Missing shimmy installation. You an OpenAI Gym environment. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym environments with SB3, you need to " "install shimmy (`pip install shimmy`)." ) from e warnings.warn( - "You provided an environment generator that returned an OpenAI Gym " - "environment. We strongly recommend transitioning to Gymnasium " - "environments. " + "You provided an OpenAI Gym environment. " + "We strongly recommend transitioning to Gymnasium environments. " "Stable-Baselines3 is automatically wrapping your environments in a compatibility " "layer, which could potentially cause issues." ) @@ -61,3 +59,43 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # Gym 0.21 env # TODO: rename to GymV21CompatibilityV0 return shimmy.GymV22CompatibilityV0(env=env) + + +def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover + """ + Takes a space and patches it to return Gymnasium Space. + This function takes the space object and returns a patched + space, using shimmy wrapper to convert it to Gymnasium, + if necessary. + + :param env: A gym/gymnasium Space + :return: Patched space (gymnasium Space) + """ + + # Gymnasium space, no convertion to be done + if isinstance(space, gymnasium.Space): + return space + + if not gym_installed or not isinstance(space, gym.Space): + raise ValueError( + f"The space is of type {type(space)}, not a Gymnasium " + f"space. In this case, we expect OpenAI Gym to be " + f"installed and the space to be an OpenAI Gym space." + ) + + try: + import shimmy # pytype: disable=import-error + except ImportError as e: + raise ImportError( + "Missing shimmy installation. You provided an OpenAI Gym space. " + "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " + "In order to use OpenAI Gym space with SB3, you need to " + "install shimmy (`pip install shimmy`)." + ) from e + + warnings.warn( + "You loaded a model that was trained using OpenAI Gym. " + "We strongly recommend transitioning to Gymnasium by saving that model again." + ) + + return shimmy.openai_gym_compatibility._convert_space(space) From fefe177e584f3d8f6ef4eef1e7eb1be04d891eb6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 16:27:13 +0100 Subject: [PATCH 113/153] Ignore pytype --- stable_baselines3/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 9cfab2a837..e5448727dd 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -691,7 +691,7 @@ def load( # Gym -> Gymnasium space conversion for key in {"observation_space", "action_space"}: - data[key] = _convert_space(data[key]) + data[key] = _convert_space(data[key]) # type: disable=unsupported-operands if env is not None: # Wrap first if needed From 1c5ca7b8f6b2c254a8636a1b36c8baf6d53dcf4c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 16:27:30 +0100 Subject: [PATCH 114/153] Remove gitlab CI --- .gitlab-ci.yml | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 .gitlab-ci.yml diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index a9e622e182..0000000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,23 +0,0 @@ -image: stablebaselines/stable-baselines3-cpu:1.5.1a6 - -type-check: - script: - - pip install pytype mypy --upgrade - - make type - -pytest: - script: - - pip install tqdm rich # for progress bar - - python --version - # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error - - MKL_THREADING_LAYER=GNU make pytest - coverage: '/^TOTAL.+?(\d+\%)$/' - -doc-build: - script: - - make doc - -lint-check: - script: - - make check-codestyle - - make lint From b730d221fdb12f9afdb120cd4dba7c6200a25a5c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 16:47:56 +0100 Subject: [PATCH 115/153] Disable pytype for convert space --- stable_baselines3/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e5448727dd..9d2527dd0c 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -691,7 +691,7 @@ def load( # Gym -> Gymnasium space conversion for key in {"observation_space", "action_space"}: - data[key] = _convert_space(data[key]) # type: disable=unsupported-operands + data[key] = _convert_space(data[key]) # pytype: disable=unsupported-operands if env is not None: # Wrap first if needed From a436891b1710f5bc6b77460ea188351ba214ce4b Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 17:31:38 +0100 Subject: [PATCH 116/153] Fix undefined info --- stable_baselines3/common/atari_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 615b72b634..99c834e012 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -135,7 +135,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, terminated, truncated, _ = self.env.step(0) + obs, _, terminated, truncated, info = self.env.step(0) # The no-op step can lead to a game over, so we need to check it again # to see if we should reset the environment and avoid the From 75217fa0538a437d1d280e7080f2a41c8a778cb6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Feb 2023 17:31:38 +0100 Subject: [PATCH 117/153] Fix undefined info --- stable_baselines3/common/atari_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index e1f1ea59bd..01ab0dc816 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -135,7 +135,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, terminated, truncated, _ = self.env.step(0) + obs, _, terminated, truncated, info = self.env.step(0) # The no-op step can lead to a game over, so we need to check it again # to see if we should reset the environment and avoid the From 65af7c1b431cf994719d3332f3c6a150fab82e4c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 19 Feb 2023 14:48:23 +0100 Subject: [PATCH 118/153] Upgrade shimmy --- setup.py | 2 +- stable_baselines3/common/vec_env/patch_gym.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 20c5635a0e..6d0f62d3fb 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,7 @@ "opencv-python", "pygame", # For atari games, - "shimmy[atari]~=0.2", + "shimmy[atari]~=0.2.1", "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index caae26bab2..b86c522364 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma "Missing shimmy installation. You an OpenAI Gym environment. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym environments with SB3, you need to " - "install shimmy (`pip install shimmy`)." + "install shimmy (`pip install 'shimmy>=0.2.1'`)." ) from e warnings.warn( @@ -57,8 +57,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma # Gym 0.26+ env return shimmy.GymV26CompatibilityV0(env=env) # Gym 0.21 env - # TODO: rename to GymV21CompatibilityV0 - return shimmy.GymV22CompatibilityV0(env=env) + return shimmy.GymV21CompatibilityV0(env=env) def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover @@ -90,7 +89,7 @@ def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Spac "Missing shimmy installation. You provided an OpenAI Gym space. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym space with SB3, you need to " - "install shimmy (`pip install shimmy`)." + "install shimmy (`pip install 'shimmy>=0.2.1'`)." ) from e warnings.warn( From ad48559b4e704f89896edb91bc3621e9236d267c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 19 Feb 2023 15:08:33 +0100 Subject: [PATCH 119/153] Fix wrappers type annotation (need PR from Gymnasium) --- stable_baselines3/common/atari_wrappers.py | 40 +++++++++++----------- stable_baselines3/common/env_util.py | 2 +- stable_baselines3/common/monitor.py | 16 ++++----- stable_baselines3/common/type_aliases.py | 4 ++- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 99c834e012..026f66855a 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,9 +1,11 @@ -from typing import Dict, Tuple +from typing import Dict, SupportsFloat import gymnasium as gym import numpy as np from gymnasium import spaces +from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn + try: import cv2 # pytype:disable=import-error @@ -11,10 +13,8 @@ except ImportError: cv2 = None -from stable_baselines3.common.type_aliases import Gym26StepReturn - -class StickyActionEnv(gym.Wrapper): +class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Sticky action. @@ -30,17 +30,17 @@ def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: self.action_repeat_probability = action_repeat_probability assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] - def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: + def reset(self, **kwargs) -> AtariResetReturn: self._sticky_action = 0 # NOOP return self.env.reset(**kwargs) - def step(self, action: int) -> Gym26StepReturn: + def step(self, action: int) -> AtariStepReturn: if self.np_random.random() >= self.action_repeat_probability: self._sticky_action = action return self.env.step(self._sticky_action) -class NoopResetEnv(gym.Wrapper): +class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. @@ -56,7 +56,7 @@ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: self.noop_action = 0 assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] - def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: + def reset(self, **kwargs) -> AtariResetReturn: self.env.reset(**kwargs) if self.override_num_noops is not None: noops = self.override_num_noops @@ -72,7 +72,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: return obs, info -class FireResetEnv(gym.Wrapper): +class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Take action on reset for environments that are fixed until firing. @@ -84,7 +84,7 @@ def __init__(self, env: gym.Env) -> None: assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] - def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: + def reset(self, **kwargs) -> AtariResetReturn: self.env.reset(**kwargs) obs, _, terminated, truncated, _ = self.env.step(1) if terminated or truncated: @@ -95,7 +95,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: return obs, {} -class EpisodicLifeEnv(gym.Wrapper): +class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. @@ -108,7 +108,7 @@ def __init__(self, env: gym.Env) -> None: self.lives = 0 self.was_real_done = True - def step(self, action: int) -> Gym26StepReturn: + def step(self, action: int) -> AtariStepReturn: obs, reward, terminated, truncated, info = self.env.step(action) self.was_real_done = terminated or truncated # check current lives, make loss of life terminal, @@ -122,7 +122,7 @@ def step(self, action: int) -> Gym26StepReturn: self.lives = lives return obs, reward, terminated, truncated, info - def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: + def reset(self, **kwargs) -> AtariResetReturn: """ Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, @@ -146,7 +146,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: return obs, info -class MaxAndSkipEnv(gym.Wrapper): +class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Return only every ``skip``-th frame (frameskipping) and return the max between the two last frames. @@ -164,7 +164,7 @@ def __init__(self, env: gym.Env, skip: int = 4) -> None: self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype) self._skip = skip - def step(self, action: int) -> Gym26StepReturn: + def step(self, action: int) -> AtariStepReturn: """ Step the environment with the given action Repeat action, sum reward, and max over last observations. @@ -181,7 +181,7 @@ def step(self, action: int) -> Gym26StepReturn: self._obs_buffer[0] = obs if i == self._skip - 1: self._obs_buffer[1] = obs - total_reward += reward + total_reward += float(reward) if done: break # Note that the observation on the done=True frame @@ -201,17 +201,17 @@ class ClipRewardEnv(gym.RewardWrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) - def reward(self, reward: float) -> float: + def reward(self, reward: SupportsFloat) -> float: """ Bin reward to {+1, 0, -1} by its sign. :param reward: :return: """ - return np.sign(reward) + return np.sign(float(reward)) -class WarpFrame(gym.ObservationWrapper): +class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]): """ Convert to grayscale and warp frames to 84x84 (default) as done in the Nature paper and later work. @@ -246,7 +246,7 @@ def observation(self, frame: np.ndarray) -> np.ndarray: return frame[:, :, None] -class AtariWrapper(gym.Wrapper): +class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]): """ Atari 2600 preprocessings diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 0ed44607fb..c3b73909e6 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -92,7 +92,7 @@ def _init() -> gym.Env: kwargs = {"render_mode": "rgb_array"} kwargs.update(env_kwargs) try: - env = gym.make(env_id, **kwargs) + env = gym.make(env_id, **kwargs) # type: ignore[arg-type] except TypeError: env = gym.make(env_id, **env_kwargs) else: diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 875ab5adc3..85ffe9b1db 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -5,16 +5,14 @@ import os import time from glob import glob -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union import gymnasium as gym -import numpy as np import pandas +from gymnasium.core import ActType, ObsType -from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn - -class Monitor(gym.Wrapper): +class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. @@ -46,7 +44,7 @@ def __init__( env_id = env.spec.id if env.spec is not None else None self.results_writer = ResultsWriter( filename, - header={"t_start": self.t_start, "env_id": env_id}, + header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=reset_keywords + info_keywords, override_existing=override_existing, ) @@ -63,7 +61,7 @@ def __init__( # extra info about the current episode, that was passed in during reset() self.current_reset_info: Dict[str, Any] = {} - def reset(self, **kwargs) -> Gym26ResetReturn: + def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True @@ -84,7 +82,7 @@ def reset(self, **kwargs) -> Gym26ResetReturn: self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """ Step the environment with the given action @@ -94,7 +92,7 @@ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") observation, reward, terminated, truncated, info = self.env.step(action) - self.rewards.append(reward) + self.rewards.append(float(reward)) if terminated or truncated: self.needs_reset = True ep_rew = sum(self.rewards) diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index e7f2975290..21478bc262 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -2,7 +2,7 @@ import sys from enum import Enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, SupportsFloat, Tuple, Union import gymnasium as gym import numpy as np @@ -18,8 +18,10 @@ GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] Gym26ResetReturn = Tuple[GymObs, Dict] +AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] GymStepReturn = Tuple[GymObs, float, bool, Dict] Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict] +AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] TensorDict = Dict[Union[str, int], th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] From dc742fb062162cf9d6f6c549d7df1525b8b44e62 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 19 Feb 2023 15:18:02 +0100 Subject: [PATCH 120/153] Fix gymnasium dependency --- setup.cfg | 2 -- setup.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index b2445953cc..a8aa96f9ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,6 @@ follow_imports = silent show_error_codes = True exclude = (?x)( stable_baselines3/a2c/a2c.py$ - | stable_baselines3/common/atari_wrappers.py$ | stable_baselines3/common/base_class.py$ | stable_baselines3/common/buffers.py$ | stable_baselines3/common/callbacks.py$ @@ -36,7 +35,6 @@ exclude = (?x)( | stable_baselines3/common/envs/bit_flipping_env.py$ | stable_baselines3/common/envs/identity_env.py$ | stable_baselines3/common/envs/multi_input_envs.py$ - | stable_baselines3/common/monitor.py$ | stable_baselines3/common/logger.py$ | stable_baselines3/common/off_policy_algorithm.py$ | stable_baselines3/common/on_policy_algorithm.py$ diff --git a/setup.py b/setup.py index 6d0f62d3fb..ef44f54bca 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,8 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium==0.27.1", + # TODO(antonin): update to point to a release number once it is merged + "git+https://github.com/pseudo-rnd-thoughts/Gymnasium@fix-wrapper-type-hints", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', From ed8bcf32f1e25c6f5743aa72d36ebe6212ca78f3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 19 Feb 2023 15:20:28 +0100 Subject: [PATCH 121/153] Fix dependency declaration --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ef44f54bca..86ca023b3e 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ # TODO(antonin): update to point to a release number once it is merged - "git+https://github.com/pseudo-rnd-thoughts/Gymnasium@fix-wrapper-type-hints", + "gymnasium @ git+https://github.com/pseudo-rnd-thoughts/Gymnasium@fix-wrapper-type-hints", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', From 75c22667c40c9b4285ca044913b37dc1a6253ad8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 19 Feb 2023 15:44:15 +0100 Subject: [PATCH 122/153] Cap pygame version for python 3.7 --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 86ca023b3e..0ddaaf92cd 100644 --- a/setup.py +++ b/setup.py @@ -123,7 +123,9 @@ "extra": [ # For render "opencv-python", - "pygame", + 'pygame; python_version >= "3.8.0"', + # See https://github.com/pygame/pygame/issues/3572 + 'pygame>=2.0,<2.1.3; python_version < "3.8.0"', # For atari games, "shimmy[atari]~=0.2.1", "autorom[accept-rom-license]~=0.4.2", From ce829c3a501b6a63be127f294a9739ad3c15d0e1 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 23 Feb 2023 22:59:34 +0100 Subject: [PATCH 123/153] Point to master branch (v0.28.0) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0ddaaf92cd..70050958e6 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ # TODO(antonin): update to point to a release number once it is merged - "gymnasium @ git+https://github.com/pseudo-rnd-thoughts/Gymnasium@fix-wrapper-type-hints", + "gymnasium @ git+https://github.com/Farama-Foundation/Gymnasium@master", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', From 645b7baf26484922c068ac31a0acf96fe95d7f91 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 23 Feb 2023 23:02:22 +0100 Subject: [PATCH 124/153] Fix: use main not master branch --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 70050958e6..52d3d518fe 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ # TODO(antonin): update to point to a release number once it is merged - "gymnasium @ git+https://github.com/Farama-Foundation/Gymnasium@master", + "gymnasium @ git+https://github.com/Farama-Foundation/Gymnasium@main", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', From 06ad5a88c713db46604c0e3b1127b6ae5a2be45b Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 25 Feb 2023 15:46:58 +0100 Subject: [PATCH 125/153] Rename done to terminated --- docs/guide/examples.rst | 4 ++-- stable_baselines3/common/envs/bit_flipping_env.py | 6 +++--- stable_baselines3/common/envs/identity_env.py | 12 ++++++------ stable_baselines3/common/envs/multi_input_envs.py | 4 ++-- tests/test_buffers.py | 8 ++++---- tests/test_dict_env.py | 4 ++-- tests/test_env_checker.py | 4 ++-- tests/test_gae.py | 6 +++--- tests/test_monitor.py | 12 ++++++------ tests/test_utils.py | 4 ++-- tests/test_vec_envs.py | 8 ++++---- tests/test_vec_normalize.py | 4 ++-- 12 files changed, 38 insertions(+), 38 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index c91cd2aa37..033a11b79c 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -477,10 +477,10 @@ The parking env is a goal-conditioned continuous control task, in which the vehi episode_reward = 0 for _ in range(100): action, _ = model.predict(obs, deterministic=True) - obs, reward, done, truncated, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) env.render() episode_reward += reward - if done or truncated or info.get("is_success", False): + if terminated or truncated or info.get("is_success", False): print("Reward:", episode_reward, "Success?", info.get("is_success", False)) episode_reward = 0.0 obs, info = env.reset() diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 01e09357fd..19d386e4c9 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -179,12 +179,12 @@ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: self.state[action] = 1 - self.state[action] obs = self._get_obs() reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None)) - done = reward == 0 + terminated = reward == 0 self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps - info = {"is_success": done} + info = {"is_success": terminated} truncated = self.current_step >= self.max_steps - return obs, reward, done, truncated, info + return obs, reward, terminated, truncated, info def compute_reward( self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 9f8234f72f..90e1fdb149 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -46,9 +46,9 @@ def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = False + terminated = False truncated = self.current_step >= self.ep_length - return self.state, reward, done, truncated, {} + return self.state, reward, terminated, truncated, {} def _choose_next_state(self) -> None: self.state = self.action_space.sample() @@ -78,9 +78,9 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[ reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = False + terminated = False truncated = self.current_step >= self.ep_length - return self.state, reward, done, truncated, {} + return self.state, reward, terminated, truncated, {} def _get_reward(self, action: np.ndarray) -> float: return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 @@ -151,9 +151,9 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) - def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: reward = 0.0 self.current_step += 1 - done = False + terminated = False truncated = self.current_step >= self.ep_length - return self.observation_space.sample(), reward, done, truncated, {} + return self.observation_space.sample(), reward, terminated, truncated, {} def render(self, mode: str = "human") -> None: pass diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 65ec372bf5..8fc9ac04fc 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -154,11 +154,11 @@ def step(self, action: Union[float, np.ndarray]) -> Gym26StepReturn: got_to_end = self.state == self.max_state reward = 1 if got_to_end else reward truncated = self.count > self.max_count - done = got_to_end + terminated = got_to_end self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" - return self.get_state_mapping(), reward, done, truncated, {"got_to_end": got_to_end} + return self.get_state_mapping(), reward, terminated, truncated, {"got_to_end": got_to_end} def render(self, mode: str = "human") -> None: """ diff --git a/tests/test_buffers.py b/tests/test_buffers.py index f988f91360..ff392f2f71 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -34,9 +34,9 @@ def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = self._observations[index] - done = truncated = self._t >= self._ep_length + terminated = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, truncated, {} + return obs, reward, terminated, truncated, {} class DummyDictEnv(gym.Env): @@ -62,9 +62,9 @@ def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()} - done = truncated = self._t >= self._ep_length + terminated = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, truncated, {} + return obs, reward, terminated, truncated, {} @pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 42aa468a7f..0265fca444 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -69,8 +69,8 @@ def seed(self, seed=None): def step(self, action): reward = 0.0 - done = truncated = False - return self.observation_space.sample(), reward, done, truncated, {} + terminated = truncated = False + return self.observation_space.sample(), reward, terminated, truncated, {} def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index e55a73fa1f..7e547fcb59 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -13,10 +13,10 @@ class ActionDictTestEnv(gym.Env): def step(self, action): observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) reward = 1 - done = True + terminated = True truncated = False info = {} - return observation, reward, done, truncated, info + return observation, reward, terminated, truncated, info def reset(self): return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} diff --git a/tests/test_gae.py b/tests/test_gae.py index 35c5689596..58e3e4158e 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -32,17 +32,17 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): def step(self, action): self.n_steps += 1 - done = truncated = False + terminated = truncated = False reward = 0.0 if self.n_steps >= self.max_steps: reward = 1.0 - done = True + terminated = True # To simplify GAE computation checks, # we do not consider truncation here. # Truncations are checked in InfiniteHorizonEnv truncated = False - return self.observation_space.sample(), reward, done, truncated, {} + return self.observation_space.sample(), reward, terminated, truncated, {} class InfiniteHorizonEnv(gym.Env): diff --git a/tests/test_monitor.py b/tests/test_monitor.py index c580fcf49b..481ef2178b 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -22,10 +22,10 @@ def test_monitor(tmp_path): ep_lengths = [] ep_len, ep_reward = 0, 0 for _ in range(total_steps): - _, reward, done, truncated, _ = monitor_env.step(monitor_env.action_space.sample()) + _, reward, terminated, truncated, _ = monitor_env.step(monitor_env.action_space.sample()) ep_len += 1 ep_reward += reward - if done or truncated: + if terminated or truncated: ep_rewards.append(ep_reward) ep_lengths.append(ep_len) monitor_env.reset() @@ -75,8 +75,8 @@ def test_monitor_load_results(tmp_path): monitor_env1.reset() episode_count1 = 0 for _ in range(1000): - _, _, done, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample()) - if done or truncated: + _, _, terminated, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample()) + if terminated or truncated: episode_count1 += 1 monitor_env1.reset() @@ -98,8 +98,8 @@ def test_monitor_load_results(tmp_path): monitor_env2 = Monitor(env2, monitor_file2, override_existing=False) monitor_env2.reset() for _ in range(1000): - _, _, done, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample()) - if done or truncated: + _, _, terminated, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample()) + if terminated or truncated: episode_count2 += 1 monitor_env2.reset() diff --git a/tests/test_utils.py b/tests/test_utils.py index d3cac07424..96ed6a5610 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -229,8 +229,8 @@ def __init__(self, env): self.needs_reset = True def step(self, action): - obs, reward, done, truncated, info = self.env.step(action) - self.needs_reset = done or truncated + obs, reward, terminated, truncated, info = self.env.step(action) + self.needs_reset = terminated or truncated self.last_obs = obs return obs, reward, True, truncated, info diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 6129438fe6..a1ec3803b7 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -42,8 +42,8 @@ def step(self, action): reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 - done = truncated = self.current_step >= self.ep_length - return self.state, reward, done, truncated, {} + terminated = truncated = self.current_step >= self.ep_length + return self.state, reward, terminated, truncated, {} def _choose_next_state(self): self.state = self.observation_space.sample() @@ -178,8 +178,8 @@ def reset(self): def step(self, action): prev_step = self.current_step self.current_step += 1 - done = truncated = self.current_step >= self.max_steps - return np.array([prev_step], dtype="int"), 0.0, done, truncated, {} + terminated = truncated = self.current_step >= self.max_steps + return np.array([prev_step], dtype="int"), 0.0, terminated, truncated, {} @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index bbe9c4faa6..7c9eabfd82 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -35,8 +35,8 @@ def step(self, action): self.t += 1 index = (self.t + self.return_reward_idx) % len(self.returned_rewards) returned_value = self.returned_rewards[index] - done = truncated = self.t == len(self.returned_rewards) - return np.array([returned_value]), returned_value, done, truncated, {} + terminated = truncated = self.t == len(self.returned_rewards) + return np.array([returned_value]), returned_value, terminated, truncated, {} def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: From 93c10cff9811f558aa84c2afc92f10b15cc3d548 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 25 Feb 2023 16:04:04 +0100 Subject: [PATCH 126/153] Fix pygame dependency for python 3.7 --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b72de46732..4d767c8b97 100644 --- a/setup.py +++ b/setup.py @@ -122,7 +122,9 @@ "extra": [ # For render "opencv-python", - "pygame", + 'pygame; python_version >= "3.8.0"', + # See https://github.com/pygame/pygame/issues/3572 + 'pygame>=2.0,<2.1.3; python_version < "3.8.0"', # For atari games, "ale-py~=0.8.0", "autorom[accept-rom-license]~=0.4.2", From 29201a75556635c54f9a9cd6926ac61c9f074f79 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 27 Feb 2023 13:48:09 +0100 Subject: [PATCH 127/153] Rename gym to gymnasium --- stable_baselines3/common/vec_env/vec_check_nan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/vec_check_nan.py b/stable_baselines3/common/vec_env/vec_check_nan.py index 98ad217f66..170f36ec8d 100644 --- a/stable_baselines3/common/vec_env/vec_check_nan.py +++ b/stable_baselines3/common/vec_env/vec_check_nan.py @@ -2,7 +2,7 @@ from typing import List, Tuple import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper From ae92d23e3627051184e086ce43e6217aab145419 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 3 Mar 2023 17:48:38 +0100 Subject: [PATCH 128/153] Update Gymnasium --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 52d3d518fe..ef8e95ccec 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ # TODO(antonin): update to point to a release number once it is merged - "gymnasium @ git+https://github.com/Farama-Foundation/Gymnasium@main", + "gymnasium @ git+https://github.com/pseudo-rnd-thoughts/Gymnasium@update-env-spec", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', From 62966995737bad36866d7539aa7611a40ae99929 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 3 Mar 2023 23:43:29 +0100 Subject: [PATCH 129/153] Fix test --- tests/test_her.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_her.py b/tests/test_her.py index 2e385d51e3..a70ec71ecb 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -348,7 +348,7 @@ def test_get_max_episode_length(): # Set max_episode_steps to None env.spec.max_episode_steps = None - vec_env = DummyVecEnv([lambda: env]) + vec_env = DummyVecEnv([lambda: gym.make(env.spec)]) with pytest.raises(ValueError): get_time_limit(vec_env, current_max_episode_length=None) From a2a03c536ad24cf92a9510ec931cacdeb9f22c78 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 7 Mar 2023 23:17:39 +0100 Subject: [PATCH 130/153] Fix tests --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index fd9f722867..e7d750e60f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv]) -@pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.TimeLimit]) +@pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.RecordEpisodeStatistics]) def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class): env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, wrapper_class=wrapper_class, monitor_dir=None, seed=0) From b82cacd44a1cb9fc3bb762a37f45b439d3afd733 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 11 Mar 2023 23:00:36 +0100 Subject: [PATCH 131/153] Forks don't have access to private variables --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fc47e9dae8..50f636d510 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,10 +36,10 @@ jobs: pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html # Install Atari Roms - pip install autorom - wget $ATARI_ROMS - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # pip install autorom + # wget $ATARI_ROMS + # base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + # AutoROM --accept-license --source-file Roms.tar.gz pip install .[extra,tests,docs] # Use headless version From d0f5e8ab98608d9468b0dc31527566bbb357e04e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 12 Mar 2023 18:52:14 +0100 Subject: [PATCH 132/153] Fix linter warnings --- stable_baselines3/common/atari_wrappers.py | 2 +- stable_baselines3/common/vec_env/stacked_observations.py | 2 +- stable_baselines3/common/vec_env/subproc_vec_env.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 01ab0dc816..706e0f40fd 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -161,7 +161,7 @@ def __init__(self, env: gym.Env, skip: int = 4) -> None: # most recent raw observations (for max pooling across time steps) assert env.observation_space.dtype is not None, "No dtype specified for the observation space" assert env.observation_space.shape is not None, "No shape defined for the observation space" - self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype) + self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) self._skip = skip def step(self, action: int) -> Gym26StepReturn: diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 555c6f23fe..9fac9735ff 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -60,7 +60,7 @@ def __init__( high=high, dtype=observation_space.dtype, # type: ignore[arg-type] ) - self.stacked_obs = np.zeros((num_envs,) + self.stacked_shape, dtype=observation_space.dtype) + self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype) else: raise TypeError( f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided." diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 9fac06f3f6..73d65106f9 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -16,7 +16,7 @@ ) -def _worker( # noqa: C901 +def _worker( remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper, From 5328921b58d11904b9c5a6352987751af607ddcd Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 14 Mar 2023 12:59:15 +0100 Subject: [PATCH 133/153] Update read the doc env --- docs/conda_env.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 7b89ba92bd..0545eef3c3 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -4,11 +4,11 @@ channels: - defaults dependencies: - cpuonly=1.0=0 - - pip=22.1.1 + - pip=22.3.1 - python=3.7 - pytorch=1.11.0=py3.7_cpu_0 - pip: - - gym==0.26 + - gymnasium - cloudpickle - opencv-python-headless - pandas From 986e6c0efde96365f08a9bd086719b663caeadd2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 20 Mar 2023 13:25:13 +0100 Subject: [PATCH 134/153] Fix env checker for GoalEnv --- stable_baselines3/common/env_checker.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 2f8359b066..8c54cbcfa5 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -252,14 +252,16 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # Unpack obs, reward, terminated, truncated, info = data - if _is_goal_env(env): - # Make mypy happy, already checked - assert isinstance(observation_space, spaces.Dict) - _check_goal_env_obs(obs, observation_space, "step") - _check_goal_env_compute_reward(obs, env, reward, info) - elif isinstance(observation_space, spaces.Dict): + if isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" + # Additional checks for GoalEnvs + if _is_goal_env(env): + # Make mypy happy, already checked + assert isinstance(observation_space, spaces.Dict) + _check_goal_env_obs(obs, observation_space, "step") + _check_goal_env_compute_reward(obs, env, reward, info) + if not obs.keys() == observation_space.spaces.keys(): raise AssertionError( "The observation keys returned by `step()` must match the observation " From e5a1e9ef40be2d78268cc10fa96fccad9c1e96ce Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 20 Mar 2023 13:32:12 +0100 Subject: [PATCH 135/153] Fix import --- stable_baselines3/her/her_replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 5a438b411e..91816c4f00 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -4,7 +4,7 @@ import numpy as np import torch as th -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.buffers import DictReplayBuffer from stable_baselines3.common.type_aliases import DictReplayBufferSamples, TensorDict From 331853a6df96f8c8b0d3f96353fe014e446447e2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Mar 2023 22:13:34 +0200 Subject: [PATCH 136/153] Update env checker (more info) and fix dtype --- stable_baselines3/common/env_checker.py | 23 +++++++++++-------- .../common/envs/bit_flipping_env.py | 2 +- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index c6cf097cbd..0cda5e5080 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -191,27 +191,32 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac # check obs dimensions, dtype and bounds assert observation_space.shape == obs.shape, ( f"The observation returned by the `{method_name}()` method does not match the shape " - f"of the given observation space. Expected: {observation_space.shape}, actual shape: {obs.shape}" + f"of the given observation space {observation_space}. " + f"Expected: {observation_space.shape}, actual shape: {obs.shape}" ) - assert observation_space.dtype == obs.dtype, ( - f"The observation returned by the `{method_name}()` method does not match the data type " - f"of the given observation space. Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" + assert np.can_cast(obs.dtype, observation_space.dtype), ( + f"The observation returned by the `{method_name}()` method does not match the data type (cannot cast) " + f"of the given observation space {observation_space}. " + f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" ) if isinstance(observation_space, spaces.Box): assert np.all(obs >= observation_space.low), ( f"The observation returned by the `{method_name}()` method does not match the lower bound " - f"of the given observation space. Expected: obs >= {np.min(observation_space.low)}, " + f"of the given observation space {observation_space}." + f"Expected: obs >= {np.min(observation_space.low)}, " f"actual min value: {np.min(obs)} at index {np.argmin(obs)}" ) assert np.all(obs <= observation_space.high), ( f"The observation returned by the `{method_name}()` method does not match the upper bound " - f"of the given observation space. Expected: obs <= {np.max(observation_space.high)}, " + f"of the given observation space {observation_space}. " + f"Expected: obs <= {np.max(observation_space.high)}, " f"actual max value: {np.max(obs)} at index {np.argmax(obs)}" ) - assert observation_space.contains( - obs - ), f"The observation returned by the `{method_name}()` method does not match the given observation space" + assert observation_space.contains(obs), ( + f"The observation returned by the `{method_name}()` method " + f"does not match the given observation space {observation_space}" + ) def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 19d386e4c9..090985dcc9 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -96,7 +96,7 @@ def __init__( self.discrete_obs_space = discrete_obs_space self.image_obs_space = image_obs_space self.state = None - self.desired_goal = np.ones((n_bits,)) + self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype) if max_steps is None: max_steps = n_bits self.max_steps = max_steps From 68861b6cb19ad7e2ff16baa63c32af9f229ab1db Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Mar 2023 23:23:35 +0200 Subject: [PATCH 137/153] Use micromamab for Docker --- Dockerfile | 43 +++++++++++---------------------------- scripts/build_docker.sh | 6 +++--- scripts/run_docker_cpu.sh | 4 ++-- scripts/run_docker_gpu.sh | 4 ++-- 4 files changed, 19 insertions(+), 38 deletions(-) diff --git a/Dockerfile b/Dockerfile index 96588ef91d..712a795d1a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,44 +1,25 @@ ARG PARENT_IMAGE FROM $PARENT_IMAGE ARG PYTORCH_DEPS=cpuonly -ARG PYTHON_VERSION=3.7 +ARG PYTHON_VERSION=3.8 +ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found) -# for tzdata -ENV DEBIAN_FRONTEND="noninteractive" TZ="Europe/Paris" - -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - git \ - curl \ - ca-certificates \ - libjpeg-dev \ - libpng-dev \ - libglib2.0-0 && \ - rm -rf /var/lib/apt/lists/* - -# Install Anaconda and dependencies -RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh && \ - /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \ - /opt/conda/bin/conda install -y pytorch=1.11 $PYTORCH_DEPS -c pytorch && \ - /opt/conda/bin/conda clean -ya -ENV PATH /opt/conda/bin:$PATH - -ENV CODE_DIR /root/code +ENV CODE_DIR /home/$MAMBA_USER # Copy setup file only to install dependencies -COPY ./setup.py ${CODE_DIR}/stable-baselines3/setup.py -COPY ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt +COPY --chown=$MAMBA_USER:$MAMBA_USER ./setup.py ${CODE_DIR}/stable-baselines3/setup.py +COPY --chown=$MAMBA_USER:$MAMBA_USER ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt + +# Install micromamba env and dependencies +RUN micromamba install -n base -y python=$PYTHON_VERSION \ + pytorch $PYTORCH_DEPS -c conda-forge -c pytorch -c nvidia && \ + micromamba clean --all --yes -RUN \ - cd ${CODE_DIR}/stable-baselines3 3&& \ +RUN cd ${CODE_DIR}/stable-baselines3 && \ pip install -e .[extra,tests,docs] && \ # Use headless version for docker pip uninstall -y opencv-python && \ pip install opencv-python-headless && \ - rm -rf $HOME/.cache/pip + pip cache purge CMD /bin/bash diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 3f0d5ae7c9..c1a4a5608f 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -1,14 +1,14 @@ #!/bin/bash -CPU_PARENT=ubuntu:20.04 -GPU_PARENT=nvidia/cuda:11.3.1-base-ubuntu20.04 +CPU_PARENT=mambaorg/micromamba:1.4-kinetic +GPU_PARENT=mambaorg/micromamba:1.4.1-focal-cuda-11.7.1 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) if [[ ${USE_GPU} == "True" ]]; then PARENT=${GPU_PARENT} - PYTORCH_DEPS="cudatoolkit=11.3" + PYTORCH_DEPS="pytorch-cuda=11.7" else PARENT=${CPU_PARENT} PYTORCH_DEPS="cpuonly" diff --git a/scripts/run_docker_cpu.sh b/scripts/run_docker_cpu.sh index 6dfafd2b90..db6c6493b6 100755 --- a/scripts/run_docker_cpu.sh +++ b/scripts/run_docker_cpu.sh @@ -7,5 +7,5 @@ echo "Executing in the docker (cpu image):" echo $cmd_line docker run -it --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ - bash -c "cd /root/code/stable-baselines3/ && $cmd_line" + --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ + bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" diff --git a/scripts/run_docker_gpu.sh b/scripts/run_docker_gpu.sh index 19e16067a4..fa8aae9c43 100755 --- a/scripts/run_docker_gpu.sh +++ b/scripts/run_docker_gpu.sh @@ -15,5 +15,5 @@ else fi docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ - bash -c "cd /root/code/stable-baselines3/ && $cmd_line" + --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ + bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" From 9f0d5d85f2dfa6c09f7552d9bafab5d64cf4ce33 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Mar 2023 23:27:07 +0200 Subject: [PATCH 138/153] Update dependencies --- setup.py | 4 ++-- stable_baselines3/common/atari_wrappers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index cbbea480a9..4e1aea9185 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ extra_packages = extra_no_roms + [ # noqa: RUF005 # For atari roms, - "autorom[accept-rom-license]~=0.5.5", + "autorom[accept-rom-license]~=0.6.0", ] @@ -139,7 +139,7 @@ # For spelling "sphinxcontrib.spelling", # Type hints support - "sphinx-autodoc-typehints==1.21.1", # TODO: remove version constraint, see #1290 + "sphinx-autodoc-typehints", # Copy button for code snippets "sphinx_copybutton", ], diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 706e0f40fd..1264d27fab 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -170,7 +170,7 @@ def step(self, action: int) -> Gym26StepReturn: Repeat action, sum reward, and max over last observations. :param action: the action - :return: observation, reward, done, information + :return: observation, reward, terminated, truncated, information """ total_reward = 0.0 terminated = truncated = False From 6617e6e73cb3a70f3e88cea780ea12bed95c099e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 30 Mar 2023 00:21:45 +0200 Subject: [PATCH 139/153] Clarify VecEnv doc --- docs/guide/vec_envs.rst | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index 3bfe69187b..ea99444d1a 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -52,26 +52,46 @@ SB3 VecEnv API is not the same as Gym API. SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: - the ``reset()`` method only returns the observation (``obs = vec_env.reset()``) and not a tuple, the info at reset are stored in ``vec_env.reset_infos``. + - only the initial call to ``vec_env.reset()`` is required, environments are reset automatically afterward (and ``reset_infos`` is updated automatically). + - the ``vec_env.step(actions)`` method expects an array as input (with a batch size corresponding to the number of environments) and returns a 4-tuple (and not a 5-tuple): ``obs, rewards, dones, infos`` instead of ``obs, reward, terminated, truncated, info`` where ``dones = terminated or truncated`` (for each env). ``obs, rewards, dones`` are numpy arrays with shape ``(n_envs, shape_for_single_env)`` (so with a batch dimension). Additional information is passed via the ``infos`` value which is a list of dictionaries. + - at the end of an episode, ``infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated`` tells the user if an episode was truncated or not: - you should bootstrap when ``infos[env_idx]["TimeLimit.truncated"] is True`` or ``dones[env_idx] is False``. + you should bootstrap if ``infos[env_idx]["TimeLimit.truncated"] is True`` (episode over due to a timeout/truncation) + or ``dones[env_idx] is False`` (episode not finished). Note: compared to Gym 0.26+ ``infos[env_idx]["TimeLimit.truncated"]`` and ``terminated`` `are mutually exclusive `_. + The conversion from SB3 to Gym API is + + .. code-block:: python + + # done is True at the end of an episode + # dones[env_idx] = terminated[env_idx] or truncated[env_idx] + # In SB3, truncated and terminated are mutually exclusive + # infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated + # terminated[env_idx] tells you whether you should bootstrap or not: + # when the episode has not ended or when the termination was a timeout/truncation + terminated[env_idx] = dones[env_idx] and not infos[env_idx]["TimeLimit.truncated"] + should_bootstrap[env_idx] = not terminated[env_idx] + + - at the end of an episode, because the environment resets automatically, we provide ``infos[env_idx]["terminal_observation"]`` which contains the last observation of an episode (and can be used when bootstrapping, see note in the previous section) -- if you pass ``render_mode="rgb_array"`` to your Gym env, a corresponding VecEnv can automatically show the rendered image - by calling ``vec_env.render(mode="human")``. This is different from Gym which currently `doesn't allow multiple render modes `_ - and doesn't allow passing a ``mode`` parameter. Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). + +- to overcome the current Gymnasium limitation (only one render mode allowed per env instance, see `issue #100 `_), + we recommend using ``render_mode="rgb_array"`` since we can both have the image as a numpy array and display it with OpenCV. + if no mode is passed or ``mode="rgb_array"`` is passed when calling ``vec_env.render`` then we use the default mode, otherwise, we use the OpenCV display. + Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). + - the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator, you should call ``vec_env.seed(seed=seed)`` and ``obs = vec_env.reset()`` afterward. - If your Gym env implements a ``seed()`` method then it will be called, - otherwise ``env.reset(seed=seed)`` will be called (in that case, you will need two resets to set the seed). + - methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``, ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. From 39eae4ec4a47c62e8ab0b683fd39de5bc4b27b5f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 30 Mar 2023 00:25:07 +0200 Subject: [PATCH 140/153] Fix Gymnasium version --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 09ad3d7c41..a7a3fcc8b3 100644 --- a/setup.py +++ b/setup.py @@ -102,8 +102,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - # TODO(antonin): update to point to a release number once it is merged - "gymnasium @ git+https://github.com/Farama-Foundation/Gymnasium@main", + "gymnasium==0.28.1", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', From b853bc655f5f7aa3386d4c9b378c378d20f963fc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 30 Mar 2023 00:28:19 +0200 Subject: [PATCH 141/153] Copy file only after mamba install --- Dockerfile | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 712a795d1a..421324dfff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,17 +4,17 @@ ARG PYTORCH_DEPS=cpuonly ARG PYTHON_VERSION=3.8 ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found) +# Install micromamba env and dependencies +RUN micromamba install -n base -y python=$PYTHON_VERSION \ + pytorch $PYTORCH_DEPS -c conda-forge -c pytorch -c nvidia && \ + micromamba clean --all --yes + ENV CODE_DIR /home/$MAMBA_USER # Copy setup file only to install dependencies COPY --chown=$MAMBA_USER:$MAMBA_USER ./setup.py ${CODE_DIR}/stable-baselines3/setup.py COPY --chown=$MAMBA_USER:$MAMBA_USER ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt -# Install micromamba env and dependencies -RUN micromamba install -n base -y python=$PYTHON_VERSION \ - pytorch $PYTORCH_DEPS -c conda-forge -c pytorch -c nvidia && \ - micromamba clean --all --yes - RUN cd ${CODE_DIR}/stable-baselines3 && \ pip install -e .[extra,tests,docs] && \ # Use headless version for docker From 621f64fbfd3e0d47b070e28fd475a10ca47af26f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 30 Mar 2023 00:36:52 +0200 Subject: [PATCH 142/153] [ci skip] Update docker doc --- docs/guide/install.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 312b86fcb5..2a00830517 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -131,7 +131,7 @@ Run the nvidia-docker GPU image .. code-block:: bash - docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /root/code/stable-baselines3/ && pytest tests/' + docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/' Or, with the shell file: @@ -143,7 +143,7 @@ Run the docker CPU image .. code-block:: bash - docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /root/code/stable-baselines3/ && pytest tests/' + docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/' Or, with the shell file: @@ -165,7 +165,7 @@ Explanation of the docker command: - ``--name test`` give explicitly the name ``test`` to the container, otherwise it will be assigned a random name - ``--mount src=...`` give access of the local directory (``pwd`` - command) to the container (it will be map to ``/root/code/stable-baselines``), so + command) to the container (it will be map to ``/home/mamba/stable-baselines``), so all the logs created in the container in this folder will be kept - ``bash -c '...'`` Run command inside the docker image, here run the tests (``pytest tests/``) From 5e1f5076687a46ccf18a31dd6e58ba0e48b94650 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 30 Mar 2023 01:01:48 +0200 Subject: [PATCH 143/153] Polish code --- pyproject.toml | 2 +- stable_baselines3/common/env_checker.py | 7 +++++-- stable_baselines3/common/preprocessing.py | 6 +++--- tests/test_buffers.py | 6 ++++-- tests/test_vec_envs.py | 6 ++++-- tests/test_vec_normalize.py | 11 ++++++----- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3e3359f36e..9b5ca1bbcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings - # "ignore::UserWarning:gym", + "ignore::UserWarning:gymnasium", # "ignore::DeprecationWarning:.*passive_env_checker.*", ] markers = [ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index bb100da0e4..43678bc747 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -250,7 +250,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)" assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}" obs, info = reset_returns - assert isinstance(info, dict), "The second element of the tuple return by `reset()` must be a dictionary" + assert isinstance(info, dict), f"The second element of the tuple return by `reset()` must be a dictionary not {info}" if _is_goal_env(env): # Make mypy happy, already checked @@ -277,7 +277,10 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action action = action_space.sample() data = env.step(action) - assert len(data) == 5, "The `step()` method must return four values: obs, reward, terminated, truncated, info" + assert len(data) == 5, ( + "The `step()` method must return five values: " + f"obs, reward, terminated, truncated, info. Actual: {len(data)} values returned." + ) # Unpack obs, reward, terminated, truncated, info = data diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 6b35481b4d..bc0959480c 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -202,9 +202,9 @@ def get_action_dim(action_space: spaces.Space) -> int: return int(len(action_space.nvec)) elif isinstance(action_space, spaces.MultiBinary): # Number of binary actions - assert isinstance(action_space.n, int), ( - "Multi-dimensional MultiBinary action space is not supported. " "You can flatten it instead." - ) + assert isinstance( + action_space.n, int + ), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead." return int(action_space.n) else: raise NotImplementedError(f"{action_space} action space is not supported") diff --git a/tests/test_buffers.py b/tests/test_buffers.py index e7dfddcaef..825002c929 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -34,7 +34,8 @@ def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = self._observations[index] - terminated = truncated = self._t >= self._ep_length + terminated = False + truncated = self._t >= self._ep_length reward = self._rewards[index] return obs, reward, terminated, truncated, {} @@ -63,7 +64,8 @@ def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()} - terminated = truncated = self._t >= self._ep_length + terminated = False + truncated = self._t >= self._ep_length reward = self._rewards[index] return obs, reward, terminated, truncated, {} diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 466d27dced..6bc7e74db1 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -42,7 +42,8 @@ def step(self, action): reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 - terminated = truncated = self.current_step >= self.ep_length + terminated = False + truncated = self.current_step >= self.ep_length return self.state, reward, terminated, truncated, {} def _choose_next_state(self): @@ -178,7 +179,8 @@ def reset(self): def step(self, action): prev_step = self.current_step self.current_step += 1 - terminated = truncated = self.current_step >= self.max_steps + terminated = False + truncated = self.current_step >= self.max_steps return np.array([prev_step], dtype="int"), 0.0, terminated, truncated, {} diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index c3e1ab6c90..ae59047951 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -35,7 +35,8 @@ def step(self, action): self.t += 1 index = (self.t + self.return_reward_idx) % len(self.returned_rewards) returned_value = self.returned_rewards[index] - terminated = truncated = self.t == len(self.returned_rewards) + terminated = False + truncated = self.t == len(self.returned_rewards) return np.array([returned_value]), returned_value, terminated, truncated, {} def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): @@ -69,8 +70,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): def step(self, action): obs = self.observation_space.sample() reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {}) - done = np.random.rand() > 0.8 - return obs, reward, done, False, {} + terminated = np.random.rand() > 0.8 + return obs, reward, terminated, False, {} def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32: distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) @@ -100,8 +101,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): def step(self, action): obs = self.observation_space.sample() - done = np.random.rand() > 0.8 - return obs, 0.0, done, False, {} + terminated = np.random.rand() > 0.8 + return obs, 0.0, terminated, False, {} def allclose(obs_1, obs_2): From e77937f8dcbe73e5b6d01e730461a1c1156ce792 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 3 Apr 2023 10:49:12 +0200 Subject: [PATCH 144/153] Reformat --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e0a9c9b1bd..af4888c866 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -190,7 +190,7 @@ def dummy_callback(locals_, _globals): policy.n_callback_calls = 0 # type: ignore[assignment, attr-defined] _, episode_lengths = evaluate_policy( policy, # type: ignore[arg-type] - model.get_env(), # type: ignore[arg-type] + model.get_env(), # type: ignore[arg-type] n_eval_episodes, deterministic=True, render=False, @@ -209,7 +209,7 @@ def dummy_callback(locals_, _globals): episode_rewards, _ = evaluate_policy( policy, # type: ignore[arg-type] - model.get_env(), # type: ignore[arg-type] + model.get_env(), # type: ignore[arg-type] n_eval_episodes, return_episode_rewards=True, ) From 13905672401f26705c37df15f7050d5db490ec26 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 5 Apr 2023 20:12:48 +0200 Subject: [PATCH 145/153] Remove deprecated features --- docs/misc/changelog.rst | 39 +++++++++++++++++++ .../common/vec_env/stacked_observations.py | 30 -------------- stable_baselines3/her/her_replay_buffer.py | 9 ----- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8b851661bc..834332c6e0 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,45 @@ Changelog ========== +Release 2.0.0a1 (WIP) +-------------------------- + +**Gymnasium support** + +.. warning:: + + Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023). + We highly recommended you to upgrade to Python >= 3.8. + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package +- The deprecated ``online_sampling`` argument of ``HerReplayBuffer`` was removed +- Removed deprecated ``stack_observation_space`` method of ``StackedObservations`` + +New Features: +^^^^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ +- Added documentation about ``VecEnv`` API vs Gym API + Release 1.8.0a14 (WIP) -------------------------- diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 27ad4e98a5..21f362f0cb 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -101,36 +101,6 @@ def compute_stacking( stacked_shape[repeat_axis] *= n_stack return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis - def stack_observation_space(self, observation_space: Union[spaces.Box, spaces.Dict]) -> Union[spaces.Box, spaces.Dict]: - """ - This function is deprecated. - - As an alternative, use - - .. code-block:: python - - low = np.repeat(observation_space.low, stacked_observation.n_stack, axis=stacked_observation.repeat_axis) - high = np.repeat(observation_space.high, stacked_observation.n_stack, axis=stacked_observation.repeat_axis) - stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype) - - :return: New observation space with stacked dimensions - """ - warnings.warn( - "stack_observation_space is deprecated and will be removed in the next SB3 release. " - "Please refer to the docstring for a workaround.", - DeprecationWarning, - ) - if isinstance(observation_space, spaces.Dict): - return spaces.Dict( - { - key: sub_stacked_observation.stack_observation_space(sub_stacked_observation.observation_space) - for key, sub_stacked_observation in self.sub_stacked_observations.items() - } - ) - low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) - high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) - return spaces.Box(low=low, high=high, dtype=observation_space.dtype) # type: ignore[arg-type] - def reset(self, observation: TObs) -> TObs: """ Reset the stacked_obs, add the reset observation to the stack, and return the stack. diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 91816c4f00..9e06d6444b 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -58,7 +58,6 @@ def __init__( n_sampled_goal: int = 4, goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future", copy_info_dict: bool = False, - online_sampling: Optional[bool] = None, ): super().__init__( buffer_size, @@ -72,14 +71,6 @@ def __init__( self.env = env self.copy_info_dict = copy_info_dict - if online_sampling is not None: - assert online_sampling is True, "Since v1.8.0, SB3 only supports online sampling with HerReplayBuffer." - warnings.warn( - "Since v1.8.0, the `online_sampling` argument is deprecated " - "as SB3 only supports online sampling with HerReplayBuffer. It will be removed in v2.0", - stacklevel=1, - ) - # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()] From 0e1e50d455af48bca4d3e2c1989ef67819ee9643 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 5 Apr 2023 20:17:34 +0200 Subject: [PATCH 146/153] Ignore warning --- tests/test_her.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_her.py b/tests/test_her.py index 1db12be41f..e17336bc5f 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -260,7 +260,7 @@ def env_fn(): del model.replay_buffer with pytest.raises(AttributeError): - model.replay_buffer + model.replay_buffer # noqa: B018 # Check that there is no warning assert len(recwarn) == 0 From 9e572371ae3e9a092e87d7190f1b4892e6f405dc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 13 Apr 2023 12:53:32 +0200 Subject: [PATCH 147/153] Update doc --- docs/guide/install.rst | 13 ------------- stable_baselines3/common/env_checker.py | 3 +-- .../common/vec_env/stacked_observations.py | 2 +- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 2a00830517..68f8f764fa 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -54,19 +54,6 @@ Bleeding-edge version pip install git+https://github.com/DLR-RM/stable-baselines3 -.. note:: - - If you want to use Gymnasium (or the latest Gym version 0.24+), you have to use - - .. code-block:: bash - - pip install git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support - pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support - - - See `PR #1327 `_ for more information. - - Development version ------------------- diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 43678bc747..cc8be48ef6 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -366,8 +366,7 @@ def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover "you may have trouble when calling `.render()`" ) - # TODO: if we want to check all declared render modes, - # we need to initialize new environments so the class should be passed as argument. + # Only check currrent render mode if env.render_mode: env.render() env.close() diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 02abf8702c..bf375e1650 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -31,7 +31,7 @@ def __init__( self, num_envs: int, n_stack: int, - observation_space: Union[spaces.Box, spaces.Dict], # Replace by Space[TObs] in gym>=0.26 + observation_space: Union[spaces.Box, spaces.Dict], channels_order: Optional[Union[str, Mapping[str, Optional[str]]]] = None, ) -> None: self.n_stack = n_stack From 3bc8918b40526cb8c0ec234cf6a24740b2216886 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 13 Apr 2023 13:28:33 +0200 Subject: [PATCH 148/153] Update examples and changelog --- docs/guide/examples.rst | 82 +++++++++++++++++++++-------------------- docs/misc/changelog.rst | 17 ++++++--- pyproject.toml | 1 - 3 files changed, 55 insertions(+), 45 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 300296aa92..5935f504a6 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -123,18 +123,18 @@ Multiprocessing: Unleashing the Power of Vectorized Environments from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.utils import set_random_seed - def make_env(env_id, rank, seed=0): + def make_env(env_id: str, rank: int, seed: int = 0): """ Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses - :param seed: (int) the inital seed for RNG - :param rank: (int) index of the subprocess + :param env_id: the environment ID + :param num_env: the number of environments you wish to have in subprocesses + :param seed: the inital seed for RNG + :param rank: index of the subprocess """ def _init(): - env = gym.make(env_id) - env.seed(seed + rank) + env = gym.make(env_id, render_mode="human") + env.reset(seed=seed + rank) return env set_random_seed(seed) return _init @@ -143,21 +143,21 @@ Multiprocessing: Unleashing the Power of Vectorized Environments env_id = "CartPole-v1" num_cpu = 4 # Number of processes to use # Create the vectorized environment - env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)]) + vec_env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)]) # Stable Baselines provides you with make_vec_env() helper # which does exactly the previous steps for you. # You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv` # env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv) - model = PPO("MlpPolicy", env, verbose=1) + model = PPO("MlpPolicy", vec_env, verbose=1) model.learn(total_timesteps=25_000) - obs = env.reset() + obs = vec_env.reset() for _ in range(1000): action, _states = model.predict(obs) - obs, rewards, dones, info = env.step(action) - env.render() + obs, rewards, dones, info = vec_env.step(action) + vec_env.render() Multiprocessing with off-policy algorithms @@ -178,12 +178,12 @@ Multiprocessing with off-policy algorithms from stable_baselines3 import SAC from stable_baselines3.common.env_util import make_vec_env - env = make_vec_env("Pendulum-v0", n_envs=4, seed=0) + vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0) # We collect 4 transitions per call to `ènv.step()` # and performs 2 gradient steps per call to `ènv.step()` # if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()` - model = SAC("MlpPolicy", env, train_freq=1, gradient_steps=2, verbose=1) + model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1) model.learn(total_timesteps=10_000) @@ -337,18 +337,18 @@ and multiprocessing for you. To install the Atari environments, run the command # There already exists an environment generator # that will make and wrap atari environments correctly. # Here we are also multi-worker training (n_envs=4 => 4 environments) - env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0) + vec_env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0) # Frame-stacking with 4 frames - env = VecFrameStack(env, n_stack=4) + vec_env = VecFrameStack(vec_env, n_stack=4) - model = A2C("CnnPolicy", env, verbose=1) + model = A2C("CnnPolicy", vec_env, verbose=1) model.learn(total_timesteps=25_000) - obs = env.reset() + obs = vec_env.reset() while True: - action, _states = model.predict(obs) - obs, rewards, dones, info = env.step(action) - env.render() + action, _states = model.predict(obs, deterministic=False) + obs, rewards, dones, info = vec_env.step(action) + vec_env.render("human") PyBullet: Normalizing input features @@ -378,12 +378,16 @@ will compute a running average and standard deviation of input features (it can from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize from stable_baselines3 import PPO - env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) + # Note: pybullet is not compatible yet with Gymnasium + # you might need to use `import rl_zoo3.gym_patches` + # and use gym (not Gymnasium) to instanciate the env + # Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4" + vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) # Automatically normalize the input features and reward - env = VecNormalize(env, norm_obs=True, norm_reward=True, + vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.) - model = PPO("MlpPolicy", env) + model = PPO("MlpPolicy", vec_env) model.learn(total_timesteps=2000) # Don't forget to save the VecNormalize statistics when saving the agent @@ -393,18 +397,18 @@ will compute a running average and standard deviation of input features (it can env.save(stats_path) # To demonstrate loading - del model, env + del model, vec_env # Load the saved statistics - env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) - env = VecNormalize.load(stats_path, env) + vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) + vec_env = VecNormalize.load(stats_path, vec_env) # do not update them at test time - env.training = False + vec_env.training = False # reward normalization is not needed at test time - env.norm_reward = False + vec_env.norm_reward = False # Load the agent - model = PPO.load(log_dir + "ppo_halfcheetah", env=env) + model = PPO.load(log_dir + "ppo_halfcheetah", env=vec_env) Hindsight Experience Replay (HER) @@ -662,7 +666,7 @@ A2C policy gradient updates on the model. # Keep top 10% n_elite = pop_size // 10 # Retrieve the environment - env = model.get_env() + vec_env = model.get_env() for iteration in range(10): # Create population of candidates and evaluate them @@ -674,7 +678,7 @@ A2C policy gradient updates on the model. # we give it (policy parameters) model.policy.load_state_dict(candidate, strict=False) # Evaluate the candidate - fitness, _ = evaluate_policy(model, env) + fitness, _ = evaluate_policy(model, vec_env) population.append((candidate, fitness)) # Take top 10% and use average over their parameters as next mean parameter top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite] @@ -745,21 +749,21 @@ Record a mp4 video (here using a random agent). video_folder = "logs/videos/" video_length = 100 - env = DummyVecEnv([lambda: gym.make(env_id)]) + vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")]) - obs = env.reset() + obs = vec_env.reset() # Record the video starting at the first step - env = VecVideoRecorder(env, video_folder, + vec_env = VecVideoRecorder(vec_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=video_length, name_prefix=f"random-agent-{env_id}") - env.reset() + vec_env.reset() for _ in range(video_length + 1): - action = [env.action_space.sample()] - obs, _, _, _ = env.step(action) + action = [vec_env.action_space.sample()] + obs, _, _, _ = vec_env.step(action) # Save the video - env.close() + vec_env.close() Bonus: Make a GIF of a Trained Agent diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6e881885e4..f4336ead59 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,14 +16,16 @@ Release 2.0.0a3 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package +- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package (@carlosluis, @arjun-kg, @tlpss) - The deprecated ``online_sampling`` argument of ``HerReplayBuffer`` was removed - Removed deprecated ``stack_observation_space`` method of ``StackedObservations`` - Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit) +- Upgraded wrappers and custom environment to Gymnasium New Features: ^^^^^^^^^^^^^ + `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -39,10 +41,17 @@ Deprecations: Others: ^^^^^^^ +- Upgraded docker images to use mamba/micromamba and CUDA 11.7 +- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks +- Improve type annotation of wrappers +- Tests envs are now checked too +- Added render test for ``VecEnv`` Documentation: ^^^^^^^^^^^^^^ - Added documentation about ``VecEnv`` API vs Gym API +- Upgraded tutorials to Gymnasium API +- Make it more explicit when using ``VecEnv`` vs Gym env Release 1.8.0 (2023-04-07) @@ -156,7 +165,6 @@ Breaking Changes: please use an ``EvalCallback`` instead - Removed deprecated ``sde_net_arch`` parameter - Removed ``ret`` attributes in ``VecNormalize``, please use ``returns`` instead -- Switched minimum Gym version to 0.26 (@carlosluis, @arjun-kg, @tlpss) - ``VecNormalize`` now updates the observation space when normalizing images New Features: @@ -216,7 +224,7 @@ Others: - Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Set tensors construction directly on the device (~8% speed boost on GPU) - Monkey-patched ``np.bool = bool`` so gym 0.21 is compatible with NumPy 1.24+ -- Standardized the use of ``from gymnasium import spaces`` +- Standardized the use of ``from gym import spaces`` - Modified ``get_system_info`` to avoid issue linked to copy-pasting on GitHub issue Documentation: @@ -340,7 +348,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ -- ``noop_max`` and ``frame_skip`` are now allowed to be equal to zero when using ``AtariWrapper`` + `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -366,7 +374,6 @@ Deprecations: Others: ^^^^^^^ - Upgraded to Python 3.7+ syntax using ``pyupgrade`` -- Updated docker base image to Ubuntu 20.04 and cuda 11.3 - Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG) Documentation: diff --git a/pyproject.toml b/pyproject.toml index c43b4376e9..aea195d458 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,6 @@ filterwarnings = [ "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings "ignore::UserWarning:gymnasium", - # "ignore::DeprecationWarning:.*passive_env_checker.*", ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" From 923bd466404b69c66a86541871f963b18ebd69de Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 13 Apr 2023 13:32:35 +0200 Subject: [PATCH 149/153] Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version --- docs/misc/changelog.rst | 5 ++ pyproject.toml | 10 +-- stable_baselines3/common/base_class.py | 73 +++++++++++-------- stable_baselines3/common/buffers.py | 11 ++- stable_baselines3/common/callbacks.py | 2 +- .../common/off_policy_algorithm.py | 5 +- .../common/on_policy_algorithm.py | 32 +++++--- stable_baselines3/common/policies.py | 5 +- stable_baselines3/common/save_util.py | 2 +- stable_baselines3/common/type_aliases.py | 2 +- stable_baselines3/common/utils.py | 4 +- stable_baselines3/dqn/dqn.py | 13 ++-- stable_baselines3/dqn/policies.py | 10 +-- stable_baselines3/ppo/ppo.py | 4 +- stable_baselines3/sac/policies.py | 36 ++++++--- stable_baselines3/sac/sac.py | 23 +++--- stable_baselines3/td3/policies.py | 27 +++++-- stable_baselines3/td3/td3.py | 14 +++- 18 files changed, 170 insertions(+), 108 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f4336ead59..fddda2a40c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -41,6 +41,11 @@ Deprecations: Others: ^^^^^^^ +- Fixed ``stable_baselines3/a2c/*.py`` type hints +- Fixed ``stable_baselines3/ppo/*.py`` type hints +- Fixed ``stable_baselines3/sac/*.py`` type hints +- Fixed ``stable_baselines3/td3/*.py`` type hints +- Fixed ``stable_baselines3/common/base_class.py`` type hints - Upgraded docker images to use mamba/micromamba and CUDA 11.7 - Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks - Improve type annotation of wrappers diff --git a/pyproject.toml b/pyproject.toml index aea195d458..b44edf5299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,9 +35,7 @@ ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - stable_baselines3/a2c/a2c.py$ - | stable_baselines3/common/base_class.py$ - | stable_baselines3/common/buffers.py$ + stable_baselines3/common/buffers.py$ | stable_baselines3/common/callbacks.py$ | stable_baselines3/common/distributions.py$ | stable_baselines3/common/envs/bit_flipping_env.py$ @@ -45,7 +43,6 @@ exclude = """(?x)( | stable_baselines3/common/envs/multi_input_envs.py$ | stable_baselines3/common/logger.py$ | stable_baselines3/common/off_policy_algorithm.py$ - | stable_baselines3/common/on_policy_algorithm.py$ | stable_baselines3/common/policies.py$ | stable_baselines3/common/save_util.py$ | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ @@ -62,11 +59,6 @@ exclude = """(?x)( | stable_baselines3/common/vec_env/vec_transpose.py$ | stable_baselines3/common/vec_env/vec_video_recorder.py$ | stable_baselines3/her/her_replay_buffer.py$ - | stable_baselines3/ppo/ppo.py$ - | stable_baselines3/sac/policies.py$ - | stable_baselines3/sac/sac.py$ - | stable_baselines3/td3/policies.py$ - | stable_baselines3/td3/td3.py$ | tests/test_logger.py$ | tests/test_train_eval_mode.py$ )""" diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 475a7d3710..ece511a249 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -22,7 +22,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict from stable_baselines3.common.utils import ( check_for_correct_spaces, get_device, @@ -44,7 +44,7 @@ SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm") -def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]: +def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv: """If env is a string, make the environment; otherwise, return env. :param env: The environment to learn from. @@ -52,13 +52,14 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE :return A Gym (vector) environment. """ if isinstance(env, str): + env_id = env if verbose >= 1: - print(f"Creating environment from the given name '{env}'") + print(f"Creating environment from the given name '{env_id}'") # Set render_mode to `rgb_array` as default, so we can record video try: - env = gym.make(env, render_mode="rgb_array") + env = gym.make(env_id, render_mode="rgb_array") except TypeError: - env = gym.make(env) + env = gym.make(env_id) return env @@ -95,6 +96,11 @@ class BaseAlgorithm(ABC): # Policy aliases (see _get_policy_from_name()) policy_aliases: Dict[str, Type[BasePolicy]] = {} policy: BasePolicy + observation_space: spaces.Space + action_space: spaces.Space + n_envs: int + lr_schedule: Schedule + _logger: Logger def __init__( self, @@ -111,8 +117,8 @@ def __init__( seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = -1, - supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, - ): + supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + ) -> None: if isinstance(policy, str): self.policy_class = self._get_policy_from_name(policy) else: @@ -122,14 +128,9 @@ def __init__( if verbose >= 1: print(f"Using {self.device} device") - self.env = None # type: Optional[GymEnv] - # get VecNormalize object if needed - self._vec_normalize_env = unwrap_vec_normalize(env) self.verbose = verbose self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs - self.observation_space: spaces.Space - self.action_space: spaces.Space - self.n_envs: int + self.num_timesteps = 0 # Used for updating schedules self._total_timesteps = 0 @@ -137,10 +138,9 @@ def __init__( self._num_timesteps_at_start = 0 self.seed = seed self.action_noise: Optional[ActionNoise] = None - self.start_time = None + self.start_time = 0.0 self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log - self.lr_schedule = None # type: Optional[Schedule] self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] self._last_episode_starts = None # type: Optional[np.ndarray] # When using VecNormalize: @@ -151,17 +151,17 @@ def __init__( self.sde_sample_freq = sde_sample_freq # Track the training progress remaining (from 1 to 0) # this is used to update the learning rate - self._current_progress_remaining = 1 + self._current_progress_remaining = 1.0 # Buffers for logging self._stats_window_size = stats_window_size self.ep_info_buffer = None # type: Optional[deque] self.ep_success_buffer = None # type: Optional[deque] # For logging (and TD3 delayed updates) self._n_updates = 0 # type: int - # The logger object - self._logger = None # type: Logger # Whether the user passed a custom logger or not self._custom_logger = False + self.env: Optional[VecEnv] = None + self._vec_normalize_env: Optional[VecNormalize] = None # Create and wrap the env if needed if env is not None: @@ -173,6 +173,9 @@ def __init__( self.n_envs = env.num_envs self.env = env + # get VecNormalize object if needed + self._vec_normalize_env = unwrap_vec_normalize(env) + if supported_action_spaces is not None: assert isinstance(self.action_space, supported_action_spaces), ( f"The algorithm only supports {supported_action_spaces} as action spaces " @@ -217,7 +220,7 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve env = Monitor(env) if verbose >= 1: print("Wrapping the env in a DummyVecEnv.") - env = DummyVecEnv([lambda: env]) + env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] # Make sure that dict-spaces are not nested (not supported) check_for_nested_spaces(env.observation_space) @@ -230,11 +233,11 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve # the other channel last), VecTransposeImage will throw an error for space in env.observation_space.spaces.values(): wrap_with_vectranspose = wrap_with_vectranspose or ( - is_image_space(space) and not is_image_space_channels_first(space) + is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type] ) else: wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first( - env.observation_space + env.observation_space # type: ignore[arg-type] ) if wrap_with_vectranspose: @@ -416,7 +419,10 @@ def _setup_learn( # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: - self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch + assert self.env is not None + # pytype: disable=annotation-type-mismatch + self._last_obs = self.env.reset() # type: ignore[assignment] + # pytype: enable=annotation-type-mismatch self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: @@ -439,6 +445,9 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd :param infos: List of additional information about the transition. :param dones: Termination signals """ + assert self.ep_info_buffer is not None + assert self.ep_success_buffer is not None + if dones is None: dones = np.array([False] * len(infos)) for idx, info in enumerate(infos): @@ -562,7 +571,7 @@ def set_random_seed(self, seed: Optional[int] = None) -> None: def set_parameters( self, - load_path_or_dict: Union[str, Dict[str, Dict]], + load_path_or_dict: Union[str, TensorDict], exact_match: bool = True, device: Union[th.device, str] = "auto", ) -> None: @@ -578,7 +587,7 @@ def set_parameters( can be used to update only specific parameters. :param device: Device on which the code should run. """ - params = None + params = {} if isinstance(load_path_or_dict, dict): params = load_path_or_dict else: @@ -616,7 +625,7 @@ def set_parameters( # # Solution: Just load the state-dict as is, and trust # the user has provided a sensible state dictionary. - attr.load_state_dict(params[name]) + attr.load_state_dict(params[name]) # type: ignore[arg-type] else: # Assume attr is th.nn.Module attr.load_state_dict(params[name], strict=exact_match) @@ -674,6 +683,9 @@ def load( # noqa: C901 print_system_info=print_system_info, ) + assert data is not None, "No data found in the saved file" + assert params is not None, "No params found in the saved file" + # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: @@ -714,13 +726,14 @@ def load( # noqa: C901 if "env" in data: env = data["env"] - # noinspection PyArgumentList - model = cls( # pytype: disable=not-instantiable,wrong-keyword-args + # pytype: disable=not-instantiable,wrong-keyword-args + model = cls( policy=data["policy_class"], env=env, device=device, - _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args + _init_setup_model=False, # type: ignore[call-arg] ) + # pytype: enable=not-instantiable,wrong-keyword-args # load parameters model.__dict__.update(data) @@ -758,12 +771,12 @@ def load( # noqa: C901 continue # Set the data attribute directly to avoid issue when using optimizers # See https://github.com/DLR-RM/stable-baselines3/issues/391 - recursive_setattr(model, name + ".data", pytorch_variables[name].data) + recursive_setattr(model, f"{name}.data", pytorch_variables[name].data) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: - model.policy.reset_noise() # pytype: disable=attribute-error + model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error return model def get_parameters(self) -> Dict[str, Dict]: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 8ae1360f6f..e52f08f693 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -335,6 +335,15 @@ class RolloutBuffer(BaseBuffer): :param n_envs: Number of parallel environments """ + observations: np.ndarray + actions: np.ndarray + rewards: np.ndarray + advantages: np.ndarray + returns: np.ndarray + episode_starts: np.ndarray + log_probs: np.ndarray + values: np.ndarray + def __init__( self, buffer_size: int, @@ -348,8 +357,6 @@ def __init__( super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma - self.observations, self.actions, self.rewards, self.advantages = None, None, None, None - self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None self.generator_ready = False self.reset() diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index fa2a356588..c9b8a33673 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -313,7 +313,7 @@ class ConvertCallback(BaseCallback): :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0): + def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0): super().__init__(verbose) self.callback = callback diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 8c9f05a4cc..e3e6c594a6 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -75,6 +75,8 @@ class OffPolicyAlgorithm(BaseAlgorithm): :param supported_action_spaces: The action spaces supported by the algorithm. """ + actor: th.nn.Module + def __init__( self, policy: Union[str, Type[BasePolicy]], @@ -129,6 +131,7 @@ def __init__( self.gradient_steps = gradient_steps self.action_noise = action_noise self.optimize_memory_usage = optimize_memory_usage + self.replay_buffer: Optional[ReplayBuffer] = None self.replay_buffer_class = replay_buffer_class self.replay_buffer_kwargs = replay_buffer_kwargs or {} self._episode_storage = None @@ -136,8 +139,6 @@ def __init__( # Save train freq parameter, will be converted later to TrainFreq object self.train_freq = train_freq - self.actor = None # type: Optional[th.nn.Module] - self.replay_buffer: Optional[ReplayBuffer] = None # Update policy keyword arguments if sde_support: self.policy_kwargs["use_sde"] = self.use_sde diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 768aa01373..87e192990b 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -52,6 +52,9 @@ class OnPolicyAlgorithm(BaseAlgorithm): :param supported_action_spaces: The action spaces supported by the algorithm. """ + rollout_buffer: RolloutBuffer + policy: ActorCriticPolicy + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], @@ -97,7 +100,6 @@ def __init__( self.ent_coef = ent_coef self.vf_coef = vf_coef self.max_grad_norm = max_grad_norm - self.rollout_buffer = None if _init_setup_model: self._setup_model() @@ -117,13 +119,11 @@ def _setup_model(self) -> None: gae_lambda=self.gae_lambda, n_envs=self.n_envs, ) - self.policy = self.policy_class( # pytype:disable=not-instantiable - self.observation_space, - self.action_space, - self.lr_schedule, - use_sde=self.use_sde, - **self.policy_kwargs # pytype:disable=not-instantiable + # pytype:disable=not-instantiable + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs ) + # pytype:enable=not-instantiable self.policy = self.policy.to(self.device) def collect_rollouts( @@ -201,16 +201,23 @@ def collect_rollouts( ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): - terminal_value = self.policy.predict_values(terminal_obs)[0] + terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) - self._last_obs = new_obs + rollout_buffer.add( + self._last_obs, # type: ignore[arg-type] + actions, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs, + ) + self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type] rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -246,6 +253,8 @@ def learn( callback.on_training_start(locals(), globals()) + assert self.env is not None + while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) @@ -257,6 +266,7 @@ def learn( # Display training infos if log_interval is not None and iteration % log_interval == 0: + assert self.ep_info_buffer is not None time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 5f11787009..32be95e831 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -58,6 +58,8 @@ class BaseModel(nn.Module): excluding the learning rate, to pass to the optimizer """ + optimizer: th.optim.Optimizer + def __init__( self, observation_space: spaces.Space, @@ -84,7 +86,6 @@ def __init__( self.optimizer_class = optimizer_class self.optimizer_kwargs = optimizer_kwargs - self.optimizer: th.optim.Optimizer self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs @@ -279,6 +280,8 @@ class BasePolicy(BaseModel, ABC): or not using a ``tanh()`` function. """ + features_extractor: BaseFeaturesExtractor + def __init__(self, *args, squash_output: bool = False, **kwargs): super().__init__(*args, **kwargs) self._squash_output = squash_output diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index e5aeb662be..3c01a3f268 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -367,7 +367,7 @@ def load_from_zip_file( device: Union[th.device, str] = "auto", verbose: int = 0, print_system_info: bool = False, -) -> Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]: +) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]: """ Load model data from a .zip archive diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 21478bc262..8744923da1 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -22,7 +22,7 @@ GymStepReturn = Tuple[GymObs, float, bool, Dict] Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict] AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] -TensorDict = Dict[Union[str, int], th.Tensor] +TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index e107e6ae33..d20cc85293 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -471,9 +471,7 @@ def polyak_update( th.add(target_param.data, param.data, alpha=tau, out=target_param.data) -def obs_as_tensor( - obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device -) -> Union[th.Tensor, TensorDict]: +def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]: """ Moves the observation to the given device. diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 3fabd59e3a..b85a30f808 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -11,7 +11,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update -from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy +from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork SelfDQN = TypeVar("SelfDQN", bound="DQN") @@ -67,6 +67,11 @@ class DQN(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + # Linear schedule will be defined in `_setup_model()` + exploration_schedule: Schedule + q_net: QNetwork + q_net_target: QNetwork + policy: DQNPolicy def __init__( self, @@ -131,10 +136,6 @@ def __init__( self.max_grad_norm = max_grad_norm # "epsilon" for the epsilon-greedy exploration self.exploration_rate = 0.0 - # Linear schedule will be defined in `_setup_model()` - self.exploration_schedule: Schedule - self.q_net: th.nn.Module - self.q_net_target: th.nn.Module if _init_setup_model: self._setup_model() @@ -164,8 +165,6 @@ def _setup_model(self) -> None: self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) def _create_aliases(self) -> None: - # For type checker: - assert isinstance(self.policy, DQNPolicy) self.q_net = self.policy.q_net self.q_net_target = self.policy.q_net_target diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 2527357381..fcdb95890b 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -27,6 +27,8 @@ class QNetwork(BasePolicy): dividing by 255.0 (True by default) """ + action_space: spaces.Discrete + def __init__( self, observation_space: spaces.Space, @@ -50,7 +52,6 @@ def __init__( self.net_arch = net_arch self.activation_fn = activation_fn self.features_dim = features_dim - assert isinstance(self.action_space, spaces.Discrete) action_dim = int(self.action_space.n) # number of actions q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) self.q_net = nn.Sequential(*q_net) @@ -62,8 +63,6 @@ def forward(self, obs: th.Tensor) -> th.Tensor: :param obs: Observation :return: The estimated Q-Value for each action. """ - # For type checker: - assert isinstance(self.features_extractor, BaseFeaturesExtractor) return self.q_net(self.extract_features(obs, self.features_extractor)) def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: @@ -106,6 +105,9 @@ class DQNPolicy(BasePolicy): excluding the learning rate, to pass to the optimizer """ + q_net: QNetwork + q_net_target: QNetwork + def __init__( self, observation_space: spaces.Space, @@ -146,8 +148,6 @@ def __init__( "normalize_images": normalize_images, } - self.q_net: QNetwork - self.q_net_target: QNetwork self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 1e3d3ad1a8..0df51dc61d 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -183,10 +183,10 @@ def train(self) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range - clip_range = self.clip_range(self._current_progress_remaining) + clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] # Optional: clip range for the value function if self.clip_range_vf is not None: - clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] entropy_losses = [] pg_losses, value_losses = [], [] diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 418e5cc227..8902629d4a 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -45,10 +45,12 @@ class Actor(BasePolicy): dividing by 255.0 (True by default) """ + action_space: spaces.Box + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -96,9 +98,9 @@ def __init__( if clip_mean > 0.0: self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) else: - self.action_dist = SquashedDiagGaussianDistribution(action_dim) + self.action_dist = SquashedDiagGaussianDistribution(action_dim) # type: ignore[assignment] self.mu = nn.Linear(last_layer_dim, action_dim) - self.log_std = nn.Linear(last_layer_dim, action_dim) + self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment] def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -157,7 +159,7 @@ def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, if self.use_sde: return mean_actions, self.log_std, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) - log_std = self.log_std(latent_pi) + log_std = self.log_std(latent_pi) # type: ignore[operator] # Original Implementation to cap the standard deviation log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} @@ -205,10 +207,14 @@ class SACPolicy(BasePolicy): between the actor and the critic (this saves computation time) """ + actor: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -267,15 +273,17 @@ def __init__( } ) - self.actor, self.actor_target = None, None - self.critic, self.critic_target = None, None self.share_features_extractor = share_features_extractor self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: self.actor = self.make_actor() - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class( + self.actor.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) @@ -286,13 +294,17 @@ def _build(self, lr_schedule: Schedule) -> None: # Create a separate features extractor for the critic # this requires more memory and computation self.critic = self.make_critic(features_extractor=None) - critic_parameters = self.critic.parameters() + critic_parameters = list(self.critic.parameters()) # Critic target should not share the features extractor with critic self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + self.critic.optimizer = self.optimizer_class( + critic_parameters, + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) # Target networks should always be in eval mode self.critic_target.set_training_mode(False) @@ -386,7 +398,7 @@ class CnnPolicy(SACPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -452,7 +464,7 @@ class MultiInputPolicy(SACPolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 278811196e..de344a4537 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -8,10 +8,10 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update -from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy +from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy SelfSAC = TypeVar("SelfSAC", bound="SAC") @@ -82,6 +82,10 @@ class SAC(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + policy: SACPolicy + actor: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic def __init__( self, @@ -137,7 +141,7 @@ def __init__( sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) @@ -147,7 +151,7 @@ def __init__( # Inverse of the reward scale self.ent_coef = ent_coef self.target_update_interval = target_update_interval - self.ent_coef_optimizer = None + self.ent_coef_optimizer: Optional[th.optim.Adam] = None if _init_setup_model: self._setup_model() @@ -161,7 +165,7 @@ def _setup_model(self) -> None: # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed - self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) + self.target_entropy = float(-np.prod(self.env.action_space.shape).astype(np.float32)) # type: ignore else: # Force conversion # this will also throw an error for unexpected string @@ -208,7 +212,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: for gradient_step in range(gradient_steps): # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: @@ -219,7 +223,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None - if self.ent_coef_optimizer is not None: + if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 @@ -233,7 +237,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper - if ent_coef_loss is not None: + if ent_coef_loss is not None and self.ent_coef_optimizer is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() @@ -255,7 +259,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Compute critic loss critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) - critic_losses.append(critic_loss.item()) + assert isinstance(critic_loss, th.Tensor) # for type checker + critic_losses.append(critic_loss.item()) # type: ignore[union-attr] # Optimize the critic self.critic.optimizer.zero_grad() diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 10546083a1..12117df89a 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -35,7 +35,7 @@ class Actor(BasePolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -106,10 +106,15 @@ class TD3Policy(BasePolicy): between the actor and the critic (this saves computation time) """ + actor: Actor + actor_target: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -160,8 +165,6 @@ def __init__( } ) - self.actor, self.actor_target = None, None - self.critic, self.critic_target = None, None self.share_features_extractor = share_features_extractor self._build(lr_schedule) @@ -174,7 +177,11 @@ def _build(self, lr_schedule: Schedule) -> None: # Initialize the target to have the same weights as the actor self.actor_target.load_state_dict(self.actor.state_dict()) - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class( + self.actor.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) @@ -190,7 +197,11 @@ def _build(self, lr_schedule: Schedule) -> None: self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.critic.optimizer = self.optimizer_class( + self.critic.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) # Target networks should always be in eval mode self.actor_target.set_training_mode(False) @@ -272,7 +283,7 @@ class CnnPolicy(TD3Policy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -326,7 +337,7 @@ class MultiInputPolicy(TD3Policy): def __init__( self, observation_space: spaces.Dict, - action_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 8f73bbb66d..10cea8efa3 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -8,10 +8,10 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update -from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy +from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy SelfTD3 = TypeVar("SelfTD3", bound="TD3") @@ -70,6 +70,11 @@ class TD3(OffPolicyAlgorithm): "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } + policy: TD3Policy + actor: Actor + actor_target: Actor + critic: ContinuousCritic + critic_target: ContinuousCritic def __init__( self, @@ -120,7 +125,7 @@ def __init__( seed=seed, sde_support=False, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) @@ -157,7 +162,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: for _ in range(gradient_steps): self._n_updates += 1 # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] with th.no_grad(): # Select action according to policy and add clipped noise @@ -175,6 +180,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Compute critic loss critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) + assert isinstance(critic_loss, th.Tensor) critic_losses.append(critic_loss.item()) # Optimize the critics From 7f4ded59df563cd60653d82ea53139cf72fbf2a1 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 13 Apr 2023 15:45:13 +0200 Subject: [PATCH 150/153] Disable mypy for python 3.7 --- .github/workflows/ci.yml | 2 ++ docs/misc/changelog.rst | 2 +- stable_baselines3/version.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 52f5f38951..7c238cdd4b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,6 +55,8 @@ jobs: - name: Type check run: | make type + # skip mypy type check for python3.7 (result is different to all other versions) + if: "!(matrix.python-version == '3.7')" - name: Test with pytest run: | make pytest diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fddda2a40c..5f2d0b7604 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a3 (WIP) +Release 2.0.0a4 (WIP) -------------------------- **Gymnasium support** diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 82d45615df..997bba2390 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a3 +2.0.0a4 From eafd84cd2fda34dacc3ac47ed284e28a1acde6d6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 13 Apr 2023 15:49:29 +0200 Subject: [PATCH 151/153] Rename Gym26StepReturn --- stable_baselines3/common/envs/bit_flipping_env.py | 4 ++-- stable_baselines3/common/envs/identity_env.py | 4 ++-- stable_baselines3/common/envs/multi_input_envs.py | 4 ++-- stable_baselines3/common/type_aliases.py | 5 ++--- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 11e9377b94..ec0de2bf9d 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -5,7 +5,7 @@ from gymnasium import Env, spaces from gymnasium.envs.registration import EnvSpec -from stable_baselines3.common.type_aliases import Gym26StepReturn +from stable_baselines3.common.type_aliases import GymStepReturn class BitFlippingEnv(Env): @@ -166,7 +166,7 @@ def reset( self.state = self.obs_space.sample() return self._get_obs(), {} - def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: + def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: """ Step into the env. diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 05e8023bfa..99a6649997 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -4,7 +4,7 @@ import numpy as np from gymnasium import spaces -from stable_baselines3.common.type_aliases import Gym26StepReturn +from stable_baselines3.common.type_aliases import GymStepReturn T = TypeVar("T", int, np.ndarray) @@ -148,7 +148,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) - self.current_step = 0 return self.observation_space.sample(), {} - def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: + def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: reward = 0.0 self.current_step += 1 terminated = False diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 1717405226..3bb07106a6 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -4,7 +4,7 @@ import numpy as np from gymnasium import spaces -from stable_baselines3.common.type_aliases import Gym26StepReturn +from stable_baselines3.common.type_aliases import GymStepReturn class SimpleMultiObsEnv(gym.Env): @@ -121,7 +121,7 @@ def init_possible_transitions(self) -> None: self.right_possible = [0, 1, 2, 12, 13, 14] self.up_possible = [4, 8, 12, 7, 11, 15] - def step(self, action: Union[float, np.ndarray]) -> Gym26StepReturn: + def step(self, action: Union[float, np.ndarray]) -> GymStepReturn: """ Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 8744923da1..d38d7cf737 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -17,10 +17,9 @@ GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] -Gym26ResetReturn = Tuple[GymObs, Dict] +GymResetReturn = Tuple[GymObs, Dict] AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] -GymStepReturn = Tuple[GymObs, float, bool, Dict] -Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict] +GymStepReturn = Tuple[GymObs, float, bool, bool, Dict] AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] From e613373f35ce4f9ab31ba3ff79c90c1f1f98de01 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 13 Apr 2023 16:50:03 +0200 Subject: [PATCH 152/153] Update continuous critic type annotation --- stable_baselines3/common/policies.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 32be95e831..fb6f1ee003 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -898,10 +898,12 @@ class ContinuousCritic(BaseModel): between the actor and the critic (this saves computation time) """ + features_extractor: BaseFeaturesExtractor + def __init__( self, observation_space: spaces.Space, - action_space: spaces.Space, + action_space: spaces.Box, net_arch: List[int], features_extractor: nn.Module, features_dim: int, From 33cd8db3b54ac503c1b8c513bfc2d9c3425d2333 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 13 Apr 2023 17:46:57 +0200 Subject: [PATCH 153/153] Fix pytype complain --- stable_baselines3/common/policies.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index fb6f1ee003..21d2034d66 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -898,14 +898,12 @@ class ContinuousCritic(BaseModel): between the actor and the critic (this saves computation time) """ - features_extractor: BaseFeaturesExtractor - def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, net_arch: List[int], - features_extractor: nn.Module, + features_extractor: BaseFeaturesExtractor, features_dim: int, activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True,