44"""BatchNorm (BN) utility functions and custom batch-size BN implementations"""
55
66from functools import partial
7+
78import torch
8- import torch .distributed as dist
99import torch .nn as nn
10- from torch .autograd .function import Function
11-
12- import slowfast .utils .distributed as du
10+ from pytorchvideo .layers .batch_norm import NaiveSyncBatchNorm3d , NaiveSyncBatchNorm1d # noqa
1311
1412
1513def get_norm (cfg ):
@@ -26,7 +24,8 @@ def get_norm(cfg):
2624 return partial (SubBatchNorm3d , num_splits = cfg .BN .NUM_SPLITS )
2725 elif cfg .BN .NORM_TYPE == "sync_batchnorm" :
2826 return partial (
29- NaiveSyncBatchNorm3d , num_sync_devices = cfg .BN .NUM_SYNC_DEVICES
27+ NaiveSyncBatchNorm3d , num_sync_devices = cfg .BN .NUM_SYNC_DEVICES ,
28+ global_sync = cfg .BN .GLOBAL_SYNC
3029 )
3130 else :
3231 raise NotImplementedError (
@@ -107,159 +106,3 @@ def forward(self, x):
107106 x = x * self .weight .view ((- 1 , 1 , 1 , 1 ))
108107 x = x + self .bias .view ((- 1 , 1 , 1 , 1 ))
109108 return x
110-
111-
112- class GroupGather (Function ):
113- """
114- GroupGather performs all gather on each of the local process/ GPU groups.
115- """
116-
117- @staticmethod
118- def forward (ctx , input , num_sync_devices , num_groups ):
119- """
120- Perform forwarding, gathering the stats across different process/ GPU
121- group.
122- """
123- ctx .num_sync_devices = num_sync_devices
124- ctx .num_groups = num_groups
125-
126- input_list = [
127- torch .zeros_like (input ) for k in range (du .get_local_size ())
128- ]
129- dist .all_gather (
130- input_list , input , async_op = False , group = du ._LOCAL_PROCESS_GROUP
131- )
132-
133- inputs = torch .stack (input_list , dim = 0 )
134- if num_groups > 1 :
135- rank = du .get_local_rank ()
136- group_idx = rank // num_sync_devices
137- inputs = inputs [
138- group_idx
139- * num_sync_devices : (group_idx + 1 )
140- * num_sync_devices
141- ]
142- inputs = torch .sum (inputs , dim = 0 )
143- return inputs
144-
145- @staticmethod
146- def backward (ctx , grad_output ):
147- """
148- Perform backwarding, gathering the gradients across different process/ GPU
149- group.
150- """
151- grad_output_list = [
152- torch .zeros_like (grad_output ) for k in range (du .get_local_size ())
153- ]
154- dist .all_gather (
155- grad_output_list ,
156- grad_output ,
157- async_op = False ,
158- group = du ._LOCAL_PROCESS_GROUP ,
159- )
160-
161- grads = torch .stack (grad_output_list , dim = 0 )
162- if ctx .num_groups > 1 :
163- rank = du .get_local_rank ()
164- group_idx = rank // ctx .num_sync_devices
165- grads = grads [
166- group_idx
167- * ctx .num_sync_devices : (group_idx + 1 )
168- * ctx .num_sync_devices
169- ]
170- grads = torch .sum (grads , dim = 0 )
171- return grads , None , None
172-
173-
174- class NaiveSyncBatchNorm3d (nn .BatchNorm3d ):
175- def __init__ (self , num_sync_devices , ** args ):
176- """
177- Naive version of Synchronized 3D BatchNorm.
178- Args:
179- num_sync_devices (int): number of device to sync.
180- args (list): other arguments.
181- """
182- self .num_sync_devices = num_sync_devices
183- if self .num_sync_devices > 0 :
184- assert du .get_local_size () % self .num_sync_devices == 0 , (
185- du .get_local_size (),
186- self .num_sync_devices ,
187- )
188- self .num_groups = du .get_local_size () // self .num_sync_devices
189- else :
190- self .num_sync_devices = du .get_local_size ()
191- self .num_groups = 1
192- super (NaiveSyncBatchNorm3d , self ).__init__ (** args )
193-
194- def forward (self , input ):
195- if du .get_local_size () == 1 or not self .training :
196- return super ().forward (input )
197-
198- assert input .shape [0 ] > 0 , "SyncBatchNorm does not support empty inputs"
199- C = input .shape [1 ]
200- mean = torch .mean (input , dim = [0 , 2 , 3 , 4 ])
201- meansqr = torch .mean (input * input , dim = [0 , 2 , 3 , 4 ])
202-
203- vec = torch .cat ([mean , meansqr ], dim = 0 )
204- vec = GroupGather .apply (vec , self .num_sync_devices , self .num_groups ) * (
205- 1.0 / self .num_sync_devices
206- )
207-
208- mean , meansqr = torch .split (vec , C )
209- var = meansqr - mean * mean
210- self .running_mean += self .momentum * (mean .detach () - self .running_mean )
211- self .running_var += self .momentum * (var .detach () - self .running_var )
212-
213- invstd = torch .rsqrt (var + self .eps )
214- scale = self .weight * invstd
215- bias = self .bias - mean * scale
216- scale = scale .reshape (1 , - 1 , 1 , 1 , 1 )
217- bias = bias .reshape (1 , - 1 , 1 , 1 , 1 )
218- return input * scale + bias
219-
220-
221- class NaiveSyncBatchNorm1d (nn .BatchNorm1d ):
222- def __init__ (self , num_sync_devices , ** args ):
223- """
224- Naive version of Synchronized 1D BatchNorm.
225- Args:
226- num_sync_devices (int): number of device to sync.
227- args (list): other arguments.
228- """
229- self .num_sync_devices = num_sync_devices
230- if self .num_sync_devices > 0 :
231- assert du .get_local_size () % self .num_sync_devices == 0 , (
232- du .get_local_size (),
233- self .num_sync_devices ,
234- )
235- self .num_groups = du .get_local_size () // self .num_sync_devices
236- else :
237- self .num_sync_devices = du .get_local_size ()
238- self .num_groups = 1
239- super (NaiveSyncBatchNorm1d , self ).__init__ (** args )
240-
241- def forward (self , input ):
242- if du .get_local_size () == 1 or not self .training :
243- return super ().forward (input )
244-
245- assert input .shape [0 ] > 0 , "SyncBatchNorm does not support empty inputs"
246- C = input .shape [1 ]
247- mean = torch .mean (input , dim = [0 ])
248- meansqr = torch .mean (input * input , dim = [0 ])
249-
250- vec = torch .cat ([mean , meansqr ], dim = 0 )
251- vec = GroupGather .apply (vec , self .num_sync_devices , self .num_groups ) * (
252- 1.0 / self .num_sync_devices
253- )
254-
255- mean , meansqr = torch .split (vec , C )
256- var = meansqr - mean * mean
257- self .running_mean += self .momentum * (mean .detach () - self .running_mean )
258- self .running_var += self .momentum * (var .detach () - self .running_var )
259-
260- invstd = torch .rsqrt (var + self .eps )
261- scale = self .weight * invstd
262- bias = self .bias - mean * scale
263- scale = scale .reshape (1 , - 1 )
264- bias = bias .reshape (1 , - 1 )
265- return input * scale + bias
0 commit comments