Skip to content

Commit 4db39b5

Browse files
authored
Merge pull request labmlai#168 from jakehsiao/patch-3
Add activation for timed embedding and dropout for Residual block in DDPM UNet
2 parents bdaa667 + 8053f3f commit 4db39b5

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

labml_nn/diffusion/ddpm/unet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import torch
2828
from torch import nn
29+
import torch.nn.functional as F
2930

3031
from labml_helpers.module import Module
3132

@@ -91,12 +92,13 @@ class ResidualBlock(Module):
9192
Each resolution is processed with two residual blocks.
9293
"""
9394

94-
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
95+
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout_rate: float = 0.1):
9596
"""
9697
* `in_channels` is the number of input channels
9798
* `out_channels` is the number of input channels
9899
* `time_channels` is the number channels in the time step ($t$) embeddings
99100
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
101+
* `dropout_rate` is the dropout rate
100102
"""
101103
super().__init__()
102104
# Group normalization and the first convolution layer
@@ -118,6 +120,7 @@ def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_gr
118120

119121
# Linear layer for time embeddings
120122
self.time_emb = nn.Linear(time_channels, out_channels)
123+
self.time_act = Swish()
121124

122125
def forward(self, x: torch.Tensor, t: torch.Tensor):
123126
"""
@@ -127,9 +130,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
127130
# First convolution layer
128131
h = self.conv1(self.act1(self.norm1(x)))
129132
# Add time embeddings
130-
h += self.time_emb(t)[:, :, None, None]
133+
h += self.time_emb(self.time_act(t))[:, :, None, None]
131134
# Second convolution layer
132-
h = self.conv2(self.act2(self.norm2(h)))
135+
h = self.conv2(F.dropout(self.act2(self.norm2(h)), self.dropout_rate))
133136

134137
# Add the shortcut connection and return
135138
return h + self.shortcut(x)

0 commit comments

Comments
 (0)