Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ Actor/Rollout/Reference Policy
optimizer_offload: False
fsdp_size: -1
checkpoint:
contents: ['model', 'optimizer', 'extra']
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
ref:
fsdp_config:
param_offload: False
Expand Down Expand Up @@ -267,9 +271,11 @@ Actor/Rollout/Reference Policy

- ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor

- ``contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint.
- ``save_contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint.
The extra information includes Rng states currently, FSDP supported lr_scheduler, and Megatron opt_param_scheduler will coming soon.
We do not store hf_model in checkpoint by default, but we provide a tool in `scripts/model_merge.py` to convert checkpoint format to hf format.
We do not store hf_model in checkpoint by default, but we provide a tool in ``scripts/model_merge.py`` to convert checkpoint format to hf format.

- ``load_contents``: The contents to load in the checkpoint, you can specify different checkpoint loading contents. By default, it is the same with ``save_checkpoint``.

**Reference Model**

Expand Down
254 changes: 111 additions & 143 deletions recipe/spin/fsdp_workers.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion recipe/sppo/sppo_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,5 @@ def init_model(self):
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.actor.checkpoint.contents,
checkpoint_contents=self.config.actor.checkpoint,
)
2 changes: 1 addition & 1 deletion tests/e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
Expand Down
6 changes: 4 additions & 2 deletions tests/e2e/run_ppo_trainer_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
set -xeuo pipefail

export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
export VERL_LOGGING_LEVEL=INFO
export VERL_PPO_LOGGING_LEVEL=INFO

NUM_GPUS=${NUM_GPUS:-8}

Expand Down Expand Up @@ -155,7 +157,7 @@ for ENGINE in "${ENGINES[@]}"; do
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.checkpoint.contents=$CHECKPOINT_CONTENTS \
actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \
actor_rollout_ref.rollout.name="${ENGINE}" \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
Expand Down Expand Up @@ -188,7 +190,7 @@ for ENGINE in "${ENGINES[@]}"; do
critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \
critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \
critic.checkpoint.contents=$CHECKPOINT_CONTENTS \
critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \
reward_model.enable=True \
reward_model.model.path="${MODEL_PATH}" \
reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
Expand Down
3 changes: 2 additions & 1 deletion verl/models/llama/megatron/checkpoint_utils/llama_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
3 changes: 2 additions & 1 deletion verl/models/llama/megatron/checkpoint_utils/llama_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
Expand Down
5 changes: 3 additions & 2 deletions verl/models/mcore/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down Expand Up @@ -382,7 +383,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -
sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.input_layernorm.weight",
)

if f"{layer_name}.self_attn.q_norm.weight" in state_dict:
_broadcast_tensor(
sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None,
Expand Down
3 changes: 2 additions & 1 deletion verl/models/mcore/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0):
Expand Down
3 changes: 2 additions & 1 deletion verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
3 changes: 2 additions & 1 deletion verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
Expand Down
11 changes: 9 additions & 2 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ actor_rollout_ref:
save_path: null # the path to save the profile result
load_weight: True
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
ref:
strategy: megatron
use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}
Expand Down Expand Up @@ -246,7 +250,10 @@ critic:
kl_coef: 0.001
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
load_contents: ${critic.checkpoint.save_contents}

reward_model:
enable: False
Expand Down
8 changes: 6 additions & 2 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,10 @@ actor_rollout_ref:

# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
contents: ['model', 'optimizer', 'extra']
save_contents: ['model', 'optimizer', 'extra']

# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}

# optimizer configs
optim:
Expand Down Expand Up @@ -580,7 +583,8 @@ critic:

# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
contents: ['model', 'optimizer', 'extra']
save_contents: ['model', 'optimizer', 'extra']
load_contents: ${critic.checkpoint.save_contents}

# configs for the reward model
reward_model:
Expand Down
66 changes: 61 additions & 5 deletions verl/utils/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
import shutil
import tempfile
from typing import Optional, Union
from typing import Union

import numpy as np
import torch
import torch.distributed
from filelock import FileLock
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer, ProcessorMixin

from verl.utils.device import is_cuda_available, is_npu_available
Expand All @@ -47,22 +49,76 @@ def __init__(
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
checkpoint_contents: Optional[list] = None,
checkpoint_contents: DictConfig = None,
):
if checkpoint_contents is None:
checkpoint_contents = ["model", "optimizer", "extra"]
checkpoint_load_contents = checkpoint_contents.get("load_contents", None) if checkpoint_contents else None
checkpoint_save_contents = checkpoint_contents.get("save_contents", None) if checkpoint_contents else None
if checkpoint_load_contents is None:
checkpoint_load_contents = ["model", "optimizer", "extra"]
if checkpoint_save_contents is None:
checkpoint_save_contents = ["model", "optimizer", "extra"]
self.previous_global_step = None
self.previous_saved_paths = []

self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.processing_class = processing_class
self.checkpoint_contents = checkpoint_contents
self.checkpoint_load_contents = checkpoint_load_contents
self.checkpoint_save_contents = checkpoint_save_contents

self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()

@property
def should_save_model(self) -> bool:
"""
Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved.
"""
return "model" in self.checkpoint_save_contents

@property
def should_save_optimizer(self) -> bool:
"""
Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved.
"""
return "optimizer" in self.checkpoint_save_contents

@property
def should_save_extra(self) -> bool:
"""
Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved.
"""
return "extra" in self.checkpoint_save_contents

@property
def should_save_hf_model(self) -> bool:
"""
Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf model and saved.
"""
return "hf_model" in self.checkpoint_save_contents

@property
def should_load_model(self) -> bool:
"""
Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded.
"""
return "model" in self.checkpoint_load_contents

@property
def should_load_optimizer(self) -> bool:
"""
Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded.
"""
return "optimizer" in self.checkpoint_load_contents

@property
def should_load_extra(self) -> bool:
"""
Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded.
"""
return "extra" in self.checkpoint_load_contents

def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):
raise NotImplementedError

Expand Down
Loading
Loading