diff --git a/gym/core.py b/gym/core.py index aea74e6ad0f..cb8a45715e3 100644 --- a/gym/core.py +++ b/gym/core.py @@ -130,11 +130,16 @@ def np_random(self) -> RandomNumberGenerator: def np_random(self, value: RandomNumberGenerator): self._np_random = value - def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + def step( + self, action: ActType + ) -> Union[ + Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] + ]: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. - Accepts an action and returns a tuple `(observation, reward, done, info)`. + Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple + (observation, reward, done, info). The latter is deprecated and will be removed in future versions. Args: action (ActType): an action provided by the agent @@ -143,14 +148,21 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: observation (object): this will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects. reward (float): The amount of reward returned as a result of taking the action. + terminated (bool): whether a `terminal state` (as defined under the MDP of the task) is reached. + In this case further step() calls could return undefined results. + truncated (bool): whether a truncation condition outside the scope of the MDP is satisfied. + Typically a timelimit, but could also be used to indicate agent physically going out of bounds. + Can be used to end the episode prematurely before a `terminal state` is reached. + info (dictionary): `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). + This might, for instance, contain: metrics that describe the agent's performance state, variables that are + hidden from observations, or individual reward terms that are combined to produce the total reward. + It also can contain information that distinguishes truncation and termination, however this is deprecated in favour + of returning two booleans, and will be removed in a future version. + + (deprecated) done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results. A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, a certain timelimit was exceeded, or the physics simulation has entered an invalid state. - info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal. - `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). - This might, for instance, contain: metrics that describe the agent's performance state, variables that are - hidden from observations, information that distinguishes truncation and termination or individual reward terms - that are combined to produce the total reward """ raise NotImplementedError @@ -298,11 +310,12 @@ class Wrapper(Env[ObsType, ActType]): Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. """ - def __init__(self, env: Env): + def __init__(self, env: Env, new_step_api: bool = False): """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. Args: env: The environment to wrap + new_step_api: Whether the wrapper's step method will output in new or old step API """ self.env = env @@ -310,6 +323,13 @@ def __init__(self, env: Env): self._observation_space: Optional[spaces.Space] = None self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None self._metadata: Optional[dict] = None + self.new_step_api = new_step_api + + if not self.new_step_api: + deprecation( + "Initializing wrapper in old step API which returns one bool instead of two. " + "It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. " + ) def __getattr__(self, name): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" @@ -391,9 +411,17 @@ def _np_random(self): "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ) - def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + def step( + self, action: ActType + ) -> Union[ + Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] + ]: """Steps through the environment with action.""" - return self.env.step(action) + from gym.utils.step_api_compatibility import ( # avoid circular import + step_api_compatibility, + ) + + return step_api_compatibility(self.env.step(action), self.new_step_api) def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]: """Resets the environment with kwargs.""" @@ -463,8 +491,13 @@ def reset(self, **kwargs): def step(self, action): """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" - observation, reward, done, info = self.env.step(action) - return self.observation(observation), reward, done, info + step_returns = self.env.step(action) + if len(step_returns) == 5: + observation, reward, terminated, truncated, info = step_returns + return self.observation(observation), reward, terminated, truncated, info + else: + observation, reward, done, info = step_returns + return self.observation(observation), reward, done, info def observation(self, observation): """Returns a modified observation.""" @@ -497,8 +530,13 @@ def reward(self, reward): def step(self, action): """Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" - observation, reward, done, info = self.env.step(action) - return observation, self.reward(reward), done, info + step_returns = self.env.step(action) + if len(step_returns) == 5: + observation, reward, terminated, truncated, info = step_returns + return observation, self.reward(reward), terminated, truncated, info + else: + observation, reward, done, info = step_returns + return observation, self.reward(reward), done, info def reward(self, reward): """Returns a modified ``reward``.""" diff --git a/gym/envs/box2d/bipedal_walker.py b/gym/envs/box2d/bipedal_walker.py index 61230b86b94..207d1fe7f06 100644 --- a/gym/envs/box2d/bipedal_walker.py +++ b/gym/envs/box2d/bipedal_walker.py @@ -599,15 +599,14 @@ def step(self, action: np.ndarray): reward -= 0.00035 * MOTORS_TORQUE * np.clip(np.abs(a), 0, 1) # normalized to about -50.0 using heuristic, more optimal agent should spend less - done = False + terminated = False if self.game_over or pos[0] < 0: reward = -100 - done = True + terminated = True if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP: - done = True - + terminated = True self.renderer.render_step() - return np.array(state, dtype=np.float32), reward, done, {} + return np.array(state, dtype=np.float32), reward, terminated, False, {} def render(self, mode: str = "human"): if self.render_mode is not None: @@ -789,9 +788,9 @@ def __init__(self): SUPPORT_KNEE_ANGLE = +0.1 supporting_knee_angle = SUPPORT_KNEE_ANGLE while True: - s, r, done, info = env.step(a) + s, r, terminated, truncated, info = env.step(a) total_reward += r - if steps % 20 == 0 or done: + if steps % 20 == 0 or terminated or truncated: print("\naction " + str([f"{x:+0.2f}" for x in a])) print(f"step {steps} total_reward {total_reward:+0.2f}") print("hull " + str([f"{x:+0.2f}" for x in s[0:4]])) @@ -854,5 +853,5 @@ def __init__(self): a[3] = knee_todo[1] a = np.clip(0.5 * a, -1.0, 1.0) - if done: + if terminated or truncated: break diff --git a/gym/envs/box2d/car_racing.py b/gym/envs/box2d/car_racing.py index ad7455d4788..710f4215bb5 100644 --- a/gym/envs/box2d/car_racing.py +++ b/gym/envs/box2d/car_racing.py @@ -526,8 +526,8 @@ def step(self, action: Union[np.ndarray, int]): self.state = self._render("single_state_pixels") step_reward = 0 - done = False - info = {} + terminated = False + truncated = False if action is not None: # First step without action, called from reset() self.reward -= 0.1 # We actually don't want to count fuel spent, we want car to be faster. @@ -536,18 +536,17 @@ def step(self, action: Union[np.ndarray, int]): step_reward = self.reward - self.prev_reward self.prev_reward = self.reward if self.tile_visited_count == len(self.track) or self.new_lap: - done = True - # Termination due to finishing lap + # Truncation due to finishing lap # This should not be treated as a failure # but like a timeout - info["TimeLimit.truncated"] = True + truncated = True x, y = self.car.hull.position if abs(x) > PLAYFIELD or abs(y) > PLAYFIELD: - done = True + terminated = True step_reward = -100 self.renderer.render_step() - return self.state, step_reward, done, info + return self.state, step_reward, terminated, truncated, {} def render(self, mode: str = "human"): if self.render_mode is not None: @@ -811,13 +810,13 @@ def register_input(): restart = False while True: register_input() - s, r, done, info = env.step(a) + s, r, terminated, truncated, info = env.step(a) total_reward += r - if steps % 200 == 0 or done: + if steps % 200 == 0 or terminated or truncated: print("\naction " + str([f"{x:+0.2f}" for x in a])) print(f"step {steps} total_reward {total_reward:+0.2f}") steps += 1 isopen = env.render() - if done or restart or isopen is False: + if terminated or truncated or restart or isopen is False: break env.close() diff --git a/gym/envs/box2d/lunar_lander.py b/gym/envs/box2d/lunar_lander.py index b0022d8bfbb..e4ab2ec348e 100644 --- a/gym/envs/box2d/lunar_lander.py +++ b/gym/envs/box2d/lunar_lander.py @@ -11,6 +11,7 @@ from gym.error import DependencyNotInstalled from gym.utils import EzPickle, colorize from gym.utils.renderer import Renderer +from gym.utils.step_api_compatibility import step_api_compatibility try: import Box2D @@ -577,16 +578,15 @@ def step(self, action): ) # less fuel spent is better, about -30 for heuristic landing reward -= s_power * 0.03 - done = False + terminated = False if self.game_over or abs(state[0]) >= 1.0: - done = True + terminated = True reward = -100 if not self.lander.awake: - done = True + terminated = True reward = +100 - self.renderer.render_step() - return np.array(state, dtype=np.float32), reward, done, {} + return np.array(state, dtype=np.float32), reward, terminated, False, {} def render(self, mode="human"): if self.render_mode is not None: @@ -771,7 +771,7 @@ def demo_heuristic_lander(env, seed=None, render=False): s = env.reset(seed=seed) while True: a = heuristic(env, s) - s, r, done, info = env.step(a) + s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True) total_reward += r if render: @@ -779,11 +779,11 @@ def demo_heuristic_lander(env, seed=None, render=False): if still_open is False: break - if steps % 20 == 0 or done: + if steps % 20 == 0 or terminated or truncated: print("observations:", " ".join([f"{x:+0.2f}" for x in s])) print(f"step {steps} total_reward {total_reward:+0.2f}") steps += 1 - if done: + if terminated or truncated: break if render: env.close() diff --git a/gym/envs/classic_control/acrobot.py b/gym/envs/classic_control/acrobot.py index 3387a101cb4..60e16abfb5e 100644 --- a/gym/envs/classic_control/acrobot.py +++ b/gym/envs/classic_control/acrobot.py @@ -86,12 +86,12 @@ class AcrobotEnv(core.Env): Each parameter in the underlying state (`theta1`, `theta2`, and the two angular velocities) is initialized uniformly between -0.1 and 0.1. This means both links are pointing downwards with some initial stochasticity. - ### Episode Termination + ### Episode End - The episode terminates if one of the following occurs: - 1. The free end reaches the target height, which is constructed as: + The episode ends if one of the following occurs: + 1. Termination: The free end reaches the target height, which is constructed as: `-cos(theta1) - cos(theta2 + theta1) > 1.0` - 2. Episode length is greater than 500 (200 for v0) + 2. Truncation: Episode length is greater than 500 (200 for v0) ### Arguments @@ -226,11 +226,11 @@ def step(self, a): ns[2] = bound(ns[2], -self.MAX_VEL_1, self.MAX_VEL_1) ns[3] = bound(ns[3], -self.MAX_VEL_2, self.MAX_VEL_2) self.state = ns - terminal = self._terminal() - reward = -1.0 if not terminal else 0.0 + terminated = self._terminal() + reward = -1.0 if not terminated else 0.0 self.renderer.render_step() - return self._get_ob(), reward, terminal, {} + return (self._get_ob(), reward, terminated, False, {}) def _get_ob(self): s = self.state diff --git a/gym/envs/classic_control/cartpole.py b/gym/envs/classic_control/cartpole.py index 52302cae2cf..5b98f898df5 100644 --- a/gym/envs/classic_control/cartpole.py +++ b/gym/envs/classic_control/cartpole.py @@ -65,12 +65,13 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): All observations are assigned a uniformly random value in `(-0.05, 0.05)` - ### Episode Termination + ### Episode End - The episode terminates if any one of the following occurs: - 1. Pole Angle is greater than ±12° - 2. Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display) - 3. Episode length is greater than 500 (200 for v0) + The episode ends if any one of the following occurs: + + 1. Termination: Pole Angle is greater than ±12° + 2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display) + 3. Truncation: Episode length is greater than 500 (200 for v0) ### Arguments @@ -126,7 +127,7 @@ def __init__(self, render_mode: Optional[str] = None): self.isopen = True self.state = None - self.steps_beyond_done = None + self.steps_beyond_terminated = None def step(self, action): err_msg = f"{action!r} ({type(action)}) invalid" @@ -160,32 +161,32 @@ def step(self, action): self.state = (x, x_dot, theta, theta_dot) - done = bool( + terminated = bool( x < -self.x_threshold or x > self.x_threshold or theta < -self.theta_threshold_radians or theta > self.theta_threshold_radians ) - if not done: + if not terminated: reward = 1.0 - elif self.steps_beyond_done is None: + elif self.steps_beyond_terminated is None: # Pole just fell! - self.steps_beyond_done = 0 + self.steps_beyond_terminated = 0 reward = 1.0 else: - if self.steps_beyond_done == 0: + if self.steps_beyond_terminated == 0: logger.warn( "You are calling 'step()' even though this " - "environment has already returned done = True. You " - "should always call 'reset()' once you receive 'done = " + "environment has already returned terminated = True. You " + "should always call 'reset()' once you receive 'terminated = " "True' -- any further steps are undefined behavior." ) - self.steps_beyond_done += 1 + self.steps_beyond_terminated += 1 reward = 0.0 self.renderer.render_step() - return np.array(self.state, dtype=np.float32), reward, done, {} + return np.array(self.state, dtype=np.float32), reward, terminated, False, {} def reset( self, @@ -201,7 +202,7 @@ def reset( options, -0.05, 0.05 # default low ) # default high self.state = self.np_random.uniform(low=low, high=high, size=(4,)) - self.steps_beyond_done = None + self.steps_beyond_terminated = None self.renderer.reset() self.renderer.render_step() if not return_info: diff --git a/gym/envs/classic_control/continuous_mountain_car.py b/gym/envs/classic_control/continuous_mountain_car.py index a0eeb012a8d..5abf77e314e 100644 --- a/gym/envs/classic_control/continuous_mountain_car.py +++ b/gym/envs/classic_control/continuous_mountain_car.py @@ -84,11 +84,11 @@ class Continuous_MountainCarEnv(gym.Env): The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`. The starting velocity of the car is always assigned to 0. - ### Episode Termination + ### Episode End - The episode terminates if either of the following happens: - 1. The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill) - 2. The length of the episode is 999. + The episode ends if either of the following happens: + 1. Termination: The position of the car is greater than or equal to 0.45 (the goal position on top of the right hill) + 2. Truncation: The length of the episode is 999. ### Arguments @@ -161,17 +161,18 @@ def step(self, action: np.ndarray): velocity = 0 # Convert a possible numpy bool to a Python bool. - done = bool(position >= self.goal_position and velocity >= self.goal_velocity) + terminated = bool( + position >= self.goal_position and velocity >= self.goal_velocity + ) reward = 0 - if done: + if terminated: reward = 100.0 reward -= math.pow(action[0], 2) * 0.1 self.state = np.array([position, velocity], dtype=np.float32) - self.renderer.render_step() - return self.state, reward, done, {} + return self.state, reward, terminated, False, {} def reset( self, diff --git a/gym/envs/classic_control/mountain_car.py b/gym/envs/classic_control/mountain_car.py index d74cde99fa6..640186db7fb 100644 --- a/gym/envs/classic_control/mountain_car.py +++ b/gym/envs/classic_control/mountain_car.py @@ -78,11 +78,11 @@ class MountainCarEnv(gym.Env): The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*. The starting velocity of the car is always assigned to 0. - ### Episode Termination + ### Episode End - The episode terminates if either of the following happens: - 1. The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill) - 2. The length of the episode is 200. + The episode ends if either of the following happens: + 1. Termination: The position of the car is greater than or equal to 0.5 (the goal position on top of the right hill) + 2. Truncation: The length of the episode is 200. ### Arguments @@ -139,13 +139,14 @@ def step(self, action: int): if position == self.min_position and velocity < 0: velocity = 0 - done = bool(position >= self.goal_position and velocity >= self.goal_velocity) + terminated = bool( + position >= self.goal_position and velocity >= self.goal_velocity + ) reward = -1.0 self.state = (position, velocity) - self.renderer.render_step() - return np.array(self.state, dtype=np.float32), reward, done, {} + return np.array(self.state, dtype=np.float32), reward, terminated, False, {} def reset( self, diff --git a/gym/envs/classic_control/pendulum.py b/gym/envs/classic_control/pendulum.py index 927bd9e921c..73bbcbcca03 100644 --- a/gym/envs/classic_control/pendulum.py +++ b/gym/envs/classic_control/pendulum.py @@ -68,9 +68,9 @@ class PendulumEnv(gym.Env): The starting state is a random angle in *[-pi, pi]* and a random angular velocity in *[-1,1]*. - ### Episode Termination + ### Episode Truncation - The episode terminates at 200 time steps. + The episode truncates at 200 time steps. ### Arguments @@ -136,7 +136,7 @@ def step(self, u): self.state = np.array([newth, newthdot]) self.renderer.render_step() - return self._get_obs(), -costs, False, {} + return self._get_obs(), -costs, False, False, {} def reset( self, diff --git a/gym/envs/mujoco/ant.py b/gym/envs/mujoco/ant.py index 39f13c7793d..1a981633d87 100644 --- a/gym/envs/mujoco/ant.py +++ b/gym/envs/mujoco/ant.py @@ -41,13 +41,16 @@ def step(self, a): survive_reward = 1.0 reward = forward_reward - ctrl_cost - contact_cost + survive_reward state = self.state_vector() - notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 - done = not notdone + not_terminated = ( + np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 + ) + terminated = not not_terminated ob = self._get_obs() return ( ob, reward, - done, + terminated, + False, dict( reward_forward=forward_reward, reward_ctrl=-ctrl_cost, diff --git a/gym/envs/mujoco/ant_v3.py b/gym/envs/mujoco/ant_v3.py index 8bd6ba1e401..4a6fa6bd334 100644 --- a/gym/envs/mujoco/ant_v3.py +++ b/gym/envs/mujoco/ant_v3.py @@ -97,9 +97,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def step(self, action): xy_position_before = self.get_body_com("torso")[:2].copy() @@ -121,7 +121,7 @@ def step(self, action): self.renderer.render_step() reward = rewards - costs - done = self.done + terminated = self.terminated observation = self._get_obs() info = { "reward_forward": forward_reward, @@ -136,7 +136,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.sim.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/ant_v4.py b/gym/envs/mujoco/ant_v4.py index 645a2f7d884..232272ab7ef 100644 --- a/gym/envs/mujoco/ant_v4.py +++ b/gym/envs/mujoco/ant_v4.py @@ -124,19 +124,19 @@ class AntEnv(MujocoEnv, utils.EzPickle): to be slightly high, thereby indicating a standing up ant. The initial orientation is designed to make it face forward as well. - ### Episode Termination + ### Episode End The ant is said to be unhealthy if any of the following happens: 1. Any of the state space values is no longer finite 2. The z-coordinate of the torso is **not** in the closed interval given by `healthy_z_range` (defaults to [0.2, 1.0]) If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. The ant is unhealthy + 1. Termination: The episode duration reaches a 1000 timesteps + 2. Truncation: The ant is unhealthy - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -263,9 +263,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def step(self, action): xy_position_before = self.get_body_com("torso")[:2].copy() @@ -282,7 +282,7 @@ def step(self, action): costs = ctrl_cost = self.control_cost(action) - done = self.done + terminated = self.terminated observation = self._get_obs() info = { "reward_forward": forward_reward, @@ -303,7 +303,7 @@ def step(self, action): reward = rewards - costs self.renderer.render_step() - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/half_cheetah.py b/gym/envs/mujoco/half_cheetah.py index 727dba84d2f..069b4d146dd 100644 --- a/gym/envs/mujoco/half_cheetah.py +++ b/gym/envs/mujoco/half_cheetah.py @@ -35,8 +35,14 @@ def step(self, action): reward_ctrl = -0.1 * np.square(action).sum() reward_run = (xposafter - xposbefore) / self.dt reward = reward_ctrl + reward_run - done = False - return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl) + terminated = False + return ( + ob, + reward, + terminated, + False, + dict(reward_run=reward_run, reward_ctrl=reward_ctrl), + ) def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/half_cheetah_v3.py b/gym/envs/mujoco/half_cheetah_v3.py index addf2dac051..07d0d74c8c8 100644 --- a/gym/envs/mujoco/half_cheetah_v3.py +++ b/gym/envs/mujoco/half_cheetah_v3.py @@ -75,7 +75,7 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False + terminated = False info = { "x_position": x_position_after, "x_velocity": x_velocity, @@ -83,7 +83,7 @@ def step(self, action): "reward_ctrl": -ctrl_cost, } - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.sim.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/half_cheetah_v4.py b/gym/envs/mujoco/half_cheetah_v4.py index f949b5ed9de..7ecf6de6268 100644 --- a/gym/envs/mujoco/half_cheetah_v4.py +++ b/gym/envs/mujoco/half_cheetah_v4.py @@ -98,8 +98,8 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle): normal noise with a mean of 0 and standard deviation of `reset_noise_scale` is added to the initial velocity values of all zeros. - ### Episode Termination - The episode terminates when the episode length is greater than 1000. + ### Episode End + The episode truncates when the episode length is greater than 1000. ### Arguments @@ -192,7 +192,7 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False + terminated = False info = { "x_position": x_position_after, "x_velocity": x_velocity, @@ -201,7 +201,7 @@ def step(self, action): } self.renderer.render_step() - return observation, reward, done, info + return observation, reward, terminated, False, info def _get_obs(self): position = self.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/hopper.py b/gym/envs/mujoco/hopper.py index 77f689f2ede..e0b9fa59cdf 100644 --- a/gym/envs/mujoco/hopper.py +++ b/gym/envs/mujoco/hopper.py @@ -36,14 +36,14 @@ def step(self, a): reward += alive_bonus reward -= 1e-3 * np.square(a).sum() s = self.state_vector() - done = not ( + terminated = not ( np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2) ) ob = self._get_obs() - return ob, reward, done, {} + return ob, reward, terminated, False, {} def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/hopper_v3.py b/gym/envs/mujoco/hopper_v3.py index 1e118c95d08..c3db1a9d3e5 100644 --- a/gym/envs/mujoco/hopper_v3.py +++ b/gym/envs/mujoco/hopper_v3.py @@ -101,9 +101,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.sim.data.qpos.flat.copy() @@ -133,13 +133,13 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/hopper_v4.py b/gym/envs/mujoco/hopper_v4.py index 407d4622755..3d2acb6e1c6 100644 --- a/gym/envs/mujoco/hopper_v4.py +++ b/gym/envs/mujoco/hopper_v4.py @@ -87,7 +87,7 @@ class HopperEnv(MujocoEnv, utils.EzPickle): (0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity. - ### Episode Termination + ### Episode End The hopper is said to be unhealthy if any of the following happens: 1. An element of `observation[1:]` (if `exclude_current_positions_from_observation=True`, else `observation[2:]`) is no longer contained in the closed interval specified by the argument `healthy_state_range` @@ -95,12 +95,12 @@ class HopperEnv(MujocoEnv, utils.EzPickle): 3. The angle (`observation[1]` if `exclude_current_positions_from_observation=True`, else `observation[2]`) is no longer contained in the closed interval specified by the argument `healthy_angle_range` If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. The hopper is unhealthy + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: The hopper is unhealthy - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -223,9 +223,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.data.qpos.flat.copy() @@ -253,14 +253,14 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } self.renderer.render_step() - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/humanoid.py b/gym/envs/mujoco/humanoid.py index 0c98e17fba5..b67179ee858 100644 --- a/gym/envs/mujoco/humanoid.py +++ b/gym/envs/mujoco/humanoid.py @@ -60,11 +60,12 @@ def step(self, a): quad_impact_cost = min(quad_impact_cost, 10) reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus qpos = self.sim.data.qpos - done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) + terminated = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) return ( self._get_obs(), reward, - done, + terminated, + False, dict( reward_linvel=lin_vel_cost, reward_quadctrl=-quad_ctrl_cost, diff --git a/gym/envs/mujoco/humanoid_v3.py b/gym/envs/mujoco/humanoid_v3.py index 47271dc28c3..81c4b3d6f54 100644 --- a/gym/envs/mujoco/humanoid_v3.py +++ b/gym/envs/mujoco/humanoid_v3.py @@ -99,9 +99,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = (not self.is_healthy) if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.sim.data.qpos.flat.copy() @@ -148,7 +148,7 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "reward_linvel": forward_reward, "reward_quadctrl": -ctrl_cost, @@ -162,7 +162,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/humanoid_v4.py b/gym/envs/mujoco/humanoid_v4.py index 9c71bcdb9c3..46620d7740d 100644 --- a/gym/envs/mujoco/humanoid_v4.py +++ b/gym/envs/mujoco/humanoid_v4.py @@ -165,18 +165,17 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle): selected to be high, thereby indicating a standing up humanoid. The initial orientation is designed to make it face forward as well. - ### Episode Termination + ### Episode End The humanoid is said to be unhealthy if the z-position of the torso is no longer contained in the closed interval specified by the argument `healthy_z_range`. If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 3. The humanoid is unhealthy - - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + 1. Truncation: The episode duration reaches a 1000 timesteps + 3. Termination: The humanoid is unhealthy + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -281,9 +280,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = (not self.is_healthy) if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = (not self.is_healthy) if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.data.qpos.flat.copy() @@ -326,7 +325,7 @@ def step(self, action): observation = self._get_obs() reward = rewards - ctrl_cost - done = self.done + terminated = self.terminated info = { "reward_linvel": forward_reward, "reward_quadctrl": -ctrl_cost, @@ -340,7 +339,7 @@ def step(self, action): } self.renderer.render_step() - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/humanoidstandup.py b/gym/envs/mujoco/humanoidstandup.py index e381bdab1c9..06ef153e95d 100644 --- a/gym/envs/mujoco/humanoidstandup.py +++ b/gym/envs/mujoco/humanoidstandup.py @@ -56,11 +56,11 @@ def step(self, a): self.renderer.render_step() - done = bool(False) return ( self._get_obs(), reward, - done, + False, + False, dict( reward_linup=uph_cost, reward_quadctrl=-quad_ctrl_cost, diff --git a/gym/envs/mujoco/humanoidstandup_v4.py b/gym/envs/mujoco/humanoidstandup_v4.py index 1bd52e4cf71..c22e7f63e51 100644 --- a/gym/envs/mujoco/humanoidstandup_v4.py +++ b/gym/envs/mujoco/humanoidstandup_v4.py @@ -151,11 +151,11 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle): to be low, thereby indicating a laying down humanoid. The initial orientation is designed to make it face forward as well. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. Any of the state space values is no longer finite + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: Any of the state space values is no longer finite ### Arguments @@ -228,11 +228,11 @@ def step(self, a): self.renderer.render_step() - done = bool(False) return ( self._get_obs(), reward, - done, + False, + False, dict( reward_linup=uph_cost, reward_quadctrl=-quad_ctrl_cost, diff --git a/gym/envs/mujoco/inverted_double_pendulum.py b/gym/envs/mujoco/inverted_double_pendulum.py index 04c7e223590..3f41fc0778c 100644 --- a/gym/envs/mujoco/inverted_double_pendulum.py +++ b/gym/envs/mujoco/inverted_double_pendulum.py @@ -40,8 +40,8 @@ def step(self, action): vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2 alive_bonus = 10 r = alive_bonus - dist_penalty - vel_penalty - done = bool(y <= 1) - return ob, r, done, {} + terminated = bool(y <= 1) + return ob, r, terminated, False, {} def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/inverted_double_pendulum_v4.py b/gym/envs/mujoco/inverted_double_pendulum_v4.py index bde6e5fdb16..c9d472d1596 100644 --- a/gym/envs/mujoco/inverted_double_pendulum_v4.py +++ b/gym/envs/mujoco/inverted_double_pendulum_v4.py @@ -85,12 +85,12 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle): of [-0.1, 0.1] added to the positional values (cart position and pole angles) and standard normal force with a standard deviation of 0.1 added to the velocity values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches 1000 timesteps. - 2. Any of the state space values is no longer finite. - 3. The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other). + 1.Truncation: The episode duration reaches 1000 timesteps. + 2.Termination: Any of the state space values is no longer finite. + 3.Termination: The y_coordinate of the tip of the second pole *is less than or equal* to 1. The maximum standing height of the system is 1.196 m when all the parts are perpendicularly vertical on top of each other). ### Arguments @@ -143,11 +143,9 @@ def step(self, action): vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2 alive_bonus = 10 r = alive_bonus - dist_penalty - vel_penalty - done = bool(y <= 1) - + terminated = bool(y <= 1) self.renderer.render_step() - - return ob, r, done, {} + return ob, r, terminated, False, {} def _get_obs(self): return np.concatenate( diff --git a/gym/envs/mujoco/inverted_pendulum.py b/gym/envs/mujoco/inverted_pendulum.py index c0cbaa689e7..e41dbe45a8f 100644 --- a/gym/envs/mujoco/inverted_pendulum.py +++ b/gym/envs/mujoco/inverted_pendulum.py @@ -35,9 +35,8 @@ def step(self, a): self.renderer.render_step() ob = self._get_obs() - notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2) - done = not notdone - return ob, reward, done, {} + terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2)) + return ob, reward, terminated, False, {} def reset_model(self): qpos = self.init_qpos + self.np_random.uniform( diff --git a/gym/envs/mujoco/inverted_pendulum_v4.py b/gym/envs/mujoco/inverted_pendulum_v4.py index 7fef812d07b..cc029672fd9 100644 --- a/gym/envs/mujoco/inverted_pendulum_v4.py +++ b/gym/envs/mujoco/inverted_pendulum_v4.py @@ -56,12 +56,12 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle): (0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-0.01, 0.01] added to the values for stochasticity. - ### Episode Termination - The episode terminates when any of the following happens: + ### Episode End + The episode ends when any of the following happens: - 1. The episode duration reaches 1000 timesteps. - 2. Any of the state space values is no longer finite. - 3. The absolute value of the vertical angle between the pole and the cart is greater than 0.2 radians. + 1. Truncation: The episode duration reaches 1000 timesteps. + 2. Termination: Any of the state space values is no longer finite. + 3. Termination: The absolutely value of the vertical angle between the pole and the cart is greater than 0.2 radian. ### Arguments @@ -109,12 +109,9 @@ def step(self, a): reward = 1.0 self.do_simulation(a, self.frame_skip) ob = self._get_obs() - notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2) - done = not notdone - + terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2)) self.renderer.render_step() - - return ob, reward, done, {} + return ob, reward, terminated, False, {} def reset_model(self): qpos = self.init_qpos + self.np_random.uniform( diff --git a/gym/envs/mujoco/pusher.py b/gym/envs/mujoco/pusher.py index 20165fb499f..b4c7fe1c12b 100644 --- a/gym/envs/mujoco/pusher.py +++ b/gym/envs/mujoco/pusher.py @@ -38,8 +38,13 @@ def step(self, a): self.renderer.render_step() ob = self._get_obs() - done = False - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): assert self.viewer is not None diff --git a/gym/envs/mujoco/pusher_v4.py b/gym/envs/mujoco/pusher_v4.py index e9859262e78..306e272d2a5 100644 --- a/gym/envs/mujoco/pusher_v4.py +++ b/gym/envs/mujoco/pusher_v4.py @@ -99,12 +99,12 @@ class PusherEnv(MujocoEnv, utils.EzPickle): The default framerate is 5 with each frame lasting for 0.01, giving rise to a *dt = 5 * 0.01 = 0.05* - ### Episode Termination + ### Episode End - The episode terminates when any of the following happens: + The episode ends when any of the following happens: - 1. The episode duration reaches a 100 timesteps. - 2. Any of the state space values is no longer finite. + 1. Truncation: The episode duration reaches a 100 timesteps. + 2. Termination: Any of the state space values is no longer finite. ### Arguments @@ -157,11 +157,14 @@ def step(self, a): self.do_simulation(a, self.frame_skip) ob = self._get_obs() - done = False - self.renderer.render_step() - - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): assert self.viewer is not None diff --git a/gym/envs/mujoco/reacher.py b/gym/envs/mujoco/reacher.py index c5495ef3604..73438666b95 100644 --- a/gym/envs/mujoco/reacher.py +++ b/gym/envs/mujoco/reacher.py @@ -34,8 +34,13 @@ def step(self, a): self.renderer.render_step() ob = self._get_obs() - done = False - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): assert self.viewer is not None diff --git a/gym/envs/mujoco/reacher_v4.py b/gym/envs/mujoco/reacher_v4.py index 4c54f701eb4..f0c334c0edd 100644 --- a/gym/envs/mujoco/reacher_v4.py +++ b/gym/envs/mujoco/reacher_v4.py @@ -89,12 +89,12 @@ class ReacherEnv(MujocoEnv, utils.EzPickle): element ("fingertip" - "target") is calculated at the end once everything is set. The default setting has a framerate of 2 and a *dt = 2 * 0.01 = 0.02* - ### Episode Termination + ### Episode End - The episode terminates when any of the following happens: + The episode ends when any of the following happens: - 1. The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps) - 2. Any of the state space values is no longer finite. + 1. Truncation: The episode duration reaches a 50 timesteps (with a new random target popping up if the reacher's fingertip reaches it before 50 timesteps) + 2. Termination: Any of the state space values is no longer finite. ### Arguments @@ -143,11 +143,14 @@ def step(self, a): reward = reward_dist + reward_ctrl self.do_simulation(a, self.frame_skip) ob = self._get_obs() - done = False - self.renderer.render_step() - - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), + ) def viewer_setup(self): assert self.viewer is not None diff --git a/gym/envs/mujoco/swimmer.py b/gym/envs/mujoco/swimmer.py index 15d7e8735d6..137a97eb52e 100644 --- a/gym/envs/mujoco/swimmer.py +++ b/gym/envs/mujoco/swimmer.py @@ -36,7 +36,13 @@ def step(self, a): reward_ctrl = -ctrl_cost_coeff * np.square(a).sum() reward = reward_fwd + reward_ctrl ob = self._get_obs() - return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) + return ( + ob, + reward, + False, + False, + dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl), + ) def _get_obs(self): qpos = self.sim.data.qpos diff --git a/gym/envs/mujoco/swimmer_v3.py b/gym/envs/mujoco/swimmer_v3.py index b20412744e2..d17c099630e 100644 --- a/gym/envs/mujoco/swimmer_v3.py +++ b/gym/envs/mujoco/swimmer_v3.py @@ -73,7 +73,6 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False info = { "reward_fwd": forward_reward, "reward_ctrl": -ctrl_cost, @@ -85,7 +84,7 @@ def step(self, action): "forward_reward": forward_reward, } - return observation, reward, done, info + return observation, reward, False, False, info def _get_obs(self): position = self.sim.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/swimmer_v4.py b/gym/envs/mujoco/swimmer_v4.py index 97ec4d7cc40..29b1221e4e5 100644 --- a/gym/envs/mujoco/swimmer_v4.py +++ b/gym/envs/mujoco/swimmer_v4.py @@ -89,8 +89,8 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle): ### Starting State All observations start in state (0,0,0,0,0,0,0,0) with a Uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] is added to the initial state for stochasticity. - ### Episode Termination - The episode terminates when the episode length is greater than 1000. + ### Episode End + The episode truncates when the episode length is greater than 1000. ### Arguments @@ -183,7 +183,6 @@ def step(self, action): observation = self._get_obs() reward = forward_reward - ctrl_cost - done = False info = { "reward_fwd": forward_reward, "reward_ctrl": -ctrl_cost, @@ -196,7 +195,7 @@ def step(self, action): } self.renderer.render_step() - return observation, reward, done, info + return observation, reward, False, False, info def _get_obs(self): position = self.data.qpos.flat.copy() diff --git a/gym/envs/mujoco/walker2d.py b/gym/envs/mujoco/walker2d.py index 786a80158a7..12ec9630f31 100644 --- a/gym/envs/mujoco/walker2d.py +++ b/gym/envs/mujoco/walker2d.py @@ -35,10 +35,9 @@ def step(self, a): reward = (posafter - posbefore) / self.dt reward += alive_bonus reward -= 1e-3 * np.square(a).sum() - done = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0) + terminated = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0) ob = self._get_obs() - - return ob, reward, done, {} + return ob, reward, terminated, False, {} def _get_obs(self): qpos = self.sim.data.qpos diff --git a/gym/envs/mujoco/walker2d_v3.py b/gym/envs/mujoco/walker2d_v3.py index 98b226a383d..8688804fe27 100644 --- a/gym/envs/mujoco/walker2d_v3.py +++ b/gym/envs/mujoco/walker2d_v3.py @@ -92,9 +92,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.sim.data.qpos.flat.copy() @@ -123,13 +123,13 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/mujoco/walker2d_v4.py b/gym/envs/mujoco/walker2d_v4.py index 390c5dc5951..795480fb7e4 100644 --- a/gym/envs/mujoco/walker2d_v4.py +++ b/gym/envs/mujoco/walker2d_v4.py @@ -92,7 +92,7 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle): (0.0, 1.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) with a uniform noise in the range of [-`reset_noise_scale`, `reset_noise_scale`] added to the values for stochasticity. - ### Episode Termination + ### Episode End The walker is said to be unhealthy if any of the following happens: 1. Any of the state space values is no longer finite @@ -100,12 +100,12 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle): 3. The absolute value of the angle (`observation[1]` if `exclude_current_positions_from_observation=False`, else `observation[2]`) is ***not*** in the closed interval specified by `healthy_angle_range` If `terminate_when_unhealthy=True` is passed during construction (which is the default), - the episode terminates when any of the following happens: + the episode ends when any of the following happens: - 1. The episode duration reaches a 1000 timesteps - 2. The walker is unhealthy + 1. Truncation: The episode duration reaches a 1000 timesteps + 2. Termination: The walker is unhealthy - If `terminate_when_unhealthy=False` is passed, the episode is terminated only when 1000 timesteps are exceeded. + If `terminate_when_unhealthy=False` is passed, the episode is ended only when 1000 timesteps are exceeded. ### Arguments @@ -221,9 +221,9 @@ def is_healthy(self): return is_healthy @property - def done(self): - done = not self.is_healthy if self._terminate_when_unhealthy else False - return done + def terminated(self): + terminated = not self.is_healthy if self._terminate_when_unhealthy else False + return terminated def _get_obs(self): position = self.data.qpos.flat.copy() @@ -251,14 +251,14 @@ def step(self, action): observation = self._get_obs() reward = rewards - costs - done = self.done + terminated = self.terminated info = { "x_position": x_position_after, "x_velocity": x_velocity, } self.renderer.render_step() - return observation, reward, done, info + return observation, reward, terminated, False, info def reset_model(self): noise_low = -self._reset_noise_scale diff --git a/gym/envs/registration.py b/gym/envs/registration.py index d558420ffa1..9ac4ef1f178 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -23,7 +23,13 @@ import numpy as np from gym.envs.__relocated__ import internal_env_relocation_map -from gym.wrappers import AutoResetWrapper, HumanRendering, OrderEnforcing, TimeLimit +from gym.wrappers import ( + AutoResetWrapper, + HumanRendering, + OrderEnforcing, + StepAPICompatibility, + TimeLimit, +) from gym.wrappers.env_checker import PassiveEnvChecker if sys.version_info < (3, 10): @@ -118,6 +124,7 @@ class EnvSpec: max_episode_steps: Optional[int] = field(default=None) order_enforce: bool = field(default=True) autoreset: bool = field(default=False) + new_step_api: bool = field(default=False) kwargs: dict = field(default_factory=dict) namespace: Optional[str] = field(init=False) @@ -522,6 +529,7 @@ def make( id: Union[str, EnvSpec], max_episode_steps: Optional[int] = None, autoreset: bool = False, + new_step_api: bool = False, disable_env_checker: bool = False, **kwargs, ) -> Env: @@ -531,6 +539,7 @@ def make( id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). + new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0 disable_env_checker: If to disable the environment checker kwargs: Additional arguments to pass to the environment constructor. @@ -644,19 +653,21 @@ def make( if disable_env_checker is False: env = PassiveEnvChecker(env) + env = StepAPICompatibility(env, new_step_api) + # Add the order enforcing wrapper if spec_.order_enforce: env = OrderEnforcing(env) # Add the time limit wrapper if max_episode_steps is not None: - env = TimeLimit(env, max_episode_steps) + env = TimeLimit(env, max_episode_steps, new_step_api) elif spec_.max_episode_steps is not None: - env = TimeLimit(env, spec_.max_episode_steps) + env = TimeLimit(env, spec_.max_episode_steps, new_step_api) # Add the autoreset wrapper if autoreset: - env = AutoResetWrapper(env) + env = AutoResetWrapper(env, new_step_api) return env diff --git a/gym/envs/toy_text/blackjack.py b/gym/envs/toy_text/blackjack.py index e99943f6c90..702fcb1151d 100644 --- a/gym/envs/toy_text/blackjack.py +++ b/gym/envs/toy_text/blackjack.py @@ -137,13 +137,13 @@ def step(self, action): if action: # hit: add a card to players hand and return self.player.append(draw_card(self.np_random)) if is_bust(self.player): - done = True + terminated = True reward = -1.0 else: - done = False + terminated = False reward = 0.0 else: # stick: play out the dealers hand, and score - done = True + terminated = True while sum_hand(self.dealer) < 17: self.dealer.append(draw_card(self.np_random)) reward = cmp(score(self.player), score(self.dealer)) @@ -158,9 +158,8 @@ def step(self, action): ): # Natural gives extra points, but doesn't autowin. Legacy implementation reward = 1.5 - self.renderer.render_step() - return self._get_obs(), reward, done, {} + return self._get_obs(), reward, terminated, False, {} def _get_obs(self): return (sum_hand(self.player), self.dealer[0], usable_ace(self.player)) diff --git a/gym/envs/toy_text/cliffwalking.py b/gym/envs/toy_text/cliffwalking.py index 24300a9accf..1de1354a5f3 100644 --- a/gym/envs/toy_text/cliffwalking.py +++ b/gym/envs/toy_text/cliffwalking.py @@ -111,7 +111,7 @@ def _calculate_transition_prob(self, current, delta): delta: Change in position for transition Returns: - Tuple of ``(1.0, new_state, reward, done)`` + Tuple of ``(1.0, new_state, reward, terminated)`` """ new_position = np.array(current) + np.array(delta) new_position = self._limit_coordinates(new_position).astype(int) @@ -120,17 +120,17 @@ def _calculate_transition_prob(self, current, delta): return [(1.0, self.start_state_index, -100, False)] terminal_state = (self.shape[0] - 1, self.shape[1] - 1) - is_done = tuple(new_position) == terminal_state - return [(1.0, new_state, -1, is_done)] + is_terminated = tuple(new_position) == terminal_state + return [(1.0, new_state, -1, is_terminated)] def step(self, a): transitions = self.P[self.s][a] i = categorical_sample([t[0] for t in transitions], self.np_random) - p, s, r, d = transitions[i] + p, s, r, t = transitions[i] self.s = s self.lastaction = a self.renderer.render_step() - return (int(s), r, d, {"prob": p}) + return (int(s), r, t, False, {"prob": p}) def reset( self, diff --git a/gym/envs/toy_text/frozen_lake.py b/gym/envs/toy_text/frozen_lake.py index fd262f7440b..595e9e387a4 100644 --- a/gym/envs/toy_text/frozen_lake.py +++ b/gym/envs/toy_text/frozen_lake.py @@ -201,9 +201,9 @@ def update_probability_matrix(row, col, action): newrow, newcol = inc(row, col, action) newstate = to_s(newrow, newcol) newletter = desc[newrow, newcol] - done = bytes(newletter) in b"GH" + terminated = bytes(newletter) in b"GH" reward = float(newletter == b"G") - return newstate, reward, done + return newstate, reward, terminated for row in range(nrow): for col in range(ncol): @@ -242,13 +242,11 @@ def update_probability_matrix(row, col, action): def step(self, a): transitions = self.P[self.s][a] i = categorical_sample([t[0] for t in transitions], self.np_random) - p, s, r, d = transitions[i] + p, s, r, t = transitions[i] self.s = s self.lastaction = a - self.renderer.render_step() - - return (int(s), r, d, {"prob": p}) + return (int(s), r, t, False, {"prob": p}) def reset( self, diff --git a/gym/envs/toy_text/taxi.py b/gym/envs/toy_text/taxi.py index 2fcf737af23..002df54947c 100644 --- a/gym/envs/toy_text/taxi.py +++ b/gym/envs/toy_text/taxi.py @@ -156,7 +156,7 @@ def __init__(self, render_mode: Optional[str] = None): reward = ( -1 ) # default reward when there is no pickup/dropoff - done = False + terminated = False taxi_loc = (row, col) if action == 0: @@ -175,7 +175,7 @@ def __init__(self, render_mode: Optional[str] = None): elif action == 5: # dropoff if (taxi_loc == locs[dest_idx]) and pass_idx == 4: new_pass_idx = dest_idx - done = True + terminated = True reward = 20 elif (taxi_loc in locs) and pass_idx == 4: new_pass_idx = locs.index(taxi_loc) @@ -184,7 +184,9 @@ def __init__(self, render_mode: Optional[str] = None): new_state = self.encode( new_row, new_col, new_pass_idx, dest_idx ) - self.P[state][action].append((1.0, new_state, reward, done)) + self.P[state][action].append( + (1.0, new_state, reward, terminated) + ) self.initial_state_distrib /= self.initial_state_distrib.sum() self.action_space = spaces.Discrete(num_actions) self.observation_space = spaces.Discrete(num_states) @@ -254,12 +256,11 @@ def action_mask(self, state: int): def step(self, a): transitions = self.P[self.s][a] i = categorical_sample([t[0] for t in transitions], self.np_random) - p, s, r, d = transitions[i] + p, s, r, t = transitions[i] self.s = s self.lastaction = a self.renderer.render_step() - - return int(s), r, d, {"prob": p, "action_mask": self.action_mask(s)} + return (int(s), r, t, False, {"prob": p, "action_mask": self.action_mask(s)}) def reset( self, diff --git a/gym/error.py b/gym/error.py index 8219c5d13f8..9a9b8899886 100644 --- a/gym/error.py +++ b/gym/error.py @@ -58,7 +58,7 @@ class ResetNeeded(Error): class ResetNotAllowed(Error): - """When the monitor is active, raised when the user tries to step an environment that's not yet done.""" + """When the monitor is active, raised when the user tries to step an environment that's not yet terminated or truncated.""" class InvalidAction(Error): diff --git a/gym/utils/passive_env_checker.py b/gym/utils/passive_env_checker.py index 41f97145002..67e9cb10548 100644 --- a/gym/utils/passive_env_checker.py +++ b/gym/utils/passive_env_checker.py @@ -4,6 +4,7 @@ import numpy as np from gym import error, logger, spaces +from gym.logger import deprecation def _check_box_observation_space(observation_space: spaces.Box): @@ -253,14 +254,24 @@ def passive_env_step_check(env, action): """A passive check for the environment step, investigating the returning data then returning the data unchanged.""" result = env.step(action) if len(result) == 4: + deprecation( + "Core environment is written in old step API which returns one bool instead of two. " + "It is recommended to rewrite the environment with new step API. " + ) obs, reward, done, info = result - assert isinstance(done, bool), "The `done` signal must be a boolean" + assert isinstance( + done, bool + ), f"The `done` signal is of type `{type(done)}` must be a boolean" elif len(result) == 5: obs, reward, terminated, truncated, info = result - assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" - assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" + assert isinstance( + terminated, bool + ), f"The `terminated` signal is of type `{type(terminated)}`. It must be a boolean" + assert isinstance( + truncated, bool + ), f"The `truncated` signal of type `{type(truncated)}`. It must be a boolean." assert ( terminated is False or truncated is False ), "Only `terminated` or `truncated` can be true, not both." diff --git a/gym/utils/play.py b/gym/utils/play.py index 1b46b968716..5793021c2a1 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -1,4 +1,7 @@ """Utilities of visualising an environment.""" + +# TODO: Convert to new step API in 1.0 + from collections import deque from typing import Callable, Dict, List, Optional, Tuple, Union @@ -208,6 +211,10 @@ def play( seed: Random seed used when resetting the environment. If None, no seed is used. noop: The action used when no key input has been entered, or the entered key combination is unknown. """ + deprecation( + "`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools." + ) + env.reset(seed=seed) if keys_to_action is None: diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py new file mode 100644 index 00000000000..2be07dbe35c --- /dev/null +++ b/gym/utils/step_api_compatibility.py @@ -0,0 +1,180 @@ +"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0.""" +from typing import Tuple, Union + +import numpy as np + +from gym.core import ObsType + +OldStepType = Tuple[ + Union[ObsType, np.ndarray], + Union[float, np.ndarray], + Union[bool, np.ndarray], + Union[dict, list], +] + +NewStepType = Tuple[ + Union[ObsType, np.ndarray], + Union[float, np.ndarray], + Union[bool, np.ndarray], + Union[bool, np.ndarray], + Union[dict, list], +] + + +def step_to_new_api( + step_returns: Union[OldStepType, NewStepType], is_vector_env=False +) -> NewStepType: + """Function to transform step returns to new step API irrespective of input API. + + Args: + step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + is_vector_env (bool): Whether the step_returns are from a vector environment + """ + if len(step_returns) == 5: + return step_returns + else: + assert len(step_returns) == 4 + observations, rewards, dones, infos = step_returns + + terminateds = [] + truncateds = [] + if not is_vector_env: + dones = [dones] + + for i in range(len(dones)): + # For every condition, handling - info single env / info vector env (list) / info vector env (dict) + + # TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done` + if (not is_vector_env and "TimeLimit.truncated" not in infos) or ( + is_vector_env + and ( + ( + isinstance(infos, list) + and "TimeLimit.truncated" not in infos[i] + ) # vector env, list info api + or ( + "TimeLimit.truncated" not in infos + or ( + "TimeLimit.truncated" in infos + and not infos["_TimeLimit.truncated"][i] + ) + ) # vector env, dict info api, if mask is False, it's the same as TimeLimit.truncated attribute not being present for env 'i' + ) + ): + + terminateds.append(dones[i]) + truncateds.append(False) + + # This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not. + elif ( + infos["TimeLimit.truncated"] + if not is_vector_env + else ( + infos["TimeLimit.truncated"][i] + if isinstance(infos, dict) + else infos[i]["TimeLimit.truncated"] + ) + ): + assert dones[i] + terminateds.append(False) + truncateds.append(True) + else: + # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, + # but it also exceeded maximum timesteps at the same step. + assert dones[i] + terminateds.append(True) + truncateds.append(True) + + return ( + observations, + rewards, + np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0], + np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0], + infos, + ) + + +def step_to_old_api( + step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False +) -> OldStepType: + """Function to transform step returns to old step API irrespective of input API. + + Args: + step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + is_vector_env (bool): Whether the step_returns are from a vector environment + """ + if len(step_returns) == 4: + return step_returns + else: + assert len(step_returns) == 5 + observations, rewards, terminateds, truncateds, infos = step_returns + dones = [] + if not is_vector_env: + terminateds = [terminateds] + truncateds = [truncateds] + + n_envs = len(terminateds) + + for i in range(n_envs): + dones.append(terminateds[i] or truncateds[i]) + if truncateds[i]: + if is_vector_env: + # handle vector infos for dict and list + if isinstance(infos, dict): + if "TimeLimit.truncated" not in infos: + # TODO: This should ideally not be done manually and should use vector_env's _add_info() + infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) + infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) + + infos["TimeLimit.truncated"][i] = ( + not terminateds[i] or infos["TimeLimit.truncated"][i] + ) + infos["_TimeLimit.truncated"][i] = True + else: + # if vector info is a list + infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[ + i + ].get("TimeLimit.truncated", False) + else: + infos["TimeLimit.truncated"] = not terminateds[i] or infos.get( + "TimeLimit.truncated", False + ) + return ( + observations, + rewards, + np.array(dones, dtype=np.bool_) if is_vector_env else dones[0], + infos, + ) + + +def step_api_compatibility( + step_returns: Union[NewStepType, OldStepType], + new_step_api: bool = False, + is_vector_env: bool = False, +) -> Union[NewStepType, OldStepType]: + """Function to transform step returns to the API specified by `new_step_api` bool. + + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) + (Refer to docs for details on the API change) + + Args: + step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + new_step_api (bool): Whether the output should be in new step API or old (False by default) + is_vector_env (bool): Whether the step_returns are from a vector environment + + Returns: + step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + + Examples: + This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, + wrapper is written in new API, and the final step output is desired to be in old API. + + >>> obs, rew, done, info = step_api_compatibility(env.step(action)) + >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) + >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) + """ + if new_step_api: + return step_to_new_api(step_returns, is_vector_env) + else: + return step_to_old_api(step_returns, is_vector_env) diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index b71ec11940d..cbf7978c647 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -15,6 +15,7 @@ def make( asynchronous: bool = True, wrappers: Optional[Union[callable, List[callable]]] = None, disable_env_checker: bool = False, + new_step_api: bool = False, **kwargs, ) -> VectorEnv: """Create a vectorized environment from multiple copies of an environment, from its id. @@ -35,6 +36,7 @@ def make( asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`. wrappers: If not ``None``, then apply the wrappers to each internal environment during creation. disable_env_checker: If to disable the env checker, if True it will only run on the first environment created. + new_step_api: If True, the vector environment's step method outputs two booleans `terminated`, `truncated` instead of one `done`. **kwargs: Keywords arguments applied during gym.make Returns: @@ -46,7 +48,10 @@ def create_env(_disable_env_checker): def _make_env(): env = gym.envs.registration.make( - id, disable_env_checker=_disable_env_checker, **kwargs + id, + disable_env_checker=_disable_env_checker, + new_step_api=True, + **kwargs, ) if wrappers is not None: if callable(wrappers): @@ -65,4 +70,8 @@ def _make_env(): env_fns = [ create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs) ] - return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns) + return ( + AsyncVectorEnv(env_fns, new_step_api=new_step_api) + if asynchronous + else SyncVectorEnv(env_fns, new_step_api=new_step_api) + ) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index faf0c96d329..0c71d959736 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -17,6 +17,7 @@ CustomSpaceError, NoAsyncCallError, ) +from gym.utils.step_api_compatibility import step_api_compatibility from gym.vector.utils import ( CloudpickleWrapper, clear_mpi_env_vars, @@ -66,6 +67,7 @@ def __init__( context: Optional[str] = None, daemon: bool = True, worker: Optional[callable] = None, + new_step_api: bool = False, ): """Vectorized environment that runs multiple environments in parallel. @@ -84,7 +86,8 @@ def __init__( the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``. worker: If set, then use that worker in a subprocess instead of a default one. - Can be useful to override some inner vector env logic, for instance, how resets on done are handled. + Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled. + new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start @@ -112,6 +115,7 @@ def __init__( num_envs=len(env_fns), observation_space=observation_space, action_space=action_space, + new_step_api=new_step_api, ) if self.shared_memory: @@ -338,7 +342,7 @@ def step_wait( timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out. Returns: - The batched environment step information, obs, reward, done and info + The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api Raises: ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). @@ -358,16 +362,17 @@ def step_wait( f"The call to `step_wait` has timed out after {timeout} second(s)." ) - observations_list, rewards, dones, infos = [], [], [], {} + observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {} successes = [] for i, pipe in enumerate(self.parent_pipes): result, success = pipe.recv() - obs, rew, done, info = result + obs, rew, terminated, truncated, info = step_api_compatibility(result, True) successes.append(success) observations_list.append(obs) rewards.append(rew) - dones.append(done) + terminateds.append(terminated) + truncateds.append(truncated) infos = self._add_info(infos, info, i) self._raise_if_errors(successes) @@ -380,11 +385,16 @@ def step_wait( self.observations, ) - return ( - deepcopy(self.observations) if self.copy else self.observations, - np.array(rewards), - np.array(dones, dtype=np.bool_), - infos, + return step_api_compatibility( + ( + deepcopy(self.observations) if self.copy else self.observations, + np.array(rewards), + np.array(terminateds, dtype=np.bool_), + np.array(truncateds, dtype=np.bool_), + infos, + ), + self.new_step_api, + True, ) def call_async(self, name: str, *args, **kwargs): @@ -604,11 +614,17 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): pipe.send((observation, True)) elif command == "step": - observation, reward, done, info = env.step(data) - if done: - info["terminal_observation"] = observation + ( + observation, + reward, + terminated, + truncated, + info, + ) = step_api_compatibility(env.step(data), True) + if terminated or truncated: + info["final_observation"] = observation observation = env.reset() - pipe.send(((observation, reward, done, info), True)) + pipe.send(((observation, reward, terminated, truncated, info), True)) elif command == "seed": env.seed(data) pipe.send((None, True)) @@ -673,14 +689,20 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error ) pipe.send((None, True)) elif command == "step": - observation, reward, done, info = env.step(data) - if done: - info["terminal_observation"] = observation + ( + observation, + reward, + terminated, + truncated, + info, + ) = step_api_compatibility(env.step(data), True) + if terminated or truncated: + info["final_observation"] = observation observation = env.reset() write_to_shared_memory( observation_space, index, observation, shared_memory ) - pipe.send(((None, reward, done, info), True)) + pipe.send(((None, reward, terminated, truncated, info), True)) elif command == "seed": env.seed(data) pipe.send((None, True)) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index f470461ecd1..cc3408e7adb 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -6,6 +6,7 @@ from gym import Env from gym.spaces import Space +from gym.utils.step_api_compatibility import step_api_compatibility from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.vector_env import VectorEnv @@ -33,6 +34,7 @@ def __init__( observation_space: Space = None, action_space: Space = None, copy: bool = True, + new_step_api: bool = False, ): """Vectorized environment that serially runs multiple environments. @@ -60,6 +62,7 @@ def __init__( num_envs=len(self.envs), observation_space=observation_space, action_space=action_space, + new_step_api=new_step_api, ) self._check_spaces() @@ -67,7 +70,8 @@ def __init__( self.single_observation_space, n=self.num_envs, fn=np.zeros ) self._rewards = np.zeros((self.num_envs,), dtype=np.float64) - self._dones = np.zeros((self.num_envs,), dtype=np.bool_) + self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_) + self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_) self._actions = None def seed(self, seed: Optional[Union[int, Sequence[int]]] = None): @@ -108,7 +112,8 @@ def reset_wait( seed = [seed + i for i in range(self.num_envs)] assert len(seed) == self.num_envs - self._dones[:] = False + self._terminateds[:] = False + self._truncateds[:] = False observations = [] infos = {} for i, (env, single_seed) in enumerate(zip(self.envs, seed)): @@ -151,9 +156,15 @@ def step_wait(self): """ observations, infos = [], {} for i, (env, action) in enumerate(zip(self.envs, self._actions)): - observation, self._rewards[i], self._dones[i], info = env.step(action) - if self._dones[i]: - info["terminal_observation"] = observation + ( + observation, + self._rewards[i], + self._terminateds[i], + self._truncateds[i], + info, + ) = step_api_compatibility(env.step(action), True) + if self._terminateds[i] or self._truncateds[i]: + info["final_observation"] = observation observation = env.reset() observations.append(observation) infos = self._add_info(infos, info, i) @@ -161,11 +172,16 @@ def step_wait(self): self.single_observation_space, observations, self.observations ) - return ( - deepcopy(self.observations) if self.copy else self.observations, - np.copy(self._rewards), - np.copy(self._dones), - infos, + return step_api_compatibility( + ( + deepcopy(self.observations) if self.copy else self.observations, + np.copy(self._rewards), + np.copy(self._terminateds), + np.copy(self._truncateds), + infos, + ), + new_step_api=self.new_step_api, + is_vector_env=True, ) def call(self, name, *args, **kwargs) -> tuple: diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index b418f708953..ad2710e02ae 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -1,5 +1,5 @@ """Base class for vectorized environments.""" -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -24,7 +24,11 @@ class VectorEnv(gym.Env): """ def __init__( - self, num_envs: int, observation_space: gym.Space, action_space: gym.Space + self, + num_envs: int, + observation_space: gym.Space, + action_space: gym.Space, + new_step_api: bool = False, ): """Base class for vectorized environments. @@ -32,6 +36,7 @@ def __init__( num_envs: Number of environments in the vectorized environment. observation_space: Observation space of a single environment. action_space: Action space of a single environment. + new_step_api (bool): Whether the vector env's step method outputs two boolean arrays (new API) or one boolean array (old API) """ self.num_envs = num_envs self.is_vector_env = True @@ -46,6 +51,13 @@ def __init__( self.single_observation_space = observation_space self.single_action_space = action_space + self.new_step_api = new_step_api + if not self.new_step_api: + deprecation( + "Initializing vector env in old step API which returns one bool array instead of two. " + "It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. " + ) + def reset_async( self, seed: Optional[Union[int, List[int]]] = None, @@ -135,7 +147,7 @@ def step(self, actions): actions: element of :attr:`action_space` Batch of actions. Returns: - Batch of observations, rewards, done and infos + Batch of (observations, rewards, terminateds, truncateds, infos) or (observations, rewards, dones, infos) """ self.step_async(actions) return self.step_wait() @@ -143,7 +155,7 @@ def step(self, actions): def call_async(self, name, *args, **kwargs): """Calls a method name for each parallel environment asynchronously.""" - def call_wait(self, **kwargs) -> List[Any]: + def call_wait(self, **kwargs) -> List[Any]: # type: ignore """After calling a method in :meth:`call_async`, this function collects the results.""" def call(self, name: str, *args, **kwargs) -> List[Any]: @@ -251,7 +263,7 @@ def _add_info(self, infos: dict, info: dict, env_num: int) -> dict: infos[k], infos[f"_{k}"] = info_array, array_mask return infos - def _init_info_arrays(self, dtype: type) -> np.ndarray: + def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]: """Initialize the info array. Initialize the info array. If the dtype is numeric diff --git a/gym/wrappers/__init__.py b/gym/wrappers/__init__.py index 483c421b1a8..856c806a6fe 100644 --- a/gym/wrappers/__init__.py +++ b/gym/wrappers/__init__.py @@ -14,6 +14,7 @@ from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.resize_observation import ResizeObservation +from gym.wrappers.step_api_compatibility import StepAPICompatibility from gym.wrappers.time_aware_observation import TimeAwareObservation from gym.wrappers.time_limit import TimeLimit from gym.wrappers.transform_observation import TransformObservation diff --git a/gym/wrappers/atari_preprocessing.py b/gym/wrappers/atari_preprocessing.py index 8e61dbd3405..96c3ee7b176 100644 --- a/gym/wrappers/atari_preprocessing.py +++ b/gym/wrappers/atari_preprocessing.py @@ -3,6 +3,7 @@ import gym from gym.spaces import Box +from gym.utils.step_api_compatibility import step_api_compatibility try: import cv2 @@ -37,6 +38,7 @@ def __init__( grayscale_obs: bool = True, grayscale_newaxis: bool = False, scale_obs: bool = False, + new_step_api: bool = False, ): """Wrapper for Atari 2600 preprocessing. @@ -45,7 +47,7 @@ def __init__( noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. screen_size (int): resize Atari frame - terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `done=True` whenever a + terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a life is lost. grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation is returned. @@ -58,7 +60,7 @@ def __init__( DependencyNotInstalled: opencv-python package not installed ValueError: Disable frame-skipping in the original env """ - super().__init__(env) + super().__init__(env, new_step_api) if cv2 is None: raise gym.error.DependencyNotInstalled( "opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari" @@ -114,20 +116,22 @@ def __init__( def step(self, action): """Applies the preprocessing for an :meth:`env.step`.""" - total_reward, done, info = 0.0, False, {} + total_reward, terminated, truncated, info = 0.0, False, False, {} for t in range(self.frame_skip): - _, reward, done, info = self.env.step(action) + _, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), True + ) total_reward += reward - self.game_over = done + self.game_over = terminated if self.terminal_on_life_loss: new_lives = self.ale.lives() - done = done or new_lives < self.lives - self.game_over = done + terminated = terminated or new_lives < self.lives + self.game_over = terminated self.lives = new_lives - if done: + if terminated or truncated: break if t == self.frame_skip - 2: if self.grayscale_obs: @@ -139,7 +143,10 @@ def step(self, action): self.ale.getScreenGrayscale(self.obs_buffer[0]) else: self.ale.getScreenRGB(self.obs_buffer[0]) - return self._get_obs(), total_reward, done, info + return step_api_compatibility( + (self._get_obs(), total_reward, terminated, truncated, info), + self.new_step_api, + ) def reset(self, **kwargs): """Resets the environment using preprocessing.""" @@ -156,9 +163,11 @@ def reset(self, **kwargs): else 0 ) for _ in range(noops): - _, _, done, step_info = self.env.step(0) + _, _, terminated, truncated, step_info = step_api_compatibility( + self.env.step(0), True + ) reset_info.update(step_info) - if done: + if terminated or truncated: if kwargs.get("return_info", False): _, reset_info = self.env.reset(**kwargs) else: diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index baa667e8053..6e20c92ffed 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -1,29 +1,40 @@ -"""Wrapper that autoreset environments when `done=True`.""" +"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`.""" import gym +from gym.utils.step_api_compatibility import step_api_compatibility class AutoResetWrapper(gym.Wrapper): """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. - When calling step causes :meth:`Env.step` to return done, :meth:`Env.reset` is called, - and the return format of :meth:`self.step` is as follows: ``(new_obs, terminal_reward, terminal_done, info)`` + When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called, + and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)`` + with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API. - ``new_obs`` is the first observation after calling :meth:`self.env.reset` - - ``terminal_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. - - ``terminal_done`` is always True + - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. + - ``final_done`` is always True. In the new API, either ``final_terminated`` or ``final_truncated`` is True - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`, - with an additional key "terminal_observation" containing the observation returned by the last call to :meth:`self.env.step` - and "terminal_info" containing the info dict returned by the last call to :meth:`self.env.step`. + with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step` + and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`. - Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns done, a + Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the - terminal reward and done state from the previous episode. - If you need the terminal state from the previous episode, you need to retrieve it via the - "terminal_observation" key in the info dict. + final reward and done state from the previous episode. + If you need the final state from the previous episode, you need to retrieve it via the + "final_observation" key in the info dict. Make sure you know what you're doing if you use this wrapper! """ + def __init__(self, env: gym.Env, new_step_api: bool = False): + """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. + + Args: + env (gym.Env): The environment to apply the wrapper + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) + """ + super().__init__(env, new_step_api) + def step(self, action): - """Steps through the environment with action and resets the environment if a done-signal is encountered. + """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered. Args: action: The action to take @@ -31,22 +42,26 @@ def step(self, action): Returns: The autoreset environment :meth:`step` """ - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), True + ) - if done: + if terminated or truncated: new_obs, new_info = self.env.reset(return_info=True) assert ( - "terminal_observation" not in new_info - ), 'info dict cannot contain key "terminal_observation" ' + "final_observation" not in new_info + ), 'info dict cannot contain key "final_observation" ' assert ( - "terminal_info" not in new_info - ), 'info dict cannot contain key "terminal_info" ' + "final_info" not in new_info + ), 'info dict cannot contain key "final_info" ' - new_info["terminal_observation"] = obs - new_info["terminal_info"] = info + new_info["final_observation"] = obs + new_info["final_info"] = info obs = new_obs info = new_info - return obs, reward, done, info + return step_api_compatibility( + (obs, reward, terminated, truncated, info), self.new_step_api + ) diff --git a/gym/wrappers/clip_action.py b/gym/wrappers/clip_action.py index de236384768..58d981e96a1 100644 --- a/gym/wrappers/clip_action.py +++ b/gym/wrappers/clip_action.py @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env): env: The environment to apply the wrapper """ assert isinstance(env.action_space, Box) - super().__init__(env) + super().__init__(env, new_step_api=True) def action(self, action): """Clips the action within the valid bounds. diff --git a/gym/wrappers/env_checker.py b/gym/wrappers/env_checker.py index 4fe0c011c96..785cd69f40d 100644 --- a/gym/wrappers/env_checker.py +++ b/gym/wrappers/env_checker.py @@ -15,7 +15,7 @@ class PassiveEnvChecker(gym.Wrapper): def __init__(self, env): """Initialises the wrapper with the environments, run the observation and action space tests.""" - super().__init__(env) + super().__init__(env, new_step_api=True) assert hasattr( env, "action_space" diff --git a/gym/wrappers/filter_observation.py b/gym/wrappers/filter_observation.py index 922c8288038..bcbe13b5065 100644 --- a/gym/wrappers/filter_observation.py +++ b/gym/wrappers/filter_observation.py @@ -35,7 +35,7 @@ def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None): ValueError: If the environment's observation space is not :class:`spaces.Dict` ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space """ - super().__init__(env) + super().__init__(env, new_step_api=True) wrapped_observation_space = env.observation_space if not isinstance(wrapped_observation_space, spaces.Dict): diff --git a/gym/wrappers/flatten_observation.py b/gym/wrappers/flatten_observation.py index fe6518b875b..95aa13e0d01 100644 --- a/gym/wrappers/flatten_observation.py +++ b/gym/wrappers/flatten_observation.py @@ -25,7 +25,7 @@ def __init__(self, env: gym.Env): Args: env: The environment to apply the wrapper """ - super().__init__(env) + super().__init__(env, new_step_api=True) self.observation_space = spaces.flatten_space(env.observation_space) def observation(self, observation): diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index 83aa440aba1..c7eb19aaacd 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -7,6 +7,7 @@ import gym from gym.error import DependencyNotInstalled from gym.spaces import Box +from gym.utils.step_api_compatibility import step_api_compatibility class LazyFrames: @@ -122,15 +123,22 @@ class FrameStack(gym.ObservationWrapper): (4, 96, 96, 3) """ - def __init__(self, env: gym.Env, num_stack: int, lz4_compress: bool = False): + def __init__( + self, + env: gym.Env, + num_stack: int, + lz4_compress: bool = False, + new_step_api: bool = False, + ): """Observation wrapper that stacks the observations in a rolling manner. Args: env (Env): The environment to apply the wrapper num_stack (int): The number of frames to stack lz4_compress (bool): Use lz4 to compress the frames internally + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_stack = num_stack self.lz4_compress = lz4_compress @@ -163,11 +171,15 @@ def step(self, action): action: The action to step through the environment with Returns: - Stacked observations, reward, done and information from the environment + Stacked observations, reward, terminated, truncated, and information from the environment """ - observation, reward, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), True + ) self.frames.append(observation) - return self.observation(None), reward, done, info + return step_api_compatibility( + (self.observation(), reward, terminated, truncated, info), self.new_step_api + ) def reset(self, **kwargs): """Reset the environment with kwargs. diff --git a/gym/wrappers/gray_scale_observation.py b/gym/wrappers/gray_scale_observation.py index 1c626f41f4f..cf8a2ea05c7 100644 --- a/gym/wrappers/gray_scale_observation.py +++ b/gym/wrappers/gray_scale_observation.py @@ -28,7 +28,7 @@ def __init__(self, env: gym.Env, keep_dim: bool = False): keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. Otherwise, they are of shape AxB. """ - super().__init__(env) + super().__init__(env, new_step_api=True) self.keep_dim = keep_dim assert ( diff --git a/gym/wrappers/human_rendering.py b/gym/wrappers/human_rendering.py index 50bc12751d9..e2b234cf903 100644 --- a/gym/wrappers/human_rendering.py +++ b/gym/wrappers/human_rendering.py @@ -45,7 +45,7 @@ def __init__(self, env): Args: env: The environment that is being wrapped """ - super().__init__(env) + super().__init__(env, new_step_api=True) assert env.render_mode in [ "single_rgb_array", "rgb_array", diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index f026e23f79d..0c6ab04a48b 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -2,6 +2,7 @@ import numpy as np import gym +from gym.utils.step_api_compatibility import step_api_compatibility # taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py @@ -54,14 +55,15 @@ class NormalizeObservation(gym.core.Wrapper): newly instantiated or the policy was changed recently. """ - def __init__(self, env: gym.Env, epsilon: float = 1e-8): + def __init__(self, env: gym.Env, epsilon: float = 1e-8, new_step_api: bool = False): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. Args: env (Env): The environment to apply the wrapper epsilon: A stability parameter that is used when scaling the observations. + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) if self.is_vector_env: @@ -72,12 +74,18 @@ def __init__(self, env: gym.Env, epsilon: float = 1e-8): def step(self, action): """Steps through the environment and normalizes the observation.""" - obs, rews, dones, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = step_api_compatibility( + self.env.step(action), True, self.is_vector_env + ) if self.is_vector_env: obs = self.normalize(obs) else: obs = self.normalize(np.array([obs]))[0] - return obs, rews, dones, infos + return step_api_compatibility( + (obs, rews, terminateds, truncateds, infos), + self.new_step_api, + self.is_vector_env, + ) def reset(self, **kwargs): """Resets the environment and normalizes the observation.""" @@ -117,6 +125,7 @@ def __init__( env: gym.Env, gamma: float = 0.99, epsilon: float = 1e-8, + new_step_api: bool = False, ): """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. @@ -124,8 +133,9 @@ def __init__( env (env): The environment to apply the wrapper epsilon (float): A stability parameter gamma (float): The discount factor that is used in the exponential moving average. + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) self.return_rms = RunningMeanStd(shape=()) @@ -135,15 +145,25 @@ def __init__( def step(self, action): """Steps through the environment, normalizing the rewards returned.""" - obs, rews, dones, infos = self.env.step(action) + obs, rews, terminateds, truncateds, infos = step_api_compatibility( + self.env.step(action), True, self.is_vector_env + ) if not self.is_vector_env: rews = np.array([rews]) self.returns = self.returns * self.gamma + rews rews = self.normalize(rews) + if not self.is_vector_env: + dones = terminateds or truncateds + else: + dones = np.bitwise_or(terminateds, truncateds) self.returns[dones] = 0.0 if not self.is_vector_env: rews = rews[0] - return obs, rews, dones, infos + return step_api_compatibility( + (obs, rews, terminateds, truncateds, infos), + self.new_step_api, + self.is_vector_env, + ) def normalize(self, rews): """Normalizes the rewards with the running mean rewards and their variance.""" diff --git a/gym/wrappers/order_enforcing.py b/gym/wrappers/order_enforcing.py index d9f853e72bc..0e9da7f878e 100644 --- a/gym/wrappers/order_enforcing.py +++ b/gym/wrappers/order_enforcing.py @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): env: The environment to wrap disable_render_order_enforcing: If to disable render order enforcing """ - super().__init__(env) + super().__init__(env, new_step_api=True) self._has_reset: bool = False self._disable_render_order_enforcing: bool = disable_render_order_enforcing diff --git a/gym/wrappers/pixel_observation.py b/gym/wrappers/pixel_observation.py index 8e7c92e2eff..3e5e4262243 100644 --- a/gym/wrappers/pixel_observation.py +++ b/gym/wrappers/pixel_observation.py @@ -77,7 +77,7 @@ def __init__( specified ``pixel_keys``. TypeError: When an unexpected pixel type is used """ - super().__init__(env) + super().__init__(env, new_step_api=True) # Avoid side-effects that occur when render_kwargs is manipulated render_kwargs = copy.deepcopy(render_kwargs) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index f13f9ec8ebc..26cdb98f895 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -6,6 +6,7 @@ import numpy as np import gym +from gym.utils.step_api_compatibility import step_api_compatibility def add_vector_episode_statistics( @@ -76,14 +77,15 @@ class RecordEpisodeStatistics(gym.Wrapper): length_queue: The lengths of the last ``deque_size``-many episodes """ - def __init__(self, env: gym.Env, deque_size: int = 100): + def __init__(self, env: gym.Env, deque_size: int = 100, new_step_api: bool = False): """This wrapper will keep track of cumulative rewards and episode lengths. Args: env (Env): The environment to apply the wrapper deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) self.num_envs = getattr(env, "num_envs", 1) self.t0 = time.perf_counter() self.episode_count = 0 @@ -102,18 +104,26 @@ def reset(self, **kwargs): def step(self, action): """Steps through the environment, recording the episode statistics.""" - observations, rewards, dones, infos = super().step(action) + ( + observations, + rewards, + terminateds, + truncateds, + infos, + ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) assert isinstance( infos, dict ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." self.episode_returns += rewards self.episode_lengths += 1 if not self.is_vector_env: - dones = [dones] - dones = list(dones) + terminateds = [terminateds] + truncateds = [truncateds] + terminateds = list(terminateds) + truncateds = list(truncateds) - for i in range(len(dones)): - if dones[i]: + for i in range(len(terminateds)): + if terminateds[i] or truncateds[i]: episode_return = self.episode_returns[i] episode_length = self.episode_lengths[i] episode_info = { @@ -134,9 +144,14 @@ def step(self, action): self.episode_count += 1 self.episode_returns[i] = 0 self.episode_lengths[i] = 0 - return ( - observations, - rewards, - dones if self.is_vector_env else dones[0], - infos, + return step_api_compatibility( + ( + observations, + rewards, + terminateds if self.is_vector_env else terminateds[0], + truncateds if self.is_vector_env else truncateds[0], + infos, + ), + self.new_step_api, + self.is_vector_env, ) diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index 5bc7b7b5efc..8736915576e 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -4,6 +4,7 @@ import gym from gym import logger +from gym.utils.step_api_compatibility import step_api_compatibility from gym.wrappers.monitoring import video_recorder @@ -32,7 +33,7 @@ class RecordVideo(gym.Wrapper): They should be functions returning a boolean that indicates whether a recording should be started at the current episode or step, respectively. If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed. - By default, the recording will be stopped once a `done` signal has been emitted by the environment. However, you can + By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``. """ @@ -45,6 +46,7 @@ def __init__( step_trigger: Callable[[int], bool] = None, video_length: int = 0, name_prefix: str = "rl-video", + new_step_api: bool = False, ): """Wrapper records videos of rollouts. @@ -56,8 +58,9 @@ def __init__( video_length (int): The length of recorded episodes. If 0, entire episodes are recorded. Otherwise, snippets of the specified length are captured name_prefix (str): Will be prepended to the filename of the recordings + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) if episode_trigger is None and step_trigger is None: episode_trigger = capped_cubic_video_schedule @@ -83,7 +86,8 @@ def __init__( self.video_length = video_length self.recording = False - self.done = False + self.terminated = False + self.truncated = False self.recorded_frames = 0 self.is_vector_env = getattr(env, "is_vector_env", False) self.episode_id = 0 @@ -91,7 +95,8 @@ def __init__( def reset(self, **kwargs): """Reset the environment using kwargs and then starts recording if video enabled.""" observations = super().reset(**kwargs) - self.done = False + self.terminated = False + self.truncated = False if self.recording: assert self.video_recorder is not None self.video_recorder.frames = [] @@ -132,18 +137,26 @@ def _video_enabled(self): def step(self, action): """Steps through the environment using action, recording observations if :attr:`self.recording`.""" - observations, rewards, dones, infos = super().step(action) - - if not self.done: + ( + observations, + rewards, + terminateds, + truncateds, + infos, + ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) + + if not (self.terminated or self.truncated): # increment steps and episodes self.step_id += 1 if not self.is_vector_env: - if dones: + if terminateds or truncateds: self.episode_id += 1 - self.done = True - elif dones[0]: + self.terminated = terminateds + self.truncated = truncateds + elif terminateds[0] or truncateds[0]: self.episode_id += 1 - self.done = True + self.terminated = terminateds[0] + self.truncated = truncateds[0] if self.recording: assert self.video_recorder is not None @@ -154,15 +167,19 @@ def step(self, action): self.close_video_recorder() else: if not self.is_vector_env: - if dones: + if terminateds or truncateds: self.close_video_recorder() - elif dones[0]: + elif terminateds[0] or truncateds[0]: self.close_video_recorder() elif self._video_enabled(): self.start_video_recorder() - return observations, rewards, dones, infos + return step_api_compatibility( + (observations, rewards, terminateds, truncateds, infos), + self.new_step_api, + self.is_vector_env, + ) def close_video_recorder(self): """Closes the video recorder if currently recording.""" diff --git a/gym/wrappers/rescale_action.py b/gym/wrappers/rescale_action.py index bf3cf6cd157..c5f2238159e 100644 --- a/gym/wrappers/rescale_action.py +++ b/gym/wrappers/rescale_action.py @@ -45,7 +45,7 @@ def __init__( ), f"expected Box action space, got {type(env.action_space)}" assert np.less_equal(min_action, max_action).all(), (min_action, max_action) - super().__init__(env) + super().__init__(env, new_step_api=True) self.min_action = ( np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action ) diff --git a/gym/wrappers/resize_observation.py b/gym/wrappers/resize_observation.py index 4f486a97bdf..29116be7f09 100644 --- a/gym/wrappers/resize_observation.py +++ b/gym/wrappers/resize_observation.py @@ -32,7 +32,7 @@ def __init__(self, env: gym.Env, shape: Union[tuple, int]): env: The environment to apply the wrapper shape: The shape of the resized observations """ - super().__init__(env) + super().__init__(env, new_step_api=True) if isinstance(shape, int): shape = (shape, shape) assert all(x > 0 for x in shape), shape diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py new file mode 100644 index 00000000000..6c081b67be9 --- /dev/null +++ b/gym/wrappers/step_api_compatibility.py @@ -0,0 +1,57 @@ +"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" +import gym +from gym.logger import deprecation +from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api + + +class StepAPICompatibility(gym.Wrapper): + r"""A wrapper which can transform an environment from new step API to old and vice-versa. + + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) + (Refer to docs for details on the API change) + + This wrapper is to be used to ease transition to new API and for backward compatibility. + + Args: + env (gym.Env): the env to wrap. Can be in old or new API + new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) + + Examples: + >>> env = gym.make("CartPole-v1") + >>> env # wrapper applied by default, set to old API + >>>> + >>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API + >>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs + + """ + + def __init__(self, env: gym.Env, new_step_api=False): + """A wrapper which can transform an environment from new step API to old and vice-versa. + + Args: + env (gym.Env): the env to wrap. Can be in old or new API + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) + """ + super().__init__(env, new_step_api) + self.new_step_api = new_step_api + if not self.new_step_api: + deprecation( + "Initializing environment in old step API which returns one bool instead of two. " + "It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future. " + ) + + def step(self, action): + """Steps through the environment, returning 5 or 4 items depending on `new_step_api`. + + Args: + action: action to step through the environment with + + Returns: + (observation, reward, terminated, truncated, info) or (observation, reward, done, info) + """ + step_returns = self.env.step(action) + if self.new_step_api: + return step_to_new_api(step_returns) + else: + return step_to_old_api(step_returns) diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py index d2f655b6e92..2307eb06334 100644 --- a/gym/wrappers/time_aware_observation.py +++ b/gym/wrappers/time_aware_observation.py @@ -3,6 +3,7 @@ import gym from gym.spaces import Box +from gym.utils.step_api_compatibility import step_api_compatibility class TimeAwareObservation(gym.ObservationWrapper): @@ -21,18 +22,20 @@ class TimeAwareObservation(gym.ObservationWrapper): array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ]) """ - def __init__(self, env: gym.Env): + def __init__(self, env: gym.Env, new_step_api: bool = False): """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space. Args: env: The environment to apply the wrapper + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) assert isinstance(env.observation_space, Box) assert env.observation_space.dtype == np.float32 low = np.append(self.observation_space.low, 0.0) high = np.append(self.observation_space.high, np.inf) self.observation_space = Box(low, high, dtype=np.float32) + self.is_vector_env = getattr(env, "is_vector_env", False) def observation(self, observation): """Adds to the observation with the current time step. @@ -55,7 +58,9 @@ def step(self, action): The environment's step using the action. """ self.t += 1 - return super().step(action) + return step_api_compatibility( + super().step(action), self.new_step_api, self.is_vector_env + ) def reset(self, **kwargs): """Reset the environment setting the time to zero. diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 1776db4df97..8e9f67f4ae9 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -2,16 +2,20 @@ from typing import Optional import gym +from gym.utils.step_api_compatibility import step_api_compatibility class TimeLimit(gym.Wrapper): - """This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded. + """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. - Oftentimes, it is **very** important to distinguish `done` signals that were produced by the - :class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations). - This can be done by looking at the ``info`` that is returned when `done`-signal was issued. + If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. + Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. + + (deprecated) + This information is passed through ``info`` that is returned when `done`-signal was issued. The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if - the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. + the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. This will be removed in favour + of only issuing a `truncated` signal in future versions. Example: >>> from gym.envs.classic_control import CartPoleEnv @@ -20,14 +24,20 @@ class TimeLimit(gym.Wrapper): >>> env = TimeLimit(env, max_episode_steps=1000) """ - def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): + def __init__( + self, + env: gym.Env, + max_episode_steps: Optional[int] = None, + new_step_api: bool = False, + ): """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. Args: env: The environment to apply the wrapper max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env) + super().__init__(env, new_step_api) if max_episode_steps is None and self.env.spec is not None: max_episode_steps = env.spec.max_episode_steps if self.env.spec is not None: @@ -46,15 +56,19 @@ def step(self, action): when truncated (the number of steps elapsed >= max episode steps) or "TimeLimit.truncated"=False if the environment terminated """ - observation, reward, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = step_api_compatibility( + self.env.step(action), + True, + ) self._elapsed_steps += 1 + if self._elapsed_steps >= self._max_episode_steps: - # TimeLimit.truncated key may have been already set by the environment - # do not overwrite it - episode_truncated = not done or info.get("TimeLimit.truncated", False) - info["TimeLimit.truncated"] = episode_truncated - done = True - return observation, reward, done, info + truncated = True + + return step_api_compatibility( + (observation, reward, terminated, truncated, info), + self.new_step_api, + ) def reset(self, **kwargs): """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero. diff --git a/gym/wrappers/transform_observation.py b/gym/wrappers/transform_observation.py index 2af2e9afb40..4da9db5bac9 100644 --- a/gym/wrappers/transform_observation.py +++ b/gym/wrappers/transform_observation.py @@ -27,7 +27,7 @@ def __init__(self, env: gym.Env, f: Callable[[Any], Any]): env: The environment to apply the wrapper f: A function that transforms the observation """ - super().__init__(env) + super().__init__(env, new_step_api=True) assert callable(f) self.f = f diff --git a/gym/wrappers/transform_reward.py b/gym/wrappers/transform_reward.py index 4bb6e18cf95..13278182d6b 100644 --- a/gym/wrappers/transform_reward.py +++ b/gym/wrappers/transform_reward.py @@ -16,7 +16,7 @@ class TransformReward(RewardWrapper): >>> env = gym.make('CartPole-v1') >>> env = TransformReward(env, lambda r: 0.01*r) >>> env.reset() - >>> observation, reward, done, info = env.step(env.action_space.sample()) + >>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample()) >>> reward 0.01 """ @@ -28,7 +28,7 @@ def __init__(self, env: gym.Env, f: Callable[[float], float]): env: The environment to apply the wrapper f: A function that transforms the reward """ - super().__init__(env) + super().__init__(env, new_step_api=True) assert callable(f) self.f = f diff --git a/gym/wrappers/vector_list_info.py b/gym/wrappers/vector_list_info.py index bf5cbc82933..874367f2811 100644 --- a/gym/wrappers/vector_list_info.py +++ b/gym/wrappers/vector_list_info.py @@ -3,6 +3,7 @@ from typing import List import gym +from gym.utils.step_api_compatibility import step_api_compatibility class VectorListInfo(gym.Wrapper): @@ -29,23 +30,30 @@ class VectorListInfo(gym.Wrapper): """ - def __init__(self, env): + def __init__(self, env, new_step_api=False): """This wrapper will convert the info into the list format. Args: env (Env): The environment to apply the wrapper + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ assert getattr( env, "is_vector_env", False ), "This wrapper can only be used in vectorized environments." - super().__init__(env) + super().__init__(env, new_step_api) def step(self, action): """Steps through the environment, convert dict info to list.""" - observation, reward, done, infos = self.env.step(action) + observation, reward, terminated, truncated, infos = step_api_compatibility( + self.env.step(action), True, True + ) list_info = self._convert_info_to_list(infos) - return observation, reward, done, list_info + return step_api_compatibility( + (observation, reward, terminated, truncated, list_info), + self.new_step_api, + True, + ) def reset(self, **kwargs): """Resets the environment using kwargs.""" diff --git a/pyproject.toml b/pyproject.toml index d88b4f9aacd..efed648d2fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,5 +34,8 @@ reportPrivateUsage = "warning" reportUntypedFunctionDecorator = "none" reportMissingTypeStubs = false reportUnboundVariable = "warning" -reportGeneralTypeIssues = "none" -reportInvalidTypeVarUse = "none" \ No newline at end of file +reportGeneralTypeIssues ="none" +reportInvalidTypeVarUse = "none" + +[tool.pytest.ini_options] +filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 61b3ca2eb0d..af857567795 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -112,7 +112,9 @@ def test_box_actions_out_of_bound(env: gym.Env): zip(env.action_space.bounded_above, env.action_space.bounded_below) ): if is_upper_bound: - obs, _, _, _ = env.step(upper_bounds) + obs, _, _, _, _ = env.step( + upper_bounds + ) # `env` is unwrapped, and in new step API oob_action = upper_bounds.copy() oob_action[i] += np.cast[dtype](OOB_VALUE) @@ -122,7 +124,9 @@ def test_box_actions_out_of_bound(env: gym.Env): assert np.alltrue(obs == oob_obs) if is_lower_bound: - obs, _, _, _ = env.step(lower_bounds) + obs, _, _, _, _ = env.step( + lower_bounds + ) # `env` is unwrapped, and in new step API oob_action = lower_bounds.copy() oob_action[i] -= np.cast[dtype](OOB_VALUE) diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index 0a44fe11c89..d1988ef5b7c 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -144,7 +144,7 @@ def test_taxi_encode_decode(): assert ( env.encode(*env.decode(state)) == state ), f"state={state}, encode(decode(state))={env.encode(*env.decode(state))}" - state, _, _, _ = env.step(env.action_space.sample()) + state, _, _, _, _ = env.step(env.action_space.sample()) @pytest.mark.parametrize( diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 4d45fe276d0..b58636641d5 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -172,7 +172,9 @@ def test_make_render_mode(): assert env.render_mode == valid_render_modes[0] env.close() - assert len(warnings) == 0 + for warning in warnings.list: + if not re.compile(".*step API.*").match(warning.message.args[0]): + raise gym.error.Error(f"Unexpected warning: {warning.message}") # Make sure that native rendering is used when possible env = gym.make("CartPole-v1", render_mode="human", disable_env_checker=True) diff --git a/tests/utils/test_terminated_truncated.py b/tests/utils/test_terminated_truncated.py new file mode 100644 index 00000000000..e74fdc85378 --- /dev/null +++ b/tests/utils/test_terminated_truncated.py @@ -0,0 +1,91 @@ +import pytest + +import gym +from gym.spaces import Discrete +from gym.vector import AsyncVectorEnv, SyncVectorEnv +from gym.wrappers import TimeLimit + + +# An environment where termination happens after 20 steps +class DummyEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + self.terminal_timestep = 20 + + self.timestep = 0 + + def step(self, action): + self.timestep += 1 + terminated = True if self.timestep >= self.terminal_timestep else False + truncated = False + + return 0, 0, terminated, truncated, {} + + def reset(self): + self.timestep = 0 + return 0 + + +@pytest.mark.parametrize("time_limit", [10, 20, 30]) +def test_terminated_truncated(time_limit): + test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True) + + terminated = False + truncated = False + test_env.reset() + while not (terminated or truncated): + _, _, terminated, truncated, _ = test_env.step(0) + + if test_env.terminal_timestep < time_limit: + assert terminated + assert not truncated + elif test_env.terminal_timestep == time_limit: + assert ( + terminated + ), "`terminated` should be True even when termination and truncation happen at the same step" + assert ( + truncated + ), "`truncated` should be True even when termination and truncation occur at same step " + else: + assert not terminated + assert truncated + + +def test_terminated_truncated_vector(): + env0 = TimeLimit(DummyEnv(), 10, new_step_api=True) + env1 = TimeLimit(DummyEnv(), 20, new_step_api=True) + env2 = TimeLimit(DummyEnv(), 30, new_step_api=True) + + async_env = AsyncVectorEnv( + [lambda: env0, lambda: env1, lambda: env2], new_step_api=True + ) + async_env.reset() + terminateds = [False, False, False] + truncateds = [False, False, False] + counter = 0 + while not all([x or y for x, y in zip(terminateds, truncateds)]): + counter += 1 + _, _, terminateds, truncateds, _ = async_env.step( + async_env.action_space.sample() + ) + print(counter) + assert counter == 20 + assert all(terminateds == [False, True, True]) + assert all(truncateds == [True, True, False]) + + sync_env = SyncVectorEnv( + [lambda: env0, lambda: env1, lambda: env2], new_step_api=True + ) + sync_env.reset() + terminateds = [False, False, False] + truncateds = [False, False, False] + counter = 0 + while not all([x or y for x, y in zip(terminateds, truncateds)]): + counter += 1 + _, _, terminateds, truncateds, _ = sync_env.step( + async_env.action_space.sample() + ) + assert counter == 20 + assert all(terminateds == [False, True, True]) + assert all(truncateds == [True, True, False]) diff --git a/tests/vector/test_step_compatibility_vector.py b/tests/vector/test_step_compatibility_vector.py new file mode 100644 index 00000000000..d0305300fc7 --- /dev/null +++ b/tests/vector/test_step_compatibility_vector.py @@ -0,0 +1,88 @@ +import numpy as np +import pytest + +import gym +from gym.spaces import Discrete +from gym.vector import AsyncVectorEnv, SyncVectorEnv + + +class OldStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def reset(self): + return 0 + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + done = False + info = {} + return obs, rew, done, info + + +class NewStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def reset(self): + return 0 + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + terminated = False + truncated = False + info = {} + return obs, rew, terminated, truncated, info + + +@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv]) +def test_vector_step_compatibility_new_env(VecEnv): + + envs = [ + OldStepEnv(), + NewStepEnv(), + ] + + vec_env = VecEnv([lambda: env for env in envs]) + vec_env.reset() + step_returns = vec_env.step([0, 0]) + assert len(step_returns) == 4 + _, _, dones, _ = step_returns + assert dones.dtype == np.bool_ + vec_env.close() + + vec_env = VecEnv([lambda: env for env in envs], new_step_api=True) + vec_env.reset() + step_returns = vec_env.step([0, 0]) + assert len(step_returns) == 5 + _, _, terminateds, truncateds, _ = step_returns + assert terminateds.dtype == np.bool_ + assert truncateds.dtype == np.bool_ + vec_env.close() + + +@pytest.mark.parametrize("async_bool", [True, False]) +def test_vector_step_compatibility_existing(async_bool): + + env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool) + env.reset() + step_returns = env.step(env.action_space.sample()) + assert len(step_returns) == 4 + _, _, dones, _ = step_returns + assert dones.dtype == np.bool_ + env.close() + + env = gym.vector.make( + "CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True + ) + env.reset() + step_returns = env.step(env.action_space.sample()) + assert len(step_returns) == 5 + _, _, terminateds, truncateds, _ = step_returns + assert terminateds.dtype == np.bool_ + assert truncateds.dtype == np.bool_ + env.close() diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 8317ad6fcef..d74e646bedc 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -36,10 +36,10 @@ def test_vector_env_equal(shared_memory): # fmt: on if any(sync_dones): - assert "terminal_observation" in async_infos - assert "_terminal_observation" in async_infos - assert "terminal_observation" in sync_infos - assert "_terminal_observation" in sync_infos + assert "final_observation" in async_infos + assert "_final_observation" in async_infos + assert "final_observation" in sync_infos + assert "_final_observation" in sync_infos assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) diff --git a/tests/vector/test_vector_env_info.py b/tests/vector/test_vector_env_info.py index a7ef4e8da0e..33849bdeca3 100644 --- a/tests/vector/test_vector_env_info.py +++ b/tests/vector/test_vector_env_info.py @@ -22,18 +22,18 @@ def test_vector_env_info(asynchronous): action = env.action_space.sample() _, _, dones, infos = env.step(action) if any(dones): - assert len(infos["terminal_observation"]) == NUM_ENVS - assert len(infos["_terminal_observation"]) == NUM_ENVS + assert len(infos["final_observation"]) == NUM_ENVS + assert len(infos["_final_observation"]) == NUM_ENVS - assert isinstance(infos["terminal_observation"], np.ndarray) - assert isinstance(infos["_terminal_observation"], np.ndarray) + assert isinstance(infos["final_observation"], np.ndarray) + assert isinstance(infos["_final_observation"], np.ndarray) for i, done in enumerate(dones): if done: - assert infos["_terminal_observation"][i] + assert infos["_final_observation"][i] else: - assert not infos["_terminal_observation"][i] - assert infos["terminal_observation"][i] is None + assert not infos["_final_observation"][i] + assert infos["final_observation"][i] is None @pytest.mark.parametrize("concurrent_ends", [1, 2, 3]) @@ -49,8 +49,8 @@ def test_vector_env_info_concurrent_termination(concurrent_ends): for i, done in enumerate(dones): if i < concurrent_ends: assert done - assert infos["_terminal_observation"][i] + assert infos["_final_observation"][i] else: - assert not infos["_terminal_observation"][i] - assert infos["terminal_observation"][i] is None + assert not infos["_final_observation"][i] + assert infos["final_observation"][i] is None return diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index bb5a4b709fd..e4ed3f9b593 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -136,8 +136,8 @@ def test_autoreset_wrapper_autoreset(): assert reward == 1 assert info == { "count": 0, - "terminal_observation": np.array([3]), - "terminal_info": {"count": 3}, + "final_observation": np.array([3]), + "final_info": {"count": 3}, } obs, reward, done, info = env.step(action) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py new file mode 100644 index 00000000000..83557f02db6 --- /dev/null +++ b/tests/wrappers/test_step_compatibility.py @@ -0,0 +1,77 @@ +import pytest + +import gym +from gym.spaces import Discrete +from gym.wrappers import StepAPICompatibility + + +class OldStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + done = False + info = {} + return obs, rew, done, info + + +class NewStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + terminated = False + truncated = False + info = {} + return obs, rew, terminated, truncated, info + + +@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) +def test_step_compatibility_to_new_api(env): + env = StepAPICompatibility(env(), True) + step_returns = env.step(0) + _, _, terminated, truncated, _ = step_returns + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + + +@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) +@pytest.mark.parametrize("new_step_api", [None, False]) +def test_step_compatibility_to_old_api(env, new_step_api): + if new_step_api is None: + env = StepAPICompatibility(env()) # default behavior is to retain old API + else: + env = StepAPICompatibility(env(), new_step_api) + step_returns = env.step(0) + assert len(step_returns) == 4 + _, _, done, _ = step_returns + assert isinstance(done, bool) + + +@pytest.mark.parametrize("new_step_api", [None, True, False]) +def test_step_compatibility_in_make(new_step_api): + if new_step_api is None: + with pytest.warns( + DeprecationWarning, match="Initializing environment in old step API" + ): + env = gym.make("CartPole-v1") + else: + env = gym.make("CartPole-v1", new_step_api=new_step_api) + + env.reset() + step_returns = env.step(0) + if new_step_api: + assert len(step_returns) == 5 + _, _, terminated, truncated, _ = step_returns + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + else: + assert len(step_returns) == 4 + _, _, done, _ = step_returns + assert isinstance(done, bool) diff --git a/tests/wrappers/test_vector_list_info.py b/tests/wrappers/test_vector_list_info.py index c14eee9cfdd..26c6e772876 100644 --- a/tests/wrappers/test_vector_list_info.py +++ b/tests/wrappers/test_vector_list_info.py @@ -32,9 +32,9 @@ def test_info_to_list(): _, _, dones, list_info = wrapped_env.step(action) for i, done in enumerate(dones): if done: - assert "terminal_observation" in list_info[i] + assert "final_observation" in list_info[i] else: - assert "terminal_observation" not in list_info[i] + assert "final_observation" not in list_info[i] def test_info_to_list_statistics():