1+ # -*- coding: utf-8 -*-
2+ # File : batchnorm.py
3+ # Author : Jiayuan Mao
4+ 5+ # Date : 27/01/2018
6+ #
7+ # This file is part of Synchronized-BatchNorm-PyTorch.
8+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9+ # Distributed under MIT License.
10+
11+ import collections
12+
13+ import torch
14+ import torch .nn .functional as F
15+
16+ from torch .nn .modules .batchnorm import _BatchNorm
17+ from torch .nn .parallel ._functions import ReduceAddCoalesced , Broadcast
18+
19+ from .comm import SyncMaster
20+
21+ __all__ = ['SynchronizedBatchNorm1d' , 'SynchronizedBatchNorm2d' , 'SynchronizedBatchNorm3d' ]
22+
23+
24+ def _sum_ft (tensor ):
25+ """sum over the first and last dimention"""
26+ return tensor .sum (dim = 0 ).sum (dim = - 1 )
27+
28+
29+ def _unsqueeze_ft (tensor ):
30+ """add new dementions at the front and the tail"""
31+ return tensor .unsqueeze (0 ).unsqueeze (- 1 )
32+
33+
34+ _ChildMessage = collections .namedtuple ('_ChildMessage' , ['sum' , 'ssum' , 'sum_size' ])
35+ _MasterMessage = collections .namedtuple ('_MasterMessage' , ['sum' , 'inv_std' ])
36+
37+
38+ class _SynchronizedBatchNorm (_BatchNorm ):
39+ def __init__ (self , num_features , eps = 1e-5 , momentum = 0.1 , affine = True ):
40+ super (_SynchronizedBatchNorm , self ).__init__ (num_features , eps = eps , momentum = momentum , affine = affine )
41+
42+ self ._sync_master = SyncMaster (self ._data_parallel_master )
43+
44+ self ._is_parallel = False
45+ self ._parallel_id = None
46+ self ._slave_pipe = None
47+
48+ def forward (self , input ):
49+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50+ if not (self ._is_parallel and self .training ):
51+ return F .batch_norm (
52+ input , self .running_mean , self .running_var , self .weight , self .bias ,
53+ self .training , self .momentum , self .eps )
54+
55+ # Resize the input to (B, C, -1).
56+ input_shape = input .size ()
57+ input = input .view (input .size (0 ), self .num_features , - 1 )
58+
59+ # Compute the sum and square-sum.
60+ sum_size = input .size (0 ) * input .size (2 )
61+ input_sum = _sum_ft (input )
62+ input_ssum = _sum_ft (input ** 2 )
63+
64+ # Reduce-and-broadcast the statistics.
65+ if self ._parallel_id == 0 :
66+ mean , inv_std = self ._sync_master .run_master (_ChildMessage (input_sum , input_ssum , sum_size ))
67+ else :
68+ mean , inv_std = self ._slave_pipe .run_slave (_ChildMessage (input_sum , input_ssum , sum_size ))
69+
70+ # Compute the output.
71+ if self .affine :
72+ # MJY:: Fuse the multiplication for speed.
73+ output = (input - _unsqueeze_ft (mean )) * _unsqueeze_ft (inv_std * self .weight ) + _unsqueeze_ft (self .bias )
74+ else :
75+ output = (input - _unsqueeze_ft (mean )) * _unsqueeze_ft (inv_std )
76+
77+ # Reshape it.
78+ return output .view (input_shape )
79+
80+ def __data_parallel_replicate__ (self , ctx , copy_id ):
81+ self ._is_parallel = True
82+ self ._parallel_id = copy_id
83+
84+ # parallel_id == 0 means master device.
85+ if self ._parallel_id == 0 :
86+ ctx .sync_master = self ._sync_master
87+ else :
88+ self ._slave_pipe = ctx .sync_master .register_slave (copy_id )
89+
90+ def _data_parallel_master (self , intermediates ):
91+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92+
93+ # Always using same "device order" makes the ReduceAdd operation faster.
94+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
95+ intermediates = sorted (intermediates , key = lambda i : i [1 ].sum .get_device ())
96+
97+ to_reduce = [i [1 ][:2 ] for i in intermediates ]
98+ to_reduce = [j for i in to_reduce for j in i ] # flatten
99+ target_gpus = [i [1 ].sum .get_device () for i in intermediates ]
100+
101+ sum_size = sum ([i [1 ].sum_size for i in intermediates ])
102+ sum_ , ssum = ReduceAddCoalesced .apply (target_gpus [0 ], 2 , * to_reduce )
103+ mean , inv_std = self ._compute_mean_std (sum_ , ssum , sum_size )
104+
105+ broadcasted = Broadcast .apply (target_gpus , mean , inv_std )
106+
107+ outputs = []
108+ for i , rec in enumerate (intermediates ):
109+ outputs .append ((rec [0 ], _MasterMessage (* broadcasted [i * 2 :i * 2 + 2 ])))
110+
111+ return outputs
112+
113+ def _compute_mean_std (self , sum_ , ssum , size ):
114+ """Compute the mean and standard-deviation with sum and square-sum. This method
115+ also maintains the moving average on the master device."""
116+ assert size > 1 , 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117+ mean = sum_ / size
118+ sumvar = ssum - sum_ * mean
119+ unbias_var = sumvar / (size - 1 )
120+ bias_var = sumvar / size
121+
122+ self .running_mean = (1 - self .momentum ) * self .running_mean + self .momentum * mean .data
123+ self .running_var = (1 - self .momentum ) * self .running_var + self .momentum * unbias_var .data
124+
125+ return mean , bias_var .clamp (self .eps ) ** - 0.5
126+
127+
128+ class SynchronizedBatchNorm1d (_SynchronizedBatchNorm ):
129+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130+ mini-batch.
131+ .. math::
132+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
134+ standard-deviation are reduced across all devices during training.
135+ For example, when one uses `nn.DataParallel` to wrap the network during
136+ training, PyTorch's implementation normalize the tensor on each device using
137+ the statistics only on that device, which accelerated the computation and
138+ is also easy to implement, but the statistics might be inaccurate.
139+ Instead, in this synchronized version, the statistics will be computed
140+ over all training samples distributed on multiple devices.
141+
142+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
143+ as the built-in PyTorch implementation.
144+ The mean and standard-deviation are calculated per-dimension over
145+ the mini-batches and gamma and beta are learnable parameter vectors
146+ of size C (where C is the input size).
147+ During training, this layer keeps a running estimate of its computed mean
148+ and variance. The running sum is kept with a default momentum of 0.1.
149+ During evaluation, this running mean/variance is used for normalization.
150+ Because the BatchNorm is done over the `C` dimension, computing statistics
151+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152+ Args:
153+ num_features: num_features from an expected input of size
154+ `batch_size x num_features [x width]`
155+ eps: a value added to the denominator for numerical stability.
156+ Default: 1e-5
157+ momentum: the value used for the running_mean and running_var
158+ computation. Default: 0.1
159+ affine: a boolean value that when set to ``True``, gives the layer learnable
160+ affine parameters. Default: ``True``
161+ Shape:
162+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
163+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164+ Examples:
165+ >>> # With Learnable Parameters
166+ >>> m = SynchronizedBatchNorm1d(100)
167+ >>> # Without Learnable Parameters
168+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
169+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
170+ >>> output = m(input)
171+ """
172+
173+ def _check_input_dim (self , input ):
174+ if input .dim () != 2 and input .dim () != 3 :
175+ raise ValueError ('expected 2D or 3D input (got {}D input)'
176+ .format (input .dim ()))
177+ super (SynchronizedBatchNorm1d , self )._check_input_dim (input )
178+
179+
180+ class SynchronizedBatchNorm2d (_SynchronizedBatchNorm ):
181+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182+ of 3d inputs
183+ .. math::
184+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
186+ standard-deviation are reduced across all devices during training.
187+ For example, when one uses `nn.DataParallel` to wrap the network during
188+ training, PyTorch's implementation normalize the tensor on each device using
189+ the statistics only on that device, which accelerated the computation and
190+ is also easy to implement, but the statistics might be inaccurate.
191+ Instead, in this synchronized version, the statistics will be computed
192+ over all training samples distributed on multiple devices.
193+
194+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
195+ as the built-in PyTorch implementation.
196+ The mean and standard-deviation are calculated per-dimension over
197+ the mini-batches and gamma and beta are learnable parameter vectors
198+ of size C (where C is the input size).
199+ During training, this layer keeps a running estimate of its computed mean
200+ and variance. The running sum is kept with a default momentum of 0.1.
201+ During evaluation, this running mean/variance is used for normalization.
202+ Because the BatchNorm is done over the `C` dimension, computing statistics
203+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204+ Args:
205+ num_features: num_features from an expected input of
206+ size batch_size x num_features x height x width
207+ eps: a value added to the denominator for numerical stability.
208+ Default: 1e-5
209+ momentum: the value used for the running_mean and running_var
210+ computation. Default: 0.1
211+ affine: a boolean value that when set to ``True``, gives the layer learnable
212+ affine parameters. Default: ``True``
213+ Shape:
214+ - Input: :math:`(N, C, H, W)`
215+ - Output: :math:`(N, C, H, W)` (same shape as input)
216+ Examples:
217+ >>> # With Learnable Parameters
218+ >>> m = SynchronizedBatchNorm2d(100)
219+ >>> # Without Learnable Parameters
220+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
221+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222+ >>> output = m(input)
223+ """
224+
225+ def _check_input_dim (self , input ):
226+ if input .dim () != 4 :
227+ raise ValueError ('expected 4D input (got {}D input)'
228+ .format (input .dim ()))
229+ super (SynchronizedBatchNorm2d , self )._check_input_dim (input )
230+
231+
232+ class SynchronizedBatchNorm3d (_SynchronizedBatchNorm ):
233+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234+ of 4d inputs
235+ .. math::
236+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
238+ standard-deviation are reduced across all devices during training.
239+ For example, when one uses `nn.DataParallel` to wrap the network during
240+ training, PyTorch's implementation normalize the tensor on each device using
241+ the statistics only on that device, which accelerated the computation and
242+ is also easy to implement, but the statistics might be inaccurate.
243+ Instead, in this synchronized version, the statistics will be computed
244+ over all training samples distributed on multiple devices.
245+
246+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
247+ as the built-in PyTorch implementation.
248+ The mean and standard-deviation are calculated per-dimension over
249+ the mini-batches and gamma and beta are learnable parameter vectors
250+ of size C (where C is the input size).
251+ During training, this layer keeps a running estimate of its computed mean
252+ and variance. The running sum is kept with a default momentum of 0.1.
253+ During evaluation, this running mean/variance is used for normalization.
254+ Because the BatchNorm is done over the `C` dimension, computing statistics
255+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256+ or Spatio-temporal BatchNorm
257+ Args:
258+ num_features: num_features from an expected input of
259+ size batch_size x num_features x depth x height x width
260+ eps: a value added to the denominator for numerical stability.
261+ Default: 1e-5
262+ momentum: the value used for the running_mean and running_var
263+ computation. Default: 0.1
264+ affine: a boolean value that when set to ``True``, gives the layer learnable
265+ affine parameters. Default: ``True``
266+ Shape:
267+ - Input: :math:`(N, C, D, H, W)`
268+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
269+ Examples:
270+ >>> # With Learnable Parameters
271+ >>> m = SynchronizedBatchNorm3d(100)
272+ >>> # Without Learnable Parameters
273+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
274+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275+ >>> output = m(input)
276+ """
277+
278+ def _check_input_dim (self , input ):
279+ if input .dim () != 5 :
280+ raise ValueError ('expected 5D input (got {}D input)'
281+ .format (input .dim ()))
282+ super (SynchronizedBatchNorm3d , self )._check_input_dim (input )
0 commit comments