Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
87a30fb
Qdagger
vwxyzjn Jan 9, 2023
f315b7c
Merge branch 'master' into qdagger
sdpkjc May 6, 2023
691e4f3
Merge branch 'master' into qdagger
sdpkjc May 6, 2023
553f0d9
qdagger gym to gymnasium, fix bugs and get it working again
sdpkjc May 6, 2023
4ee3e03
add Student ReplayBuffer & epsilon_greedy in online phase
sdpkjc May 8, 2023
1910383
preliminary torch implementation
sdpkjc May 12, 2023
bfef85b
fix qdagger_dqn_atari_impalacnn
sdpkjc May 19, 2023
0eb0cca
add tests
sdpkjc May 19, 2023
3a766b3
add tests
sdpkjc May 19, 2023
0395f53
Merge branch 'master' into qdagger
sdpkjc May 19, 2023
ec8b455
fix set real_next_obs
sdpkjc May 21, 2023
dd795c8
update style
sdpkjc May 29, 2023
fd87abc
fix update jax target network in offline phase
sdpkjc May 29, 2023
7dffccf
Fix & Overhaul
sdpkjc May 30, 2023
9800615
Make rich to default package & Add specific algorithm dependencies env
sdpkjc May 30, 2023
e871113
fix requirements files
sdpkjc May 30, 2023
f20fad3
fix requirements files
sdpkjc May 30, 2023
d21a2ec
preliminary docs
sdpkjc Jun 5, 2023
e223cdd
fix docs
sdpkjc Jun 5, 2023
cd17955
Merge branch 'master' into qdagger
sdpkjc Jun 6, 2023
a1f88de
update docs link in script
sdpkjc Jun 7, 2023
6fe0d2d
add learning curves imgs
sdpkjc Jun 7, 2023
3cd410e
update experiment results
sdpkjc Jun 7, 2023
7443b50
enhance docs
sdpkjc Jun 8, 2023
6c6dded
add benchmark script
sdpkjc Jun 8, 2023
f787ea6
update docs
sdpkjc Jun 8, 2023
d9b206c
update docs
sdpkjc Jun 8, 2023
dd73a4b
auto set --teacher-policy-hf-repo
sdpkjc Jun 9, 2023
1a0a412
fix & update docs
sdpkjc Jun 9, 2023
9200f34
fix pre-commit
sdpkjc Jun 9, 2023
bf81e5d
fix auto set teacher_policy_hf_repo
sdpkjc Jun 9, 2023
4e2655f
Update qdagger.md
vwxyzjn Jun 9, 2023
6ac8f62
add compare plots
sdpkjc Jun 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
preliminary docs
  • Loading branch information
sdpkjc committed Jun 5, 2023
commit d21a2ec15bad1309a41f94e94a8b340bc4a26167
23 changes: 23 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,26 @@ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

--------------------------------------------------------------------------------

Code in `cleanrl/qdagger_dqn_atari_impalacnn.py` and `cleanrl/qdagger_dqn_atari_jax_impalacnn.py` are adapted from https://github.com/google-research/reincarnating_rl

**NOTE: the original repo did not fill out the copyright section in their license
so the following copyright notice is copied as is per the license requirement.
See https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/LICENSE#L189


Copyright [yyyy] [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ You may also use a prebuilt development environment hosted in Gitpod:
| | [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy) |
| ✅ [Phasic Policy Gradient (PPG)](https://arxiv.org/abs/2009.04416) | [`ppg_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppg/#ppg_procgenpy) |
| ✅ [Random Network Distillation (RND)](https://arxiv.org/abs/1810.12894) | [`ppo_rnd_envpool.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_rnd_envpool.py), [docs](/rl-algorithms/ppo-rnd/#ppo_rnd_envpoolpy) |
| ✅ [Qdagger](https://arxiv.org/abs/2206.01626) | [`qdagger_dqn_atari_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_impalacnnpy) |
| | [`qdagger_dqn_atari_jax_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy) |


## Open RL Benchmark
Expand Down
2 changes: 2 additions & 0 deletions docs/rl-algorithms/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@
| | :material-github: [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py), :material-file-document: [docs](/rl-algorithms/td3/#td3_continuous_action_jaxpy) |
| ✅ [Phasic Policy Gradient (PPG)](https://arxiv.org/abs/2009.04416) | :material-github: [`ppg_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py), :material-file-document: [docs](/rl-algorithms/ppg/#ppg_procgenpy) |
| ✅ [Random Network Distillation (RND)](https://arxiv.org/abs/1810.12894) | :material-github: [`ppo_rnd_envpool.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_rnd_envpool.py), :material-file-document: [docs](/rl-algorithms/ppo-rnd/#ppo_rnd_envpoolpy) |
| ✅ [Qdagger](https://arxiv.org/abs/2206.01626) | :material-github: [`qdagger_dqn_atari_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py), :material-file-document: [docs](https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_impalacnnpy) |
| | :material-github: [`qdagger_dqn_atari_jax_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py), :material-file-document: [docs](https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy) |
125 changes: 125 additions & 0 deletions docs/rl-algorithms/qdagger.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Qdagger

## Overview

As an extension of the Q-learning, DQN's main technical contribution is the use of replay buffer and target network, both of which would help improve the stability of the algorithm.


Original papers:

* [Reincarnating Reinforcement Learning: Reusing Prior Computation to Accelerate Progress](https://arxiv.org/abs/2206.01626)

## Implemented Variants

| Variants Implemented | Description |
| ----------- | ----------- |
| :material-github: [`qdagger_dqn_atari_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py), :material-file-document: [docs](/rl-algorithms/qdagger/#qdagger_dqn_atari_impalacnnpy) | For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |
| :material-github: [`qdagger_dqn_atari_jax_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py), :material-file-document: [docs](/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy) | For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |


Below are our single-file implementations of Qdagger:


## `qdagger_dqn_atari_impalacnn.py`

The [qdagger_dqn_atari_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py) has the following features:

* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discrete` action space

### Usage

```bash
poetry install -E atari
python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id BreakoutNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/BreakoutNoFrameskip-v4-dqn_atari-seed1
python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id PongNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/PongNoFrameskip-v4-dqn_atari-seed1
```

=== "poetry"

```bash
poetry install -E atari
poetry run python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id BreakoutNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/BreakoutNoFrameskip-v4-dqn_atari-seed1
poetry run python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id PongNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/PongNoFrameskip-v4-dqn_atari-seed1
```

=== "pip"

```bash
pip install -r requirements/requirements-atari.txt
python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id BreakoutNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/BreakoutNoFrameskip-v4-dqn_atari-seed1
python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id PongNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/PongNoFrameskip-v4-dqn_atari-seed1
```


### Explanation of the logged metrics

Running `python cleanrl/qdagger_dqn_atari_impalacnn.py` will automatically record various metrics such as actor or value losses in Tensorboard. Below is the documentation for these metrics:

* `charts/episodic_return`: episodic return of the game
* `charts/SPS`: number of steps per second
* `losses/td_loss`: the mean squared error (MSE) between the Q values at timestep $t$ and the Bellman update target estimated using the reward $r_t$ and the Q values at timestep $t+1$, thus minimizing the *one-step* temporal difference. Formally, it can be expressed by the equation below.
$$
J(\theta^{Q}) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \big[ (Q(s, a) - y)^2 \big],
$$
with the Bellman update target is $y = r + \gamma \, Q^{'}(s', a')$ and the replay buffer is $\mathcal{D}$.
* `losses/q_values`: implemented as `qf1(data.observations, data.actions).view(-1)`, it is the average Q values of the sampled data in the replay buffer; useful when gauging if under or over estimation happens.


### Implementation details

WIP

### Experiment results

WIP

## `qdagger_dqn_atari_jax_impalacnn.py`


The [qdagger_dqn_atari_jax_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py) has the following features:

* Uses [Jax](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Optax](https://github.com/deepmind/optax) instead of `torch`. [qdagger_dqn_atari_jax_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py) is roughly 25%-50% faster than [qdagger_dqn_atari_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py)
* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discrete` action space

### Usage


=== "poetry"

```bash
poetry install -E "atari jax"
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id BreakoutNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1
poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id PongNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/PongNoFrameskip-v4-dqn_atari_jax-seed1
```

=== "pip"

```bash
pip install -r requirements/requirements-atari.txt
pip install -r requirements/requirements-jax.txt
pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id BreakoutNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1
python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id PongNoFrameskip-v4 --teacher-policy-hf-repo cleanrl/PongNoFrameskip-v4-dqn_atari_jax-seed1
```


???+ warning

Note that JAX does not work in Windows :fontawesome-brands-windows:. The official [docs](https://github.com/google/jax#installation) recommends using Windows Subsystem for Linux (WSL) to install JAX.

### Explanation of the logged metrics

See [related docs](/rl-algorithms/qdagger/#explanation-of-the-logged-metrics) for `qdagger_dqn_atari_impalacnn.py`.

### Implementation details

See [related docs](/rl-algorithms/qdagger/#implementation-details) for `qdagger_dqn_atari_impalacnn.py`.

### Experiment results

WIP
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ nav:
- rl-algorithms/ppg.md
- rl-algorithms/ppo-rnd.md
- rl-algorithms/rpo.md
- rl-algorithms/rpo.md
- Advanced:
- advanced/hyperparameter-tuning.md
- advanced/resume-training.md
Expand Down