Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +48,8 @@ def parse_args():
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--buffer-size", type=int, default=10000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
Expand Down Expand Up @@ -80,9 +82,7 @@ def thunk():
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -139,8 +139,11 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
assert envs.num_envs == 1, "vectorized envs are not supported at the moment"

q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
Expand All @@ -152,12 +155,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -168,23 +171,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if info is None:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncated):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
35 changes: 20 additions & 15 deletions cleanrl/dqn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import flax
import flax.linen as nn
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -46,6 +46,8 @@ def parse_args():
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--buffer-size", type=int, default=10000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
Expand Down Expand Up @@ -78,9 +80,7 @@ def thunk():
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -137,10 +137,13 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
key, q_key = jax.random.split(key, 2)

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
assert envs.num_envs == 1, "vectorized envs are not supported yet"

obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)

q_network = QNetwork(action_dim=envs.single_action_space.n)

Expand All @@ -160,7 +163,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_observation_space,
envs.single_action_space,
"cpu",
handle_timeout_termination=True,
handle_timeout_termination=False,
)

@jax.jit
Expand All @@ -181,7 +184,7 @@ def mse_loss(params):
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -193,23 +196,25 @@ def mse_loss(params):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if info is None:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncated):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down