Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,41 +197,42 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
_, next_pmfs = target_network.get_action(data.next_observations)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones)
# projection
delta_z = target_network.atoms[1] - target_network.atoms[0]
tz = next_atoms.clamp(args.v_min, args.v_max)

b = (tz - args.v_min) / delta_z
l = b.floor().clamp(0, args.n_atoms - 1)
u = b.ceil().clamp(0, args.n_atoms - 1)
# (l == u).float() handles the case where bj is exactly an integer
# example bj = 1, then the upper ceiling should be uj= 2, and lj= 1
d_m_l = (u + (l == u).float() - b) * next_pmfs
d_m_u = (b - l) * next_pmfs
target_pmfs = torch.zeros_like(next_pmfs)
for i in range(target_pmfs.size(0)):
target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i])
target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i])

_, old_pmfs = q_network.get_action(data.observations, data.actions.flatten())
loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()

if global_step % 100 == 0:
writer.add_scalar("losses/loss", loss.item(), global_step)
old_val = (old_pmfs * q_network.atoms).sum(1)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
_, next_pmfs = target_network.get_action(data.next_observations)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones)
# projection
delta_z = target_network.atoms[1] - target_network.atoms[0]
tz = next_atoms.clamp(args.v_min, args.v_max)

b = (tz - args.v_min) / delta_z
l = b.floor().clamp(0, args.n_atoms - 1)
u = b.ceil().clamp(0, args.n_atoms - 1)
# (l == u).float() handles the case where bj is exactly an integer
# example bj = 1, then the upper ceiling should be uj= 2, and lj= 1
d_m_l = (u + (l == u).float() - b) * next_pmfs
d_m_u = (b - l) * next_pmfs
target_pmfs = torch.zeros_like(next_pmfs)
for i in range(target_pmfs.size(0)):
target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i])
target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i])

_, old_pmfs = q_network.get_action(data.observations, data.actions.flatten())
loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()

if global_step % 100 == 0:
writer.add_scalar("losses/loss", loss.item(), global_step)
old_val = (old_pmfs * q_network.atoms).sum(1)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()

# update the target network
if global_step % args.target_network_frequency == 0:
Expand Down
71 changes: 36 additions & 35 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,41 +219,42 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
_, next_pmfs = target_network.get_action(data.next_observations)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones)
# projection
delta_z = target_network.atoms[1] - target_network.atoms[0]
tz = next_atoms.clamp(args.v_min, args.v_max)

b = (tz - args.v_min) / delta_z
l = b.floor().clamp(0, args.n_atoms - 1)
u = b.ceil().clamp(0, args.n_atoms - 1)
# (l == u).float() handles the case where bj is exactly an integer
# example bj = 1, then the upper ceiling should be uj= 2, and lj= 1
d_m_l = (u + (l == u).float() - b) * next_pmfs
d_m_u = (b - l) * next_pmfs
target_pmfs = torch.zeros_like(next_pmfs)
for i in range(target_pmfs.size(0)):
target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i])
target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i])

_, old_pmfs = q_network.get_action(data.observations, data.actions.flatten())
loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()

if global_step % 100 == 0:
writer.add_scalar("losses/loss", loss.item(), global_step)
old_val = (old_pmfs * q_network.atoms).sum(1)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
_, next_pmfs = target_network.get_action(data.next_observations)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones)
# projection
delta_z = target_network.atoms[1] - target_network.atoms[0]
tz = next_atoms.clamp(args.v_min, args.v_max)

b = (tz - args.v_min) / delta_z
l = b.floor().clamp(0, args.n_atoms - 1)
u = b.ceil().clamp(0, args.n_atoms - 1)
# (l == u).float() handles the case where bj is exactly an integer
# example bj = 1, then the upper ceiling should be uj= 2, and lj= 1
d_m_l = (u + (l == u).float() - b) * next_pmfs
d_m_u = (b - l) * next_pmfs
target_pmfs = torch.zeros_like(next_pmfs)
for i in range(target_pmfs.size(0)):
target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i])
target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i])

_, old_pmfs = q_network.get_action(data.observations, data.actions.flatten())
loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean()

if global_step % 100 == 0:
writer.add_scalar("losses/loss", loss.item(), global_step)
old_val = (old_pmfs * q_network.atoms).sum(1)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()

# update the target network
if global_step % args.target_network_frequency == 0:
Expand Down
37 changes: 19 additions & 18 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
target_max, _ = target_network(data.next_observations).max(dim=1)
td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
old_val = q_network(data.observations).gather(1, data.actions).squeeze()
loss = F.mse_loss(td_target, old_val)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", loss, global_step)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
target_max, _ = target_network(data.next_observations).max(dim=1)
td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
old_val = q_network(data.observations).gather(1, data.actions).squeeze()
loss = F.mse_loss(td_target, old_val)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", loss, global_step)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()

# update the target network
if global_step % args.target_network_frequency == 0:
Expand Down
37 changes: 19 additions & 18 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,24 +204,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
target_max, _ = target_network(data.next_observations).max(dim=1)
td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
old_val = q_network(data.observations).gather(1, data.actions).squeeze()
loss = F.mse_loss(td_target, old_val)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", loss, global_step)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
target_max, _ = target_network(data.next_observations).max(dim=1)
td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
old_val = q_network(data.observations).gather(1, data.actions).squeeze()
loss = F.mse_loss(td_target, old_val)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", loss, global_step)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()

# update the target network
if global_step % args.target_network_frequency == 0:
Expand Down
35 changes: 18 additions & 17 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,23 +235,24 @@ def mse_loss(params):
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
# perform a gradient-descent step
loss, old_val, q_state = update(
q_state,
data.observations.numpy(),
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step)
writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
# perform a gradient-descent step
loss, old_val, q_state = update(
q_state,
data.observations.numpy(),
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step)
writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# update the target network
if global_step % args.target_network_frequency == 0:
Expand Down
35 changes: 18 additions & 17 deletions cleanrl/dqn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,24 @@ def mse_loss(params):
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
# perform a gradient-descent step
loss, old_val, q_state = update(
q_state,
data.observations.numpy(),
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step)
writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
# perform a gradient-descent step
loss, old_val, q_state = update(
q_state,
data.observations.numpy(),
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step)
writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# update the target network
if global_step % args.target_network_frequency == 0:
Expand Down