From 54e661ddcb91a32099e5489279b90c37acfbe9b8 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 27 Dec 2022 21:14:41 -0500 Subject: [PATCH] chore: simply ppo gae code --- cleanrl/ppo_atari_envpool_xla_jax_scan.py | 6 ++++-- tests/test_jax_compute_gae.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cleanrl/ppo_atari_envpool_xla_jax_scan.py b/cleanrl/ppo_atari_envpool_xla_jax_scan.py index 9cbbe67fb..7f374cef1 100644 --- a/cleanrl/ppo_atari_envpool_xla_jax_scan.py +++ b/cleanrl/ppo_atari_envpool_xla_jax_scan.py @@ -323,8 +323,10 @@ def compute_gae( _, advantages = jax.lax.scan( compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True ) - storage = storage.replace(advantages=advantages) - storage = storage.replace(returns=storage.advantages + storage.values) + storage = storage.replace( + advantages=advantages, + returns=storage.advantages + storage.values, + ) return storage def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): diff --git a/tests/test_jax_compute_gae.py b/tests/test_jax_compute_gae.py index de0763fa0..9f4f0f739 100644 --- a/tests/test_jax_compute_gae.py +++ b/tests/test_jax_compute_gae.py @@ -39,8 +39,10 @@ def compute_gae_scan( _, advantages = jax.lax.scan( compute_gae_once_fn, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True ) - storage = storage.replace(advantages=advantages) - storage = storage.replace(returns=storage.advantages + storage.values) + storage = storage.replace( + advantages=advantages, + returns=storage.advantages + storage.values, + ) return storage def compute_gae_python_loop(