-
Notifications
You must be signed in to change notification settings - Fork 929
SAC-discrete implementation #270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
18b643b
add draft of SAC discrete implementation
timoklein c3c98bd
run pre-commit
timoklein ec31dc4
Use log softmax instead of author's log-pi code
timoklein deb37e8
Revert to cleanrl SAC delay implementation (it's more stable)
timoklein a1fdd2b
Remove docstrings and duplicate code
timoklein 977a83a
Use correct clipreward wrapper
timoklein f2ea3e6
fix bug in log softmax calculation
timoklein 48af04c
adhere to cleanrl log_prob naming
timoklein b2a09a0
fix bug in entropy target calculation
timoklein 89680c7
change layer initialization to match existing cleanrl codebase
timoklein b1d7d44
working minimal diff version
timoklein 61e1c74
implement original learning update frequency
timoklein 7cd1e3a
parameterize the entropy scale for autotuning
timoklein 61c46fc
add benchmarking script
timoklein 4915e4c
rename target entropy factor and set new default value
timoklein 6f7251f
add docs draft
timoklein 23b60ff
fix SAC-discrete links to work pre merge
timoklein 10ee9f0
add preliminary result table for SAC-discrete
timoklein 8430fd8
clean up todos and add header
timoklein a17768c
minimize diff between sac_atari and sac_continuous
timoklein d6a507c
add sac-discrete end2end test
timoklein a7ea6f4
SAC-discrete docs rework
timoklein 9f6493c
Update SAC-discrete @100k results
timoklein 59a6d00
Fix doc links and unify naming in code
timoklein 1304b7a
update docs
vwxyzjn 3a3f41b
fix target update frequency (see PR #323)
timoklein 80187ad
clarify comment regarding CNN encoder sharing
timoklein e9cb494
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein e199e39
fix benchmark installation
timoklein bb27fa1
fix eps in minimal diff version and improve code readability
timoklein 6a46632
add docs for eps and finalize code
timoklein cad5fff
use no_grad for actor Q-vals and re-use action-probs & log-probs in a…
timoklein 0cf47f1
update docs for new code and settings
timoklein 61988c4
fix links to point to main branch
timoklein 6e17005
update sac-discrete training plots
timoklein 33b00f3
new sac-d training plots
timoklein 5dabafb
update results table and fix link
timoklein 90b2fd5
fix pong chart title
timoklein a763994
add Jimmy Ba name as exception to code spell check
timoklein 071cdbb
change target_entropy_scale default value to same value as experiments
timoklein dcc2633
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein c671a92
remove blank line at end of pre-commit
timoklein File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| poetry install --with atari | ||
| OMP_NUM_THREADS=1 python -m cleanrl_utils.benchmark \ | ||
| --env-ids PongNoFrameskip-v4 BreakoutNoFrameskip-v4 BeamRiderNoFrameskip-v4 \ | ||
| --command "poetry run python cleanrl/sac_atari.py --cuda True --track" \ | ||
| --num-seeds 3 \ | ||
| --workers 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,342 @@ | ||
| # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_ataripy | ||
| import argparse | ||
| import os | ||
| import random | ||
| import time | ||
| from distutils.util import strtobool | ||
|
|
||
| import gym | ||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import torch.optim as optim | ||
| from stable_baselines3.common.atari_wrappers import ( | ||
| ClipRewardEnv, | ||
| EpisodicLifeEnv, | ||
| FireResetEnv, | ||
| MaxAndSkipEnv, | ||
| NoopResetEnv, | ||
| ) | ||
| from stable_baselines3.common.buffers import ReplayBuffer | ||
| from torch.distributions.categorical import Categorical | ||
| from torch.utils.tensorboard import SummaryWriter | ||
|
|
||
|
|
||
| def parse_args(): | ||
| # fmt: off | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), | ||
| help="the name of this experiment") | ||
| parser.add_argument("--seed", type=int, default=1, | ||
| help="seed of the experiment") | ||
| parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
| help="if toggled, `torch.backends.cudnn.deterministic=False`") | ||
| parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
| help="if toggled, cuda will be enabled by default") | ||
| parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
| help="if toggled, this experiment will be tracked with Weights and Biases") | ||
| parser.add_argument("--wandb-project-name", type=str, default="cleanRL", | ||
| help="the wandb's project name") | ||
| parser.add_argument("--wandb-entity", type=str, default=None, | ||
| help="the entity (team) of wandb's project") | ||
| parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
| help="weather to capture videos of the agent performances (check out `videos` folder)") | ||
|
|
||
| # Algorithm specific arguments | ||
| parser.add_argument("--env-id", type=str, default="BeamRiderNoFrameskip-v4", | ||
| help="the id of the environment") | ||
| parser.add_argument("--total-timesteps", type=int, default=5000000, | ||
| help="total timesteps of the experiments") | ||
| parser.add_argument("--buffer-size", type=int, default=int(1e6), | ||
| help="the replay memory buffer size") # smaller than in original paper but evaluation is done only for 100k steps anyway | ||
| parser.add_argument("--gamma", type=float, default=0.99, | ||
| help="the discount factor gamma") | ||
| parser.add_argument("--tau", type=float, default=1.0, | ||
| help="target smoothing coefficient (default: 1)") # Default is 1 to perform replacement update | ||
| parser.add_argument("--batch-size", type=int, default=64, | ||
| help="the batch size of sample from the reply memory") | ||
| parser.add_argument("--learning-starts", type=int, default=2e4, | ||
| help="timestep to start learning") | ||
| parser.add_argument("--policy-lr", type=float, default=3e-4, | ||
| help="the learning rate of the policy network optimizer") | ||
| parser.add_argument("--q-lr", type=float, default=3e-4, | ||
| help="the learning rate of the Q network network optimizer") | ||
| parser.add_argument("--update-frequency", type=int, default=4, | ||
| help="the frequency of training updates") | ||
| parser.add_argument("--target-network-frequency", type=int, default=8000, | ||
| help="the frequency of updates for the target networks") | ||
| parser.add_argument("--alpha", type=float, default=0.2, | ||
| help="Entropy regularization coefficient.") | ||
| parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True, | ||
| help="automatic tuning of the entropy coefficient") | ||
| parser.add_argument("--target-entropy-scale", type=float, default=0.89, | ||
| help="coefficient for scaling the autotune entropy target") | ||
| args = parser.parse_args() | ||
| # fmt: on | ||
| return args | ||
|
|
||
|
|
||
| def make_env(env_id, seed, idx, capture_video, run_name): | ||
| def thunk(): | ||
| env = gym.make(env_id) | ||
| env = gym.wrappers.RecordEpisodeStatistics(env) | ||
| if capture_video: | ||
| if idx == 0: | ||
| env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
| env = NoopResetEnv(env, noop_max=30) | ||
| env = MaxAndSkipEnv(env, skip=4) | ||
| env = EpisodicLifeEnv(env) | ||
| if "FIRE" in env.unwrapped.get_action_meanings(): | ||
| env = FireResetEnv(env) | ||
| env = ClipRewardEnv(env) | ||
| env = gym.wrappers.ResizeObservation(env, (84, 84)) | ||
| env = gym.wrappers.GrayScaleObservation(env) | ||
| env = gym.wrappers.FrameStack(env, 4) | ||
| env.seed(seed) | ||
| env.action_space.seed(seed) | ||
| env.observation_space.seed(seed) | ||
| return env | ||
|
|
||
| return thunk | ||
|
|
||
|
|
||
| def layer_init(layer, bias_const=0.0): | ||
| nn.init.kaiming_normal_(layer.weight) | ||
| torch.nn.init.constant_(layer.bias, bias_const) | ||
| return layer | ||
|
|
||
|
|
||
| # ALGO LOGIC: initialize agent here: | ||
| # NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients | ||
| # See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info | ||
| # TL;DR The actor's gradients mess up the representation when using a joint encoder | ||
| class SoftQNetwork(nn.Module): | ||
| def __init__(self, envs): | ||
| super().__init__() | ||
| obs_shape = envs.single_observation_space.shape | ||
| self.conv = nn.Sequential( | ||
| layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)), | ||
| nn.ReLU(), | ||
| layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), | ||
| nn.ReLU(), | ||
| layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), | ||
| nn.Flatten(), | ||
| ) | ||
|
|
||
| with torch.inference_mode(): | ||
| output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1] | ||
|
|
||
| self.fc1 = layer_init(nn.Linear(output_dim, 512)) | ||
| self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n)) | ||
|
|
||
| def forward(self, x): | ||
| x = F.relu(self.conv(x / 255.0)) | ||
| x = F.relu(self.fc1(x)) | ||
| q_vals = self.fc_q(x) | ||
| return q_vals | ||
|
|
||
|
|
||
| class Actor(nn.Module): | ||
| def __init__(self, envs): | ||
| super().__init__() | ||
| obs_shape = envs.single_observation_space.shape | ||
| self.conv = nn.Sequential( | ||
| layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)), | ||
| nn.ReLU(), | ||
| layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), | ||
| nn.ReLU(), | ||
| layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), | ||
| nn.Flatten(), | ||
| ) | ||
|
|
||
| with torch.inference_mode(): | ||
| output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1] | ||
|
|
||
| self.fc1 = layer_init(nn.Linear(output_dim, 512)) | ||
| self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n)) | ||
|
|
||
| def forward(self, x): | ||
| x = F.relu(self.conv(x)) | ||
| x = F.relu(self.fc1(x)) | ||
| logits = self.fc_logits(x) | ||
|
|
||
| return logits | ||
|
|
||
| def get_action(self, x): | ||
| logits = self(x / 255.0) | ||
| policy_dist = Categorical(logits=logits) | ||
| action = policy_dist.sample() | ||
| # Action probabilities for calculating the adapted soft-Q loss | ||
| action_probs = policy_dist.probs | ||
| log_prob = F.log_softmax(logits, dim=1) | ||
| return action, log_prob, action_probs | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_args() | ||
| run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" | ||
| if args.track: | ||
| import wandb | ||
|
|
||
| wandb.init( | ||
| project=args.wandb_project_name, | ||
| entity=args.wandb_entity, | ||
| sync_tensorboard=True, | ||
| config=vars(args), | ||
| name=run_name, | ||
| monitor_gym=True, | ||
| save_code=True, | ||
| ) | ||
| writer = SummaryWriter(f"runs/{run_name}") | ||
| writer.add_text( | ||
| "hyperparameters", | ||
| "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
| ) | ||
|
|
||
| # TRY NOT TO MODIFY: seeding | ||
| random.seed(args.seed) | ||
| np.random.seed(args.seed) | ||
| torch.manual_seed(args.seed) | ||
| torch.backends.cudnn.deterministic = args.torch_deterministic | ||
|
|
||
| device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
|
|
||
| # env setup | ||
| envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) | ||
| assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" | ||
|
|
||
| actor = Actor(envs).to(device) | ||
| qf1 = SoftQNetwork(envs).to(device) | ||
| qf2 = SoftQNetwork(envs).to(device) | ||
| qf1_target = SoftQNetwork(envs).to(device) | ||
| qf2_target = SoftQNetwork(envs).to(device) | ||
| qf1_target.load_state_dict(qf1.state_dict()) | ||
| qf2_target.load_state_dict(qf2.state_dict()) | ||
| # TRY NOT TO MODIFY: eps=1e-4 increases numerical stability | ||
| q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4) | ||
| actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4) | ||
|
|
||
| # Automatic entropy tuning | ||
| if args.autotune: | ||
| target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n)) | ||
timoklein marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| log_alpha = torch.zeros(1, requires_grad=True, device=device) | ||
| alpha = log_alpha.exp().item() | ||
| a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4) | ||
| else: | ||
| alpha = args.alpha | ||
|
|
||
| rb = ReplayBuffer( | ||
| args.buffer_size, | ||
| envs.single_observation_space, | ||
| envs.single_action_space, | ||
| device, | ||
| handle_timeout_termination=True, | ||
timoklein marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| start_time = time.time() | ||
|
|
||
| # TRY NOT TO MODIFY: start the game | ||
| obs = envs.reset() | ||
| for global_step in range(args.total_timesteps): | ||
| # ALGO LOGIC: put action logic here | ||
| if global_step < args.learning_starts: | ||
| actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) | ||
| else: | ||
| actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) | ||
| actions = actions.detach().cpu().numpy() | ||
|
|
||
| # TRY NOT TO MODIFY: execute the game and log data. | ||
| next_obs, rewards, dones, infos = envs.step(actions) | ||
|
|
||
| # TRY NOT TO MODIFY: record rewards for plotting purposes | ||
| for info in infos: | ||
| if "episode" in info.keys(): | ||
| print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | ||
| writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | ||
| writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
| break | ||
|
|
||
| # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` | ||
| real_next_obs = next_obs.copy() | ||
| for idx, d in enumerate(dones): | ||
| if d: | ||
| real_next_obs[idx] = infos[idx]["terminal_observation"] | ||
| rb.add(obs, real_next_obs, actions, rewards, dones, infos) | ||
|
|
||
| # TRY NOT TO MODIFY: CRUCIAL step easy to overlook | ||
| obs = next_obs | ||
|
|
||
| # ALGO LOGIC: training. | ||
| if global_step > args.learning_starts: | ||
| if global_step % args.update_frequency == 0: | ||
| data = rb.sample(args.batch_size) | ||
| # CRITIC training | ||
| with torch.no_grad(): | ||
| _, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations) | ||
| qf1_next_target = qf1_target(data.next_observations) | ||
| qf2_next_target = qf2_target(data.next_observations) | ||
| # we can use the action probabilities instead of MC sampling to estimate the expectation | ||
| min_qf_next_target = next_state_action_probs * ( | ||
timoklein marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi | ||
| ) | ||
| # adapt Q-target for discrete Q-function | ||
| min_qf_next_target = min_qf_next_target.sum(dim=1) | ||
| next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target) | ||
|
|
||
| # use Q-values only for the taken actions | ||
| qf1_values = qf1(data.observations) | ||
| qf2_values = qf2(data.observations) | ||
| qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1) | ||
| qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1) | ||
| qf1_loss = F.mse_loss(qf1_a_values, next_q_value) | ||
| qf2_loss = F.mse_loss(qf2_a_values, next_q_value) | ||
| qf_loss = qf1_loss + qf2_loss | ||
|
|
||
| q_optimizer.zero_grad() | ||
| qf_loss.backward() | ||
| q_optimizer.step() | ||
|
|
||
| # ACTOR training | ||
| _, log_pi, action_probs = actor.get_action(data.observations) | ||
| with torch.no_grad(): | ||
| qf1_values = qf1(data.observations) | ||
| qf2_values = qf2(data.observations) | ||
| min_qf_values = torch.min(qf1_values, qf2_values) | ||
| # no need for reparameterization, the expectation can be calculated for discrete actions | ||
| actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean() | ||
|
|
||
| actor_optimizer.zero_grad() | ||
| actor_loss.backward() | ||
| actor_optimizer.step() | ||
|
|
||
| if args.autotune: | ||
| # re-use action probabilities for temperature loss | ||
| alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean() | ||
|
|
||
| a_optimizer.zero_grad() | ||
| alpha_loss.backward() | ||
| a_optimizer.step() | ||
| alpha = log_alpha.exp().item() | ||
|
|
||
| # update the target networks | ||
| if global_step % args.target_network_frequency == 0: | ||
| for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): | ||
| target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) | ||
| for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): | ||
| target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) | ||
|
|
||
| if global_step % 100 == 0: | ||
| writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) | ||
| writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) | ||
| writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) | ||
| writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) | ||
| writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) | ||
| writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) | ||
| writer.add_scalar("losses/alpha", alpha, global_step) | ||
| print("SPS:", int(global_step / (time.time() - start_time))) | ||
| writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
| if args.autotune: | ||
| writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step) | ||
|
|
||
| envs.close() | ||
| writer.close() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.