Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c1c8168
update optimization
awaelchli Mar 13, 2023
0504638
update inception
awaelchli Mar 13, 2023
da590db
update attention
awaelchli Mar 13, 2023
ec3d50a
incomplete GNN
awaelchli Mar 13, 2023
0b92f1d
update energy models
awaelchli Mar 13, 2023
a5466e4
update deep autoencoders
awaelchli Mar 13, 2023
88722ac
normalizing flows incomplete
awaelchli Mar 13, 2023
47f1cfe
autoregressive
awaelchli Mar 13, 2023
cfab2ee
vit
awaelchli Mar 13, 2023
b4d7be3
meta learning
awaelchli Mar 13, 2023
8ab731c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2023
18c2430
Apply suggestions from code review
awaelchli Mar 13, 2023
000c803
update gnn
awaelchli Mar 13, 2023
3824605
simclr
awaelchli Mar 13, 2023
d1a93e7
update
awaelchli Mar 13, 2023
6e04b38
Merge branch 'upgrade/course' of github.com:Lightning-AI/tutorials in…
awaelchli Mar 13, 2023
9377298
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2023
b24b122
Update course_UvA-DL/05-transformers-and-MH-attention/Transformers_MH…
awaelchli Mar 13, 2023
099c50a
update
awaelchli Mar 13, 2023
82f2be6
links
awaelchli Mar 13, 2023
a6a1a17
Merge branch 'main' into upgrade/course
Borda Mar 14, 2023
5733cc0
Merge branch 'main' into upgrade/course
Borda Mar 14, 2023
28accf9
2.0.0rc0
Borda Mar 14, 2023
10615e9
lightning
Borda Mar 14, 2023
217be55
2.0
Borda Mar 14, 2023
1dc811a
Apply suggestions from code review
awaelchli Mar 14, 2023
bb1e40a
docs build fix
awaelchli Mar 14, 2023
835a628
stupid sphinx
awaelchli Mar 14, 2023
94e6725
Merge branch 'main' into upgrade/course
Borda Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
simclr
  • Loading branch information
awaelchli committed Mar 13, 2023
commit 3824605b80f9a674678dac775dd436d30e8f0c7f
41 changes: 21 additions & 20 deletions course_UvA-DL/13-contrastive-learning/SimCLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@

import matplotlib
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import lightning as L
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
import matplotlib_inline.backend_inline
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import STL10
from tqdm.notebook import tqdm

plt.set_cmap("cividis")
# %matplotlib inline
set_matplotlib_formats("svg", "pdf") # For export
matplotlib_inline.backend_inline.set_matplotlib_formats("svg", "pdf") # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.set()

Expand All @@ -67,7 +67,7 @@
NUM_WORKERS = os.cpu_count()

# Setting the seed
pl.seed_everything(42)
L.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
Expand Down Expand Up @@ -215,7 +215,7 @@ def __call__(self, x):

# %%
# Visualize some examples
pl.seed_everything(42)
L.seed_everything(42)
NUM_IMAGES = 6
imgs = torch.stack([img for idx in range(NUM_IMAGES) for img in unlabeled_data[idx][0]], dim=0)
img_grid = torchvision.utils.make_grid(imgs, nrow=6, normalize=True, pad_value=0.9)
Expand Down Expand Up @@ -275,7 +275,7 @@ def __call__(self, x):


# %%
class SimCLR(pl.LightningModule):
class SimCLR(L.LightningModule):
def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=500):
super().__init__()
self.save_hyperparameters()
Expand Down Expand Up @@ -355,15 +355,15 @@ def validation_step(self, batch, batch_idx):

# %%
def train_simclr(batch_size, max_epochs=500, **kwargs):
trainer = pl.Trainer(
trainer = L.Trainer(
default_root_dir=os.path.join(CHECKPOINT_PATH, "SimCLR"),
gpus=1 if str(device) == "cuda:0" else 0,
accelerator="auto",
devices=1,
max_epochs=max_epochs,
callbacks=[
ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc_top5"),
LearningRateMonitor("epoch"),
],
progress_bar_refresh_rate=1,
)
trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

Expand All @@ -390,7 +390,7 @@ def train_simclr(batch_size, max_epochs=500, **kwargs):
pin_memory=True,
num_workers=NUM_WORKERS,
)
pl.seed_everything(42) # To be reproducable
L.seed_everything(42) # To be reproducable
model = SimCLR(max_epochs=max_epochs, **kwargs)
trainer.fit(model, train_loader, val_loader)
# Load best checkpoint after training
Expand Down Expand Up @@ -442,7 +442,7 @@ def train_simclr(batch_size, max_epochs=500, **kwargs):


# %%
class LogisticRegression(pl.LightningModule):
class LogisticRegression(L.LightningModule):
def __init__(self, feature_dim, num_classes, lr, weight_decay, max_epochs=100):
super().__init__()
self.save_hyperparameters()
Expand Down Expand Up @@ -538,15 +538,16 @@ def prepare_data_features(model, dataset):

# %%
def train_logreg(batch_size, train_feats_data, test_feats_data, model_suffix, max_epochs=100, **kwargs):
trainer = pl.Trainer(
trainer = L.Trainer(
default_root_dir=os.path.join(CHECKPOINT_PATH, "LogisticRegression"),
gpus=1 if str(device) == "cuda:0" else 0,
accelerator="auto",
devices=1,
max_epochs=max_epochs,
callbacks=[
ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
LearningRateMonitor("epoch"),
],
progress_bar_refresh_rate=0,
enable_progress_bar=False,
check_val_every_n_epoch=10,
)
trainer.logger._default_hp_metric = None
Expand Down Expand Up @@ -663,7 +664,7 @@ def get_smaller_dataset(original_dataset, num_imgs_per_label):


# %%
class ResNet(pl.LightningModule):
class ResNet(L.LightningModule):
def __init__(self, num_classes, lr, weight_decay, max_epochs=100):
super().__init__()
self.save_hyperparameters()
Expand Down Expand Up @@ -729,15 +730,15 @@ def test_step(self, batch, batch_idx):

# %%
def train_resnet(batch_size, max_epochs=100, **kwargs):
trainer = pl.Trainer(
trainer = L.Trainer(
default_root_dir=os.path.join(CHECKPOINT_PATH, "ResNet"),
gpus=1 if str(device) == "cuda:0" else 0,
accelerator="auto",
devices=1,
max_epochs=max_epochs,
callbacks=[
ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
LearningRateMonitor("epoch"),
],
progress_bar_refresh_rate=1,
check_val_every_n_epoch=2,
)
trainer.logger._default_hp_metric = None
Expand All @@ -761,7 +762,7 @@ def train_resnet(batch_size, max_epochs=100, **kwargs):
print("Found pretrained model at %s, loading..." % pretrained_filename)
model = ResNet.load_from_checkpoint(pretrained_filename)
else:
pl.seed_everything(42) # To be reproducable
L.seed_everything(42) # To be reproducable
model = ResNet(**kwargs)
trainer.fit(model, train_loader, test_loader)
model = ResNet.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
Expand Down