diff --git a/cleanrl/c51.py b/cleanrl/c51.py index 64cf9e134..b4ebb5473 100755 --- a/cleanrl/c51.py +++ b/cleanrl/c51.py @@ -197,41 +197,42 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): obs = next_obs # ALGO LOGIC: training. - if global_step > args.learning_starts and global_step % args.train_frequency == 0: - data = rb.sample(args.batch_size) - with torch.no_grad(): - _, next_pmfs = target_network.get_action(data.next_observations) - next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones) - # projection - delta_z = target_network.atoms[1] - target_network.atoms[0] - tz = next_atoms.clamp(args.v_min, args.v_max) - - b = (tz - args.v_min) / delta_z - l = b.floor().clamp(0, args.n_atoms - 1) - u = b.ceil().clamp(0, args.n_atoms - 1) - # (l == u).float() handles the case where bj is exactly an integer - # example bj = 1, then the upper ceiling should be uj= 2, and lj= 1 - d_m_l = (u + (l == u).float() - b) * next_pmfs - d_m_u = (b - l) * next_pmfs - target_pmfs = torch.zeros_like(next_pmfs) - for i in range(target_pmfs.size(0)): - target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i]) - target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i]) - - _, old_pmfs = q_network.get_action(data.observations, data.actions.flatten()) - loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean() - - if global_step % 100 == 0: - writer.add_scalar("losses/loss", loss.item(), global_step) - old_val = (old_pmfs * q_network.atoms).sum(1) - writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - - # optimize the model - optimizer.zero_grad() - loss.backward() - optimizer.step() + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + with torch.no_grad(): + _, next_pmfs = target_network.get_action(data.next_observations) + next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones) + # projection + delta_z = target_network.atoms[1] - target_network.atoms[0] + tz = next_atoms.clamp(args.v_min, args.v_max) + + b = (tz - args.v_min) / delta_z + l = b.floor().clamp(0, args.n_atoms - 1) + u = b.ceil().clamp(0, args.n_atoms - 1) + # (l == u).float() handles the case where bj is exactly an integer + # example bj = 1, then the upper ceiling should be uj= 2, and lj= 1 + d_m_l = (u + (l == u).float() - b) * next_pmfs + d_m_u = (b - l) * next_pmfs + target_pmfs = torch.zeros_like(next_pmfs) + for i in range(target_pmfs.size(0)): + target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i]) + target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i]) + + _, old_pmfs = q_network.get_action(data.observations, data.actions.flatten()) + loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean() + + if global_step % 100 == 0: + writer.add_scalar("losses/loss", loss.item(), global_step) + old_val = (old_pmfs * q_network.atoms).sum(1) + writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + # optimize the model + optimizer.zero_grad() + loss.backward() + optimizer.step() # update the target network if global_step % args.target_network_frequency == 0: diff --git a/cleanrl/c51_atari.py b/cleanrl/c51_atari.py index 75c1af27c..63ee12d32 100755 --- a/cleanrl/c51_atari.py +++ b/cleanrl/c51_atari.py @@ -219,41 +219,42 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): obs = next_obs # ALGO LOGIC: training. - if global_step > args.learning_starts and global_step % args.train_frequency == 0: - data = rb.sample(args.batch_size) - with torch.no_grad(): - _, next_pmfs = target_network.get_action(data.next_observations) - next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones) - # projection - delta_z = target_network.atoms[1] - target_network.atoms[0] - tz = next_atoms.clamp(args.v_min, args.v_max) - - b = (tz - args.v_min) / delta_z - l = b.floor().clamp(0, args.n_atoms - 1) - u = b.ceil().clamp(0, args.n_atoms - 1) - # (l == u).float() handles the case where bj is exactly an integer - # example bj = 1, then the upper ceiling should be uj= 2, and lj= 1 - d_m_l = (u + (l == u).float() - b) * next_pmfs - d_m_u = (b - l) * next_pmfs - target_pmfs = torch.zeros_like(next_pmfs) - for i in range(target_pmfs.size(0)): - target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i]) - target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i]) - - _, old_pmfs = q_network.get_action(data.observations, data.actions.flatten()) - loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean() - - if global_step % 100 == 0: - writer.add_scalar("losses/loss", loss.item(), global_step) - old_val = (old_pmfs * q_network.atoms).sum(1) - writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - - # optimize the model - optimizer.zero_grad() - loss.backward() - optimizer.step() + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + with torch.no_grad(): + _, next_pmfs = target_network.get_action(data.next_observations) + next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones) + # projection + delta_z = target_network.atoms[1] - target_network.atoms[0] + tz = next_atoms.clamp(args.v_min, args.v_max) + + b = (tz - args.v_min) / delta_z + l = b.floor().clamp(0, args.n_atoms - 1) + u = b.ceil().clamp(0, args.n_atoms - 1) + # (l == u).float() handles the case where bj is exactly an integer + # example bj = 1, then the upper ceiling should be uj= 2, and lj= 1 + d_m_l = (u + (l == u).float() - b) * next_pmfs + d_m_u = (b - l) * next_pmfs + target_pmfs = torch.zeros_like(next_pmfs) + for i in range(target_pmfs.size(0)): + target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i]) + target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i]) + + _, old_pmfs = q_network.get_action(data.observations, data.actions.flatten()) + loss = (-(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean() + + if global_step % 100 == 0: + writer.add_scalar("losses/loss", loss.item(), global_step) + old_val = (old_pmfs * q_network.atoms).sum(1) + writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + # optimize the model + optimizer.zero_grad() + loss.backward() + optimizer.step() # update the target network if global_step % args.target_network_frequency == 0: diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py index 3c28e5ce4..7ff157414 100644 --- a/cleanrl/dqn.py +++ b/cleanrl/dqn.py @@ -182,24 +182,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): obs = next_obs # ALGO LOGIC: training. - if global_step > args.learning_starts and global_step % args.train_frequency == 0: - data = rb.sample(args.batch_size) - with torch.no_grad(): - target_max, _ = target_network(data.next_observations).max(dim=1) - td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) - old_val = q_network(data.observations).gather(1, data.actions).squeeze() - loss = F.mse_loss(td_target, old_val) - - if global_step % 100 == 0: - writer.add_scalar("losses/td_loss", loss, global_step) - writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - - # optimize the model - optimizer.zero_grad() - loss.backward() - optimizer.step() + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + with torch.no_grad(): + target_max, _ = target_network(data.next_observations).max(dim=1) + td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) + old_val = q_network(data.observations).gather(1, data.actions).squeeze() + loss = F.mse_loss(td_target, old_val) + + if global_step % 100 == 0: + writer.add_scalar("losses/td_loss", loss, global_step) + writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + # optimize the model + optimizer.zero_grad() + loss.backward() + optimizer.step() # update the target network if global_step % args.target_network_frequency == 0: diff --git a/cleanrl/dqn_atari.py b/cleanrl/dqn_atari.py index 0191c92c0..95c1325e8 100644 --- a/cleanrl/dqn_atari.py +++ b/cleanrl/dqn_atari.py @@ -204,24 +204,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): obs = next_obs # ALGO LOGIC: training. - if global_step > args.learning_starts and global_step % args.train_frequency == 0: - data = rb.sample(args.batch_size) - with torch.no_grad(): - target_max, _ = target_network(data.next_observations).max(dim=1) - td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) - old_val = q_network(data.observations).gather(1, data.actions).squeeze() - loss = F.mse_loss(td_target, old_val) - - if global_step % 100 == 0: - writer.add_scalar("losses/td_loss", loss, global_step) - writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - - # optimize the model - optimizer.zero_grad() - loss.backward() - optimizer.step() + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + with torch.no_grad(): + target_max, _ = target_network(data.next_observations).max(dim=1) + td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) + old_val = q_network(data.observations).gather(1, data.actions).squeeze() + loss = F.mse_loss(td_target, old_val) + + if global_step % 100 == 0: + writer.add_scalar("losses/td_loss", loss, global_step) + writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + # optimize the model + optimizer.zero_grad() + loss.backward() + optimizer.step() # update the target network if global_step % args.target_network_frequency == 0: diff --git a/cleanrl/dqn_atari_jax.py b/cleanrl/dqn_atari_jax.py index 3d060e9f1..869000aaa 100644 --- a/cleanrl/dqn_atari_jax.py +++ b/cleanrl/dqn_atari_jax.py @@ -235,23 +235,24 @@ def mse_loss(params): obs = next_obs # ALGO LOGIC: training. - if global_step > args.learning_starts and global_step % args.train_frequency == 0: - data = rb.sample(args.batch_size) - # perform a gradient-descent step - loss, old_val, q_state = update( - q_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - ) - - if global_step % 100 == 0: - writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step) - writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + # perform a gradient-descent step + loss, old_val, q_state = update( + q_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + ) + + if global_step % 100 == 0: + writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step) + writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) # update the target network if global_step % args.target_network_frequency == 0: diff --git a/cleanrl/dqn_jax.py b/cleanrl/dqn_jax.py index 9319468e2..86c9bc21e 100644 --- a/cleanrl/dqn_jax.py +++ b/cleanrl/dqn_jax.py @@ -207,23 +207,24 @@ def mse_loss(params): obs = next_obs # ALGO LOGIC: training. - if global_step > args.learning_starts and global_step % args.train_frequency == 0: - data = rb.sample(args.batch_size) - # perform a gradient-descent step - loss, old_val, q_state = update( - q_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - ) - - if global_step % 100 == 0: - writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step) - writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + # perform a gradient-descent step + loss, old_val, q_state = update( + q_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + ) + + if global_step % 100 == 0: + writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step) + writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) # update the target network if global_step % args.target_network_frequency == 0: