Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
692492c
Update requirements
araffin May 29, 2022
f9b0fb5
Merge branch 'master' into feat/gym-0.24
araffin May 29, 2022
3172ee6
Updates for newest gym version
araffin May 29, 2022
ba37700
Load the env only when needed
araffin May 31, 2022
dd65644
Merge branch 'master' into feat/gym-0.24
araffin Jun 10, 2022
aefcea6
Pin seaborn dependency
araffin Jun 10, 2022
d685d32
Update dependencies
araffin Jun 14, 2022
e5ef2b3
Fix pytorch download
araffin Jun 14, 2022
7f4ecbb
Merge branch 'master' into feat/gym-0.24
araffin Jun 23, 2022
6829642
Update requirements
araffin Jun 23, 2022
60ada0e
Update scipy requirement
araffin Jun 23, 2022
90b27b4
No scipy needed
araffin Jun 23, 2022
cdb8474
Merge branch 'master' of https://github.com/DLR-RM/rl-baselines3-zoo …
qgallouedec Oct 26, 2022
65f2bbc
Load the env only when needed
qgallouedec Oct 26, 2022
15fe4a0
Remove duplicated seaborn dependency
qgallouedec Oct 26, 2022
9a7fedd
done to terminated, truncated
qgallouedec Oct 26, 2022
e98c2a2
Revert VecEnv step modification (still return done)
qgallouedec Oct 26, 2022
b5b8c24
gym26 reset format
qgallouedec Oct 26, 2022
b17afd4
lint step wrappers
qgallouedec Oct 26, 2022
cb8e1b7
update render
qgallouedec Oct 26, 2022
865c5e0
black
qgallouedec Oct 26, 2022
f425058
gym 0.24 -> gym 0.26
qgallouedec Oct 26, 2022
478198c
Update ale-py dependency
qgallouedec Oct 26, 2022
9937d2a
Merge branch 'master' into feat/gym-0.24
araffin Oct 31, 2022
d689a90
Fixes bundle for gym 0.26
araffin Oct 31, 2022
bbdf8a6
Fix requirement
araffin Oct 31, 2022
d508eea
Update highway env
araffin Nov 14, 2022
8144345
Merge branch 'master' into feat/gym-0.24
araffin Nov 28, 2022
50fb759
Merge branch 'master' into feat/gym-0.24
araffin Dec 13, 2022
022e64b
Merge branch 'master' into feat/gym-0.24
araffin Dec 18, 2022
aad58ab
Fix panda env check
araffin Dec 18, 2022
edf185b
Pass render mode for no vel env
araffin Dec 19, 2022
70ddfb1
Declared exported methods
araffin Dec 19, 2022
a7bdc5e
Skip panda env for now
araffin Dec 19, 2022
d258dab
Fixes due to render_mode
araffin Dec 19, 2022
f272bad
Upgrade highway env version
araffin Dec 20, 2022
068adc8
Add gym patches
araffin Dec 22, 2022
54aa5d1
Another patch for gym
araffin Dec 22, 2022
0affb2c
Fix lint warning
araffin Dec 22, 2022
4e019e8
Patch Atari game video recording
araffin Dec 23, 2022
23f7c1a
Enable tests when display is available
araffin Dec 23, 2022
c54b547
Merge branch 'master' into feat/gym-0.24
araffin Jan 2, 2023
7bac0b3
Fix type checker
araffin Jan 2, 2023
2e1ce9b
Merge branch 'master' into feat/gym-0.24
araffin Jan 5, 2023
8a69bf5
Merge branch 'master' into feat/gym-0.24
araffin Feb 13, 2023
305e0cc
Remove gitlab file
araffin Feb 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch - faster to download
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
pip install pybullet==3.1.9
pip install -r requirements.txt
# Use headless version
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch - faster to download
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
pip install pybullet==3.1.9
pip install -r requirements.txt
# Use headless version
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Release 1.7.0a1 (WIP)

### Breaking Changes
- Upgraded to gym 0.24

### New Features
- Specifying custom policies in yaml file is now supported (@Rick-v-E)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ for multiple, specify a list:

```yaml
env_wrapper:
- rl_zoo3.wrappers.DoneOnSuccessWrapper:
- rl_zoo3.wrappers.TruncatedOnSuccessWrapper:
reward_offset: 1.0
- sb3_contrib.common.wrappers.TimeFeatureWrapper
```
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/her.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ FetchSlide-v1:
FetchPickAndPlace-v1:
env_wrapper:
- sb3_contrib.common.wrappers.TimeFeatureWrapper
# - rl_zoo3.wrappers.DoneOnSuccessWrapper:
# - rl_zoo3.wrappers.TruncatedOnSuccessWrapper:
# reward_offset: 0
# n_successes: 4
# - stable_baselines3.common.monitor.Monitor
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ MiniGrid-FourRooms-v0:
learning_rate: 2.5e-4
clip_range: 0.2

CarRacing-v0:
CarRacing-v1:
env_wrapper:
- rl_zoo3.wrappers.FrameSkip:
skip: 2
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/sac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ MinitaurBulletDuckEnv-v0:
learning_starts: 10000

# To be tuned
CarRacing-v0:
CarRacing-v1:
env_wrapper:
- rl_zoo3.wrappers.FrameSkip:
skip: 2
Expand Down
17 changes: 10 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
gym==0.21
stable-baselines3[extra,tests,docs]>=1.6.2
sb3-contrib>=1.6.2
gym==0.24.1
# stable-baselines3[extra,tests,docs]>=1.5.1a7
git+https://github.com/carlosluis/stable-baselines3@fix_tests#egg=stable_baselines3[extra,tests,docs]
# sb3-contrib>=1.5.0
git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/new-gym-version
box2d-py==2.3.8
pybullet
gym-minigrid
scikit-optimize
optuna
# scikit-optimize
optuna~=2.10.1
pytablewriter~=0.64
pyyaml>=5.1
cloudpickle>=1.5.0
plotly
panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2
panda-gym~=2.0.2
rliable>=1.0.5
wandb
ale-py==0.7.5
huggingface_sb3>=2.2.1, <3.*
seaborn
seaborn~=0.11.2
tqdm
rich
importlib-metadata~=4.13 # flake8 not compatible with importlib-metadata>5.0
15 changes: 9 additions & 6 deletions rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,10 @@ def enjoy(): # noqa: C901
"clip_range": lambda _: 0.0,
}

model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, device=args.device, **kwargs)
if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""):
kwargs["env"] = env

model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs)
obs = env.reset()

# Deterministic by default except for atari games
Expand Down Expand Up @@ -218,9 +220,9 @@ def enjoy(): # noqa: C901
episode_start=episode_start,
deterministic=deterministic,
)
obs, reward, done, infos = env.step(action)
obs, reward, termination, truncation, infos = env.step(action)

episode_start = done
episode_start = termination or truncation

if not args.no_render:
env.render("human")
Expand All @@ -236,8 +238,8 @@ def enjoy(): # noqa: C901
if episode_infos is not None:
print(f"Atari Episode Score: {episode_infos['r']:.2f}")
print("Atari Episode Length", episode_infos["l"])

if done and not is_atari and args.verbose > 0:
# TODO: episode_start is a confusing name here, should we rename to episode_end?
if episode_start and not is_atari and args.verbose > 0:
# NOTE: for env using VecNormalize, the mean reward
# is a normalized reward when `--norm_reward` flag is passed
print(f"Episode Reward: {episode_reward:.2f}")
Expand All @@ -248,7 +250,8 @@ def enjoy(): # noqa: C901
ep_len = 0

# Reset also when the goal is achieved when using HER
if done and infos[0].get("is_success") is not None:
# TODO: episode_start is a confusing name here, should we rename to episode_end?
if episode_start and infos[0].get("is_success") is not None:
if args.verbose > 1:
print("Success?", infos[0].get("is_success", False))

Expand Down
6 changes: 5 additions & 1 deletion rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import warnings
from collections import OrderedDict
from copy import deepcopy
from pathlib import Path
from pprint import pprint
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -570,11 +571,14 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)

# On most env, SubprocVecEnv does not help and is quite memory hungry
# therefore we use DummyVecEnv by default
# Fix for gym 0.24, to keep old behavior
env_kwargs = deepcopy(self.env_kwargs)
env_kwargs.update(disable_env_checker=True)
env = make_vec_env(
env_id=self.env_name.gym_id,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
env_kwargs=env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
Expand Down
4 changes: 2 additions & 2 deletions rl_zoo3/record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@
episode_start=episode_starts,
deterministic=deterministic,
)
obs, _, dones, _ = env.step(action)
episode_starts = dones
obs, _, terminated, truncated, _ = env.step(action)
episode_starts = np.logical_or(terminated, truncated)
if not args.no_render:
env.render()
except KeyboardInterrupt:
Expand Down
5 changes: 5 additions & 0 deletions rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import glob
import importlib
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import gym
Expand Down Expand Up @@ -231,6 +232,10 @@ def create_test_env(
vec_env_cls = SubprocVecEnv
# start_method = 'spawn' for thread safe

# Fix for gym 0.24, to keep old behavior
env_kwargs = deepcopy(env_kwargs)
env_kwargs.update(disable_env_checker=True)

env = make_vec_env(
env_id,
n_envs=n_envs,
Expand Down
48 changes: 26 additions & 22 deletions rl_zoo3/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional

import gym
import numpy as np
from sb3_contrib.common.wrappers import TimeFeatureWrapper # noqa: F401 (backward compatibility)


class DoneOnSuccessWrapper(gym.Wrapper):
class TruncatedOnSuccessWrapper(gym.Wrapper):
"""
Reset on success and offsets the reward.
Useful for GoalEnv.
Expand All @@ -15,20 +17,21 @@ def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int =
self.n_successes = n_successes
self.current_successes = 0

def reset(self):
def reset(self, seed: Optional[int] = None):
self.current_successes = 0
return self.env.reset()
kwargs = {} if seed is None else {"seed": seed}
return self.env.reset(**kwargs)

def step(self, action):
obs, reward, done, info = self.env.step(action)
obs, reward, terminated, truncated, info = self.env.step(action)
if info.get("is_success", False):
self.current_successes += 1
else:
self.current_successes = 0
# number of successes in a row
done = done or self.current_successes >= self.n_successes
truncated = truncated or self.current_successes >= self.n_successes
reward += self.reward_offset
return obs, reward, done, info
return obs, reward, terminated, truncated, info

def compute_reward(self, achieved_goal, desired_goal, info):
reward = self.env.compute_reward(achieved_goal, desired_goal, info)
Expand Down Expand Up @@ -103,17 +106,17 @@ def reset(self):
return self.env.reset()

def step(self, action):
obs, reward, done, info = self.env.step(action)
obs, reward, terminated, truncated, info = self.env.step(action)

self.accumulated_reward += reward
self.current_step += 1

if self.current_step % self.delay == 0 or done:
if self.current_step % self.delay == 0 or terminated or truncated:
reward = self.accumulated_reward
self.accumulated_reward = 0.0
else:
reward = 0.0
return obs, reward, done, info
return obs, reward, terminated, truncated, info


class HistoryWrapper(gym.Wrapper):
Expand Down Expand Up @@ -155,24 +158,25 @@ def __init__(self, env: gym.Env, horizon: int = 2):
def _create_obs_from_history(self):
return np.concatenate((self.obs_history, self.action_history))

def reset(self):
def reset(self, seed: Optional[int] = None):
# Flush the history
self.obs_history[...] = 0
self.action_history[...] = 0
obs = self.env.reset()
kwargs = {} if seed is None else {"seed": seed}
obs = self.env.reset(**kwargs)
self.obs_history[..., -obs.shape[-1] :] = obs
return self._create_obs_from_history()

def step(self, action):
obs, reward, done, info = self.env.step(action)
obs, reward, terminated, truncated, info = self.env.step(action)
last_ax_size = obs.shape[-1]

self.obs_history = np.roll(self.obs_history, shift=-last_ax_size, axis=-1)
self.obs_history[..., -obs.shape[-1] :] = obs

self.action_history = np.roll(self.action_history, shift=-action.shape[-1], axis=-1)
self.action_history[..., -action.shape[-1] :] = action
return self._create_obs_from_history(), reward, done, info
return self._create_obs_from_history(), reward, terminated, truncated, info


class HistoryWrapperObsDict(gym.Wrapper):
Expand Down Expand Up @@ -214,11 +218,12 @@ def __init__(self, env: gym.Env, horizon: int = 2):
def _create_obs_from_history(self):
return np.concatenate((self.obs_history, self.action_history))

def reset(self):
def reset(self, seed: Optional[int] = None):
# Flush the history
self.obs_history[...] = 0
self.action_history[...] = 0
obs_dict = self.env.reset()
kwargs = {} if seed is None else {"seed": seed}
obs_dict = self.env.reset(**kwargs)
obs = obs_dict["observation"]
self.obs_history[..., -obs.shape[-1] :] = obs

Expand All @@ -227,7 +232,7 @@ def reset(self):
return obs_dict

def step(self, action):
obs_dict, reward, done, info = self.env.step(action)
obs_dict, reward, terminated, truncated, info = self.env.step(action)
obs = obs_dict["observation"]
last_ax_size = obs.shape[-1]

Expand All @@ -239,7 +244,7 @@ def step(self, action):

obs_dict["observation"] = self._create_obs_from_history()

return obs_dict, reward, done, info
return obs_dict, reward, terminated, truncated, info


class FrameSkip(gym.Wrapper):
Expand All @@ -260,17 +265,16 @@ def step(self, action: np.ndarray):
Repeat action, sum reward.

:param action: the action
:return: observation, reward, done, information
:return: observation, reward, terminated, truncated, information
"""
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
obs, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
if done:
if terminated or truncated:
break

return obs, total_reward, done, info
return obs, total_reward, terminated, truncated, info

def reset(self):
return self.env.reset()
Expand Down
1 change: 1 addition & 0 deletions tests/test_hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_optimize(tmp_path, sampler, pruner, experiment):
args = ["-n", str(N_STEPS), "--algo", algo, "--env", env_id, "-params", 'policy_kwargs:"dict(net_arch=[32])"', "n_envs:1"]
args += ["n_steps:10"] if algo == "ppo" else []
args += [
"--no-optim-plots",
"--seed",
"14",
"--log-folder",
Expand Down