2626
2727import torch
2828from torch import nn
29+ import torch .nn .functional as F
2930
3031from labml_helpers .module import Module
3132
@@ -91,12 +92,13 @@ class ResidualBlock(Module):
9192 Each resolution is processed with two residual blocks.
9293 """
9394
94- def __init__ (self , in_channels : int , out_channels : int , time_channels : int , n_groups : int = 32 ):
95+ def __init__ (self , in_channels : int , out_channels : int , time_channels : int , n_groups : int = 32 , dropout_rate : float = 0.1 ):
9596 """
9697 * `in_channels` is the number of input channels
9798 * `out_channels` is the number of input channels
9899 * `time_channels` is the number channels in the time step ($t$) embeddings
99100 * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
101+ * `dropout_rate` is the dropout rate
100102 """
101103 super ().__init__ ()
102104 # Group normalization and the first convolution layer
@@ -118,6 +120,7 @@ def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_gr
118120
119121 # Linear layer for time embeddings
120122 self .time_emb = nn .Linear (time_channels , out_channels )
123+ self .time_act = Swish ()
121124
122125 def forward (self , x : torch .Tensor , t : torch .Tensor ):
123126 """
@@ -127,9 +130,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
127130 # First convolution layer
128131 h = self .conv1 (self .act1 (self .norm1 (x )))
129132 # Add time embeddings
130- h += self .time_emb (t )[:, :, None , None ]
133+ h += self .time_emb (self . time_act ( t ) )[:, :, None , None ]
131134 # Second convolution layer
132- h = self .conv2 (self .act2 (self .norm2 (h )))
135+ h = self .conv2 (F . dropout ( self .act2 (self .norm2 (h )), self . dropout_rate ))
133136
134137 # Add the shortcut connection and return
135138 return h + self .shortcut (x )
0 commit comments