Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix optimizer config
  • Loading branch information
ETOgaosion committed May 26, 2025
commit 4d1ee68224b1c4908eaf3f9ffb30be58ce7ccb84
1 change: 1 addition & 0 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ ____________________________________________________
.. code:: yaml

optim:
optimizer: adam
lr: 1e-6
clip_grad: 1.0
total_training_steps: -1 # must be override by program
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ actor_rollout_ref:
data_loader_seed: null
shuffle: False
optim:
optimizer: adam
lr: 1e-6
clip_grad: 1.0
total_training_steps: -1 # must be override by program
Expand Down
2 changes: 1 addition & 1 deletion verl/utils/megatron/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_megatron_optimizer(

def get_megatron_optimizer_param_scheduler(
optimizer,
config: OptimizerConfig,
config,
):
"""
Get the optimizer parameter scheduler for Megatron.
Expand Down
7 changes: 4 additions & 3 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC

def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:
config = OptimizerConfig(
optimizer="adam",
optimizer=optim_config.get("optimizer", "adam"),
lr=optim_config.get("lr"),
clip_grad=optim_config.get("clip_grad"),
weight_decay=optim_config.get("weight_decay"),
min_lr=optim_config.get("min_lr", None),
clip_grad=optim_config.get("clip_grad", 1.0),
weight_decay=optim_config.get("weight_decay", 0.01),
bf16=True,
params_dtype=torch.bfloat16,
use_distributed_optimizer=True,
Expand Down
8 changes: 4 additions & 4 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def megatron_actor_model_provider(pre_process, post_process):

# TODO: add more optimizer args into config
if self._is_actor:
optim_config = init_megatron_optim_config(optim_config)
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
optim_config_megatron = init_megatron_optim_config(optim_config)
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron)
actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler(optimizer=actor_optimizer, config=optim_config)
else:
optim_config = None
Expand Down Expand Up @@ -666,8 +666,8 @@ def megatron_critic_model_provider(pre_process, post_process):
print_model_size(critic_module[0])

# TODO: add more optimizer args into config
optim_config = init_megatron_optim_config(optim_config)
critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config)
optim_config_megatron = init_megatron_optim_config(optim_config)
critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron)
critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler(optimizer=critic_optimizer, config=optim_config)
torch.cuda.empty_cache()
return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config
Expand Down
Loading