Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e211e32
augmentation
awaelchli Mar 14, 2023
d7fa299
barlow twins
awaelchli Mar 14, 2023
a42d37f
debug
awaelchli Mar 14, 2023
575b5cc
update
awaelchli Mar 14, 2023
1fe713e
multi
awaelchli Mar 14, 2023
a90e5a2
update
awaelchli Mar 14, 2023
24acb0c
barlow
awaelchli Mar 14, 2023
856baa7
gan
awaelchli Mar 14, 2023
b0fd90b
datamodule
awaelchli Mar 14, 2023
86e67ef
x
awaelchli Mar 14, 2023
145cf1c
update
awaelchli Mar 14, 2023
59475e3
update
awaelchli Mar 14, 2023
7a2411f
update
awaelchli Mar 14, 2023
f2c16d2
update
awaelchli Mar 14, 2023
ccb4c1e
dqn
awaelchli Mar 14, 2023
b934208
debug
awaelchli Mar 14, 2023
afd61ab
x
awaelchli Mar 14, 2023
f78cd2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
6aa48b2
lightning>=2.0.0rc0
Borda Mar 14, 2023
108724e
2.0
Borda Mar 14, 2023
73477fe
Merge branch 'main' into upgrade/examples
awaelchli Mar 14, 2023
1b6eaf3
cifar10 baseline
awaelchli Mar 14, 2023
a871081
update
awaelchli Mar 14, 2023
ef01f36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
a2b0a1d
update gitingore
awaelchli Mar 14, 2023
2795750
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
d2f9246
bolts
awaelchli Mar 14, 2023
6a7a4f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2023
da6f9e3
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
3017f5c
fix datamodules
awaelchli Mar 14, 2023
b1e9acf
revert to old
awaelchli Mar 14, 2023
f95a03f
Merge branch 'main' into upgrade/examples
Borda Mar 14, 2023
9577958
Merge remote-tracking branch 'origin/upgrade/examples' into upgrade/e…
awaelchli Mar 14, 2023
bc383c3
Merge branch 'main' into upgrade/examples
Borda Mar 14, 2023
dfba373
2.0
Borda Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
gan
  • Loading branch information
awaelchli committed Mar 14, 2023
commit 856baa7a117271f85e071fe851843a45c4bbfd22
2 changes: 1 addition & 1 deletion lightning_examples/basic-gan/.meta.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
title: PyTorch Lightning Basic GAN Tutorial
author: PL team
created: 2020-12-21
updated: 2021-06-16
updated: 2023-03-15
license: CC BY-SA
build: 5
tags:
Expand Down
84 changes: 45 additions & 39 deletions lightning_examples/basic-gan/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

Expand All @@ -20,11 +19,11 @@
# ### MNIST DataModule
#
# Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial
# on them or see the [latest release docs](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html).
# on them or see the [latest release docs](https://lightning.ai/docs/pytorch/stable/data/datamodule.html).


# %%
class MNISTDataModule(LightningDataModule):
class MNISTDataModule(L.LightningDataModule):
def __init__(
self,
data_dir: str = PATH_DATASETS,
Expand Down Expand Up @@ -145,7 +144,7 @@ def forward(self, img):


# %%
class GAN(LightningModule):
class GAN(L.LightningModule):
def __init__(
self,
channels,
Expand All @@ -160,6 +159,7 @@ def __init__(
):
super().__init__()
self.save_hyperparameters()
self.automatic_optimization = False

# networks
data_shape = (channels, width, height)
Expand All @@ -176,54 +176,61 @@ def forward(self, z):
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)

def training_step(self, batch, batch_idx, optimizer_idx):
def training_step(self, batch):
imgs, _ = batch

optimizer_g, optimizer_d = self.optimizers()

# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs)

# train generator
if optimizer_idx == 0:

# generate images
self.generated_imgs = self(z)
# generate images
self.toggle_optimizer(optimizer_g)
self.generated_imgs = self(z)

# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, 0)
# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, 0)

# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
self.log("g_loss", g_loss, prog_bar=True)
return g_loss
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
self.log("g_loss", g_loss, prog_bar=True)
self.manual_backward(g_loss)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

# train discriminator
if optimizer_idx == 1:
# Measure discriminator's ability to classify real from generated samples
# Measure discriminator's ability to classify real from generated samples
self.toggle_optimizer(optimizer_d)

# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)

fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
return d_loss
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
self.manual_backward(d_loss)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)

def configure_optimizers(self):
lr = self.hparams.lr
Expand All @@ -246,11 +253,10 @@ def on_validation_epoch_end(self):
# %%
dm = MNISTDataModule()
model = GAN(*dm.dims)
trainer = Trainer(
trainer = L.Trainer(
accelerator="auto",
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
devices=1,
max_epochs=5,
callbacks=[TQDMProgressBar(refresh_rate=20)],
)
trainer.fit(model, dm)

Expand Down
2 changes: 1 addition & 1 deletion lightning_examples/cifar10-baseline/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def create_model():

# %% [markdown]
# ### Lightning Module
# Check out the [`configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#configure-optimizers)
# Check out the [`configure_optimizers`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers)
# method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy
# in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different
# LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
Expand Down
2 changes: 1 addition & 1 deletion lightning_examples/datamodules/.meta.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ description: This notebook will walk you through how to start using Datamodules.
the release of `pytorch-lightning` version 0.9.0, we have included a new class called
`LightningDataModule` to help you decouple data related hooks from your `LightningModule`.
The most up-to-date documentation on datamodules can be found
[here](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html).
[here](https://lightning.ai/docs/pytorch/stable/data/datamodule.html).
requirements:
- torchvision
accelerator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# criteria (a multi-phase extension of ``EarlyStopping`` packaged with FinetuningScheduler), user-specified epoch transitions or a composition of the two (the default mode).
# A [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) training session completes when the
# final phase of the schedule has its stopping criteria met. See
# the [early stopping documentation](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.EarlyStopping.html) for more details on that callback's configuration.
# the [early stopping documentation](https://lightning.ai/docs/pytorch/stable/api/pytorch_lightning.callbacks.EarlyStopping.html) for more details on that callback's configuration.
#
# ![FinetuningScheduler explicit loss animation](fts_explicit_loss_anim.gif){height="272px" width="376px"}

Expand Down Expand Up @@ -118,7 +118,7 @@
# ## Resuming Scheduled Fine-Tuning Training Sessions
#
# Resumption of scheduled fine-tuning training is identical to the continuation of
# [other training sessions](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) with the caveat that the provided checkpoint must have been saved by a [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) session.
# [other training sessions](https://lightning.ai/docs/pytorch/stable/common/trainer.html) with the caveat that the provided checkpoint must have been saved by a [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) session.
# [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) uses [FTSCheckpoint](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSCheckpoint) (an extension of ``ModelCheckpoint``) to maintain schedule state with special metadata.
#
#
Expand All @@ -138,7 +138,7 @@
# trainer.fit(..., ckpt_path="some/path/to/my_kth_best_checkpoint.ckpt")
# ```
#
# Note that similar to the behavior of [ModelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html), (specifically [this PR](https://github.com/Lightning-AI/lightning/pull/12045)),
# Note that similar to the behavior of [ModelCheckpoint](https://lightning.ai/docs/pytorch/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html), (specifically [this PR](https://github.com/Lightning-AI/lightning/pull/12045)),
# when resuming training with a different [FTSCheckpoint](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSCheckpoint) ``dirpath`` from the provided
# checkpoint, the new training session's checkpoint state will be re-initialized at the resumption depth with the provided checkpoint being set as the best checkpoint.

Expand Down
6 changes: 3 additions & 3 deletions lightning_examples/mnist-hello-world/hello-world.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,17 @@ def configure_optimizers(self):
#
# ### Note what the following built-in functions are doing:
#
# 1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#prepare-data) 💾
# 1. [prepare_data()](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#prepare-data) 💾
# - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.
# - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)
#
# 2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#setup) ⚙️
# 2. [setup(stage)](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#setup) ⚙️
# - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).
# - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.
# - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).
# - **Note this runs across all GPUs and it *is* safe to make state assignments here**
#
# 3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.hooks.DataHooks.html#pytorch_lightning.core.hooks.DataHooks.train_dataloader) ♻️
# 3. [x_dataloader()](https://lightning.ai/docs/pytorch/stable/api/pytorch_lightning.core.hooks.DataHooks.html#pytorch_lightning.core.hooks.DataHooks.train_dataloader) ♻️
# - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`


Expand Down
2 changes: 1 addition & 1 deletion lightning_examples/mnist-tpu-training/.meta.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ tags:
- Image
description: In this notebook, we'll train a model on TPUs. Updating one Trainer flag is all you need for that.
The most up to documentation related to TPU training can be found
[here](https://pytorch-lightning.readthedocs.io/en/stable/accelerators/tpu.html).
[here](https://lightning.ai/docs/pytorch/stable/accelerators/tpu.html).
requirements:
- torchvision
accelerator:
Expand Down
2 changes: 1 addition & 1 deletion lightning_examples/mnist-tpu-training/mnist-tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# ### Defining The `MNISTDataModule`
#
# Below we define `MNISTDataModule`. You can learn more about datamodules
# in [docs](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html).
# in [docs](https://lightning.ai/docs/pytorch/stable/data/datamodule.html).


# %%
Expand Down