Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
18b643b
add draft of SAC discrete implementation
timoklein Aug 29, 2022
c3c98bd
run pre-commit
timoklein Aug 29, 2022
ec31dc4
Use log softmax instead of author's log-pi code
timoklein Aug 31, 2022
deb37e8
Revert to cleanrl SAC delay implementation (it's more stable)
timoklein Aug 31, 2022
a1fdd2b
Remove docstrings and duplicate code
timoklein Aug 31, 2022
977a83a
Use correct clipreward wrapper
timoklein Aug 31, 2022
f2ea3e6
fix bug in log softmax calculation
timoklein Sep 6, 2022
48af04c
adhere to cleanrl log_prob naming
timoklein Sep 6, 2022
b2a09a0
fix bug in entropy target calculation
timoklein Sep 6, 2022
89680c7
change layer initialization to match existing cleanrl codebase
timoklein Sep 6, 2022
b1d7d44
working minimal diff version
timoklein Sep 19, 2022
61e1c74
implement original learning update frequency
timoklein Sep 20, 2022
7cd1e3a
parameterize the entropy scale for autotuning
timoklein Oct 3, 2022
61c46fc
add benchmarking script
timoklein Oct 3, 2022
4915e4c
rename target entropy factor and set new default value
timoklein Oct 6, 2022
6f7251f
add docs draft
timoklein Nov 5, 2022
23b60ff
fix SAC-discrete links to work pre merge
timoklein Nov 10, 2022
10ee9f0
add preliminary result table for SAC-discrete
timoklein Nov 10, 2022
8430fd8
clean up todos and add header
timoklein Nov 10, 2022
a17768c
minimize diff between sac_atari and sac_continuous
timoklein Nov 11, 2022
d6a507c
add sac-discrete end2end test
timoklein Nov 11, 2022
a7ea6f4
SAC-discrete docs rework
timoklein Nov 11, 2022
9f6493c
Update SAC-discrete @100k results
timoklein Nov 12, 2022
59a6d00
Fix doc links and unify naming in code
timoklein Nov 12, 2022
1304b7a
update docs
vwxyzjn Nov 13, 2022
3a3f41b
fix target update frequency (see PR #323)
timoklein Nov 24, 2022
80187ad
clarify comment regarding CNN encoder sharing
timoklein Nov 24, 2022
e9cb494
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein Nov 25, 2022
e199e39
fix benchmark installation
timoklein Nov 25, 2022
bb27fa1
fix eps in minimal diff version and improve code readability
timoklein Dec 3, 2022
6a46632
add docs for eps and finalize code
timoklein Dec 5, 2022
cad5fff
use no_grad for actor Q-vals and re-use action-probs & log-probs in a…
timoklein Dec 7, 2022
0cf47f1
update docs for new code and settings
timoklein Dec 14, 2022
61988c4
fix links to point to main branch
timoklein Dec 14, 2022
6e17005
update sac-discrete training plots
timoklein Dec 19, 2022
33b00f3
new sac-d training plots
timoklein Dec 19, 2022
5dabafb
update results table and fix link
timoklein Dec 19, 2022
90b2fd5
fix pong chart title
timoklein Dec 19, 2022
a763994
add Jimmy Ba name as exception to code spell check
timoklein Jan 13, 2023
071cdbb
change target_entropy_scale default value to same value as experiments
timoklein Jan 13, 2023
dcc2633
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein Jan 13, 2023
c671a92
remove blank line at end of pre-commit
timoklein Jan 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use no_grad for actor Q-vals and re-use action-probs & log-probs in a…
…lpha loss
  • Loading branch information
timoklein committed Dec 7, 2022
commit cad5fff474bf43e9463ccea1e5e42e14f1395db9
15 changes: 7 additions & 8 deletions cleanrl/sac_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def parse_args():
help="Entropy regularization coefficient.")
parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
help="automatic tuning of the entropy coefficient")
parser.add_argument("--target-entropy-scale", type=float, default=0.88,
parser.add_argument("--target-entropy-scale", type=float, default=0.9,
help="coefficient for scaling the autotune entropy target")
args = parser.parse_args()
# fmt: on
Expand Down Expand Up @@ -298,9 +298,10 @@ def get_action(self, x):

# ACTOR training
_, log_pi, action_probs = actor.get_action(data.observations)
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
min_qf_values = torch.min(qf1_values, qf2_values)
with torch.no_grad():
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
min_qf_values = torch.min(qf1_values, qf2_values)
# no need for reparameterization, the expectation can be calculated for discrete actions
actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean()

Expand All @@ -309,10 +310,8 @@ def get_action(self, x):
actor_optimizer.step()

if args.autotune:
# use action probabilities for temperature loss
with torch.no_grad():
_, log_pi, action_probs = actor.get_action(data.observations)
alpha_loss = (action_probs * (-log_alpha * (log_pi + target_entropy))).mean()
# re-use action probabilities for temperature loss
alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
Expand Down
2 changes: 1 addition & 1 deletion docs/rl-algorithms/sac.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ Surpassing Human-Level Performance on ImageNet Classification"](https://arxiv.or
alpha = log_alpha.exp().item()
```

4. [`sac_atari.py`](https://github.com/timoklein/cleanrl/blob/sac-discrete/cleanrl/sac_atari.py) uses `--target-entropy-scale=0.88` while the [SAC-discrete paper](https://arxiv.org/abs/1910.07207) uses `--target-entropy-scale=0.98` due to improved stability when training for more than 100k steps. Tuning this parameter to the environment at hand is advised and can lead to significant performance gains.
4. [`sac_atari.py`](https://github.com/timoklein/cleanrl/blob/sac-discrete/cleanrl/sac_atari.py) uses `--target-entropy-scale=0.9` while the [SAC-discrete paper](https://arxiv.org/abs/1910.07207) uses `--target-entropy-scale=0.98` due to improved stability when training for more than 100k steps. Tuning this parameter to the environment at hand is advised and can lead to significant performance gains.

5. [`sac_atari.py`](https://github.com/timoklein/cleanrl/blob/sac-discrete/cleanrl/sac_atari.py) performs learning updates only on every $n^{\text{th}}$ step. This leads to improved stability and prevents the agent's performance from degenerating during longer training runs.
Note the difference to [`sac_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py): [`sac_atari.py`](https://github.com/timoklein/cleanrl/blob/sac-discrete/cleanrl/sac_atari.py) updates every $n^{\text{th}}$ environment step and does a single update of actor and critic on every update step. [`sac_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py) updates the critic every step and the actor every $n^{\text{th}}$ step. It then compensates for the delayed actor updates by performing $n$ actor update steps.
Expand Down