diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3b1d04e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+*.pth
+*.local
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..a8008ae
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/PPO.py b/PPO.py
index 3b36d33..611a75c 100644
--- a/PPO.py
+++ b/PPO.py
@@ -1,8 +1,13 @@
+import os
+from collections import deque
+import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical
import gym
+from tensorboardX import SummaryWriter
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Memory:
@@ -20,6 +25,9 @@ def clear_memory(self):
del self.rewards[:]
del self.is_terminals[:]
+ def __len__(self):
+ return len(self.rewards)
+
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, n_latent_var):
super(ActorCritic, self).__init__()
@@ -46,17 +54,14 @@ def __init__(self, state_dim, action_dim, n_latent_var):
def forward(self):
raise NotImplementedError
- def act(self, state, memory):
+ def act(self, state):
state = torch.from_numpy(state).float().to(device)
action_probs = self.action_layer(state)
dist = Categorical(action_probs)
action = dist.sample()
+ log_prob = dist.log_prob(action)
- memory.states.append(state)
- memory.actions.append(action)
- memory.logprobs.append(dist.log_prob(action))
-
- return action.item()
+ return action.numpy(), log_prob
def evaluate(self, state, action):
action_probs = self.action_layer(state)
@@ -70,7 +75,7 @@ def evaluate(self, state, action):
return action_logprobs, torch.squeeze(state_value), dist_entropy
class PPO:
- def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip):
+ def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip, writer):
self.lr = lr
self.betas = betas
self.gamma = gamma
@@ -83,8 +88,9 @@ def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epoc
self.policy_old.load_state_dict(self.policy.state_dict())
self.MseLoss = nn.MSELoss()
+ self.writer = writer
- def update(self, memory):
+ def update(self, memory):
# Monte Carlo estimate of state rewards:
rewards = []
discounted_reward = 0
@@ -95,7 +101,7 @@ def update(self, memory):
rewards.insert(0, discounted_reward)
# Normalizing the rewards:
- rewards = torch.tensor(rewards).to(device)
+ rewards = torch.stack(rewards).to(device)
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
# convert list to tensor
@@ -104,10 +110,15 @@ def update(self, memory):
old_logprobs = torch.stack(memory.logprobs).to(device).detach()
# Optimize policy for K epochs:
+ policy_losses = []
+ value_losses = []
+ entropy_losses = []
+ losses = []
+
for _ in range(self.K_epochs):
# Evaluating old actions and values :
logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
-
+
# Finding the ratio (pi_theta / pi_theta__old):
ratios = torch.exp(logprobs - old_logprobs.detach())
@@ -115,30 +126,49 @@ def update(self, memory):
advantages = rewards - state_values.detach()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
- loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy
+
+ policy_loss = -torch.min(surr1, surr2)
+ value_loss = 0.5*self.MseLoss(state_values, rewards)
+ entropy_loss = - 0.01*dist_entropy
+ loss = policy_loss + value_loss + entropy_loss
# take gradient step
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
-
+
+ policy_losses.append(policy_loss.mean().item())
+ value_losses.append(value_loss.mean().item())
+ entropy_losses.append(entropy_loss.mean().item())
+ losses.append(loss.mean().item())
+
# Copy new weights into old policy:
self.policy_old.load_state_dict(self.policy.state_dict())
+
+ avg_policy_loss = np.mean(policy_losses)
+ avg_value_loss = np.mean(value_losses)
+ avg_entropy_loss = np.mean(entropy_losses)
+ avg_loss = np.mean(losses)
+
+ return avg_policy_loss, avg_value_loss, avg_entropy_loss, avg_loss
+
+
def main():
############## Hyperparameters ##############
+ experiment_name = "ppo_original"
env_name = "LunarLander-v2"
# creating environment
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = 4
render = False
- solved_reward = 230 # stop training if avg_reward > solved_reward
- log_interval = 20 # print avg reward in the interval
+ average_interval = 100
+ solved_score = 230 # stop training if avg_reward > solved_reward
max_episodes = 50000 # max training episodes
max_timesteps = 300 # max timesteps in one episode
n_latent_var = 64 # number of variables in hidden layer
- update_timestep = 2000 # update policy every n timesteps
+ update_memory_size = 2000 # update policy when memory is maxed out
lr = 0.002
betas = (0.9, 0.999)
gamma = 0.99 # discount factor
@@ -146,63 +176,81 @@ def main():
eps_clip = 0.2 # clip parameter for PPO
random_seed = None
#############################################
-
+
+ exp_dir = os.path.join("experiments", experiment_name)
+ os.makedirs(exp_dir, exist_ok=True)
+ writer = SummaryWriter(exp_dir)
+
if random_seed:
torch.manual_seed(random_seed)
env.seed(random_seed)
memory = Memory()
- ppo = PPO(state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip)
- print(lr,betas)
-
+ ppo = PPO(state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip, writer)
+
# logging variables
- running_reward = 0
- avg_length = 0
- timestep = 0
-
+ scores = deque(maxlen=average_interval)
+ lengths = deque(maxlen=average_interval)
+
# training loop
for i_episode in range(1, max_episodes+1):
+ episode_score = 0
state = env.reset()
for t in range(max_timesteps):
- timestep += 1
-
+
# Running policy_old:
- action = ppo.policy_old.act(state, memory)
- state, reward, done, _ = env.step(action)
+ with torch.no_grad():
+ action, log_prob = ppo.policy_old.act(state)
+ next_state, reward, done, _ = env.step(action)
+
+ # update episode score
+ episode_score += reward
- # Saving reward and is_terminal:
- memory.rewards.append(reward)
+ # Appending to the Memory as tensors
+ memory.states.append(torch.from_numpy(state).float())
+ memory.actions.append(torch.tensor(action).long())
+ memory.logprobs.append(log_prob)
+ memory.rewards.append(torch.tensor(reward).float())
memory.is_terminals.append(done)
- # update if its time
- if timestep % update_timestep == 0:
- ppo.update(memory)
+ # update if the memory is big enough
+ memory_size = len(memory)
+ if memory_size >= update_memory_size:
+ # update ppo
+ avg_policy_loss, avg_value_loss, avg_entropy_loss, avg_loss = ppo.update(memory)
+
+ writer.add_scalar("info/avg_policy_loss", avg_policy_loss, i_episode)
+ writer.add_scalar("info/avg_value_loss", avg_value_loss, i_episode)
+ writer.add_scalar("info/avg_entropy_loss", avg_entropy_loss, i_episode)
+ writer.add_scalar("info/avg_ppo_loss", avg_loss, i_episode)
+
+ # clear memory
memory.clear_memory()
- timestep = 0
-
- running_reward += reward
+
if render:
env.render()
+
+ state = next_state
+
+ # if game is over
if done:
+ scores.append(episode_score)
break
-
- avg_length += t
+
+ # record play length
+ lengths.append(t)
# stop training if avg_reward > solved_reward
- if running_reward > (log_interval*solved_reward):
+ avg_score = np.mean(scores)
+ avg_length = np.mean(lengths)
+ writer.add_scalar("info/avg_score", avg_score, i_episode)
+ writer.add_scalar("info/avg_length", avg_length, i_episode)
+
+ if avg_score > solved_score:
print("########## Solved! ##########")
- torch.save(ppo.policy.state_dict(), './PPO_{}.pth'.format(env_name))
+ torch.save(ppo.policy.state_dict(), './{}/PPO_{}.pth'.format(exp_dir, env_name))
break
- # logging
- if i_episode % log_interval == 0:
- avg_length = int(avg_length/log_interval)
- running_reward = int((running_reward/log_interval))
-
- print('Episode {} \t avg length: {} \t reward: {}'.format(i_episode, avg_length, running_reward))
- running_reward = 0
- avg_length = 0
-
if __name__ == '__main__':
main()