Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions gymnasium/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def np_random(self) -> np.random.Generator:
Instances of `np.random.Generator`
"""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
self._np_random, _ = seeding.np_random()
return self._np_random

@np_random.setter
Expand Down Expand Up @@ -234,7 +234,10 @@ def __exit__(self, *args: Any):
WrapperActType = TypeVar("WrapperActType")


class Wrapper(Env[WrapperObsType, WrapperActType]):
class Wrapper(
Env[WrapperObsType, WrapperActType],
Generic[WrapperObsType, WrapperActType, ObsType, ActType],
):
"""Wraps a :class:`gymnasium.Env` to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.

This class is the base class of all wrappers to change the behavior of the underlying environment.
Expand Down Expand Up @@ -391,7 +394,7 @@ def unwrapped(self) -> Env[ObsType, ActType]:
return self.env.unwrapped


class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`.

If you would like to apply a function to only the observation before
Expand Down Expand Up @@ -434,7 +437,7 @@ def observation(self, observation: ObsType) -> WrapperObsType:
raise NotImplementedError


class RewardWrapper(Wrapper[ObsType, ActType]):
class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
"""Superclass of wrappers that can modify the returning reward from a step.

If you would like to apply a function to the reward that is returned by the base environment before
Expand Down Expand Up @@ -467,7 +470,7 @@ def reward(self, reward: SupportsFloat) -> SupportsFloat:
raise NotImplementedError


class ActionWrapper(Wrapper[ObsType, WrapperActType]):
class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
"""Superclass of wrappers that can modify the action before :meth:`env.step`.

If you would like to apply a function to the action before passing it to the base environment,
Expand Down
52 changes: 27 additions & 25 deletions gymnasium/experimental/wrappers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
Expand All @@ -26,10 +26,10 @@
)


class AutoresetV0(gym.Wrapper):
class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""

def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.

Args:
Expand All @@ -40,8 +40,8 @@ def __init__(self, env: gym.Env):
self._reset_options: dict[str, Any] | None = None

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.

Args:
Expand All @@ -51,7 +51,7 @@ def step(
The autoreset environment :meth:`step`
"""
if self._episode_ended:
obs, info = super().reset(options=self._reset_options)
obs, info = self.env.reset(options=self._reset_options)
self._episode_ended = True
return obs, 0, False, False, info
else:
Expand All @@ -61,14 +61,14 @@ def step(

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment, saving the options used."""
self._episode_ended = False
self._reset_options = options
return super().reset(seed=seed, options=self._reset_options)


class PassiveEnvCheckerV0(gym.Wrapper):
class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""

def __init__(self, env: Env[ObsType, ActType]):
Expand All @@ -89,8 +89,8 @@ def __init__(self, env: Env[ObsType, ActType]):
self._checked_render: bool = False

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self._checked_step is False:
self._checked_step = True
Expand All @@ -100,7 +100,7 @@ def step(

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self._checked_reset is False:
self._checked_reset = True
Expand All @@ -117,7 +117,7 @@ def render(self) -> RenderFrame | list[RenderFrame] | None:
return self.env.render()


class OrderEnforcingV0(gym.Wrapper):
class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.

Example:
Expand All @@ -139,7 +139,11 @@ class OrderEnforcingV0(gym.Wrapper):
>>> env.close()
"""

def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
def __init__(
self,
env: gym.Env[ObsType, ActType],
disable_render_order_enforcing: bool = False,
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.

Args:
Expand All @@ -150,17 +154,15 @@ def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment with `kwargs`."""
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment."""
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return super().step(action)

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment with `kwargs`."""
self._has_reset = True
return super().reset(seed=seed, options=options)
Expand All @@ -180,7 +182,7 @@ def has_reset(self):
return self._has_reset


class RecordEpisodeStatisticsV0(gym.Wrapper):
class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""This wrapper will keep track of cumulative rewards and episode lengths.

At the end of an episode, the statistics of the episode will be added to ``info``
Expand Down Expand Up @@ -244,13 +246,13 @@ def __init__(
self.episode_reward: float = -1
self.episode_length: int = -1

self.episode_time_length_buffer = deque(maxlen=buffer_length)
self.episode_reward_buffer = deque(maxlen=buffer_length)
self.episode_length_buffer = deque(maxlen=buffer_length)
self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length)
self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length)
self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length)

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, recording the episode statistics."""
obs, reward, terminated, truncated, info = super().step(action)

Expand Down Expand Up @@ -279,7 +281,7 @@ def step(

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
obs, info = super().reset(seed=seed, options=options)

Expand Down
6 changes: 3 additions & 3 deletions gymnasium/experimental/wrappers/conversion/jax_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np

from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled


Expand Down Expand Up @@ -92,7 +92,7 @@ def _iterable_jax_to_numpy(
return type(value)(jax_to_numpy(v) for v in value)


class JaxToNumpyV0(Wrapper):
class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
"""Wraps a jax environment so that it can be interacted with through numpy arrays.

Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
Expand All @@ -102,7 +102,7 @@ class JaxToNumpyV0(Wrapper):
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
"""

def __init__(self, env: Env):
def __init__(self, env: Env[ObsType, ActType]):
"""Wraps an environment such that the input and outputs are numpy arrays.

Args:
Expand Down
16 changes: 8 additions & 8 deletions gymnasium/experimental/wrappers/lambda_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
import numpy as np

import gymnasium as gym
from gymnasium.core import ActType, WrapperActType
from gymnasium.core import ActType, ObsType, WrapperActType
from gymnasium.spaces import Box, Space


class LambdaActionV0(gym.ActionWrapper):
class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""

def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
func: Callable[[WrapperActType], ActType],
action_space: Space | None,
action_space: Space[WrapperActType] | None,
):
"""Initialize LambdaAction.

Expand All @@ -47,7 +47,7 @@ def action(self, action: WrapperActType) -> ActType:
return self.func(action)


class ClipActionV0(LambdaActionV0):
class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
"""Clip the continuous action within the valid :class:`Box` observation space bound.

Example:
Expand All @@ -63,7 +63,7 @@ class ClipActionV0(LambdaActionV0):
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
"""

def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""A wrapper for clipping continuous actions within the valid bound.

Args:
Expand All @@ -83,7 +83,7 @@ def __init__(self, env: gym.Env):
)


class RescaleActionV0(LambdaActionV0):
class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].

The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
Expand All @@ -107,7 +107,7 @@ class RescaleActionV0(LambdaActionV0):

def __init__(
self,
env: gym.Env,
env: gym.Env[ObsType, ActType],
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
Expand Down
Loading