diff --git a/gymnasium/core.py b/gymnasium/core.py index f2bc4063ff..14a27bfce8 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -201,7 +201,7 @@ def np_random(self) -> np.random.Generator: Instances of `np.random.Generator` """ if self._np_random is None: - self._np_random, seed = seeding.np_random() + self._np_random, _ = seeding.np_random() return self._np_random @np_random.setter @@ -234,7 +234,10 @@ def __exit__(self, *args: Any): WrapperActType = TypeVar("WrapperActType") -class Wrapper(Env[WrapperObsType, WrapperActType]): +class Wrapper( + Env[WrapperObsType, WrapperActType], + Generic[WrapperObsType, WrapperActType, ObsType, ActType], +): """Wraps a :class:`gymnasium.Env` to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. This class is the base class of all wrappers to change the behavior of the underlying environment. @@ -391,7 +394,7 @@ def unwrapped(self) -> Env[ObsType, ActType]: return self.env.unwrapped -class ObservationWrapper(Wrapper[WrapperObsType, ActType]): +class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]): """Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`. If you would like to apply a function to only the observation before @@ -434,7 +437,7 @@ def observation(self, observation: ObsType) -> WrapperObsType: raise NotImplementedError -class RewardWrapper(Wrapper[ObsType, ActType]): +class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]): """Superclass of wrappers that can modify the returning reward from a step. If you would like to apply a function to the reward that is returned by the base environment before @@ -467,7 +470,7 @@ def reward(self, reward: SupportsFloat) -> SupportsFloat: raise NotImplementedError -class ActionWrapper(Wrapper[ObsType, WrapperActType]): +class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]): """Superclass of wrappers that can modify the action before :meth:`env.step`. If you would like to apply a function to the action before passing it to the base environment, diff --git a/gymnasium/experimental/wrappers/common.py b/gymnasium/experimental/wrappers/common.py index 223cfd1026..efac3f3aa5 100644 --- a/gymnasium/experimental/wrappers/common.py +++ b/gymnasium/experimental/wrappers/common.py @@ -15,7 +15,7 @@ import gymnasium as gym from gymnasium import Env -from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType +from gymnasium.core import ActType, ObsType, RenderFrame from gymnasium.error import ResetNeeded from gymnasium.utils.passive_env_checker import ( check_action_space, @@ -26,10 +26,10 @@ ) -class AutoresetV0(gym.Wrapper): +class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.""" - def __init__(self, env: gym.Env): + def __init__(self, env: gym.Env[ObsType, ActType]): """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. Args: @@ -40,8 +40,8 @@ def __init__(self, env: gym.Env): self._reset_options: dict[str, Any] | None = None def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step. Args: @@ -51,7 +51,7 @@ def step( The autoreset environment :meth:`step` """ if self._episode_ended: - obs, info = super().reset(options=self._reset_options) + obs, info = self.env.reset(options=self._reset_options) self._episode_ended = True return obs, 0, False, False, info else: @@ -61,14 +61,14 @@ def step( def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment, saving the options used.""" self._episode_ended = False self._reset_options = options return super().reset(seed=seed, options=self._reset_options) -class PassiveEnvCheckerV0(gym.Wrapper): +class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" def __init__(self, env: Env[ObsType, ActType]): @@ -89,8 +89,8 @@ def __init__(self, env: Env[ObsType, ActType]): self._checked_render: bool = False def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment that on the first call will run the `passive_env_step_check`.""" if self._checked_step is False: self._checked_step = True @@ -100,7 +100,7 @@ def step( def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment that on the first call will run the `passive_env_reset_check`.""" if self._checked_reset is False: self._checked_reset = True @@ -117,7 +117,7 @@ def render(self) -> RenderFrame | list[RenderFrame] | None: return self.env.render() -class OrderEnforcingV0(gym.Wrapper): +class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. Example: @@ -139,7 +139,11 @@ class OrderEnforcingV0(gym.Wrapper): >>> env.close() """ - def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): + def __init__( + self, + env: gym.Env[ObsType, ActType], + disable_render_order_enforcing: bool = False, + ): """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. Args: @@ -150,17 +154,15 @@ def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): self._has_reset: bool = False self._disable_render_order_enforcing: bool = disable_render_order_enforcing - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: - """Steps through the environment with `kwargs`.""" + def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: + """Steps through the environment.""" if not self._has_reset: raise ResetNeeded("Cannot call env.step() before calling env.reset()") return super().step(action) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment with `kwargs`.""" self._has_reset = True return super().reset(seed=seed, options=options) @@ -180,7 +182,7 @@ def has_reset(self): return self._has_reset -class RecordEpisodeStatisticsV0(gym.Wrapper): +class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """This wrapper will keep track of cumulative rewards and episode lengths. At the end of an episode, the statistics of the episode will be added to ``info`` @@ -244,13 +246,13 @@ def __init__( self.episode_reward: float = -1 self.episode_length: int = -1 - self.episode_time_length_buffer = deque(maxlen=buffer_length) - self.episode_reward_buffer = deque(maxlen=buffer_length) - self.episode_length_buffer = deque(maxlen=buffer_length) + self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length) + self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length) + self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length) def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment, recording the episode statistics.""" obs, reward, terminated, truncated, info = super().step(action) @@ -279,7 +281,7 @@ def step( def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment using seed and options and resets the episode rewards and lengths.""" obs, info = super().reset(seed=seed, options=options) diff --git a/gymnasium/experimental/wrappers/conversion/jax_to_numpy.py b/gymnasium/experimental/wrappers/conversion/jax_to_numpy.py index 9f4c264a66..cc7a39b2d0 100644 --- a/gymnasium/experimental/wrappers/conversion/jax_to_numpy.py +++ b/gymnasium/experimental/wrappers/conversion/jax_to_numpy.py @@ -9,7 +9,7 @@ import numpy as np from gymnasium import Env, Wrapper -from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType +from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled @@ -92,7 +92,7 @@ def _iterable_jax_to_numpy( return type(value)(jax_to_numpy(v) for v in value) -class JaxToNumpyV0(Wrapper): +class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]): """Wraps a jax environment so that it can be interacted with through numpy arrays. Actions must be provided as numpy arrays and observations will be returned as numpy arrays. @@ -102,7 +102,7 @@ class JaxToNumpyV0(Wrapper): The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)`` """ - def __init__(self, env: Env): + def __init__(self, env: Env[ObsType, ActType]): """Wraps an environment such that the input and outputs are numpy arrays. Args: diff --git a/gymnasium/experimental/wrappers/lambda_action.py b/gymnasium/experimental/wrappers/lambda_action.py index 703525d514..f9ba7439bd 100644 --- a/gymnasium/experimental/wrappers/lambda_action.py +++ b/gymnasium/experimental/wrappers/lambda_action.py @@ -16,18 +16,18 @@ import numpy as np import gymnasium as gym -from gymnasium.core import ActType, WrapperActType +from gymnasium.core import ActType, ObsType, WrapperActType from gymnasium.spaces import Box, Space -class LambdaActionV0(gym.ActionWrapper): +class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]): """A wrapper that provides a function to modify the action passed to :meth:`step`.""" def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], func: Callable[[WrapperActType], ActType], - action_space: Space | None, + action_space: Space[WrapperActType] | None, ): """Initialize LambdaAction. @@ -47,7 +47,7 @@ def action(self, action: WrapperActType) -> ActType: return self.func(action) -class ClipActionV0(LambdaActionV0): +class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]): """Clip the continuous action within the valid :class:`Box` observation space bound. Example: @@ -63,7 +63,7 @@ class ClipActionV0(LambdaActionV0): ... # Executes the action np.array([1.0, -1.0, 0]) in the base environment """ - def __init__(self, env: gym.Env): + def __init__(self, env: gym.Env[ObsType, ActType]): """A wrapper for clipping continuous actions within the valid bound. Args: @@ -83,7 +83,7 @@ def __init__(self, env: gym.Env): ) -class RescaleActionV0(LambdaActionV0): +class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]): """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` @@ -107,7 +107,7 @@ class RescaleActionV0(LambdaActionV0): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], min_action: float | int | np.ndarray, max_action: float | int | np.ndarray, ): diff --git a/gymnasium/experimental/wrappers/lambda_observations.py b/gymnasium/experimental/wrappers/lambda_observations.py index feb5434799..5debea8858 100644 --- a/gymnasium/experimental/wrappers/lambda_observations.py +++ b/gymnasium/experimental/wrappers/lambda_observations.py @@ -31,7 +31,7 @@ from gymnasium.spaces import Box, Dict, utils -class LambdaObservationV0(gym.ObservationWrapper): +class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]): """Transforms an observation via a function provided to the wrapper. The function :attr:`func` will be applied to all observations. @@ -50,9 +50,9 @@ class LambdaObservationV0(gym.ObservationWrapper): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], func: Callable[[ObsType], Any], - observation_space: gym.Space | None, + observation_space: gym.Space[WrapperObsType] | None, ): """Constructor for the lambda observation wrapper. @@ -72,7 +72,7 @@ def observation(self, observation: ObsType) -> Any: return self.func(observation) -class FilterObservationV0(LambdaObservationV0): +class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Filter Dict observation space by the keys. Example: @@ -91,7 +91,9 @@ class FilterObservationV0(LambdaObservationV0): ({'time': 0}, 1.0, False, False, {}) """ - def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]): + def __init__( + self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int] + ): """Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys.""" assert isinstance(filter_keys, Sequence) @@ -169,7 +171,7 @@ def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]): self.filter_keys: Final[Sequence[str | int]] = filter_keys -class FlattenObservationV0(LambdaObservationV0): +class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Observation wrapper that flattens the observation. Example: @@ -186,7 +188,7 @@ class FlattenObservationV0(LambdaObservationV0): (27648,) """ - def __init__(self, env: gym.Env): + def __init__(self, env: gym.Env[ObsType, ActType]): """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.""" super().__init__( env, @@ -195,7 +197,7 @@ def __init__(self, env: gym.Env): ) -class GrayscaleObservationV0(LambdaObservationV0): +class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Observation wrapper that converts an RGB image to grayscale. The :attr:`keep_dim` will keep the channel dimension @@ -214,7 +216,7 @@ class GrayscaleObservationV0(LambdaObservationV0): (96, 96, 1) """ - def __init__(self, env: gym.Env, keep_dim: bool = False): + def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False): """Constructor for an RGB image based environments to make the image grayscale.""" assert isinstance(env.observation_space, spaces.Box) assert ( @@ -258,7 +260,7 @@ def __init__(self, env: gym.Env, keep_dim: bool = False): ) -class ResizeObservationV0(LambdaObservationV0): +class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Resizes image observations using OpenCV to shape. Example: @@ -272,7 +274,7 @@ class ResizeObservationV0(LambdaObservationV0): (32, 32, 3) """ - def __init__(self, env: gym.Env, shape: tuple[int, ...]): + def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]): """Constructor that requires an image environment observation space with a shape.""" assert isinstance(env.observation_space, spaces.Box) assert len(env.observation_space.shape) in [2, 3] @@ -304,7 +306,7 @@ def __init__(self, env: gym.Env, shape: tuple[int, ...]): ) -class ReshapeObservationV0(LambdaObservationV0): +class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Reshapes array based observations to shapes. Example: @@ -318,7 +320,7 @@ class ReshapeObservationV0(LambdaObservationV0): (24, 4, 96, 1, 3) """ - def __init__(self, env: gym.Env, shape: int | tuple[int, ...]): + def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]): """Constructor for env with Box observation space that has a shape product equal to the new shape product.""" assert isinstance(env.observation_space, spaces.Box) assert np.product(shape) == np.product(env.observation_space.shape) @@ -337,7 +339,7 @@ def __init__(self, env: gym.Env, shape: int | tuple[int, ...]): super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space) -class RescaleObservationV0(LambdaObservationV0): +class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Linearly rescales observation to between a minimum and maximum value. Example: @@ -353,7 +355,7 @@ class RescaleObservationV0(LambdaObservationV0): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], min_obs: np.floating | np.integer | np.ndarray, max_obs: np.floating | np.integer | np.ndarray, ): @@ -402,10 +404,10 @@ def __init__( ) -class DtypeObservationV0(LambdaObservationV0): +class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Observation wrapper for transforming the dtype of an observation.""" - def __init__(self, env: gym.Env, dtype: Any): + def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any): """Constructor for Dtype, this is only valid with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces.""" assert isinstance( env.observation_space, @@ -446,7 +448,7 @@ def __init__(self, env: gym.Env, dtype: Any): super().__init__(env, lambda obs: dtype(obs), new_observation_space) -class PixelObservationV0(LambdaObservationV0): +class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): """Augment observations by pixel values. Observations of this wrapper will be dictionaries of images. @@ -499,7 +501,7 @@ def __init__( ) -class NormalizeObservationV0(ObservationWrapper): +class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation @@ -511,7 +513,7 @@ class NormalizeObservationV0(ObservationWrapper): newly instantiated or the policy was changed recently. """ - def __init__(self, env: gym.Env, epsilon: float = 1e-8): + def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. Args: diff --git a/gymnasium/experimental/wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py index e373dab2d3..fcb838cb8a 100644 --- a/gymnasium/experimental/wrappers/lambda_reward.py +++ b/gymnasium/experimental/wrappers/lambda_reward.py @@ -3,7 +3,6 @@ * ``LambdaReward`` - Transforms the reward by a function * ``ClipReward`` - Clips the reward between a minimum and maximum value """ - from __future__ import annotations from typing import Any, Callable, SupportsFloat @@ -11,12 +10,12 @@ import numpy as np import gymnasium as gym -from gymnasium.core import WrapperActType, WrapperObsType +from gymnasium.core import ActType, ObsType from gymnasium.error import InvalidBound from gymnasium.experimental.wrappers.utils import RunningMeanStd -class LambdaRewardV0(gym.RewardWrapper): +class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]): """A reward wrapper that allows a custom function to modify the step reward. Example: @@ -32,7 +31,7 @@ class LambdaRewardV0(gym.RewardWrapper): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], func: Callable[[SupportsFloat], SupportsFloat], ): """Initialize LambdaRewardV0 wrapper. @@ -54,7 +53,7 @@ def reward(self, reward: SupportsFloat) -> SupportsFloat: return self.func(reward) -class ClipRewardV0(LambdaRewardV0): +class ClipRewardV0(LambdaRewardV0[ObsType, ActType]): """A wrapper that clips the rewards for an environment between an upper and lower bound. Example: @@ -70,7 +69,7 @@ class ClipRewardV0(LambdaRewardV0): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], min_reward: float | np.ndarray | None = None, max_reward: float | np.ndarray | None = None, ): @@ -93,7 +92,7 @@ def __init__( super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)) -class NormalizeRewardV0(gym.Wrapper): +class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. The exponential moving average will have variance :math:`(1 - \gamma)^2`. @@ -109,7 +108,7 @@ class NormalizeRewardV0(gym.Wrapper): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], gamma: float = 0.99, epsilon: float = 1e-8, ): @@ -138,8 +137,8 @@ def update_running_mean(self, setting: bool): self._update_running_mean = setting def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment, normalizing the reward returned.""" obs, reward, terminated, truncated, info = super().step(action) self.discounted_reward = self.discounted_reward * self.gamma * ( @@ -147,7 +146,7 @@ def step( ) + float(reward) return obs, self.normalize(float(reward)), terminated, truncated, info - def normalize(self, reward): + def normalize(self, reward: SupportsFloat): """Normalizes the rewards with the running mean rewards and their variance.""" if self._update_running_mean: self.rewards_running_means.update(self.discounted_reward) diff --git a/gymnasium/experimental/wrappers/rendering.py b/gymnasium/experimental/wrappers/rendering.py index 2dc5cc187e..795d3a7d6c 100644 --- a/gymnasium/experimental/wrappers/rendering.py +++ b/gymnasium/experimental/wrappers/rendering.py @@ -14,11 +14,11 @@ import gymnasium as gym from gymnasium import error, logger -from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType +from gymnasium.core import ActType, ObsType, RenderFrame from gymnasium.error import DependencyNotInstalled -class RenderCollectionV0(gym.Wrapper): +class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.""" def __init__( @@ -52,8 +52,8 @@ def render_mode(self): return f"{self.env.render_mode}_list" def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Perform a step in the base environment and collect a frame.""" output = super().step(action) self.frame_list.append(super().render()) @@ -61,7 +61,7 @@ def step( def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment, eventually clear the frame_list, and collect a frame.""" output = super().reset(seed=seed, options=options) @@ -71,7 +71,7 @@ def reset( return output - def render(self) -> RenderFrame | list[RenderFrame] | None: + def render(self) -> list[RenderFrame]: """Returns the collection of frames and, if pop_frames = True, clears it.""" frames = self.frame_list if self.pop_frames: @@ -80,7 +80,7 @@ def render(self) -> RenderFrame | list[RenderFrame] | None: return frames -class RecordVideoV0(gym.Wrapper): +class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """This wrapper records videos of rollouts. Usually, you only want to record episodes intermittently, say every hundredth episode. @@ -98,10 +98,10 @@ class RecordVideoV0(gym.Wrapper): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], video_folder: str, - episode_trigger: Callable[[int], bool] = None, - step_trigger: Callable[[int], bool] = None, + episode_trigger: Callable[[int], bool] | None = None, + step_trigger: Callable[[int], bool] | None = None, video_length: int = 0, name_prefix: str = "rl-video", disable_logger: bool = False, @@ -155,13 +155,13 @@ def capped_cubic_video_schedule(episode_id: int) -> bool: ) os.makedirs(self.video_folder, exist_ok=True) - self.name_prefix = name_prefix - self._video_name = None - self.frames_per_sec = self.metadata.get("render_fps", 30) - self.video_length = video_length if video_length != 0 else float("inf") - self.recording = False - self.recorded_frames = [] - self.render_history = [] + self.name_prefix: str = name_prefix + self._video_name: str | None = None + self.frames_per_sec: int = self.metadata.get("render_fps", 30) + self.video_length: int = video_length if video_length != 0 else float("inf") + self.recording: bool = False + self.recorded_frames: list[RenderFrame] = [] + self.render_history: list[RenderFrame] = [] self.step_id = -1 self.episode_id = -1 @@ -187,7 +187,7 @@ def _capture_frame(self): def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Reset the environment and eventually starts a new recording.""" obs, info = super().reset(seed=seed, options=options) self.episode_id += 1 @@ -205,8 +205,8 @@ def reset( return obs, info def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment using action, recording observations if :attr:`self.recording`.""" obs, rew, terminated, truncated, info = self.env.step(action) self.step_id += 1 @@ -221,7 +221,7 @@ def step( return obs, rew, terminated, truncated, info - def start_recording(self, video_name): + def start_recording(self, video_name: str): """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" if self.recording: self.stop_recording() @@ -252,7 +252,7 @@ def stop_recording(self): self.recording = False self._video_name = None - def render(self): + def render(self) -> RenderFrame | list[RenderFrame]: """Compute the render frames as specified by render_mode attribute during initialization of the environment.""" render_out = super().render() if self.recording and isinstance(render_out, List): @@ -277,7 +277,7 @@ def __del__(self): logger.warn("Unable to save last video! Did you call close()?") -class HumanRenderingV0(gym.Wrapper): +class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """Performs human rendering for an environment that only supports "rgb_array"rendering. This wrapper is particularly useful when you have implemented an environment that can produce @@ -311,7 +311,7 @@ class HumanRenderingV0(gym.Wrapper): [] """ - def __init__(self, env): + def __init__(self, env: gym.Env[ObsType, ActType]): """Initialize a :class:`HumanRendering` instance. Args: @@ -339,9 +339,7 @@ def render_mode(self): """Always returns ``'human'``.""" return "human" - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: """Perform a step in the base environment and render a frame to the screen.""" result = super().step(action) self._render_frame() @@ -349,13 +347,13 @@ def step( def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment and render a frame to the screen.""" result = super().reset(seed=seed, options=options) self._render_frame() return result - def render(self): + def render(self) -> None: """This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`.""" return None diff --git a/gymnasium/experimental/wrappers/stateful_action.py b/gymnasium/experimental/wrappers/stateful_action.py index 5d3a821c57..c466602444 100644 --- a/gymnasium/experimental/wrappers/stateful_action.py +++ b/gymnasium/experimental/wrappers/stateful_action.py @@ -4,18 +4,20 @@ from typing import Any import gymnasium as gym -from gymnasium.core import ActionWrapper, ActType, WrapperActType, WrapperObsType +from gymnasium.core import ActionWrapper, ActType, ObsType from gymnasium.error import InvalidProbability -class StickyActionV0(ActionWrapper): +class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]): """Wrapper which adds a probability of repeating the previous action. This wrapper follows the implementation proposed by `Machado et al., 2018 `_ in Section 5.2 on page 12. """ - def __init__(self, env: gym.Env, repeat_action_probability: float): + def __init__( + self, env: gym.Env[ObsType, ActType], repeat_action_probability: float + ): """Initialize StickyAction wrapper. Args: @@ -29,17 +31,17 @@ def __init__(self, env: gym.Env, repeat_action_probability: float): super().__init__(env) self.repeat_action_probability = repeat_action_probability - self.last_action: WrapperActType | None = None + self.last_action: ActType | None = None def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Reset the environment.""" self.last_action = None return super().reset(seed=seed, options=options) - def action(self, action: WrapperActType) -> ActType: + def action(self, action: ActType) -> ActType: """Execute the action.""" if ( self.last_action is not None diff --git a/gymnasium/experimental/wrappers/stateful_observation.py b/gymnasium/experimental/wrappers/stateful_observation.py index 061663015a..ef7210b331 100644 --- a/gymnasium/experimental/wrappers/stateful_observation.py +++ b/gymnasium/experimental/wrappers/stateful_observation.py @@ -26,10 +26,10 @@ from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate -class DelayObservationV0(gym.ObservationWrapper): +class DelayObservationV0(gym.ObservationWrapper[ObsType, ActType, ObsType]): """Wrapper which adds a delay to the returned observation.""" - def __init__(self, env: gym.Env, delay: int): + def __init__(self, env: gym.Env[ObsType, ActType], delay: int): """Initialize the DelayObservation wrapper. Args: @@ -45,13 +45,13 @@ def __init__(self, env: gym.Env, delay: int): assert 0 < delay self.delay: Final[int] = delay - self.observation_queue: Final[deque] = deque() + self.observation_queue: Final[deque[ObsType]] = deque() super().__init__(env) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment, clearing the observation queue.""" self.observation_queue.clear() @@ -67,7 +67,7 @@ def observation(self, observation: ObsType) -> ObsType: return jp.zeros_like(observation) -class TimeAwareObservationV0(gym.ObservationWrapper): +class TimeAwareObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]): """Augment the observation with time information of the episode. Time can be represented as a normalized value between [0,1] @@ -104,7 +104,7 @@ class TimeAwareObservationV0(gym.ObservationWrapper): def __init__( self, - env: gym.Env, + env: gym.Env[ObsType, ActType], flatten: bool = False, normalize_time: bool = True, *, @@ -212,7 +212,7 @@ def reset( return super().reset(seed=seed, options=options) -class FrameStackObservationV0(gym.Wrapper): +class FrameStackObservationV0(gym.Wrapper[WrapperObsType, ActType, ObsType, ActType]): """Observation wrapper that stacks the observations in a rolling manner. For example, if the number of stacks is 4, then the returned observation contains @@ -302,7 +302,7 @@ def reset( info, ) - def _init_stacked_obs(self) -> deque: + def _init_stacked_obs(self) -> deque[ObsType]: return deque( iterate( self.observation_space, diff --git a/gymnasium/utils/passive_env_checker.py b/gymnasium/utils/passive_env_checker.py index 46d220ae05..171f89d257 100644 --- a/gymnasium/utils/passive_env_checker.py +++ b/gymnasium/utils/passive_env_checker.py @@ -8,6 +8,13 @@ from gymnasium import Space, error, logger, spaces +__all__ = [ + "env_render_passive_checker", + "env_reset_passive_checker", + "env_step_passive_checker", +] + + def _check_box_observation_space(observation_space: spaces.Box): """Checks that a :class:`Box` observation space is defined in a sensible way. diff --git a/tests/experimental/wrappers/test_record_video.py b/tests/experimental/wrappers/test_record_video.py index 20b35d6e29..1dc05b11d6 100644 --- a/tests/experimental/wrappers/test_record_video.py +++ b/tests/experimental/wrappers/test_record_video.py @@ -23,6 +23,7 @@ def test_record_video_using_default_trigger(): env.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert env.episode_trigger is not None assert len(mp4_files) == sum( env.episode_trigger(i) for i in range(episode_count + 1) ) @@ -46,6 +47,7 @@ def test_record_video_while_rendering(): env.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert env.episode_trigger is not None assert len(mp4_files) == sum( env.episode_trigger(i) for i in range(episode_count + 1) )