-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Update the type hinting for core.py #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
046c18c
1cd0a4c
8eaa5bd
c43e48c
cb964fa
8cbf484
8559e12
51cfe93
cb6c92c
19224f5
53f3f88
78cec07
5f04cd8
ec71d7d
c49e12e
53b9ac6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,14 +76,16 @@ class Env(Generic[ObsType, ActType]): | |
| def np_random(self) -> np.random.Generator: | ||
| """Returns the environment's internal :attr:`_np_random` that if not set will initialize with a random seed.""" | ||
| if self._np_random is None: | ||
| self._np_random, seed = seeding.np_random() | ||
| self._np_random, _ = seeding.np_random() | ||
| return self._np_random | ||
|
|
||
| @np_random.setter | ||
| def np_random(self, value: np.random.Generator): | ||
| self._np_random = value | ||
|
|
||
| def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: | ||
| def step( | ||
| self, action: ActType | ||
| ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: | ||
| """Run one timestep of the environment's dynamics. | ||
|
|
||
| When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. | ||
|
|
@@ -113,8 +115,8 @@ def reset( | |
| self, | ||
| *, | ||
| seed: Optional[int] = None, | ||
| options: Optional[dict] = None, | ||
| ) -> Tuple[ObsType, dict]: | ||
| options: Optional[Dict[str, Any]] = None, | ||
| ) -> Tuple[ObsType, Dict[str, Any]]: | ||
| """Resets the environment to an initial state and returns the initial observation. | ||
|
|
||
| This method can reset the environment's random number generator(s) if ``seed`` is an integer or | ||
|
|
@@ -134,7 +136,6 @@ def reset( | |
| options (optional dict): Additional information to specify how the environment is reset (optional, | ||
| depending on the specific environment) | ||
|
|
||
|
|
||
| Returns: | ||
| observation (object): Observation of the initial state. This will be an element of :attr:`observation_space` | ||
| (typically a numpy array) and is analogous to the observation returned by :meth:`step`. | ||
|
|
@@ -175,7 +176,7 @@ def close(self): | |
| pass | ||
|
|
||
| @property | ||
| def unwrapped(self) -> "Env": | ||
| def unwrapped(self) -> "Env[ObsType, ActType]": | ||
| """Returns the base non-wrapped environment. | ||
|
|
||
| Returns: | ||
|
|
@@ -194,14 +195,18 @@ def __enter__(self): | |
| """Support with-statement for the environment.""" | ||
| return self | ||
|
|
||
| def __exit__(self, *args): | ||
| def __exit__(self, *args: List[Any]): | ||
pseudo-rnd-thoughts marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Support with-statement for the environment.""" | ||
| self.close() | ||
| # propagate exception | ||
| return False | ||
|
|
||
|
|
||
| class Wrapper(Env[ObsType, ActType]): | ||
| WrapperObsType = TypeVar("WrapperObsType") | ||
| WrapperActType = TypeVar("WrapperActType") | ||
|
|
||
|
|
||
| class Wrapper(Env[WrapperObsType, WrapperActType]): | ||
| """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. | ||
|
|
||
| This class is the base class for all wrappers. The subclass could override | ||
|
|
@@ -212,55 +217,63 @@ class Wrapper(Env[ObsType, ActType]): | |
| Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. | ||
| """ | ||
|
|
||
| def __init__(self, env: Env): | ||
| def __init__(self, env: Env[ObsType, ActType]): | ||
| """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. | ||
|
|
||
| Args: | ||
| env: The environment to wrap | ||
| """ | ||
| self.env = env | ||
|
|
||
| self._action_space: Optional[spaces.Space] = None | ||
| self._observation_space: Optional[spaces.Space] = None | ||
| self._action_space: Optional[spaces.Space[WrapperActType]] = None | ||
| self._observation_space: Optional[spaces.Space[WrapperObsType]] = None | ||
| self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None | ||
| self._metadata: Optional[dict] = None | ||
| self._metadata: Optional[Dict[str, Any]] = None | ||
|
|
||
| def __getattr__(self, name): | ||
| def __getattr__(self, name: str): | ||
| """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" | ||
| if name.startswith("_"): | ||
| if name == "_np_random": | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not entirely sure about this, it's out of scope, and seems like a guardrail that's not necessarily needed.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a guardrail that already existed, but it was bugged so didn't work. I had to modify this function to re-add it such that the correct error occurred. We could replace it with a warning? I added it to gym a couple of months ago as I thought I had found a bug, but what was accurately happening was that I was using the wrapper
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine for now, we can reconsider this guardrail in general in the future |
||
| raise AttributeError( | ||
| "Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`." | ||
| ) | ||
| elif name.startswith("_"): | ||
| raise AttributeError(f"accessing private attribute '{name}' is prohibited") | ||
| return getattr(self.env, name) | ||
|
|
||
| @property | ||
| def spec(self): | ||
| def spec(self) -> "EnvSpec": | ||
| """Returns the environment specification.""" | ||
| return self.env.spec | ||
|
|
||
| @classmethod | ||
| def class_name(cls): | ||
| def class_name(cls) -> str: | ||
| """Returns the class name of the wrapper.""" | ||
| return cls.__name__ | ||
|
|
||
| @property | ||
| def action_space(self) -> spaces.Space[ActType]: | ||
| def action_space( | ||
| self, | ||
| ) -> Union[spaces.Space[ActType], spaces.Space[WrapperActType]]: | ||
| """Returns the action space of the environment.""" | ||
| if self._action_space is None: | ||
| return self.env.action_space | ||
| return self._action_space | ||
|
|
||
| @action_space.setter | ||
| def action_space(self, space: spaces.Space): | ||
| def action_space(self, space: spaces.Space[WrapperActType]): | ||
| self._action_space = space | ||
|
|
||
| @property | ||
| def observation_space(self) -> spaces.Space: | ||
| def observation_space( | ||
| self, | ||
| ) -> Union[spaces.Space[ObsType], spaces.Space[WrapperObsType]]: | ||
| """Returns the observation space of the environment.""" | ||
| if self._observation_space is None: | ||
| return self.env.observation_space | ||
| return self._observation_space | ||
|
|
||
| @observation_space.setter | ||
| def observation_space(self, space: spaces.Space): | ||
| def observation_space(self, space: spaces.Space[WrapperObsType]): | ||
| self._observation_space = space | ||
|
|
||
| @property | ||
|
|
@@ -275,14 +288,14 @@ def reward_range(self, value: Tuple[SupportsFloat, SupportsFloat]): | |
| self._reward_range = value | ||
|
|
||
| @property | ||
| def metadata(self) -> dict: | ||
| def metadata(self) -> Dict[str, Any]: | ||
| """Returns the environment metadata.""" | ||
| if self._metadata is None: | ||
| return self.env.metadata | ||
| return self._metadata | ||
|
|
||
| @metadata.setter | ||
| def metadata(self, value): | ||
| def metadata(self, value: Dict[str, Any]): | ||
| self._metadata = value | ||
|
|
||
| @property | ||
|
|
@@ -296,11 +309,15 @@ def np_random(self) -> np.random.Generator: | |
| return self.env.np_random | ||
|
|
||
| @np_random.setter | ||
| def np_random(self, value): | ||
| def np_random(self, value: np.random.Generator): | ||
| self.env.np_random = value | ||
|
|
||
| @property | ||
| def _np_random(self): | ||
| """This code will never be run due to __getattr__ being called prior this. | ||
|
|
||
| It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable. | ||
| """ | ||
| raise AttributeError( | ||
| "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." | ||
| ) | ||
|
|
@@ -309,15 +326,15 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: | |
| """Steps through the environment with action.""" | ||
| return self.env.step(action) | ||
|
|
||
| def reset(self, **kwargs) -> Tuple[ObsType, dict]: | ||
| """Resets the environment with kwargs.""" | ||
| return self.env.reset(**kwargs) | ||
| def reset( | ||
| self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None | ||
| ) -> Tuple[WrapperObsType, Dict[str, Any]]: | ||
| """Resets the environment with a seed and options.""" | ||
| return self.env.reset(seed=seed, options=options) | ||
|
|
||
| def render( | ||
| self, *args, **kwargs | ||
| ) -> Optional[Union[RenderFrame, List[RenderFrame]]]: | ||
| def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: | ||
| """Renders the environment.""" | ||
| return self.env.render(*args, **kwargs) | ||
| return self.env.render() | ||
|
|
||
| def close(self): | ||
| """Closes the environment.""" | ||
|
|
@@ -332,12 +349,12 @@ def __repr__(self): | |
| return str(self) | ||
|
|
||
| @property | ||
| def unwrapped(self) -> Env: | ||
| def unwrapped(self) -> Env[ObsType, ActType]: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessarily true - consider the wrappers that discretize actions, or add pixel observations
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, removed |
||
| """Returns the base environment of the wrapper.""" | ||
| return self.env.unwrapped | ||
|
|
||
|
|
||
| class ObservationWrapper(Wrapper): | ||
| class ObservationWrapper(Wrapper[WrapperObsType, ActType]): | ||
| """Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`. | ||
|
|
||
| If you would like to apply a function to the observation that is returned by the base environment before | ||
|
|
@@ -365,22 +382,26 @@ def observation(self, obs): | |
| index of the timestep to the observation. | ||
| """ | ||
|
|
||
| def reset(self, **kwargs): | ||
| def reset( | ||
| self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None | ||
| ) -> Tuple[WrapperObsType, Dict[str, Any]]: | ||
| """Resets the environment, returning a modified observation using :meth:`self.observation`.""" | ||
| obs, info = self.env.reset(**kwargs) | ||
| obs, info = self.env.reset(seed=seed, options=options) | ||
| return self.observation(obs), info | ||
|
|
||
| def step(self, action): | ||
| def step( | ||
| self, action: ActType | ||
| ) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: | ||
| """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" | ||
| observation, reward, terminated, truncated, info = self.env.step(action) | ||
| return self.observation(observation), reward, terminated, truncated, info | ||
|
|
||
| def observation(self, observation): | ||
| def observation(self, observation: ObsType) -> WrapperObsType: | ||
| """Returns a modified observation.""" | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class RewardWrapper(Wrapper): | ||
| class RewardWrapper(Wrapper[ObsType, ActType]): | ||
| """Superclass of wrappers that can modify the returning reward from a step. | ||
|
|
||
| If you would like to apply a function to the reward that is returned by the base environment before | ||
|
|
@@ -393,28 +414,30 @@ class RewardWrapper(Wrapper): | |
| because it is intrinsic), we want to clip the reward to a range to gain some numerical stability. | ||
| To do that, we could, for instance, implement the following wrapper:: | ||
|
|
||
| class ClipReward(gym.RewardWrapper): | ||
| class ClipReward(gymnasium.RewardWrapper): | ||
pseudo-rnd-thoughts marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self, env, min_reward, max_reward): | ||
| super().__init__(env) | ||
| self.min_reward = min_reward | ||
| self.max_reward = max_reward | ||
| self.reward_range = (min_reward, max_reward) | ||
|
|
||
| def reward(self, reward): | ||
| return np.clip(reward, self.min_reward, self.max_reward) | ||
| def reward(self, r: float) -> float: | ||
| return np.clip(r, self.min_reward, self.max_reward) | ||
| """ | ||
|
|
||
| def step(self, action): | ||
| def step( | ||
| self, action: ActType | ||
| ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: | ||
| """Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" | ||
| observation, reward, terminated, truncated, info = self.env.step(action) | ||
| return observation, self.reward(reward), terminated, truncated, info | ||
|
|
||
| def reward(self, reward): | ||
| def reward(self, reward: SupportsFloat) -> float: | ||
| """Returns a modified ``reward``.""" | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class ActionWrapper(Wrapper): | ||
| class ActionWrapper(Wrapper[ObsType, WrapperActType]): | ||
| """Superclass of wrappers that can modify the action before :meth:`env.step`. | ||
|
|
||
| If you would like to apply a function to the action before passing it to the base environment, | ||
|
|
@@ -446,14 +469,16 @@ def action(self, act): | |
| Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction`. | ||
| """ | ||
|
|
||
| def step(self, action): | ||
| def step( | ||
| self, action: WrapperActType | ||
| ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: | ||
| """Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`.""" | ||
| return self.env.step(self.action(action)) | ||
|
|
||
| def action(self, action): | ||
| def action(self, action: WrapperActType) -> ActType: | ||
| """Returns a modified action before :meth:`env.step` is called.""" | ||
| raise NotImplementedError | ||
|
|
||
| def reverse_action(self, action): | ||
| def reverse_action(self, action: ActType) -> WrapperActType: | ||
| """Returns a reversed ``action``.""" | ||
| raise NotImplementedError | ||
Uh oh!
There was an error while loading. Please reload this page.