Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ Actor/Rollout/Reference Policy
optimizer_offload: False
fsdp_size: -1
checkpoint:
contents: ['model', 'optimizer', 'extra']
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
ref:
fsdp_config:
param_offload: False
Expand Down Expand Up @@ -267,9 +271,11 @@ Actor/Rollout/Reference Policy

- ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor

- ``contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint.
- ``save_contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint.
The extra information includes Rng states currently, FSDP supported lr_scheduler, and Megatron opt_param_scheduler will coming soon.
We do not store hf_model in checkpoint by default, but we provide a tool in `scripts/model_merge.py` to convert checkpoint format to hf format.
We do not store hf_model in checkpoint by default, but we provide a tool in ``scripts/model_merge.py`` to convert checkpoint format to hf format.

- ``load_contents``: The contents to load in the checkpoint, you can specify different checkpoint loading contents. By default, it is the same with ``save_checkpoint``.

**Reference Model**

Expand Down
6 changes: 4 additions & 2 deletions recipe/spin/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def init_model(self):
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.actor.checkpoint.contents)
checkpoint_load_contents=self.config.actor.checkpoint.load_contents,
checkpoint_save_contents=self.config.actor.checkpoint.save_contents)


if self._is_actor:
Expand All @@ -144,7 +145,8 @@ def init_model(self):
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.actor.checkpoint.contents)
checkpoint_load_contents=self.config.actor.checkpoint.load_contents,
checkpoint_save_contents=self.config.actor.checkpoint.save_contents)

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
Expand Down
3 changes: 2 additions & 1 deletion recipe/sppo/sppo_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,6 @@ def init_model(self):
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.actor.checkpoint.contents,
checkpoint_load_contents=self.config.actor.checkpoint.load_contents,
checkpoint_save_contents=self.config.actor.checkpoint.save_contents,
)
2 changes: 1 addition & 1 deletion tests/e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/run_ppo_trainer_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ for ENGINE in "${ENGINES[@]}"; do
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.checkpoint.contents=$CHECKPOINT_CONTENTS \
actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \
actor_rollout_ref.rollout.name="${ENGINE}" \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
Expand Down Expand Up @@ -188,7 +188,7 @@ for ENGINE in "${ENGINES[@]}"; do
critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \
critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \
critic.checkpoint.contents=$CHECKPOINT_CONTENTS \
critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \
reward_model.enable=True \
reward_model.model.path="${MODEL_PATH}" \
reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
Expand Down
3 changes: 2 additions & 1 deletion verl/models/llama/megatron/checkpoint_utils/llama_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
3 changes: 2 additions & 1 deletion verl/models/llama/megatron/checkpoint_utils/llama_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
Expand Down
5 changes: 3 additions & 2 deletions verl/models/mcore/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down Expand Up @@ -382,7 +383,7 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -
sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.input_layernorm.weight",
)

if f"{layer_name}.self_attn.q_norm.weight" in state_dict:
_broadcast_tensor(
sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None,
Expand Down
3 changes: 2 additions & 1 deletion verl/models/mcore/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0):
Expand Down
3 changes: 2 additions & 1 deletion verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model

start_time = time.time()

Expand Down
3 changes: 2 additions & 1 deletion verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from verl.utils.megatron_utils import print_rank_0, unwrap_model
from verl.utils.logger import print_rank_0
from verl.utils.megatron_utils import unwrap_model


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
Expand Down
11 changes: 9 additions & 2 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ actor_rollout_ref:
save_path: null # the path to save the profile result
load_weight: True
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
ref:
strategy: megatron
use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}
Expand Down Expand Up @@ -246,7 +250,10 @@ critic:
kl_coef: 0.001
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
save_contents: ['model', 'optimizer', 'extra']
load_contents: ${critic.checkpoint.save_contents}

reward_model:
enable: False
Expand Down
8 changes: 6 additions & 2 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,10 @@ actor_rollout_ref:

# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
contents: ['model', 'optimizer', 'extra']
save_contents: ['model', 'optimizer', 'extra']

# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}

# optimizer configs
optim:
Expand Down Expand Up @@ -580,7 +583,8 @@ critic:

# What to include in saved checkpoints
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
contents: ['model', 'optimizer', 'extra']
save_contents: ['model', 'optimizer', 'extra']
load_contents: ${critic.checkpoint.save_contents}

# configs for the reward model
reward_model:
Expand Down
61 changes: 57 additions & 4 deletions verl/utils/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,75 @@ def __init__(
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
checkpoint_contents: Optional[list] = None,
checkpoint_load_contents: Optional[list] = None,
checkpoint_save_contents: Optional[list] = None,
):
if checkpoint_contents is None:
checkpoint_contents = ["model", "optimizer", "extra"]
if checkpoint_load_contents is None:
checkpoint_load_contents = ["model", "optimizer", "extra"]
if checkpoint_save_contents is None:
checkpoint_save_contents = ["model", "optimizer", "extra"]
self.previous_global_step = None
self.previous_saved_paths = []

self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.processing_class = processing_class
self.checkpoint_contents = checkpoint_contents
self.checkpoint_load_contents = checkpoint_load_contents
self.checkpoint_save_contents = checkpoint_save_contents

self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()

@property
def should_save_model(self) -> bool:
"""
Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved.
"""
return "model" in self.checkpoint_save_contents

@property
def should_save_optimizer(self) -> bool:
"""
Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved.
"""
return "optimizer" in self.checkpoint_save_contents

@property
def should_save_extra(self) -> bool:
"""
Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved.
"""
return "extra" in self.checkpoint_save_contents

@property
def should_save_hf_model(self) -> bool:
"""
Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf model and saved.
"""
return "hf_model" in self.checkpoint_save_contents

@property
def should_load_model(self) -> bool:
"""
Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded.
"""
return "model" in self.checkpoint_load_contents

@property
def should_load_optimizer(self) -> bool:
"""
Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded.
"""
return "optimizer" in self.checkpoint_load_contents

@property
def should_load_extra(self) -> bool:
"""
Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded.
"""
return "extra" in self.checkpoint_load_contents

def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):
raise NotImplementedError

Expand Down
Loading
Loading