Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
UvA DL Tutorial 9: Updating NF architecture
  • Loading branch information
phlippe committed Jan 4, 2023
commit 0ed14372da7c9475d33168183e4e69ddd164062f
30 changes: 17 additions & 13 deletions course_UvA-DL/09-normalizing-flows/NF_image_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def visualize_dequantization(quants, prior=None):
# Prior over discrete values. If not given, a uniform is assumed
if prior is None:
prior = np.ones(quants, dtype=np.float32) / quants
prior = prior / prior.sum() * quants # In the following, we assume 1 for each value means uniform distribution
prior = prior / prior.sum() # Ensure proper categorical distribution

inp = torch.arange(-4, 4, 0.01).view(-1, 1, 1, 1) # Possible continuous values we want to consider
ldj = torch.zeros(inp.shape[0])
Expand Down Expand Up @@ -798,21 +798,24 @@ def forward(self, x):


class LayerNormChannels(nn.Module):
def __init__(self, c_in):
"""This module applies layer norm across channels in an image.

Has been shown to work well with ResNet connections.
Args:
c_in: Number of channels of the input
def __init__(self, c_in, eps=1e-5):
"""
This module applies layer norm across channels in an image.
Inputs:
c_in - Number of channels of the input
eps - Small constant to stabilize std
"""
super().__init__()
self.layer_norm = nn.LayerNorm(c_in)
self.gamma = nn.Parameter(torch.ones(1, c_in, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, c_in, 1, 1))
self.eps = eps

def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.layer_norm(x)
x = x.permute(0, 3, 1, 2)
return x
mean = x.mean(dim=1, keepdim=True)
var = x.var(dim=1, unbiased=False, keepdim=True)
y = (x - mean) / torch.sqrt(var + self.eps)
y = y * self.gamma + self.beta
return y


class GatedConv(nn.Module):
Expand All @@ -825,7 +828,8 @@ def __init__(self, c_in, c_hidden):
"""
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=3, padding=1),
ConcatELU(),
nn.Conv2d(2 * c_in, c_hidden, kernel_size=3, padding=1),
ConcatELU(),
nn.Conv2d(2 * c_hidden, 2 * c_in, kernel_size=1),
)
Expand Down