-
Notifications
You must be signed in to change notification settings - Fork 148
Expand file tree
/
Copy pathoptimizer.py
More file actions
133 lines (116 loc) · 4.8 KB
/
optimizer.py
File metadata and controls
133 lines (116 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2020 Ross Wightman
"""
import re
import torch
from torch import optim as optim
from utils.distributed import is_main_process
import logging
logger = logging.getLogger(__name__)
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
named_param_tuples = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
named_param_tuples.append([name, param, 0])
elif name in no_decay_list:
named_param_tuples.append([name, param, 0])
else:
named_param_tuples.append([name, param, weight_decay])
return named_param_tuples
def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
"""use lr=diff_lr for modules named found in diff_lr_names,
otherwise use lr=default_lr
Args:
named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
diff_lr_names: List(str)
diff_lr: float
default_lr: float
Returns:
named_param_tuples_with_lr: List([name, param, weight_decay, lr])
"""
named_param_tuples_with_lr = []
logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
for name, p, wd in named_param_tuples_or_model:
use_diff_lr = False
for diff_name in diff_lr_names:
# if diff_name in name:
if re.search(diff_name, name) is not None:
logger.info(f"param {name} use different_lr: {diff_lr}")
use_diff_lr = True
break
named_param_tuples_with_lr.append(
[name, p, wd, diff_lr if use_diff_lr else default_lr]
)
if is_main_process():
for name, _, wd, diff_lr in named_param_tuples_with_lr:
logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
return named_param_tuples_with_lr
def create_optimizer_params_group(named_param_tuples_with_lr):
"""named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
group = {}
for name, p, wd, lr in named_param_tuples_with_lr:
if wd not in group:
group[wd] = {}
if lr not in group[wd]:
group[wd][lr] = []
group[wd][lr].append(p)
optimizer_params_group = []
for wd, lr_groups in group.items():
for lr, p in lr_groups.items():
optimizer_params_group.append(dict(
params=p,
weight_decay=wd,
lr=lr
))
logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
return optimizer_params_group
def create_optimizer(args, model, filter_bias_and_bn=True):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
# check for modules that requires different lr
if hasattr(args, "different_lr") and args.different_lr.enable:
diff_lr_module_names = args.different_lr.module_names
diff_lr = args.different_lr.lr
else:
diff_lr_module_names = []
diff_lr = None
no_decay = {}
if hasattr(model, 'no_weight_decay'):
no_decay = model.no_weight_decay()
named_param_tuples = add_weight_decay(
model, weight_decay, no_decay, filter_bias_and_bn)
named_param_tuples = add_different_lr(
named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
parameters = create_optimizer_params_group(named_param_tuples)
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
if hasattr(args, 'opt_args') and args.opt_args is not None:
opt_args.update(args.opt_args)
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer