Skip to content
Closed
Changes from 1 commit
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
1b585d6
initial commit
vwxyzjn Oct 13, 2022
fa82356
pre-commit
vwxyzjn Oct 13, 2022
4074eee
Add hub integration
vwxyzjn Oct 13, 2022
4436ce4
pre-commit
vwxyzjn Oct 14, 2022
df41e3d
use CommitOperation
vwxyzjn Oct 18, 2022
a98383d
Fix pre-commit
vwxyzjn Oct 18, 2022
b430540
refactor
vwxyzjn Oct 18, 2022
dd8ee86
Merge branch 'master' into hf-integration
vwxyzjn Oct 18, 2022
8144562
push changes
vwxyzjn Oct 27, 2022
2f20e17
refactor
vwxyzjn Oct 27, 2022
fdfc2a5
fix pre-commit
vwxyzjn Nov 16, 2022
56413f8
pre-commit
vwxyzjn Nov 16, 2022
b1b1dbd
Merge branch 'master' into hf-integration
vwxyzjn Nov 16, 2022
f6865d4
close the env and writer after eval
vwxyzjn Nov 16, 2022
fbe986c
support dqn jax
vwxyzjn Nov 17, 2022
83aa010
pre-commit
vwxyzjn Nov 17, 2022
ba1bfdb
Update cleanrl_utils/huggingface.py
vwxyzjn Nov 17, 2022
aee6809
address comments
vwxyzjn Nov 17, 2022
80a460f
update docs
vwxyzjn Nov 17, 2022
40be7d8
support dqn_atari_jax
vwxyzjn Dec 10, 2022
65ded2a
bug fix and docs
vwxyzjn Dec 13, 2022
133e6bd
Add cleanrl to the hf's `metadata`
vwxyzjn Dec 13, 2022
10d0b79
Merge branch 'master' into hf-integration
vwxyzjn Dec 15, 2022
ca60f24
include huggingface integration
vwxyzjn Dec 15, 2022
b165e35
test for enjoy.py
vwxyzjn Dec 15, 2022
7163d0d
bump version, pip install extra hack
vwxyzjn Dec 15, 2022
27d9b3d
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
2a2208f
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
4ac5631
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
40358b1
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
df68d57
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
7dddfbd
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
954723f
update docs
vwxyzjn Dec 16, 2022
fb858ae
update pre-commit
vwxyzjn Dec 16, 2022
b508f66
quick fix
vwxyzjn Dec 16, 2022
7d5193b
bug fix
vwxyzjn Dec 16, 2022
c390b8d
lazy load modules to avoid dependency issues
vwxyzjn Dec 20, 2022
cc456d6
Add huggingface shields
vwxyzjn Dec 20, 2022
fd5a737
Add emoji
vwxyzjn Dec 20, 2022
3b0af25
Update docs
vwxyzjn Dec 20, 2022
ff0be11
pre-commit
vwxyzjn Dec 20, 2022
9bd034e
Update docs
vwxyzjn Dec 20, 2022
78022d7
Update docs
vwxyzjn Dec 20, 2022
aae8d4d
Merge branch 'master' into hf-integration
kinalmehta Dec 30, 2022
1c2cd40
fix: use `algorithm_variant_filename` in model card reproduction script
kinalmehta Dec 31, 2022
e172a0c
typo fix
kinalmehta Dec 31, 2022
c733514
feat: add hf support for c51
kinalmehta Dec 31, 2022
15be698
formatting fix
kinalmehta Dec 31, 2022
8fac8e3
support pulling variant depdencies directly
vwxyzjn Dec 31, 2022
35d6fc7
support model saving for `ppo_atari_envpool_xla_jax_scan`
vwxyzjn Dec 31, 2022
1ce42c9
Merge branch 'master' into hf-integration
vwxyzjn Dec 31, 2022
8990794
support `ppo_atari_envpool_xla_jax_scan`
vwxyzjn Jan 1, 2023
ea4a71d
quick change
vwxyzjn Jan 1, 2023
091b5a6
PPO with machado Atari preprocessing
vwxyzjn Jan 1, 2023
e641b1f
black
vwxyzjn Jan 1, 2023
001337d
update benchmark script
vwxyzjn Jan 1, 2023
e88cdbc
update benchmark script
vwxyzjn Jan 1, 2023
3ecc3f5
change the default frames to 200M (50M steps)
vwxyzjn Jan 1, 2023
b897317
deal with truncation properly
vwxyzjn Jan 2, 2023
67cab32
add truncation and termination metrics
vwxyzjn Jan 2, 2023
9a1486a
add epidose length stats
vwxyzjn Jan 2, 2023
3a90d67
handle truncation and video recording better
vwxyzjn Jan 3, 2023
3ada13b
push changes
vwxyzjn Jan 26, 2023
9c1d220
things kind of work now
vwxyzjn Jan 26, 2023
92f1769
stats
vwxyzjn Jan 26, 2023
38be44e
pre-commit
vwxyzjn Jan 26, 2023
c2b18b5
pmap worked
vwxyzjn Jan 27, 2023
ab732a6
test multi-devices
vwxyzjn Jan 27, 2023
27dd034
cache
vwxyzjn Jan 27, 2023
c1f7301
delete
vwxyzjn Jan 27, 2023
73ae15c
set some common setting
vwxyzjn Jan 27, 2023
38c14ec
allow `params-queue-timeout`
vwxyzjn Jan 27, 2023
aa37525
update some experiments
vwxyzjn Jan 28, 2023
1d85943
[WIP] support multithreads per actor
vwxyzjn Jan 29, 2023
907bdbd
improve profile
vwxyzjn Jan 29, 2023
f9ee991
refactor order
vwxyzjn Jan 29, 2023
a93a1f5
call `jax.device_put_sharded` in the learner
vwxyzjn Jan 29, 2023
0eefb50
use `jax.put_device_replicated`
vwxyzjn Jan 29, 2023
0dd591c
Warm up the learner before unleashing actor
vwxyzjn Feb 4, 2023
6fb29d1
add model eval, add slurm friendly setting
vwxyzjn Feb 5, 2023
c6705a1
args check
vwxyzjn Feb 5, 2023
d6eab1b
Merge branch 'master' into ppo_machado
vwxyzjn Feb 5, 2023
c7de59c
update poetry version
vwxyzjn Feb 5, 2023
94d09df
Empty-Commit
Feb 5, 2023
c3fd478
trigger ci
vwxyzjn Feb 5, 2023
c27a532
fix problem
vwxyzjn Feb 5, 2023
18b7458
bump poetry version
vwxyzjn Feb 5, 2023
c3c831d
Merge branch 'ppo_machado' of https://github.com/vwxyzjn/cleanrl into…
Feb 5, 2023
c224e19
utility improvement
vwxyzjn Feb 5, 2023
c92d3e1
add eval components
vwxyzjn Feb 5, 2023
ecdf2e1
deal with episode truncations in eval
vwxyzjn Feb 5, 2023
be26f77
IMPALA atari wrapper setting
vwxyzjn Feb 5, 2023
464072d
hack: ensure the program always ends
vwxyzjn Feb 6, 2023
4e93341
workaround for sail-sg/envpool#239
vwxyzjn Feb 8, 2023
4d7895d
no longer need timeout
vwxyzjn Feb 8, 2023
9c46850
make sure the program ends
vwxyzjn Feb 9, 2023
52e2638
push changes
vwxyzjn Feb 10, 2023
e61a1b0
support `jax.distributed`
vwxyzjn Feb 16, 2023
8f29c82
distributed support
vwxyzjn Feb 16, 2023
2657f59
pre-commit
vwxyzjn Feb 16, 2023
100ff51
put data transfer in the learner
vwxyzjn Feb 17, 2023
abfe835
quick change
vwxyzjn Feb 17, 2023
2a89940
slurm support
vwxyzjn Feb 17, 2023
eab04f0
quick fix
vwxyzjn Feb 17, 2023
85c6aac
push changes
vwxyzjn Feb 17, 2023
42a800b
quick change
vwxyzjn Feb 17, 2023
9b52acb
fix distributed bug
vwxyzjn Feb 19, 2023
753a9cb
pre-commit
vwxyzjn Feb 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
distributed support
  • Loading branch information
vwxyzjn committed Feb 16, 2023
commit 8f29c82d48e25878674e67c4a7595feec1a01a29
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
"""
* 🥼 Test throughput (see docs):
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
* this will help us diagnose the throughput issue
* python sebulba_ppo_envpool.py --actor-device-ids 0 --num-actor-threads 2 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
* 🔥 Best performance so far (more GPUs -> faster)
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 --track
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 1 --track
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 --track
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --track
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 --track
* (this actually doesn't work that well) python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 7 --num-envs 70 --async-batch-size 35 --track
"""
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
import argparse
import os
Expand Down Expand Up @@ -108,25 +94,21 @@ def parse_args():
help="the target KL divergence threshold")

parser.add_argument("--actor-device-ids", type=int, nargs="+", default=[0], # type is actually List[int]
help="the device ids that actor workers will use")
help="the device ids that actor workers will use (currently only support 1 device)")
parser.add_argument("--learner-device-ids", type=int, nargs="+", default=[0], # type is actually List[int]
help="the device ids that actor workers will use")
parser.add_argument("--num-actor-threads", type=int, default=1,
help="the number of actor threads")
parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to call block_until_ready() for profiling")
help="the device ids that learner workers will use")
parser.add_argument("--distributed", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to use `jax.distirbuted`")
parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to call block_until_ready() for profiling")
parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to test actor-learner throughput by removing the actor-learner communication")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_updates = args.total_timesteps // args.batch_size
args.local_batch_size = int(args.num_envs * args.num_steps)
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
args.num_updates = args.total_timesteps // args.local_batch_size
args.async_update = int(args.num_envs / args.async_batch_size)
assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now"
if args.num_actor_threads > 1:
warnings.warn("⚠️ !!!! `num_actor_threads` > 1 is not tested with learning; see docs for detail")
# fmt: on
return args

Expand All @@ -136,14 +118,12 @@ def parse_args():
) # 108000 is the max number of frames in an Atari game, divided by 4 to account for frame skipping


def make_env(env_id, seed, num_envs, async_batch_size=1, num_threads=None, thread_affinity_offset=-1):
def make_env(env_id, seed, num_envs, async_batch_size=1):
def thunk():
envs = envpool.make(
env_id,
env_type="gym",
num_envs=num_envs,
num_threads=num_threads if num_threads is not None else async_batch_size,
thread_affinity_offset=thread_affinity_offset,
batch_size=async_batch_size,
episodic_life=True, # Espeholt et al., 2018, Tab. G.1
repeat_action_probability=0, # Hessel et al., 2022 (Muesli) Tab. 10
Expand Down Expand Up @@ -304,17 +284,14 @@ def f(carry, x):


def rollout(
i,
num_threads, # =None,
thread_affinity_offset, # =-1,
key: jax.random.PRNGKey,
args,
rollout_queue,
params_queue: queue.Queue,
writer,
learner_devices,
):
envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size, num_threads, thread_affinity_offset)()
envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size)()
len_actor_device_ids = len(args.actor_device_ids)
global_step = 0
# TRY NOT TO MODIFY: start the game
Expand Down Expand Up @@ -370,7 +347,7 @@ def rollout(
env_recv_time_start = time.time()
next_obs, next_reward, next_done, info = envs.recv()
env_recv_time += time.time() - env_recv_time_start
global_step += len(next_done) * args.num_actor_threads * len_actor_device_ids * args.world_size
global_step += len(next_done) * len_actor_device_ids * args.world_size
env_id = info["env_id"]

inference_time_start = time.time()
Expand Down Expand Up @@ -413,9 +390,8 @@ def rollout(
avg_episodic_return = np.mean(returned_episode_returns)
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", np.mean(returned_episode_lengths), global_step)
if i == 0:
print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")
print("SPS:", int(global_step / (time.time() - start_time)))
print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

writer.add_scalar("stats/truncations", np.sum(truncations), global_step)
Expand Down Expand Up @@ -460,7 +436,6 @@ def rollout(
int(
args.num_envs
* args.num_steps
* args.num_actor_threads
* len_actor_device_ids
* args.world_size
/ (time.time() - update_time_start)
Expand Down Expand Up @@ -644,9 +619,9 @@ def update_minibatch(agent_state, minibatch):
args.world_size = jax.process_count()
args.local_rank = jax.process_index()
args.world_num_envs = args.num_envs * args.world_size
args.world_batch_size = args.batch_size * args.world_size
args.world_minibatch_size = args.minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.batch_size * args.world_size)
args.world_batch_size = args.local_batch_size * args.world_size
args.world_minibatch_size = args.local_minibatch_size * args.world_size
args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
args.async_update = int(args.num_envs / args.async_batch_size)
local_devices = jax.local_devices()
global_devices = jax.devices()
Expand Down Expand Up @@ -706,7 +681,7 @@ def linear_schedule(count):
),
tx=optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.adam(
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
Expand All @@ -724,33 +699,21 @@ def linear_schedule(count):

rollout_queue = queue.Queue(maxsize=1)
params_queues = []
num_cpus = mp.cpu_count()
fair_num_cpus = num_cpus // len(args.actor_device_ids)

class DummyWriter:
def add_scalar(self, arg0, arg1, arg3):
pass

dummy_writer = DummyWriter()
for d_idx, d_id in enumerate(args.actor_device_ids):
for j in range(args.num_actor_threads):
params_queue = queue.Queue(maxsize=1)
params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id]))
threading.Thread(
target=rollout,
args=(
j,
fair_num_cpus if args.num_actor_threads > 1 else None,
j * args.num_actor_threads if args.num_actor_threads > 1 else -1,
jax.device_put(key, local_devices[d_id]),
args,
rollout_queue,
params_queue,
writer if d_idx == 0 and j == 0 else dummy_writer,
learner_devices,
),
).start()
params_queues.append(params_queue)
params_queue = queue.Queue(maxsize=1)
params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id]))
threading.Thread(
target=rollout,
args=(
jax.device_put(key, local_devices[d_id]),
args,
rollout_queue,
params_queue,
writer,
learner_devices,
),
).start()
params_queues.append(params_queue)

rollout_queue_get_time = deque(maxlen=10)
learner_policy_version = 0
Expand Down Expand Up @@ -784,10 +747,9 @@ def add_scalar(self, arg0, arg1, arg3):
)
if learner_policy_version == 1 or not args.test_actor_learner_throughput:
for d_idx, d_id in enumerate(args.actor_device_ids):
for j in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + j].put(
jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id])
)
params_queues[d_idx].put(
jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id])
)
if args.profile:
v_loss[-1, -1, -1].block_until_ready()
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
Expand All @@ -799,7 +761,7 @@ def add_scalar(self, arg0, arg1, arg3):
)

# TRY NOT TO MODIFY: record rewards for plotting purposes
# writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
writer.add_scalar("losses/value_loss", v_loss[-1, -1, -1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1, -1, -1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step)
Expand Down