Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e211e32
augmentation
awaelchli Mar 14, 2023
d7fa299
barlow twins
awaelchli Mar 14, 2023
a42d37f
debug
awaelchli Mar 14, 2023
575b5cc
update
awaelchli Mar 14, 2023
1fe713e
multi
awaelchli Mar 14, 2023
a90e5a2
update
awaelchli Mar 14, 2023
24acb0c
barlow
awaelchli Mar 14, 2023
856baa7
gan
awaelchli Mar 14, 2023
b0fd90b
datamodule
awaelchli Mar 14, 2023
86e67ef
x
awaelchli Mar 14, 2023
145cf1c
update
awaelchli Mar 14, 2023
59475e3
update
awaelchli Mar 14, 2023
7a2411f
update
awaelchli Mar 14, 2023
f2c16d2
update
awaelchli Mar 14, 2023
ccb4c1e
dqn
awaelchli Mar 14, 2023
b934208
debug
awaelchli Mar 14, 2023
afd61ab
x
awaelchli Mar 14, 2023
f78cd2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
6aa48b2
lightning>=2.0.0rc0
Borda Mar 14, 2023
108724e
2.0
Borda Mar 14, 2023
73477fe
Merge branch 'main' into upgrade/examples
awaelchli Mar 14, 2023
1b6eaf3
cifar10 baseline
awaelchli Mar 14, 2023
a871081
update
awaelchli Mar 14, 2023
ef01f36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
a2b0a1d
update gitingore
awaelchli Mar 14, 2023
2795750
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
d2f9246
bolts
awaelchli Mar 14, 2023
6a7a4f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
da6f9e3
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
3017f5c
fix datamodules
awaelchli Mar 14, 2023
b1e9acf
revert to old
awaelchli Mar 14, 2023
f95a03f
Merge branch 'main' into upgrade/examples
Borda Mar 14, 2023
9577958
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
bc383c3
Merge branch 'main' into upgrade/examples
Borda Mar 14, 2023
dfba373
2.0
Borda Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
datamodule
  • Loading branch information
awaelchli committed Mar 14, 2023
commit b0fd90bc84220922101c66989852d769451534f6
2 changes: 1 addition & 1 deletion lightning_examples/datamodules/.meta.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
title: PyTorch Lightning DataModules
author: PL team
created: 2020-12-21
updated: 2021-06-07
updated: 2023-03-15
license: CC BY-SA
build: 3
description: This notebook will walk you through how to start using Datamodules. With
Expand Down
33 changes: 14 additions & 19 deletions lightning_examples/datamodules/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import lightning as L
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
Expand All @@ -34,7 +33,7 @@


# %%
class LitMNIST(LightningModule):
class LitMNIST(L.LightningModule):
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
super().__init__()

Expand Down Expand Up @@ -69,13 +68,13 @@ def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)

def training_step(self, batch, batch_idx):
def training_step(self, batch):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss

def validation_step(self, batch, batch_idx):
def validation_step(self, batch):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
Expand Down Expand Up @@ -122,11 +121,10 @@ def test_dataloader(self):

# %%
model = LitMNIST()
trainer = Trainer(
trainer = L.Trainer(
max_epochs=2,
accelerator="auto",
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
callbacks=[TQDMProgressBar(refresh_rate=20)],
devices=1,
)
trainer.fit(model)

Expand Down Expand Up @@ -163,7 +161,7 @@ def test_dataloader(self):


# %%
class MNISTDataModule(LightningDataModule):
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = PATH_DATASETS):
super().__init__()
self.data_dir = data_dir
Expand Down Expand Up @@ -211,7 +209,7 @@ def test_dataloader(self):


# %%
class LitModel(LightningModule):
class LitModel(L.LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
super().__init__()

Expand All @@ -238,13 +236,13 @@ def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)

def training_step(self, batch, batch_idx):
def training_step(self, batch):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss

def validation_step(self, batch, batch_idx):
def validation_step(self, batch):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
Expand All @@ -269,11 +267,10 @@ def configure_optimizers(self):
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = Trainer(
trainer = L.Trainer(
max_epochs=3,
callbacks=[TQDMProgressBar(refresh_rate=20)],
accelerator="auto",
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
devices=1,
)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)
Expand All @@ -285,7 +282,7 @@ def configure_optimizers(self):


# %%
class CIFAR10DataModule(LightningDataModule):
class CIFAR10DataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
Expand Down Expand Up @@ -334,11 +331,9 @@ def test_dataloader(self):
# %%
dm = CIFAR10DataModule()
model = LitModel(*dm.dims, dm.num_classes, hidden_size=256)
tqdm_progress_bar = TQDMProgressBar(refresh_rate=20)
trainer = Trainer(
max_epochs=5,
accelerator="auto",
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
callbacks=[tqdm_progress_bar],
devices=1,
)
trainer.fit(model, dm)