Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e211e32
augmentation
awaelchli Mar 14, 2023
d7fa299
barlow twins
awaelchli Mar 14, 2023
a42d37f
debug
awaelchli Mar 14, 2023
575b5cc
update
awaelchli Mar 14, 2023
1fe713e
multi
awaelchli Mar 14, 2023
a90e5a2
update
awaelchli Mar 14, 2023
24acb0c
barlow
awaelchli Mar 14, 2023
856baa7
gan
awaelchli Mar 14, 2023
b0fd90b
datamodule
awaelchli Mar 14, 2023
86e67ef
x
awaelchli Mar 14, 2023
145cf1c
update
awaelchli Mar 14, 2023
59475e3
update
awaelchli Mar 14, 2023
7a2411f
update
awaelchli Mar 14, 2023
f2c16d2
update
awaelchli Mar 14, 2023
ccb4c1e
dqn
awaelchli Mar 14, 2023
b934208
debug
awaelchli Mar 14, 2023
afd61ab
x
awaelchli Mar 14, 2023
f78cd2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
6aa48b2
lightning>=2.0.0rc0
Borda Mar 14, 2023
108724e
2.0
Borda Mar 14, 2023
73477fe
Merge branch 'main' into upgrade/examples
awaelchli Mar 14, 2023
1b6eaf3
cifar10 baseline
awaelchli Mar 14, 2023
a871081
update
awaelchli Mar 14, 2023
ef01f36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
a2b0a1d
update gitingore
awaelchli Mar 14, 2023
2795750
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
d2f9246
bolts
awaelchli Mar 14, 2023
6a7a4f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
da6f9e3
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
3017f5c
fix datamodules
awaelchli Mar 14, 2023
b1e9acf
revert to old
awaelchli Mar 14, 2023
f95a03f
Merge branch 'main' into upgrade/examples
Borda Mar 14, 2023
9577958
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
bc383c3
Merge branch 'main' into upgrade/examples
Borda Mar 14, 2023
dfba373
2.0
Borda Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
revert to old
  • Loading branch information
awaelchli committed Mar 14, 2023
commit b1e9acf3a6a864b8f9897931d0101a4864818181
26 changes: 18 additions & 8 deletions lightning_examples/reinforce-learning-DQN/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import Iterator, List, Tuple

import gym
import lightning as L
import numpy as np
import pandas as pd
import seaborn as sn
import torch
from IPython.display import display
from lightning.pytorch.loggers import CSVLogger
from IPython.core.display import display
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from torch import Tensor, nn
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -176,7 +176,7 @@ def play_step(
action = self.get_action(net, epsilon, device)

# do step in the environment
new_state, reward, done, _, _ = self.env.step(action)
new_state, reward, done, _ = self.env.step(action)

exp = Experience(self.state, action, reward, done, new_state)

Expand All @@ -193,7 +193,7 @@ def play_step(


# %%
class DQNLightning(L.LightningModule):
class DQNLightning(LightningModule):
"""Basic DQN Model."""

def __init__(
Expand Down Expand Up @@ -291,7 +291,17 @@ def get_epsilon(self, start: int, end: int, frames: int) -> float:
return end
return start - (self.global_step / frames) * (start - end)

def training_step(self, batch, batch_idx):
def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss
based on the minibatch recieved.

Args:
batch: current mini batch of replay data
nb_batch: batch number

Returns:
Training loss and log metrics
"""
device = self.get_device(batch)
epsilon = self.get_epsilon(self.hparams.eps_start, self.hparams.eps_end, self.hparams.eps_last_frame)
self.log("epsilon", epsilon)
Expand Down Expand Up @@ -353,9 +363,9 @@ def get_device(self, batch) -> str:

model = DQNLightning()

trainer = L.Trainer(
trainer = Trainer(
accelerator="auto",
devices=1,
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
max_epochs=150,
val_check_interval=50,
logger=CSVLogger(save_dir="logs/"),
Expand Down