Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions libai/optim/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_default_optimizer_params(
weight_decay_bias=None,
clip_grad_max_norm=None,
clip_grad_norm_type=None,
lr_factor_func=None,
overrides=None,
):
"""
Expand All @@ -58,6 +59,10 @@ def get_default_optimizer_params(
in optimizer.
weight_decay_norm: override weight decay for params in normalization layers
weight_decay_bias: override weight decay for bias parameters
lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
corresponding lr decay rate. Note that setting this option requires
also setting ``base_lr``. e.g.
"lr_factor_func = lambda module_name: 0.1 if "transformer" in module_name else 1"
overrides: if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
Expand Down Expand Up @@ -92,7 +97,9 @@ def get_default_optimizer_params(
if "bias" in overrides:
raise ValueError("Conflicting overrides for 'bias'")
overrides["bias"] = bias_overrides

if lr_factor_func is not None:
if base_lr is None:
raise ValueError("lr_factor_func requires base_lr")
norm_module_types = (
LayerNorm,
flow.nn.BatchNorm1d,
Expand All @@ -108,8 +115,8 @@ def get_default_optimizer_params(
)
params = []
memo = set()
for module in model.modules():
for model_param_name, value in module.named_parameters(recurse=False):
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
Expand All @@ -120,7 +127,9 @@ def get_default_optimizer_params(
hyperparams = copy.copy(defaults)
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
hyperparams["weight_decay"] = weight_decay_norm
hyperparams.update(overrides.get(model_param_name, {}))
if lr_factor_func is not None:
hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
hyperparams.update(overrides.get(f"{module_name}.{module_param_name}", {}))
params.append({"params": [value], **hyperparams})
return reduce_param_groups(params)

Expand Down