Skip to content

Commit 0638089

Browse files
authored
Merge pull request #18 from JudasDie/zpzhang
Fix BN error
2 parents 5146b32 + bb6c276 commit 0638089

File tree

9 files changed

+605
-6
lines changed

9 files changed

+605
-6
lines changed

main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
SCALE = 0.125
2020

21+
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'
2122
def parse_options():
2223
parser = argparse.ArgumentParser()
2324
parser.add_argument('--frame_num', '-n', type=int, default=10,
@@ -55,8 +56,9 @@ def parse_options():
5556

5657
def main(args):
5758

58-
model = modeling.VOSNet(model=args.model).cuda()
59-
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
59+
# model = modeling.VOSNet(model=args.model).cuda()
60+
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
61+
model = modeling.VOSNet(model=args.model, sync_bn=True).cuda()
6062
model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)
6163

6264
criterion = CrossEntropy(temperature=args.temperature).cuda()

modeling/network.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,37 @@
11
import torch.nn as nn
22

33
from modeling.backbone.resnet import resnet18, resnet50, resnet101
4+
from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d # additional codes
45

56

67
class VOSNet(nn.Module):
78

89
def __init__(self,
9-
model='resnet18'):
10+
model='resnet18', sync_bn=False):
1011

1112
super(VOSNet, self).__init__()
1213
self.model = model
14+
15+
# additional codes
16+
if sync_bn:
17+
print("Using SynchronizedBatchNorm2d.")
18+
BatchNorm = SynchronizedBatchNorm2d
19+
else:
20+
BatchNorm = nn.BatchNorm2d
1321

1422
if model == 'resnet18':
15-
resnet = resnet18(pretrained=True)
23+
# resnet = resnet18(pretrained=True)
24+
resnet = resnet18(pretrained=True, BatchNorm=BatchNorm) # additional codes
1625
self.backbone = nn.Sequential(*list(resnet.children())[0:8])
1726
elif model == 'resnet50':
18-
resnet = resnet50(pretrained=True)
27+
# resnet = resnet50(pretrained=True)
28+
resnet = resnet50(pretrained=True, BatchNorm=BatchNorm) # additional codes
1929
self.backbone = nn.Sequential(*list(resnet.children())[0:8])
2030
self.adjust_dim = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0, bias=False)
2131
self.bn256 = nn.BatchNorm2d(256)
2232
elif model == 'resnet101':
23-
resnet = resnet101(pretrained=True)
33+
# resnet = resnet101(pretrained=True)
34+
resnet = resnet101(pretrained=True, BatchNorm=BatchNorm) # additional codes
2435
self.backbone = nn.Sequential(*list(resnet.children())[0:8])
2536
self.adjust_dim = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0, bias=False)
2637
self.bn256 = nn.BatchNorm2d(256)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
# File : __init__.py
3+
# Author : Jiayuan Mao
4+
5+
# Date : 27/01/2018
6+
#
7+
# This file is part of Synchronized-BatchNorm-PyTorch.
8+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9+
# Distributed under MIT License.
10+
11+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12+
from .replicate import DataParallelWithCallback, patch_replication_callback
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# -*- coding: utf-8 -*-
2+
# File : batchnorm.py
3+
# Author : Jiayuan Mao
4+
5+
# Date : 27/01/2018
6+
#
7+
# This file is part of Synchronized-BatchNorm-PyTorch.
8+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9+
# Distributed under MIT License.
10+
11+
import collections
12+
13+
import torch
14+
import torch.nn.functional as F
15+
16+
from torch.nn.modules.batchnorm import _BatchNorm
17+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18+
19+
from .comm import SyncMaster
20+
21+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22+
23+
24+
def _sum_ft(tensor):
25+
"""sum over the first and last dimention"""
26+
return tensor.sum(dim=0).sum(dim=-1)
27+
28+
29+
def _unsqueeze_ft(tensor):
30+
"""add new dementions at the front and the tail"""
31+
return tensor.unsqueeze(0).unsqueeze(-1)
32+
33+
34+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36+
37+
38+
class _SynchronizedBatchNorm(_BatchNorm):
39+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41+
42+
self._sync_master = SyncMaster(self._data_parallel_master)
43+
44+
self._is_parallel = False
45+
self._parallel_id = None
46+
self._slave_pipe = None
47+
48+
def forward(self, input):
49+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50+
if not (self._is_parallel and self.training):
51+
return F.batch_norm(
52+
input, self.running_mean, self.running_var, self.weight, self.bias,
53+
self.training, self.momentum, self.eps)
54+
55+
# Resize the input to (B, C, -1).
56+
input_shape = input.size()
57+
input = input.view(input.size(0), self.num_features, -1)
58+
59+
# Compute the sum and square-sum.
60+
sum_size = input.size(0) * input.size(2)
61+
input_sum = _sum_ft(input)
62+
input_ssum = _sum_ft(input ** 2)
63+
64+
# Reduce-and-broadcast the statistics.
65+
if self._parallel_id == 0:
66+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67+
else:
68+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69+
70+
# Compute the output.
71+
if self.affine:
72+
# MJY:: Fuse the multiplication for speed.
73+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74+
else:
75+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76+
77+
# Reshape it.
78+
return output.view(input_shape)
79+
80+
def __data_parallel_replicate__(self, ctx, copy_id):
81+
self._is_parallel = True
82+
self._parallel_id = copy_id
83+
84+
# parallel_id == 0 means master device.
85+
if self._parallel_id == 0:
86+
ctx.sync_master = self._sync_master
87+
else:
88+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89+
90+
def _data_parallel_master(self, intermediates):
91+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92+
93+
# Always using same "device order" makes the ReduceAdd operation faster.
94+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
95+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96+
97+
to_reduce = [i[1][:2] for i in intermediates]
98+
to_reduce = [j for i in to_reduce for j in i] # flatten
99+
target_gpus = [i[1].sum.get_device() for i in intermediates]
100+
101+
sum_size = sum([i[1].sum_size for i in intermediates])
102+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104+
105+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106+
107+
outputs = []
108+
for i, rec in enumerate(intermediates):
109+
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110+
111+
return outputs
112+
113+
def _compute_mean_std(self, sum_, ssum, size):
114+
"""Compute the mean and standard-deviation with sum and square-sum. This method
115+
also maintains the moving average on the master device."""
116+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117+
mean = sum_ / size
118+
sumvar = ssum - sum_ * mean
119+
unbias_var = sumvar / (size - 1)
120+
bias_var = sumvar / size
121+
122+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124+
125+
return mean, bias_var.clamp(self.eps) ** -0.5
126+
127+
128+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130+
mini-batch.
131+
.. math::
132+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
134+
standard-deviation are reduced across all devices during training.
135+
For example, when one uses `nn.DataParallel` to wrap the network during
136+
training, PyTorch's implementation normalize the tensor on each device using
137+
the statistics only on that device, which accelerated the computation and
138+
is also easy to implement, but the statistics might be inaccurate.
139+
Instead, in this synchronized version, the statistics will be computed
140+
over all training samples distributed on multiple devices.
141+
142+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
143+
as the built-in PyTorch implementation.
144+
The mean and standard-deviation are calculated per-dimension over
145+
the mini-batches and gamma and beta are learnable parameter vectors
146+
of size C (where C is the input size).
147+
During training, this layer keeps a running estimate of its computed mean
148+
and variance. The running sum is kept with a default momentum of 0.1.
149+
During evaluation, this running mean/variance is used for normalization.
150+
Because the BatchNorm is done over the `C` dimension, computing statistics
151+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152+
Args:
153+
num_features: num_features from an expected input of size
154+
`batch_size x num_features [x width]`
155+
eps: a value added to the denominator for numerical stability.
156+
Default: 1e-5
157+
momentum: the value used for the running_mean and running_var
158+
computation. Default: 0.1
159+
affine: a boolean value that when set to ``True``, gives the layer learnable
160+
affine parameters. Default: ``True``
161+
Shape:
162+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
163+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164+
Examples:
165+
>>> # With Learnable Parameters
166+
>>> m = SynchronizedBatchNorm1d(100)
167+
>>> # Without Learnable Parameters
168+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
169+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
170+
>>> output = m(input)
171+
"""
172+
173+
def _check_input_dim(self, input):
174+
if input.dim() != 2 and input.dim() != 3:
175+
raise ValueError('expected 2D or 3D input (got {}D input)'
176+
.format(input.dim()))
177+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
178+
179+
180+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
181+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182+
of 3d inputs
183+
.. math::
184+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
186+
standard-deviation are reduced across all devices during training.
187+
For example, when one uses `nn.DataParallel` to wrap the network during
188+
training, PyTorch's implementation normalize the tensor on each device using
189+
the statistics only on that device, which accelerated the computation and
190+
is also easy to implement, but the statistics might be inaccurate.
191+
Instead, in this synchronized version, the statistics will be computed
192+
over all training samples distributed on multiple devices.
193+
194+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
195+
as the built-in PyTorch implementation.
196+
The mean and standard-deviation are calculated per-dimension over
197+
the mini-batches and gamma and beta are learnable parameter vectors
198+
of size C (where C is the input size).
199+
During training, this layer keeps a running estimate of its computed mean
200+
and variance. The running sum is kept with a default momentum of 0.1.
201+
During evaluation, this running mean/variance is used for normalization.
202+
Because the BatchNorm is done over the `C` dimension, computing statistics
203+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204+
Args:
205+
num_features: num_features from an expected input of
206+
size batch_size x num_features x height x width
207+
eps: a value added to the denominator for numerical stability.
208+
Default: 1e-5
209+
momentum: the value used for the running_mean and running_var
210+
computation. Default: 0.1
211+
affine: a boolean value that when set to ``True``, gives the layer learnable
212+
affine parameters. Default: ``True``
213+
Shape:
214+
- Input: :math:`(N, C, H, W)`
215+
- Output: :math:`(N, C, H, W)` (same shape as input)
216+
Examples:
217+
>>> # With Learnable Parameters
218+
>>> m = SynchronizedBatchNorm2d(100)
219+
>>> # Without Learnable Parameters
220+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
221+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222+
>>> output = m(input)
223+
"""
224+
225+
def _check_input_dim(self, input):
226+
if input.dim() != 4:
227+
raise ValueError('expected 4D input (got {}D input)'
228+
.format(input.dim()))
229+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
230+
231+
232+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
233+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234+
of 4d inputs
235+
.. math::
236+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
238+
standard-deviation are reduced across all devices during training.
239+
For example, when one uses `nn.DataParallel` to wrap the network during
240+
training, PyTorch's implementation normalize the tensor on each device using
241+
the statistics only on that device, which accelerated the computation and
242+
is also easy to implement, but the statistics might be inaccurate.
243+
Instead, in this synchronized version, the statistics will be computed
244+
over all training samples distributed on multiple devices.
245+
246+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
247+
as the built-in PyTorch implementation.
248+
The mean and standard-deviation are calculated per-dimension over
249+
the mini-batches and gamma and beta are learnable parameter vectors
250+
of size C (where C is the input size).
251+
During training, this layer keeps a running estimate of its computed mean
252+
and variance. The running sum is kept with a default momentum of 0.1.
253+
During evaluation, this running mean/variance is used for normalization.
254+
Because the BatchNorm is done over the `C` dimension, computing statistics
255+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256+
or Spatio-temporal BatchNorm
257+
Args:
258+
num_features: num_features from an expected input of
259+
size batch_size x num_features x depth x height x width
260+
eps: a value added to the denominator for numerical stability.
261+
Default: 1e-5
262+
momentum: the value used for the running_mean and running_var
263+
computation. Default: 0.1
264+
affine: a boolean value that when set to ``True``, gives the layer learnable
265+
affine parameters. Default: ``True``
266+
Shape:
267+
- Input: :math:`(N, C, D, H, W)`
268+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
269+
Examples:
270+
>>> # With Learnable Parameters
271+
>>> m = SynchronizedBatchNorm3d(100)
272+
>>> # Without Learnable Parameters
273+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
274+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275+
>>> output = m(input)
276+
"""
277+
278+
def _check_input_dim(self, input):
279+
if input.dim() != 5:
280+
raise ValueError('expected 5D input (got {}D input)'
281+
.format(input.dim()))
282+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)

0 commit comments

Comments
 (0)