Skip to content

Commit fdbf43d

Browse files
authored
Merge branch 'main' into patch-1
2 parents c1de548 + 84cb0ac commit fdbf43d

File tree

5 files changed

+47
-239
lines changed

5 files changed

+47
-239
lines changed

slowfast/config/defaults.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,19 @@
110110
# NUM_SPLITS splits, and run BN on each of them separately independently.
111111
_C.BN.NUM_SPLITS = 1
112112

113-
# Parameter for NaiveSyncBatchNorm3d, where the stats across `NUM_SYNC_DEVICES`
114-
# devices will be synchronized.
113+
# Parameter for NaiveSyncBatchNorm, where the stats across `NUM_SYNC_DEVICES`
114+
# devices will be synchronized. `NUM_SYNC_DEVICES` cannot be larger than number of
115+
# devices per machine; if global sync is desired, set `GLOBAL_SYNC`.
116+
# By default ONLY applies to NaiveSyncBatchNorm3d; consider also setting
117+
# CONTRASTIVE.BN_SYNC_MLP if appropriate.
115118
_C.BN.NUM_SYNC_DEVICES = 1
116119

120+
# Parameter for NaiveSyncBatchNorm. Setting `GLOBAL_SYNC` to True synchronizes
121+
# stats across all devices, across all machines; in this case, `NUM_SYNC_DEVICES`
122+
# must be set to None.
123+
# By default ONLY applies to NaiveSyncBatchNorm3d; consider also setting
124+
# CONTRASTIVE.BN_SYNC_MLP if appropriate.
125+
_C.BN.GLOBAL_SYNC = False
117126

118127
# ---------------------------------------------------------------------------- #
119128
# Training options.

slowfast/models/batchnorm_helper.py

Lines changed: 4 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
"""BatchNorm (BN) utility functions and custom batch-size BN implementations"""
55

66
from functools import partial
7+
78
import torch
8-
import torch.distributed as dist
99
import torch.nn as nn
10-
from torch.autograd.function import Function
11-
12-
import slowfast.utils.distributed as du
10+
from pytorchvideo.layers.batch_norm import NaiveSyncBatchNorm3d, NaiveSyncBatchNorm1d # noqa
1311

1412

1513
def get_norm(cfg):
@@ -26,7 +24,8 @@ def get_norm(cfg):
2624
return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS)
2725
elif cfg.BN.NORM_TYPE == "sync_batchnorm":
2826
return partial(
29-
NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES
27+
NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES,
28+
global_sync=cfg.BN.GLOBAL_SYNC
3029
)
3130
else:
3231
raise NotImplementedError(
@@ -107,159 +106,3 @@ def forward(self, x):
107106
x = x * self.weight.view((-1, 1, 1, 1))
108107
x = x + self.bias.view((-1, 1, 1, 1))
109108
return x
110-
111-
112-
class GroupGather(Function):
113-
"""
114-
GroupGather performs all gather on each of the local process/ GPU groups.
115-
"""
116-
117-
@staticmethod
118-
def forward(ctx, input, num_sync_devices, num_groups):
119-
"""
120-
Perform forwarding, gathering the stats across different process/ GPU
121-
group.
122-
"""
123-
ctx.num_sync_devices = num_sync_devices
124-
ctx.num_groups = num_groups
125-
126-
input_list = [
127-
torch.zeros_like(input) for k in range(du.get_local_size())
128-
]
129-
dist.all_gather(
130-
input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP
131-
)
132-
133-
inputs = torch.stack(input_list, dim=0)
134-
if num_groups > 1:
135-
rank = du.get_local_rank()
136-
group_idx = rank // num_sync_devices
137-
inputs = inputs[
138-
group_idx
139-
* num_sync_devices : (group_idx + 1)
140-
* num_sync_devices
141-
]
142-
inputs = torch.sum(inputs, dim=0)
143-
return inputs
144-
145-
@staticmethod
146-
def backward(ctx, grad_output):
147-
"""
148-
Perform backwarding, gathering the gradients across different process/ GPU
149-
group.
150-
"""
151-
grad_output_list = [
152-
torch.zeros_like(grad_output) for k in range(du.get_local_size())
153-
]
154-
dist.all_gather(
155-
grad_output_list,
156-
grad_output,
157-
async_op=False,
158-
group=du._LOCAL_PROCESS_GROUP,
159-
)
160-
161-
grads = torch.stack(grad_output_list, dim=0)
162-
if ctx.num_groups > 1:
163-
rank = du.get_local_rank()
164-
group_idx = rank // ctx.num_sync_devices
165-
grads = grads[
166-
group_idx
167-
* ctx.num_sync_devices : (group_idx + 1)
168-
* ctx.num_sync_devices
169-
]
170-
grads = torch.sum(grads, dim=0)
171-
return grads, None, None
172-
173-
174-
class NaiveSyncBatchNorm3d(nn.BatchNorm3d):
175-
def __init__(self, num_sync_devices, **args):
176-
"""
177-
Naive version of Synchronized 3D BatchNorm.
178-
Args:
179-
num_sync_devices (int): number of device to sync.
180-
args (list): other arguments.
181-
"""
182-
self.num_sync_devices = num_sync_devices
183-
if self.num_sync_devices > 0:
184-
assert du.get_local_size() % self.num_sync_devices == 0, (
185-
du.get_local_size(),
186-
self.num_sync_devices,
187-
)
188-
self.num_groups = du.get_local_size() // self.num_sync_devices
189-
else:
190-
self.num_sync_devices = du.get_local_size()
191-
self.num_groups = 1
192-
super(NaiveSyncBatchNorm3d, self).__init__(**args)
193-
194-
def forward(self, input):
195-
if du.get_local_size() == 1 or not self.training:
196-
return super().forward(input)
197-
198-
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
199-
C = input.shape[1]
200-
mean = torch.mean(input, dim=[0, 2, 3, 4])
201-
meansqr = torch.mean(input * input, dim=[0, 2, 3, 4])
202-
203-
vec = torch.cat([mean, meansqr], dim=0)
204-
vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * (
205-
1.0 / self.num_sync_devices
206-
)
207-
208-
mean, meansqr = torch.split(vec, C)
209-
var = meansqr - mean * mean
210-
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
211-
self.running_var += self.momentum * (var.detach() - self.running_var)
212-
213-
invstd = torch.rsqrt(var + self.eps)
214-
scale = self.weight * invstd
215-
bias = self.bias - mean * scale
216-
scale = scale.reshape(1, -1, 1, 1, 1)
217-
bias = bias.reshape(1, -1, 1, 1, 1)
218-
return input * scale + bias
219-
220-
221-
class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
222-
def __init__(self, num_sync_devices, **args):
223-
"""
224-
Naive version of Synchronized 1D BatchNorm.
225-
Args:
226-
num_sync_devices (int): number of device to sync.
227-
args (list): other arguments.
228-
"""
229-
self.num_sync_devices = num_sync_devices
230-
if self.num_sync_devices > 0:
231-
assert du.get_local_size() % self.num_sync_devices == 0, (
232-
du.get_local_size(),
233-
self.num_sync_devices,
234-
)
235-
self.num_groups = du.get_local_size() // self.num_sync_devices
236-
else:
237-
self.num_sync_devices = du.get_local_size()
238-
self.num_groups = 1
239-
super(NaiveSyncBatchNorm1d, self).__init__(**args)
240-
241-
def forward(self, input):
242-
if du.get_local_size() == 1 or not self.training:
243-
return super().forward(input)
244-
245-
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
246-
C = input.shape[1]
247-
mean = torch.mean(input, dim=[0])
248-
meansqr = torch.mean(input * input, dim=[0])
249-
250-
vec = torch.cat([mean, meansqr], dim=0)
251-
vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * (
252-
1.0 / self.num_sync_devices
253-
)
254-
255-
mean, meansqr = torch.split(vec, C)
256-
var = meansqr - mean * mean
257-
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
258-
self.running_var += self.momentum * (var.detach() - self.running_var)
259-
260-
invstd = torch.rsqrt(var + self.eps)
261-
scale = self.weight * invstd
262-
bias = self.bias - mean * scale
263-
scale = scale.reshape(1, -1)
264-
bias = bias.reshape(1, -1)
265-
return input * scale + bias

slowfast/models/head_helper.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(
155155
flatten=False,
156156
xavier_init=True,
157157
bn_sync_num=1,
158+
global_sync=False,
158159
):
159160
super(MLPHead, self).__init__()
160161
self.flatten = flatten
@@ -164,10 +165,12 @@ def __init__(
164165
mlp_layers[-1].xavier_init = xavier_init
165166
for i in range(1, num_layers):
166167
if bn_on:
167-
if bn_sync_num > 1:
168+
if global_sync or bn_sync_num > 1:
168169
mlp_layers.append(
169170
NaiveSyncBatchNorm1d(
170-
num_sync_devices=bn_sync_num, num_features=mlp_dim
171+
num_sync_devices=bn_sync_num,
172+
global_sync=global_sync,
173+
num_features=mlp_dim
171174
)
172175
)
173176
else:
@@ -266,6 +269,10 @@ def __init__(
266269
bn_sync_num=cfg.BN.NUM_SYNC_DEVICES
267270
if cfg.CONTRASTIVE.BN_SYNC_MLP
268271
else 1,
272+
global_sync=(
273+
cfg.CONTRASTIVE.BN_SYNC_MLP and
274+
cfg.BN.GLOBAL_SYNC
275+
),
269276
)
270277

271278
# Softmax for evaluation and testing.
@@ -294,6 +301,10 @@ def __init__(
294301
bn_sync_num=cfg.BN.NUM_SYNC_DEVICES
295302
if cfg.CONTRASTIVE.BN_SYNC_MLP
296303
else 1,
304+
global_sync=(
305+
cfg.CONTRASTIVE.BN_SYNC_MLP and
306+
cfg.BN.GLOBAL_SYNC
307+
),
297308
)
298309
self.predictors.append(local_mlp)
299310

@@ -525,6 +536,10 @@ def __init__(
525536
bn_sync_num=cfg.BN.NUM_SYNC_DEVICES
526537
if cfg.CONTRASTIVE.BN_SYNC_MLP
527538
else 1,
539+
global_sync=(
540+
cfg.CONTRASTIVE.BN_SYNC_MLP and
541+
cfg.BN.GLOBAL_SYNC
542+
),
528543
)
529544
self.detach_final_fc = cfg.MODEL.DETACH_FINAL_FC
530545

slowfast/models/optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ def construct_optimizer(model, cfg):
115115
eps=1e-08,
116116
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
117117
)
118+
elif cfg.SOLVER.OPTIMIZING_METHOD == "mt_adamw":
119+
optimizer = torch.optim._multi_tensor.AdamW(
120+
optim_params,
121+
lr=cfg.SOLVER.BASE_LR,
122+
eps=1e-08,
123+
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
124+
)
118125
else:
119126
raise NotImplementedError(
120127
"Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD)

slowfast/utils/distributed.py

Lines changed: 8 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,14 @@
99
import torch
1010
import torch.distributed as dist
1111

12-
_LOCAL_PROCESS_GROUP = None
13-
14-
15-
def cat_all_gather(tensors, local=False):
16-
"""Performs the concatenated all_reduce operation on the provided tensors."""
17-
if local:
18-
gather_sz = get_local_size()
19-
else:
20-
gather_sz = torch.distributed.get_world_size()
21-
tensors_gather = [torch.ones_like(tensors) for _ in range(gather_sz)]
22-
torch.distributed.all_gather(
23-
tensors_gather,
24-
tensors,
25-
async_op=False,
26-
group=_LOCAL_PROCESS_GROUP if local else None,
27-
)
28-
output = torch.cat(tensors_gather, dim=0)
29-
return output
30-
12+
from pytorchvideo.layers.distributed import ( # noqa
13+
get_world_size,
14+
cat_all_gather,
15+
init_distributed_training,
16+
get_local_size,
17+
get_local_rank,
18+
get_local_process_group,
19+
)
3120

3221
def all_gather(tensors):
3322
"""
@@ -128,17 +117,6 @@ def is_root_proc():
128117
return True
129118

130119

131-
def get_world_size():
132-
"""
133-
Get the size of the world.
134-
"""
135-
if not dist.is_available():
136-
return 1
137-
if not dist.is_initialized():
138-
return 1
139-
return dist.get_world_size()
140-
141-
142120
def get_rank():
143121
"""
144122
Get the rank of the current process.
@@ -282,50 +260,6 @@ def all_gather_unaligned(data, group=None):
282260
return data_list
283261

284262

285-
def init_distributed_training(cfg):
286-
"""
287-
Initialize variables needed for distributed training.
288-
"""
289-
if cfg.NUM_GPUS <= 1:
290-
return
291-
num_gpus_per_machine = cfg.NUM_GPUS
292-
num_machines = dist.get_world_size() // num_gpus_per_machine
293-
for i in range(num_machines):
294-
ranks_on_i = list(
295-
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
296-
)
297-
pg = dist.new_group(ranks_on_i)
298-
if i == cfg.SHARD_ID:
299-
global _LOCAL_PROCESS_GROUP
300-
_LOCAL_PROCESS_GROUP = pg
301-
302-
303-
def get_local_size() -> int:
304-
"""
305-
Returns:
306-
The size of the per-machine process group,
307-
i.e. the number of processes per machine.
308-
"""
309-
if not dist.is_available():
310-
return 1
311-
if not dist.is_initialized():
312-
return 1
313-
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
314-
315-
316-
def get_local_rank() -> int:
317-
"""
318-
Returns:
319-
The rank of the current process within the local (per-machine) process group.
320-
"""
321-
if not dist.is_available():
322-
return 0
323-
if not dist.is_initialized():
324-
return 0
325-
assert _LOCAL_PROCESS_GROUP is not None
326-
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
327-
328-
329263
class GatherLayer(torch.autograd.Function):
330264
"""Gather tensors from all process, supporting backward propagation."""
331265

0 commit comments

Comments
 (0)