diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 8fc466ec76a..a6d8566146d 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -108,19 +108,16 @@ jobs: run: | ray stop --force ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) testing learning rate scheduler run: | ray stop --force - RESUME_MODE=auto MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + LR_WARMUP_STEPS=1 TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) run: | exp_name="qwen3-0.6b-megatron-gsm8k-minimal" python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) - run: | - ray stop --force - ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - name: clean up run: | rm -rf checkpoints diff --git a/docs/examples/config.rst b/docs/examples/config.rst index e9cb89bb15f..a3f910b3645 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -374,6 +374,37 @@ Reference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.us .. note:: **NOTED**: In this config field, users only need to select from ``dummy_megatron``, ``dummy_dtensor``, ``dummy_hf`` for rollout initialization and our hybrid engine will select the corresponding weight loader (i.e., ``megatron``, ``dtensor``, ``hf``) during actor/rollout weight synchronization. + +Megatron Optimizer and Optimizer Parameter Scheduler +____________________________________________________ + +.. code:: yaml + + optim: + optimizer: adam + lr: 1e-6 + clip_grad: 1.0 + total_training_steps: -1 # must be override by program + lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 + lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + lr_decay_steps: null + lr_decay_style: linear # select from constant/linear/cosine/inverse_square_root + min_lr: 0.0 # minimum learning rate, default to 0.0 + weight_decay: 0.01 + weight_decay_incr_style: constant # select from constant/linear/cosine + lr_wsd_decay_style: exponential # select from constant/exponential/cosine + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler + + +Notice that there are some differences in APIs between Megatron optimizer and FSDP optimizer. + +- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup. +- Megatron optimizer also support weight decay decay mechanism +- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training. + + Critic Model ~~~~~~~~~~~~ diff --git a/tests/e2e/run_ppo_trainer_megatron.sh b/tests/e2e/run_ppo_trainer_megatron.sh index 91a3d5ac8ad..d85f68623c9 100644 --- a/tests/e2e/run_ppo_trainer_megatron.sh +++ b/tests/e2e/run_ppo_trainer_megatron.sh @@ -101,6 +101,8 @@ CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +LR_WARMUP_STEPS=${LR_WARMUP_STEPS:-null} + CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra'] SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0} if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then @@ -134,7 +136,7 @@ for ENGINE in "${ENGINES[@]}"; do data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=$LR_WARMUP_STEPS \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ @@ -170,6 +172,7 @@ for ENGINE in "${ENGINES[@]}"; do actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ critic.optim.lr=2e-5 \ + critic.optim.lr_warmup_steps=$LR_WARMUP_STEPS \ critic.model.path="${MODEL_PATH}" \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index e827c68a6d0..c79f7bb0401 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -62,14 +62,21 @@ actor_rollout_ref: data_loader_seed: null shuffle: False optim: + optimizer: adam lr: 1e-6 clip_grad: 1.0 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 + lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + lr_decay_steps: null + lr_decay_style: linear # select from constant/linear/cosine/inverse_square_root + min_lr: 0.0 # minimum learning rate, default to 0.0 weight_decay: 0.01 + weight_decay_incr_style: constant # select from constant/linear/cosine + lr_wsd_decay_style: exponential # select from constant/exponential/cosine + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler megatron: param_offload: False grad_offload: False @@ -177,13 +184,21 @@ critic: rollout_n: ${actor_rollout_ref.rollout.n} strategy: megatron optim: - lr: 1e-5 + optimizer: adam + lr: 1e-6 clip_grad: 1.0 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 + lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + lr_decay_steps: null + lr_decay_style: linear # select from constant/linear/cosine/inverse_square_root + min_lr: 0.0 # minimum learning rate, default to 0.0 weight_decay: 0.01 + weight_decay_incr_style: constant # select from constant/linear/cosine + lr_wsd_decay_style: exponential # select from constant/exponential/cosine + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler model: path: ~/models/deepseek-llm-7b-chat tokenizer_path: ${actor_rollout_ref.model.path} diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index d24ed91abac..cd6b0f8cd31 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -30,6 +30,7 @@ get_hf_model_checkpoint_path, get_model_checkpoint_path, get_optimizer_checkpoint_path, + get_optimizer_scheduler_checkpoint_path, get_rng_states_checkpoint_path, ) @@ -63,7 +64,9 @@ def __init__( share_embeddings_and_output_weights: bool, tokenizer, optimizer, + optimizer_scheduler, use_distributed_optimizer: bool, + use_checkpoint_opt_param_scheduler: bool = False, checkpoint_contents: Optional[list] = None, **kwargs, ): @@ -72,7 +75,7 @@ def __init__( super().__init__( model, optimizer=optimizer, - lr_scheduler=None, + lr_scheduler=optimizer_scheduler, processing_class=tokenizer, checkpoint_contents=checkpoint_contents, ) @@ -88,6 +91,7 @@ def __init__( self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.model_path = self.config.model.path self.use_distributed_optimizer = use_distributed_optimizer + self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler self.rank = torch.distributed.get_rank() @@ -213,6 +217,15 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte if "extra" in self.checkpoint_contents: self.load_rng_states(local_path) + if self.use_checkpoint_opt_param_scheduler: + optimizer_scheduler_path = get_optimizer_scheduler_checkpoint_path(local_path, only_rank0_save=False) + if os.path.exists(optimizer_scheduler_path): + print(f"Loading optimizer scheduler from {optimizer_scheduler_path}") + state_dict = torch.load(optimizer_scheduler_path, weights_only=False) + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(state_dict) + else: + print(f"Optimizer scheduler path {optimizer_scheduler_path} does not exist, skipping loading.") if del_local_after_load: try: @@ -324,4 +337,11 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i torch.save(rng_state, rng_state_path) print(f"Rank {self.rank} saving rng states to {rng_state_path}") + optimizer_scheduler_path = get_optimizer_scheduler_checkpoint_path(local_path, only_rank0_save=False) + if self.lr_scheduler is not None: + state_dict = self.lr_scheduler.state_dict() + torch.save(state_dict, optimizer_scheduler_path) + if self.rank == 0: + print(f"Rank {self.rank} saving optimizer scheduler state to {optimizer_scheduler_path}") + self.previous_saved_paths.append(local_path) diff --git a/verl/utils/megatron/optimizer.py b/verl/utils/megatron/optimizer.py index 30ebf6cc956..fd7eb357105 100644 --- a/verl/utils/megatron/optimizer.py +++ b/verl/utils/megatron/optimizer.py @@ -15,6 +15,7 @@ from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native +from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler def get_megatron_optimizer( @@ -34,4 +35,44 @@ def get_megatron_optimizer( ) -# TODO: add get_optimizer_param_scheduler(optimizer) to implement lr scheuler. +def get_megatron_optimizer_param_scheduler( + optimizer, + config, +): + """ + Get the optimizer parameter scheduler for Megatron. + """ + if config.get("lr_decay_steps", None) is None: + config.lr_decay_steps = config.total_training_steps + wsd_decay_steps = None + if config.get("lr_wsd_decay_steps", None) is not None: + wsd_decay_steps = config.lr_wsd_decay_steps + if config.get("lr_warmup_steps_ratio", None) is not None and (config.get("lr_warmup_steps", None) is None or config.lr_warmup_steps <= 0): + config.lr_warmup_steps = int(config.lr_warmup_steps_ratio * config.lr_decay_steps) + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=config.lr_warmup_init, + max_lr=config.lr, + min_lr=config.min_lr, + lr_warmup_steps=config.lr_warmup_steps, + lr_decay_steps=config.lr_decay_steps, + lr_decay_style=config.lr_decay_style, + start_wd=config.weight_decay, + end_wd=config.weight_decay, + wd_incr_steps=config.total_training_steps, + wd_incr_style=config.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=config.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=(not config.use_checkpoint_opt_param_scheduler), + wsd_decay_steps=wsd_decay_steps, + lr_wsd_decay_style=config.lr_wsd_decay_style, + ) + + return opt_param_scheduler + + +def get_megatron_last_lr(optimizer): + """ + Get the last learning rate from the optimizer parameter scheduler. + """ + return optimizer.param_groups[0]["lr"] diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 17f3c8211fc..ad3dfe5100c 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -200,10 +200,11 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig: config = OptimizerConfig( - optimizer="adam", + optimizer=optim_config.get("optimizer", "adam"), lr=optim_config.get("lr"), - clip_grad=optim_config.get("clip_grad"), - weight_decay=optim_config.get("weight_decay"), + min_lr=optim_config.get("min_lr", None), + clip_grad=optim_config.get("clip_grad", 1.0), + weight_decay=optim_config.get("weight_decay", 0.01), bf16=True, params_dtype=torch.bfloat16, use_distributed_optimizer=True, @@ -455,7 +456,7 @@ def get_optimizer_checkpoint_path(checkpoint_path, use_distributed_optimizer=Tru return os.path.join(checkpoint_path, "optim", f"distrib_optim_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") -def get_rng_states_checkpoint_path(checkpoint_path, only_rank0_save=True): +def get_rng_states_checkpoint_path(checkpoint_path, only_rank0_save=False): # save rng states cause interrupts os.makedirs(os.path.join(checkpoint_path, "rng_states"), exist_ok=True) if only_rank0_save: @@ -467,6 +468,18 @@ def get_rng_states_checkpoint_path(checkpoint_path, only_rank0_save=True): return os.path.join(checkpoint_path, "rng_states", f"rng_states_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") +def get_optimizer_scheduler_checkpoint_path(checkpoint_path, only_rank0_save=False): + # save rng states cause interrupts + os.makedirs(os.path.join(checkpoint_path, "optimizer_scheduler"), exist_ok=True) + if only_rank0_save: + return os.path.join(checkpoint_path, "optimizer_scheduler", "optimizer_scheduler.pt") + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + tp_rank = mpu.get_tensor_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + return os.path.join(checkpoint_path, "optimizer_scheduler", f"optimizer_scheduler_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt") + + def convert_megatron_model_to_transformers_model( name, param, diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index f0338006b47..5027d594fd3 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -509,8 +509,7 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() - learning_rate = self.actor_optimizer.param_groups[-1]["lr"] - data = {"actor/grad_norm": grad_norm, "actor/lr": learning_rate} + data = {"actor/grad_norm": grad_norm} append_to_dict(metrics, data) if update_successful: diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 211beee7c02..b07a3b13247 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -142,7 +142,7 @@ def __init__(self, config: DictConfig, role: str): def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.megatron.optimizer import get_megatron_optimizer + from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import get_generation_config, print_model_size @@ -195,15 +195,17 @@ def megatron_actor_model_provider(pre_process, post_process): # TODO: add more optimizer args into config if self._is_actor: - optim_config = init_megatron_optim_config(optim_config) - actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config) + optim_config_megatron = init_megatron_optim_config(optim_config) + actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron) + actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler(optimizer=actor_optimizer, config=optim_config) else: optim_config = None actor_optimizer = None + actor_optimizer_scheduler = None log_gpu_memory_usage("After actor optimizer init", logger=logger) - return actor_module, actor_optimizer, self.hf_config, optim_config + return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -336,7 +338,7 @@ def init_model(self): if self._is_actor or self._is_rollout: # we need the model for actor and rollout optim_config = self.config.actor.optim if self._is_actor else None - self.actor_module, self.actor_optimizer, self.actor_model_config, self.actor_optim_config = self._build_model_optimizer( + self.actor_module, self.actor_optimizer, self.actor_optimizer_scheduler, self.actor_model_config, self.actor_optim_config = self._build_model_optimizer( model_path=self.config.model.path, optim_config=optim_config, override_model_config=override_model_config, @@ -397,7 +399,9 @@ def init_model(self): share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, tokenizer=self.tokenizer, optimizer=self.actor_optimizer, + optimizer_scheduler=self.actor_optimizer_scheduler, use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, checkpoint_contents=self.config.actor.checkpoint.contents, ) torch.cuda.empty_cache() @@ -424,6 +428,10 @@ def update_actor(self, data: DataProto): global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + from verl.utils.megatron.optimizer import get_megatron_last_lr + + metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer) + self.actor_optimizer_scheduler.step(1) # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) @@ -595,7 +603,7 @@ def __init__(self, config): def _build_critic_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.megatron.optimizer import get_megatron_optimizer + from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import print_model_size @@ -632,10 +640,11 @@ def megatron_critic_model_provider(pre_process, post_process): print_model_size(critic_module[0]) # TODO: add more optimizer args into config - optim_config = init_megatron_optim_config(optim_config) - critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config) + optim_config_megatron = init_megatron_optim_config(optim_config) + critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) + critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler(optimizer=critic_optimizer, config=optim_config) torch.cuda.empty_cache() - return critic_module, critic_optimizer, self.hf_config, optim_config + return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): @@ -653,7 +662,7 @@ def init_model(self): override_transformer_config = OmegaConf.to_container(self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) self.param_dtype = torch.bfloat16 self.dtype = PrecisionType.to_dtype(self.param_dtype) - self.critic_module, self.critic_optimizer, self.critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer( + self.critic_module, self.critic_optimizer, self.critic_optimizer_scheduler, self.critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer( model_path=self.config.model.path, optim_config=self.config.optim, override_model_config=override_model_config, @@ -685,7 +694,9 @@ def init_model(self): share_embeddings_and_output_weights=False, tokenizer=self.tokenizer, optimizer=self.critic_optimizer, + optimizer_scheduler=self.critic_optimizer_scheduler, use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, + use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler, checkpoint_contents=self.config.checkpoint.contents, ) @@ -721,6 +732,11 @@ def update_critic(self, data: DataProto): global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + from verl.utils.megatron.optimizer import get_megatron_last_lr + + metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer) + self.critic_optimizer_scheduler.step(1) + output = DataProto(batch=None, meta_info={"metrics": metrics}) if self._is_offload_param: