From e6b0a40152d1df1cd10d180b562df43628c8ab70 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 14 Apr 2022 21:26:04 +0530 Subject: [PATCH 01/37] New Step API with terminated, truncated bools instead of done Old API - `done=True` if episode ends in any way. New API - `terminated=True` if environment terminates (eg. due to task completion, failure etc.) `truncated=True` if episode truncates due to a time limit Changes 1. All existing environments are changed to new API without direct support for old API 2. Vector envs are changed to new API without direct support for old API 3. All wrappers (Except TimeLimit, OrderEnforcing) are changed to new API without direct support for old API 4. TimeLimit, OrderEnforcing and wrappers which don't change the step function support both APIs. StepCompatibility Wrapper 1. This wrapper is added to support conversion from old to new API and vice versa. 2. Takes `return_two_dones` argument in __init__. `True` (new API) by default. 3. Wrapper applied at make with `return_two_dones=True` by default. It can be changed during make like `env.make("CartPole-v1", return_two_dones=False)` StepCompatibilityVector Wrapper - Transforms vector environment to old API. Set `return_two_dones=False`. Misc 1. Autoreset bug fixed (hacky) by setting local variable instead of editing self.autoreset in Env.Spec. --- README.md | 2 +- gym/core.py | 39 ++++++-- gym/envs/box2d/bipedal_walker.py | 14 +-- gym/envs/box2d/car_racing.py | 14 +-- gym/envs/box2d/lunar_lander.py | 14 +-- gym/envs/classic_control/acrobot.py | 14 +-- gym/envs/classic_control/cartpole.py | 33 +++---- .../continuous_mountain_car.py | 16 ++-- gym/envs/classic_control/mountain_car.py | 14 +-- gym/envs/classic_control/pendulum.py | 6 +- gym/envs/mujoco/ant.py | 9 +- gym/envs/mujoco/ant_v3.py | 22 ++--- gym/envs/mujoco/half_cheetah.py | 10 +- gym/envs/mujoco/half_cheetah_v3.py | 8 +- gym/envs/mujoco/hopper.py | 4 +- gym/envs/mujoco/hopper_v3.py | 22 ++--- gym/envs/mujoco/humanoid.py | 5 +- gym/envs/mujoco/humanoid_v3.py | 22 ++--- gym/envs/mujoco/humanoidstandup.py | 12 +-- gym/envs/mujoco/inverted_double_pendulum.py | 14 +-- gym/envs/mujoco/inverted_pendulum.py | 15 ++- gym/envs/mujoco/mujoco_env.py | 4 +- gym/envs/mujoco/pusher.py | 17 ++-- gym/envs/mujoco/reacher.py | 17 ++-- gym/envs/mujoco/swimmer.py | 8 +- gym/envs/mujoco/swimmer_v3.py | 7 +- gym/envs/mujoco/walker2d.py | 4 +- gym/envs/mujoco/walker2d_v3.py | 10 +- gym/envs/registration.py | 18 +++- gym/envs/toy_text/blackjack.py | 8 +- gym/envs/toy_text/cliffwalking.py | 10 +- gym/envs/toy_text/frozen_lake.py | 8 +- gym/envs/toy_text/taxi.py | 12 ++- gym/utils/env_checker.py | 17 +++- gym/utils/play.py | 17 ++-- gym/vector/__init__.py | 11 ++- gym/vector/async_vector_env.py | 30 +++--- gym/vector/step_compatibility_vector.py | 36 ++++++++ gym/vector/sync_vector_env.py | 21 +++-- gym/vector/vector_env.py | 7 +- gym/wrappers/__init__.py | 1 + gym/wrappers/atari_preprocessing.py | 16 ++-- gym/wrappers/autoreset.py | 42 ++++----- gym/wrappers/frame_stack.py | 4 +- gym/wrappers/normalize.py | 12 ++- gym/wrappers/order_enforcing.py | 3 +- gym/wrappers/record_episode_statistics.py | 12 ++- gym/wrappers/record_video.py | 14 +-- gym/wrappers/step_compatibility.py | 69 ++++++++++++++ gym/wrappers/time_limit.py | 20 ++-- gym/wrappers/transform_reward.py | 2 +- tests/envs/test_determinism.py | 5 +- tests/envs/test_envs.py | 9 +- tests/envs/test_mujoco_v2_to_v3_conversion.py | 19 +++- tests/test_core.py | 7 +- tests/utils/test_env_checker.py | 9 +- tests/utils/test_terminated_truncated.py | 91 +++++++++++++++++++ tests/vector/test_async_vector_env.py | 23 +++-- .../vector/test_step_compatibility_vector.py | 89 ++++++++++++++++++ tests/vector/test_sync_vector_env.py | 17 ++-- tests/vector/test_vector_env.py | 16 ++-- tests/vector/utils.py | 8 +- tests/wrappers/nested_dict_test.py | 4 +- tests/wrappers/test_atari_preprocessing.py | 6 +- tests/wrappers/test_autoreset.py | 37 +++++--- tests/wrappers/test_clip_action.py | 7 +- tests/wrappers/test_filter_observation.py | 4 +- tests/wrappers/test_frame_stack.py | 4 +- tests/wrappers/test_normalize.py | 16 +++- tests/wrappers/test_pixel_observation.py | 4 +- .../test_record_episode_statistics.py | 8 +- tests/wrappers/test_record_video.py | 10 +- tests/wrappers/test_rescale_action.py | 4 +- tests/wrappers/test_step_compatibility.py | 75 +++++++++++++++ tests/wrappers/test_time_aware_observation.py | 4 +- tests/wrappers/test_time_limit_info.py | 1 + tests/wrappers/test_transform_observation.py | 13 ++- tests/wrappers/test_transform_reward.py | 12 +-- 78 files changed, 899 insertions(+), 369 deletions(-) create mode 100644 gym/vector/step_compatibility_vector.py create mode 100644 gym/wrappers/step_compatibility.py create mode 100644 tests/utils/test_terminated_truncated.py create mode 100644 tests/vector/test_step_compatibility_vector.py create mode 100644 tests/wrappers/test_step_compatibility.py create mode 100644 tests/wrappers/test_time_limit_info.py diff --git a/README.md b/README.md index 37b63f2d162..737f16bce82 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ observation, info = env.reset(seed=42, return_info=True) for _ in range(1000): action = env.action_space.sample() - observation, reward, done, info = env.step(action) + observation, reward, terminated, truncated, info = env.step(action) if done: observation, info = env.reset(return_info=True) diff --git a/gym/core.py b/gym/core.py index eabaea69508..cad98c2ec93 100644 --- a/gym/core.py +++ b/gym/core.py @@ -61,12 +61,17 @@ def np_random(self, value: RandomNumberGenerator): self._np_random = value @abstractmethod - def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + def step( + self, action: ActType + ) -> Union[ + Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] + ]: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. - Accepts an action and returns a tuple (observation, reward, done, info). + Accepts an action and returns either a tuple (observation, reward, terminated, truncated, info) or a tuple + (observation, reward, done, info). The latter is deprecated and will be removed in future versions. Args: action (object): an action provided by the agent @@ -76,13 +81,17 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: Returns: observation (object): agent's observation of the current environment. This will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects. reward (float) : amount of reward returned after previous action - done (bool): whether the episode has ended, in which case further :meth:`step` calls will return undefined results. A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, a certain timelimit was exceeded, or the physics simulation has entered an invalid state. ``info`` may contain additional information regarding the reason for a ``done`` signal. + terminated (bool): whether the episode has ended due to a termination, in which case further step() calls will return undefined results + truncated (bool): whether the episode has ended due to a truncation, in which case further step() calls will return undefined results info (dict): contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain: - metrics that describe the agent's performance or - state variables that are hidden from observations or - information that distinguishes truncation and termination or - individual reward terms that are combined to produce the total reward + + (deprecated) + done (bool): whether the episode has ended due to any reason, in which case further step() calls will return undefined results """ raise NotImplementedError @@ -290,7 +299,11 @@ def metadata(self) -> dict: def metadata(self, value): self._metadata = value - def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + def step( + self, action: ActType + ) -> Union[ + Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] + ]: return self.env.step(action) def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]: @@ -325,8 +338,13 @@ def reset(self, **kwargs): return self.observation(self.env.reset(**kwargs)) def step(self, action): - observation, reward, done, info = self.env.step(action) - return self.observation(observation), reward, done, info + step_returns = self.env.step(action) + if len(step_returns) == 5: + observation, reward, terminated, truncated, info = step_returns + return self.observation(observation), reward, terminated, truncated, info + else: + observation, reward, done, info = step_returns + return self.observation(observation), reward, done, info @abstractmethod def observation(self, observation): @@ -338,8 +356,13 @@ def reset(self, **kwargs): return self.env.reset(**kwargs) def step(self, action): - observation, reward, done, info = self.env.step(action) - return observation, self.reward(reward), done, info + step_returns = self.env.step(action) + if len(step_returns) == 5: + observation, reward, terminated, truncated, info = step_returns + return observation, self.reward(reward), terminated, truncated, info + else: + observation, reward, done, info = step_returns + return observation, self.reward(reward), done, info @abstractmethod def reward(self, reward): diff --git a/gym/envs/box2d/bipedal_walker.py b/gym/envs/box2d/bipedal_walker.py index ee072cf5311..ce260ce5064 100644 --- a/gym/envs/box2d/bipedal_walker.py +++ b/gym/envs/box2d/bipedal_walker.py @@ -581,13 +581,13 @@ def step(self, action: np.ndarray): reward -= 0.00035 * MOTORS_TORQUE * np.clip(np.abs(a), 0, 1) # normalized to about -50.0 using heuristic, more optimal agent should spend less - done = False + terminated = False if self.game_over or pos[0] < 0: reward = -100 - done = True + terminated = True if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP: - done = True - return np.array(state, dtype=np.float32), reward, done, {} + terminated = True + return np.array(state, dtype=np.float32), reward, terminated, False, {} def render(self, mode: str = "human"): import pygame @@ -757,9 +757,9 @@ def __init__(self): SUPPORT_KNEE_ANGLE = +0.1 supporting_knee_angle = SUPPORT_KNEE_ANGLE while True: - s, r, done, info = env.step(a) + s, r, terminated, truncated, info = env.step(a) total_reward += r - if steps % 20 == 0 or done: + if steps % 20 == 0 or terminated or truncated: print("\naction " + str([f"{x:+0.2f}" for x in a])) print(f"step {steps} total_reward {total_reward:+0.2f}") print("hull " + str([f"{x:+0.2f}" for x in s[0:4]])) @@ -823,5 +823,5 @@ def __init__(self): a = np.clip(0.5 * a, -1.0, 1.0) env.render() - if done: + if terminated or truncated: break diff --git a/gym/envs/box2d/car_racing.py b/gym/envs/box2d/car_racing.py index b95c2cde04b..8e28dffc1ba 100644 --- a/gym/envs/box2d/car_racing.py +++ b/gym/envs/box2d/car_racing.py @@ -415,7 +415,7 @@ def step(self, action): self.state = self.render("state_pixels") step_reward = 0 - done = False + terminated = False if action is not None: # First step without action, called from reset() self.reward -= 0.1 # We actually don't want to count fuel spent, we want car to be faster. @@ -424,13 +424,13 @@ def step(self, action): step_reward = self.reward - self.prev_reward self.prev_reward = self.reward if self.tile_visited_count == len(self.track) or self.new_lap: - done = True + terminated = True x, y = self.car.hull.position if abs(x) > PLAYFIELD or abs(y) > PLAYFIELD: - done = True + terminated = True step_reward = -100 - return self.state, step_reward, done, {} + return self.state, step_reward, terminated, False, {} def render(self, mode="human"): import pygame @@ -660,13 +660,13 @@ def register_input(): restart = False while True: register_input() - s, r, done, info = env.step(a) + s, r, terminated, truncated, info = env.step(a) total_reward += r - if steps % 200 == 0 or done: + if steps % 200 == 0 or terminated or truncated: print("\naction " + str([f"{x:+0.2f}" for x in a])) print(f"step {steps} total_reward {total_reward:+0.2f}") steps += 1 isopen = env.render() - if done or restart or isopen == False: + if terminated or truncated or restart or isopen == False: break env.close() diff --git a/gym/envs/box2d/lunar_lander.py b/gym/envs/box2d/lunar_lander.py index 0d9c6c0484b..a93e50a6061 100644 --- a/gym/envs/box2d/lunar_lander.py +++ b/gym/envs/box2d/lunar_lander.py @@ -473,14 +473,14 @@ def step(self, action): ) # less fuel spent is better, about -30 for heuristic landing reward -= s_power * 0.03 - done = False + terminated = False if self.game_over or abs(state[0]) >= 1.0: - done = True + terminated = True reward = -100 if not self.lander.awake: - done = True + terminated = True reward = +100 - return np.array(state, dtype=np.float32), reward, done, {} + return np.array(state, dtype=np.float32), reward, terminated, False, {} def render(self, mode="human"): import pygame @@ -654,7 +654,7 @@ def demo_heuristic_lander(env, seed=None, render=False): s = env.reset(seed=seed) while True: a = heuristic(env, s) - s, r, done, info = env.step(a) + s, r, terminated, truncated, info = env.step(a) total_reward += r if render: @@ -662,11 +662,11 @@ def demo_heuristic_lander(env, seed=None, render=False): if still_open == False: break - if steps % 20 == 0 or done: + if steps % 20 == 0 or terminated or truncated: print("observations:", " ".join([f"{x:+0.2f}" for x in s])) print(f"step {steps} total_reward {total_reward:+0.2f}") steps += 1 - if done: + if terminated or truncated: break if render: env.close() diff --git a/gym/envs/classic_control/acrobot.py b/gym/envs/classic_control/acrobot.py index 78b5d08944f..69de3919398 100644 --- a/gym/envs/classic_control/acrobot.py +++ b/gym/envs/classic_control/acrobot.py @@ -82,12 +82,12 @@ class AcrobotEnv(core.Env): Each parameter in the underlying state (`theta1`, `theta2`, and the two angular velocities) is initialized uniformly between -0.1 and 0.1. This means both links are pointing downwards with some initial stochasticity. - ### Episode Termination + ### Episode End - The episode terminates if one of the following occurs: - 1. The free end reaches the target height, which is constructed as: + The episode ends if one of the following occurs: + 1. Termination: The free end reaches the target height, which is constructed as: `-cos(theta1) - cos(theta2 + theta1) > 1.0` - 2. Episode length is greater than 500 (200 for v0) + 2. Truncation: Episode length is greater than 500 (200 for v0) ### Arguments @@ -206,9 +206,9 @@ def step(self, a): ns[2] = bound(ns[2], -self.MAX_VEL_1, self.MAX_VEL_1) ns[3] = bound(ns[3], -self.MAX_VEL_2, self.MAX_VEL_2) self.state = ns - terminal = self._terminal() - reward = -1.0 if not terminal else 0.0 - return (self._get_ob(), reward, terminal, {}) + terminated = self._terminal() + reward = -1.0 if not terminated else 0.0 + return (self._get_ob(), reward, terminated, False, {}) def _get_ob(self): s = self.state diff --git a/gym/envs/classic_control/cartpole.py b/gym/envs/classic_control/cartpole.py index 8bb67762f84..43ecd775d4d 100644 --- a/gym/envs/classic_control/cartpole.py +++ b/gym/envs/classic_control/cartpole.py @@ -56,12 +56,13 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): All observations are assigned a uniformly random value in `(-0.05, 0.05)` - ### Episode Termination + ### Episode End - The episode terminates if any one of the following occurs: - 1. Pole Angle is greater than ±12° - 2. Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display) - 3. Episode length is greater than 500 (200 for v0) + The episode ends if any one of the following occurs: + + 1. Termination: Pole Angle is greater than ±12° + 2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display) + 3. Truncation: Episode length is greater than 500 (200 for v0) ### Arguments @@ -109,7 +110,7 @@ def __init__(self): self.isopen = True self.state = None - self.steps_beyond_done = None + self.steps_beyond_terminated = None def step(self, action): err_msg = f"{action!r} ({type(action)}) invalid" @@ -143,31 +144,31 @@ def step(self, action): self.state = (x, x_dot, theta, theta_dot) - done = bool( + terminated = bool( x < -self.x_threshold or x > self.x_threshold or theta < -self.theta_threshold_radians or theta > self.theta_threshold_radians ) - if not done: + if not terminated: reward = 1.0 - elif self.steps_beyond_done is None: + elif self.steps_beyond_terminated is None: # Pole just fell! - self.steps_beyond_done = 0 + self.steps_beyond_terminated = 0 reward = 1.0 else: - if self.steps_beyond_done == 0: + if self.steps_beyond_terminated == 0: logger.warn( "You are calling 'step()' even though this " - "environment has already returned done = True. You " - "should always call 'reset()' once you receive 'done = " + "environment has already returned terminated = True. You " + "should always call 'reset()' once you receive 'terminated = " "True' -- any further steps are undefined behavior." ) - self.steps_beyond_done += 1 + self.steps_beyond_terminated += 1 reward = 0.0 - return np.array(self.state, dtype=np.float32), reward, done, {} + return np.array(self.state, dtype=np.float32), reward, terminated, False, {} def reset( self, @@ -178,7 +179,7 @@ def reset( ): super().reset(seed=seed) self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) - self.steps_beyond_done = None + self.steps_beyond_terminated = None if not return_info: return np.array(self.state, dtype=np.float32) else: diff --git a/gym/envs/classic_control/continuous_mountain_car.py b/gym/envs/classic_control/continuous_mountain_car.py index b7a45e63c2d..049d766b585 100644 --- a/gym/envs/classic_control/continuous_mountain_car.py +++ b/gym/envs/classic_control/continuous_mountain_car.py @@ -76,11 +76,11 @@ class Continuous_MountainCarEnv(gym.Env): The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`. The starting velocity of the car is always assigned to 0. - ### Episode Termination + ### Episode End - The episode terminates if either of the following happens: - 1. The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill) - 2. The length of the episode is 999. + The episode ends if either of the following happens: + 1. Termination: The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill) + 2. Truncation: The length of the episode is 999. ### Arguments @@ -145,15 +145,17 @@ def step(self, action: np.ndarray): velocity = 0 # Convert a possible numpy bool to a Python bool. - done = bool(position >= self.goal_position and velocity >= self.goal_velocity) + terminated = bool( + position >= self.goal_position and velocity >= self.goal_velocity + ) reward = 0 - if done: + if terminated: reward = 100.0 reward -= math.pow(action[0], 2) * 0.1 self.state = np.array([position, velocity], dtype=np.float32) - return self.state, reward, done, {} + return self.state, reward, terminated, False, {} def reset( self, diff --git a/gym/envs/classic_control/mountain_car.py b/gym/envs/classic_control/mountain_car.py index 6c9a555b517..9422165d500 100644 --- a/gym/envs/classic_control/mountain_car.py +++ b/gym/envs/classic_control/mountain_car.py @@ -72,11 +72,11 @@ class MountainCarEnv(gym.Env): The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*. The starting velocity of the car is always assigned to 0. - ### Episode Termination + ### Episode End - The episode terminates if either of the following happens: - 1. The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill) - 2. The length of the episode is 200. + The episode ends if either of the following happens: + 1. Termination: The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill) + 2. Truncation: The length of the episode is 200. ### Arguments @@ -125,11 +125,13 @@ def step(self, action: int): if position == self.min_position and velocity < 0: velocity = 0 - done = bool(position >= self.goal_position and velocity >= self.goal_velocity) + terminated = bool( + position >= self.goal_position and velocity >= self.goal_velocity + ) reward = -1.0 self.state = (position, velocity) - return np.array(self.state, dtype=np.float32), reward, done, {} + return np.array(self.state, dtype=np.float32), reward, terminated, False, {} def reset( self, diff --git a/gym/envs/classic_control/pendulum.py b/gym/envs/classic_control/pendulum.py index bd29042de54..1a27461742e 100644 --- a/gym/envs/classic_control/pendulum.py +++ b/gym/envs/classic_control/pendulum.py @@ -58,9 +58,9 @@ class PendulumEnv(gym.Env): The starting state is a random angle in *[-pi, pi]* and a random angular velocity in *[-1,1]*. - ### Episode Termination + ### Episode Truncation - The episode terminates at 200 time steps. + The episode truncates at 200 time steps. ### Arguments @@ -118,7 +118,7 @@ def step(self, u): newth = th + newthdot * dt self.state = np.array([newth, newthdot]) - return self._get_obs(), -costs, False, {} + return self._get_obs(), -costs, False, False, {} def reset( self, diff --git a/gym/envs/mujoco/ant.py b/gym/envs/mujoco/ant.py index e61b787db98..215915ac2d8 100644 --- a/gym/envs/mujoco/ant.py +++ b/gym/envs/mujoco/ant.py @@ -21,13 +21,16 @@ def step(self, a): survive_reward = 1.0 reward = forward_reward - ctrl_cost - contact_cost + survive_reward state = self.state_vector() - notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 - done = not notdone + not_terminated = ( + np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 + ) + terminated = not not_terminated ob = self._get_obs() return ( ob, reward, - done, + terminated, + False, dict( reward_forward=forward_reward, reward_ctrl=-ctrl_cost, diff --git a/gym/envs/mujoco/ant_v3.py b/gym/envs/mujoco/ant_v3.py index aeffa507523..06f5a6ef634 100644 --- a/gym/envs/mujoco/ant_v3.py +++ b/gym/envs/mujoco/ant_v3.py @@ -122,19 +122,19 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): to be slightly high, thereby indicating a standing up ant. The initial orientation is designed to make it face forward as well. - ### Episode Termination + ### Episode End The ant is said to be unhealthy if any of the following happens: 1. Any of the state space values is no longer finite 2. The z-coordinate of the torso is **not** in the closed interval given by `healthy_z_range` (defaults to [0.2, 1.0]) If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. The ant is unhealthy + 1. Termination: The episode duration reaches a 1000 timesteps + 2. Truncation: The ant is unhealthy - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -156,7 +156,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): | `ctrl_cost_weight` | **float** | `0.5` | Weight for *ctrl_cost* term (see section on reward) | | `contact_cost_weight` | **float** | `5e-4` | Weight for *contact_cost* term (see section on reward) | | `healthy_reward` | **float** | `1` | Constant reward given if the ant is "healthy" after timestep | - | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a done signal if the z-coordinate of the torso is no longer in the `healthy_z_range` | + | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a terminated signal if the z-coordinate of the torso is no longer in the `healthy_z_range` | | `healthy_z_range` | **tuple** | `(0.2, 1)` | The ant is considered healthy if the z-coordinate of the torso is in this range | | `contact_force_range` | **tuple** | `(-1, 1)` | Contact forces are clipped to this range in the computation of *contact_cost* | | `reset_noise_scale` | **float** | `0.1` | Scale of random perturbations of initial position and velocity (see section on Starting State) | @@ -234,9 +234,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def step(self, action): xy_position_before = self.get_body_com("torso")[:2].copy() @@ -256,7 +256,7 @@ def step(self, action): costs = ctrl_cost + contact_cost reward = rewards - costs - done = self.done + terminated = self.terminated observation = self._get_obs() info = { "reward_forward": forward_reward, @@ -271,7 +271,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.sim.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/half_cheetah.py b/gym/envs/mujoco/half_cheetah.py index 53a206fb656..339f7815ba1 100644 --- a/gym/envs/mujoco/half_cheetah.py +++ b/gym/envs/mujoco/half_cheetah.py @@ -17,8 +17,14 @@ def step(self, action): reward_ctrl = -0.1 * np.square(action).sum() reward_run = (xposafter - xposbefore) / self.dt reward = reward_ctrl + reward_run - done = False - return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl) + terminated = False + return ( + ob, + reward, + terminated, + False, + dict(reward_run=reward_run, reward_ctrl=reward_ctrl), + ) def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/half_cheetah_v3.py b/gym/envs/mujoco/half_cheetah_v3.py index 64104f867a8..13f9232e7cd 100644 --- a/gym/envs/mujoco/half_cheetah_v3.py +++ b/gym/envs/mujoco/half_cheetah_v3.py @@ -95,8 +95,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): normal noise with a mean of 0 and standard deviation of `reset_noise_scale` is added to the initial velocity values of all zeros. - ### Episode Termination - The episode terminates when the episode length is greater than 1000. + ### Episode End + The episode truncates when the episode length is greater than 1000. ### Arguments @@ -167,7 +167,7 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False + terminated = False info = { "x_position": x_position_after, "x_velocity": x_velocity, @@ -175,7 +175,7 @@ def step(self, action): "reward_ctrl": -ctrl_cost, } - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.sim.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/hopper.py b/gym/envs/mujoco/hopper.py index ad459bda2a2..41a822fed66 100644 --- a/gym/envs/mujoco/hopper.py +++ b/gym/envs/mujoco/hopper.py @@ -18,14 +18,14 @@ def step(self, a): reward += alive_bonus reward -= 1e-3 * np.square(a).sum() s = self.state_vector() - done = not ( + terminated = not ( np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2) ) ob = self._get_obs() - return ob, reward, done, {} + return ob, reward, terminated, False, {} def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/hopper_v3.py b/gym/envs/mujoco/hopper_v3.py index 807b6976f1b..ae4bc8fc4cb 100644 --- a/gym/envs/mujoco/hopper_v3.py +++ b/gym/envs/mujoco/hopper_v3.py @@ -87,7 +87,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): (0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity. - ### Episode Termination + ### Episode End The hopper is said to be unhealthy if any of the following happens: 1. An element of `observation[1:]` (if `exclude_current_positions_from_observation=True`, else `observation[2:]`) is no longer contained in the closed interval specified by the argument `healthy_state_range` @@ -95,12 +95,12 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): 3. The angle (`observation[1]` if `exclude_current_positions_from_observation=True`, else `observation[2]`) is no longer contained in the closed interval specified by the argument `healthy_angle_range` If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. The hopper is unhealthy + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: The hopper is unhealthy - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -122,7 +122,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): | `forward_reward_weight` | **float** | `1.0` | Weight for *forward_reward* term (see section on reward) | | `ctrl_cost_weight` | **float** | `0.001` | Weight for *ctrl_cost* reward (see section on reward) | | `healthy_reward` | **float** | `1` | Constant reward given if the ant is "healthy" after timestep | - | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a done signal if the hopper is no longer healthy | + | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a terminated signal if the hopper is no longer healthy | | `healthy_state_range` | **tuple** | `(-100, 100)` | The elements of `observation[1:]` (if `exclude_current_positions_from_observation=True`, else `observation[2:]`) must be in this range for the hopper to be considered healthy | | `healthy_z_range` | **tuple** | `(0.7, float("inf"))` | The z-coordinate must be in this range for the hopper to be considered healthy | | `healthy_angle_range` | **tuple** | `(-0.2, 0.2)` | The angle given by `observation[1]` (if `exclude_current_positions_from_observation=True`, else `observation[2]`) must be in this range for the hopper to be considered healthy | @@ -201,9 +201,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.sim.data.qpos.flat.copy() @@ -231,13 +231,13 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/humanoid.py b/gym/envs/mujoco/humanoid.py index d025419937e..bb4890560f5 100644 --- a/gym/envs/mujoco/humanoid.py +++ b/gym/envs/mujoco/humanoid.py @@ -40,11 +40,12 @@ def step(self, a): quad_impact_cost = min(quad_impact_cost, 10) reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus qpos = self.sim.data.qpos - done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) + terminated = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) return ( self._get_obs(), reward, - done, + terminated, + False, dict( reward_linvel=lin_vel_cost, reward_quadctrl=-quad_ctrl_cost, diff --git a/gym/envs/mujoco/humanoid_v3.py b/gym/envs/mujoco/humanoid_v3.py index a17887b0ff6..ccd84b86ba6 100644 --- a/gym/envs/mujoco/humanoid_v3.py +++ b/gym/envs/mujoco/humanoid_v3.py @@ -163,17 +163,17 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): selected to be high, thereby indicating a standing up humanoid. The initial orientation is designed to make it face forward as well. - ### Episode Termination + ### Episode End The humanoid is said to be unhealthy if the z-position of the torso is no longer contained in the closed interval specified by the argument `healthy_z_range`. If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 3. The humanoid is unhealthy + 1. Truncation: The episode duration reaches a 1000 timesteps + 3. Termination: The humanoid is unhealthy - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -197,7 +197,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): | `ctrl_cost_weight` | **float** | `0.1` | Weight for *ctrl_cost* term (see section on reward) | | `contact_cost_weight` | **float** | `5e-7` | Weight for *contact_cost* term (see section on reward) | | `healthy_reward` | **float** | `5.0` | Constant reward given if the humanoid is "healthy" after timestep | - | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a done signal if the z-coordinate of the torso is no longer in the `healthy_z_range` | + | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a terminated signal if the z-coordinate of the torso is no longer in the `healthy_z_range` | | `healthy_z_range` | **tuple** | `(1.0, 2.0)` | The humanoid is considered healthy if the z-coordinate of the torso is in this range | | `reset_noise_scale` | **float** | `1e-2` | Scale of random perturbations of initial position and velocity (see section on Starting State) | | `exclude_current_positions_from_observation`| **bool** | `True`| Whether or not to omit the x- and y-coordinates from observations. Excluding the position can serve as an inductive bias to induce position-agnostic behavior in policies | @@ -268,9 +268,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = (not self.is_healthy) if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.sim.data.qpos.flat.copy() @@ -315,7 +315,7 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "reward_linvel": forward_reward, "reward_quadctrl": -ctrl_cost, @@ -329,7 +329,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/humanoidstandup.py b/gym/envs/mujoco/humanoidstandup.py index d1a6b47427d..91014383f7e 100644 --- a/gym/envs/mujoco/humanoidstandup.py +++ b/gym/envs/mujoco/humanoidstandup.py @@ -149,11 +149,11 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle): to be low, thereby indicating a laying down humanoid. The initial orientation is designed to make it face forward as well. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. Any of the state space values is no longer finite + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: Any of the state space values is no longer finite ### Arguments @@ -203,11 +203,11 @@ def step(self, a): quad_impact_cost = min(quad_impact_cost, 10) reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 - done = bool(False) return ( self._get_obs(), reward, - done, + False, + False, dict( reward_linup=uph_cost, reward_quadctrl=-quad_ctrl_cost, diff --git a/gym/envs/mujoco/inverted_double_pendulum.py b/gym/envs/mujoco/inverted_double_pendulum.py index f523d170e68..15f2cd9dd1d 100644 --- a/gym/envs/mujoco/inverted_double_pendulum.py +++ b/gym/envs/mujoco/inverted_double_pendulum.py @@ -84,12 +84,12 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): of [-0.1, 0.1] added to the positional values (cart position and pole angles) and standard normal force with a standard deviation of 0.1 added to the velocity values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches 1000 timesteps. - 2. Any of the state space values is no longer finite. - 3. The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other). + 1.Truncation: The episode duration reaches 1000 timesteps. + 2.Termination: Any of the state space values is no longer finite. + 3.Termination: The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other). ### Arguments @@ -123,8 +123,8 @@ def step(self, action): vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2 alive_bonus = 10 r = alive_bonus - dist_penalty - vel_penalty - done = bool(y <= 1) - return ob, r, done, {} + terminated = bool(y <= 1) + return ob, r, terminated, False, {} def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/inverted_pendulum.py b/gym/envs/mujoco/inverted_pendulum.py index 46472cf43a9..94ce46f0a69 100644 --- a/gym/envs/mujoco/inverted_pendulum.py +++ b/gym/envs/mujoco/inverted_pendulum.py @@ -55,12 +55,12 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): (0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-0.01, 0.01] added to the values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches 1000 timesteps. - 2. Any of the state space values is no longer finite. - 3. The absolutely value of the vertical angle between the pole and the cart is greater than 0.2 radian. + 1. Truncation: The episode duration reaches 1000 timesteps. + 2. Termination: Any of the state space values is no longer finite. + 3. Termination: The absolutely value of the vertical angle between the pole and the cart is greater than 0.2 radian. ### Arguments @@ -89,9 +89,8 @@ def step(self, a): reward = 1.0 self.do_simulation(a, self.frame_skip) ob = self._get_obs() - notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2) - done = not notdone - return ob, reward, done, {} + terminated = not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2) + return ob, reward, terminated, False, {} def reset_model(self): qpos = self.init_qpos + self.np_random.uniform( diff --git a/gym/envs/mujoco/mujoco_env.py b/gym/envs/mujoco/mujoco_env.py index 1d04f2f9061..ebf3d22b83f 100644 --- a/gym/envs/mujoco/mujoco_env.py +++ b/gym/envs/mujoco/mujoco_env.py @@ -69,8 +69,8 @@ def __init__(self, model_path, frame_skip): self._set_action_space() action = self.action_space.sample() - observation, _reward, done, _info = self.step(action) - assert not done + observation, _reward, terminated, truncated, _info = self.step(action) + assert not (terminated or truncated) self._set_observation_space(observation) diff --git a/gym/envs/mujoco/pusher.py b/gym/envs/mujoco/pusher.py index 3be870e0176..787cb94ba21 100644 --- a/gym/envs/mujoco/pusher.py +++ b/gym/envs/mujoco/pusher.py @@ -99,12 +99,12 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): The default framerate is 5 with each frame lasting for 0.01, giving rise to a *dt = 5 * 0.01 = 0.05* - ### Episode Termination + ### Episode End - The episode terminates when any of the following happens: + The episode ends when any of the following happens: - 1. The episode duration reaches a 100 timesteps. - 2. Any of the state space values is no longer finite. + 1. Truncation: The episode duration reaches a 100 timesteps. + 2. Termination: Any of the state space values is no longer finite. ### Arguments @@ -143,8 +143,13 @@ def step(self, a): self.do_simulation(a, self.frame_skip) ob = self._get_obs() - done = False - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): self.viewer.cam.trackbodyid = -1 diff --git a/gym/envs/mujoco/reacher.py b/gym/envs/mujoco/reacher.py index 0df3a974bd3..153472abec5 100644 --- a/gym/envs/mujoco/reacher.py +++ b/gym/envs/mujoco/reacher.py @@ -88,12 +88,12 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): element ("fingertip" - "target") is calculated at the end once everything is set. The default setting has a framerate of 2 and a *dt = 2 * 0.01 = 0.02* - ### Episode Termination + ### Episode End - The episode terminates when any of the following happens: + The episode ends when any of the following happens: - 1. The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps) - 2. Any of the state space values is no longer finite. + 1. Truncation: The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps) + 2. Termination: Any of the state space values is no longer finite. ### Arguments @@ -128,8 +128,13 @@ def step(self, a): reward = reward_dist + reward_ctrl self.do_simulation(a, self.frame_skip) ob = self._get_obs() - done = False - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): self.viewer.cam.trackbodyid = 0 diff --git a/gym/envs/mujoco/swimmer.py b/gym/envs/mujoco/swimmer.py index 429852f79a8..b7768621966 100644 --- a/gym/envs/mujoco/swimmer.py +++ b/gym/envs/mujoco/swimmer.py @@ -18,7 +18,13 @@ def step(self, a): reward_ctrl = -ctrl_cost_coeff * np.square(a).sum() reward = reward_fwd + reward_ctrl ob = self._get_obs() - return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl), + ) def _get_obs(self): qpos = self.sim.data.qpos diff --git a/gym/envs/mujoco/swimmer_v3.py b/gym/envs/mujoco/swimmer_v3.py index db07f238f28..6946b52e7d0 100644 --- a/gym/envs/mujoco/swimmer_v3.py +++ b/gym/envs/mujoco/swimmer_v3.py @@ -88,8 +88,8 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): ### Starting State All observations start in state (0,0,0,0,0,0,0,0) with a Uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] is added to the initial state for stochasticity. - ### Episode Termination - The episode terminates when the episode length is greater than 1000. + ### Episode End + The episode ends when the episode length is greater than 1000. ### Arguments @@ -161,7 +161,6 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False info = { "reward_fwd": forward_reward, "reward_ctrl": -ctrl_cost, @@ -173,7 +172,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, False, False, info def _get_obs(self): position = self.sim.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/walker2d.py b/gym/envs/mujoco/walker2d.py index 915ff45f7f1..e7fdb0a56f6 100644 --- a/gym/envs/mujoco/walker2d.py +++ b/gym/envs/mujoco/walker2d.py @@ -17,9 +17,9 @@ def step(self, a): reward = (posafter - posbefore) / self.dt reward += alive_bonus reward -= 1e-3 * np.square(a).sum() - done = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0) + terminated = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0) ob = self._get_obs() - return ob, reward, done, {} + return ob, reward, terminated, False, {} def _get_obs(self): qpos = self.sim.data.qpos diff --git a/gym/envs/mujoco/walker2d_v3.py b/gym/envs/mujoco/walker2d_v3.py index 2c091810f28..92bf4f564d7 100644 --- a/gym/envs/mujoco/walker2d_v3.py +++ b/gym/envs/mujoco/walker2d_v3.py @@ -92,7 +92,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): (0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity. - ### Episode Termination + ### Episode End The walker is said to be unhealthy if any of the following happens: 1. Any of the state space values is no longer finite @@ -100,12 +100,12 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): 3. The absolute value of the angle (`observation[1]` if `exclude_current_positions_from_observation=False`, else `observation[2]`) is ***not*** in the closed interval specified by `healthy_angle_range` If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. The walker is unhealthy + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: The walker is unhealthy - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 70f0b9345e5..ff2b30adf95 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -91,6 +91,7 @@ class EnvSpec: max_episode_steps: The maximum number of steps that an episode can consist of order_enforce: Whether to wrap the environment in an orderEnforcing wrapper autoreset: Whether the environment should automatically reset when it reaches the done state + return_two_dones: Whether the environment step method returns two bools or one bool as episode ending signal kwargs: The kwargs to pass to the environment class """ @@ -102,6 +103,7 @@ class EnvSpec: max_episode_steps: Optional[int] = field(default=None) order_enforce: bool = field(default=True) autoreset: bool = field(default=False) + return_two_dones: bool = field(default=True) kwargs: dict = field(default_factory=dict) namespace: Optional[str] = field(init=False) name: str = field(init=False) @@ -135,8 +137,16 @@ def make(self, **kwargs) -> Env: _kwargs.update(kwargs) if "autoreset" in _kwargs: - self.autoreset = _kwargs["autoreset"] + autoreset = _kwargs["autoreset"] del _kwargs["autoreset"] + else: + autoreset = self.autoreset + + if "return_two_dones" in _kwargs: + return_two_dones = _kwargs["return_two_dones"] + del _kwargs["return_two_dones"] + else: + return_two_dones = self.return_two_dones if callable(self.entry_point): env = self.entry_point(**_kwargs) @@ -149,6 +159,10 @@ def make(self, **kwargs) -> Env: spec.kwargs = _kwargs env.unwrapped.spec = spec + from gym.wrappers import StepCompatibility + + env = StepCompatibility(env, return_two_dones) + if self.order_enforce: from gym.wrappers.order_enforcing import OrderEnforcing @@ -160,7 +174,7 @@ def make(self, **kwargs) -> Env: env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps) - if self.autoreset: + if autoreset: from gym.wrappers.autoreset import AutoResetWrapper env = AutoResetWrapper(env) diff --git a/gym/envs/toy_text/blackjack.py b/gym/envs/toy_text/blackjack.py index fa384961579..8b0143bdba6 100644 --- a/gym/envs/toy_text/blackjack.py +++ b/gym/envs/toy_text/blackjack.py @@ -130,13 +130,13 @@ def step(self, action): if action: # hit: add a card to players hand and return self.player.append(draw_card(self.np_random)) if is_bust(self.player): - done = True + terminated = True reward = -1.0 else: - done = False + terminated = False reward = 0.0 else: # stick: play out the dealers hand, and score - done = True + terminated = True while sum_hand(self.dealer) < 17: self.dealer.append(draw_card(self.np_random)) reward = cmp(score(self.player), score(self.dealer)) @@ -151,7 +151,7 @@ def step(self, action): ): # Natural gives extra points, but doesn't autowin. Legacy implementation reward = 1.5 - return self._get_obs(), reward, done, {} + return self._get_obs(), reward, terminated, False, {} def _get_obs(self): return (sum_hand(self.player), self.dealer[0], usable_ace(self.player)) diff --git a/gym/envs/toy_text/cliffwalking.py b/gym/envs/toy_text/cliffwalking.py index 27188b3b00f..15c4f33c55d 100644 --- a/gym/envs/toy_text/cliffwalking.py +++ b/gym/envs/toy_text/cliffwalking.py @@ -106,7 +106,7 @@ def _calculate_transition_prob(self, current, delta): Determine the outcome for an action. Transition Prob is always 1.0. :param current: Current position on the grid as (row, col) :param delta: Change in position for transition - :return: (1.0, new_state, reward, done) + :return: (1.0, new_state, reward, terminated) """ new_position = np.array(current) + np.array(delta) new_position = self._limit_coordinates(new_position).astype(int) @@ -115,16 +115,16 @@ def _calculate_transition_prob(self, current, delta): return [(1.0, self.start_state_index, -100, False)] terminal_state = (self.shape[0] - 1, self.shape[1] - 1) - is_done = tuple(new_position) == terminal_state - return [(1.0, new_state, -1, is_done)] + is_terminated = tuple(new_position) == terminal_state + return [(1.0, new_state, -1, is_terminated)] def step(self, a): transitions = self.P[self.s][a] i = categorical_sample([t[0] for t in transitions], self.np_random) - p, s, r, d = transitions[i] + p, s, r, t = transitions[i] self.s = s self.lastaction = a - return (int(s), r, d, {"prob": p}) + return (int(s), r, t, False, {"prob": p}) def reset( self, diff --git a/gym/envs/toy_text/frozen_lake.py b/gym/envs/toy_text/frozen_lake.py index 07a29a8faec..33ea0f8f618 100644 --- a/gym/envs/toy_text/frozen_lake.py +++ b/gym/envs/toy_text/frozen_lake.py @@ -174,9 +174,9 @@ def update_probability_matrix(row, col, action): newrow, newcol = inc(row, col, action) newstate = to_s(newrow, newcol) newletter = desc[newrow, newcol] - done = bytes(newletter) in b"GH" + terminated = bytes(newletter) in b"GH" reward = float(newletter == b"G") - return newstate, reward, done + return newstate, reward, terminated for row in range(nrow): for col in range(ncol): @@ -212,10 +212,10 @@ def update_probability_matrix(row, col, action): def step(self, a): transitions = self.P[self.s][a] i = categorical_sample([t[0] for t in transitions], self.np_random) - p, s, r, d = transitions[i] + p, s, r, t = transitions[i] self.s = s self.lastaction = a - return (int(s), r, d, {"prob": p}) + return (int(s), r, t, False, {"prob": p}) def reset( self, diff --git a/gym/envs/toy_text/taxi.py b/gym/envs/toy_text/taxi.py index bbe6fbf0946..1074962a535 100644 --- a/gym/envs/toy_text/taxi.py +++ b/gym/envs/toy_text/taxi.py @@ -132,7 +132,7 @@ def __init__(self): reward = ( -1 ) # default reward when there is no pickup/dropoff - done = False + terminated = False taxi_loc = (row, col) if action == 0: @@ -151,7 +151,7 @@ def __init__(self): elif action == 5: # dropoff if (taxi_loc == locs[dest_idx]) and pass_idx == 4: new_pass_idx = dest_idx - done = True + terminated = True reward = 20 elif (taxi_loc in locs) and pass_idx == 4: new_pass_idx = locs.index(taxi_loc) @@ -160,7 +160,9 @@ def __init__(self): new_state = self.encode( new_row, new_col, new_pass_idx, dest_idx ) - self.P[state][action].append((1.0, new_state, reward, done)) + self.P[state][action].append( + (1.0, new_state, reward, terminated) + ) self.initial_state_distrib /= self.initial_state_distrib.sum() self.action_space = spaces.Discrete(num_actions) self.observation_space = spaces.Discrete(num_states) @@ -206,10 +208,10 @@ def decode(self, i): def step(self, a): transitions = self.P[self.s][a] i = categorical_sample([t[0] for t in transitions], self.np_random) - p, s, r, d = transitions[i] + p, s, r, t = transitions[i] self.s = s self.lastaction = a - return (int(s), r, d, {"prob": p}) + return (int(s), r, t, False, {"prob": p}) def reset( self, diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 625e3b75dad..2128d39150d 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -53,7 +53,7 @@ def _check_nan(env: gym.Env, check_inf: bool = True) -> None: """Check for NaN and Inf.""" for _ in range(10): action = env.action_space.sample() - observation, reward, _, _ = env.step(action) + observation, reward, _, _, _ = env.step(action) if np.any(np.isnan(observation)): logger.warn("Encountered NaN value in observations.") @@ -191,11 +191,14 @@ def _check_returned_values( data = env.step(action) assert ( - len(data) == 4 - ), "The `step()` method must return four values: obs, reward, done, info" + len(data) == 5 or len(data) == 4 + ), "The `step()` method must return either four values: obs, reward, done, info, or five values: obs, reward, terminated, truncated, info" # Unpack - obs, reward, done, info = data + if len(data) == 4: + obs, reward, done, info = data + else: + obs, reward, terminated, truncated, info = data if isinstance(observation_space, spaces.Dict): assert isinstance( @@ -214,7 +217,11 @@ def _check_returned_values( assert isinstance( reward, (float, int, np.float32) ), "The reward returned by `step()` must be a float" - assert isinstance(done, bool), "The `done` signal must be a boolean" + if len(data) == 4: + assert isinstance(done, bool), "The `done` signal must be a boolean" + else: + assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" + assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" assert isinstance( info, dict ), "The `info` returned by `step()` must be a python dictionary" diff --git a/gym/utils/play.py b/gym/utils/play.py index 74667820581..edfe2edd22b 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -41,7 +41,7 @@ def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=N gym.utils.play.PlayPlot. Here's a sample code for plotting the reward for last 5 second of gameplay. - def callback(obs_t, obs_tp1, action, rew, done, info): + def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): return [rew,] plotter = PlayPlot(callback, 30 * 5, ["reward"]) @@ -68,7 +68,8 @@ def callback(obs_t, obs_tp1, action, rew, done, info): obs_tp1: observation after performing action action: action that was executed rew: reward that was received - done: whether the environment is done or not + terminated: whether the environment is terminated or not + truncated: whether the environment is truncated or not info: debug info keys_to_action: dict: tuple(int) -> int or None Mapping from keys pressed to action performed. @@ -116,9 +117,11 @@ def callback(obs_t, obs_tp1, action, rew, done, info): else: action = keys_to_action.get(tuple(sorted(pressed_keys)), 0) prev_obs = obs - obs, rew, env_done, info = env.step(action) + obs, rew, env_terminated, env_truncated, info = env.step(action) if callback is not None: - callback(prev_obs, obs, action, rew, env_done, info) + callback( + prev_obs, obs, action, rew, env_terminated, env_truncated, info + ) if obs is not None: rendered = env.render(mode="rgb_array") display_arr(screen, rendered, transpose=transpose, video_size=video_size) @@ -164,8 +167,10 @@ def __init__(self, callback, horizon_timesteps, plot_names): self.cur_plot = [None for _ in range(num_plots)] self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] - def callback(self, obs_t, obs_tp1, action, rew, done, info): - points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) + def callback(self, obs_t, obs_tp1, action, rew, terminated, truncated, info): + points = self.data_callback( + obs_t, obs_tp1, action, rew, terminated, truncated, info + ) for point, data_series in zip(points, self.data): data_series.append(point) self.t += 1 diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index 4f3b94814fd..ed9fea0a887 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -4,13 +4,16 @@ Iterable = (tuple, list) from gym.vector.async_vector_env import AsyncVectorEnv +from gym.vector.step_compatibility_vector import StepCompatibilityVector from gym.vector.sync_vector_env import SyncVectorEnv from gym.vector.vector_env import VectorEnv, VectorEnvWrapper __all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] -def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs): +def make( + id, num_envs=1, asynchronous=True, wrappers=None, return_two_dones=True, **kwargs +): """Create a vectorized environment from multiple copies of an environment, from its id. @@ -62,4 +65,8 @@ def _make_env(): return env env_fns = [_make_env for _ in range(num_envs)] - return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns) + return ( + StepCompatibilityVector(AsyncVectorEnv(env_fns), return_two_dones) + if asynchronous + else StepCompatibilityVector(SyncVectorEnv(env_fns), return_two_dones) + ) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index e9c1c99f4aa..cdcab830223 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -75,7 +75,7 @@ class AsyncVectorEnv(VectorEnv): worker : callable, optional If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, - how resets on done are handled. + how resets on termination or truncation are handled. Warning ------- @@ -376,8 +376,11 @@ def step_wait(self, timeout=None): rewards : :obj:`np.ndarray`, dtype :obj:`np.float_` A vector of rewards from the vectorized environment. - dones : :obj:`np.ndarray`, dtype :obj:`np.bool_` - A vector whose entries indicate whether the episode has ended. + terminateds : :obj:`np.ndarray`, dtype :obj:`np.bool_` + A vector whose entries indicate whether the episode has ended due to termination. + + truncateds: :obj:`np.ndarray`, dtype :obj:`np.bool_` + A vector whose entries indicate whether the episode has ended due to truncation. infos : list of dict A list of auxiliary diagnostic information dicts from sub-environments. @@ -410,7 +413,7 @@ def step_wait(self, timeout=None): results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) self._state = AsyncState.DEFAULT - observations_list, rewards, dones, infos = zip(*results) + observations_list, rewards, terminateds, truncateds, infos = zip(*results) if not self.shared_memory: self.observations = concatenate( @@ -422,7 +425,8 @@ def step_wait(self, timeout=None): return ( deepcopy(self.observations) if self.copy else self.observations, np.array(rewards), - np.array(dones, dtype=np.bool_), + np.array(terminateds, dtype=np.bool_), + np.array(truncateds, dtype=np.bool_), infos, ) @@ -647,11 +651,11 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): pipe.send((observation, True)) elif command == "step": - observation, reward, done, info = env.step(data) - if done: - info["terminal_observation"] = observation + observation, reward, terminated, truncated, info = env.step(data) + if terminated or truncated: + info["closing_observation"] = observation observation = env.reset() - pipe.send(((observation, reward, done, info), True)) + pipe.send(((observation, reward, terminated, truncated, info), True)) elif command == "seed": env.seed(data) pipe.send((None, True)) @@ -716,14 +720,14 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error ) pipe.send((None, True)) elif command == "step": - observation, reward, done, info = env.step(data) - if done: - info["terminal_observation"] = observation + observation, reward, terminated, truncated, info = env.step(data) + if terminated or truncated: + info["closing_observation"] = observation observation = env.reset() write_to_shared_memory( observation_space, index, observation, shared_memory ) - pipe.send(((None, reward, done, info), True)) + pipe.send(((None, reward, terminated, truncated, info), True)) elif command == "seed": env.seed(data) pipe.send((None, True)) diff --git a/gym/vector/step_compatibility_vector.py b/gym/vector/step_compatibility_vector.py new file mode 100644 index 00000000000..d0e23b60431 --- /dev/null +++ b/gym/vector/step_compatibility_vector.py @@ -0,0 +1,36 @@ +import numpy as np + +import gym +from gym import logger +from gym.vector.vector_env import VectorEnvWrapper + + +class StepCompatibilityVector(VectorEnvWrapper): + def __init__(self, env, return_two_dones=True): + super().__init__(env) + self._return_two_dones = return_two_dones + + def step_wait(self): + step_returns = self.env.step_wait() + if self._return_two_dones: + return step_returns + else: + return self._step_returns_new_to_old(step_returns) + + def _step_returns_new_to_old(self, step_returns): + assert len(step_returns) == 5 + observations, rewards, terminateds, truncateds, infos = step_returns + logger.warn( + "Using a vector wrapper to transform new step API (which returns two bool vectors terminateds, truncateds) into old (returns one bool vector dones). " + "This wrapper will be removed in the future. " + "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " + ) + dones = [] + for i in range(len(terminateds)): + dones.append(terminateds[i] or truncateds[i]) + if truncateds[i]: + infos[i]["TimeLimit.truncated"] = not terminateds[i] + return observations, rewards, np.array(dones, dtype=np.bool_), infos + + def __del__(self): + self.env.__del__() diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 97913499bfd..4e08a845d30 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -72,7 +72,8 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True self.single_observation_space, n=self.num_envs, fn=np.zeros ) self._rewards = np.zeros((self.num_envs,), dtype=np.float64) - self._dones = np.zeros((self.num_envs,), dtype=np.bool_) + self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_) + self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_) self._actions = None def seed(self, seed=None): @@ -98,7 +99,8 @@ def reset_wait( seed = [seed + i for i in range(self.num_envs)] assert len(seed) == self.num_envs - self._dones[:] = False + self._terminateds[:] = False + self._truncateds[:] = False observations = [] data_list = [] for env, single_seed in zip(self.envs, seed): @@ -135,9 +137,15 @@ def step_async(self, actions): def step_wait(self): observations, infos = [], [] for i, (env, action) in enumerate(zip(self.envs, self._actions)): - observation, self._rewards[i], self._dones[i], info = env.step(action) - if self._dones[i]: - info["terminal_observation"] = observation + ( + observation, + self._rewards[i], + self._terminateds[i], + self._truncateds[i], + info, + ) = env.step(action) + if self._terminateds[i] or self._truncateds[i]: + info["closing_observation"] = observation observation = env.reset() observations.append(observation) infos.append(info) @@ -148,7 +156,8 @@ def step_wait(self): return ( deepcopy(self.observations) if self.copy else self.observations, np.copy(self._rewards), - np.copy(self._dones), + np.copy(self._terminateds), + np.copy(self._truncateds), infos, ) diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index b6f497919ca..43254cc0c2f 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -101,8 +101,11 @@ def step(self, actions): rewards : :obj:`np.ndarray`, dtype :obj:`np.float_` A vector of rewards from the vectorized environment. - dones : :obj:`np.ndarray`, dtype :obj:`np.bool_` - A vector whose entries indicate whether the episode has ended. + terminateds : :obj:`np.ndarray`, dtype :obj:`np.bool_` + A vector whose entries indicate whether the episode has terminated. + + truncateds : :obj:`np.ndarray`, dtype :obj:`np.bool_` + A vector whose entries indicate whether the episode has truncated. infos : list of dict A list of auxiliary diagnostic information dicts from sub-environments. diff --git a/gym/wrappers/__init__.py b/gym/wrappers/__init__.py index caadb4074f7..8a405c62bcd 100644 --- a/gym/wrappers/__init__.py +++ b/gym/wrappers/__init__.py @@ -12,6 +12,7 @@ from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.resize_observation import ResizeObservation +from gym.wrappers.step_compatibility import StepCompatibility from gym.wrappers.time_aware_observation import TimeAwareObservation from gym.wrappers.time_limit import TimeLimit from gym.wrappers.transform_observation import TransformObservation diff --git a/gym/wrappers/atari_preprocessing.py b/gym/wrappers/atari_preprocessing.py index 33a70c3ae58..1caa178d3c8 100644 --- a/gym/wrappers/atari_preprocessing.py +++ b/gym/wrappers/atari_preprocessing.py @@ -31,7 +31,7 @@ class AtariPreprocessing(gym.Wrapper): noop_max (int): max number of no-ops frame_skip (int): the frequency at which the agent experiences the game. screen_size (int): resize Atari frame - terminal_on_life_loss (bool): if True, then step() returns done=True whenever a + terminal_on_life_loss (bool): if True, then step() returns terminated=True whenever a life is lost. grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation is returned. @@ -108,16 +108,16 @@ def step(self, action): R = 0.0 for t in range(self.frame_skip): - _, reward, done, info = self.env.step(action) + _, reward, terminated, truncated, info = self.env.step(action) R += reward - self.game_over = done + self.game_over = terminated if self.terminal_on_life_loss: new_lives = self.ale.lives() - done = done or new_lives < self.lives + terminated = terminated or new_lives < self.lives self.lives = new_lives - if done: + if terminated or truncated: break if t == self.frame_skip - 2: if self.grayscale_obs: @@ -129,7 +129,7 @@ def step(self, action): self.ale.getScreenGrayscale(self.obs_buffer[0]) else: self.ale.getScreenRGB(self.obs_buffer[0]) - return self._get_obs(), R, done, info + return self._get_obs(), R, terminated, truncated, info def reset(self, **kwargs): # NoopReset @@ -145,9 +145,9 @@ def reset(self, **kwargs): else 0 ) for _ in range(noops): - _, _, done, step_info = self.env.step(0) + _, _, terminated, truncated, step_info = self.env.step(0) reset_info.update(step_info) - if done: + if terminated or truncated: if kwargs.get("return_info", False): _, reset_info = self.env.reset(**kwargs) else: diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index 408d8b243e3..4873872cbe0 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -6,55 +6,55 @@ class AutoResetWrapper(gym.Wrapper): A class for providing an automatic reset functionality for gym environments when calling self.step(). - When calling step causes self.env.step() to return done, + When calling step causes self.env.step() to return terminated or truncated, self.env.reset() is called, and the return format of self.step() is as follows: - new_obs, terminal_reward, terminal_done, info + new_obs, closing_reward, closing_terminated, closing_truncated, closing_info new_obs is the first observation after calling self.env.reset(), - terminal_reward is the reward after calling self.env.step(), + closing_reward is the reward after calling self.env.step(), prior to calling self.env.reset() - terminal_done is always True + (closing_terminated or closing_truncated) is True info is a dict containing all the keys from the info dict returned by - the call to self.env.reset(), with an additional key "terminal_observation" + the call to self.env.reset(), with an additional key "closing_observation" containing the observation returned by the last call to self.env.step() - and "terminal_info" containing the info dict returned by the last call + and "closing_info" containing the info dict returned by the last call to self.env.step(). - If done is not true when self.env.step() is called, self.step() returns - obs, reward, done, and info as normal. + If (terminated or truncated) is not true when self.env.step() is called, self.step() returns + obs, reward, terminated, truncated, and info as normal. Warning: When using this wrapper to collect rollouts, note - that the when self.env.step() returns done, a + that the when self.env.step() returns terminated=True or truncated=True, a new observation from after calling self.env.reset() is returned - by self.step() alongside the terminal reward and done state from the - previous episode . If you need the terminal state from the previous - episode, you need to retrieve it via the the "terminal_observation" key + by self.step() alongside the closing reward and done state from the + previous episode . If you need the closing state from the previous + episode, you need to retrieve it via the the "closing_observation" key in the info dict. Make sure you know what you're doing if you use this wrapper! """ def step(self, action): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) - if done: + if terminated or truncated: new_obs, new_info = self.env.reset(return_info=True) assert ( - "terminal_observation" not in new_info - ), 'info dict cannot contain key "terminal_observation" ' + "closing_observation" not in new_info + ), 'info dict cannot contain key "closing_observation" ' assert ( - "terminal_info" not in new_info - ), 'info dict cannot contain key "terminal_info" ' + "closing_info" not in new_info + ), 'info dict cannot contain key "closing_info" ' - new_info["terminal_observation"] = obs - new_info["terminal_info"] = info + new_info["closing_observation"] = obs + new_info["closing_info"] = info obs = new_obs info = new_info - return obs, reward, done, info + return obs, reward, terminated, truncated, info diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index 0af589bdff8..f67387be29f 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -115,9 +115,9 @@ def observation(self): return LazyFrames(list(self.frames), self.lz4_compress) def step(self, action): - observation, reward, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) self.frames.append(observation) - return self.observation(), reward, done, info + return self.observation(), reward, terminated, truncated, info def reset(self, **kwargs): if kwargs.get("return_info", False): diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index d12a87a2ad7..cc8585c2e05 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -57,12 +57,12 @@ def __init__( self.epsilon = epsilon def step(self, action): - obs, rews, dones, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = self.env.step(action) if self.is_vector_env: obs = self.normalize(obs) else: obs = self.normalize(np.array([obs]))[0] - return obs, rews, dones, infos + return obs, rews, terminateds, truncateds, infos def reset(self, **kwargs): return_info = kwargs.get("return_info", False) @@ -100,15 +100,19 @@ def __init__( self.epsilon = epsilon def step(self, action): - obs, rews, dones, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = self.env.step(action) if not self.is_vector_env: rews = np.array([rews]) self.returns = self.returns * self.gamma + rews rews = self.normalize(rews) + if not self.is_vector_env: # TODO: Check this + dones = terminateds or truncateds + else: + dones = np.bitwise_or(terminateds, truncateds) self.returns[dones] = 0.0 if not self.is_vector_env: rews = rews[0] - return obs, rews, dones, infos + return obs, rews, terminateds, truncateds, infos def normalize(self, rews): self.return_rms.update(self.returns) diff --git a/gym/wrappers/order_enforcing.py b/gym/wrappers/order_enforcing.py index f6f33fa5c0e..5c60b007ce6 100644 --- a/gym/wrappers/order_enforcing.py +++ b/gym/wrappers/order_enforcing.py @@ -10,8 +10,7 @@ def __init__(self, env): def step(self, action): assert self._has_reset, "Cannot call env.step() before calling reset()" - observation, reward, done, info = self.env.step(action) - return observation, reward, done, info + return self.env.step(action) def reset(self, **kwargs): self._has_reset = True diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index ab5f51192ef..0ec4a541851 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -26,16 +26,17 @@ def reset(self, **kwargs): return observations def step(self, action): - observations, rewards, dones, infos = super().step(action) + observations, rewards, terminateds, truncateds, infos = super().step(action) self.episode_returns += rewards self.episode_lengths += 1 if not self.is_vector_env: infos = [infos] - dones = [dones] + terminateds = [terminateds] + truncateds = [truncateds] else: infos = list(infos) # Convert infos to mutable type - for i in range(len(dones)): - if dones[i]: + for i in range(len(terminateds)): + if terminateds[i] or truncateds[i]: infos[i] = infos[i].copy() episode_return = self.episode_returns[i] episode_length = self.episode_lengths[i] @@ -55,6 +56,7 @@ def step(self, action): return ( observations, rewards, - dones if self.is_vector_env else dones[0], + terminateds if self.is_vector_env else terminateds[0], + truncateds if self.is_vector_env else truncateds[0], infos if self.is_vector_env else infos[0], ) diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index dd2efa295cc..0c21229ea51 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -1,6 +1,8 @@ import os from typing import Callable, Optional +import numpy as np + import gym from gym import logger from gym.wrappers.monitoring import video_recorder @@ -83,14 +85,14 @@ def _video_enabled(self): return self.episode_trigger(self.episode_id) def step(self, action): - observations, rewards, dones, infos = super().step(action) + observations, rewards, terminateds, truncateds, infos = super().step(action) # increment steps and episodes self.step_id += 1 if not self.is_vector_env: - if dones: + if terminateds or truncateds: self.episode_id += 1 - elif dones[0]: + elif terminateds[0] or truncateds[0]: self.episode_id += 1 if self.recording: @@ -101,15 +103,15 @@ def step(self, action): self.close_video_recorder() else: if not self.is_vector_env: - if dones: + if terminateds or truncateds: self.close_video_recorder() - elif dones[0]: + elif terminateds[0] or truncateds[0]: self.close_video_recorder() elif self._video_enabled(): self.start_video_recorder() - return observations, rewards, dones, infos + return observations, rewards, terminateds, truncateds, infos def close_video_recorder(self) -> None: if self.recording: diff --git a/gym/wrappers/step_compatibility.py b/gym/wrappers/step_compatibility.py new file mode 100644 index 00000000000..09a30251d11 --- /dev/null +++ b/gym/wrappers/step_compatibility.py @@ -0,0 +1,69 @@ +import gym +from gym import logger + + +class StepCompatibility(gym.Wrapper): + def __init__(self, env, return_two_dones=True): + super().__init__(env) + self._return_two_dones = return_two_dones + + def step(self, action): + step_returns = self.env.step(action) + if self._return_two_dones: + if len(step_returns) == 5: + logger.warn( + "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " + "Take care to update supporting code to be compatible with this API" + ) + return step_returns + else: + return self._step_returns_old_to_new(step_returns) + else: + if len(step_returns) == 4: + logger.warn( + "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" + ) + return step_returns + elif len(step_returns) == 5: + return self._step_returns_new_to_old(step_returns) + + def _step_returns_old_to_new(self, step_returns): + assert len(step_returns) == 4 + logger.warn( + "Using a wrapper to transform env with old step API into new. This wrapper will be removed in the future. " + "It is recommended to upgrade the core env to the new step API." + "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate" + "Otherwise, `terminated=done` and `truncated=False`" + ) + + obs, rew, done, info = step_returns + if "TimeLimit.truncated" not in info: + terminated = done + truncated = False + elif info["TimeLimit.truncated"]: + terminated = False + truncated = True + else: + # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, + # but it also exceeded maximum timesteps at the same step. + + terminated = True + truncated = True + + return obs, rew, terminated, truncated, info + + def _step_returns_new_to_old(self, step_returns): + assert len(step_returns) == 5 + logger.warn( + "Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " + "This wrapper will be removed in the future. " + "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " + ) + + obs, reward, terminated, truncated, info = step_returns + done = terminated or truncated + if truncated: + info[ + "TimeLimit.truncated" + ] = not terminated # to be consistent with old API + return obs, reward, done, info diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 43c9f9a60ef..15a724a47e1 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -14,12 +14,20 @@ def __init__(self, env, max_episode_steps=None): self._elapsed_steps = None def step(self, action): - observation, reward, done, info = self.env.step(action) - self._elapsed_steps += 1 - if self._elapsed_steps >= self._max_episode_steps: - info["TimeLimit.truncated"] = not done - done = True - return observation, reward, done, info + step_returns = self.env.step(action) + if len(step_returns) == 4: + observation, reward, done, info = self.env.step(action) + if self._elapsed_steps >= self._max_episode_steps: + info["TimeLimit.truncated"] = not done + done = True + return observation, reward, done, info + else: + observation, reward, terminated, truncated, info = step_returns + self._elapsed_steps += 1 + if self._elapsed_steps >= self._max_episode_steps: + truncated = True + info["TimeLimit.truncated"] = truncated + return observation, reward, terminated, truncated, info def reset(self, **kwargs): self._elapsed_steps = 0 diff --git a/gym/wrappers/transform_reward.py b/gym/wrappers/transform_reward.py index 8586df4ba2e..229d9022c01 100644 --- a/gym/wrappers/transform_reward.py +++ b/gym/wrappers/transform_reward.py @@ -10,7 +10,7 @@ class TransformReward(RewardWrapper): >>> env = gym.make('CartPole-v1') >>> env = TransformReward(env, lambda r: 0.01*r) >>> env.reset() - >>> observation, reward, done, info = env.step(env.action_space.sample()) + >>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample()) >>> reward 0.01 diff --git a/tests/envs/test_determinism.py b/tests/envs/test_determinism.py index fef13aacd69..f134643abaf 100644 --- a/tests/envs/test_determinism.py +++ b/tests/envs/test_determinism.py @@ -45,12 +45,13 @@ def test_env(spec): assert_equals(initial_observation1, initial_observation2) - for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate( + for i, ((o1, r1, term1, trunc1, i1), (o2, r2, term2, trunc2, i2)) in enumerate( zip(step_responses1, step_responses2) ): assert_equals(o1, o2, f"[{i}] ") assert r1 == r2, f"[{i}] r1: {r1}, r2: {r2}" - assert d1 == d2, f"[{i}] d1: {d1}, d2: {d2}" + assert term1 == term2, f"[{i}] term1: {term1}, term2: {term2}" + assert trunc1 == trunc2, f"[{i}] trunc1: {trunc1}, trunc2: {trunc2}" # Go returns a Pachi game board in info, which doesn't # properly check equality. For now, we hack around this by diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 23de61b98b2..fd010b5b0d5 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -37,12 +37,13 @@ def test_env(spec): ), f"Reset observation dtype: {ob.dtype}, expected: {ob_space.dtype}" a = act_space.sample() - observation, reward, done, _info = env.step(a) + observation, reward, terminated, truncated, _info = env.step(a) assert ob_space.contains( observation ), f"Step observation: {observation!r} not in space" assert np.isscalar(reward), f"{reward} is not a scalar for {env}" - assert isinstance(done, bool), f"Expected {done} to be a boolean" + assert isinstance(terminated, bool), f"Expected {terminated} to be a boolean" + assert isinstance(truncated, bool), f"Expected {truncated} to be a boolean" if isinstance(ob_space, Box): assert ( observation.dtype == ob_space.dtype @@ -84,8 +85,8 @@ def test_random_rollout(): assert env.observation_space.contains(ob) a = agent(ob) assert env.action_space.contains(a) - (ob, _reward, done, _info) = env.step(a) - if done: + (ob, _reward, terminated, truncated, _info) = env.step(a) + if terminated or truncated: break env.close() diff --git a/tests/envs/test_mujoco_v2_to_v3_conversion.py b/tests/envs/test_mujoco_v2_to_v3_conversion.py index 201d497667e..247c1baedb5 100644 --- a/tests/envs/test_mujoco_v2_to_v3_conversion.py +++ b/tests/envs/test_mujoco_v2_to_v3_conversion.py @@ -19,13 +19,26 @@ def verify_environments_match( for i in range(num_actions): action = old_environment.action_space.sample() - old_observation, old_reward, old_done, old_info = old_environment.step(action) - new_observation, new_reward, new_done, new_info = new_environment.step(action) + ( + old_observation, + old_reward, + old_terminated, + old_truncated, + old_info, + ) = old_environment.step(action) + ( + new_observation, + new_reward, + new_terminated, + new_truncated, + new_info, + ) = new_environment.step(action) eps = 1e-6 np.testing.assert_allclose(old_observation, new_observation, atol=eps) np.testing.assert_allclose(old_reward, new_reward, atol=eps) - np.testing.assert_allclose(old_done, new_done, atol=eps) + np.testing.assert_equal(old_terminated, new_terminated, atol=eps) + np.testing.assert_equal(old_truncated, new_truncated, atol=eps) for key in old_info: np.testing.assert_allclose(old_info[key], new_info[key], atol=eps) diff --git a/tests/test_core.py b/tests/test_core.py index 2d6b0dcd305..ef3f7d7d449 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -25,7 +25,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, {}) + return (observation, 0.0, False, False, {}) class UnknownSpacesEnv(core.Env): @@ -54,11 +54,12 @@ def reset( def step(self, action): observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, {}) + return (observation, 0.0, False, False, {}) class OldStyleEnv(core.Env): - """This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)""" + """This environment doesn't accept any arguments in reset, step returns one bool instead of two, + ideally we want to support this too (for now)""" def __init__(self): pass diff --git a/tests/utils/test_env_checker.py b/tests/utils/test_env_checker.py index b50ec4c39e1..a846fdb9927 100644 --- a/tests/utils/test_env_checker.py +++ b/tests/utils/test_env_checker.py @@ -15,8 +15,9 @@ class ActionDictTestEnv(gym.Env): def step(self, action): observation = np.array([1.0, 1.5, 0.5]) reward = 1 - done = True - return observation, reward, done + terminated = True + truncated = True + return observation, reward, terminated, truncated def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): super().reset(seed=seed) @@ -27,12 +28,12 @@ def render(self, mode="human"): def test_check_env_dict_action(): - # Environment.step() only returns 3 values: obs, reward, done. Not info! + # Environment.step() only returns 4 values: obs, reward, terminated, truncated. Not info! test_env = ActionDictTestEnv() with pytest.raises(AssertionError) as errorinfo: check_env(env=test_env, warn=True) assert ( str(errorinfo.value) - == "The `step()` method must return four values: obs, reward, done, info" + == "The `step()` method must return four values: obs, reward, terminated, truncated, info" ) diff --git a/tests/utils/test_terminated_truncated.py b/tests/utils/test_terminated_truncated.py new file mode 100644 index 00000000000..a4d810bce3a --- /dev/null +++ b/tests/utils/test_terminated_truncated.py @@ -0,0 +1,91 @@ +import pytest + +import gym +from gym.spaces import Discrete +from gym.vector import AsyncVectorEnv, SyncVectorEnv +from gym.wrappers import TimeLimit + + +# An environment where termination happens after 20 steps +class DummyEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + self.terminal_timestep = 20 + + self.timestep = 0 + + def step(self, action): + self.timestep += 1 + terminated = True if self.timestep >= self.terminal_timestep else False + truncated = False + + return 0, 0, terminated, truncated, {} + + def reset(self): + self.timestep = 0 + return 0 + + +@pytest.mark.parametrize("time_limit", [10, 20, 30]) +def test_terminated_truncated(time_limit): + test_env = TimeLimit(DummyEnv(), time_limit) + + terminated = False + truncated = False + test_env.reset() + while not (terminated or truncated): + _, _, terminated, truncated, _ = test_env.step(0) + + if test_env.terminal_timestep < time_limit: + assert terminated + assert not truncated + elif test_env.terminal_timestep == time_limit: + assert ( + terminated + ), "`terminated` should be True even when termination and truncation happen at the same step" + assert ( + truncated + ), "`truncated` should be True even when termination and truncation occur at same step " + else: + assert not terminated + assert truncated + + +def test_terminated_truncated_vector(): + env0 = TimeLimit(DummyEnv(), 10) + env1 = TimeLimit(DummyEnv(), 20) + env2 = TimeLimit(DummyEnv(), 30) + + async_env = AsyncVectorEnv([lambda: env0, lambda: env1, lambda: env2]) + async_env.reset() + terminateds = [False, False, False] + truncateds = [False, False, False] + counter = 0 + while not all([x or y for x, y in zip(terminateds, truncateds)]): + counter += 1 + _, _, terminateds, truncateds, _ = async_env.step( + async_env.action_space.sample() + ) + print(counter) + assert counter == 20 + assert all(terminateds == [False, True, True]) + assert all(truncateds == [True, True, False]) + + sync_env = SyncVectorEnv([lambda: env0, lambda: env1, lambda: env2]) + sync_env.reset() + terminateds = [False, False, False] + truncateds = [False, False, False] + counter = 0 + while not all([x or y for x, y in zip(terminateds, truncateds)]): + counter += 1 + _, _, terminateds, truncateds, _ = sync_env.step( + async_env.action_space.sample() + ) + assert counter == 20 + assert all(terminateds == [False, True, True]) + assert all(truncateds == [True, True, False]) + + +if __name__ == "__main__": + test_terminated_truncated(10) diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 41104799019..8159565f0a7 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -82,7 +82,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, dones, _ = env.step(actions) + observations, rewards, terminateds, truncateds, _ = env.step(actions) finally: env.close() @@ -97,10 +97,15 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(dones, np.ndarray) - assert dones.dtype == np.bool_ - assert dones.ndim == 1 - assert dones.size == 8 + assert isinstance(terminateds, np.ndarray) + assert terminateds.dtype == np.bool_ + assert terminateds.ndim == 1 + assert terminateds.size == 8 + + assert isinstance(truncateds, np.ndarray) + assert truncateds.dtype == np.bool_ + assert truncateds.ndim == 1 + assert truncateds.size == 8 @pytest.mark.parametrize("shared_memory", [True, False]) @@ -180,7 +185,9 @@ def test_step_timeout_async_vector_env(shared_memory): env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) observations = env.reset() env.step_async([0.1, 0.1, 0.3, 0.1]) - observations, rewards, dones, _ = env.step_wait(timeout=0.1) + observations, rewards, terminateds, truncateds, _ = env.step_wait( + timeout=0.1 + ) finally: env.close(terminate=True) @@ -222,7 +229,7 @@ def test_step_out_of_order_async_vector_env(shared_memory): env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) actions = env.action_space.sample() observations = env.reset() - observations, rewards, dones, infos = env.step_wait() + observations, rewards, terminateds, truncateds, infos = env.step_wait() except AlreadyPendingCallError as exception: assert exception.name == "step" raise @@ -272,7 +279,7 @@ def test_custom_space_async_vector_env(): assert isinstance(env.action_space, Tuple) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, dones, _ = env.step(actions) + step_observations, rewards, termianteds, truncateds, _ = env.step(actions) finally: env.close() diff --git a/tests/vector/test_step_compatibility_vector.py b/tests/vector/test_step_compatibility_vector.py new file mode 100644 index 00000000000..ddd0a971266 --- /dev/null +++ b/tests/vector/test_step_compatibility_vector.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest + +import gym +from gym.spaces import Discrete +from gym.vector import AsyncVectorEnv, StepCompatibilityVector, SyncVectorEnv +from gym.wrappers import StepCompatibility + + +class OldStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def reset(self): + return 0 + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + done = False + info = {} + return obs, rew, done, info + + +class NewStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def reset(self): + return 0 + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + terminated = False + truncated = False + info = {} + return obs, rew, terminated, truncated, info + + +@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv]) +def test_vector_step_compatibility_new_env(VecEnv): + + envs = [ + StepCompatibility(OldStepEnv()), + NewStepEnv(), + ] # input to vec env must be in new step api + + vec_env = StepCompatibilityVector( + VecEnv([lambda: env for env in envs]), return_two_dones=False + ) + vec_env.reset() + step_returns = vec_env.step([0, 0]) + assert len(step_returns) == 4 + _, _, dones, _ = step_returns + assert dones.dtype == np.bool_ + + vec_env = StepCompatibilityVector(VecEnv([lambda: env for env in envs])) + vec_env.reset() + step_returns = vec_env.step([0, 0]) + assert len(step_returns) == 5 + _, _, terminateds, truncateds, _ = step_returns + assert terminateds.dtype == np.bool_ + assert truncateds.dtype == np.bool_ + + +@pytest.mark.parametrize("async_bool", [True, False]) +def test_vector_step_compatibility_existing(async_bool): + + env = gym.vector.make( + "CartPole-v1", num_envs=3, asynchronous=async_bool, return_two_dones=False + ) + env.reset() + step_returns = env.step(env.action_space.sample()) + assert len(step_returns) == 4 + _, _, dones, _ = step_returns + assert dones.dtype == np.bool_ + + env = gym.vector.make( + "CartPole-v1", num_envs=3, asynchronous=async_bool, return_two_dones=True + ) + env.reset() + step_returns = env.step(env.action_space.sample()) + assert len(step_returns) == 5 + _, _, terminateds, truncateds, _ = step_returns + assert terminateds.dtype == np.bool_ + assert truncateds.dtype == np.bool_ diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index 623803238ce..d12fd1d0099 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -76,7 +76,7 @@ def test_step_sync_vector_env(use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, dones, _ = env.step(actions) + observations, rewards, terminateds, truncateds, _ = env.step(actions) finally: env.close() @@ -91,10 +91,15 @@ def test_step_sync_vector_env(use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(dones, np.ndarray) - assert dones.dtype == np.bool_ - assert dones.ndim == 1 - assert dones.size == 8 + assert isinstance(terminateds, np.ndarray) + assert terminateds.dtype == np.bool_ + assert terminateds.ndim == 1 + assert terminateds.size == 8 + + assert isinstance(truncateds, np.ndarray) + assert truncateds.dtype == np.bool_ + assert truncateds.ndim == 1 + assert truncateds.size == 8 def test_call_sync_vector_env(): @@ -150,7 +155,7 @@ def test_custom_space_sync_vector_env(): assert isinstance(env.action_space, Tuple) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, dones, _ = env.step(actions) + step_observations, rewards, terminateds, truncateds, _ = env.step(actions) finally: env.close() diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 82870d79c29..c1cca786522 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -31,19 +31,19 @@ def test_vector_env_equal(shared_memory): assert actions in sync_env.action_space # fmt: off - async_observations, async_rewards, async_dones, async_infos = async_env.step(actions) - sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions) + async_observations, async_rewards, async_terminateds, async_truncateds, async_infos = async_env.step(actions) + sync_observations, sync_rewards, sync_terminateds, sync_truncateds, sync_infos = sync_env.step(actions) # fmt: on - for idx in range(len(sync_dones)): - if sync_dones[idx]: - assert "terminal_observation" in async_infos[idx] - assert "terminal_observation" in sync_infos[idx] - assert sync_dones[idx] + for idx in range(len(sync_terminateds)): + if sync_terminateds[idx] or sync_truncateds[idx]: + assert "closing_observation" in async_infos[idx] + assert "closing_observation" in sync_infos[idx] assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) - assert np.all(async_dones == sync_dones) + assert np.all(async_terminateds == sync_terminateds) + assert np.all(async_truncateds == sync_truncateds) finally: async_env.close() diff --git a/tests/vector/utils.py b/tests/vector/utils.py index 0eadb672642..e394fb59e43 100644 --- a/tests/vector/utils.py +++ b/tests/vector/utils.py @@ -67,8 +67,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): time.sleep(action) observation = self.observation_space.sample() - reward, done = 0.0, False - return observation, reward, done, {} + reward, terminated, truncated = 0.0, False, False + return observation, reward, terminated, truncated, {} class CustomSpace(gym.Space): @@ -102,8 +102,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): observation = f"step({action:s})" - reward, done = 0.0, False - return observation, reward, done, {} + reward, terminated, truncated = 0.0, False, False + return observation, reward, terminated, truncated, {} def make_env(env_name, seed): diff --git a/tests/wrappers/nested_dict_test.py b/tests/wrappers/nested_dict_test.py index bde47054137..0c71b212dde 100644 --- a/tests/wrappers/nested_dict_test.py +++ b/tests/wrappers/nested_dict_test.py @@ -29,8 +29,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): del action observation = self.observation_space.sample() - reward, terminal, info = 0.0, False, {} - return observation, reward, terminal, info + reward, terminated, truncated, info = 0.0, False, False, {} + return observation, reward, terminated, truncated, info NESTED_DICT_TEST_CASES = ( diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index e36d3768838..a7bf6b40ea7 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -74,13 +74,13 @@ def test_atari_preprocessing_scale(env_fn): noop_max=0, ) obs = env.reset().flatten() - done, step_i = False, 0 + terminated, truncated, step_i = False, False, 0 max_obs = 1 if scaled else 255 assert (0 <= obs).all() and ( obs <= max_obs ).all(), f"Obs. must be in range [0,{max_obs}]" - while not done or step_i <= max_test_steps: - obs, _, done, _ = env.step(env.action_space.sample()) + while not (terminated or truncated) or step_i <= max_test_steps: + obs, _, terminated, truncated, _ = env.step(env.action_space.sample()) obs = obs.flatten() assert (0 <= obs).all() and ( obs <= max_obs diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 76c035f87dd..952921387bd 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -34,6 +34,7 @@ def step(self, action): return ( np.array([self.count]), 1 if self.count > 2 else 0, + False, self.count > 2, {"count": self.count}, ) @@ -84,9 +85,10 @@ def test_make_autoreset_true(spec): env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset) - done = False - while not done: - obs, reward, done, info = env.step(env.action_space.sample()) + terminated = False + truncated = False + while not terminated and not truncated: + obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) assert isinstance(env, AutoResetWrapper) assert env.unwrapped.reset.called @@ -115,32 +117,37 @@ def test_autoreset_autoreset(): assert obs == np.array([0]) assert info == {"count": 0} action = 1 - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert done == False + assert terminated == False + assert truncated == False assert info == {"count": 1} - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([2]) - assert done == False + assert terminated == False + assert truncated == False assert reward == 0 assert info == {"count": 2} - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([0]) - assert done == True + assert terminated == False + assert truncated == True assert reward == 1 assert info == { "count": 0, - "terminal_observation": np.array([3]), - "terminal_info": {"count": 3}, + "closing_observation": np.array([3]), + "closing_info": {"count": 3}, } - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert done == False + assert terminated == False + assert truncated == False assert info == {"count": 1} - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([2]) assert reward == 0 - assert done == False + assert terminated == False + assert truncated == False assert info == {"count": 2} diff --git a/tests/wrappers/test_clip_action.py b/tests/wrappers/test_clip_action.py index aebf867b6e0..8b4a4a4fd23 100644 --- a/tests/wrappers/test_clip_action.py +++ b/tests/wrappers/test_clip_action.py @@ -17,10 +17,11 @@ def test_clip_action(): actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] for action in actions: - obs1, r1, d1, _ = env.step( + obs1, r1, term1, trunc1, _ = env.step( np.clip(action, env.action_space.low, env.action_space.high) ) - obs2, r2, d2, _ = wrapped_env.step(action) + obs2, r2, term2, trunc2, _ = wrapped_env.step(action) assert np.allclose(r1, r2) assert np.allclose(obs1, obs2) - assert d1 == d2 + assert term1 == term2 + assert trunc1 == trunc2 diff --git a/tests/wrappers/test_filter_observation.py b/tests/wrappers/test_filter_observation.py index e7d5ef2b052..de46ae2dc20 100644 --- a/tests/wrappers/test_filter_observation.py +++ b/tests/wrappers/test_filter_observation.py @@ -32,8 +32,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): del action observation = self.observation_space.sample() - reward, terminal, info = 0.0, False, {} - return observation, reward, terminal, info + reward, terminated, truncated, info = 0.0, False, False, {} + return observation, reward, terminated, truncated, info FILTER_OBSERVATION_TEST_CASES = ( diff --git a/tests/wrappers/test_frame_stack.py b/tests/wrappers/test_frame_stack.py index b9af3002c1f..903f9926c1b 100644 --- a/tests/wrappers/test_frame_stack.py +++ b/tests/wrappers/test_frame_stack.py @@ -42,8 +42,8 @@ def test_frame_stack(env_id, num_stack, lz4_compress): for _ in range(num_stack**2): action = env.action_space.sample() - dup_obs, _, _, _ = dup.step(action) - obs, _, _, _ = env.step(action) + dup_obs, _, _, _, _ = dup.step(action) + obs, _, _, _, _ = env.step(action) assert np.allclose(obs[-1], dup_obs) assert len(obs) == num_stack diff --git a/tests/wrappers/test_normalize.py b/tests/wrappers/test_normalize.py index 13bf32011be..ee73163a1d4 100644 --- a/tests/wrappers/test_normalize.py +++ b/tests/wrappers/test_normalize.py @@ -22,7 +22,13 @@ def __init__(self, return_reward_idx=0): def step(self, action): self.t += 1 - return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {} + return ( + np.array([self.t]), + self.t, + self.t == len(self.returned_rewards), + False, + {}, + ) def reset( self, @@ -94,7 +100,7 @@ def test_normalize_observation_vector_env(): env_fns = [make_env(0), make_env(1)] envs = gym.vector.SyncVectorEnv(env_fns) envs.reset() - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4) np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4) @@ -107,7 +113,7 @@ def test_normalize_observation_vector_env(): np.mean([0.5]), # the mean of first observations [[0, 1]] decimal=4, ) - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.obs_rms.mean, np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]] @@ -120,13 +126,13 @@ def test_normalize_return_vector_env(): envs = gym.vector.SyncVectorEnv(env_fns) envs = NormalizeReward(envs) obs = envs.reset() - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.return_rms.mean, np.mean([1.5]), # the mean of first returns [[1, 2]] decimal=4, ) - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.return_rms.mean, np.mean( diff --git a/tests/wrappers/test_pixel_observation.py b/tests/wrappers/test_pixel_observation.py index 95f094579cd..480e2cec16c 100644 --- a/tests/wrappers/test_pixel_observation.py +++ b/tests/wrappers/test_pixel_observation.py @@ -27,8 +27,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): del action observation = self.observation_space.sample() - reward, terminal, info = 0.0, False, {} - return observation, reward, terminal, info + reward, terminated, truncated, info = 0.0, False, False, {} + return observation, reward, terminated, truncated, info class FakeArrayObservationEnvironment(FakeEnvironment): diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index d9633409eb3..3bb9ec58611 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -16,8 +16,8 @@ def test_record_episode_statistics(env_id, deque_size): assert env.episode_returns[0] == 0.0 assert env.episode_lengths[0] == 0 for t in range(env.spec.max_episode_steps): - _, _, done, info = env.step(env.action_space.sample()) - if done: + _, _, terminated, truncated, info = env.step(env.action_space.sample()) + if terminated or truncated: assert "episode" in info assert all([item in info["episode"] for item in ["r", "l", "t"]]) break @@ -50,9 +50,9 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): ) envs.reset() for _ in range(max_episode_step + 1): - _, _, dones, infos = envs.step(envs.action_space.sample()) + _, _, terminateds, truncateds, infos = envs.step(envs.action_space.sample()) for idx, info in enumerate(infos): - if dones[idx]: + if terminateds[idx] or truncateds[idx]: assert "episode" in info assert all([item in info["episode"] for item in ["r", "l", "t"]]) break diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index 0757c1bec54..878b3af80ef 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -19,8 +19,8 @@ def test_record_video_using_default_trigger(): env.reset() for _ in range(199): action = env.action_space.sample() - _, _, done, _ = env.step(action) - if done: + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: env.reset() env.close() assert os.path.isdir("videos") @@ -68,8 +68,8 @@ def test_record_video_step_trigger(): env.reset() for _ in range(199): action = env.action_space.sample() - _, _, done, _ = env.step(action) - if done: + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: env.reset() env.close() assert os.path.isdir("videos") @@ -96,7 +96,7 @@ def test_record_video_within_vector(): envs = gym.wrappers.RecordEpisodeStatistics(envs) envs.reset() for i in range(199): - _, _, _, infos = envs.step(envs.action_space.sample()) + _, _, _, _, infos = envs.step(envs.action_space.sample()) for info in infos: if "episode" in info.keys(): print(f"episode_reward={info['episode']['r']}") diff --git a/tests/wrappers/test_rescale_action.py b/tests/wrappers/test_rescale_action.py index 6db5ad5fa75..fc71929e718 100644 --- a/tests/wrappers/test_rescale_action.py +++ b/tests/wrappers/test_rescale_action.py @@ -20,10 +20,10 @@ def test_rescale_action(): wrapped_obs = wrapped_env.reset(seed=seed) assert np.allclose(obs, wrapped_obs) - obs, reward, _, _ = env.step([1.5]) + obs, reward, _, _, _ = env.step([1.5]) with pytest.raises(AssertionError): wrapped_env.step([1.5]) - wrapped_obs, wrapped_reward, _, _ = wrapped_env.step([0.75]) + wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75]) assert np.allclose(obs, wrapped_obs) assert np.allclose(reward, wrapped_reward) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py new file mode 100644 index 00000000000..22853064f5f --- /dev/null +++ b/tests/wrappers/test_step_compatibility.py @@ -0,0 +1,75 @@ +import pytest + +import gym +from gym.spaces import Discrete +from gym.wrappers import StepCompatibility + + +class OldStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + done = False + info = {} + return obs, rew, done, info + + +class NewStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + terminated = False + truncated = False + info = {} + return obs, rew, terminated, truncated, info + + +@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) +def test_step_compatibility_to_old_api(env): + env = StepCompatibility(env(), False) + step_returns = env.step(0) + assert len(step_returns) == 4 + _, _, done, _ = step_returns + assert isinstance(done, bool) + + +@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) +@pytest.mark.parametrize("return_two_dones", [None, True]) +def test_step_compatibility_to_new_api(env, return_two_dones): + if return_two_dones is None: + env = StepCompatibility(env()) # default behavior is to convert to new api + else: + env = StepCompatibility(env(), return_two_dones) + step_returns = env.step(0) + assert len(step_returns) == 5 + _, _, terminated, truncated, _ = step_returns + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + + +@pytest.mark.parametrize("return_two_dones", [None, True, False]) +def test_step_compatibility_in_make(return_two_dones): + if return_two_dones is None: + env = gym.make("CartPole-v1") # check default behavior + else: + env = gym.make("CartPole-v1", return_two_dones=return_two_dones) + + env.reset() + step_returns = env.step(0) + if return_two_dones == True or return_two_dones == None: # new api + assert len(step_returns) == 5 + _, _, terminated, truncated, _ = step_returns + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + else: # old api + assert len(step_returns) == 4 + _, _, done, _ = step_returns + assert isinstance(done, bool) diff --git a/tests/wrappers/test_time_aware_observation.py b/tests/wrappers/test_time_aware_observation.py index a996d608cdc..80bd5c3d69e 100644 --- a/tests/wrappers/test_time_aware_observation.py +++ b/tests/wrappers/test_time_aware_observation.py @@ -17,12 +17,12 @@ def test_time_aware_observation(env_id): assert wrapped_obs[-1] == 0.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 - wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) + wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.t == 1.0 assert wrapped_obs[-1] == 1.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 - wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) + wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.t == 2.0 assert wrapped_obs[-1] == 2.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 diff --git a/tests/wrappers/test_time_limit_info.py b/tests/wrappers/test_time_limit_info.py new file mode 100644 index 00000000000..792d6005489 --- /dev/null +++ b/tests/wrappers/test_time_limit_info.py @@ -0,0 +1 @@ +# diff --git a/tests/wrappers/test_transform_observation.py b/tests/wrappers/test_transform_observation.py index fc1076ae4f7..098578677a9 100644 --- a/tests/wrappers/test_transform_observation.py +++ b/tests/wrappers/test_transform_observation.py @@ -18,8 +18,15 @@ def test_transform_observation(env_id): assert np.allclose(wrapped_obs, affine_transform(obs)) action = env.action_space.sample() - obs, reward, done, _ = env.step(action) - wrapped_obs, wrapped_reward, wrapped_done, _ = wrapped_env.step(action) + obs, reward, terminated, truncated, _ = env.step(action) + ( + wrapped_obs, + wrapped_reward, + wrapped_terminated, + wrapped_truncated, + _, + ) = wrapped_env.step(action) assert np.allclose(wrapped_obs, affine_transform(obs)) assert np.allclose(wrapped_reward, reward) - assert wrapped_done == done + assert wrapped_terminated == terminated + assert wrapped_truncated == truncated diff --git a/tests/wrappers/test_transform_reward.py b/tests/wrappers/test_transform_reward.py index c7badb7a2d0..74e061025dc 100644 --- a/tests/wrappers/test_transform_reward.py +++ b/tests/wrappers/test_transform_reward.py @@ -17,8 +17,8 @@ def test_transform_reward(env_id): env.reset(seed=0) wrapped_env.reset(seed=0) - _, reward, _, _ = env.step(action) - _, wrapped_reward, _, _ = wrapped_env.step(action) + _, reward, _, _, _ = env.step(action) + _, wrapped_reward, _, _, _ = wrapped_env.step(action) assert wrapped_reward == scale * reward del env, wrapped_env @@ -33,8 +33,8 @@ def test_transform_reward(env_id): env.reset(seed=0) wrapped_env.reset(seed=0) - _, reward, _, _ = env.step(action) - _, wrapped_reward, _, _ = wrapped_env.step(action) + _, reward, _, _, _ = env.step(action) + _, wrapped_reward, _, _, _ = wrapped_env.step(action) assert abs(wrapped_reward) < abs(reward) assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002 @@ -49,8 +49,8 @@ def test_transform_reward(env_id): for _ in range(1000): action = env.action_space.sample() - _, wrapped_reward, done, _ = wrapped_env.step(action) + _, wrapped_reward, terminated, truncated, _ = wrapped_env.step(action) assert wrapped_reward in [-1.0, 0.0, 1.0] - if done: + if terminated or truncated: break del env, wrapped_env From a0c44755437eaf2a40d20ca20671d76603b049b6 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 20 Apr 2022 22:21:40 +0530 Subject: [PATCH 02/37] Setting return_two_dones=False as default --- README.md | 2 +- gym/envs/mujoco/walker2d_v3.py | 14 +++++----- gym/envs/registration.py | 2 +- gym/vector/__init__.py | 2 ++ gym/wrappers/step_compatibility.py | 9 ++++++- tests/envs/test_action_dim_check.py | 2 +- tests/envs/test_determinism.py | 4 +-- tests/envs/test_envs.py | 13 ++++++---- tests/envs/test_mujoco_v2_to_v3_conversion.py | 8 +++--- tests/utils/test_play.py | 26 +++++++++---------- .../vector/test_step_compatibility_vector.py | 4 +++ tests/vector/utils.py | 5 ++-- tests/wrappers/test_autoreset.py | 8 +++--- tests/wrappers/test_clip_action.py | 2 +- .../test_record_episode_statistics.py | 11 +++++--- tests/wrappers/test_record_video.py | 12 ++++----- tests/wrappers/test_rescale_action.py | 6 ++--- tests/wrappers/test_step_compatibility.py | 25 +++++++++--------- tests/wrappers/test_time_aware_observation.py | 2 +- tests/wrappers/test_transform_observation.py | 4 +-- tests/wrappers/test_transform_reward.py | 18 ++++++++----- 21 files changed, 103 insertions(+), 76 deletions(-) diff --git a/README.md b/README.md index 737f16bce82..37b63f2d162 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ observation, info = env.reset(seed=42, return_info=True) for _ in range(1000): action = env.action_space.sample() - observation, reward, terminated, truncated, info = env.step(action) + observation, reward, done, info = env.step(action) if done: observation, info = env.reset(return_info=True) diff --git a/gym/envs/mujoco/walker2d_v3.py b/gym/envs/mujoco/walker2d_v3.py index 92bf4f564d7..db5b709402e 100644 --- a/gym/envs/mujoco/walker2d_v3.py +++ b/gym/envs/mujoco/walker2d_v3.py @@ -15,7 +15,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): """ ### Description - This environment builds on the hopper environment based on the work done by Erez, Tassa, and Todorov + This environment builds on the hopper environment based on the work terminated by Erez, Tassa, and Todorov in ["Infinite Horizon Model Predictive Control for Nonlinear Periodic Tasks"](http://www.roboticsproceedings.org/rss07/p10.pdf) by adding another set of legs making it possible for the robot to walker forward instead of hop. Like other Mujoco environments, this environment aims to increase the number of independent state @@ -127,7 +127,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): | `forward_reward_weight` | **float** | `1.0` | Weight for *forward_reward* term (see section on reward) | | `ctrl_cost_weight` | **float** | `1e-3` | Weight for *ctr_cost* term (see section on reward) | | `healthy_reward` | **float** | `1.0` | Constant reward given if the ant is "healthy" after timestep | - | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a done signal if the z-coordinate of the walker is no longer healthy | + | `terminate_when_unhealthy` | **bool**| `True` | If true, issue a terminated signal if the z-coordinate of the walker is no longer healthy | | `healthy_z_range` | **tuple** | `(0.8, 2)` | The z-coordinate of the top of the walker must be in this range to be considered healthy | | `healthy_angle_range` | **tuple** | `(-1, 1)` | The angle must be in this range to be considered healthy| | `reset_noise_scale` | **float** | `5e-3` | Scale of random perturbations of initial position and velocity (see section on Starting State) | @@ -199,9 +199,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.sim.data.qpos.flat.copy() @@ -229,13 +229,13 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 48fc2e5c512..fbd6cea12f5 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -103,7 +103,7 @@ class EnvSpec: max_episode_steps: Optional[int] = field(default=None) order_enforce: bool = field(default=True) autoreset: bool = field(default=False) - return_two_dones: bool = field(default=True) + return_two_dones: bool = field(default=False) kwargs: dict = field(default_factory=dict) namespace: Optional[str] = field(init=False) name: str = field(init=False) diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index ed9fea0a887..ed37aeb3cc4 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -7,6 +7,7 @@ from gym.vector.step_compatibility_vector import StepCompatibilityVector from gym.vector.sync_vector_env import SyncVectorEnv from gym.vector.vector_env import VectorEnv, VectorEnvWrapper +from gym.wrappers import StepCompatibility __all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] @@ -52,6 +53,7 @@ def make( def _make_env(): env = make_(id, **kwargs) + env = StepCompatibility(env, return_two_dones=True) if wrappers is not None: if callable(wrappers): env = wrappers(env) diff --git a/gym/wrappers/step_compatibility.py b/gym/wrappers/step_compatibility.py index 09a30251d11..3992af863a6 100644 --- a/gym/wrappers/step_compatibility.py +++ b/gym/wrappers/step_compatibility.py @@ -3,9 +3,15 @@ class StepCompatibility(gym.Wrapper): - def __init__(self, env, return_two_dones=True): + def __init__(self, env, return_two_dones=False): super().__init__(env) self._return_two_dones = return_two_dones + if not self._return_two_dones: + logger.warn( + "Initializing environment in old step API which returns one bool instead of two. " + "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " + "To use these features, please set `return_two_dones=True` in make to use new API (see docs for more details)." + ) def step(self, action): step_returns = self.env.step(action) @@ -23,6 +29,7 @@ def step(self, action): logger.warn( "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" ) + return step_returns elif len(step_returns) == 5: return self._step_returns_new_to_old(step_returns) diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 448847cd6ec..5828be16f1a 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -11,7 +11,7 @@ @pytest.mark.skipif(skip_mujoco, reason=SKIP_MUJOCO_WARNING_MESSAGE) @pytest.mark.parametrize("environment_id", ENVIRONMENT_IDS) def test_serialize_deserialize(environment_id): - env = envs.make(environment_id) + env = envs.make(environment_id, return_two_dones=True) env.reset() with pytest.raises(ValueError, match="Action dimension mismatch"): diff --git a/tests/envs/test_determinism.py b/tests/envs/test_determinism.py index f134643abaf..d842c66c293 100644 --- a/tests/envs/test_determinism.py +++ b/tests/envs/test_determinism.py @@ -9,14 +9,14 @@ def test_env(spec): # Note that this precludes running this test in multiple # threads. However, we probably already can't do multithreading # due to some environments. - env1 = spec.make() + env1 = spec.make(return_two_dones=True) initial_observation1 = env1.reset(seed=0) env1.action_space.seed(0) action_samples1 = [env1.action_space.sample() for i in range(4)] step_responses1 = [env1.step(action) for action in action_samples1] env1.close() - env2 = spec.make() + env2 = spec.make(return_two_dones=True) initial_observation2 = env2.reset(seed=0) env2.action_space.seed(0) action_samples2 = [env2.action_space.sample() for i in range(4)] diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index fd010b5b0d5..3efd854d59d 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -17,7 +17,7 @@ def test_env(spec): # Capture warnings with pytest.warns(None) as warnings: - env = spec.make() + env = spec.make(return_two_dones=True) # Test if env adheres to Gym API check_env(env, warn=True, skip_render_check=True) @@ -63,7 +63,7 @@ def test_env(spec): def test_reset_info(spec): with pytest.warns(None) as warnings: - env = spec.make() + env = spec.make(return_two_dones=True) ob_space = env.observation_space obs = env.reset() @@ -78,7 +78,10 @@ def test_reset_info(spec): # Run a longer rollout on some environments def test_random_rollout(): - for env in [envs.make("CartPole-v1"), envs.make("FrozenLake-v1")]: + for env in [ + envs.make("CartPole-v1", return_two_dones=True), + envs.make("FrozenLake-v1", return_two_dones=True), + ]: agent = lambda ob: env.action_space.sample() ob = env.reset() for _ in range(10): @@ -93,8 +96,8 @@ def test_random_rollout(): def test_env_render_result_is_immutable(): environs = [ - envs.make("Taxi-v3"), - envs.make("FrozenLake-v1"), + envs.make("Taxi-v3", return_two_dones=True), + envs.make("FrozenLake-v1", return_two_dones=True), ] for env in environs: diff --git a/tests/envs/test_mujoco_v2_to_v3_conversion.py b/tests/envs/test_mujoco_v2_to_v3_conversion.py index 247c1baedb5..b0b9d7ad4d9 100644 --- a/tests/envs/test_mujoco_v2_to_v3_conversion.py +++ b/tests/envs/test_mujoco_v2_to_v3_conversion.py @@ -9,8 +9,8 @@ def verify_environments_match( old_environment_id, new_environment_id, seed=1, num_actions=1000 ): - old_environment = envs.make(old_environment_id) - new_environment = envs.make(new_environment_id) + old_environment = envs.make(old_environment_id, return_two_dones=True) + new_environment = envs.make(new_environment_id, return_two_dones=True) old_reset_observation = old_environment.reset(seed=seed) new_reset_observation = new_environment.reset(seed=seed) @@ -37,8 +37,8 @@ def verify_environments_match( eps = 1e-6 np.testing.assert_allclose(old_observation, new_observation, atol=eps) np.testing.assert_allclose(old_reward, new_reward, atol=eps) - np.testing.assert_equal(old_terminated, new_terminated, atol=eps) - np.testing.assert_equal(old_truncated, new_truncated, atol=eps) + np.testing.assert_equal(old_terminated, new_terminated) + np.testing.assert_equal(old_truncated, new_truncated) for key in old_info: np.testing.assert_allclose(old_info[key], new_info[key], atol=eps) diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py index 0f0ee1e46eb..f54a866b186 100644 --- a/tests/utils/test_play.py +++ b/tests/utils/test_play.py @@ -23,8 +23,8 @@ class DummyEnvSpec: class DummyPlayEnv(gym.Env): def step(self, action): obs = np.zeros((1, 1)) - rew, done, info = 1, False, {} - return obs, rew, done, info + rew, terminated, truncated, info = 1, False, False, {} + return obs, rew, terminated, truncated, info def reset(self, seed=None): ... @@ -39,9 +39,9 @@ def __init__(self, callback: Callable): self.cumulative_reward = 0 self.last_observation = None - def callback(self, obs_t, obs_tp1, action, rew, done, info): - _, obs_tp1, _, rew, _, _ = self.data_callback( - obs_t, obs_tp1, action, rew, done, info + def callback(self, obs_t, obs_tp1, action, rew, terminated, truncated, info): + _, obs_tp1, _, rew, _, _, _ = self.data_callback( + obs_t, obs_tp1, action, rew, terminated, truncated, info ) self.cumulative_reward += rew self.last_observation = obs_tp1 @@ -144,16 +144,16 @@ def test_play_loop(): Event(QUIT), ] - def callback(obs_t, obs_tp1, action, rew, done, info): + def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): event.post(callback_events.pop(0)) - return obs_t, obs_tp1, action, rew, done, info + return obs_t, obs_tp1, action, rew, terminated, truncated, info env = DummyPlayEnv() cumulative_env_reward = 0 for s in range( len(callback_events) ): # we run the same number of steps executed with play() - _, rew, _, _ = env.step(None) + _, rew, _, _, _ = env.step(None) cumulative_env_reward += rew env_play = DummyPlayEnv() @@ -183,7 +183,7 @@ def test_play_loop_real_env(): ] keydown_events = [k for k in callback_events if k.type == KEYDOWN] - def callback(obs_t, obs_tp1, action, rew, done, info): + def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): pygame_event = callback_events.pop(0) event.post(pygame_event) @@ -193,9 +193,9 @@ def callback(obs_t, obs_tp1, action, rew, done, info): pygame_event = callback_events.pop(0) event.post(pygame_event) - return obs_t, obs_tp1, action, rew, done, info + return obs_t, obs_tp1, action, rew, terminated, truncated, info - env = gym.make(ENV) + env = gym.make(ENV, return_two_dones=True) env.reset(seed=SEED) keys_to_action = dummy_keys_to_action() @@ -204,9 +204,9 @@ def callback(obs_t, obs_tp1, action, rew, done, info): env.step(0) for e in keydown_events: action = keys_to_action[(e.key,)] - obs, _, _, _ = env.step(action) + obs, _, _, _, _ = env.step(action) - env_play = gym.make(ENV) + env_play = gym.make(ENV, return_two_dones=True) status = PlayStatus(callback) play(env_play, callback=status.callback, keys_to_action=keys_to_action, seed=SEED) diff --git a/tests/vector/test_step_compatibility_vector.py b/tests/vector/test_step_compatibility_vector.py index ddd0a971266..2096dedbd0a 100644 --- a/tests/vector/test_step_compatibility_vector.py +++ b/tests/vector/test_step_compatibility_vector.py @@ -56,6 +56,7 @@ def test_vector_step_compatibility_new_env(VecEnv): assert len(step_returns) == 4 _, _, dones, _ = step_returns assert dones.dtype == np.bool_ + vec_env.close() vec_env = StepCompatibilityVector(VecEnv([lambda: env for env in envs])) vec_env.reset() @@ -64,6 +65,7 @@ def test_vector_step_compatibility_new_env(VecEnv): _, _, terminateds, truncateds, _ = step_returns assert terminateds.dtype == np.bool_ assert truncateds.dtype == np.bool_ + vec_env.close() @pytest.mark.parametrize("async_bool", [True, False]) @@ -77,6 +79,7 @@ def test_vector_step_compatibility_existing(async_bool): assert len(step_returns) == 4 _, _, dones, _ = step_returns assert dones.dtype == np.bool_ + env.close() env = gym.vector.make( "CartPole-v1", num_envs=3, asynchronous=async_bool, return_two_dones=True @@ -87,3 +90,4 @@ def test_vector_step_compatibility_existing(async_bool): _, _, terminateds, truncateds, _ = step_returns assert terminateds.dtype == np.bool_ assert truncateds.dtype == np.bool_ + env.close() diff --git a/tests/vector/utils.py b/tests/vector/utils.py index e394fb59e43..d04a6586d0b 100644 --- a/tests/vector/utils.py +++ b/tests/vector/utils.py @@ -106,9 +106,10 @@ def step(self, action): return observation, reward, terminated, truncated, {} -def make_env(env_name, seed): +def make_env(env_name, seed, return_two_dones=True): + # return_two_dones=True, only for compatibility with vector tests, to be removed at v1.0 def _make(): - env = gym.make(env_name) + env = gym.make(env_name, return_two_dones=return_two_dones) env.reset(seed=seed) return env diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 952921387bd..431a5550ada 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -54,7 +54,7 @@ def reset( def test_autoreset_reset_info(): - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) env = AutoResetWrapper(env) ob_space = env.observation_space obs = env.reset() @@ -77,7 +77,7 @@ def test_make_autoreset_true(spec): """ env = None with pytest.warns(None) as warnings: - env = spec.make(autoreset=True) + env = spec.make(autoreset=True, return_two_dones=True) ob_space = env.observation_space obs = env.reset(seed=0) @@ -98,7 +98,7 @@ def test_make_autoreset_true(spec): def test_make_autoreset_false(spec): env = None with pytest.warns(None) as warnings: - env = spec.make(autoreset=False) + env = spec.make(autoreset=False, return_two_dones=True) assert not isinstance(env, AutoResetWrapper) @@ -106,7 +106,7 @@ def test_make_autoreset_false(spec): def test_make_autoreset_default_false(spec): env = None with pytest.warns(None) as warnings: - env = spec.make() + env = spec.make(return_two_dones=True) assert not isinstance(env, AutoResetWrapper) diff --git a/tests/wrappers/test_clip_action.py b/tests/wrappers/test_clip_action.py index 8b4a4a4fd23..5d5a93506b5 100644 --- a/tests/wrappers/test_clip_action.py +++ b/tests/wrappers/test_clip_action.py @@ -6,7 +6,7 @@ def test_clip_action(): # mountaincar: action-based rewards - make_env = lambda: gym.make("MountainCarContinuous-v0") + make_env = lambda: gym.make("MountainCarContinuous-v0", return_two_dones=True) env = make_env() wrapped_env = ClipAction(make_env()) diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index 3bb9ec58611..54d31068e2c 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("deque_size", [2, 5]) def test_record_episode_statistics(env_id, deque_size): - env = gym.make(env_id) + env = gym.make(env_id, return_two_dones=True) env = RecordEpisodeStatistics(env, deque_size) for n in range(5): @@ -26,7 +26,7 @@ def test_record_episode_statistics(env_id, deque_size): def test_record_episode_statistics_reset_info(): - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) env = RecordEpisodeStatistics(env) ob_space = env.observation_space obs = env.reset() @@ -41,7 +41,12 @@ def test_record_episode_statistics_reset_info(): ("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)] ) def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): - envs = gym.vector.make("CartPole-v1", num_envs=num_envs, asynchronous=asynchronous) + envs = gym.vector.make( + "CartPole-v1", + num_envs=num_envs, + asynchronous=asynchronous, + return_two_dones=True, + ) envs = RecordEpisodeStatistics(envs) max_episode_step = ( envs.env_fns[0]().spec.max_episode_steps diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index 878b3af80ef..516323bc0bc 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -14,7 +14,7 @@ def test_record_video_using_default_trigger(): - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) env = gym.wrappers.RecordVideo(env, "videos") env.reset() for _ in range(199): @@ -32,7 +32,7 @@ def test_record_video_using_default_trigger(): def test_record_video_reset_return_info(): - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) ob_space = env.observation_space obs, info = env.reset(return_info=True) @@ -42,7 +42,7 @@ def test_record_video_reset_return_info(): assert ob_space.contains(obs) assert isinstance(info, dict) - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) ob_space = env.observation_space obs = env.reset(return_info=False) @@ -51,7 +51,7 @@ def test_record_video_reset_return_info(): shutil.rmtree("videos") assert ob_space.contains(obs) - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) ob_space = env.observation_space obs = env.reset() @@ -62,7 +62,7 @@ def test_record_video_reset_return_info(): def test_record_video_step_trigger(): - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) env._max_episode_steps = 20 env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env.reset() @@ -80,7 +80,7 @@ def test_record_video_step_trigger(): def make_env(gym_id, seed): def thunk(): - env = gym.make(gym_id) + env = gym.make(gym_id, return_two_dones=True) env._max_episode_steps = 20 if seed == 1: env = gym.wrappers.RecordVideo( diff --git a/tests/wrappers/test_rescale_action.py b/tests/wrappers/test_rescale_action.py index fc71929e718..abade7c9705 100644 --- a/tests/wrappers/test_rescale_action.py +++ b/tests/wrappers/test_rescale_action.py @@ -6,13 +6,13 @@ def test_rescale_action(): - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", return_two_dones=True) with pytest.raises(AssertionError): env = RescaleAction(env, -1, 1) del env - env = gym.make("Pendulum-v1") - wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1) + env = gym.make("Pendulum-v1", return_two_dones=True) + wrapped_env = RescaleAction(gym.make("Pendulum-v1", return_two_dones=True), -1, 1) seed = 0 diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index 22853064f5f..f1125f97aa8 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -33,26 +33,25 @@ def step(self, action): @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -def test_step_compatibility_to_old_api(env): - env = StepCompatibility(env(), False) +def test_step_compatibility_to_new_api(env): + env = StepCompatibility(env(), True) step_returns = env.step(0) - assert len(step_returns) == 4 - _, _, done, _ = step_returns - assert isinstance(done, bool) + _, _, terminated, truncated, _ = step_returns + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -@pytest.mark.parametrize("return_two_dones", [None, True]) -def test_step_compatibility_to_new_api(env, return_two_dones): +@pytest.mark.parametrize("return_two_dones", [None, False]) +def test_step_compatibility_to_old_api(env, return_two_dones): if return_two_dones is None: - env = StepCompatibility(env()) # default behavior is to convert to new api + env = StepCompatibility(env()) # default behavior is to retain old API else: env = StepCompatibility(env(), return_two_dones) step_returns = env.step(0) - assert len(step_returns) == 5 - _, _, terminated, truncated, _ = step_returns - assert isinstance(terminated, bool) - assert isinstance(truncated, bool) + assert len(step_returns) == 4 + _, _, done, _ = step_returns + assert isinstance(done, bool) @pytest.mark.parametrize("return_two_dones", [None, True, False]) @@ -64,7 +63,7 @@ def test_step_compatibility_in_make(return_two_dones): env.reset() step_returns = env.step(0) - if return_two_dones == True or return_two_dones == None: # new api + if return_two_dones == True: # new api assert len(step_returns) == 5 _, _, terminated, truncated, _ = step_returns assert isinstance(terminated, bool) diff --git a/tests/wrappers/test_time_aware_observation.py b/tests/wrappers/test_time_aware_observation.py index 80bd5c3d69e..bdf803346e1 100644 --- a/tests/wrappers/test_time_aware_observation.py +++ b/tests/wrappers/test_time_aware_observation.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_time_aware_observation(env_id): - env = gym.make(env_id) + env = gym.make(env_id, return_two_dones=True) wrapped_env = TimeAwareObservation(env) assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 diff --git a/tests/wrappers/test_transform_observation.py b/tests/wrappers/test_transform_observation.py index 098578677a9..c2c4ede3a2e 100644 --- a/tests/wrappers/test_transform_observation.py +++ b/tests/wrappers/test_transform_observation.py @@ -8,9 +8,9 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_transform_observation(env_id): affine_transform = lambda x: 3 * x + 2 - env = gym.make(env_id) + env = gym.make(env_id, return_two_dones=True) wrapped_env = TransformObservation( - gym.make(env_id), lambda obs: affine_transform(obs) + gym.make(env_id, return_two_dones=True), lambda obs: affine_transform(obs) ) obs = env.reset(seed=0) diff --git a/tests/wrappers/test_transform_reward.py b/tests/wrappers/test_transform_reward.py index 74e061025dc..0e8cb32bdd4 100644 --- a/tests/wrappers/test_transform_reward.py +++ b/tests/wrappers/test_transform_reward.py @@ -10,8 +10,10 @@ def test_transform_reward(env_id): # use case #1: scale scales = [0.1, 200] for scale in scales: - env = gym.make(env_id) - wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r) + env = gym.make(env_id, return_two_dones=True) + wrapped_env = TransformReward( + gym.make(env_id, return_two_dones=True), lambda r: scale * r + ) action = env.action_space.sample() env.reset(seed=0) @@ -26,8 +28,10 @@ def test_transform_reward(env_id): # use case #2: clip min_r = -0.0005 max_r = 0.0002 - env = gym.make(env_id) - wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r)) + env = gym.make(env_id, return_two_dones=True) + wrapped_env = TransformReward( + gym.make(env_id, return_two_dones=True), lambda r: np.clip(r, min_r, max_r) + ) action = env.action_space.sample() env.reset(seed=0) @@ -41,8 +45,10 @@ def test_transform_reward(env_id): del env, wrapped_env # use case #3: sign - env = gym.make(env_id) - wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r)) + env = gym.make(env_id, return_two_dones=True) + wrapped_env = TransformReward( + gym.make(env_id, return_two_dones=True), lambda r: np.sign(r) + ) env.reset(seed=0) wrapped_env.reset(seed=0) From 2aabc30df6065984dfb42631e10daaf2bc75e0b2 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 21 Apr 2022 00:17:14 +0530 Subject: [PATCH 03/37] update warnings --- gym/vector/step_compatibility_vector.py | 2 +- gym/wrappers/step_compatibility.py | 10 +++++----- tests/wrappers/test_step_compatibility.py | 7 ++++++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/gym/vector/step_compatibility_vector.py b/gym/vector/step_compatibility_vector.py index d0e23b60431..b34143f466e 100644 --- a/gym/vector/step_compatibility_vector.py +++ b/gym/vector/step_compatibility_vector.py @@ -20,7 +20,7 @@ def step_wait(self): def _step_returns_new_to_old(self, step_returns): assert len(step_returns) == 5 observations, rewards, terminateds, truncateds, infos = step_returns - logger.warn( + logger.deprecation( "Using a vector wrapper to transform new step API (which returns two bool vectors terminateds, truncateds) into old (returns one bool vector dones). " "This wrapper will be removed in the future. " "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " diff --git a/gym/wrappers/step_compatibility.py b/gym/wrappers/step_compatibility.py index 3992af863a6..a6c7f8fe1ee 100644 --- a/gym/wrappers/step_compatibility.py +++ b/gym/wrappers/step_compatibility.py @@ -7,7 +7,7 @@ def __init__(self, env, return_two_dones=False): super().__init__(env) self._return_two_dones = return_two_dones if not self._return_two_dones: - logger.warn( + logger.deprecation( "Initializing environment in old step API which returns one bool instead of two. " "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " "To use these features, please set `return_two_dones=True` in make to use new API (see docs for more details)." @@ -17,7 +17,7 @@ def step(self, action): step_returns = self.env.step(action) if self._return_two_dones: if len(step_returns) == 5: - logger.warn( + logger.deprecation( "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " "Take care to update supporting code to be compatible with this API" ) @@ -26,7 +26,7 @@ def step(self, action): return self._step_returns_old_to_new(step_returns) else: if len(step_returns) == 4: - logger.warn( + logger.deprecation( "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" ) @@ -36,7 +36,7 @@ def step(self, action): def _step_returns_old_to_new(self, step_returns): assert len(step_returns) == 4 - logger.warn( + logger.deprecation( "Using a wrapper to transform env with old step API into new. This wrapper will be removed in the future. " "It is recommended to upgrade the core env to the new step API." "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate" @@ -61,7 +61,7 @@ def _step_returns_old_to_new(self, step_returns): def _step_returns_new_to_old(self, step_returns): assert len(step_returns) == 5 - logger.warn( + logger.deprecation( "Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " "This wrapper will be removed in the future. " "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index f1125f97aa8..213c7ed2d3c 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -1,3 +1,5 @@ +import warnings + import pytest import gym @@ -57,7 +59,10 @@ def test_step_compatibility_to_old_api(env, return_two_dones): @pytest.mark.parametrize("return_two_dones", [None, True, False]) def test_step_compatibility_in_make(return_two_dones): if return_two_dones is None: - env = gym.make("CartPole-v1") # check default behavior + with pytest.warns( + DeprecationWarning, match="Initializing environment in old step API" + ): + env = gym.make("CartPole-v1") else: env = gym.make("CartPole-v1", return_two_dones=return_two_dones) From 1babe4e18ea73205d71b5b615c8689cbb44efea4 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 21 Apr 2022 15:36:29 +0530 Subject: [PATCH 04/37] pytest - ignore deprecation warnings --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index bbeae1ed72b..184536f7f22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,3 +18,6 @@ strict = ["gym/version.py", "gym/logger.py"] reportMissingImports = true reportMissingTypeStubs = false verboseOutput = true + +[tool.pytest.ini_options] +filterwarnings = ["ignore::DeprecationWarning"] \ No newline at end of file From c9c6add6b0d330625e86a98f6fff02d81276444b Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 21 Apr 2022 20:48:31 +0530 Subject: [PATCH 05/37] Only ignore step api deprecation warnings --- gym/vector/step_compatibility_vector.py | 2 +- gym/wrappers/step_compatibility.py | 10 +++++----- pyproject.toml | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gym/vector/step_compatibility_vector.py b/gym/vector/step_compatibility_vector.py index b34143f466e..f277fc4f43f 100644 --- a/gym/vector/step_compatibility_vector.py +++ b/gym/vector/step_compatibility_vector.py @@ -21,7 +21,7 @@ def _step_returns_new_to_old(self, step_returns): assert len(step_returns) == 5 observations, rewards, terminateds, truncateds, infos = step_returns logger.deprecation( - "Using a vector wrapper to transform new step API (which returns two bool vectors terminateds, truncateds) into old (returns one bool vector dones). " + "[StepAPI] Using a vector wrapper to transform new step API (which returns two bool vectors terminateds, truncateds) into old (returns one bool vector dones). " "This wrapper will be removed in the future. " "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " ) diff --git a/gym/wrappers/step_compatibility.py b/gym/wrappers/step_compatibility.py index a6c7f8fe1ee..099d32f0e17 100644 --- a/gym/wrappers/step_compatibility.py +++ b/gym/wrappers/step_compatibility.py @@ -8,7 +8,7 @@ def __init__(self, env, return_two_dones=False): self._return_two_dones = return_two_dones if not self._return_two_dones: logger.deprecation( - "Initializing environment in old step API which returns one bool instead of two. " + "[StepAPI] Initializing environment in old step API which returns one bool instead of two. " "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " "To use these features, please set `return_two_dones=True` in make to use new API (see docs for more details)." ) @@ -18,7 +18,7 @@ def step(self, action): if self._return_two_dones: if len(step_returns) == 5: logger.deprecation( - "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " + "[StepAPI] Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " "Take care to update supporting code to be compatible with this API" ) return step_returns @@ -27,7 +27,7 @@ def step(self, action): else: if len(step_returns) == 4: logger.deprecation( - "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" + "[StepAPI] Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" ) return step_returns @@ -37,7 +37,7 @@ def step(self, action): def _step_returns_old_to_new(self, step_returns): assert len(step_returns) == 4 logger.deprecation( - "Using a wrapper to transform env with old step API into new. This wrapper will be removed in the future. " + "[StepAPI] Using a wrapper to transform env with old step API into new. This wrapper will be removed in the future. " "It is recommended to upgrade the core env to the new step API." "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate" "Otherwise, `terminated=done` and `truncated=False`" @@ -62,7 +62,7 @@ def _step_returns_old_to_new(self, step_returns): def _step_returns_new_to_old(self, step_returns): assert len(step_returns) == 5 logger.deprecation( - "Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " + "[StepAPI] Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " "This wrapper will be removed in the future. " "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " ) diff --git a/pyproject.toml b/pyproject.toml index 184536f7f22..4a932770056 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,4 +20,4 @@ reportMissingTypeStubs = false verboseOutput = true [tool.pytest.ini_options] -filterwarnings = ["ignore::DeprecationWarning"] \ No newline at end of file +filterwarnings = ['ignore:.*\[StepAPI\].*:DeprecationWarning'] # to be removed at 1.0 when old step API is removed From c5fe53cb9b51f04a6e1e5d2e2137ab1d8e014590 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Fri, 22 Apr 2022 00:44:11 +0530 Subject: [PATCH 06/37] fix duplicate wrapping bug in vector envs --- gym/vector/__init__.py | 3 +-- gym/wrappers/step_compatibility.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index ed37aeb3cc4..308d44d34ea 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -52,8 +52,7 @@ def make( from gym.envs import make as make_ def _make_env(): - env = make_(id, **kwargs) - env = StepCompatibility(env, return_two_dones=True) + env = make_(id, return_two_dones=True, **kwargs) if wrappers is not None: if callable(wrappers): env = wrappers(env) diff --git a/gym/wrappers/step_compatibility.py b/gym/wrappers/step_compatibility.py index 099d32f0e17..4a1f88a5ad0 100644 --- a/gym/wrappers/step_compatibility.py +++ b/gym/wrappers/step_compatibility.py @@ -39,7 +39,7 @@ def _step_returns_old_to_new(self, step_returns): logger.deprecation( "[StepAPI] Using a wrapper to transform env with old step API into new. This wrapper will be removed in the future. " "It is recommended to upgrade the core env to the new step API." - "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate" + "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate. " "Otherwise, `terminated=done` and `truncated=False`" ) From 6af7182a3f461554604799cf1ca367b3da2c0bf0 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Mon, 25 Apr 2022 15:26:08 +0530 Subject: [PATCH 07/37] edit docstrings, comments, warnings --- gym/core.py | 5 ++-- gym/vector/step_compatibility_vector.py | 19 ++++++++++++-- gym/wrappers/step_compatibility.py | 34 ++++++++++++++++++++----- pyproject.toml | 2 +- 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/gym/core.py b/gym/core.py index fdc6fe705ed..90e8723876e 100644 --- a/gym/core.py +++ b/gym/core.py @@ -81,13 +81,12 @@ def step( Returns: observation (object): agent's observation of the current environment. This will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects. reward (float) : amount of reward returned after previous action - terminated (bool): whether the episode has ended due to a termination, in which case further step() calls will return undefined results - truncated (bool): whether the episode has ended due to a truncation, in which case further step() calls will return undefined results + terminated (bool): whether the episode has ended due to reaching a terminal state intrinsic to the core environment, in which case further step() calls will return undefined results + truncated (bool): whether the episode has ended due to a truncation, i.e., a timelimit outside the scope of the problem defined in the environment. info (dict): contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain: - metrics that describe the agent's performance or - state variables that are hidden from observations or - - information that distinguishes truncation and termination or - individual reward terms that are combined to produce the total reward (deprecated) diff --git a/gym/vector/step_compatibility_vector.py b/gym/vector/step_compatibility_vector.py index f277fc4f43f..c8fba8644c2 100644 --- a/gym/vector/step_compatibility_vector.py +++ b/gym/vector/step_compatibility_vector.py @@ -6,6 +6,21 @@ class StepCompatibilityVector(VectorEnvWrapper): + r"""A wrapper which can transform a vector environment to a new or old step API. + + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) + (Refer to docs for details on the API change) + + This wrapper is to be used to ease transition to new API. It will be removed in v1.0 + + Parameters + ---------- + env (gym.vector.VectorEnv): the vector env to wrap. Has to be in new step API + return_two_dones (bool): True to use vector env with new step API, False to use vector env with old step API. (True by default) + + """ + def __init__(self, env, return_two_dones=True): super().__init__(env) self._return_two_dones = return_two_dones @@ -21,8 +36,8 @@ def _step_returns_new_to_old(self, step_returns): assert len(step_returns) == 5 observations, rewards, terminateds, truncateds, infos = step_returns logger.deprecation( - "[StepAPI] Using a vector wrapper to transform new step API (which returns two bool vectors terminateds, truncateds) into old (returns one bool vector dones). " - "This wrapper will be removed in the future. " + "Using a vector wrapper to transform new step API (which returns two bool vectors terminateds, truncateds) into old (returns one bool vector dones). " + "This wrapper will be removed in v1.0. " "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " ) dones = [] diff --git a/gym/wrappers/step_compatibility.py b/gym/wrappers/step_compatibility.py index 4a1f88a5ad0..7099b866510 100644 --- a/gym/wrappers/step_compatibility.py +++ b/gym/wrappers/step_compatibility.py @@ -3,12 +3,28 @@ class StepCompatibility(gym.Wrapper): + r"""A wrapper which can transform an environment from new step API to old and vice-versa. + + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) + (Refer to docs for details on the API change) + + This wrapper is to be used to ease transition to new API and for backward compatibility. It will be removed in v1.0 + + + Parameters + ---------- + env (gym.Env): the env to wrap. Can be in old or new API + return_two_dones (bool): True to use env with new step API, False to use env with old step API. (False by default) + + """ + def __init__(self, env, return_two_dones=False): super().__init__(env) self._return_two_dones = return_two_dones if not self._return_two_dones: logger.deprecation( - "[StepAPI] Initializing environment in old step API which returns one bool instead of two. " + "Initializing environment in old step API which returns one bool instead of two. " "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " "To use these features, please set `return_two_dones=True` in make to use new API (see docs for more details)." ) @@ -18,8 +34,8 @@ def step(self, action): if self._return_two_dones: if len(step_returns) == 5: logger.deprecation( - "[StepAPI] Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " - "Take care to update supporting code to be compatible with this API" + "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " + "Take care to supporting code to be compatible with this API" ) return step_returns else: @@ -27,7 +43,7 @@ def step(self, action): else: if len(step_returns) == 4: logger.deprecation( - "[StepAPI] Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" + "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" ) return step_returns @@ -35,9 +51,11 @@ def step(self, action): return self._step_returns_new_to_old(step_returns) def _step_returns_old_to_new(self, step_returns): + # Method to transform old step API to new + assert len(step_returns) == 4 logger.deprecation( - "[StepAPI] Using a wrapper to transform env with old step API into new. This wrapper will be removed in the future. " + "Using a wrapper to transform env with old step API into new. This wrapper will be removed in v1.0. " "It is recommended to upgrade the core env to the new step API." "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate. " "Otherwise, `terminated=done` and `truncated=False`" @@ -60,10 +78,12 @@ def _step_returns_old_to_new(self, step_returns): return obs, rew, terminated, truncated, info def _step_returns_new_to_old(self, step_returns): + # Method to transform new step API to old + assert len(step_returns) == 5 logger.deprecation( - "[StepAPI] Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " - "This wrapper will be removed in the future. " + "Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " + "This wrapper will be removed in v1.0 " "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " ) diff --git a/pyproject.toml b/pyproject.toml index 4a932770056..b82e848aa1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,4 +20,4 @@ reportMissingTypeStubs = false verboseOutput = true [tool.pytest.ini_options] -filterwarnings = ['ignore:.*\[StepAPI\].*:DeprecationWarning'] # to be removed at 1.0 when old step API is removed +filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # to be removed at 1.0 when old step API is removed From 68ef969c3d79ee3bb79bcb3d49fe620a5e043a98 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 4 May 2022 13:17:43 +0530 Subject: [PATCH 08/37] step compatibility for wrappers, vectors --- gym/core.py | 5 +- gym/envs/registration.py | 8 +- gym/vector/__init__.py | 10 +- gym/vector/async_vector_env.py | 11 +- gym/vector/step_compatibility_vector.py | 82 +++++------- gym/vector/sync_vector_env.py | 5 +- gym/vector/vector_env.py | 4 + gym/wrappers/__init__.py | 2 +- gym/wrappers/autoreset.py | 10 +- gym/wrappers/frame_stack.py | 7 +- gym/wrappers/normalize.py | 12 +- gym/wrappers/record_episode_statistics.py | 12 +- gym/wrappers/record_video.py | 14 +- gym/wrappers/step_compatibility.py | 148 +++++++++++++++------- gym/wrappers/time_aware_observation.py | 3 + gym/wrappers/time_limit.py | 2 +- 16 files changed, 217 insertions(+), 118 deletions(-) diff --git a/gym/core.py b/gym/core.py index 83d007ff969..097e628cabf 100644 --- a/gym/core.py +++ b/gym/core.py @@ -303,7 +303,7 @@ def step( ) -> Union[ Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] ]: - return self.env.step(action) + return self._get_env_step_returns(action) def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]: return self.env.reset(**kwargs) @@ -327,6 +327,9 @@ def __repr__(self): def unwrapped(self) -> Env: return self.env.unwrapped + def _get_env_step_returns(self, action): + return self.env.step(action) + class ObservationWrapper(Wrapper): def reset(self, **kwargs): diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 864b22944bf..46931ffab4d 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -97,7 +97,7 @@ class EnvSpec: max_episode_steps: Optional[int] = field(default=None) order_enforce: bool = field(default=True) autoreset: bool = field(default=False) - return_two_dones: bool = field(default=False) + new_step_api: bool = field(default=False) kwargs: dict = field(default_factory=dict) namespace: Optional[str] = field(init=False) @@ -429,7 +429,7 @@ def make( id: str | EnvSpec, max_episode_steps: Optional[int] = None, autoreset: bool = False, - return_two_dones: bool = False, + new_step_api: bool = False, **kwargs, ) -> Env: """ @@ -439,7 +439,7 @@ def make( id: Name of the environment. max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). - return_two_dones: Whether to use old or new step API (StepCompatibility wrapper). Will be removed at v1.0 + new_step_api: Whether to use old or new step API (StepCompatibility wrapper). Will be removed at v1.0 kwargs: Additional arguments to pass to the environment constructor. Returns: An instance of the environment. @@ -495,7 +495,7 @@ def make( if spec_.order_enforce: env = OrderEnforcing(env) - env = StepCompatibility(env, return_two_dones) + env = StepCompatibility(env, new_step_api) if max_episode_steps is not None: env = TimeLimit(env, max_episode_steps) diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index 308d44d34ea..b04319ffa58 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -4,16 +4,14 @@ Iterable = (tuple, list) from gym.vector.async_vector_env import AsyncVectorEnv -from gym.vector.step_compatibility_vector import StepCompatibilityVector from gym.vector.sync_vector_env import SyncVectorEnv from gym.vector.vector_env import VectorEnv, VectorEnvWrapper -from gym.wrappers import StepCompatibility __all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] def make( - id, num_envs=1, asynchronous=True, wrappers=None, return_two_dones=True, **kwargs + id, num_envs=1, asynchronous=True, wrappers=None, new_step_api=False, **kwargs ): """Create a vectorized environment from multiple copies of an environment, from its id. @@ -52,7 +50,7 @@ def make( from gym.envs import make as make_ def _make_env(): - env = make_(id, return_two_dones=True, **kwargs) + env = make_(id, new_step_api=True, **kwargs) if wrappers is not None: if callable(wrappers): env = wrappers(env) @@ -67,7 +65,7 @@ def _make_env(): env_fns = [_make_env for _ in range(num_envs)] return ( - StepCompatibilityVector(AsyncVectorEnv(env_fns), return_two_dones) + AsyncVectorEnv(env_fns, new_step_api=new_step_api) if asynchronous - else StepCompatibilityVector(SyncVectorEnv(env_fns), return_two_dones) + else SyncVectorEnv(env_fns, new_step_api=new_step_api) ) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 6af3b357b9b..f097ba0e5c4 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -14,6 +14,7 @@ CustomSpaceError, NoAsyncCallError, ) +from gym.vector.step_compatibility_vector import step_api_vector_compatibility from gym.vector.utils import ( CloudpickleWrapper, clear_mpi_env_vars, @@ -25,6 +26,7 @@ write_to_shared_memory, ) from gym.vector.vector_env import VectorEnv +from gym.wrappers.step_compatibility import step_to_new_api __all__ = ["AsyncVectorEnv"] @@ -36,6 +38,7 @@ class AsyncState(Enum): WAITING_CALL = "call" +@step_api_vector_compatibility class AsyncVectorEnv(VectorEnv): """Vectorized environment that runs multiple environments in parallel. It uses `multiprocessing`_ processes, and pipes for communication. @@ -650,7 +653,9 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): pipe.send((observation, True)) elif command == "step": - observation, reward, terminated, truncated, info = env.step(data) + observation, reward, terminated, truncated, info = step_to_new_api( + env.step(data) + ) if terminated or truncated: info["closing_observation"] = observation observation = env.reset() @@ -719,7 +724,9 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error ) pipe.send((None, True)) elif command == "step": - observation, reward, terminated, truncated, info = env.step(data) + observation, reward, terminated, truncated, info = step_to_new_api( + env.step(data) + ) if terminated or truncated: info["closing_observation"] = observation observation = env.reset() diff --git a/gym/vector/step_compatibility_vector.py b/gym/vector/step_compatibility_vector.py index c8fba8644c2..01ccfbcd41f 100644 --- a/gym/vector/step_compatibility_vector.py +++ b/gym/vector/step_compatibility_vector.py @@ -1,51 +1,37 @@ -import numpy as np - -import gym -from gym import logger from gym.vector.vector_env import VectorEnvWrapper +from gym.wrappers.step_compatibility import step_to_new_api, step_to_old_api + + +def step_api_vector_compatibility(VectorEnvClass): + class StepCompatibilityVector(VectorEnvWrapper): + r"""A wrapper which can transform a vector environment to a new or old step API. + + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) + (Refer to docs for details on the API change) + + This wrapper is to be used to ease transition to new API. It will be removed in v1.0 + + Parameters + ---------- + env (gym.vector.VectorEnv): the vector env to wrap. Has to be in new step API + new_step_api (bool): True to use vector env with new step API, False to use vector env with old step API. (True by default) + + """ + + def __init__(self, *args, **kwargs): + self.new_step_api = kwargs.get("new_step_api", False) + kwargs.pop("new_step_api", None) + super().__init__(VectorEnvClass(*args, **kwargs)) + + def step_wait(self): + step_returns = self.env.step_wait() + if self.new_step_api: + return step_to_new_api(step_returns) + else: + return step_to_old_api(step_returns) + def __del__(self): + self.env.__del__() -class StepCompatibilityVector(VectorEnvWrapper): - r"""A wrapper which can transform a vector environment to a new or old step API. - - Old step API refers to step() method returning (observation, reward, done, info) - New step API refers to step() method returning (observation, reward, terminated, truncated, info) - (Refer to docs for details on the API change) - - This wrapper is to be used to ease transition to new API. It will be removed in v1.0 - - Parameters - ---------- - env (gym.vector.VectorEnv): the vector env to wrap. Has to be in new step API - return_two_dones (bool): True to use vector env with new step API, False to use vector env with old step API. (True by default) - - """ - - def __init__(self, env, return_two_dones=True): - super().__init__(env) - self._return_two_dones = return_two_dones - - def step_wait(self): - step_returns = self.env.step_wait() - if self._return_two_dones: - return step_returns - else: - return self._step_returns_new_to_old(step_returns) - - def _step_returns_new_to_old(self, step_returns): - assert len(step_returns) == 5 - observations, rewards, terminateds, truncateds, infos = step_returns - logger.deprecation( - "Using a vector wrapper to transform new step API (which returns two bool vectors terminateds, truncateds) into old (returns one bool vector dones). " - "This wrapper will be removed in v1.0. " - "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " - ) - dones = [] - for i in range(len(terminateds)): - dones.append(terminateds[i] or truncateds[i]) - if truncateds[i]: - infos[i]["TimeLimit.truncated"] = not terminateds[i] - return observations, rewards, np.array(dones, dtype=np.bool_), infos - - def __del__(self): - self.env.__del__() + return StepCompatibilityVector diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index ebe25c511b4..78652b478c1 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -3,12 +3,15 @@ import numpy as np +from gym.vector.step_compatibility_vector import step_api_vector_compatibility from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.vector_env import VectorEnv +from gym.wrappers.step_compatibility import step_to_new_api __all__ = ["SyncVectorEnv"] +@step_api_vector_compatibility class SyncVectorEnv(VectorEnv): """Vectorized environment that serially runs multiple environments. @@ -141,7 +144,7 @@ def step_wait(self): self._terminateds[i], self._truncateds[i], info, - ) = env.step(action) + ) = step_to_new_api(env.step(action)) if self._terminateds[i] or self._truncateds[i]: info["closing_observation"] = observation observation = env.reset() diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index 2b6a7e8ddfd..78004e8f5dd 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -224,6 +224,10 @@ def __repr__(self): else: return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})" + @staticmethod + def get_env_step_return(env, action): + return env.step(action) + class VectorEnvWrapper(VectorEnv): r"""Wraps the vectorized environment to allow a modular transformation. diff --git a/gym/wrappers/__init__.py b/gym/wrappers/__init__.py index d0e70218941..549af6e2792 100644 --- a/gym/wrappers/__init__.py +++ b/gym/wrappers/__init__.py @@ -11,7 +11,7 @@ from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.resize_observation import ResizeObservation -from gym.wrappers.step_compatibility import StepCompatibility +from gym.wrappers.step_compatibility import StepCompatibility, step_api_compatibility from gym.wrappers.time_aware_observation import TimeAwareObservation from gym.wrappers.time_limit import TimeLimit from gym.wrappers.transform_observation import TransformObservation diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index 4873872cbe0..ef829b8f41a 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -1,6 +1,8 @@ import gym +from gym.wrappers.step_compatibility import step_api_compatibility +@step_api_compatibility class AutoResetWrapper(gym.Wrapper): """ A class for providing an automatic reset functionality @@ -38,8 +40,14 @@ class AutoResetWrapper(gym.Wrapper): use this wrapper! """ + new_step_api = True # whether this wrapper is written in new API (assumed old API if not present) + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + self.new_step_api = True + def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) + obs, reward, terminated, truncated, info = self._get_env_step_returns(action) if terminated or truncated: diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index 8a9c36924fd..5281dd50863 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -4,6 +4,7 @@ from gym import ObservationWrapper from gym.spaces import Box +from gym.wrappers import step_api_compatibility class LazyFrames: @@ -62,6 +63,7 @@ def _check_decompress(self, frame): return frame +@step_api_compatibility class FrameStack(ObservationWrapper): r"""Observation wrapper that stacks the observations in a rolling manner. @@ -93,6 +95,7 @@ class FrameStack(ObservationWrapper): lz4_compress (bool): use lz4 to compress the frames internally """ + new_step_api = True def __init__(self, env, num_stack, lz4_compress=False): super().__init__(env) @@ -114,7 +117,9 @@ def observation(self): return LazyFrames(list(self.frames), self.lz4_compress) def step(self, action): - observation, reward, terminated, truncated, info = self.env.step(action) + observation, reward, terminated, truncated, info = self._get_env_step_returns( + action + ) self.frames.append(observation) return self.observation(), reward, terminated, truncated, info diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index 372f6654d54..2b420308f34 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -1,6 +1,7 @@ import numpy as np import gym +from gym.wrappers import step_api_compatibility # taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py @@ -50,6 +51,8 @@ class NormalizeObservation(gym.core.Wrapper): epsilon: A stability parameter that is used when scaling the observations. """ + new_step_api = True + def __init__( self, env, @@ -65,7 +68,7 @@ def __init__( self.epsilon = epsilon def step(self, action): - obs, rews, terminateds, truncateds, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = self._get_env_step_returns(action) if self.is_vector_env: obs = self.normalize(obs) else: @@ -92,6 +95,7 @@ def normalize(self, obs): return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon) +@step_api_compatibility class NormalizeReward(gym.core.Wrapper): """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. @@ -106,6 +110,8 @@ class NormalizeReward(gym.core.Wrapper): gamma (float): The discount factor that is used in the exponential moving average. """ + new_step_api = True + def __init__( self, env, @@ -121,12 +127,12 @@ def __init__( self.epsilon = epsilon def step(self, action): - obs, rews, terminateds, truncateds, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = self._get_env_step_returns(action) if not self.is_vector_env: rews = np.array([rews]) self.returns = self.returns * self.gamma + rews rews = self.normalize(rews) - if not self.is_vector_env: # TODO: Check this + if not self.is_vector_env: dones = terminateds or truncateds else: dones = np.bitwise_or(terminateds, truncateds) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index c541ac2c035..3b18e866edb 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -4,8 +4,10 @@ import numpy as np import gym +from gym.wrappers import step_api_compatibility +@step_api_compatibility class RecordEpisodeStatistics(gym.Wrapper): """This wrapper will keep track of cumulative rewards and episode lengths. @@ -35,6 +37,8 @@ class RecordEpisodeStatistics(gym.Wrapper): length_queue: The lengths of the last `deque_size`-many episodes """ + new_step_api = True + def __init__(self, env, deque_size=100): super().__init__(env) self.num_envs = getattr(env, "num_envs", 1) @@ -53,7 +57,13 @@ def reset(self, **kwargs): return observations def step(self, action): - observations, rewards, terminateds, truncateds, infos = super().step(action) + ( + observations, + rewards, + terminateds, + truncateds, + infos, + ) = self._get_env_step_returns(action) self.episode_returns += rewards self.episode_lengths += 1 if not self.is_vector_env: diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index e0e684e47e4..4d9b53bd637 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -1,10 +1,9 @@ import os from typing import Callable -import numpy as np - import gym from gym import logger +from gym.wrappers import step_api_compatibility from gym.wrappers.monitoring import video_recorder @@ -15,6 +14,7 @@ def capped_cubic_video_schedule(episode_id): return episode_id % 1000 == 0 +@step_api_compatibility class RecordVideo(gym.Wrapper): """This wrapper records videos of rollouts. @@ -37,6 +37,8 @@ class RecordVideo(gym.Wrapper): name_prefix (str): Will be prepended to the filename of the recordings """ + new_step_api = True + def __init__( self, env, @@ -106,7 +108,13 @@ def _video_enabled(self): return self.episode_trigger(self.episode_id) def step(self, action): - observations, rewards, terminateds, truncateds, infos = super().step(action) + ( + observations, + rewards, + terminateds, + truncateds, + infos, + ) = self._get_env_step_returns(action) # increment steps and episodes self.step_id += 1 diff --git a/gym/wrappers/step_compatibility.py b/gym/wrappers/step_compatibility.py index 7099b866510..b0579305fa5 100644 --- a/gym/wrappers/step_compatibility.py +++ b/gym/wrappers/step_compatibility.py @@ -15,44 +15,38 @@ class StepCompatibility(gym.Wrapper): Parameters ---------- env (gym.Env): the env to wrap. Can be in old or new API - return_two_dones (bool): True to use env with new step API, False to use env with old step API. (False by default) + new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) """ - def __init__(self, env, return_two_dones=False): + def __init__(self, env: gym.Env, new_step_api=False): super().__init__(env) - self._return_two_dones = return_two_dones - if not self._return_two_dones: + self.new_step_api = new_step_api + if not self.new_step_api: logger.deprecation( "Initializing environment in old step API which returns one bool instead of two. " "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " - "To use these features, please set `return_two_dones=True` in make to use new API (see docs for more details)." + "To use these features, please set `new_step_api=True` in make to use new API (see docs for more details)." ) def step(self, action): step_returns = self.env.step(action) - if self._return_two_dones: - if len(step_returns) == 5: - logger.deprecation( - "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " - "Take care to supporting code to be compatible with this API" - ) - return step_returns - else: - return self._step_returns_old_to_new(step_returns) + if self.new_step_api: + return step_to_new_api(step_returns) else: - if len(step_returns) == 4: - logger.deprecation( - "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" - ) + return step_to_old_api(step_returns) - return step_returns - elif len(step_returns) == 5: - return self._step_returns_new_to_old(step_returns) - def _step_returns_old_to_new(self, step_returns): - # Method to transform old step API to new +def step_to_new_api(step_returns, is_vector_env=False): + # Method to transform step returns to new step API + if len(step_returns) == 5: + logger.deprecation( + "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " + "Take care to supporting code to be compatible with this API" + ) + return step_returns + else: assert len(step_returns) == 4 logger.deprecation( "Using a wrapper to transform env with old step API into new. This wrapper will be removed in v1.0. " @@ -61,25 +55,46 @@ def _step_returns_old_to_new(self, step_returns): "Otherwise, `terminated=done` and `truncated=False`" ) - obs, rew, done, info = step_returns - if "TimeLimit.truncated" not in info: - terminated = done - truncated = False - elif info["TimeLimit.truncated"]: - terminated = False - truncated = True - else: - # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, - # but it also exceeded maximum timesteps at the same step. + observations, rewards, dones, infos = step_returns + + terminateds = [] + truncateds = [] + if not is_vector_env: + dones = [dones] + infos = [infos] + for i in range(len(dones)): + if "TimeLimit.truncated" not in infos[i]: + terminateds.append(dones[i]) + truncateds.append(False) + elif infos[i]["TimeLimit.truncated"]: + terminateds.append(False) + truncateds.append(True) + else: + # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, + # but it also exceeded maximum timesteps at the same step. + + terminateds.append(True) + truncateds.append(True) + + return ( + observations, + rewards, + terminateds if is_vector_env else terminateds[0], + truncateds if is_vector_env else truncateds[0], + infos if is_vector_env else infos[0], + ) - terminated = True - truncated = True - return obs, rew, terminated, truncated, info +def step_to_old_api(step_returns, is_vector_env=False): + # Method to transform step returns to old step API - def _step_returns_new_to_old(self, step_returns): - # Method to transform new step API to old + if len(step_returns) == 4: + logger.deprecation( + "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" + ) + return step_returns + else: assert len(step_returns) == 5 logger.deprecation( "Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " @@ -87,10 +102,53 @@ def _step_returns_new_to_old(self, step_returns): "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " ) - obs, reward, terminated, truncated, info = step_returns - done = terminated or truncated - if truncated: - info[ - "TimeLimit.truncated" - ] = not terminated # to be consistent with old API - return obs, reward, done, info + observations, rewards, terminateds, truncateds, infos = step_returns + dones = [] + if not is_vector_env: + terminateds = [terminateds] + truncateds = [truncateds] + infos = [infos] + + for i in range(len(terminateds)): + dones.append(terminateds[i] or truncateds[i]) + # to be consistent with old API + if truncateds[i]: + infos[i]["TimeLimit.truncated"] = not terminateds[i] + return ( + observations, + rewards, + dones if is_vector_env else dones[0], + infos if is_vector_env else infos[0], + ) + + +def step_api_compatibility(WrapperClass): + """ + A step API compatibility wrapper function to transform wrappers in new step API to old + """ + + class StepCompatibilityWrapper(StepCompatibility): + def __init__(self, env: gym.Wrapper, output_new_step_api: bool = False): + super().__init__(WrapperClass(env), output_new_step_api) + if hasattr(WrapperClass, "new_step_api"): + self.has_new_step_api = WrapperClass.new_step_api + else: + self.has_new_step_api = False + self.wrap = WrapperClass(env) + + def _get_env_step_returns(self, action): + return ( + step_to_new_api(self.wrap.step(action)) + if self.has_new_step_api + else step_to_old_api(self.wrap.step(action)) + ) + + return StepCompatibilityWrapper + + +# def check_is_new_api(env: Union[gym.Env, gym.Wrapper]): +# env_copy = deepcopy(env) +# env_copy.reset() +# step_returns = env_copy.step(env_copy.action_space.sample()) +# del env_copy +# return len(step_returns) == 5 diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py index 7f22e681480..c623be82783 100644 --- a/gym/wrappers/time_aware_observation.py +++ b/gym/wrappers/time_aware_observation.py @@ -2,8 +2,10 @@ from gym import ObservationWrapper from gym.spaces import Box +from gym.wrappers import step_api_compatibility +@step_api_compatibility class TimeAwareObservation(ObservationWrapper): r"""Augment the observation with current time step in the trajectory. @@ -12,6 +14,7 @@ class TimeAwareObservation(ObservationWrapper): support pixel observation space yet. """ + new_step_api = True def __init__(self, env): super().__init__(env) diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index b73a400711b..ebf85ea6f49 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -28,7 +28,7 @@ def __init__(self, env, max_episode_steps: Optional[int] = None): self._elapsed_steps = None def step(self, action): - step_returns = self.env.step(action) + step_returns = self._get_env_step_returns(action) if len(step_returns) == 4: observation, reward, done, info = self.env.step(action) if self._elapsed_steps >= self._max_episode_steps: From f06343b86920f30e253902cc1f957593273c95cc Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 4 May 2022 13:21:17 +0530 Subject: [PATCH 09/37] reset tests back to old api --- tests/envs/spec_list.py | 8 +- tests/envs/test_action_dim_check.py | 4 +- tests/envs/test_atari_legacy_env_specs.py | 8 +- tests/envs/test_determinism.py | 9 +- tests/envs/test_envs.py | 26 ++-- tests/envs/test_frozenlake_dfs.py | 3 + tests/envs/test_mujoco_v2_to_v3_conversion.py | 23 +--- tests/envs/test_registration.py | 123 +++++++++++++----- tests/spaces/test_spaces.py | 8 +- tests/test_core.py | 7 +- tests/utils/test_env_checker.py | 9 +- tests/utils/test_play.py | 4 +- tests/vector/test_async_vector_env.py | 33 ++--- tests/vector/test_shared_memory.py | 15 +-- tests/vector/test_spaces.py | 69 +--------- tests/vector/test_sync_vector_env.py | 50 +------ tests/vector/test_vector_env.py | 16 +-- tests/vector/test_vector_env_wrapper.py | 1 + tests/vector/utils.py | 23 ++-- tests/wrappers/nested_dict_test.py | 6 +- tests/wrappers/test_atari_preprocessing.py | 6 +- tests/wrappers/test_autoreset.py | 64 ++++----- tests/wrappers/test_clip_action.py | 12 +- tests/wrappers/test_filter_observation.py | 4 +- tests/wrappers/test_frame_stack.py | 12 +- tests/wrappers/test_normalize.py | 16 +-- tests/wrappers/test_order_enforcing.py | 3 + tests/wrappers/test_pixel_observation.py | 4 +- .../test_record_episode_statistics.py | 20 ++- tests/wrappers/test_record_video.py | 31 +++-- tests/wrappers/test_rescale_action.py | 10 +- tests/wrappers/test_time_aware_observation.py | 6 +- tests/wrappers/test_time_limit.py | 3 + tests/wrappers/test_transform_observation.py | 21 +-- tests/wrappers/test_transform_reward.py | 30 ++--- 35 files changed, 294 insertions(+), 393 deletions(-) diff --git a/tests/envs/spec_list.py b/tests/envs/spec_list.py index a0f192b7c2a..11c816f6bd1 100644 --- a/tests/envs/spec_list.py +++ b/tests/envs/spec_list.py @@ -11,7 +11,7 @@ skip_mujoco = not (os.environ.get("MUJOCO_KEY")) if not skip_mujoco: try: - import mujoco_py # noqa:F401 + import mujoco_py except ImportError: skip_mujoco = True @@ -24,12 +24,12 @@ def should_skip_env_spec_for_tests(spec): if skip_mujoco and ep.startswith("gym.envs.mujoco"): return True try: - import gym.envs.atari # noqa:F401 + import gym.envs.atari except ImportError: if ep.startswith("gym.envs.atari"): return True try: - import Box2D # noqa:F401 + import Box2D except ImportError: if ep.startswith("gym.envs.box2d"): return True @@ -50,6 +50,6 @@ def should_skip_env_spec_for_tests(spec): spec_list = [ spec - for spec in sorted(envs.registry.values(), key=lambda x: x.id) + for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec.entry_point is not None and not should_skip_env_spec_for_tests(spec) ] diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 51c74ef7f88..448847cd6ec 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -1,3 +1,5 @@ +import pickle + import pytest from gym import envs @@ -9,7 +11,7 @@ @pytest.mark.skipif(skip_mujoco, reason=SKIP_MUJOCO_WARNING_MESSAGE) @pytest.mark.parametrize("environment_id", ENVIRONMENT_IDS) def test_serialize_deserialize(environment_id): - env = envs.make(environment_id, return_two_dones=True) + env = envs.make(environment_id) env.reset() with pytest.raises(ValueError, match="Action dimension mismatch"): diff --git a/tests/envs/test_atari_legacy_env_specs.py b/tests/envs/test_atari_legacy_env_specs.py index 5a1c406e4cd..37e99245c8d 100644 --- a/tests/envs/test_atari_legacy_env_specs.py +++ b/tests/envs/test_atari_legacy_env_specs.py @@ -1,11 +1,11 @@ -from itertools import product - import pytest -from gym.envs.registration import registry - pytest.importorskip("gym.envs.atari") +from itertools import product + +from gym.envs.registration import registry + def test_ale_legacy_env_specs(): versions = ["-v0", "-v4"] diff --git a/tests/envs/test_determinism.py b/tests/envs/test_determinism.py index d842c66c293..fef13aacd69 100644 --- a/tests/envs/test_determinism.py +++ b/tests/envs/test_determinism.py @@ -9,14 +9,14 @@ def test_env(spec): # Note that this precludes running this test in multiple # threads. However, we probably already can't do multithreading # due to some environments. - env1 = spec.make(return_two_dones=True) + env1 = spec.make() initial_observation1 = env1.reset(seed=0) env1.action_space.seed(0) action_samples1 = [env1.action_space.sample() for i in range(4)] step_responses1 = [env1.step(action) for action in action_samples1] env1.close() - env2 = spec.make(return_two_dones=True) + env2 = spec.make() initial_observation2 = env2.reset(seed=0) env2.action_space.seed(0) action_samples2 = [env2.action_space.sample() for i in range(4)] @@ -45,13 +45,12 @@ def test_env(spec): assert_equals(initial_observation1, initial_observation2) - for i, ((o1, r1, term1, trunc1, i1), (o2, r2, term2, trunc2, i2)) in enumerate( + for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate( zip(step_responses1, step_responses2) ): assert_equals(o1, o2, f"[{i}] ") assert r1 == r2, f"[{i}] r1: {r1}, r2: {r2}" - assert term1 == term2, f"[{i}] term1: {term1}, term2: {term2}" - assert trunc1 == trunc2, f"[{i}] trunc1: {trunc1}, trunc2: {trunc2}" + assert d1 == d2, f"[{i}] d1: {d1}, d2: {d2}" # Go returns a Pachi game board in info, which doesn't # properly check equality. For now, we hack around this by diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 4819452789c..23de61b98b2 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -17,7 +17,7 @@ def test_env(spec): # Capture warnings with pytest.warns(None) as warnings: - env = spec.make(return_two_dones=True) + env = spec.make() # Test if env adheres to Gym API check_env(env, warn=True, skip_render_check=True) @@ -37,13 +37,12 @@ def test_env(spec): ), f"Reset observation dtype: {ob.dtype}, expected: {ob_space.dtype}" a = act_space.sample() - observation, reward, terminated, truncated, _info = env.step(a) + observation, reward, done, _info = env.step(a) assert ob_space.contains( observation ), f"Step observation: {observation!r} not in space" assert np.isscalar(reward), f"{reward} is not a scalar for {env}" - assert isinstance(terminated, bool), f"Expected {terminated} to be a boolean" - assert isinstance(truncated, bool), f"Expected {truncated} to be a boolean" + assert isinstance(done, bool), f"Expected {done} to be a boolean" if isinstance(ob_space, Box): assert ( observation.dtype == ob_space.dtype @@ -62,8 +61,8 @@ def test_env(spec): @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) def test_reset_info(spec): - with pytest.warns(None): - env = spec.make(return_two_dones=True) + with pytest.warns(None) as warnings: + env = spec.make() ob_space = env.observation_space obs = env.reset() @@ -78,22 +77,23 @@ def test_reset_info(spec): # Run a longer rollout on some environments def test_random_rollout(): - for env in [envs.make("CartPole-v1", return_two_dones=True), envs.make("FrozenLake-v1", return_two_dones=True)]: + for env in [envs.make("CartPole-v1"), envs.make("FrozenLake-v1")]: + agent = lambda ob: env.action_space.sample() ob = env.reset() for _ in range(10): assert env.observation_space.contains(ob) - action = env.action_space.sample() - assert env.action_space.contains(action) - (ob, _reward, terminated, truncated, _info) = env.step(action) - if terminated or truncated: + a = agent(ob) + assert env.action_space.contains(a) + (ob, _reward, done, _info) = env.step(a) + if done: break env.close() def test_env_render_result_is_immutable(): environs = [ - envs.make("Taxi-v3", return_two_dones=True), - envs.make("FrozenLake-v1", return_two_dones=True), + envs.make("Taxi-v3"), + envs.make("FrozenLake-v1"), ] for env in environs: diff --git a/tests/envs/test_frozenlake_dfs.py b/tests/envs/test_frozenlake_dfs.py index b620cbcb042..3cfeddf2b73 100644 --- a/tests/envs/test_frozenlake_dfs.py +++ b/tests/envs/test_frozenlake_dfs.py @@ -1,3 +1,6 @@ +import numpy as np +import pytest + from gym.envs.toy_text.frozen_lake import generate_random_map diff --git a/tests/envs/test_mujoco_v2_to_v3_conversion.py b/tests/envs/test_mujoco_v2_to_v3_conversion.py index b0b9d7ad4d9..201d497667e 100644 --- a/tests/envs/test_mujoco_v2_to_v3_conversion.py +++ b/tests/envs/test_mujoco_v2_to_v3_conversion.py @@ -9,8 +9,8 @@ def verify_environments_match( old_environment_id, new_environment_id, seed=1, num_actions=1000 ): - old_environment = envs.make(old_environment_id, return_two_dones=True) - new_environment = envs.make(new_environment_id, return_two_dones=True) + old_environment = envs.make(old_environment_id) + new_environment = envs.make(new_environment_id) old_reset_observation = old_environment.reset(seed=seed) new_reset_observation = new_environment.reset(seed=seed) @@ -19,26 +19,13 @@ def verify_environments_match( for i in range(num_actions): action = old_environment.action_space.sample() - ( - old_observation, - old_reward, - old_terminated, - old_truncated, - old_info, - ) = old_environment.step(action) - ( - new_observation, - new_reward, - new_terminated, - new_truncated, - new_info, - ) = new_environment.step(action) + old_observation, old_reward, old_done, old_info = old_environment.step(action) + new_observation, new_reward, new_done, new_info = new_environment.step(action) eps = 1e-6 np.testing.assert_allclose(old_observation, new_observation, atol=eps) np.testing.assert_allclose(old_reward, new_reward, atol=eps) - np.testing.assert_equal(old_terminated, new_terminated) - np.testing.assert_equal(old_truncated, new_truncated) + np.testing.assert_allclose(old_done, new_done, atol=eps) for key in old_info: np.testing.assert_allclose(old_info[key], new_info[key], atol=eps) diff --git a/tests/envs/test_registration.py b/tests/envs/test_registration.py index 707f0ee04f4..125b7cef591 100644 --- a/tests/envs/test_registration.py +++ b/tests/envs/test_registration.py @@ -2,8 +2,9 @@ import gym from gym import envs, error -from gym.envs import register, spec +from gym.envs import registration from gym.envs.classic_control import cartpole +from gym.envs.registration import EnvSpec, EnvSpecTree class ArgumentEnv(gym.Env): @@ -54,8 +55,8 @@ def register_some_envs(): for version in versions: env_id = f"{namespace}/{versioned_name}-v{version}" - del gym.envs.registry[env_id] - del gym.envs.registry[f"{namespace}/{unversioned_name}"] + del gym.envs.registry.env_specs[env_id] + del gym.envs.registry.env_specs[f"{namespace}/{unversioned_name}"] def test_make(): @@ -82,15 +83,10 @@ def test_make(): ], ) def test_register(env_id, namespace, name, version): - register(env_id) + envs.register(env_id) assert gym.envs.spec(env_id).id == env_id - full_name = f"{name}" - if namespace: - full_name = f"{namespace}/{full_name}" - if version is not None: - full_name = f"{full_name}-v{version}" - assert full_name in gym.envs.registry.keys() - del gym.envs.registry[env_id] + assert version in gym.envs.registry.env_specs.tree[namespace][name].keys() + del gym.envs.registry.env_specs[env_id] @pytest.mark.parametrize( @@ -103,7 +99,7 @@ def test_register(env_id, namespace, name, version): ) def test_register_error(env_id): with pytest.raises(error.Error, match="Malformed environment ID"): - register(env_id) + envs.register(env_id) @pytest.mark.parametrize( @@ -192,23 +188,27 @@ def test_spec_with_kwargs(): def test_missing_lookup(): - register(id="Test1-v0", entry_point=None) - register(id="Test1-v15", entry_point=None) - register(id="Test1-v9", entry_point=None) - register(id="Other1-v100", entry_point=None) - - with pytest.raises(error.DeprecatedEnv): - spec("Test1-v1") + registry = registration.EnvRegistry() + registry.register(id="Test-v0", entry_point=None) + registry.register(id="Test-v15", entry_point=None) + registry.register(id="Test-v9", entry_point=None) + registry.register(id="Other-v100", entry_point=None) + try: + registry.spec("Test-v1") # must match an env name but not the version above + except error.DeprecatedEnv: + pass + else: + assert False try: - spec("Test1-v1000") + registry.spec("Test-v1000") except error.UnregisteredEnv: pass else: assert False try: - spec("Unknown1-v1") + registry.spec("Unknown-v1") except error.UnregisteredEnv: pass else: @@ -216,8 +216,9 @@ def test_missing_lookup(): def test_malformed_lookup(): + registry = registration.EnvRegistry() try: - spec("“Breakout-v0”") + registry.spec("“Breakout-v0”") except error.Error as e: assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}" else: @@ -225,47 +226,99 @@ def test_malformed_lookup(): def test_versioned_lookups(): - register("test/Test2-v5") + registry = registration.EnvRegistry() + registry.register("test/Test-v5") with pytest.raises(error.VersionNotFound): - spec("test/Test2-v9") + registry.spec("test/Test-v9") with pytest.raises(error.DeprecatedEnv): - spec("test/Test2-v4") + registry.spec("test/Test-v4") - assert spec("test/Test2-v5") + assert registry.spec("test/Test-v5") def test_default_lookups(): - register("test/Test3") + registry = registration.EnvRegistry() + registry.register("test/Test") with pytest.raises(error.DeprecatedEnv): - spec("test/Test3-v0") + registry.spec("test/Test-v0") # Lookup default - spec("test/Test3") + registry.spec("test/Test") + + +def test_env_spec_tree(): + spec_tree = EnvSpecTree() + + # Add with namespace + spec = EnvSpec("test/Test-v0") + spec_tree["test/Test-v0"] = spec + assert spec_tree.tree.keys() == {"test"} + assert spec_tree.tree["test"].keys() == {"Test"} + assert spec_tree.tree["test"]["Test"].keys() == {0} + assert spec_tree.tree["test"]["Test"][0] == spec + assert spec_tree["test/Test-v0"] == spec + + # Add without namespace + spec = EnvSpec("Test-v0") + spec_tree["Test-v0"] = spec + assert spec_tree.tree.keys() == {"test", None} + assert spec_tree.tree[None].keys() == {"Test"} + assert spec_tree.tree[None]["Test"].keys() == {0} + assert spec_tree.tree[None]["Test"][0] == spec + + # Delete last version deletes entire subtree + del spec_tree["test/Test-v0"] + assert spec_tree.tree.keys() == {None} + + # Append second version for same name + spec_tree["Test-v1"] = EnvSpec("Test-v1") + assert spec_tree.tree.keys() == {None} + assert spec_tree.tree[None].keys() == {"Test"} + assert spec_tree.tree[None]["Test"].keys() == {0, 1} + + # Deleting one version leaves other + del spec_tree["Test-v0"] + assert spec_tree.tree.keys() == {None} + assert spec_tree.tree[None].keys() == {"Test"} + assert spec_tree.tree[None]["Test"].keys() == {1} + + # Add without version + myenv = "MyAwesomeEnv" + spec = EnvSpec(myenv) + spec_tree[myenv] = spec + assert spec_tree.tree.keys() == {None} + assert myenv in spec_tree.tree[None].keys() + assert spec_tree.tree[None][myenv].keys() == {None} + assert spec_tree.tree[None][myenv][None] == spec + assert spec_tree.__repr__() == "├──Test: [ v1 ]\n" + f"└──{myenv}: [ ]\n" def test_register_versioned_unversioned(): # Register versioned then unversioned versioned_env = "Test/MyEnv-v0" - register(versioned_env) + envs.register(versioned_env) assert gym.envs.spec(versioned_env).id == versioned_env unversioned_env = "Test/MyEnv" with pytest.raises(error.RegistrationError): - register(unversioned_env) + envs.register(unversioned_env) # Clean everything - del gym.envs.registry[versioned_env] + del gym.envs.registry.env_specs[versioned_env] # Register unversioned then versioned - register(unversioned_env) + with pytest.warns(UserWarning): + envs.register(unversioned_env) assert gym.envs.spec(unversioned_env).id == unversioned_env with pytest.raises(error.RegistrationError): - register(versioned_env) + envs.register(versioned_env) # Clean everything - del gym.envs.registry[unversioned_env] + envs_list = [versioned_env, unversioned_env] + for env in envs_list: + del gym.envs.registry.env_specs[env] def test_return_latest_versioned_env(register_some_envs): diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index 85739028eb6..975fffa2af2 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -569,20 +569,20 @@ def test_infinite_space(space): # but floats are unbounded for infinite if np.any(space.high != 0): assert ( - space.is_bounded("above") is False + space.is_bounded("above") == False ), "inf upper bound supposed to be unbounded" else: assert ( - space.is_bounded("above") is True + space.is_bounded("above") == True ), "non-inf upper bound supposed to be bounded" if np.any(space.low != 0): assert ( - space.is_bounded("below") is False + space.is_bounded("below") == False ), "inf lower bound supposed to be unbounded" else: assert ( - space.is_bounded("below") is True + space.is_bounded("below") == True ), "non-inf lower bound supposed to be bounded" # check for dtype diff --git a/tests/test_core.py b/tests/test_core.py index ef3f7d7d449..2d6b0dcd305 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -25,7 +25,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, False, {}) + return (observation, 0.0, False, {}) class UnknownSpacesEnv(core.Env): @@ -54,12 +54,11 @@ def reset( def step(self, action): observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, False, {}) + return (observation, 0.0, False, {}) class OldStyleEnv(core.Env): - """This environment doesn't accept any arguments in reset, step returns one bool instead of two, - ideally we want to support this too (for now)""" + """This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)""" def __init__(self): pass diff --git a/tests/utils/test_env_checker.py b/tests/utils/test_env_checker.py index a846fdb9927..b50ec4c39e1 100644 --- a/tests/utils/test_env_checker.py +++ b/tests/utils/test_env_checker.py @@ -15,9 +15,8 @@ class ActionDictTestEnv(gym.Env): def step(self, action): observation = np.array([1.0, 1.5, 0.5]) reward = 1 - terminated = True - truncated = True - return observation, reward, terminated, truncated + done = True + return observation, reward, done def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): super().reset(seed=seed) @@ -28,12 +27,12 @@ def render(self, mode="human"): def test_check_env_dict_action(): - # Environment.step() only returns 4 values: obs, reward, terminated, truncated. Not info! + # Environment.step() only returns 3 values: obs, reward, done. Not info! test_env = ActionDictTestEnv() with pytest.raises(AssertionError) as errorinfo: check_env(env=test_env, warn=True) assert ( str(errorinfo.value) - == "The `step()` method must return four values: obs, reward, terminated, truncated, info" + == "The `step()` method must return four values: obs, reward, done, info" ) diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py index b1834e3a4eb..cc94f4a94e6 100644 --- a/tests/utils/test_play.py +++ b/tests/utils/test_play.py @@ -168,7 +168,7 @@ def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): return obs_t, obs_tp1, action, rew, terminated, truncated, info - env = gym.make(ENV, return_two_dones=True) + env = gym.make(ENV) env.reset(seed=SEED) keys_to_action = dummy_keys_to_action() @@ -179,7 +179,7 @@ def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): action = keys_to_action[(e.key,)] obs, _, _, _, _ = env.step(action) - env_play = gym.make(ENV, return_two_dones=True) + env_play = gym.make(ENV) status = PlayStatus(callback) play(env_play, callback=status.callback, keys_to_action=keys_to_action, seed=SEED) diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 513b7413aa9..41104799019 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -82,7 +82,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, terminateds, truncateds, _ = env.step(actions) + observations, rewards, dones, _ = env.step(actions) finally: env.close() @@ -97,15 +97,10 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(terminateds, np.ndarray) - assert terminateds.dtype == np.bool_ - assert terminateds.ndim == 1 - assert terminateds.size == 8 - - assert isinstance(truncateds, np.ndarray) - assert truncateds.dtype == np.bool_ - assert truncateds.ndim == 1 - assert truncateds.size == 8 + assert isinstance(dones, np.ndarray) + assert dones.dtype == np.bool_ + assert dones.ndim == 1 + assert dones.size == 8 @pytest.mark.parametrize("shared_memory", [True, False]) @@ -172,7 +167,7 @@ def test_reset_timeout_async_vector_env(shared_memory): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env.reset_async() - env.reset_wait(timeout=0.1) + observations = env.reset_wait(timeout=0.1) finally: env.close(terminate=True) @@ -183,11 +178,9 @@ def test_step_timeout_async_vector_env(shared_memory): with pytest.raises(TimeoutError): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - env.reset() + observations = env.reset() env.step_async([0.1, 0.1, 0.3, 0.1]) - observations, rewards, terminateds, truncateds, _ = env.step_wait( - timeout=0.1 - ) + observations, rewards, dones, _ = env.step_wait(timeout=0.1) finally: env.close(terminate=True) @@ -199,7 +192,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory): with pytest.raises(NoAsyncCallError): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - env.reset_wait() + observations = env.reset_wait() except NoAsyncCallError as exception: assert exception.name == "reset" raise @@ -210,7 +203,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) actions = env.action_space.sample() - env.reset() + observations = env.reset() env.step_async(actions) env.reset_async() except NoAsyncCallError as exception: @@ -229,7 +222,7 @@ def test_step_out_of_order_async_vector_env(shared_memory): env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) actions = env.action_space.sample() observations = env.reset() - observations, rewards, terminateds, truncateds, infos = env.step_wait() + observations, rewards, dones, infos = env.step_wait() except AlreadyPendingCallError as exception: assert exception.name == "step" raise @@ -255,7 +248,7 @@ def test_already_closed_async_vector_env(shared_memory): with pytest.raises(ClosedEnvironmentError): env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env.close() - env.reset() + observations = env.reset() @pytest.mark.parametrize("shared_memory", [True, False]) @@ -279,7 +272,7 @@ def test_custom_space_async_vector_env(): assert isinstance(env.action_space, Tuple) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, termianteds, truncateds, _ = env.step(actions) + step_observations, rewards, dones, _ = env.step(actions) finally: env.close() diff --git a/tests/vector/test_shared_memory.py b/tests/vector/test_shared_memory.py index 5d18452a881..120d109dd9f 100644 --- a/tests/vector/test_shared_memory.py +++ b/tests/vector/test_shared_memory.py @@ -65,7 +65,8 @@ def assert_nested_type(lhs, rhs, n): # Assert the length of the array assert len(lhs[:]) == n * len(rhs[:]) # Assert the data type - assert isinstance(lhs[0], type(rhs[0])) + assert type(lhs[0]) == type(rhs[0]) # noqa: E721 + else: raise TypeError(f"Got unknown type `{type(lhs)}`.") @@ -82,7 +83,7 @@ def assert_nested_type(lhs, rhs, n): def test_create_shared_memory_custom_space(n, ctx, space): ctx = mp if (ctx is None) else mp.get_context(ctx) with pytest.raises(CustomSpaceError): - create_shared_memory(space, n=n, ctx=ctx) + shared_memory = create_shared_memory(space, n=n, ctx=ctx) @pytest.mark.parametrize( @@ -123,10 +124,6 @@ def write(i, shared_memory, sample): assert_nested_equal(shared_memory_n8, samples) -def _process_write(space, i, shared_memory, sample): - write_to_shared_memory(space, i, sample, shared_memory) - - @pytest.mark.parametrize( "space", spaces, ids=[space.__class__.__name__ for space in spaces] ) @@ -156,13 +153,15 @@ def assert_nested_equal(lhs, rhs, space, n): else: raise TypeError(f"Got unknown type `{type(space)}`") + def write(i, shared_memory, sample): + write_to_shared_memory(space, i, sample, shared_memory) + shared_memory_n8 = create_shared_memory(space, n=8) memory_view_n8 = read_from_shared_memory(space, shared_memory_n8, n=8) samples = [space.sample() for _ in range(8)] processes = [ - Process(target=_process_write, args=(space, i, shared_memory_n8, samples[i])) - for i in range(8) + Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8) ] for process in processes: diff --git a/tests/vector/test_spaces.py b/tests/vector/test_spaces.py index fa7aee0b4ff..d01d5a45ad9 100644 --- a/tests/vector/test_spaces.py +++ b/tests/vector/test_spaces.py @@ -1,13 +1,9 @@ -import copy - import numpy as np import pytest -from numpy.testing import assert_array_equal -from gym import Space from gym.spaces import Box, Dict, MultiDiscrete, Tuple from gym.vector.utils.spaces import batch_space, iterate -from tests.vector.utils import CustomSpace, assert_rng_equal, custom_spaces, spaces +from tests.vector.utils import CustomSpace, custom_spaces, spaces expected_batch_spaces_4 = [ Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64), @@ -133,66 +129,3 @@ def test_iterate_custom_space(space, batch_space): for i, item in enumerate(iterator): assert item in space assert i == 3 - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -@pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]]) -@pytest.mark.parametrize( - "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] -) -def test_rng_different_at_each_index(space: Space, n: int, base_seed: int): - """ - Tests that the rng values produced at each index are different - to prevent if the rng is copied for each subspace - """ - space.seed(base_seed) - - batched_space = batch_space(space, n) - assert space.np_random is not batched_space.np_random - assert_rng_equal(space.np_random, batched_space.np_random) - - batched_sample = batched_space.sample() - sample = list(iterate(batched_space, batched_sample)) - assert not all(np.all(element == sample[0]) for element in sample), sample - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]]) -@pytest.mark.parametrize( - "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] -) -def test_deterministic(space: Space, n: int, base_seed: int): - """Tests the batched spaces are deterministic by using a copied version""" - # Copy the spaces and check that the np_random are not reference equal - space_a = space - space_a.seed(base_seed) - space_b = copy.deepcopy(space_a) - assert_rng_equal(space_a.np_random, space_b.np_random) - assert space_a.np_random is not space_b.np_random - - # Batch the spaces and check that the np_random are not reference equal - space_a_batched = batch_space(space_a, n) - space_b_batched = batch_space(space_b, n) - assert_rng_equal(space_a_batched.np_random, space_b_batched.np_random) - assert space_a_batched.np_random is not space_b_batched.np_random - # Create that the batched space is not reference equal to the origin spaces - assert space_a.np_random is not space_a_batched.np_random - - # Check that batched space a and b random number generator are not effected by the original space - space_a.sample() - space_a_batched_sample = space_a_batched.sample() - space_b_batched_sample = space_b_batched.sample() - for a_sample, b_sample in zip( - iterate(space_a_batched, space_a_batched_sample), - iterate(space_b_batched, space_b_batched_sample), - ): - if isinstance(a_sample, tuple): - assert len(a_sample) == len(b_sample) - for a_subsample, b_subsample in zip(a_sample, b_sample): - assert_array_equal(a_subsample, b_subsample) - else: - assert_array_equal(a_sample, b_sample) diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index d27d81e18ae..623803238ce 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -1,16 +1,9 @@ import numpy as np import pytest -from gym.envs.registration import EnvSpec from gym.spaces import Box, Discrete, MultiDiscrete, Tuple from gym.vector.sync_vector_env import SyncVectorEnv -from tests.envs.spec_list import spec_list -from tests.vector.utils import ( - CustomSpace, - assert_rng_equal, - make_custom_space_env, - make_env, -) +from tests.vector.utils import CustomSpace, make_custom_space_env, make_env def test_create_sync_vector_env(): @@ -83,7 +76,7 @@ def test_step_sync_vector_env(use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, terminateds, truncateds, _ = env.step(actions) + observations, rewards, dones, _ = env.step(actions) finally: env.close() @@ -98,15 +91,10 @@ def test_step_sync_vector_env(use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(terminateds, np.ndarray) - assert terminateds.dtype == np.bool_ - assert terminateds.ndim == 1 - assert terminateds.size == 8 - - assert isinstance(truncateds, np.ndarray) - assert truncateds.dtype == np.bool_ - assert truncateds.ndim == 1 - assert truncateds.size == 8 + assert isinstance(dones, np.ndarray) + assert dones.dtype == np.bool_ + assert dones.ndim == 1 + assert dones.size == 8 def test_call_sync_vector_env(): @@ -162,7 +150,7 @@ def test_custom_space_sync_vector_env(): assert isinstance(env.action_space, Tuple) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, terminateds, truncateds, _ = env.step(actions) + step_observations, rewards, dones, _ = env.step(actions) finally: env.close() @@ -179,27 +167,3 @@ def test_custom_space_sync_vector_env(): "step(action-5)", "step(action-7)", ) - - -def test_sync_vector_env_seed(): - env = make_env("BipedalWalker-v3", seed=123)() - sync_vector_env = SyncVectorEnv([make_env("BipedalWalker-v3", seed=123)]) - - assert_rng_equal(env.action_space.np_random, sync_vector_env.action_space.np_random) - for _ in range(100): - env_action = env.action_space.sample() - vector_action = sync_vector_env.action_space.sample() - assert np.all(env_action == vector_action) - - -@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) -def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3): - """Check that for all environments, the sync vector envs produce the same action samples using the same seeds""" - env_1 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) - env_2 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) - assert_rng_equal(env_1.action_space.np_random, env_2.action_space.np_random) - - for _ in range(100): - env_1_samples = env_1.action_space.sample() - env_2_samples = env_2.action_space.sample() - assert np.all(env_1_samples == env_2_samples) diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index c1cca786522..82870d79c29 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -31,19 +31,19 @@ def test_vector_env_equal(shared_memory): assert actions in sync_env.action_space # fmt: off - async_observations, async_rewards, async_terminateds, async_truncateds, async_infos = async_env.step(actions) - sync_observations, sync_rewards, sync_terminateds, sync_truncateds, sync_infos = sync_env.step(actions) + async_observations, async_rewards, async_dones, async_infos = async_env.step(actions) + sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions) # fmt: on - for idx in range(len(sync_terminateds)): - if sync_terminateds[idx] or sync_truncateds[idx]: - assert "closing_observation" in async_infos[idx] - assert "closing_observation" in sync_infos[idx] + for idx in range(len(sync_dones)): + if sync_dones[idx]: + assert "terminal_observation" in async_infos[idx] + assert "terminal_observation" in sync_infos[idx] + assert sync_dones[idx] assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) - assert np.all(async_terminateds == sync_terminateds) - assert np.all(async_truncateds == sync_truncateds) + assert np.all(async_dones == sync_dones) finally: async_env.close() diff --git a/tests/vector/test_vector_env_wrapper.py b/tests/vector/test_vector_env_wrapper.py index 156eaa47f64..4c8d165d175 100644 --- a/tests/vector/test_vector_env_wrapper.py +++ b/tests/vector/test_vector_env_wrapper.py @@ -1,3 +1,4 @@ +import gym from gym.vector import VectorEnvWrapper, make diff --git a/tests/vector/utils.py b/tests/vector/utils.py index 0e3febe2491..0eadb672642 100644 --- a/tests/vector/utils.py +++ b/tests/vector/utils.py @@ -5,7 +5,6 @@ import gym from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple -from gym.utils.seeding import RandomNumberGenerator spaces = [ Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64), @@ -68,18 +67,18 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): time.sleep(action) observation = self.observation_space.sample() - reward, terminated, truncated = 0.0, False, False - return observation, reward, terminated, truncated, {} + reward, done = 0.0, False + return observation, reward, done, {} class CustomSpace(gym.Space): """Minimal custom observation space.""" def sample(self): - return self.np_random.integers(0, 10, ()) + return "sample" def contains(self, x): - return 0 <= x <= 10 + return isinstance(x, str) def __eq__(self, other): return isinstance(other, CustomSpace) @@ -103,15 +102,13 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): observation = f"step({action:s})" - reward, terminated, truncated = 0.0, False, False - return observation, reward, terminated, truncated, {} + reward, done = 0.0, False + return observation, reward, done, {} -def make_env(env_name, seed, return_two_dones=True): - # return_two_dones=True, only for compatibility with vector tests, to be removed at v1.0 +def make_env(env_name, seed): def _make(): - env = gym.make(env_name, return_two_dones=return_two_dones) - env.action_space.seed(seed) + env = gym.make(env_name) env.reset(seed=seed) return env @@ -134,7 +131,3 @@ def _make(): return env return _make - - -def assert_rng_equal(rng_1: RandomNumberGenerator, rng_2: RandomNumberGenerator): - assert rng_1.bit_generator.state == rng_2.bit_generator.state diff --git a/tests/wrappers/nested_dict_test.py b/tests/wrappers/nested_dict_test.py index 0f446a4433f..bde47054137 100644 --- a/tests/wrappers/nested_dict_test.py +++ b/tests/wrappers/nested_dict_test.py @@ -5,7 +5,7 @@ import pytest import gym -from gym.spaces import Box, Dict, Tuple +from gym.spaces import Box, Dict, Discrete, Tuple from gym.wrappers import FilterObservation, FlattenObservation @@ -29,8 +29,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): del action observation = self.observation_space.sample() - reward, terminated, truncated, info = 0.0, False, False, {} - return observation, reward, terminated, truncated, info + reward, terminal, info = 0.0, False, {} + return observation, reward, terminal, info NESTED_DICT_TEST_CASES = ( diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index a7bf6b40ea7..e36d3768838 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -74,13 +74,13 @@ def test_atari_preprocessing_scale(env_fn): noop_max=0, ) obs = env.reset().flatten() - terminated, truncated, step_i = False, False, 0 + done, step_i = False, 0 max_obs = 1 if scaled else 255 assert (0 <= obs).all() and ( obs <= max_obs ).all(), f"Obs. must be in range [0,{max_obs}]" - while not (terminated or truncated) or step_i <= max_test_steps: - obs, _, terminated, truncated, _ = env.step(env.action_space.sample()) + while not done or step_i <= max_test_steps: + obs, _, done, _ = env.step(env.action_space.sample()) obs = obs.flatten() assert (0 <= obs).all() and ( obs <= max_obs diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 278e98e08c5..76c035f87dd 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -1,3 +1,4 @@ +import types from typing import Optional from unittest.mock import MagicMock @@ -22,9 +23,7 @@ class DummyResetEnv(gym.Env): metadata = {} def __init__(self): - self.action_space = gym.spaces.Box( - low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float64 - ) + self.action_space = gym.spaces.Box(low=np.array([-1.0]), high=np.array([1.0])) self.observation_space = gym.spaces.Box( low=np.array([-1.0]), high=np.array([1.0]) ) @@ -35,7 +34,6 @@ def step(self, action): return ( np.array([self.count]), 1 if self.count > 2 else 0, - False, self.count > 2, {"count": self.count}, ) @@ -55,7 +53,7 @@ def reset( def test_autoreset_reset_info(): - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") env = AutoResetWrapper(env) ob_space = env.observation_space obs = env.reset() @@ -65,7 +63,6 @@ def test_autoreset_reset_info(): obs, info = env.reset(return_info=True) assert ob_space.contains(obs) assert isinstance(info, dict) - env.close() @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) @@ -78,40 +75,37 @@ def test_make_autoreset_true(spec): amount of time with random actions, which is true as of the time of adding this test. """ env = None - with pytest.warns(None): - env = spec.make(autoreset=True, return_two_dones=True) + with pytest.warns(None) as warnings: + env = spec.make(autoreset=True) - env.reset(seed=0) + ob_space = env.observation_space + obs = env.reset(seed=0) env.action_space.seed(0) env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset) - terminated = False - truncated = False - while not terminated and not truncated: - obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) + done = False + while not done: + obs, reward, done, info = env.step(env.action_space.sample()) assert isinstance(env, AutoResetWrapper) assert env.unwrapped.reset.called - env.close() @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) def test_make_autoreset_false(spec): env = None - with pytest.warns(None): - env = spec.make(autoreset=False, return_two_dones=True) + with pytest.warns(None) as warnings: + env = spec.make(autoreset=False) assert not isinstance(env, AutoResetWrapper) - env.close() @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) def test_make_autoreset_default_false(spec): env = None - with pytest.warns(None): - env = spec.make(return_two_dones=True) + with pytest.warns(None) as warnings: + env = spec.make() assert not isinstance(env, AutoResetWrapper) - env.close() def test_autoreset_autoreset(): @@ -121,38 +115,32 @@ def test_autoreset_autoreset(): assert obs == np.array([0]) assert info == {"count": 0} action = 1 - obs, reward, terminated, truncated, info = env.step(action) + obs, reward, done, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert terminated is False - assert truncated is False + assert done == False assert info == {"count": 1} - obs, reward, terminated, truncated, info = env.step(action) + obs, reward, done, info = env.step(action) assert obs == np.array([2]) - assert terminated is False - assert truncated is False + assert done == False assert reward == 0 assert info == {"count": 2} - obs, reward, terminated, truncated, info = env.step(action) + obs, reward, done, info = env.step(action) assert obs == np.array([0]) - assert terminated is False - assert truncated is True + assert done == True assert reward == 1 assert info == { "count": 0, - "closing_observation": np.array([3]), - "closing_info": {"count": 3}, + "terminal_observation": np.array([3]), + "terminal_info": {"count": 3}, } - obs, reward, terminated, truncated, info = env.step(action) + obs, reward, done, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert terminated is False - assert truncated is False + assert done == False assert info == {"count": 1} - obs, reward, terminated, truncated, info = env.step(action) + obs, reward, done, info = env.step(action) assert obs == np.array([2]) assert reward == 0 - assert terminated is False - assert truncated is False + assert done == False assert info == {"count": 2} - env.close() diff --git a/tests/wrappers/test_clip_action.py b/tests/wrappers/test_clip_action.py index 58364088d09..aebf867b6e0 100644 --- a/tests/wrappers/test_clip_action.py +++ b/tests/wrappers/test_clip_action.py @@ -6,8 +6,9 @@ def test_clip_action(): # mountaincar: action-based rewards - env = gym.make("MountainCarContinuous-v0", return_two_dones=True) - wrapped_env = ClipAction(gym.make("MountainCarContinuous-v0")) + make_env = lambda: gym.make("MountainCarContinuous-v0") + env = make_env() + wrapped_env = ClipAction(make_env()) seed = 0 @@ -16,11 +17,10 @@ def test_clip_action(): actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] for action in actions: - obs1, r1, term1, trunc1, _ = env.step( + obs1, r1, d1, _ = env.step( np.clip(action, env.action_space.low, env.action_space.high) ) - obs2, r2, term2, trunc2, _ = wrapped_env.step(action) + obs2, r2, d2, _ = wrapped_env.step(action) assert np.allclose(r1, r2) assert np.allclose(obs1, obs2) - assert term1 == term2 - assert trunc1 == trunc2 + assert d1 == d2 diff --git a/tests/wrappers/test_filter_observation.py b/tests/wrappers/test_filter_observation.py index de46ae2dc20..e7d5ef2b052 100644 --- a/tests/wrappers/test_filter_observation.py +++ b/tests/wrappers/test_filter_observation.py @@ -32,8 +32,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): del action observation = self.observation_space.sample() - reward, terminated, truncated, info = 0.0, False, False, {} - return observation, reward, terminated, truncated, info + reward, terminal, info = 0.0, False, {} + return observation, reward, terminal, info FILTER_OBSERVATION_TEST_CASES = ( diff --git a/tests/wrappers/test_frame_stack.py b/tests/wrappers/test_frame_stack.py index 1a9ac8b6931..b9af3002c1f 100644 --- a/tests/wrappers/test_frame_stack.py +++ b/tests/wrappers/test_frame_stack.py @@ -1,6 +1,9 @@ -import numpy as np import pytest +pytest.importorskip("gym.envs.atari") + +import numpy as np + import gym from gym.wrappers import FrameStack @@ -10,9 +13,6 @@ lz4 = None -pytest.importorskip("gym.envs.atari") - - @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "Pong-v0"]) @pytest.mark.parametrize("num_stack", [2, 3, 4]) @pytest.mark.parametrize( @@ -42,8 +42,8 @@ def test_frame_stack(env_id, num_stack, lz4_compress): for _ in range(num_stack**2): action = env.action_space.sample() - dup_obs, _, _, _, _ = dup.step(action) - obs, _, _, _, _ = env.step(action) + dup_obs, _, _, _ = dup.step(action) + obs, _, _, _ = env.step(action) assert np.allclose(obs[-1], dup_obs) assert len(obs) == num_stack diff --git a/tests/wrappers/test_normalize.py b/tests/wrappers/test_normalize.py index ee73163a1d4..13bf32011be 100644 --- a/tests/wrappers/test_normalize.py +++ b/tests/wrappers/test_normalize.py @@ -22,13 +22,7 @@ def __init__(self, return_reward_idx=0): def step(self, action): self.t += 1 - return ( - np.array([self.t]), - self.t, - self.t == len(self.returned_rewards), - False, - {}, - ) + return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {} def reset( self, @@ -100,7 +94,7 @@ def test_normalize_observation_vector_env(): env_fns = [make_env(0), make_env(1)] envs = gym.vector.SyncVectorEnv(env_fns) envs.reset() - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _ = envs.step(envs.action_space.sample()) np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4) np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4) @@ -113,7 +107,7 @@ def test_normalize_observation_vector_env(): np.mean([0.5]), # the mean of first observations [[0, 1]] decimal=4, ) - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.obs_rms.mean, np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]] @@ -126,13 +120,13 @@ def test_normalize_return_vector_env(): envs = gym.vector.SyncVectorEnv(env_fns) envs = NormalizeReward(envs) obs = envs.reset() - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.return_rms.mean, np.mean([1.5]), # the mean of first returns [[1, 2]] decimal=4, ) - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.return_rms.mean, np.mean( diff --git a/tests/wrappers/test_order_enforcing.py b/tests/wrappers/test_order_enforcing.py index 47dd4597121..9b9290aec76 100644 --- a/tests/wrappers/test_order_enforcing.py +++ b/tests/wrappers/test_order_enforcing.py @@ -1,3 +1,6 @@ +import numpy as np +import pytest + import gym from gym.wrappers import OrderEnforcing diff --git a/tests/wrappers/test_pixel_observation.py b/tests/wrappers/test_pixel_observation.py index 480e2cec16c..95f094579cd 100644 --- a/tests/wrappers/test_pixel_observation.py +++ b/tests/wrappers/test_pixel_observation.py @@ -27,8 +27,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): del action observation = self.observation_space.sample() - reward, terminated, truncated, info = 0.0, False, False, {} - return observation, reward, terminated, truncated, info + reward, terminal, info = 0.0, False, {} + return observation, reward, terminal, info class FakeArrayObservationEnvironment(FakeEnvironment): diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index 2581480af50..d9633409eb3 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import gym @@ -7,7 +8,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("deque_size", [2, 5]) def test_record_episode_statistics(env_id, deque_size): - env = gym.make(env_id, return_two_dones=True) + env = gym.make(env_id) env = RecordEpisodeStatistics(env, deque_size) for n in range(5): @@ -15,8 +16,8 @@ def test_record_episode_statistics(env_id, deque_size): assert env.episode_returns[0] == 0.0 assert env.episode_lengths[0] == 0 for t in range(env.spec.max_episode_steps): - _, _, terminated, truncated, info = env.step(env.action_space.sample()) - if terminated or truncated: + _, _, done, info = env.step(env.action_space.sample()) + if done: assert "episode" in info assert all([item in info["episode"] for item in ["r", "l", "t"]]) break @@ -25,7 +26,7 @@ def test_record_episode_statistics(env_id, deque_size): def test_record_episode_statistics_reset_info(): - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") env = RecordEpisodeStatistics(env) ob_space = env.observation_space obs = env.reset() @@ -40,12 +41,7 @@ def test_record_episode_statistics_reset_info(): ("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)] ) def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): - envs = gym.vector.make( - "CartPole-v1", - num_envs=num_envs, - asynchronous=asynchronous, - return_two_dones=True, - ) + envs = gym.vector.make("CartPole-v1", num_envs=num_envs, asynchronous=asynchronous) envs = RecordEpisodeStatistics(envs) max_episode_step = ( envs.env_fns[0]().spec.max_episode_steps @@ -54,9 +50,9 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): ) envs.reset() for _ in range(max_episode_step + 1): - _, _, terminateds, truncateds, infos = envs.step(envs.action_space.sample()) + _, _, dones, infos = envs.step(envs.action_space.sample()) for idx, info in enumerate(infos): - if terminateds[idx] or truncateds[idx]: + if dones[idx]: assert "episode" in info assert all([item in info["episode"] for item in ["r", "l", "t"]]) break diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index dd1d93077dc..0757c1bec54 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -1,19 +1,26 @@ import os import shutil +import numpy as np +import pytest + import gym -from gym.wrappers import capped_cubic_video_schedule +from gym.wrappers import ( + RecordEpisodeStatistics, + RecordVideo, + capped_cubic_video_schedule, +) def test_record_video_using_default_trigger(): - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") env = gym.wrappers.RecordVideo(env, "videos") env.reset() for _ in range(199): action = env.action_space.sample() - _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: + _, _, done, _ = env.step(action) + if done: env.reset() env.close() assert os.path.isdir("videos") @@ -25,7 +32,7 @@ def test_record_video_using_default_trigger(): def test_record_video_reset_return_info(): - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) ob_space = env.observation_space obs, info = env.reset(return_info=True) @@ -35,7 +42,7 @@ def test_record_video_reset_return_info(): assert ob_space.contains(obs) assert isinstance(info, dict) - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) ob_space = env.observation_space obs = env.reset(return_info=False) @@ -44,7 +51,7 @@ def test_record_video_reset_return_info(): shutil.rmtree("videos") assert ob_space.contains(obs) - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) ob_space = env.observation_space obs = env.reset() @@ -55,14 +62,14 @@ def test_record_video_reset_return_info(): def test_record_video_step_trigger(): - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") env._max_episode_steps = 20 env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env.reset() for _ in range(199): action = env.action_space.sample() - _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: + _, _, done, _ = env.step(action) + if done: env.reset() env.close() assert os.path.isdir("videos") @@ -73,7 +80,7 @@ def test_record_video_step_trigger(): def make_env(gym_id, seed): def thunk(): - env = gym.make(gym_id, return_two_dones=True) + env = gym.make(gym_id) env._max_episode_steps = 20 if seed == 1: env = gym.wrappers.RecordVideo( @@ -89,7 +96,7 @@ def test_record_video_within_vector(): envs = gym.wrappers.RecordEpisodeStatistics(envs) envs.reset() for i in range(199): - _, _, _, _, infos = envs.step(envs.action_space.sample()) + _, _, _, infos = envs.step(envs.action_space.sample()) for info in infos: if "episode" in info.keys(): print(f"episode_reward={info['episode']['r']}") diff --git a/tests/wrappers/test_rescale_action.py b/tests/wrappers/test_rescale_action.py index abade7c9705..6db5ad5fa75 100644 --- a/tests/wrappers/test_rescale_action.py +++ b/tests/wrappers/test_rescale_action.py @@ -6,13 +6,13 @@ def test_rescale_action(): - env = gym.make("CartPole-v1", return_two_dones=True) + env = gym.make("CartPole-v1") with pytest.raises(AssertionError): env = RescaleAction(env, -1, 1) del env - env = gym.make("Pendulum-v1", return_two_dones=True) - wrapped_env = RescaleAction(gym.make("Pendulum-v1", return_two_dones=True), -1, 1) + env = gym.make("Pendulum-v1") + wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1) seed = 0 @@ -20,10 +20,10 @@ def test_rescale_action(): wrapped_obs = wrapped_env.reset(seed=seed) assert np.allclose(obs, wrapped_obs) - obs, reward, _, _, _ = env.step([1.5]) + obs, reward, _, _ = env.step([1.5]) with pytest.raises(AssertionError): wrapped_env.step([1.5]) - wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75]) + wrapped_obs, wrapped_reward, _, _ = wrapped_env.step([0.75]) assert np.allclose(obs, wrapped_obs) assert np.allclose(reward, wrapped_reward) diff --git a/tests/wrappers/test_time_aware_observation.py b/tests/wrappers/test_time_aware_observation.py index bdf803346e1..a996d608cdc 100644 --- a/tests/wrappers/test_time_aware_observation.py +++ b/tests/wrappers/test_time_aware_observation.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_time_aware_observation(env_id): - env = gym.make(env_id, return_two_dones=True) + env = gym.make(env_id) wrapped_env = TimeAwareObservation(env) assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 @@ -17,12 +17,12 @@ def test_time_aware_observation(env_id): assert wrapped_obs[-1] == 0.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 - wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) + wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.t == 1.0 assert wrapped_obs[-1] == 1.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 - wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) + wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.t == 2.0 assert wrapped_obs[-1] == 2.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index fa7f70da430..32e6e5d2ad0 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -1,3 +1,6 @@ +import numpy as np +import pytest + import gym from gym.wrappers import TimeLimit diff --git a/tests/wrappers/test_transform_observation.py b/tests/wrappers/test_transform_observation.py index 0d108cb35d3..fc1076ae4f7 100644 --- a/tests/wrappers/test_transform_observation.py +++ b/tests/wrappers/test_transform_observation.py @@ -7,12 +7,10 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_transform_observation(env_id): - def affine_transform(x): - return 3 * x + 2 - - env = gym.make(env_id, return_two_dones=True) + affine_transform = lambda x: 3 * x + 2 + env = gym.make(env_id) wrapped_env = TransformObservation( - gym.make(env_id, return_two_dones=True), lambda obs: affine_transform(obs) + gym.make(env_id), lambda obs: affine_transform(obs) ) obs = env.reset(seed=0) @@ -20,15 +18,8 @@ def affine_transform(x): assert np.allclose(wrapped_obs, affine_transform(obs)) action = env.action_space.sample() - obs, reward, terminated, truncated, _ = env.step(action) - ( - wrapped_obs, - wrapped_reward, - wrapped_terminated, - wrapped_truncated, - _, - ) = wrapped_env.step(action) + obs, reward, done, _ = env.step(action) + wrapped_obs, wrapped_reward, wrapped_done, _ = wrapped_env.step(action) assert np.allclose(wrapped_obs, affine_transform(obs)) assert np.allclose(wrapped_reward, reward) - assert wrapped_terminated == terminated - assert wrapped_truncated == truncated + assert wrapped_done == done diff --git a/tests/wrappers/test_transform_reward.py b/tests/wrappers/test_transform_reward.py index 0e8cb32bdd4..c7badb7a2d0 100644 --- a/tests/wrappers/test_transform_reward.py +++ b/tests/wrappers/test_transform_reward.py @@ -10,17 +10,15 @@ def test_transform_reward(env_id): # use case #1: scale scales = [0.1, 200] for scale in scales: - env = gym.make(env_id, return_two_dones=True) - wrapped_env = TransformReward( - gym.make(env_id, return_two_dones=True), lambda r: scale * r - ) + env = gym.make(env_id) + wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r) action = env.action_space.sample() env.reset(seed=0) wrapped_env.reset(seed=0) - _, reward, _, _, _ = env.step(action) - _, wrapped_reward, _, _, _ = wrapped_env.step(action) + _, reward, _, _ = env.step(action) + _, wrapped_reward, _, _ = wrapped_env.step(action) assert wrapped_reward == scale * reward del env, wrapped_env @@ -28,35 +26,31 @@ def test_transform_reward(env_id): # use case #2: clip min_r = -0.0005 max_r = 0.0002 - env = gym.make(env_id, return_two_dones=True) - wrapped_env = TransformReward( - gym.make(env_id, return_two_dones=True), lambda r: np.clip(r, min_r, max_r) - ) + env = gym.make(env_id) + wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r)) action = env.action_space.sample() env.reset(seed=0) wrapped_env.reset(seed=0) - _, reward, _, _, _ = env.step(action) - _, wrapped_reward, _, _, _ = wrapped_env.step(action) + _, reward, _, _ = env.step(action) + _, wrapped_reward, _, _ = wrapped_env.step(action) assert abs(wrapped_reward) < abs(reward) assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002 del env, wrapped_env # use case #3: sign - env = gym.make(env_id, return_two_dones=True) - wrapped_env = TransformReward( - gym.make(env_id, return_two_dones=True), lambda r: np.sign(r) - ) + env = gym.make(env_id) + wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r)) env.reset(seed=0) wrapped_env.reset(seed=0) for _ in range(1000): action = env.action_space.sample() - _, wrapped_reward, terminated, truncated, _ = wrapped_env.step(action) + _, wrapped_reward, done, _ = wrapped_env.step(action) assert wrapped_reward in [-1.0, 0.0, 1.0] - if terminated or truncated: + if done: break del env, wrapped_env From 794737bbcb82db7068959a84ac1ce0b9f7847304 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 4 May 2022 14:39:47 +0530 Subject: [PATCH 10/37] fix circular import --- gym/wrappers/__init__.py | 2 +- gym/wrappers/frame_stack.py | 2 +- gym/wrappers/normalize.py | 2 +- gym/wrappers/record_episode_statistics.py | 2 +- gym/wrappers/record_video.py | 2 +- gym/wrappers/time_aware_observation.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gym/wrappers/__init__.py b/gym/wrappers/__init__.py index 549af6e2792..d0e70218941 100644 --- a/gym/wrappers/__init__.py +++ b/gym/wrappers/__init__.py @@ -11,7 +11,7 @@ from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.resize_observation import ResizeObservation -from gym.wrappers.step_compatibility import StepCompatibility, step_api_compatibility +from gym.wrappers.step_compatibility import StepCompatibility from gym.wrappers.time_aware_observation import TimeAwareObservation from gym.wrappers.time_limit import TimeLimit from gym.wrappers.transform_observation import TransformObservation diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index 5281dd50863..b6ff76571a6 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -4,7 +4,7 @@ from gym import ObservationWrapper from gym.spaces import Box -from gym.wrappers import step_api_compatibility +from gym.wrappers.step_compatibility import step_api_compatibility class LazyFrames: diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index 2b420308f34..29a932abff5 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -1,7 +1,7 @@ import numpy as np import gym -from gym.wrappers import step_api_compatibility +from gym.wrappers.step_compatibility import step_api_compatibility # taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index 3b18e866edb..b793a5b2743 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -4,7 +4,7 @@ import numpy as np import gym -from gym.wrappers import step_api_compatibility +from gym.wrappers.step_compatibility import step_api_compatibility @step_api_compatibility diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index 4d9b53bd637..7a746bb47a7 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -3,7 +3,7 @@ import gym from gym import logger -from gym.wrappers import step_api_compatibility +from gym.wrappers.step_compatibility import step_api_compatibility from gym.wrappers.monitoring import video_recorder diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py index c623be82783..837bd80894b 100644 --- a/gym/wrappers/time_aware_observation.py +++ b/gym/wrappers/time_aware_observation.py @@ -2,7 +2,7 @@ from gym import ObservationWrapper from gym.spaces import Box -from gym.wrappers import step_api_compatibility +from gym.wrappers.step_compatibility import step_api_compatibility @step_api_compatibility From f89e5da1bab010806acfc8fc5e78c2193185573d Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 4 May 2022 15:20:52 +0530 Subject: [PATCH 11/37] merge tests with master --- tests/envs/spec_list.py | 8 +- tests/envs/test_action_dim_check.py | 2 - tests/envs/test_atari_legacy_env_specs.py | 8 +- tests/envs/test_envs.py | 9 +- tests/envs/test_frozenlake_dfs.py | 3 - tests/envs/test_registration.py | 123 +++++------------- tests/spaces/test_spaces.py | 8 +- tests/utils/test_play.py | 16 +-- tests/vector/test_async_vector_env.py | 10 +- tests/vector/test_shared_memory.py | 15 ++- tests/vector/test_spaces.py | 69 +++++++++- tests/vector/test_sync_vector_env.py | 33 ++++- tests/vector/test_vector_env_wrapper.py | 1 - tests/vector/utils.py | 10 +- tests/wrappers/nested_dict_test.py | 2 +- tests/wrappers/test_autoreset.py | 32 ++--- tests/wrappers/test_clip_action.py | 5 +- tests/wrappers/test_frame_stack.py | 8 +- tests/wrappers/test_order_enforcing.py | 3 - .../test_record_episode_statistics.py | 1 - tests/wrappers/test_record_video.py | 9 +- tests/wrappers/test_time_limit.py | 3 - tests/wrappers/test_transform_observation.py | 4 +- 23 files changed, 208 insertions(+), 174 deletions(-) diff --git a/tests/envs/spec_list.py b/tests/envs/spec_list.py index 11c816f6bd1..a0f192b7c2a 100644 --- a/tests/envs/spec_list.py +++ b/tests/envs/spec_list.py @@ -11,7 +11,7 @@ skip_mujoco = not (os.environ.get("MUJOCO_KEY")) if not skip_mujoco: try: - import mujoco_py + import mujoco_py # noqa:F401 except ImportError: skip_mujoco = True @@ -24,12 +24,12 @@ def should_skip_env_spec_for_tests(spec): if skip_mujoco and ep.startswith("gym.envs.mujoco"): return True try: - import gym.envs.atari + import gym.envs.atari # noqa:F401 except ImportError: if ep.startswith("gym.envs.atari"): return True try: - import Box2D + import Box2D # noqa:F401 except ImportError: if ep.startswith("gym.envs.box2d"): return True @@ -50,6 +50,6 @@ def should_skip_env_spec_for_tests(spec): spec_list = [ spec - for spec in sorted(envs.registry.all(), key=lambda x: x.id) + for spec in sorted(envs.registry.values(), key=lambda x: x.id) if spec.entry_point is not None and not should_skip_env_spec_for_tests(spec) ] diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 448847cd6ec..a6479240858 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -1,5 +1,3 @@ -import pickle - import pytest from gym import envs diff --git a/tests/envs/test_atari_legacy_env_specs.py b/tests/envs/test_atari_legacy_env_specs.py index 37e99245c8d..5a1c406e4cd 100644 --- a/tests/envs/test_atari_legacy_env_specs.py +++ b/tests/envs/test_atari_legacy_env_specs.py @@ -1,11 +1,11 @@ -import pytest - -pytest.importorskip("gym.envs.atari") - from itertools import product +import pytest + from gym.envs.registration import registry +pytest.importorskip("gym.envs.atari") + def test_ale_legacy_env_specs(): versions = ["-v0", "-v4"] diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 23de61b98b2..b0a33da2ed4 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -61,7 +61,7 @@ def test_env(spec): @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) def test_reset_info(spec): - with pytest.warns(None) as warnings: + with pytest.warns(None): env = spec.make() ob_space = env.observation_space @@ -78,13 +78,12 @@ def test_reset_info(spec): # Run a longer rollout on some environments def test_random_rollout(): for env in [envs.make("CartPole-v1"), envs.make("FrozenLake-v1")]: - agent = lambda ob: env.action_space.sample() ob = env.reset() for _ in range(10): assert env.observation_space.contains(ob) - a = agent(ob) - assert env.action_space.contains(a) - (ob, _reward, done, _info) = env.step(a) + action = env.action_space.sample() + assert env.action_space.contains(action) + (ob, _reward, done, _info) = env.step(action) if done: break env.close() diff --git a/tests/envs/test_frozenlake_dfs.py b/tests/envs/test_frozenlake_dfs.py index 3cfeddf2b73..b620cbcb042 100644 --- a/tests/envs/test_frozenlake_dfs.py +++ b/tests/envs/test_frozenlake_dfs.py @@ -1,6 +1,3 @@ -import numpy as np -import pytest - from gym.envs.toy_text.frozen_lake import generate_random_map diff --git a/tests/envs/test_registration.py b/tests/envs/test_registration.py index 125b7cef591..707f0ee04f4 100644 --- a/tests/envs/test_registration.py +++ b/tests/envs/test_registration.py @@ -2,9 +2,8 @@ import gym from gym import envs, error -from gym.envs import registration +from gym.envs import register, spec from gym.envs.classic_control import cartpole -from gym.envs.registration import EnvSpec, EnvSpecTree class ArgumentEnv(gym.Env): @@ -55,8 +54,8 @@ def register_some_envs(): for version in versions: env_id = f"{namespace}/{versioned_name}-v{version}" - del gym.envs.registry.env_specs[env_id] - del gym.envs.registry.env_specs[f"{namespace}/{unversioned_name}"] + del gym.envs.registry[env_id] + del gym.envs.registry[f"{namespace}/{unversioned_name}"] def test_make(): @@ -83,10 +82,15 @@ def test_make(): ], ) def test_register(env_id, namespace, name, version): - envs.register(env_id) + register(env_id) assert gym.envs.spec(env_id).id == env_id - assert version in gym.envs.registry.env_specs.tree[namespace][name].keys() - del gym.envs.registry.env_specs[env_id] + full_name = f"{name}" + if namespace: + full_name = f"{namespace}/{full_name}" + if version is not None: + full_name = f"{full_name}-v{version}" + assert full_name in gym.envs.registry.keys() + del gym.envs.registry[env_id] @pytest.mark.parametrize( @@ -99,7 +103,7 @@ def test_register(env_id, namespace, name, version): ) def test_register_error(env_id): with pytest.raises(error.Error, match="Malformed environment ID"): - envs.register(env_id) + register(env_id) @pytest.mark.parametrize( @@ -188,27 +192,23 @@ def test_spec_with_kwargs(): def test_missing_lookup(): - registry = registration.EnvRegistry() - registry.register(id="Test-v0", entry_point=None) - registry.register(id="Test-v15", entry_point=None) - registry.register(id="Test-v9", entry_point=None) - registry.register(id="Other-v100", entry_point=None) - try: - registry.spec("Test-v1") # must match an env name but not the version above - except error.DeprecatedEnv: - pass - else: - assert False + register(id="Test1-v0", entry_point=None) + register(id="Test1-v15", entry_point=None) + register(id="Test1-v9", entry_point=None) + register(id="Other1-v100", entry_point=None) + + with pytest.raises(error.DeprecatedEnv): + spec("Test1-v1") try: - registry.spec("Test-v1000") + spec("Test1-v1000") except error.UnregisteredEnv: pass else: assert False try: - registry.spec("Unknown-v1") + spec("Unknown1-v1") except error.UnregisteredEnv: pass else: @@ -216,9 +216,8 @@ def test_missing_lookup(): def test_malformed_lookup(): - registry = registration.EnvRegistry() try: - registry.spec("“Breakout-v0”") + spec("“Breakout-v0”") except error.Error as e: assert "Malformed environment ID" in f"{e}", f"Unexpected message: {e}" else: @@ -226,99 +225,47 @@ def test_malformed_lookup(): def test_versioned_lookups(): - registry = registration.EnvRegistry() - registry.register("test/Test-v5") + register("test/Test2-v5") with pytest.raises(error.VersionNotFound): - registry.spec("test/Test-v9") + spec("test/Test2-v9") with pytest.raises(error.DeprecatedEnv): - registry.spec("test/Test-v4") + spec("test/Test2-v4") - assert registry.spec("test/Test-v5") + assert spec("test/Test2-v5") def test_default_lookups(): - registry = registration.EnvRegistry() - registry.register("test/Test") + register("test/Test3") with pytest.raises(error.DeprecatedEnv): - registry.spec("test/Test-v0") + spec("test/Test3-v0") # Lookup default - registry.spec("test/Test") - - -def test_env_spec_tree(): - spec_tree = EnvSpecTree() - - # Add with namespace - spec = EnvSpec("test/Test-v0") - spec_tree["test/Test-v0"] = spec - assert spec_tree.tree.keys() == {"test"} - assert spec_tree.tree["test"].keys() == {"Test"} - assert spec_tree.tree["test"]["Test"].keys() == {0} - assert spec_tree.tree["test"]["Test"][0] == spec - assert spec_tree["test/Test-v0"] == spec - - # Add without namespace - spec = EnvSpec("Test-v0") - spec_tree["Test-v0"] = spec - assert spec_tree.tree.keys() == {"test", None} - assert spec_tree.tree[None].keys() == {"Test"} - assert spec_tree.tree[None]["Test"].keys() == {0} - assert spec_tree.tree[None]["Test"][0] == spec - - # Delete last version deletes entire subtree - del spec_tree["test/Test-v0"] - assert spec_tree.tree.keys() == {None} - - # Append second version for same name - spec_tree["Test-v1"] = EnvSpec("Test-v1") - assert spec_tree.tree.keys() == {None} - assert spec_tree.tree[None].keys() == {"Test"} - assert spec_tree.tree[None]["Test"].keys() == {0, 1} - - # Deleting one version leaves other - del spec_tree["Test-v0"] - assert spec_tree.tree.keys() == {None} - assert spec_tree.tree[None].keys() == {"Test"} - assert spec_tree.tree[None]["Test"].keys() == {1} - - # Add without version - myenv = "MyAwesomeEnv" - spec = EnvSpec(myenv) - spec_tree[myenv] = spec - assert spec_tree.tree.keys() == {None} - assert myenv in spec_tree.tree[None].keys() - assert spec_tree.tree[None][myenv].keys() == {None} - assert spec_tree.tree[None][myenv][None] == spec - assert spec_tree.__repr__() == "├──Test: [ v1 ]\n" + f"└──{myenv}: [ ]\n" + spec("test/Test3") def test_register_versioned_unversioned(): # Register versioned then unversioned versioned_env = "Test/MyEnv-v0" - envs.register(versioned_env) + register(versioned_env) assert gym.envs.spec(versioned_env).id == versioned_env unversioned_env = "Test/MyEnv" with pytest.raises(error.RegistrationError): - envs.register(unversioned_env) + register(unversioned_env) # Clean everything - del gym.envs.registry.env_specs[versioned_env] + del gym.envs.registry[versioned_env] # Register unversioned then versioned - with pytest.warns(UserWarning): - envs.register(unversioned_env) + register(unversioned_env) assert gym.envs.spec(unversioned_env).id == unversioned_env with pytest.raises(error.RegistrationError): - envs.register(versioned_env) + register(versioned_env) # Clean everything - envs_list = [versioned_env, unversioned_env] - for env in envs_list: - del gym.envs.registry.env_specs[env] + del gym.envs.registry[unversioned_env] def test_return_latest_versioned_env(register_some_envs): diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index 975fffa2af2..85739028eb6 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -569,20 +569,20 @@ def test_infinite_space(space): # but floats are unbounded for infinite if np.any(space.high != 0): assert ( - space.is_bounded("above") == False + space.is_bounded("above") is False ), "inf upper bound supposed to be unbounded" else: assert ( - space.is_bounded("above") == True + space.is_bounded("above") is True ), "non-inf upper bound supposed to be bounded" if np.any(space.low != 0): assert ( - space.is_bounded("below") == False + space.is_bounded("below") is False ), "inf lower bound supposed to be unbounded" else: assert ( - space.is_bounded("below") == True + space.is_bounded("below") is True ), "non-inf lower bound supposed to be bounded" # check for dtype diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py index cc94f4a94e6..02acce160de 100644 --- a/tests/utils/test_play.py +++ b/tests/utils/test_play.py @@ -23,8 +23,8 @@ class DummyEnvSpec: class DummyPlayEnv(gym.Env): def step(self, action): obs = np.zeros((1, 1)) - rew, terminated, truncated, info = 1, False, False, {} - return obs, rew, terminated, truncated, info + rew, done, info = 1, False, {} + return obs, rew, done, info def reset(self, seed=None): ... @@ -39,9 +39,9 @@ def __init__(self, callback: Callable): self.cumulative_reward = 0 self.last_observation = None - def callback(self, obs_t, obs_tp1, action, rew, terminated, truncated, info): - _, obs_tp1, _, rew, _, _, _ = self.data_callback( - obs_t, obs_tp1, action, rew, terminated, truncated, info + def callback(self, obs_t, obs_tp1, action, rew, done, info): + _, obs_tp1, _, rew, _, _ = self.data_callback( + obs_t, obs_tp1, action, rew, done, info ) self.cumulative_reward += rew self.last_observation = obs_tp1 @@ -156,7 +156,7 @@ def test_play_loop_real_env(): ] keydown_events = [k for k in callback_events if k.type == KEYDOWN] - def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): + def callback(obs_t, obs_tp1, action, rew, done, info): pygame_event = callback_events.pop(0) event.post(pygame_event) @@ -166,7 +166,7 @@ def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): pygame_event = callback_events.pop(0) event.post(pygame_event) - return obs_t, obs_tp1, action, rew, terminated, truncated, info + return obs_t, obs_tp1, action, rew, done, info env = gym.make(ENV) env.reset(seed=SEED) @@ -177,7 +177,7 @@ def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): env.step(0) for e in keydown_events: action = keys_to_action[(e.key,)] - obs, _, _, _, _ = env.step(action) + obs, _, _, _ = env.step(action) env_play = gym.make(ENV) status = PlayStatus(callback) diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 41104799019..a02f0617b53 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -167,7 +167,7 @@ def test_reset_timeout_async_vector_env(shared_memory): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env.reset_async() - observations = env.reset_wait(timeout=0.1) + env.reset_wait(timeout=0.1) finally: env.close(terminate=True) @@ -178,7 +178,7 @@ def test_step_timeout_async_vector_env(shared_memory): with pytest.raises(TimeoutError): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - observations = env.reset() + env.reset() env.step_async([0.1, 0.1, 0.3, 0.1]) observations, rewards, dones, _ = env.step_wait(timeout=0.1) finally: @@ -192,7 +192,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory): with pytest.raises(NoAsyncCallError): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - observations = env.reset_wait() + env.reset_wait() except NoAsyncCallError as exception: assert exception.name == "reset" raise @@ -203,7 +203,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) actions = env.action_space.sample() - observations = env.reset() + env.reset() env.step_async(actions) env.reset_async() except NoAsyncCallError as exception: @@ -248,7 +248,7 @@ def test_already_closed_async_vector_env(shared_memory): with pytest.raises(ClosedEnvironmentError): env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env.close() - observations = env.reset() + env.reset() @pytest.mark.parametrize("shared_memory", [True, False]) diff --git a/tests/vector/test_shared_memory.py b/tests/vector/test_shared_memory.py index 120d109dd9f..5d18452a881 100644 --- a/tests/vector/test_shared_memory.py +++ b/tests/vector/test_shared_memory.py @@ -65,8 +65,7 @@ def assert_nested_type(lhs, rhs, n): # Assert the length of the array assert len(lhs[:]) == n * len(rhs[:]) # Assert the data type - assert type(lhs[0]) == type(rhs[0]) # noqa: E721 - + assert isinstance(lhs[0], type(rhs[0])) else: raise TypeError(f"Got unknown type `{type(lhs)}`.") @@ -83,7 +82,7 @@ def assert_nested_type(lhs, rhs, n): def test_create_shared_memory_custom_space(n, ctx, space): ctx = mp if (ctx is None) else mp.get_context(ctx) with pytest.raises(CustomSpaceError): - shared_memory = create_shared_memory(space, n=n, ctx=ctx) + create_shared_memory(space, n=n, ctx=ctx) @pytest.mark.parametrize( @@ -124,6 +123,10 @@ def write(i, shared_memory, sample): assert_nested_equal(shared_memory_n8, samples) +def _process_write(space, i, shared_memory, sample): + write_to_shared_memory(space, i, sample, shared_memory) + + @pytest.mark.parametrize( "space", spaces, ids=[space.__class__.__name__ for space in spaces] ) @@ -153,15 +156,13 @@ def assert_nested_equal(lhs, rhs, space, n): else: raise TypeError(f"Got unknown type `{type(space)}`") - def write(i, shared_memory, sample): - write_to_shared_memory(space, i, sample, shared_memory) - shared_memory_n8 = create_shared_memory(space, n=8) memory_view_n8 = read_from_shared_memory(space, shared_memory_n8, n=8) samples = [space.sample() for _ in range(8)] processes = [ - Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8) + Process(target=_process_write, args=(space, i, shared_memory_n8, samples[i])) + for i in range(8) ] for process in processes: diff --git a/tests/vector/test_spaces.py b/tests/vector/test_spaces.py index d01d5a45ad9..fa7aee0b4ff 100644 --- a/tests/vector/test_spaces.py +++ b/tests/vector/test_spaces.py @@ -1,9 +1,13 @@ +import copy + import numpy as np import pytest +from numpy.testing import assert_array_equal +from gym import Space from gym.spaces import Box, Dict, MultiDiscrete, Tuple from gym.vector.utils.spaces import batch_space, iterate -from tests.vector.utils import CustomSpace, custom_spaces, spaces +from tests.vector.utils import CustomSpace, assert_rng_equal, custom_spaces, spaces expected_batch_spaces_4 = [ Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64), @@ -129,3 +133,66 @@ def test_iterate_custom_space(space, batch_space): for i, item in enumerate(iterator): assert item in space assert i == 3 + + +@pytest.mark.parametrize( + "space", spaces, ids=[space.__class__.__name__ for space in spaces] +) +@pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]]) +@pytest.mark.parametrize( + "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] +) +def test_rng_different_at_each_index(space: Space, n: int, base_seed: int): + """ + Tests that the rng values produced at each index are different + to prevent if the rng is copied for each subspace + """ + space.seed(base_seed) + + batched_space = batch_space(space, n) + assert space.np_random is not batched_space.np_random + assert_rng_equal(space.np_random, batched_space.np_random) + + batched_sample = batched_space.sample() + sample = list(iterate(batched_space, batched_sample)) + assert not all(np.all(element == sample[0]) for element in sample), sample + + +@pytest.mark.parametrize( + "space", spaces, ids=[space.__class__.__name__ for space in spaces] +) +@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]]) +@pytest.mark.parametrize( + "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] +) +def test_deterministic(space: Space, n: int, base_seed: int): + """Tests the batched spaces are deterministic by using a copied version""" + # Copy the spaces and check that the np_random are not reference equal + space_a = space + space_a.seed(base_seed) + space_b = copy.deepcopy(space_a) + assert_rng_equal(space_a.np_random, space_b.np_random) + assert space_a.np_random is not space_b.np_random + + # Batch the spaces and check that the np_random are not reference equal + space_a_batched = batch_space(space_a, n) + space_b_batched = batch_space(space_b, n) + assert_rng_equal(space_a_batched.np_random, space_b_batched.np_random) + assert space_a_batched.np_random is not space_b_batched.np_random + # Create that the batched space is not reference equal to the origin spaces + assert space_a.np_random is not space_a_batched.np_random + + # Check that batched space a and b random number generator are not effected by the original space + space_a.sample() + space_a_batched_sample = space_a_batched.sample() + space_b_batched_sample = space_b_batched.sample() + for a_sample, b_sample in zip( + iterate(space_a_batched, space_a_batched_sample), + iterate(space_b_batched, space_b_batched_sample), + ): + if isinstance(a_sample, tuple): + assert len(a_sample) == len(b_sample) + for a_subsample, b_subsample in zip(a_sample, b_sample): + assert_array_equal(a_subsample, b_subsample) + else: + assert_array_equal(a_sample, b_sample) diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index 623803238ce..6407327e87f 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -1,9 +1,16 @@ import numpy as np import pytest +from gym.envs.registration import EnvSpec from gym.spaces import Box, Discrete, MultiDiscrete, Tuple from gym.vector.sync_vector_env import SyncVectorEnv -from tests.vector.utils import CustomSpace, make_custom_space_env, make_env +from tests.envs.spec_list import spec_list +from tests.vector.utils import ( + CustomSpace, + assert_rng_equal, + make_custom_space_env, + make_env, +) def test_create_sync_vector_env(): @@ -167,3 +174,27 @@ def test_custom_space_sync_vector_env(): "step(action-5)", "step(action-7)", ) + + +def test_sync_vector_env_seed(): + env = make_env("BipedalWalker-v3", seed=123)() + sync_vector_env = SyncVectorEnv([make_env("BipedalWalker-v3", seed=123)]) + + assert_rng_equal(env.action_space.np_random, sync_vector_env.action_space.np_random) + for _ in range(100): + env_action = env.action_space.sample() + vector_action = sync_vector_env.action_space.sample() + assert np.all(env_action == vector_action) + + +@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) +def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3): + """Check that for all environments, the sync vector envs produce the same action samples using the same seeds""" + env_1 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) + env_2 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) + assert_rng_equal(env_1.action_space.np_random, env_2.action_space.np_random) + + for _ in range(100): + env_1_samples = env_1.action_space.sample() + env_2_samples = env_2.action_space.sample() + assert np.all(env_1_samples == env_2_samples) diff --git a/tests/vector/test_vector_env_wrapper.py b/tests/vector/test_vector_env_wrapper.py index 4c8d165d175..156eaa47f64 100644 --- a/tests/vector/test_vector_env_wrapper.py +++ b/tests/vector/test_vector_env_wrapper.py @@ -1,4 +1,3 @@ -import gym from gym.vector import VectorEnvWrapper, make diff --git a/tests/vector/utils.py b/tests/vector/utils.py index 0eadb672642..b500163ae5d 100644 --- a/tests/vector/utils.py +++ b/tests/vector/utils.py @@ -5,6 +5,7 @@ import gym from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple +from gym.utils.seeding import RandomNumberGenerator spaces = [ Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64), @@ -75,10 +76,10 @@ class CustomSpace(gym.Space): """Minimal custom observation space.""" def sample(self): - return "sample" + return self.np_random.integers(0, 10, ()) def contains(self, x): - return isinstance(x, str) + return 0 <= x <= 10 def __eq__(self, other): return isinstance(other, CustomSpace) @@ -109,6 +110,7 @@ def step(self, action): def make_env(env_name, seed): def _make(): env = gym.make(env_name) + env.action_space.seed(seed) env.reset(seed=seed) return env @@ -131,3 +133,7 @@ def _make(): return env return _make + + +def assert_rng_equal(rng_1: RandomNumberGenerator, rng_2: RandomNumberGenerator): + assert rng_1.bit_generator.state == rng_2.bit_generator.state diff --git a/tests/wrappers/nested_dict_test.py b/tests/wrappers/nested_dict_test.py index bde47054137..87899724e32 100644 --- a/tests/wrappers/nested_dict_test.py +++ b/tests/wrappers/nested_dict_test.py @@ -5,7 +5,7 @@ import pytest import gym -from gym.spaces import Box, Dict, Discrete, Tuple +from gym.spaces import Box, Dict, Tuple from gym.wrappers import FilterObservation, FlattenObservation diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 76c035f87dd..39b47c756a4 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -1,4 +1,3 @@ -import types from typing import Optional from unittest.mock import MagicMock @@ -23,7 +22,9 @@ class DummyResetEnv(gym.Env): metadata = {} def __init__(self): - self.action_space = gym.spaces.Box(low=np.array([-1.0]), high=np.array([1.0])) + self.action_space = gym.spaces.Box( + low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float64 + ) self.observation_space = gym.spaces.Box( low=np.array([-1.0]), high=np.array([1.0]) ) @@ -63,6 +64,7 @@ def test_autoreset_reset_info(): obs, info = env.reset(return_info=True) assert ob_space.contains(obs) assert isinstance(info, dict) + env.close() @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) @@ -74,12 +76,10 @@ def test_make_autoreset_true(spec): Note: This test assumes that all first-party environments will terminate in a finite amount of time with random actions, which is true as of the time of adding this test. """ - env = None - with pytest.warns(None) as warnings: + with pytest.warns(None): env = spec.make(autoreset=True) - ob_space = env.observation_space - obs = env.reset(seed=0) + env.reset(seed=0) env.action_space.seed(0) env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset) @@ -90,22 +90,23 @@ def test_make_autoreset_true(spec): assert isinstance(env, AutoResetWrapper) assert env.unwrapped.reset.called + env.close() @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) def test_make_autoreset_false(spec): - env = None - with pytest.warns(None) as warnings: + with pytest.warns(None): env = spec.make(autoreset=False) assert not isinstance(env, AutoResetWrapper) + env.close() @pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list]) def test_make_autoreset_default_false(spec): - env = None - with pytest.warns(None) as warnings: + with pytest.warns(None): env = spec.make() assert not isinstance(env, AutoResetWrapper) + env.close() def test_autoreset_autoreset(): @@ -118,16 +119,16 @@ def test_autoreset_autoreset(): obs, reward, done, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert done == False + assert done is False assert info == {"count": 1} obs, reward, done, info = env.step(action) assert obs == np.array([2]) - assert done == False + assert done is False assert reward == 0 assert info == {"count": 2} obs, reward, done, info = env.step(action) assert obs == np.array([0]) - assert done == True + assert done is True assert reward == 1 assert info == { "count": 0, @@ -137,10 +138,11 @@ def test_autoreset_autoreset(): obs, reward, done, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert done == False + assert done is False assert info == {"count": 1} obs, reward, done, info = env.step(action) assert obs == np.array([2]) assert reward == 0 - assert done == False + assert done is False assert info == {"count": 2} + env.close() diff --git a/tests/wrappers/test_clip_action.py b/tests/wrappers/test_clip_action.py index aebf867b6e0..d00290f1bda 100644 --- a/tests/wrappers/test_clip_action.py +++ b/tests/wrappers/test_clip_action.py @@ -6,9 +6,8 @@ def test_clip_action(): # mountaincar: action-based rewards - make_env = lambda: gym.make("MountainCarContinuous-v0") - env = make_env() - wrapped_env = ClipAction(make_env()) + env = gym.make("MountainCarContinuous-v0") + wrapped_env = ClipAction(gym.make("MountainCarContinuous-v0")) seed = 0 diff --git a/tests/wrappers/test_frame_stack.py b/tests/wrappers/test_frame_stack.py index b9af3002c1f..4c0bdd88da3 100644 --- a/tests/wrappers/test_frame_stack.py +++ b/tests/wrappers/test_frame_stack.py @@ -1,8 +1,5 @@ -import pytest - -pytest.importorskip("gym.envs.atari") - import numpy as np +import pytest import gym from gym.wrappers import FrameStack @@ -13,6 +10,9 @@ lz4 = None +pytest.importorskip("gym.envs.atari") + + @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "Pong-v0"]) @pytest.mark.parametrize("num_stack", [2, 3, 4]) @pytest.mark.parametrize( diff --git a/tests/wrappers/test_order_enforcing.py b/tests/wrappers/test_order_enforcing.py index 9b9290aec76..47dd4597121 100644 --- a/tests/wrappers/test_order_enforcing.py +++ b/tests/wrappers/test_order_enforcing.py @@ -1,6 +1,3 @@ -import numpy as np -import pytest - import gym from gym.wrappers import OrderEnforcing diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index d9633409eb3..4b159877549 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -1,4 +1,3 @@ -import numpy as np import pytest import gym diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index 0757c1bec54..2627c6668d0 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -1,15 +1,8 @@ import os import shutil -import numpy as np -import pytest - import gym -from gym.wrappers import ( - RecordEpisodeStatistics, - RecordVideo, - capped_cubic_video_schedule, -) +from gym.wrappers import capped_cubic_video_schedule def test_record_video_using_default_trigger(): diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index 32e6e5d2ad0..fa7f70da430 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -1,6 +1,3 @@ -import numpy as np -import pytest - import gym from gym.wrappers import TimeLimit diff --git a/tests/wrappers/test_transform_observation.py b/tests/wrappers/test_transform_observation.py index fc1076ae4f7..695edce0780 100644 --- a/tests/wrappers/test_transform_observation.py +++ b/tests/wrappers/test_transform_observation.py @@ -7,7 +7,9 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_transform_observation(env_id): - affine_transform = lambda x: 3 * x + 2 + def affine_transform(x): + return 3 * x + 2 + env = gym.make(env_id) wrapped_env = TransformObservation( gym.make(env_id), lambda obs: affine_transform(obs) From 8b518bb58964aa8d93addb190d607994c75d341e Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 5 May 2022 11:16:20 +0530 Subject: [PATCH 12/37] existing code, tests work --- gym/core.py | 8 +- gym/envs/registration.py | 11 ++- gym/utils/env_checker.py | 2 +- gym/utils/play.py | 24 ++--- .../step_api_compatibility.py} | 89 ++++--------------- gym/vector/async_vector_env.py | 50 +++++++---- gym/vector/step_compatibility_vector.py | 37 -------- gym/vector/sync_vector_env.py | 32 ++++--- gym/vector/vector_env.py | 13 ++- gym/wrappers/__init__.py | 2 +- gym/wrappers/autoreset.py | 18 ++-- gym/wrappers/frame_stack.py | 16 ++-- gym/wrappers/normalize.py | 33 ++++--- gym/wrappers/record_episode_statistics.py | 27 +++--- gym/wrappers/record_video.py | 16 ++-- gym/wrappers/step_api_compatibility.py | 38 ++++++++ gym/wrappers/time_aware_observation.py | 13 +-- gym/wrappers/time_limit.py | 34 +++---- tests/utils/test_terminated_truncated.py | 20 ++--- tests/vector/test_vector_env.py | 4 +- tests/wrappers/test_autoreset.py | 4 +- .../test_record_episode_statistics.py | 3 + 22 files changed, 246 insertions(+), 248 deletions(-) rename gym/{wrappers/step_compatibility.py => utils/step_api_compatibility.py} (51%) delete mode 100644 gym/vector/step_compatibility_vector.py create mode 100644 gym/wrappers/step_api_compatibility.py diff --git a/gym/core.py b/gym/core.py index 097e628cabf..e61965259c5 100644 --- a/gym/core.py +++ b/gym/core.py @@ -237,13 +237,14 @@ class Wrapper(Env[ObsType, ActType]): """ - def __init__(self, env: Env): + def __init__(self, env: Env, new_step_api: bool = False): self.env = env self._action_space: spaces.Space | None = None self._observation_space: spaces.Space | None = None self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None self._metadata: dict | None = None + self.new_step_api = new_step_api def __getattr__(self, name): if name.startswith("_"): @@ -303,7 +304,7 @@ def step( ) -> Union[ Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] ]: - return self._get_env_step_returns(action) + return self.env.step(action) def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]: return self.env.reset(**kwargs) @@ -327,9 +328,6 @@ def __repr__(self): def unwrapped(self) -> Env: return self.env.unwrapped - def _get_env_step_returns(self, action): - return self.env.step(action) - class ObservationWrapper(Wrapper): def reset(self, **kwargs): diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 46931ffab4d..3dde3b41106 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -24,7 +24,12 @@ import numpy as np from gym.envs.__relocated__ import internal_env_relocation_map -from gym.wrappers import AutoResetWrapper, OrderEnforcing, StepCompatibility, TimeLimit +from gym.wrappers import ( + AutoResetWrapper, + OrderEnforcing, + StepAPICompatibility, + TimeLimit, +) if sys.version_info < (3, 10): import importlib_metadata as metadata # type: ignore @@ -439,7 +444,7 @@ def make( id: Name of the environment. max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). - new_step_api: Whether to use old or new step API (StepCompatibility wrapper). Will be removed at v1.0 + new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0 kwargs: Additional arguments to pass to the environment constructor. Returns: An instance of the environment. @@ -495,7 +500,7 @@ def make( if spec_.order_enforce: env = OrderEnforcing(env) - env = StepCompatibility(env, new_step_api) + env = StepAPICompatibility(env, new_step_api) if max_episode_steps is not None: env = TimeLimit(env, max_episode_steps) diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 2128d39150d..e06065bd327 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -53,7 +53,7 @@ def _check_nan(env: gym.Env, check_inf: bool = True) -> None: """Check for NaN and Inf.""" for _ in range(10): action = env.action_space.sample() - observation, reward, _, _, _ = env.step(action) + observation, reward, _, _ = env.step(action) if np.any(np.isnan(observation)): logger.warn("Encountered NaN value in observations.") diff --git a/gym/utils/play.py b/gym/utils/play.py index 32e32b9f269..649cd267388 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -117,7 +117,7 @@ def play( gym.utils.play.PlayPlot. Here's a sample code for plotting the reward for last 5 second of gameplay. - def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): + def callback(obs_t, obs_tp1, action, rew, done, info): return [rew,] plotter = PlayPlot(callback, 30 * 5, ["reward"]) @@ -144,8 +144,7 @@ def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): obs_tp1: observation after performing action action: action that was executed rew: reward that was received - terminated: whether the environment is terminated or not - truncated: whether the environment is truncated or not + done: whether the environment is done or not info: debug info keys_to_action: dict: tuple(int) -> int or None Mapping from keys pressed to action performed. @@ -164,22 +163,19 @@ def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): env.reset(seed=seed) game = PlayableGame(env, keys_to_action, zoom) - terminated = True - truncated = True + done = True clock = pygame.time.Clock() while game.running: - if terminated or truncated: - terminated = False - truncated = False + if done: + done = False obs = env.reset(seed=seed) else: action = keys_to_action.get(tuple(sorted(game.pressed_keys)), 0) prev_obs = obs - obs, rew, terminated, truncated, info = env.step(action) + obs, rew, done, info = env.step(action) if callback is not None: - callback(prev_obs, obs, action, rew, terminated, truncated, info) - + callback(prev_obs, obs, action, rew, done, info) if obs is not None: # TODO: this needs to be updated when the render API change goes through rendered = env.render(mode="rgb_array") @@ -217,10 +213,8 @@ def __init__(self, callback, horizon_timesteps, plot_names): self.cur_plot = [None for _ in range(num_plots)] self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] - def callback(self, obs_t, obs_tp1, action, rew, terminated, truncated, info): - points = self.data_callback( - obs_t, obs_tp1, action, rew, terminated, truncated, info - ) + def callback(self, obs_t, obs_tp1, action, rew, done, info): + points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) for point, data_series in zip(points, self.data): data_series.append(point) self.t += 1 diff --git a/gym/wrappers/step_compatibility.py b/gym/utils/step_api_compatibility.py similarity index 51% rename from gym/wrappers/step_compatibility.py rename to gym/utils/step_api_compatibility.py index b0579305fa5..fe0301772db 100644 --- a/gym/wrappers/step_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -1,54 +1,20 @@ -import gym -from gym import logger +import numpy as np - -class StepCompatibility(gym.Wrapper): - r"""A wrapper which can transform an environment from new step API to old and vice-versa. - - Old step API refers to step() method returning (observation, reward, done, info) - New step API refers to step() method returning (observation, reward, terminated, truncated, info) - (Refer to docs for details on the API change) - - This wrapper is to be used to ease transition to new API and for backward compatibility. It will be removed in v1.0 - - - Parameters - ---------- - env (gym.Env): the env to wrap. Can be in old or new API - new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) - - """ - - def __init__(self, env: gym.Env, new_step_api=False): - super().__init__(env) - self.new_step_api = new_step_api - if not self.new_step_api: - logger.deprecation( - "Initializing environment in old step API which returns one bool instead of two. " - "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " - "To use these features, please set `new_step_api=True` in make to use new API (see docs for more details)." - ) - - def step(self, action): - step_returns = self.env.step(action) - if self.new_step_api: - return step_to_new_api(step_returns) - else: - return step_to_old_api(step_returns) +from gym.logger import deprecation def step_to_new_api(step_returns, is_vector_env=False): # Method to transform step returns to new step API if len(step_returns) == 5: - logger.deprecation( + deprecation( "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " "Take care to supporting code to be compatible with this API" ) return step_returns else: assert len(step_returns) == 4 - logger.deprecation( + deprecation( "Using a wrapper to transform env with old step API into new. This wrapper will be removed in v1.0. " "It is recommended to upgrade the core env to the new step API." "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate. " @@ -79,8 +45,8 @@ def step_to_new_api(step_returns, is_vector_env=False): return ( observations, rewards, - terminateds if is_vector_env else terminateds[0], - truncateds if is_vector_env else truncateds[0], + np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0], + np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0], infos if is_vector_env else infos[0], ) @@ -89,14 +55,14 @@ def step_to_old_api(step_returns, is_vector_env=False): # Method to transform step returns to old step API if len(step_returns) == 4: - logger.deprecation( + deprecation( "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" ) return step_returns else: assert len(step_returns) == 5 - logger.deprecation( + deprecation( "Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " "This wrapper will be removed in v1.0 " "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " @@ -117,38 +83,15 @@ def step_to_old_api(step_returns, is_vector_env=False): return ( observations, rewards, - dones if is_vector_env else dones[0], + np.array(dones, dtype=np.bool_) if is_vector_env else dones[0], infos if is_vector_env else infos[0], ) -def step_api_compatibility(WrapperClass): - """ - A step API compatibility wrapper function to transform wrappers in new step API to old - """ - - class StepCompatibilityWrapper(StepCompatibility): - def __init__(self, env: gym.Wrapper, output_new_step_api: bool = False): - super().__init__(WrapperClass(env), output_new_step_api) - if hasattr(WrapperClass, "new_step_api"): - self.has_new_step_api = WrapperClass.new_step_api - else: - self.has_new_step_api = False - self.wrap = WrapperClass(env) - - def _get_env_step_returns(self, action): - return ( - step_to_new_api(self.wrap.step(action)) - if self.has_new_step_api - else step_to_old_api(self.wrap.step(action)) - ) - - return StepCompatibilityWrapper - - -# def check_is_new_api(env: Union[gym.Env, gym.Wrapper]): -# env_copy = deepcopy(env) -# env_copy.reset() -# step_returns = env_copy.step(env_copy.action_space.sample()) -# del env_copy -# return len(step_returns) == 5 +def step_api_compatibility( + step_returns, new_step_api: bool = False, is_vector_env: bool = False +): + if new_step_api: + return step_to_new_api(step_returns, is_vector_env) + else: + return step_to_old_api(step_returns, is_vector_env) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index f097ba0e5c4..53c6a582430 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -14,7 +14,7 @@ CustomSpaceError, NoAsyncCallError, ) -from gym.vector.step_compatibility_vector import step_api_vector_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility from gym.vector.utils import ( CloudpickleWrapper, clear_mpi_env_vars, @@ -26,7 +26,6 @@ write_to_shared_memory, ) from gym.vector.vector_env import VectorEnv -from gym.wrappers.step_compatibility import step_to_new_api __all__ = ["AsyncVectorEnv"] @@ -38,7 +37,6 @@ class AsyncState(Enum): WAITING_CALL = "call" -@step_api_vector_compatibility class AsyncVectorEnv(VectorEnv): """Vectorized environment that runs multiple environments in parallel. It uses `multiprocessing`_ processes, and pipes for communication. @@ -122,6 +120,7 @@ def __init__( context=None, daemon=True, worker=None, + new_step_api=False, ): ctx = mp.get_context(context) self.env_fns = env_fns @@ -139,6 +138,7 @@ def __init__( num_envs=len(env_fns), observation_space=observation_space, action_space=action_space, + new_step_api=new_step_api, ) if self.shared_memory: @@ -415,7 +415,13 @@ def step_wait(self, timeout=None): results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) self._state = AsyncState.DEFAULT - observations_list, rewards, terminateds, truncateds, infos = zip(*results) + ( + observations_list, + rewards, + terminateds, + truncateds, + infos, + ) = step_api_compatibility(tuple(zip(*results)), True, True) if not self.shared_memory: self.observations = concatenate( @@ -424,12 +430,16 @@ def step_wait(self, timeout=None): self.observations, ) - return ( - deepcopy(self.observations) if self.copy else self.observations, - np.array(rewards), - np.array(terminateds, dtype=np.bool_), - np.array(truncateds, dtype=np.bool_), - infos, + return step_api_compatibility( + ( + deepcopy(self.observations) if self.copy else self.observations, + np.array(rewards), + np.array(terminateds, dtype=np.bool_), + np.array(truncateds, dtype=np.bool_), + infos, + ), + self.new_step_api, + True, ) def call_async(self, name, *args, **kwargs): @@ -653,9 +663,13 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): pipe.send((observation, True)) elif command == "step": - observation, reward, terminated, truncated, info = step_to_new_api( - env.step(data) - ) + ( + observation, + reward, + terminated, + truncated, + info, + ) = step_api_compatibility(env.step(data), True) if terminated or truncated: info["closing_observation"] = observation observation = env.reset() @@ -724,9 +738,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error ) pipe.send((None, True)) elif command == "step": - observation, reward, terminated, truncated, info = step_to_new_api( - env.step(data) - ) + ( + observation, + reward, + terminated, + truncated, + info, + ) = step_api_compatibility(env.step(data), True) if terminated or truncated: info["closing_observation"] = observation observation = env.reset() diff --git a/gym/vector/step_compatibility_vector.py b/gym/vector/step_compatibility_vector.py deleted file mode 100644 index 01ccfbcd41f..00000000000 --- a/gym/vector/step_compatibility_vector.py +++ /dev/null @@ -1,37 +0,0 @@ -from gym.vector.vector_env import VectorEnvWrapper -from gym.wrappers.step_compatibility import step_to_new_api, step_to_old_api - - -def step_api_vector_compatibility(VectorEnvClass): - class StepCompatibilityVector(VectorEnvWrapper): - r"""A wrapper which can transform a vector environment to a new or old step API. - - Old step API refers to step() method returning (observation, reward, done, info) - New step API refers to step() method returning (observation, reward, terminated, truncated, info) - (Refer to docs for details on the API change) - - This wrapper is to be used to ease transition to new API. It will be removed in v1.0 - - Parameters - ---------- - env (gym.vector.VectorEnv): the vector env to wrap. Has to be in new step API - new_step_api (bool): True to use vector env with new step API, False to use vector env with old step API. (True by default) - - """ - - def __init__(self, *args, **kwargs): - self.new_step_api = kwargs.get("new_step_api", False) - kwargs.pop("new_step_api", None) - super().__init__(VectorEnvClass(*args, **kwargs)) - - def step_wait(self): - step_returns = self.env.step_wait() - if self.new_step_api: - return step_to_new_api(step_returns) - else: - return step_to_old_api(step_returns) - - def __del__(self): - self.env.__del__() - - return StepCompatibilityVector diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 78652b478c1..c68654dd1c0 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -3,15 +3,13 @@ import numpy as np -from gym.vector.step_compatibility_vector import step_api_vector_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.vector_env import VectorEnv -from gym.wrappers.step_compatibility import step_to_new_api __all__ = ["SyncVectorEnv"] -@step_api_vector_compatibility class SyncVectorEnv(VectorEnv): """Vectorized environment that serially runs multiple environments. @@ -53,7 +51,14 @@ class SyncVectorEnv(VectorEnv): [-0.85009176, 0.5266346 , 0.60007906]], dtype=float32) """ - def __init__(self, env_fns, observation_space=None, action_space=None, copy=True): + def __init__( + self, + env_fns, + observation_space=None, + action_space=None, + copy=True, + new_step_api=False, + ): self.env_fns = env_fns self.envs = [env_fn() for env_fn in env_fns] self.copy = copy @@ -66,6 +71,7 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True num_envs=len(env_fns), observation_space=observation_space, action_space=action_space, + new_step_api=new_step_api, ) self._check_spaces() @@ -144,7 +150,7 @@ def step_wait(self): self._terminateds[i], self._truncateds[i], info, - ) = step_to_new_api(env.step(action)) + ) = step_api_compatibility(env.step(action), True) if self._terminateds[i] or self._truncateds[i]: info["closing_observation"] = observation observation = env.reset() @@ -154,12 +160,16 @@ def step_wait(self): self.single_observation_space, observations, self.observations ) - return ( - deepcopy(self.observations) if self.copy else self.observations, - np.copy(self._rewards), - np.copy(self._terminateds), - np.copy(self._truncateds), - infos, + return step_api_compatibility( + ( + deepcopy(self.observations) if self.copy else self.observations, + np.copy(self._rewards), + np.copy(self._terminateds), + np.copy(self._truncateds), + infos, + ), + new_step_api=self.new_step_api, + is_vector_env=True, ) def call(self, name, *args, **kwargs): diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index 78004e8f5dd..ce34b7e689c 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -33,7 +33,7 @@ class VectorEnv(gym.Env): Action space of a single environment. """ - def __init__(self, num_envs, observation_space, action_space): + def __init__(self, num_envs, observation_space, action_space, new_step_api=False): self.num_envs = num_envs self.is_vector_env = True self.observation_space = batch_space(observation_space, n=num_envs) @@ -47,6 +47,8 @@ def __init__(self, num_envs, observation_space, action_space): self.single_observation_space = observation_space self.single_action_space = action_space + self.new_step_api = new_step_api + def reset_async( self, seed: Optional[Union[int, List[int]]] = None, @@ -269,6 +271,15 @@ def close_extras(self, **kwargs): def seed(self, seed=None): return self.env.seed(seed) + def call(self, *args, **kwargs): + return self.env.call(*args, **kwargs) + + # def setattr(self, name, values): + # return self.env.set_attr(name, values) + + # def getattr(self, name): + # return self.env.getattr(name) + # implicitly forward all other methods and attributes to self.env def __getattr__(self, name): if name.startswith("_"): diff --git a/gym/wrappers/__init__.py b/gym/wrappers/__init__.py index d0e70218941..b38f8843193 100644 --- a/gym/wrappers/__init__.py +++ b/gym/wrappers/__init__.py @@ -11,7 +11,7 @@ from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.resize_observation import ResizeObservation -from gym.wrappers.step_compatibility import StepCompatibility +from gym.wrappers.step_api_compatibility import StepAPICompatibility from gym.wrappers.time_aware_observation import TimeAwareObservation from gym.wrappers.time_limit import TimeLimit from gym.wrappers.transform_observation import TransformObservation diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index ef829b8f41a..64c50540ec0 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -1,8 +1,7 @@ import gym -from gym.wrappers.step_compatibility import step_api_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility -@step_api_compatibility class AutoResetWrapper(gym.Wrapper): """ A class for providing an automatic reset functionality @@ -40,14 +39,13 @@ class AutoResetWrapper(gym.Wrapper): use this wrapper! """ - new_step_api = True # whether this wrapper is written in new API (assumed old API if not present) - - def __init__(self, env: gym.Env) -> None: - super().__init__(env) - self.new_step_api = True + def __init__(self, env: gym.Env, new_step_api: bool = False) -> None: + super().__init__(env, new_step_api) def step(self, action): - obs, reward, terminated, truncated, info = self._get_env_step_returns(action) + obs, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), new_step_api=True + ) if terminated or truncated: @@ -65,4 +63,6 @@ def step(self, action): obs = new_obs info = new_info - return obs, reward, terminated, truncated, info + return step_api_compatibility( + (obs, reward, terminated, truncated, info), self.new_step_api + ) diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index b6ff76571a6..fa4cb73fdc5 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -4,7 +4,7 @@ from gym import ObservationWrapper from gym.spaces import Box -from gym.wrappers.step_compatibility import step_api_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility class LazyFrames: @@ -63,7 +63,6 @@ def _check_decompress(self, frame): return frame -@step_api_compatibility class FrameStack(ObservationWrapper): r"""Observation wrapper that stacks the observations in a rolling manner. @@ -95,10 +94,9 @@ class FrameStack(ObservationWrapper): lz4_compress (bool): use lz4 to compress the frames internally """ - new_step_api = True - def __init__(self, env, num_stack, lz4_compress=False): - super().__init__(env) + def __init__(self, env, num_stack, lz4_compress=False, new_step_api: bool = False): + super().__init__(env, new_step_api) self.num_stack = num_stack self.lz4_compress = lz4_compress @@ -117,11 +115,13 @@ def observation(self): return LazyFrames(list(self.frames), self.lz4_compress) def step(self, action): - observation, reward, terminated, truncated, info = self._get_env_step_returns( - action + observation, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), True ) self.frames.append(observation) - return self.observation(), reward, terminated, truncated, info + return step_api_compatibility( + (self.observation(), reward, terminated, truncated, info), self.new_step_api + ) def reset(self, **kwargs): if kwargs.get("return_info", False): diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index 29a932abff5..0c2c78866a2 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -1,7 +1,7 @@ import numpy as np import gym -from gym.wrappers.step_compatibility import step_api_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility # taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py @@ -51,14 +51,13 @@ class NormalizeObservation(gym.core.Wrapper): epsilon: A stability parameter that is used when scaling the observations. """ - new_step_api = True - def __init__( self, env, epsilon=1e-8, + new_step_api=False, ): - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) if self.is_vector_env: @@ -68,12 +67,18 @@ def __init__( self.epsilon = epsilon def step(self, action): - obs, rews, terminateds, truncateds, infos = self._get_env_step_returns(action) + obs, rews, terminateds, truncateds, infos = step_api_compatibility( + self.env.step(action), True, self.is_vector_env + ) if self.is_vector_env: obs = self.normalize(obs) else: obs = self.normalize(np.array([obs]))[0] - return obs, rews, terminateds, truncateds, infos + return step_api_compatibility( + (obs, rews, terminateds, truncateds, infos), + self.new_step_api, + self.is_vector_env, + ) def reset(self, **kwargs): return_info = kwargs.get("return_info", False) @@ -95,7 +100,6 @@ def normalize(self, obs): return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon) -@step_api_compatibility class NormalizeReward(gym.core.Wrapper): """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. @@ -110,15 +114,14 @@ class NormalizeReward(gym.core.Wrapper): gamma (float): The discount factor that is used in the exponential moving average. """ - new_step_api = True - def __init__( self, env, gamma=0.99, epsilon=1e-8, + new_step_api=False, ): - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) self.return_rms = RunningMeanStd(shape=()) @@ -127,7 +130,9 @@ def __init__( self.epsilon = epsilon def step(self, action): - obs, rews, terminateds, truncateds, infos = self._get_env_step_returns(action) + obs, rews, terminateds, truncateds, infos = step_api_compatibility( + self.env.step(action), True, self.is_vector_env + ) if not self.is_vector_env: rews = np.array([rews]) self.returns = self.returns * self.gamma + rews @@ -139,7 +144,11 @@ def step(self, action): self.returns[dones] = 0.0 if not self.is_vector_env: rews = rews[0] - return obs, rews, terminateds, truncateds, infos + return step_api_compatibility( + (obs, rews, terminateds, truncateds, infos), + self.new_step_api, + self.is_vector_env, + ) def normalize(self, rews): self.return_rms.update(self.returns) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index b793a5b2743..2a9e94c28be 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -4,10 +4,9 @@ import numpy as np import gym -from gym.wrappers.step_compatibility import step_api_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility -@step_api_compatibility class RecordEpisodeStatistics(gym.Wrapper): """This wrapper will keep track of cumulative rewards and episode lengths. @@ -37,10 +36,8 @@ class RecordEpisodeStatistics(gym.Wrapper): length_queue: The lengths of the last `deque_size`-many episodes """ - new_step_api = True - - def __init__(self, env, deque_size=100): - super().__init__(env) + def __init__(self, env, deque_size=100, new_step_api=False): + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.t0 = time.perf_counter() self.episode_count = 0 @@ -63,7 +60,7 @@ def step(self, action): terminateds, truncateds, infos, - ) = self._get_env_step_returns(action) + ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) self.episode_returns += rewards self.episode_lengths += 1 if not self.is_vector_env: @@ -90,10 +87,14 @@ def step(self, action): self.episode_lengths[i] = 0 if self.is_vector_env: infos = tuple(infos) - return ( - observations, - rewards, - terminateds if self.is_vector_env else terminateds[0], - truncateds if self.is_vector_env else truncateds[0], - infos if self.is_vector_env else infos[0], + return step_api_compatibility( + ( + observations, + rewards, + terminateds if self.is_vector_env else terminateds[0], + truncateds if self.is_vector_env else truncateds[0], + infos if self.is_vector_env else infos[0], + ), + self.new_step_api, + self.is_vector_env, ) diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index 7a746bb47a7..64c47caff75 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -3,7 +3,7 @@ import gym from gym import logger -from gym.wrappers.step_compatibility import step_api_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility from gym.wrappers.monitoring import video_recorder @@ -14,7 +14,6 @@ def capped_cubic_video_schedule(episode_id): return episode_id % 1000 == 0 -@step_api_compatibility class RecordVideo(gym.Wrapper): """This wrapper records videos of rollouts. @@ -37,8 +36,6 @@ class RecordVideo(gym.Wrapper): name_prefix (str): Will be prepended to the filename of the recordings """ - new_step_api = True - def __init__( self, env, @@ -47,8 +44,9 @@ def __init__( step_trigger: Callable[[int], bool] = None, video_length: int = 0, name_prefix: str = "rl-video", + new_step_api: bool = False, ): - super().__init__(env) + super().__init__(env, new_step_api) if episode_trigger is None and step_trigger is None: episode_trigger = capped_cubic_video_schedule @@ -114,7 +112,7 @@ def step(self, action): terminateds, truncateds, infos, - ) = self._get_env_step_returns(action) + ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) # increment steps and episodes self.step_id += 1 @@ -140,7 +138,11 @@ def step(self, action): elif self._video_enabled(): self.start_video_recorder() - return observations, rewards, terminateds, truncateds, infos + return step_api_compatibility( + (observations, rewards, terminateds, truncateds, infos), + self.new_step_api, + self.is_vector_env, + ) def close_video_recorder(self) -> None: if self.recording: diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py new file mode 100644 index 00000000000..cceffcc6c71 --- /dev/null +++ b/gym/wrappers/step_api_compatibility.py @@ -0,0 +1,38 @@ +import gym +from gym.logger import deprecation +from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api + + +class StepAPICompatibility(gym.Wrapper): + r"""A wrapper which can transform an environment from new step API to old and vice-versa. + + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) + (Refer to docs for details on the API change) + + This wrapper is to be used to ease transition to new API and for backward compatibility. It will be removed in v1.0 + + + Parameters + ---------- + env (gym.Env): the env to wrap. Can be in old or new API + new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) + + """ + + def __init__(self, env: gym.Env, new_step_api=False): + super().__init__(env) + self.new_step_api = new_step_api + if not self.new_step_api: + deprecation( + "Initializing environment in old step API which returns one bool instead of two. " + "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " + "To use these features, please set `new_step_api=True` in make to use new API (see docs for more details)." + ) + + def step(self, action): + step_returns = self.env.step(action) + if self.new_step_api: + return step_to_new_api(step_returns) + else: + return step_to_old_api(step_returns) diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py index 837bd80894b..6d39aed0f92 100644 --- a/gym/wrappers/time_aware_observation.py +++ b/gym/wrappers/time_aware_observation.py @@ -2,10 +2,9 @@ from gym import ObservationWrapper from gym.spaces import Box -from gym.wrappers.step_compatibility import step_api_compatibility +from gym.utils.step_api_compatibility import step_api_compatibility -@step_api_compatibility class TimeAwareObservation(ObservationWrapper): r"""Augment the observation with current time step in the trajectory. @@ -14,22 +13,24 @@ class TimeAwareObservation(ObservationWrapper): support pixel observation space yet. """ - new_step_api = True - def __init__(self, env): - super().__init__(env) + def __init__(self, env, new_step_api=False): + super().__init__(env, new_step_api) assert isinstance(env.observation_space, Box) assert env.observation_space.dtype == np.float32 low = np.append(self.observation_space.low, 0.0) high = np.append(self.observation_space.high, np.inf) self.observation_space = Box(low, high, dtype=np.float32) + self.is_vector_env = getattr(env, "is_vector_env", False) def observation(self, observation): return np.append(observation, self.t) def step(self, action): self.t += 1 - return super().step(action) + return step_api_compatibility( + super().step(action), self.new_step_api, self.is_vector_env + ) def reset(self, **kwargs): self.t = 0 diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index ebf85ea6f49..a574fb0b71c 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -1,6 +1,7 @@ from typing import Optional import gym +from gym.utils.step_api_compatibility import step_api_compatibility class TimeLimit(gym.Wrapper): @@ -18,8 +19,10 @@ class TimeLimit(gym.Wrapper): max_episode_steps (Optional[int]): The maximum number of steps until a done-signal occurs. If it is `None`, the value from `env.spec` (if available) will be used """ - def __init__(self, env, max_episode_steps: Optional[int] = None): - super().__init__(env) + def __init__( + self, env, max_episode_steps: Optional[int] = None, new_step_api: bool = False + ): + super().__init__(env, new_step_api) if max_episode_steps is None and self.env.spec is not None: max_episode_steps = env.spec.max_episode_steps if self.env.spec is not None: @@ -28,20 +31,19 @@ def __init__(self, env, max_episode_steps: Optional[int] = None): self._elapsed_steps = None def step(self, action): - step_returns = self._get_env_step_returns(action) - if len(step_returns) == 4: - observation, reward, done, info = self.env.step(action) - if self._elapsed_steps >= self._max_episode_steps: - info["TimeLimit.truncated"] = not done - done = True - return observation, reward, done, info - else: - observation, reward, terminated, truncated, info = step_returns - self._elapsed_steps += 1 - if self._elapsed_steps >= self._max_episode_steps: - truncated = True - info["TimeLimit.truncated"] = truncated - return observation, reward, terminated, truncated, info + observation, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), + True, + ) + self._elapsed_steps += 1 + + if self._elapsed_steps >= self._max_episode_steps: + truncated = True + + return step_api_compatibility( + (observation, reward, terminated, truncated, info), + self.new_step_api, + ) def reset(self, **kwargs): self._elapsed_steps = 0 diff --git a/tests/utils/test_terminated_truncated.py b/tests/utils/test_terminated_truncated.py index a4d810bce3a..e74fdc85378 100644 --- a/tests/utils/test_terminated_truncated.py +++ b/tests/utils/test_terminated_truncated.py @@ -29,7 +29,7 @@ def reset(self): @pytest.mark.parametrize("time_limit", [10, 20, 30]) def test_terminated_truncated(time_limit): - test_env = TimeLimit(DummyEnv(), time_limit) + test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True) terminated = False truncated = False @@ -53,11 +53,13 @@ def test_terminated_truncated(time_limit): def test_terminated_truncated_vector(): - env0 = TimeLimit(DummyEnv(), 10) - env1 = TimeLimit(DummyEnv(), 20) - env2 = TimeLimit(DummyEnv(), 30) + env0 = TimeLimit(DummyEnv(), 10, new_step_api=True) + env1 = TimeLimit(DummyEnv(), 20, new_step_api=True) + env2 = TimeLimit(DummyEnv(), 30, new_step_api=True) - async_env = AsyncVectorEnv([lambda: env0, lambda: env1, lambda: env2]) + async_env = AsyncVectorEnv( + [lambda: env0, lambda: env1, lambda: env2], new_step_api=True + ) async_env.reset() terminateds = [False, False, False] truncateds = [False, False, False] @@ -72,7 +74,9 @@ def test_terminated_truncated_vector(): assert all(terminateds == [False, True, True]) assert all(truncateds == [True, True, False]) - sync_env = SyncVectorEnv([lambda: env0, lambda: env1, lambda: env2]) + sync_env = SyncVectorEnv( + [lambda: env0, lambda: env1, lambda: env2], new_step_api=True + ) sync_env.reset() terminateds = [False, False, False] truncateds = [False, False, False] @@ -85,7 +89,3 @@ def test_terminated_truncated_vector(): assert counter == 20 assert all(terminateds == [False, True, True]) assert all(truncateds == [True, True, False]) - - -if __name__ == "__main__": - test_terminated_truncated(10) diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 82870d79c29..5544f8660f4 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -37,8 +37,8 @@ def test_vector_env_equal(shared_memory): for idx in range(len(sync_dones)): if sync_dones[idx]: - assert "terminal_observation" in async_infos[idx] - assert "terminal_observation" in sync_infos[idx] + assert "closing_observation" in async_infos[idx] + assert "closing_observation" in sync_infos[idx] assert sync_dones[idx] assert np.all(async_observations == sync_observations) diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 39b47c756a4..f676f7d0688 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -132,8 +132,8 @@ def test_autoreset_autoreset(): assert reward == 1 assert info == { "count": 0, - "terminal_observation": np.array([3]), - "terminal_info": {"count": 3}, + "closing_observation": np.array([3]), + "closing_info": {"count": 3}, } obs, reward, done, info = env.step(action) assert obs == np.array([1]) diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index 4b159877549..9ec7666db4c 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -55,3 +55,6 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): assert "episode" in info assert all([item in info["episode"] for item in ["r", "l", "t"]]) break + + +test_record_episode_statistics("Pendulum-v1", 2) From 9a2a9af452e15f07cacedb8c4884492220c8b744 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 5 May 2022 11:46:45 +0530 Subject: [PATCH 13/37] fix compat at registration, tests --- gym/envs/registration.py | 9 +++--- .../vector/test_step_compatibility_vector.py | 20 +++++------- tests/wrappers/test_step_compatibility.py | 31 ++++++++++--------- tests/wrappers/test_time_limit_info.py | 1 - 4 files changed, 28 insertions(+), 33 deletions(-) delete mode 100644 tests/wrappers/test_time_limit_info.py diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 3dde3b41106..1a10daa9aa4 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -497,18 +497,17 @@ def make( env.unwrapped.spec = spec_ + env = StepAPICompatibility(env, new_step_api) if spec_.order_enforce: env = OrderEnforcing(env) - env = StepAPICompatibility(env, new_step_api) - if max_episode_steps is not None: - env = TimeLimit(env, max_episode_steps) + env = TimeLimit(env, max_episode_steps, new_step_api) elif spec_.max_episode_steps is not None: - env = TimeLimit(env, spec_.max_episode_steps) + env = TimeLimit(env, spec_.max_episode_steps, new_step_api) if autoreset: - env = AutoResetWrapper(env) + env = AutoResetWrapper(env, new_step_api) return env diff --git a/tests/vector/test_step_compatibility_vector.py b/tests/vector/test_step_compatibility_vector.py index 2096dedbd0a..4a14c97d84f 100644 --- a/tests/vector/test_step_compatibility_vector.py +++ b/tests/vector/test_step_compatibility_vector.py @@ -3,8 +3,8 @@ import gym from gym.spaces import Discrete -from gym.vector import AsyncVectorEnv, StepCompatibilityVector, SyncVectorEnv -from gym.wrappers import StepCompatibility +from gym.vector import AsyncVectorEnv, SyncVectorEnv +from gym.wrappers import StepAPICompatibility class OldStepEnv(gym.Env): @@ -44,13 +44,11 @@ def step(self, action): def test_vector_step_compatibility_new_env(VecEnv): envs = [ - StepCompatibility(OldStepEnv()), + OldStepEnv(), NewStepEnv(), - ] # input to vec env must be in new step api + ] - vec_env = StepCompatibilityVector( - VecEnv([lambda: env for env in envs]), return_two_dones=False - ) + vec_env = VecEnv([lambda: env for env in envs]) vec_env.reset() step_returns = vec_env.step([0, 0]) assert len(step_returns) == 4 @@ -58,7 +56,7 @@ def test_vector_step_compatibility_new_env(VecEnv): assert dones.dtype == np.bool_ vec_env.close() - vec_env = StepCompatibilityVector(VecEnv([lambda: env for env in envs])) + vec_env = VecEnv([lambda: env for env in envs], new_step_api=True) vec_env.reset() step_returns = vec_env.step([0, 0]) assert len(step_returns) == 5 @@ -71,9 +69,7 @@ def test_vector_step_compatibility_new_env(VecEnv): @pytest.mark.parametrize("async_bool", [True, False]) def test_vector_step_compatibility_existing(async_bool): - env = gym.vector.make( - "CartPole-v1", num_envs=3, asynchronous=async_bool, return_two_dones=False - ) + env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool) env.reset() step_returns = env.step(env.action_space.sample()) assert len(step_returns) == 4 @@ -82,7 +78,7 @@ def test_vector_step_compatibility_existing(async_bool): env.close() env = gym.vector.make( - "CartPole-v1", num_envs=3, asynchronous=async_bool, return_two_dones=True + "CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True ) env.reset() step_returns = env.step(env.action_space.sample()) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index 213c7ed2d3c..ab30afb9531 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -1,10 +1,8 @@ -import warnings - import pytest import gym from gym.spaces import Discrete -from gym.wrappers import StepCompatibility +from gym.wrappers import StepAPICompatibility class OldStepEnv(gym.Env): @@ -36,7 +34,7 @@ def step(self, action): @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) def test_step_compatibility_to_new_api(env): - env = StepCompatibility(env(), True) + env = StepAPICompatibility(env(), True) step_returns = env.step(0) _, _, terminated, truncated, _ = step_returns assert isinstance(terminated, bool) @@ -44,36 +42,39 @@ def test_step_compatibility_to_new_api(env): @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -@pytest.mark.parametrize("return_two_dones", [None, False]) -def test_step_compatibility_to_old_api(env, return_two_dones): - if return_two_dones is None: - env = StepCompatibility(env()) # default behavior is to retain old API +@pytest.mark.parametrize("new_step_api", [None, False]) +def test_step_compatibility_to_old_api(env, new_step_api): + if new_step_api is None: + env = StepAPICompatibility(env()) # default behavior is to retain old API else: - env = StepCompatibility(env(), return_two_dones) + env = StepAPICompatibility(env(), new_step_api) step_returns = env.step(0) assert len(step_returns) == 4 _, _, done, _ = step_returns assert isinstance(done, bool) -@pytest.mark.parametrize("return_two_dones", [None, True, False]) -def test_step_compatibility_in_make(return_two_dones): - if return_two_dones is None: +@pytest.mark.parametrize("new_step_api", [None, True, False]) +def test_step_compatibility_in_make(new_step_api): + if new_step_api is None: with pytest.warns( DeprecationWarning, match="Initializing environment in old step API" ): env = gym.make("CartPole-v1") else: - env = gym.make("CartPole-v1", return_two_dones=return_two_dones) + env = gym.make("CartPole-v1", new_step_api=new_step_api) env.reset() step_returns = env.step(0) - if return_two_dones == True: # new api + if new_step_api: assert len(step_returns) == 5 _, _, terminated, truncated, _ = step_returns assert isinstance(terminated, bool) assert isinstance(truncated, bool) - else: # old api + else: assert len(step_returns) == 4 _, _, done, _ = step_returns assert isinstance(done, bool) + + +test_step_compatibility_in_make(True) diff --git a/tests/wrappers/test_time_limit_info.py b/tests/wrappers/test_time_limit_info.py deleted file mode 100644 index 792d6005489..00000000000 --- a/tests/wrappers/test_time_limit_info.py +++ /dev/null @@ -1 +0,0 @@ -# From 29eafe5d2cae44c338dc65c6f956d854517e4223 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 5 May 2022 15:28:46 +0530 Subject: [PATCH 14/37] docstrings, tests passing --- gym/utils/step_api_compatibility.py | 49 ++++++++++++++++--- gym/wrappers/step_api_compatibility.py | 16 +++--- pyproject.toml | 2 +- .../vector/test_step_compatibility_vector.py | 1 - .../test_record_episode_statistics.py | 3 -- tests/wrappers/test_step_compatibility.py | 3 -- 6 files changed, 52 insertions(+), 22 deletions(-) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index fe0301772db..48fddf33285 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -4,7 +4,12 @@ def step_to_new_api(step_returns, is_vector_env=False): - # Method to transform step returns to new step API + """Function to transform step returns to new step API irrespective of input API + + Args: + step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + is_vector_env (bool): Whether the step_returns are from a vector environment + """ if len(step_returns) == 5: deprecation( @@ -15,8 +20,8 @@ def step_to_new_api(step_returns, is_vector_env=False): else: assert len(step_returns) == 4 deprecation( - "Using a wrapper to transform env with old step API into new. This wrapper will be removed in v1.0. " - "It is recommended to upgrade the core env to the new step API." + "Transforming code with old step API into new. " + "It is recommended to upgrade the core env to the new step API. This can also be done by setting `new_step_api=True` at make. " "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate. " "Otherwise, `terminated=done` and `truncated=False`" ) @@ -52,20 +57,24 @@ def step_to_new_api(step_returns, is_vector_env=False): def step_to_old_api(step_returns, is_vector_env=False): - # Method to transform step returns to old step API + """Function to transform step returns to old step API irrespective of input API + + Args: + step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + is_vector_env (bool): Whether the step_returns are from a vector environment + """ if len(step_returns) == 4: deprecation( - "Core environment uses old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" + "Using old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" ) return step_returns else: assert len(step_returns) == 5 deprecation( - "Using a wrapper to transform new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " - "This wrapper will be removed in v1.0 " - "It is recommended to upgrade your accompanying code instead to be compatible with the new API, and use the new API. " + "Transforming code in new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " + "It is recommended to upgrade accompanying code to be compatible with the new API, and use the new API by setting `new_step_api=True`. " ) observations, rewards, terminateds, truncateds, infos = step_returns @@ -91,6 +100,30 @@ def step_to_old_api(step_returns, is_vector_env=False): def step_api_compatibility( step_returns, new_step_api: bool = False, is_vector_env: bool = False ): + """Function to transform step returns to the API specified by `new_step_api` bool. + + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) + (Refer to docs for details on the API change) + + Args: + step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + new_step_api (bool): Whether the output should be in new step API or old (False by default) + is_vector_env (bool): Whether the step_returns are from a vector environment + + Returns: + step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + + Examples: + This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, + wrapper is written in new API, and the final step output is desired to be in old API. + + >>> obs, rew, done, info = step_api_compatibility(env.step(action)) + >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) + >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) + + """ + if new_step_api: return step_to_new_api(step_returns, is_vector_env) else: diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index cceffcc6c71..ab3baf6e109 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -10,14 +10,19 @@ class StepAPICompatibility(gym.Wrapper): New step API refers to step() method returning (observation, reward, terminated, truncated, info) (Refer to docs for details on the API change) - This wrapper is to be used to ease transition to new API and for backward compatibility. It will be removed in v1.0 + This wrapper is to be used to ease transition to new API and for backward compatibility. - - Parameters - ---------- + Args: env (gym.Env): the env to wrap. Can be in old or new API new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) + Examples: + >>> env = gym.make("CartPole-v1") + >>> env # wrapper applied by default, set to old API + >>>> + >>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API + >>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs + """ def __init__(self, env: gym.Env, new_step_api=False): @@ -26,8 +31,7 @@ def __init__(self, env: gym.Env, new_step_api=False): if not self.new_step_api: deprecation( "Initializing environment in old step API which returns one bool instead of two. " - "Note that vector API and most wrappers would not work as these have been upgraded to the new API. " - "To use these features, please set `new_step_api=True` in make to use new API (see docs for more details)." + "It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. " ) def step(self, action): diff --git a/pyproject.toml b/pyproject.toml index b82e848aa1e..00842e2626d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,4 +20,4 @@ reportMissingTypeStubs = false verboseOutput = true [tool.pytest.ini_options] -filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # to be removed at 1.0 when old step API is removed +filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # to be removed when old step API is removed diff --git a/tests/vector/test_step_compatibility_vector.py b/tests/vector/test_step_compatibility_vector.py index 4a14c97d84f..d0305300fc7 100644 --- a/tests/vector/test_step_compatibility_vector.py +++ b/tests/vector/test_step_compatibility_vector.py @@ -4,7 +4,6 @@ import gym from gym.spaces import Discrete from gym.vector import AsyncVectorEnv, SyncVectorEnv -from gym.wrappers import StepAPICompatibility class OldStepEnv(gym.Env): diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index 9ec7666db4c..4b159877549 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -55,6 +55,3 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): assert "episode" in info assert all([item in info["episode"] for item in ["r", "l", "t"]]) break - - -test_record_episode_statistics("Pendulum-v1", 2) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index ab30afb9531..83557f02db6 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -75,6 +75,3 @@ def test_step_compatibility_in_make(new_step_api): assert len(step_returns) == 4 _, _, done, _ = step_returns assert isinstance(done, bool) - - -test_step_compatibility_in_make(True) From 97f36d3cafd76a8bd8ffecdac94ce1d337e956d0 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Sat, 28 May 2022 21:48:57 +0530 Subject: [PATCH 15/37] dealing with conflicts --- gym/core.py | 49 ++++++++--- gym/envs/mujoco/ant_v4.py | 27 +++--- gym/envs/mujoco/half_cheetah_v4.py | 8 +- gym/envs/mujoco/hopper_v4.py | 30 ++++--- gym/envs/mujoco/humanoid_v4.py | 25 +++--- gym/envs/mujoco/humanoidstandup_v4.py | 12 +-- .../mujoco/inverted_double_pendulum_v4.py | 14 +-- gym/envs/mujoco/inverted_pendulum_v4.py | 15 ++-- gym/envs/mujoco/pusher_v4.py | 17 ++-- gym/envs/mujoco/reacher_v4.py | 17 ++-- gym/envs/mujoco/swimmer_v4.py | 7 +- gym/envs/mujoco/walker2d_v4.py | 29 ++++--- gym/envs/registration.py | 18 +++- gym/envs/toy_text/cliffwalking.py | 2 +- gym/utils/env_checker.py | 3 +- gym/utils/play.py | 3 + gym/utils/step_api_compatibility.py | 86 +++++++++++++++---- gym/vector/__init__.py | 2 + gym/vector/async_vector_env.py | 11 ++- gym/vector/sync_vector_env.py | 2 + gym/vector/vector_env.py | 19 ++-- gym/wrappers/atari_preprocessing.py | 19 ++-- gym/wrappers/autoreset.py | 34 +++++--- gym/wrappers/frame_stack.py | 19 +++- gym/wrappers/normalize.py | 17 ++-- gym/wrappers/record_episode_statistics.py | 38 +++++--- gym/wrappers/record_video.py | 11 ++- gym/wrappers/step_api_compatibility.py | 15 ++++ gym/wrappers/time_aware_observation.py | 5 +- gym/wrappers/time_limit.py | 15 +++- pyproject.toml | 3 + tests/envs/test_action_dim_check.py | 7 +- tests/vector/test_vector_env.py | 8 +- tests/vector/test_vector_env_info.py | 20 ++--- tests/wrappers/test_vector_list_info.py | 4 +- 35 files changed, 419 insertions(+), 192 deletions(-) diff --git a/gym/core.py b/gym/core.py index 1356aebf8fe..4f31d0eae5e 100644 --- a/gym/core.py +++ b/gym/core.py @@ -66,11 +66,16 @@ def np_random(self) -> RandomNumberGenerator: def np_random(self, value: RandomNumberGenerator): self._np_random = value - def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + def step( + self, action: ActType + ) -> Union[ + Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] + ]: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. - Accepts an action and returns a tuple `(observation, reward, done, info)`. + Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple + (observation, reward, done, info). The latter is deprecated and will be removed in future versions. Args: action (ActType): an action provided by the agent @@ -79,14 +84,18 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: observation (object): this will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects. reward (float): The amount of reward returned as a result of taking the action. - done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results. - A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, - a certain timelimit was exceeded, or the physics simulation has entered an invalid state. + terminated (bool): whether the episode has ended due to reaching a terminal state intrinsic to the core environment, in which case further step() calls will return undefined results + truncated (bool): whether the episode has ended due to a truncation, i.e., a timelimit outside the scope of the problem defined in the environment. info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal. `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain: metrics that describe the agent's performance state, variables that are hidden from observations, information that distinguishes truncation and termination or individual reward terms that are combined to produce the total reward + + (deprecated) + done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results. + A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, + a certain timelimit was exceeded, or the physics simulation has entered an invalid state. """ raise NotImplementedError @@ -242,11 +251,12 @@ class Wrapper(Env[ObsType, ActType]): Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. """ - def __init__(self, env: Env): + def __init__(self, env: Env, new_step_api: bool = False): """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. Args: env: The environment to wrap + new_step_api: Whether the wrapper's step method will output in new or old step API """ self.env = env @@ -254,6 +264,7 @@ def __init__(self, env: Env): self._observation_space: Optional[spaces.Space] = None self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None self._metadata: Optional[dict] = None + self.new_step_api = new_step_api def __getattr__(self, name): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" @@ -315,9 +326,13 @@ def metadata(self) -> dict: def metadata(self, value): self._metadata = value - def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + def step( + self, action: ActType + ) -> Union[ + Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] + ]: """Steps through the environment with action.""" - return self.env.step(action) + return self.env.step(action) # ! Does this take self.new_step_api into account? def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]: """Resets the environment with kwargs.""" @@ -387,8 +402,13 @@ def reset(self, **kwargs): def step(self, action): """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" - observation, reward, done, info = self.env.step(action) - return self.observation(observation), reward, done, info + step_returns = self.env.step(action) + if len(step_returns) == 5: + observation, reward, terminated, truncated, info = step_returns + return self.observation(observation), reward, terminated, truncated, info + else: + observation, reward, done, info = step_returns + return self.observation(observation), reward, done, info def observation(self, observation): """Returns a modified observation.""" @@ -421,8 +441,13 @@ def reward(self, reward): def step(self, action): """Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" - observation, reward, done, info = self.env.step(action) - return observation, self.reward(reward), done, info + step_returns = self.env.step(action) + if len(step_returns) == 5: + observation, reward, terminated, truncated, info = step_returns + return observation, self.reward(reward), terminated, truncated, info + else: + observation, reward, done, info = step_returns + return observation, self.reward(reward), done, info def reward(self, reward): """Returns a modified ``reward``.""" diff --git a/gym/envs/mujoco/ant_v4.py b/gym/envs/mujoco/ant_v4.py index bdc26112ca0..2f6e0372077 100644 --- a/gym/envs/mujoco/ant_v4.py +++ b/gym/envs/mujoco/ant_v4.py @@ -131,12 +131,19 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): to be slightly high, thereby indicating a standing up ant. The initial orientation is designed to make it face forward as well. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The ant is said to be unhealthy if any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. Any of the state space values is no longer finite - 3. The y-orientation (index 2) in the state is **not** in the range `[0.2, 1.0]` + 1. Any of the state space values is no longer finite + 2. The z-coordinate of the torso is **not** in the closed interval given by `healthy_z_range` (defaults to [0.2, 1.0]) + + If `terminate_when_unhealthy=True` is passed during construction (which is the default), + the episode ends when any of the following happens: + + 1. Termination: The episode duration reaches a 1000 timesteps + 2. Truncation: The ant is unhealthy + + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -229,9 +236,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def step(self, action): xy_position_before = self.get_body_com("torso")[:2].copy() @@ -248,7 +255,7 @@ def step(self, action): costs = ctrl_cost = self.control_cost(action) - done = self.done + terminated = self.terminated observation = self._get_obs() info = { "reward_forward": forward_reward, @@ -268,7 +275,7 @@ def step(self, action): reward = rewards - costs - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/half_cheetah_v4.py b/gym/envs/mujoco/half_cheetah_v4.py index dc834dd8f93..7ce02b3994e 100644 --- a/gym/envs/mujoco/half_cheetah_v4.py +++ b/gym/envs/mujoco/half_cheetah_v4.py @@ -117,8 +117,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): normal noise with a mean of 0 and standard deviation of 0.1 is added to the initial velocity values of all zeros. - ### Episode Termination - The episode terminates when the episode length is greater than 1000. + ### Episode End + The episode truncates when the episode length is greater than 1000. ### Arguments @@ -184,7 +184,7 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False + terminated = False info = { "x_position": x_position_after, "x_velocity": x_velocity, @@ -192,7 +192,7 @@ def step(self, action): "reward_ctrl": -ctrl_cost, } - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/hopper_v4.py b/gym/envs/mujoco/hopper_v4.py index 776cedf63eb..bcc3a269dd9 100644 --- a/gym/envs/mujoco/hopper_v4.py +++ b/gym/envs/mujoco/hopper_v4.py @@ -104,14 +104,20 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): (0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-0.005, 0.005] added to the values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The hopper is said to be unhealthy if any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. Any of the state space values is no longer finite - 3. The absolute value of any of the state variable indexed (angle and beyond) is greater than 100 - 4. The height of the hopper becomes greater than 0.7 metres (hopper has hopped too high). - 5. The absolute value of the angle (index 2) is less than 0.2 radians (hopper has fallen down). + 1. An element of `observation[1:]` (if `exclude_current_positions_from_observation=True`, else `observation[2:]`) is no longer contained in the closed interval specified by the argument `healthy_state_range` + 2. The height of the hopper (`observation[0]` if `exclude_current_positions_from_observation=True`, else `observation[1]`) is no longer contained in the closed interval specified by the argument `healthy_z_range` (usually meaning that it has fallen) + 3. The angle (`observation[1]` if `exclude_current_positions_from_observation=True`, else `observation[2]`) is no longer contained in the closed interval specified by the argument `healthy_angle_range` + + If `terminate_when_unhealthy=True` is passed during construction (which is the default), + the episode ends when any of the following happens: + + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: The hopper is unhealthy + + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -201,9 +207,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.data.qpos.flat.copy() @@ -231,13 +237,13 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/humanoid_v4.py b/gym/envs/mujoco/humanoid_v4.py index 992503f6fd0..3c65781fce9 100644 --- a/gym/envs/mujoco/humanoid_v4.py +++ b/gym/envs/mujoco/humanoid_v4.py @@ -169,12 +169,17 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): selected to be high, thereby indicating a standing up humanoid. The initial orientation is designed to make it face forward as well. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The humanoid is said to be unhealthy if the z-position of the torso is no longer contained in the + closed interval specified by the argument `healthy_z_range`. - 1. The episode duration reaches a 1000 timesteps - 2. Any of the state space values is no longer finite - 3. The z-coordinate of the torso (index 0 in state space OR index 2 in the table) is **not** in the range `[1.0, 2.0]` (the humanoid has fallen or is about to fall beyond recovery). + If `terminate_when_unhealthy=True` is passed during construction (which is the default), + the episode ends when any of the following happens: + + 1. Truncation: The episode duration reaches a 1000 timesteps + 3. Termination: The humanoid is unhealthy + + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -248,9 +253,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = (not self.is_healthy) if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.data.qpos.flat.copy() @@ -293,7 +298,7 @@ def step(self, action): observation = self._get_obs() reward = rewards - ctrl_cost - done = self.done + terminated = self.terminated info = { "reward_linvel": forward_reward, "reward_quadctrl": -ctrl_cost, @@ -306,7 +311,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/humanoidstandup_v4.py b/gym/envs/mujoco/humanoidstandup_v4.py index c15d8c022b6..8759bb3c16d 100644 --- a/gym/envs/mujoco/humanoidstandup_v4.py +++ b/gym/envs/mujoco/humanoidstandup_v4.py @@ -156,11 +156,11 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle): to be low, thereby indicating a laying down humanoid. The initial orientation is designed to make it face forward as well. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. Any of the state space values is no longer finite + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: Any of the state space values is no longer finite ### Arguments @@ -218,11 +218,11 @@ def step(self, a): quad_impact_cost = min(quad_impact_cost, 10) reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 - done = bool(False) return ( self._get_obs(), reward, - done, + False, + False, dict( reward_linup=uph_cost, reward_quadctrl=-quad_ctrl_cost, diff --git a/gym/envs/mujoco/inverted_double_pendulum_v4.py b/gym/envs/mujoco/inverted_double_pendulum_v4.py index ca5fafbbe96..eb6a00e545f 100644 --- a/gym/envs/mujoco/inverted_double_pendulum_v4.py +++ b/gym/envs/mujoco/inverted_double_pendulum_v4.py @@ -84,12 +84,12 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): of [-0.1, 0.1] added to the positional values (cart position and pole angles) and standard normal force with a standard deviation of 0.1 added to the velocity values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches 1000 timesteps. - 2. Any of the state space values is no longer finite. - 3. The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other). + 1.Truncation: The episode duration reaches 1000 timesteps. + 2.Termination: Any of the state space values is no longer finite. + 3.Termination: The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other). ### Arguments @@ -131,8 +131,8 @@ def step(self, action): vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2 alive_bonus = 10 r = alive_bonus - dist_penalty - vel_penalty - done = bool(y <= 1) - return ob, r, done, {} + terminated = bool(y <= 1) + return ob, r, terminated, False, {} def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/inverted_pendulum_v4.py b/gym/envs/mujoco/inverted_pendulum_v4.py index 0044b99f41e..e26ce1d2446 100644 --- a/gym/envs/mujoco/inverted_pendulum_v4.py +++ b/gym/envs/mujoco/inverted_pendulum_v4.py @@ -55,12 +55,12 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): (0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-0.01, 0.01] added to the values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches 1000 timesteps. - 2. Any of the state space values is no longer finite. - 3. The absolute value of the vertical angle between the pole and the cart is greater than 0.2 radians. + 1. Truncation: The episode duration reaches 1000 timesteps. + 2. Termination: Any of the state space values is no longer finite. + 3. Termination: The absolutely value of the vertical angle between the pole and the cart is greater than 0.2 radian. ### Arguments @@ -98,9 +98,8 @@ def step(self, a): reward = 1.0 self.do_simulation(a, self.frame_skip) ob = self._get_obs() - notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2) - done = not notdone - return ob, reward, done, {} + terminated = not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2) + return ob, reward, terminated, False, {} def reset_model(self): qpos = self.init_qpos + self.np_random.uniform( diff --git a/gym/envs/mujoco/pusher_v4.py b/gym/envs/mujoco/pusher_v4.py index acdd94bf150..7406308b89b 100644 --- a/gym/envs/mujoco/pusher_v4.py +++ b/gym/envs/mujoco/pusher_v4.py @@ -98,12 +98,12 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): The default framerate is 5 with each frame lasting for 0.01, giving rise to a *dt = 5 * 0.01 = 0.05* - ### Episode Termination + ### Episode End - The episode terminates when any of the following happens: + The episode ends when any of the following happens: - 1. The episode duration reaches a 100 timesteps. - 2. Any of the state space values is no longer finite. + 1. Truncation: The episode duration reaches a 100 timesteps. + 2. Termination: Any of the state space values is no longer finite. ### Arguments @@ -147,8 +147,13 @@ def step(self, a): self.do_simulation(a, self.frame_skip) ob = self._get_obs() - done = False - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): self.viewer.cam.trackbodyid = -1 diff --git a/gym/envs/mujoco/reacher_v4.py b/gym/envs/mujoco/reacher_v4.py index 278770fa320..50ba0c2b036 100644 --- a/gym/envs/mujoco/reacher_v4.py +++ b/gym/envs/mujoco/reacher_v4.py @@ -88,12 +88,12 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): element ("fingertip" - "target") is calculated at the end once everything is set. The default setting has a framerate of 2 and a *dt = 2 * 0.01 = 0.02* - ### Episode Termination + ### Episode End - The episode terminates when any of the following happens: + The episode ends when any of the following happens: - 1. The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps) - 2. Any of the state space values is no longer finite. + 1. Truncation: The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps) + 2. Termination: Any of the state space values is no longer finite. ### Arguments @@ -133,8 +133,13 @@ def step(self, a): reward = reward_dist + reward_ctrl self.do_simulation(a, self.frame_skip) ob = self._get_obs() - done = False - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): self.viewer.cam.trackbodyid = 0 diff --git a/gym/envs/mujoco/swimmer_v4.py b/gym/envs/mujoco/swimmer_v4.py index aa499d14362..ddb5c58eea7 100644 --- a/gym/envs/mujoco/swimmer_v4.py +++ b/gym/envs/mujoco/swimmer_v4.py @@ -102,8 +102,8 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): ### Starting State All observations start in state (0,0,0,0,0,0,0,0) with a Uniform noise in the range of [-0.1, 0.1] is added to the initial state for stochasticity. - ### Episode Termination - The episode terminates when the episode length is greater than 1000. + ### Episode End + The episode truncates when the episode length is greater than 1000. ### Arguments @@ -169,7 +169,6 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False info = { "reward_fwd": forward_reward, "reward_ctrl": -ctrl_cost, @@ -181,7 +180,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, False, False, info def _get_obs(self): position = self.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/walker2d_v4.py b/gym/envs/mujoco/walker2d_v4.py index e3085d8a121..77c38e75d88 100644 --- a/gym/envs/mujoco/walker2d_v4.py +++ b/gym/envs/mujoco/walker2d_v4.py @@ -122,13 +122,20 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): (0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-0.005, 0.005] added to the values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The walker is said to be unhealthy if any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. Any of the state space values is no longer finite - 3. The height of the walker (index 1) is ***not*** in the range `[0.8, 2]` - 4. The absolute value of the angle (index 2) is ***not*** in the range `[-1, 1]` + 1. Any of the state space values is no longer finite + 2. The height of the walker is ***not*** in the closed interval specified by `healthy_z_range` + 3. The absolute value of the angle (`observation[1]` if `exclude_current_positions_from_observation=False`, else `observation[2]`) is ***not*** in the closed interval specified by `healthy_angle_range` + + If `terminate_when_unhealthy=True` is passed during construction (which is the default), + the episode ends when any of the following happens: + + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: The walker is unhealthy + + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -212,9 +219,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.data.qpos.flat.copy() @@ -242,13 +249,13 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/registration.py b/gym/envs/registration.py index a265d536a21..e5e36e07d12 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -24,7 +24,12 @@ from gym.envs.__relocated__ import internal_env_relocation_map from gym.utils.env_checker import check_env -from gym.wrappers import AutoResetWrapper, OrderEnforcing, TimeLimit +from gym.wrappers import ( + AutoResetWrapper, + OrderEnforcing, + StepAPICompatibility, + TimeLimit, +) if sys.version_info < (3, 10): import importlib_metadata as metadata # type: ignore @@ -522,6 +527,7 @@ def make( id: Union[str, EnvSpec], max_episode_steps: Optional[int] = None, autoreset: bool = False, + new_step_api: bool = False, disable_env_checker: bool = False, **kwargs, ) -> Env: @@ -536,6 +542,7 @@ def make( id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). + new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0 disable_env_checker: If to disable the environment checker kwargs: Additional arguments to pass to the environment constructor. @@ -623,6 +630,15 @@ def make( f"You can set `disable_env_checker=True` to disable this check." ) + if not disable_env_checker: + try: + check_env(env) + except Exception as e: + logger.warn( + f"Env check failed with the following message: {e}\n" + f"You can set `disable_env_checker=True` to disable this check." + ) + return env diff --git a/gym/envs/toy_text/cliffwalking.py b/gym/envs/toy_text/cliffwalking.py index 7ab613dadfd..18ff2021d08 100644 --- a/gym/envs/toy_text/cliffwalking.py +++ b/gym/envs/toy_text/cliffwalking.py @@ -107,7 +107,7 @@ def _calculate_transition_prob(self, current, delta): delta: Change in position for transition Returns: - Tuple of ``(1.0, new_state, reward, done)`` + Tuple of ``(1.0, new_state, reward, terminated)`` """ new_position = np.array(current) + np.array(delta) new_position = self._limit_coordinates(new_position).astype(int) diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 8dcc790e60d..67c3089c5d4 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -19,6 +19,7 @@ import gym from gym import logger from gym.spaces import Box, Dict, Discrete, Space, Tuple +from gym.utils.step_api_compatibility import step_api_compatibility def _is_numpy_array_space(space: Space) -> bool: @@ -71,7 +72,7 @@ def _check_nan(env: gym.Env, check_inf: bool = True): """ for _ in range(10): action = env.action_space.sample() - observation, reward, done, _ = env.step(action) + observation, reward, done, _ = step_api_compatibility(env.step(action), False) if done: env.reset() diff --git a/gym/utils/play.py b/gym/utils/play.py index c96769f29d5..42263d41f41 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -1,4 +1,7 @@ """Utilities of visualising an environment.""" + +# TODO: Convert to new step API in 1.0 + from collections import deque from typing import Callable, Dict, List, Optional, Tuple, Union diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 48fddf33285..f0454153420 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -1,16 +1,36 @@ +"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0.""" +from typing import Tuple, Union + import numpy as np +from gym.core import ObsType from gym.logger import deprecation +OldStepType = Tuple[ + Union[ObsType, np.ndarray], + Union[float, np.ndarray], + Union[bool, np.ndarray], + Union[dict, list], +] + +NewStepType = Tuple[ + Union[ObsType, np.ndarray], + Union[float, np.ndarray], + Union[bool, np.ndarray], + Union[bool, np.ndarray], + Union[dict, list], +] + -def step_to_new_api(step_returns, is_vector_env=False): - """Function to transform step returns to new step API irrespective of input API +def step_to_new_api( + step_returns: Union[OldStepType, NewStepType], is_vector_env=False +) -> NewStepType: + """Function to transform step returns to new step API irrespective of input API. Args: step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) is_vector_env (bool): Whether the step_returns are from a vector environment """ - if len(step_returns) == 5: deprecation( "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " @@ -32,12 +52,26 @@ def step_to_new_api(step_returns, is_vector_env=False): truncateds = [] if not is_vector_env: dones = [dones] - infos = [infos] + for i in range(len(dones)): - if "TimeLimit.truncated" not in infos[i]: + if "TimeLimit.truncated" not in infos or ( + is_vector_env + and "TimeLimit.truncated" in infos + and not infos["_TimeLimit.truncated"][ + i + ] # if mask is False, it's the same as TimeLimit.truncated attribute not being present + ): terminateds.append(dones[i]) truncateds.append(False) - elif infos[i]["TimeLimit.truncated"]: + elif ( + infos["TimeLimit.truncated"] + if not is_vector_env + else ( # handle vector info as both dict and list + infos["TimeLimit.truncated"][i] + if isinstance(infos, dict) + else infos[i]["TimeLimit.truncated"] + ) + ): terminateds.append(False) truncateds.append(True) else: @@ -52,18 +86,19 @@ def step_to_new_api(step_returns, is_vector_env=False): rewards, np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0], np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0], - infos if is_vector_env else infos[0], + infos, ) -def step_to_old_api(step_returns, is_vector_env=False): - """Function to transform step returns to old step API irrespective of input API +def step_to_old_api( + step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False +) -> OldStepType: + """Function to transform step returns to old step API irrespective of input API. Args: step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) is_vector_env (bool): Whether the step_returns are from a vector environment """ - if len(step_returns) == 4: deprecation( "Using old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" @@ -82,24 +117,41 @@ def step_to_old_api(step_returns, is_vector_env=False): if not is_vector_env: terminateds = [terminateds] truncateds = [truncateds] - infos = [infos] - for i in range(len(terminateds)): + n_envs = len(terminateds) + + for i in range(n_envs): dones.append(terminateds[i] or truncateds[i]) # to be consistent with old API if truncateds[i]: - infos[i]["TimeLimit.truncated"] = not terminateds[i] + if is_vector_env: + # handle vector infos for dict and list + if isinstance(infos, dict): + if "TimeLimit.truncated" not in infos: + # TODO: This should ideally not be done manually and should use vector_env's _add_info() + infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) + infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) + + infos["TimeLimit.truncated"][i] = not terminateds[i] + infos["_TimeLimit.truncated"][i] = True + else: + # if vector info is a list + infos[i]["TimeLimit.truncated"] = not terminateds[i] + else: + infos["TimeLimit.truncated"] = not terminateds[i] return ( observations, rewards, np.array(dones, dtype=np.bool_) if is_vector_env else dones[0], - infos if is_vector_env else infos[0], + infos, ) def step_api_compatibility( - step_returns, new_step_api: bool = False, is_vector_env: bool = False -): + step_returns: Union[NewStepType, OldStepType], + new_step_api: bool = False, + is_vector_env: bool = False, +) -> Union[NewStepType, OldStepType]: """Function to transform step returns to the API specified by `new_step_api` bool. Old step API refers to step() method returning (observation, reward, done, info) @@ -121,9 +173,7 @@ def step_api_compatibility( >>> obs, rew, done, info = step_api_compatibility(env.step(action)) >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) - """ - if new_step_api: return step_to_new_api(step_returns, is_vector_env) else: diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index 7edcef883a3..cbf31abd02f 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -13,6 +13,7 @@ def make( num_envs: int = 1, asynchronous: bool = True, wrappers: Optional[Union[callable, List[callable]]] = None, + new_step_api: bool = False, **kwargs, ) -> VectorEnv: """Create a vectorized environment from multiple copies of an environment, from its id. @@ -32,6 +33,7 @@ def make( num_envs: Number of copies of the environment. asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`. wrappers: If not ``None``, then apply the wrappers to each internal environment during creation. + new_step_api: If True, the vector environment's step method outputs two booleans `terminated`, `truncated` instead of one `done`. **kwargs: Keywords arguments applied during gym.make Returns: diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 3f40c71abbb..aaffe62ae69 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -67,6 +67,7 @@ def __init__( context: Optional[str] = None, daemon: bool = True, worker: Optional[callable] = None, + new_step_api: bool = False, ): """Vectorized environment that runs multiple environments in parallel. @@ -86,6 +87,7 @@ def __init__( so for some environments you may want to have it set to ``False``. worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled. + new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start @@ -340,7 +342,7 @@ def step_wait( timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out. Returns: - The batched environment step information, obs, reward, done and info + The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api Raises: ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). @@ -360,16 +362,17 @@ def step_wait( f"The call to `step_wait` has timed out after {timeout} second(s)." ) - observations_list, rewards, dones, infos = [], [], [], {} + observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {} successes = [] for i, pipe in enumerate(self.parent_pipes): result, success = pipe.recv() - obs, rew, done, info = result + obs, rew, terminated, truncated, info = step_api_compatibility(result, True) successes.append(success) observations_list.append(obs) rewards.append(rew) - dones.append(done) + terminateds.append(terminated) + truncateds.append(truncated) infos = self._add_info(infos, info, i) self._raise_if_errors(successes) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 02ae0f2c775..25264a45bd1 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -5,6 +5,7 @@ import numpy as np from gym.spaces import Space +from gym.utils.step_api_compatibility import step_api_compatibility from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.vector_env import VectorEnv @@ -32,6 +33,7 @@ def __init__( observation_space: Space = None, action_space: Space = None, copy: bool = True, + new_step_api: bool = False, ): """Vectorized environment that serially runs multiple environments. diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index 46d62db7196..103cdfb5606 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -1,5 +1,5 @@ """Base class for vectorized environments.""" -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -24,7 +24,11 @@ class VectorEnv(gym.Env): """ def __init__( - self, num_envs: int, observation_space: gym.Space, action_space: gym.Space + self, + num_envs: int, + observation_space: gym.Space, + action_space: gym.Space, + new_step_api: bool = False, ): """Base class for vectorized environments. @@ -32,6 +36,7 @@ def __init__( num_envs: Number of environments in the vectorized environment. observation_space: Observation space of a single environment. action_space: Action space of a single environment. + new_step_api (bool): Whether the vector env's step method outputs two boolean arrays (new API) or one boolean array (old API) """ self.num_envs = num_envs self.is_vector_env = True @@ -136,7 +141,7 @@ def step(self, actions): actions: element of :attr:`action_space` Batch of actions. Returns: - Batch of observations, rewards, done and infos + Batch of (observations, rewards, terminateds, truncateds, infos) or (observations, rewards, dones, infos) """ self.step_async(actions) return self.step_wait() @@ -144,7 +149,7 @@ def step(self, actions): def call_async(self, name, *args, **kwargs): """Calls a method name for each parallel environment asynchronously.""" - def call_wait(self, **kwargs) -> List[Any]: + def call_wait(self, **kwargs) -> List[Any]: # type: ignore """After calling a method in :meth:`call_async`, this function collects the results.""" def call(self, name: str, *args, **kwargs) -> List[Any]: @@ -252,7 +257,7 @@ def _add_info(self, infos: dict, info: dict, env_num: int) -> dict: infos[k], infos[f"_{k}"] = info_array, array_mask return infos - def _init_info_arrays(self, dtype: type) -> np.ndarray: + def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]: """Initialize the info array. Initialize the info array. If the dtype is numeric @@ -293,10 +298,6 @@ def __repr__(self) -> str: else: return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})" - @staticmethod - def get_env_step_return(env, action): - return env.step(action) - class VectorEnvWrapper(VectorEnv): """Wraps the vectorized environment to allow a modular transformation. diff --git a/gym/wrappers/atari_preprocessing.py b/gym/wrappers/atari_preprocessing.py index 7e2db2c1d8a..794dc826de3 100644 --- a/gym/wrappers/atari_preprocessing.py +++ b/gym/wrappers/atari_preprocessing.py @@ -4,6 +4,7 @@ import gym from gym.error import DependencyNotInstalled from gym.spaces import Box +from gym.utils.step_api_compatibility import step_api_compatibility try: import cv2 @@ -38,6 +39,7 @@ def __init__( grayscale_obs: bool = True, grayscale_newaxis: bool = False, scale_obs: bool = False, + new_step_api: bool = False, ): """Wrapper for Atari 2600 preprocessing. @@ -59,7 +61,7 @@ def __init__( DependencyNotInstalled: opencv-python package not installed ValueError: Disable frame-skipping in the original env """ - super().__init__(env) + super().__init__(env, new_step_api) if cv2 is None: raise DependencyNotInstalled( "opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari" @@ -117,14 +119,16 @@ def step(self, action): total_reward = 0.0 for t in range(self.frame_skip): - _, reward, done, info = self.env.step(action) + _, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), True + ) total_reward += reward - self.game_over = done + self.game_over = terminated if self.terminal_on_life_loss: new_lives = self.ale.lives() - done = done or new_lives < self.lives - self.game_over = done + terminated = terminated or new_lives < self.lives + self.game_over = terminated self.lives = new_lives if terminated or truncated: @@ -139,7 +143,10 @@ def step(self, action): self.ale.getScreenGrayscale(self.obs_buffer[0]) else: self.ale.getScreenRGB(self.obs_buffer[0]) - return self._get_obs(), total_reward, done, info + return step_api_compatibility( + (self._get_obs(), total_reward, terminated, truncated, info), + self.new_step_api, + ) def reset(self, **kwargs): """Resets the environment using preprocessing.""" diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index ec7e211129d..8f6d24db165 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -1,4 +1,4 @@ -"""Wrapper that autoreset environments when `done=True`.""" +"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`.""" import gym from gym.utils.step_api_compatibility import step_api_compatibility @@ -6,23 +6,33 @@ class AutoResetWrapper(gym.Wrapper): """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. - When calling step causes :meth:`Env.step` to return done, :meth:`Env.reset` is called, - and the return format of :meth:`self.step` is as follows: ``(new_obs, terminal_reward, terminal_done, info)`` + When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called, + and the return format of :meth:`self.step` is as follows: ``(new_obs, closing_reward, closing_terminated, closing_truncated, info)`` + with new step API and ``(new_obs, closing_reward, closing_done, info)`` with the old step API. - ``new_obs`` is the first observation after calling :meth:`self.env.reset` - - ``terminal_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. - - ``terminal_done`` is always True + - ``closing_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. + - ``closing_done`` is always True. In the new API, either ``closing_terminated`` or ``closing_truncated`` is True - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`, - with an additional key "terminal_observation" containing the observation returned by the last call to :meth:`self.env.step` - and "terminal_info" containing the info dict returned by the last call to :meth:`self.env.step`. + with an additional key "closing_observation" containing the observation returned by the last call to :meth:`self.env.step` + and "closing_info" containing the info dict returned by the last call to :meth:`self.env.step`. Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns done, a new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the - terminal reward and done state from the previous episode. - If you need the terminal state from the previous episode, you need to retrieve it via the - "terminal_observation" key in the info dict. + closing reward and done state from the previous episode. + If you need the closing state from the previous episode, you need to retrieve it via the + "closing_observation" key in the info dict. Make sure you know what you're doing if you use this wrapper! """ + def __init__(self, env: gym.Env, new_step_api: bool = False): + """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. + + Args: + env (gym.Env): The environment to apply the wrapper + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) + """ + super().__init__(env, new_step_api) + def step(self, action): """Steps through the environment with action and resets the environment if a done-signal is encountered. @@ -32,7 +42,9 @@ def step(self, action): Returns: The autoreset environment :meth:`step` """ - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), True + ) if terminated or truncated: diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index 715557c9e28..a6ca7193f63 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -123,15 +123,22 @@ class FrameStack(gym.ObservationWrapper): (4, 96, 96, 3) """ - def __init__(self, env: gym.Env, num_stack: int, lz4_compress: bool = False): + def __init__( + self, + env: gym.Env, + num_stack: int, + lz4_compress: bool = False, + new_step_api: bool = False, + ): """Observation wrapper that stacks the observations in a rolling manner. Args: env (Env): The environment to apply the wrapper num_stack (int): The number of frames to stack lz4_compress (bool): Use lz4 to compress the frames internally + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_stack = num_stack self.lz4_compress = lz4_compress @@ -166,9 +173,13 @@ def step(self, action): Returns: Stacked observations, reward, done and information from the environment """ - observation, reward, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), True + ) self.frames.append(observation) - return self.observation(None), reward, done, info + return step_api_compatibility( + (self.observation(), reward, terminated, truncated, info), self.new_step_api + ) def reset(self, **kwargs): """Reset the environment with kwargs. diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index 9a96d484941..53b7ca2e28a 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -55,14 +55,15 @@ class NormalizeObservation(gym.core.Wrapper): newly instantiated or the policy was changed recently. """ - def __init__(self, env: gym.Env, epsilon: float = 1e-8): + def __init__(self, env: gym.Env, epsilon: float = 1e-8, new_step_api: bool = False): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. Args: env (Env): The environment to apply the wrapper epsilon: A stability parameter that is used when scaling the observations. + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) if self.is_vector_env: @@ -73,7 +74,9 @@ def __init__(self, env: gym.Env, epsilon: float = 1e-8): def step(self, action): """Steps through the environment and normalizes the observation.""" - obs, rews, dones, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = step_api_compatibility( + self.env.step(action), True, self.is_vector_env + ) if self.is_vector_env: obs = self.normalize(obs) else: @@ -121,6 +124,7 @@ def __init__( env: gym.Env, gamma: float = 0.99, epsilon: float = 1e-8, + new_step_api: bool = False, ): """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. @@ -128,8 +132,9 @@ def __init__( env (env): The environment to apply the wrapper epsilon (float): A stability parameter gamma (float): The discount factor that is used in the exponential moving average. + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) self.return_rms = RunningMeanStd(shape=()) @@ -139,7 +144,9 @@ def __init__( def step(self, action): """Steps through the environment, normalizing the rewards returned.""" - obs, rews, dones, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = step_api_compatibility( + self.env.step(action), True, self.is_vector_env + ) if not self.is_vector_env: rews = np.array([rews]) self.returns = self.returns * self.gamma + rews diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index f241e266f4f..41c061721bd 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -76,14 +76,15 @@ class RecordEpisodeStatistics(gym.Wrapper): length_queue: The lengths of the last ``deque_size``-many episodes """ - def __init__(self, env: gym.Env, deque_size: int = 100): + def __init__(self, env: gym.Env, deque_size: int = 100, new_step_api: bool = False): """This wrapper will keep track of cumulative rewards and episode lengths. Args: env (Env): The environment to apply the wrapper deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.t0 = time.perf_counter() self.episode_count = 0 @@ -102,18 +103,26 @@ def reset(self, **kwargs): def step(self, action): """Steps through the environment, recording the episode statistics.""" - observations, rewards, dones, infos = super().step(action) + ( + observations, + rewards, + terminateds, + truncateds, + infos, + ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) assert isinstance( infos, dict ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." self.episode_returns += rewards self.episode_lengths += 1 if not self.is_vector_env: - dones = [dones] - dones = list(dones) + terminateds = [terminateds] + truncateds = [truncateds] + terminateds = list(terminateds) + truncateds = list(truncateds) - for i in range(len(dones)): - if dones[i]: + for i in range(len(terminateds)): + if terminateds[i] or truncateds[i]: episode_return = self.episode_returns[i] episode_length = self.episode_lengths[i] episode_info = { @@ -134,9 +143,14 @@ def step(self, action): self.episode_count += 1 self.episode_returns[i] = 0 self.episode_lengths[i] = 0 - return ( - observations, - rewards, - dones if self.is_vector_env else dones[0], - infos, + return step_api_compatibility( + ( + observations, + rewards, + terminateds if self.is_vector_env else terminateds[0], + truncateds if self.is_vector_env else truncateds[0], + infos, + ), + self.new_step_api, + self.is_vector_env, ) diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index 9780c45e434..02f1417401d 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -58,8 +58,9 @@ def __init__( video_length (int): The length of recorded episodes. If 0, entire episodes are recorded. Otherwise, snippets of the specified length are captured name_prefix (str): Will be prepended to the filename of the recordings + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) if episode_trigger is None and step_trigger is None: episode_trigger = capped_cubic_video_schedule @@ -123,7 +124,13 @@ def _video_enabled(self): def step(self, action): """Steps through the environment using action, recording observations if :attr:`self.recording`.""" - observations, rewards, dones, infos = super().step(action) + ( + observations, + rewards, + terminateds, + truncateds, + infos, + ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) # increment steps and episodes self.step_id += 1 diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index ab3baf6e109..67759a99ad4 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -1,3 +1,4 @@ +"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" import gym from gym.logger import deprecation from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api @@ -26,6 +27,12 @@ class StepAPICompatibility(gym.Wrapper): """ def __init__(self, env: gym.Env, new_step_api=False): + """A wrapper which can transform an environment from new step API to old and vice-versa. + + Args: + env (gym.Env): the env to wrap. Can be in old or new API + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) + """ super().__init__(env) self.new_step_api = new_step_api if not self.new_step_api: @@ -35,6 +42,14 @@ def __init__(self, env: gym.Env, new_step_api=False): ) def step(self, action): + """Steps through the environment, returning 5 or 4 items depending on `new_step_api`. + + Args: + action: action to step through the environment with + + Returns: + (observation, reward, terminated, truncated, info) or (observation, reward, done, info) + """ step_returns = self.env.step(action) if self.new_step_api: return step_to_new_api(step_returns) diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py index 76f781345ff..2307eb06334 100644 --- a/gym/wrappers/time_aware_observation.py +++ b/gym/wrappers/time_aware_observation.py @@ -22,13 +22,14 @@ class TimeAwareObservation(gym.ObservationWrapper): array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ]) """ - def __init__(self, env: gym.Env): + def __init__(self, env: gym.Env, new_step_api: bool = False): """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space. Args: env: The environment to apply the wrapper + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) assert isinstance(env.observation_space, Box) assert env.observation_space.dtype == np.float32 low = np.append(self.observation_space.low, 0.0) diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 2637a97591c..92b41c58654 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -21,14 +21,20 @@ class TimeLimit(gym.Wrapper): >>> env = TimeLimit(env, max_episode_steps=1000) """ - def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): + def __init__( + self, + env: gym.Env, + max_episode_steps: Optional[int] = None, + new_step_api: bool = False, + ): """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. Args: env: The environment to apply the wrapper max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) if max_episode_steps is None and self.env.spec is not None: max_episode_steps = env.spec.max_episode_steps if self.env.spec is not None: @@ -47,7 +53,10 @@ def step(self, action): when truncated (the number of steps elapsed >= max episode steps) or "TimeLimit.truncated"=False if the environment terminated """ - observation, reward, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), + True, + ) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: diff --git a/pyproject.toml b/pyproject.toml index b886e94b071..560d92eec32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,3 +42,6 @@ reportUntypedFunctionDecorator = "none" reportMissingTypeStubs = false reportUnboundVariable = "warning" reportGeneralTypeIssues ="none" + +[tool.pytest.ini_options] +filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # to be removed when old step API is removed diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 6b5032258e1..f21da2a15b4 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -1,3 +1,4 @@ +import warnings from typing import List import numpy as np @@ -30,7 +31,11 @@ def filters_envs_action_space_type( """ filtered_envs = [] for spec in env_spec_list: - env = gym.make(spec.id) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=".*step API.*", category=DeprecationWarning + ) # since this function is outside scope of pytest warning suppression + env = gym.make(spec.id) if isinstance(env.action_space, action_space): filtered_envs.append(env) return filtered_envs diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 697831ec6ff..ed1bdccb76d 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -36,10 +36,10 @@ def test_vector_env_equal(shared_memory): # fmt: on if any(sync_dones): - assert "terminal_observation" in async_infos - assert "_terminal_observation" in async_infos - assert "terminal_observation" in sync_infos - assert "_terminal_observation" in sync_infos + assert "closing_observation" in async_infos + assert "_closing_observation" in async_infos + assert "closing_observation" in sync_infos + assert "_closing_observation" in sync_infos assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) diff --git a/tests/vector/test_vector_env_info.py b/tests/vector/test_vector_env_info.py index 3d79d88d743..eb5a2a3dd2e 100644 --- a/tests/vector/test_vector_env_info.py +++ b/tests/vector/test_vector_env_info.py @@ -20,18 +20,18 @@ def test_vector_env_info(asynchronous): action = env.action_space.sample() _, _, dones, infos = env.step(action) if any(dones): - assert len(infos["terminal_observation"]) == NUM_ENVS - assert len(infos["_terminal_observation"]) == NUM_ENVS + assert len(infos["closing_observation"]) == NUM_ENVS + assert len(infos["_closing_observation"]) == NUM_ENVS - assert isinstance(infos["terminal_observation"], np.ndarray) - assert isinstance(infos["_terminal_observation"], np.ndarray) + assert isinstance(infos["closing_observation"], np.ndarray) + assert isinstance(infos["_closing_observation"], np.ndarray) for i, done in enumerate(dones): if done: - assert infos["_terminal_observation"][i] + assert infos["_closing_observation"][i] else: - assert not infos["_terminal_observation"][i] - assert infos["terminal_observation"][i] is None + assert not infos["_closing_observation"][i] + assert infos["closing_observation"][i] is None @pytest.mark.parametrize("concurrent_ends", [1, 2, 3]) @@ -47,8 +47,8 @@ def test_vector_env_info_concurrent_termination(concurrent_ends): for i, done in enumerate(dones): if i < concurrent_ends: assert done - assert infos["_terminal_observation"][i] + assert infos["_closing_observation"][i] else: - assert not infos["_terminal_observation"][i] - assert infos["terminal_observation"][i] is None + assert not infos["_closing_observation"][i] + assert infos["closing_observation"][i] is None return diff --git a/tests/wrappers/test_vector_list_info.py b/tests/wrappers/test_vector_list_info.py index 1df737e1ed2..162dab7a0cf 100644 --- a/tests/wrappers/test_vector_list_info.py +++ b/tests/wrappers/test_vector_list_info.py @@ -32,9 +32,9 @@ def test_info_to_list(): _, _, dones, list_info = wrapped_env.step(action) for i, done in enumerate(dones): if done: - assert "terminal_observation" in list_info[i] + assert "closing_observation" in list_info[i] else: - assert "terminal_observation" not in list_info[i] + assert "closing_observation" not in list_info[i] def test_info_to_list_statistics(): From 63d3d1957856e1362e65b1b2261e0c5c8a146319 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Sat, 28 May 2022 22:04:03 +0530 Subject: [PATCH 16/37] update wrapper class to use step compatibility --- gym/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gym/core.py b/gym/core.py index 4f31d0eae5e..b805a74425b 100644 --- a/gym/core.py +++ b/gym/core.py @@ -332,7 +332,11 @@ def step( Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] ]: """Steps through the environment with action.""" - return self.env.step(action) # ! Does this take self.new_step_api into account? + from gym.utils.step_api_compatibility import ( # avoid circular import + step_api_compatibility, + ) + + return step_api_compatibility(self.env.step(action), self.new_step_api) def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]: """Resets the environment with kwargs.""" From 9ce03cbdcecfffc26ff32119098a0b3003aacbd8 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 2 Jun 2022 06:29:28 +0530 Subject: [PATCH 17/37] add warning for play --- gym/utils/play.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gym/utils/play.py b/gym/utils/play.py index 42263d41f41..d7d888de955 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -202,6 +202,11 @@ def play( seed: Random seed used when resetting the environment. If None, no seed is used. noop: The action used when no key input has been entered, or the entered key combination is unknown. """ + + deprecation( + "`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools." + ) + env.reset(seed=seed) key_code_to_action = {} From f93295f6606ef07f0e9b8f13a58bd8643b428cc7 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 2 Jun 2022 06:30:18 +0530 Subject: [PATCH 18/37] add todo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 560d92eec32..0d9aaf0a472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,4 +44,4 @@ reportUnboundVariable = "warning" reportGeneralTypeIssues ="none" [tool.pytest.ini_options] -filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # to be removed when old step API is removed +filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed From 19404943c698ce28a9f4952745e1557e6d387fe9 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 2 Jun 2022 06:31:57 +0530 Subject: [PATCH 19/37] replace 'closing' with 'final' --- gym/vector/async_vector_env.py | 4 ++-- gym/vector/sync_vector_env.py | 2 +- gym/wrappers/autoreset.py | 26 ++++++++++++------------- tests/vector/test_vector_env.py | 8 ++++---- tests/vector/test_vector_env_info.py | 20 +++++++++---------- tests/wrappers/test_autoreset.py | 4 ++-- tests/wrappers/test_vector_list_info.py | 4 ++-- 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index aaffe62ae69..29610bfa5d7 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -621,7 +621,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): info, ) = step_api_compatibility(env.step(data), True) if terminated or truncated: - info["closing_observation"] = observation + info["final_observation"] = observation observation = env.reset() pipe.send(((observation, reward, terminated, truncated, info), True)) elif command == "seed": @@ -696,7 +696,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error info, ) = step_api_compatibility(env.step(data), True) if terminated or truncated: - info["closing_observation"] = observation + info["final_observation"] = observation observation = env.reset() write_to_shared_memory( observation_space, index, observation, shared_memory diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 25264a45bd1..714ced9e351 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -163,7 +163,7 @@ def step_wait(self): info, ) = step_api_compatibility(env.step(action), True) if self._terminateds[i] or self._truncateds[i]: - info["closing_observation"] = observation + info["final_observation"] = observation observation = env.reset() observations.append(observation) infos = self._add_info(infos, info, i) diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index 8f6d24db165..5c31d80223b 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -7,20 +7,20 @@ class AutoResetWrapper(gym.Wrapper): """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called, - and the return format of :meth:`self.step` is as follows: ``(new_obs, closing_reward, closing_terminated, closing_truncated, info)`` - with new step API and ``(new_obs, closing_reward, closing_done, info)`` with the old step API. + and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)`` + with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API. - ``new_obs`` is the first observation after calling :meth:`self.env.reset` - - ``closing_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. - - ``closing_done`` is always True. In the new API, either ``closing_terminated`` or ``closing_truncated`` is True + - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. + - ``final_done`` is always True. In the new API, either ``final_terminated`` or ``final_truncated`` is True - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`, - with an additional key "closing_observation" containing the observation returned by the last call to :meth:`self.env.step` - and "closing_info" containing the info dict returned by the last call to :meth:`self.env.step`. + with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step` + and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`. Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns done, a new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the closing reward and done state from the previous episode. If you need the closing state from the previous episode, you need to retrieve it via the - "closing_observation" key in the info dict. + "final_observation" key in the info dict. Make sure you know what you're doing if you use this wrapper! """ @@ -50,14 +50,14 @@ def step(self, action): new_obs, new_info = self.env.reset(return_info=True) assert ( - "closing_observation" not in new_info - ), 'info dict cannot contain key "closing_observation" ' + "final_observation" not in new_info + ), 'info dict cannot contain key "final_observation" ' assert ( - "closing_info" not in new_info - ), 'info dict cannot contain key "closing_info" ' + "final_info" not in new_info + ), 'info dict cannot contain key "final_info" ' - new_info["closing_observation"] = obs - new_info["closing_info"] = info + new_info["final_observation"] = obs + new_info["final_info"] = info obs = new_obs info = new_info diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index ed1bdccb76d..748bc9b06ec 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -36,10 +36,10 @@ def test_vector_env_equal(shared_memory): # fmt: on if any(sync_dones): - assert "closing_observation" in async_infos - assert "_closing_observation" in async_infos - assert "closing_observation" in sync_infos - assert "_closing_observation" in sync_infos + assert "final_observation" in async_infos + assert "_final_observation" in async_infos + assert "final_observation" in sync_infos + assert "_final_observation" in sync_infos assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) diff --git a/tests/vector/test_vector_env_info.py b/tests/vector/test_vector_env_info.py index eb5a2a3dd2e..9aa729f0365 100644 --- a/tests/vector/test_vector_env_info.py +++ b/tests/vector/test_vector_env_info.py @@ -20,18 +20,18 @@ def test_vector_env_info(asynchronous): action = env.action_space.sample() _, _, dones, infos = env.step(action) if any(dones): - assert len(infos["closing_observation"]) == NUM_ENVS - assert len(infos["_closing_observation"]) == NUM_ENVS + assert len(infos["final_observation"]) == NUM_ENVS + assert len(infos["_final_observation"]) == NUM_ENVS - assert isinstance(infos["closing_observation"], np.ndarray) - assert isinstance(infos["_closing_observation"], np.ndarray) + assert isinstance(infos["final_observation"], np.ndarray) + assert isinstance(infos["_final_observation"], np.ndarray) for i, done in enumerate(dones): if done: - assert infos["_closing_observation"][i] + assert infos["_final_observation"][i] else: - assert not infos["_closing_observation"][i] - assert infos["closing_observation"][i] is None + assert not infos["_final_observation"][i] + assert infos["final_observation"][i] is None @pytest.mark.parametrize("concurrent_ends", [1, 2, 3]) @@ -47,8 +47,8 @@ def test_vector_env_info_concurrent_termination(concurrent_ends): for i, done in enumerate(dones): if i < concurrent_ends: assert done - assert infos["_closing_observation"][i] + assert infos["_final_observation"][i] else: - assert not infos["_closing_observation"][i] - assert infos["closing_observation"][i] is None + assert not infos["_final_observation"][i] + assert infos["final_observation"][i] is None return diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 25a7344f428..3515e1260c8 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -132,8 +132,8 @@ def test_autoreset_wrapper_autoreset(): assert reward == 1 assert info == { "count": 0, - "closing_observation": np.array([3]), - "closing_info": {"count": 3}, + "final_observation": np.array([3]), + "final_info": {"count": 3}, } obs, reward, done, info = env.step(action) diff --git a/tests/wrappers/test_vector_list_info.py b/tests/wrappers/test_vector_list_info.py index 162dab7a0cf..f8cd0f16c9f 100644 --- a/tests/wrappers/test_vector_list_info.py +++ b/tests/wrappers/test_vector_list_info.py @@ -32,9 +32,9 @@ def test_info_to_list(): _, _, dones, list_info = wrapped_env.step(action) for i, done in enumerate(dones): if done: - assert "closing_observation" in list_info[i] + assert "final_observation" in list_info[i] else: - assert "closing_observation" not in list_info[i] + assert "final_observation" not in list_info[i] def test_info_to_list_statistics(): From f12b5fbea7a0c4ec7a1555ecf75d432634bdae82 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 2 Jun 2022 06:37:01 +0530 Subject: [PATCH 20/37] fix pre-commit --- gym/utils/play.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gym/utils/play.py b/gym/utils/play.py index d7d888de955..039609960a1 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -202,7 +202,6 @@ def play( seed: Random seed used when resetting the environment. If None, no seed is used. noop: The action used when no key input has been entered, or the entered key combination is unknown. """ - deprecation( "`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools." ) From aa5a0710be6e8b6d7182549626557a345ea8c6ff Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Fri, 3 Jun 2022 09:54:19 +0530 Subject: [PATCH 21/37] remove previously missed `done` references --- gym/core.py | 8 ++++---- gym/error.py | 2 +- gym/vector/async_vector_env.py | 2 +- gym/wrappers/atari_preprocessing.py | 2 +- gym/wrappers/autoreset.py | 8 ++++---- gym/wrappers/frame_stack.py | 2 +- gym/wrappers/record_video.py | 2 +- gym/wrappers/time_limit.py | 13 ++++++++----- gym/wrappers/vector_list_info.py | 16 ++++++++++++---- 9 files changed, 33 insertions(+), 22 deletions(-) diff --git a/gym/core.py b/gym/core.py index b805a74425b..7a9b64b79d8 100644 --- a/gym/core.py +++ b/gym/core.py @@ -86,11 +86,11 @@ def step( reward (float): The amount of reward returned as a result of taking the action. terminated (bool): whether the episode has ended due to reaching a terminal state intrinsic to the core environment, in which case further step() calls will return undefined results truncated (bool): whether the episode has ended due to a truncation, i.e., a timelimit outside the scope of the problem defined in the environment. - info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal. - `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). + info (dictionary): `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain: metrics that describe the agent's performance state, variables that are - hidden from observations, information that distinguishes truncation and termination or individual reward terms - that are combined to produce the total reward + hidden from observations, or individual reward terms that are combined to produce the total reward. + It also can contain information that distinguishes truncation and termination, however this is deprecated in favour + of returning two booleans, and will be removed in a future version. (deprecated) done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results. diff --git a/gym/error.py b/gym/error.py index 8219c5d13f8..9a9b8899886 100644 --- a/gym/error.py +++ b/gym/error.py @@ -58,7 +58,7 @@ class ResetNeeded(Error): class ResetNotAllowed(Error): - """When the monitor is active, raised when the user tries to step an environment that's not yet done.""" + """When the monitor is active, raised when the user tries to step an environment that's not yet terminated or truncated.""" class InvalidAction(Error): diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 29610bfa5d7..6af82c0ffe9 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -86,7 +86,7 @@ def __init__( the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``. worker: If set, then use that worker in a subprocess instead of a default one. - Can be useful to override some inner vector env logic, for instance, how resets on done are handled. + Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled. new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance diff --git a/gym/wrappers/atari_preprocessing.py b/gym/wrappers/atari_preprocessing.py index 794dc826de3..31245fcccde 100644 --- a/gym/wrappers/atari_preprocessing.py +++ b/gym/wrappers/atari_preprocessing.py @@ -48,7 +48,7 @@ def __init__( noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. screen_size (int): resize Atari frame - terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `done=True` whenever a + terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a life is lost. grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation is returned. diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index 5c31d80223b..6e20c92ffed 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -16,10 +16,10 @@ class AutoResetWrapper(gym.Wrapper): with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step` and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`. - Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns done, a + Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the - closing reward and done state from the previous episode. - If you need the closing state from the previous episode, you need to retrieve it via the + final reward and done state from the previous episode. + If you need the final state from the previous episode, you need to retrieve it via the "final_observation" key in the info dict. Make sure you know what you're doing if you use this wrapper! """ @@ -34,7 +34,7 @@ def __init__(self, env: gym.Env, new_step_api: bool = False): super().__init__(env, new_step_api) def step(self, action): - """Steps through the environment with action and resets the environment if a done-signal is encountered. + """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered. Args: action: The action to take diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index a6ca7193f63..c7eb19aaacd 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -171,7 +171,7 @@ def step(self, action): action: The action to step through the environment with Returns: - Stacked observations, reward, done and information from the environment + Stacked observations, reward, terminated, truncated, and information from the environment """ observation, reward, terminated, truncated, info = step_api_compatibility( self.env.step(action), True diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index 02f1417401d..5c97fad6ff0 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -33,7 +33,7 @@ class RecordVideo(gym.Wrapper): They should be functions returning a boolean that indicates whether a recording should be started at the current episode or step, respectively. If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed. - By default, the recording will be stopped once a `done` signal has been emitted by the environment. However, you can + By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``. """ diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 92b41c58654..8e9f67f4ae9 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -6,13 +6,16 @@ class TimeLimit(gym.Wrapper): - """This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded. + """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. - Oftentimes, it is **very** important to distinguish `done` signals that were produced by the - :class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations). - This can be done by looking at the ``info`` that is returned when `done`-signal was issued. + If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. + Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. + + (deprecated) + This information is passed through ``info`` that is returned when `done`-signal was issued. The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if - the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. + the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. This will be removed in favour + of only issuing a `truncated` signal in future versions. Example: >>> from gym.envs.classic_control import CartPoleEnv diff --git a/gym/wrappers/vector_list_info.py b/gym/wrappers/vector_list_info.py index bf5cbc82933..874367f2811 100644 --- a/gym/wrappers/vector_list_info.py +++ b/gym/wrappers/vector_list_info.py @@ -3,6 +3,7 @@ from typing import List import gym +from gym.utils.step_api_compatibility import step_api_compatibility class VectorListInfo(gym.Wrapper): @@ -29,23 +30,30 @@ class VectorListInfo(gym.Wrapper): """ - def __init__(self, env): + def __init__(self, env, new_step_api=False): """This wrapper will convert the info into the list format. Args: env (Env): The environment to apply the wrapper + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ assert getattr( env, "is_vector_env", False ), "This wrapper can only be used in vectorized environments." - super().__init__(env) + super().__init__(env, new_step_api) def step(self, action): """Steps through the environment, convert dict info to list.""" - observation, reward, done, infos = self.env.step(action) + observation, reward, terminated, truncated, infos = step_api_compatibility( + self.env.step(action), True, True + ) list_info = self._convert_info_to_list(infos) - return observation, reward, done, list_info + return step_api_compatibility( + (observation, reward, terminated, truncated, list_info), + self.new_step_api, + True, + ) def reset(self, **kwargs): """Resets the environment using kwargs.""" From e135b9ee5f70a37eb8b8ccf8492366be16228485 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Fri, 3 Jun 2022 14:54:22 +0530 Subject: [PATCH 22/37] fix step compat in atari wrapper reset --- gym/wrappers/atari_preprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gym/wrappers/atari_preprocessing.py b/gym/wrappers/atari_preprocessing.py index 31245fcccde..9015686191b 100644 --- a/gym/wrappers/atari_preprocessing.py +++ b/gym/wrappers/atari_preprocessing.py @@ -163,7 +163,9 @@ def reset(self, **kwargs): else 0 ) for _ in range(noops): - _, _, terminated, truncated, step_info = self.env.step(0) + _, _, terminated, truncated, step_info = step_api_compatibility( + self.env.step(0), True + ) reset_info.update(step_info) if terminated or truncated: if kwargs.get("return_info", False): From 1f11077ea84fabdfd7a800bad58881870cd6343d Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 7 Jun 2022 22:46:23 +0530 Subject: [PATCH 23/37] fix tests with step returning np.bool_ --- gym/envs/mujoco/inverted_pendulum.py | 2 +- gym/envs/mujoco/inverted_pendulum_v4.py | 2 +- gym/envs/registration.py | 18 +++++++----------- gym/utils/env_checker.py | 1 - gym/utils/passive_env_checker.py | 12 +++++++++--- gym/vector/__init__.py | 11 +++++++++-- 6 files changed, 27 insertions(+), 19 deletions(-) diff --git a/gym/envs/mujoco/inverted_pendulum.py b/gym/envs/mujoco/inverted_pendulum.py index 850c3460aaa..ae02d89db67 100644 --- a/gym/envs/mujoco/inverted_pendulum.py +++ b/gym/envs/mujoco/inverted_pendulum.py @@ -15,7 +15,7 @@ def step(self, a): reward = 1.0 self.do_simulation(a, self.frame_skip) ob = self._get_obs() - terminated = not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2) + terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2)) return ob, reward, terminated, False, {} def reset_model(self): diff --git a/gym/envs/mujoco/inverted_pendulum_v4.py b/gym/envs/mujoco/inverted_pendulum_v4.py index e26ce1d2446..c2b7b64b640 100644 --- a/gym/envs/mujoco/inverted_pendulum_v4.py +++ b/gym/envs/mujoco/inverted_pendulum_v4.py @@ -98,7 +98,7 @@ def step(self, a): reward = 1.0 self.do_simulation(a, self.frame_skip) ob = self._get_obs() - terminated = not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2) + terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2)) return ob, reward, terminated, False, {} def reset_model(self): diff --git a/gym/envs/registration.py b/gym/envs/registration.py index ee22bee5e54..1883b8b7e4b 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -23,7 +23,12 @@ import numpy as np from gym.envs.__relocated__ import internal_env_relocation_map -from gym.wrappers import AutoResetWrapper, OrderEnforcing, StepAPICompatibility, TimeLimit +from gym.wrappers import ( + AutoResetWrapper, + OrderEnforcing, + StepAPICompatibility, + TimeLimit, +) from gym.wrappers.env_checker import PassiveEnvChecker if sys.version_info < (3, 10): @@ -602,7 +607,7 @@ def make( # Run the environment checker as the lowest level wrapper if disable_env_checker is False: env = PassiveEnvChecker(env) - + env = StepAPICompatibility(env, new_step_api) # Add the order enforcing wrapper @@ -619,15 +624,6 @@ def make( if autoreset: env = AutoResetWrapper(env, new_step_api) - if not disable_env_checker: - try: - check_env(env) - except Exception as e: - logger.warn( - f"Env check failed with the following message: {e}\n" - f"You can set `disable_env_checker=True` to disable this check." - ) - return env diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 82419815a37..e8d7fbeb2bb 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -27,7 +27,6 @@ passive_env_reset_check, passive_env_step_check, ) -from gym.utils.step_api_compatibility import step_api_compatibility def data_equivalence(data_1, data_2) -> bool: diff --git a/gym/utils/passive_env_checker.py b/gym/utils/passive_env_checker.py index d49ea7faf8f..c921f98d942 100644 --- a/gym/utils/passive_env_checker.py +++ b/gym/utils/passive_env_checker.py @@ -255,12 +255,18 @@ def passive_env_step_check(env, action): if len(result) == 4: obs, reward, done, info = result - assert isinstance(done, bool), "The `done` signal must be a boolean" + assert isinstance( + done, bool + ), f"The `done` signal is of type `{type(done)}` must be a boolean" elif len(result) == 5: obs, reward, terminated, truncated, info = result - assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" - assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" + assert isinstance( + terminated, bool + ), f"The `terminated` signal is of type `{type(terminated)}`. It must be a boolean" + assert isinstance( + truncated, bool + ), f"The `truncated` signal of type `{type(truncated)}`. It must be a boolean." assert ( terminated is False or truncated is False ), "Only `terminated` or `truncated` can be true, not both." diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index 6a9c6218829..a0ea7df0cdb 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -48,7 +48,10 @@ def create_env(_disable_env_checker): def _make_env(): env = gym.envs.registration.make( - id, disable_env_checker=_disable_env_checker, new_step_api=True, **kwargs + id, + disable_env_checker=_disable_env_checker, + new_step_api=True, + **kwargs, ) if wrappers is not None: if callable(wrappers): @@ -68,4 +71,8 @@ def _make_env(): create_env(env_num == 0 and disable_env_checker is False) for env_num in range(num_envs) ] - return AsyncVectorEnv(env_fns, new_step_api=new_step_api) if asynchronous else SyncVectorEnv(env_fns, new_step_api=new_step_api) + return ( + AsyncVectorEnv(env_fns, new_step_api=new_step_api) + if asynchronous + else SyncVectorEnv(env_fns, new_step_api=new_step_api) + ) From e861fbc0b0b4de6f8e1de6038fcf06d412264263 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 7 Jun 2022 23:19:15 +0530 Subject: [PATCH 24/37] remove warning for using new api --- gym/utils/step_api_compatibility.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index f0454153420..e47db04eb40 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -32,10 +32,6 @@ def step_to_new_api( is_vector_env (bool): Whether the step_returns are from a vector environment """ if len(step_returns) == 5: - deprecation( - "Using an environment with new step API that returns two bools terminated, truncated instead of one bool done. " - "Take care to supporting code to be compatible with this API" - ) return step_returns else: assert len(step_returns) == 4 From 8e56f459ac741acaf7f6a41532b9d22f1fedf5cb Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 8 Jun 2022 09:56:33 +0530 Subject: [PATCH 25/37] pre-commit fixes --- gym/envs/classic_control/acrobot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gym/envs/classic_control/acrobot.py b/gym/envs/classic_control/acrobot.py index 290757e0fcf..8152b1ac1c3 100644 --- a/gym/envs/classic_control/acrobot.py +++ b/gym/envs/classic_control/acrobot.py @@ -223,7 +223,7 @@ def step(self, a): self.state = ns terminated = self._terminal() reward = -1.0 if not terminated else 0.0 - + self.renderer.render_step() return (self._get_ob(), reward, terminated, False, {}) From 5e8f085beece46eec300e75e9ef7f7e4e0e40c9f Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Mon, 20 Jun 2022 23:08:26 +0530 Subject: [PATCH 26/37] new API does not include 'TimeLimit.truncated' in info --- gym/envs/box2d/car_racing.py | 7 +++--- gym/utils/step_api_compatibility.py | 37 ++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/gym/envs/box2d/car_racing.py b/gym/envs/box2d/car_racing.py index cdb7da7c307..db517d64dae 100644 --- a/gym/envs/box2d/car_racing.py +++ b/gym/envs/box2d/car_racing.py @@ -483,7 +483,7 @@ def step(self, action: Union[np.ndarray, int]): step_reward = 0 terminated = False - info = {} + truncated = False if action is not None: # First step without action, called from reset() self.reward -= 0.1 # We actually don't want to count fuel spent, we want car to be faster. @@ -492,18 +492,17 @@ def step(self, action: Union[np.ndarray, int]): step_reward = self.reward - self.prev_reward self.prev_reward = self.reward if self.tile_visited_count == len(self.track) or self.new_lap: - truncated = True # Truncation due to finishing lap # This should not be treated as a failure # but like a timeout - info["TimeLimit.truncated"] = True + truncated = True x, y = self.car.hull.position if abs(x) > PLAYFIELD or abs(y) > PLAYFIELD: terminated = True step_reward = -100 self.renderer.render_step() - return self.state, step_reward, terminated, truncated, info + return self.state, step_reward, terminated, truncated, {} def render(self, mode: str = "human"): if self.render_mode is not None: diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index e47db04eb40..6c830b642f0 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -50,33 +50,55 @@ def step_to_new_api( dones = [dones] for i in range(len(dones)): - if "TimeLimit.truncated" not in infos or ( + # For every condition, handling - info single env / info vector env (list) / info vector env (dict) + + # TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done` + if (not is_vector_env and "TimeLimit.truncated" not in infos) or ( is_vector_env - and "TimeLimit.truncated" in infos - and not infos["_TimeLimit.truncated"][ - i - ] # if mask is False, it's the same as TimeLimit.truncated attribute not being present + and ( + ( + isinstance(infos, list) + and "TimeLimit.truncated" not in infos[i] + ) # vector env, list info api + or ( + "TimeLimit.truncated" in infos + and not infos["_TimeLimit.truncated"][i] + ) # vector env, dict info api, if mask is False, it's the same as TimeLimit.truncated attribute not being present for env 'i' + ) ): + terminateds.append(dones[i]) truncateds.append(False) + + # This means info["TimeLimit.truncated"] exists and is True, which means the truncation has occurred but termination has not. elif ( infos["TimeLimit.truncated"] if not is_vector_env - else ( # handle vector info as both dict and list + else ( infos["TimeLimit.truncated"][i] if isinstance(infos, dict) else infos[i]["TimeLimit.truncated"] ) ): + assert dones[i] is True terminateds.append(False) truncateds.append(True) else: # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, # but it also exceeded maximum timesteps at the same step. - + assert dones[i] is True terminateds.append(True) truncateds.append(True) + # removing "TimeLimit.truncated" from info + if isinstance(infos, list): + infos[i].pop(["TimeLimit.truncated"], None) + + # if info dict vector, can only pop after all envs are processed (also for single env) + if isinstance(infos, dict): + infos.pop("TimeLimit.truncated", None) + infos.pop("TimeLimit.truncated_", None) + return ( observations, rewards, @@ -118,7 +140,6 @@ def step_to_old_api( for i in range(n_envs): dones.append(terminateds[i] or truncateds[i]) - # to be consistent with old API if truncateds[i]: if is_vector_env: # handle vector infos for dict and list From cdb35160d5f867f492e24f5dbc0f9d709cb220e2 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Mon, 20 Jun 2022 23:58:49 +0530 Subject: [PATCH 27/37] fix checks, tests --- gym/envs/box2d/lunar_lander.py | 3 ++- gym/utils/step_api_compatibility.py | 13 ++++++++----- tests/envs/test_action_dim_check.py | 8 ++++++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/gym/envs/box2d/lunar_lander.py b/gym/envs/box2d/lunar_lander.py index 90b51678dfd..f5ceed335a6 100644 --- a/gym/envs/box2d/lunar_lander.py +++ b/gym/envs/box2d/lunar_lander.py @@ -11,6 +11,7 @@ from gym.error import DependencyNotInstalled from gym.utils import EzPickle, colorize from gym.utils.renderer import Renderer +from gym.utils.step_api_compatibility import step_api_compatibility try: import Box2D @@ -762,7 +763,7 @@ def demo_heuristic_lander(env, seed=None, render=False): s = env.reset(seed=seed) while True: a = heuristic(env, s) - s, r, terminated, truncated, info = env.step(a) + s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True) total_reward += r if render: diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 6c830b642f0..8877df6a206 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -61,8 +61,11 @@ def step_to_new_api( and "TimeLimit.truncated" not in infos[i] ) # vector env, list info api or ( - "TimeLimit.truncated" in infos - and not infos["_TimeLimit.truncated"][i] + "TimeLimit.truncated" not in infos + or ( + "TimeLimit.truncated" in infos + and not infos["_TimeLimit.truncated"][i] + ) ) # vector env, dict info api, if mask is False, it's the same as TimeLimit.truncated attribute not being present for env 'i' ) ): @@ -70,7 +73,7 @@ def step_to_new_api( terminateds.append(dones[i]) truncateds.append(False) - # This means info["TimeLimit.truncated"] exists and is True, which means the truncation has occurred but termination has not. + # This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not. elif ( infos["TimeLimit.truncated"] if not is_vector_env @@ -80,13 +83,13 @@ def step_to_new_api( else infos[i]["TimeLimit.truncated"] ) ): - assert dones[i] is True + assert dones[i] terminateds.append(False) truncateds.append(True) else: # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, # but it also exceeded maximum timesteps at the same step. - assert dones[i] is True + assert dones[i] terminateds.append(True) truncateds.append(True) diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 61b3ca2eb0d..af857567795 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -112,7 +112,9 @@ def test_box_actions_out_of_bound(env: gym.Env): zip(env.action_space.bounded_above, env.action_space.bounded_below) ): if is_upper_bound: - obs, _, _, _ = env.step(upper_bounds) + obs, _, _, _, _ = env.step( + upper_bounds + ) # `env` is unwrapped, and in new step API oob_action = upper_bounds.copy() oob_action[i] += np.cast[dtype](OOB_VALUE) @@ -122,7 +124,9 @@ def test_box_actions_out_of_bound(env: gym.Env): assert np.alltrue(obs == oob_obs) if is_lower_bound: - obs, _, _, _ = env.step(lower_bounds) + obs, _, _, _, _ = env.step( + lower_bounds + ) # `env` is unwrapped, and in new step API oob_action = lower_bounds.copy() oob_action[i] -= np.cast[dtype](OOB_VALUE) From 8cc2074574084eb24129af51263c9895318bdba5 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 21 Jun 2022 00:15:48 +0530 Subject: [PATCH 28/37] vector info mask - fix wrong underscore --- gym/utils/step_api_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 8877df6a206..8f86890f77d 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -100,7 +100,7 @@ def step_to_new_api( # if info dict vector, can only pop after all envs are processed (also for single env) if isinstance(infos, dict): infos.pop("TimeLimit.truncated", None) - infos.pop("TimeLimit.truncated_", None) + infos.pop("_TimeLimit.truncated", None) return ( observations, From 2f83d55246fad4e5b5528b64d83cb87f14a00d1e Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 21 Jun 2022 21:09:52 +0530 Subject: [PATCH 29/37] dont remove from info --- gym/utils/step_api_compatibility.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 8f86890f77d..5fa9c1acd66 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -93,15 +93,6 @@ def step_to_new_api( terminateds.append(True) truncateds.append(True) - # removing "TimeLimit.truncated" from info - if isinstance(infos, list): - infos[i].pop(["TimeLimit.truncated"], None) - - # if info dict vector, can only pop after all envs are processed (also for single env) - if isinstance(infos, dict): - infos.pop("TimeLimit.truncated", None) - infos.pop("_TimeLimit.truncated", None) - return ( observations, rewards, @@ -152,13 +143,19 @@ def step_to_old_api( infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) - infos["TimeLimit.truncated"][i] = not terminateds[i] + infos["TimeLimit.truncated"][i] = ( + not terminateds[i] or infos["TimeLimit.truncated"][i] + ) infos["_TimeLimit.truncated"][i] = True else: # if vector info is a list - infos[i]["TimeLimit.truncated"] = not terminateds[i] + infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[ + i + ].get("TimeLimit.truncated", False) else: - infos["TimeLimit.truncated"] = not terminateds[i] + infos["TimeLimit.truncated"] = not terminateds[i] or infos.get( + "TimeLimit.truncated", False + ) return ( observations, rewards, From b1660cf71005851b80e77b23d4b824a7987bcdd3 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 21 Jun 2022 21:43:10 +0530 Subject: [PATCH 30/37] edit definitions --- gym/core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gym/core.py b/gym/core.py index 54feb643095..f18d7bd47dd 100644 --- a/gym/core.py +++ b/gym/core.py @@ -135,8 +135,11 @@ def step( observation (object): this will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects. reward (float): The amount of reward returned as a result of taking the action. - terminated (bool): whether the episode has ended due to reaching a terminal state intrinsic to the core environment, in which case further step() calls will return undefined results - truncated (bool): whether the episode has ended due to a truncation, i.e., a timelimit outside the scope of the problem defined in the environment. + terminated (bool): whether a `terminal state` (as defined under the MDP of the task) is reached. + In this case further step() calls could return undefined results. + truncated (bool): whether a truncation condition outside the scope of the MDP is satisfied. + Typically a timelimit, but could also be used to indicate agent physically going out of bounds. + Can be used to end the episode prematurely before a `terminal state` is reached. info (dictionary): `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain: metrics that describe the agent's performance state, variables that are hidden from observations, or individual reward terms that are combined to produce the total reward. From ea10e7a1dcfa772edc1a310f5eaaa283d6dd791c Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 21 Jun 2022 21:48:53 +0530 Subject: [PATCH 31/37] remove whitespaces :/ --- gym/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gym/core.py b/gym/core.py index f18d7bd47dd..1b472665409 100644 --- a/gym/core.py +++ b/gym/core.py @@ -135,10 +135,10 @@ def step( observation (object): this will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects. reward (float): The amount of reward returned as a result of taking the action. - terminated (bool): whether a `terminal state` (as defined under the MDP of the task) is reached. - In this case further step() calls could return undefined results. - truncated (bool): whether a truncation condition outside the scope of the MDP is satisfied. - Typically a timelimit, but could also be used to indicate agent physically going out of bounds. + terminated (bool): whether a `terminal state` (as defined under the MDP of the task) is reached. + In this case further step() calls could return undefined results. + truncated (bool): whether a truncation condition outside the scope of the MDP is satisfied. + Typically a timelimit, but could also be used to indicate agent physically going out of bounds. Can be used to end the episode prematurely before a `terminal state` is reached. info (dictionary): `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain: metrics that describe the agent's performance state, variables that are From d7dff2c3afd4ac7bb52caac25e04ae1085ef3181 Mon Sep 17 00:00:00 2001 From: Arjun KG Date: Sun, 3 Jul 2022 16:24:50 +0530 Subject: [PATCH 32/37] update tests --- tests/envs/test_env_implementation.py | 2 +- tests/envs/test_make.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index 29043d0782e..c0b3884e23c 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -141,4 +141,4 @@ def test_taxi_encode_decode(): assert ( env.encode(*env.decode(state)) == state ), f"state={state}, encode(decode(state))={env.encode(*env.decode(state))}" - state, _, _, _ = env.step(env.action_space.sample()) + state, _, _, _, _ = env.step(env.action_space.sample()) diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 4d45fe276d0..1eeeda90756 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -13,6 +13,14 @@ from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv from tests.wrappers.utils import has_wrapper +IGNORE_WARNINGS = [ + f"\x1b[33mWARN: {message}\x1b[0m" + for message in [ + "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." + ] +] + + gym.register( "RegisterDuringMakeEnv-v0", entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv", @@ -172,7 +180,9 @@ def test_make_render_mode(): assert env.render_mode == valid_render_modes[0] env.close() - assert len(warnings) == 0 + for warning in warnings.list: + if warning.message.args[0] not in IGNORE_WARNINGS: + raise gym.error.Error(f"Unexpected warning: {warning.message}") # Make sure that native rendering is used when possible env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True) From b2c10a4d5b279a1e9a35edcec3fabc465456755b Mon Sep 17 00:00:00 2001 From: Arjun KG Date: Sun, 3 Jul 2022 17:13:19 +0530 Subject: [PATCH 33/37] fix pattern --- tests/envs/test_make.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 1eeeda90756..b58636641d5 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -13,14 +13,6 @@ from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv from tests.wrappers.utils import has_wrapper -IGNORE_WARNINGS = [ - f"\x1b[33mWARN: {message}\x1b[0m" - for message in [ - "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." - ] -] - - gym.register( "RegisterDuringMakeEnv-v0", entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv", @@ -181,7 +173,7 @@ def test_make_render_mode(): env.close() for warning in warnings.list: - if warning.message.args[0] not in IGNORE_WARNINGS: + if not re.compile(".*step API.*").match(warning.message.args[0]): raise gym.error.Error(f"Unexpected warning: {warning.message}") # Make sure that native rendering is used when possible From 6553bedfd0d2d2a78c3c1e97b156703fb184bf2c Mon Sep 17 00:00:00 2001 From: Arjun KG Date: Mon, 4 Jul 2022 21:45:38 +0530 Subject: [PATCH 34/37] restructure warnings --- gym/core.py | 6 ++++++ gym/utils/step_api_compatibility.py | 17 ----------------- gym/vector/vector_env.py | 5 +++++ 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/gym/core.py b/gym/core.py index 0ae8cf17fd6..2c351f832f3 100644 --- a/gym/core.py +++ b/gym/core.py @@ -321,6 +321,12 @@ def __init__(self, env: Env, new_step_api: bool = False): self._metadata: Optional[dict] = None self.new_step_api = new_step_api + if not self.new_step_api: + deprecation( + "Initializing wrapper in old step API which returns one bool instead of two. " + "It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. " + ) + def __getattr__(self, name): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" if name.startswith("_"): diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 5fa9c1acd66..2be07dbe35c 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -4,7 +4,6 @@ import numpy as np from gym.core import ObsType -from gym.logger import deprecation OldStepType = Tuple[ Union[ObsType, np.ndarray], @@ -35,13 +34,6 @@ def step_to_new_api( return step_returns else: assert len(step_returns) == 4 - deprecation( - "Transforming code with old step API into new. " - "It is recommended to upgrade the core env to the new step API. This can also be done by setting `new_step_api=True` at make. " - "If 'TimeLimit.truncated' is set at truncation, terminated and truncated values will be accurate. " - "Otherwise, `terminated=done` and `truncated=False`" - ) - observations, rewards, dones, infos = step_returns terminateds = [] @@ -112,18 +104,9 @@ def step_to_old_api( is_vector_env (bool): Whether the step_returns are from a vector environment """ if len(step_returns) == 4: - deprecation( - "Using old step API which returns one boolean (done). Please upgrade to new API to return two booleans - terminated, truncated" - ) - return step_returns else: assert len(step_returns) == 5 - deprecation( - "Transforming code in new step API (which returns two booleans terminated, truncated) into old (returns one boolean done). " - "It is recommended to upgrade accompanying code to be compatible with the new API, and use the new API by setting `new_step_api=True`. " - ) - observations, rewards, terminateds, truncateds, infos = step_returns dones = [] if not is_vector_env: diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index 103cdfb5606..b2edc6334f9 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -52,6 +52,11 @@ def __init__( self.single_action_space = action_space self.new_step_api = new_step_api + if not self.new_step_api: + deprecation( + "Initializing vector env in old step API which returns one bool array instead of two. " + "It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. " + ) def reset_async( self, From 50d367eaffb600096b4d3d3595a3d5eb38f1a25a Mon Sep 17 00:00:00 2001 From: Arjun KG Date: Mon, 4 Jul 2022 22:26:26 +0530 Subject: [PATCH 35/37] fix incorrect warning --- gym/wrappers/clip_action.py | 2 +- gym/wrappers/env_checker.py | 2 +- gym/wrappers/filter_observation.py | 2 +- gym/wrappers/flatten_observation.py | 2 +- gym/wrappers/gray_scale_observation.py | 2 +- gym/wrappers/human_rendering.py | 2 +- gym/wrappers/order_enforcing.py | 2 +- gym/wrappers/pixel_observation.py | 2 +- gym/wrappers/rescale_action.py | 2 +- gym/wrappers/resize_observation.py | 2 +- gym/wrappers/transform_observation.py | 2 +- gym/wrappers/transform_reward.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/gym/wrappers/clip_action.py b/gym/wrappers/clip_action.py index de236384768..58d981e96a1 100644 --- a/gym/wrappers/clip_action.py +++ b/gym/wrappers/clip_action.py @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env): env: The environment to apply the wrapper """ assert isinstance(env.action_space, Box) - super().__init__(env) + super().__init__(env, new_step_api=True) def action(self, action): """Clips the action within the valid bounds. diff --git a/gym/wrappers/env_checker.py b/gym/wrappers/env_checker.py index 4fe0c011c96..785cd69f40d 100644 --- a/gym/wrappers/env_checker.py +++ b/gym/wrappers/env_checker.py @@ -15,7 +15,7 @@ class PassiveEnvChecker(gym.Wrapper): def __init__(self, env): """Initialises the wrapper with the environments, run the observation and action space tests.""" - super().__init__(env) + super().__init__(env, new_step_api=True) assert hasattr( env, "action_space" diff --git a/gym/wrappers/filter_observation.py b/gym/wrappers/filter_observation.py index 922c8288038..bcbe13b5065 100644 --- a/gym/wrappers/filter_observation.py +++ b/gym/wrappers/filter_observation.py @@ -35,7 +35,7 @@ def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None): ValueError: If the environment's observation space is not :class:`spaces.Dict` ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space """ - super().__init__(env) + super().__init__(env, new_step_api=True) wrapped_observation_space = env.observation_space if not isinstance(wrapped_observation_space, spaces.Dict): diff --git a/gym/wrappers/flatten_observation.py b/gym/wrappers/flatten_observation.py index fe6518b875b..95aa13e0d01 100644 --- a/gym/wrappers/flatten_observation.py +++ b/gym/wrappers/flatten_observation.py @@ -25,7 +25,7 @@ def __init__(self, env: gym.Env): Args: env: The environment to apply the wrapper """ - super().__init__(env) + super().__init__(env, new_step_api=True) self.observation_space = spaces.flatten_space(env.observation_space) def observation(self, observation): diff --git a/gym/wrappers/gray_scale_observation.py b/gym/wrappers/gray_scale_observation.py index 565b5f32840..ca2a2c84323 100644 --- a/gym/wrappers/gray_scale_observation.py +++ b/gym/wrappers/gray_scale_observation.py @@ -28,7 +28,7 @@ def __init__(self, env: gym.Env, keep_dim: bool = False): keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. Otherwise, they are of shape AxB. """ - super().__init__(env) + super().__init__(env, new_step_api=True) self.keep_dim = keep_dim assert ( diff --git a/gym/wrappers/human_rendering.py b/gym/wrappers/human_rendering.py index 34b9d829c14..5a4f568ddc8 100644 --- a/gym/wrappers/human_rendering.py +++ b/gym/wrappers/human_rendering.py @@ -45,7 +45,7 @@ def __init__(self, env): Args: env: The environment that is being wrapped """ - super().__init__(env) + super().__init__(env, new_step_api=True) assert env.render_mode in [ "single_rgb_array", "rgb_array", diff --git a/gym/wrappers/order_enforcing.py b/gym/wrappers/order_enforcing.py index 79cfb12abd7..2515a3311c4 100644 --- a/gym/wrappers/order_enforcing.py +++ b/gym/wrappers/order_enforcing.py @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): env: The environment to wrap disable_render_order_enforcing: If to disable render order enforcing """ - super().__init__(env) + super().__init__(env, new_step_api=True) self._has_reset: bool = False self._disable_render_order_enforcing: bool = disable_render_order_enforcing diff --git a/gym/wrappers/pixel_observation.py b/gym/wrappers/pixel_observation.py index 0cc6d72ed37..a16eed2bc75 100644 --- a/gym/wrappers/pixel_observation.py +++ b/gym/wrappers/pixel_observation.py @@ -77,7 +77,7 @@ def __init__( specified ``pixel_keys``. TypeError: When an unexpected pixel type is used """ - super().__init__(env) + super().__init__(env, new_step_api=True) # Avoid side-effects that occur when render_kwargs is manipulated render_kwargs = copy.deepcopy(render_kwargs) diff --git a/gym/wrappers/rescale_action.py b/gym/wrappers/rescale_action.py index bf3cf6cd157..c5f2238159e 100644 --- a/gym/wrappers/rescale_action.py +++ b/gym/wrappers/rescale_action.py @@ -45,7 +45,7 @@ def __init__( ), f"expected Box action space, got {type(env.action_space)}" assert np.less_equal(min_action, max_action).all(), (min_action, max_action) - super().__init__(env) + super().__init__(env, new_step_api=True) self.min_action = ( np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action ) diff --git a/gym/wrappers/resize_observation.py b/gym/wrappers/resize_observation.py index 8ed7cc91092..7fbff4e8baa 100644 --- a/gym/wrappers/resize_observation.py +++ b/gym/wrappers/resize_observation.py @@ -32,7 +32,7 @@ def __init__(self, env: gym.Env, shape: Union[tuple, int]): env: The environment to apply the wrapper shape: The shape of the resized observations """ - super().__init__(env) + super().__init__(env, new_step_api=True) if isinstance(shape, int): shape = (shape, shape) assert all(x > 0 for x in shape), shape diff --git a/gym/wrappers/transform_observation.py b/gym/wrappers/transform_observation.py index 2af2e9afb40..4da9db5bac9 100644 --- a/gym/wrappers/transform_observation.py +++ b/gym/wrappers/transform_observation.py @@ -27,7 +27,7 @@ def __init__(self, env: gym.Env, f: Callable[[Any], Any]): env: The environment to apply the wrapper f: A function that transforms the observation """ - super().__init__(env) + super().__init__(env, new_step_api=True) assert callable(f) self.f = f diff --git a/gym/wrappers/transform_reward.py b/gym/wrappers/transform_reward.py index a17a8ef1bc0..13278182d6b 100644 --- a/gym/wrappers/transform_reward.py +++ b/gym/wrappers/transform_reward.py @@ -28,7 +28,7 @@ def __init__(self, env: gym.Env, f: Callable[[float], float]): env: The environment to apply the wrapper f: A function that transforms the reward """ - super().__init__(env) + super().__init__(env, new_step_api=True) assert callable(f) self.f = f From d71836fb23fc8f722ac8a4b677f311f4407d995e Mon Sep 17 00:00:00 2001 From: Arjun KG Date: Mon, 4 Jul 2022 22:35:48 +0530 Subject: [PATCH 36/37] fix incorrect warnings (properly) --- gym/wrappers/step_api_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index 67759a99ad4..6c081b67be9 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -33,7 +33,7 @@ def __init__(self, env: gym.Env, new_step_api=False): env (gym.Env): the env to wrap. Can be in old or new API new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.new_step_api = new_step_api if not self.new_step_api: deprecation( From a74762504e8f9e31061ca03e8335540c654d5e1a Mon Sep 17 00:00:00 2001 From: Arjun KG Date: Tue, 5 Jul 2022 12:00:49 +0530 Subject: [PATCH 37/37] add warning to env checker --- gym/utils/passive_env_checker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gym/utils/passive_env_checker.py b/gym/utils/passive_env_checker.py index 4088bdc501d..67e9cb10548 100644 --- a/gym/utils/passive_env_checker.py +++ b/gym/utils/passive_env_checker.py @@ -4,6 +4,7 @@ import numpy as np from gym import error, logger, spaces +from gym.logger import deprecation def _check_box_observation_space(observation_space: spaces.Box): @@ -253,6 +254,10 @@ def passive_env_step_check(env, action): """A passive check for the environment step, investigating the returning data then returning the data unchanged.""" result = env.step(action) if len(result) == 4: + deprecation( + "Core environment is written in old step API which returns one bool instead of two. " + "It is recommended to rewrite the environment with new step API. " + ) obs, reward, done, info = result assert isinstance(