Skip to content

Commit 62f5495

Browse files
Yangruipishiyouga
authored andcommitted
[ppo] feat: add critic valuehead model support for multi-modal PPO (verl-project#1839)
### Checklist Before Starting - [ ] Search for similar PR(s). ### What does this PR do? - 支持多模的 PPO,主要是复用 trl 的 `AutoModelForCausalLMWithValueHead` 作为 critic valuehead model ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] New CI unit test(s) are added to cover the code path. - [ ] Rely on existing unit tests on CI that covers the code path. --------- Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
1 parent 684d57b commit 62f5495

File tree

7 files changed

+124
-13
lines changed

7 files changed

+124
-13
lines changed

.github/workflows/e2e_ppo_trainer.yml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,13 @@ jobs:
181181
fetch-depth: 0
182182
- name: Install the current repository
183183
run: |
184-
pip3 install -e .[test,geo,vllm]
184+
pip3 install -e .[test,gpu,vllm,geo,trl]
185185
# Geo3k
186186
- name: Prepare Geo3k dataset
187187
run: |
188188
ray stop --force
189189
python3 examples/data_preprocess/geo3k.py
190-
- name: Running Geo3k VLM E2E training tests on 8 L20 GPUs with rmpad using function rm
190+
- name: Running Geo3k VLM GRPO E2E training tests on 8 L20 GPUs with rmpad using function rm
191191
run: |
192192
ray stop --force
193193
TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
@@ -197,6 +197,16 @@ jobs:
197197
SP_SIZE=2 \
198198
bash tests/e2e/ppo_trainer/run_function_reward.sh
199199
200+
- name: Running Geo3k VLM PPO E2E training tests on 8 L20 GPUs with rmpad using function rm
201+
run: |
202+
ray stop --force
203+
TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
204+
MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \
205+
MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \
206+
ADV_ESTIMATOR=gae RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
207+
SP_SIZE=2 \
208+
bash tests/e2e/ppo_trainer/run_function_reward.sh
209+
200210
e2e_ppo_trainer_sglang:
201211
runs-on: [L20x8]
202212
needs: pre_commit_for_ppo
@@ -364,4 +374,4 @@ jobs:
364374
ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
365375
ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \
366376
ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \
367-
bash tests/e2e/ppo_trainer/run_function_reward.sh
377+
bash tests/e2e/ppo_trainer/run_function_reward.sh

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"torch-memory-saver>=0.0.5",
5656
"torch==2.6.0",
5757
]
58+
TRL_REQUIRES = ["trl<=0.9.6"]
5859

5960
extras_require = {
6061
"test": TEST_REQUIRES,
@@ -64,6 +65,7 @@
6465
"math": MATH_REQUIRES,
6566
"vllm": VLLM_REQUIRES,
6667
"sglang": SGLANG_REQUIRES,
68+
"trl": TRL_REQUIRES,
6769
}
6870

6971

verl/models/transformers/monkey_patch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers.modeling_flash_attention_utils import _flash_attention_forward
2626
from transformers.modeling_utils import PreTrainedModel
2727

28+
from verl.utils.import_utils import is_trl_available
2829
from verl.utils.ulysses import (
2930
gather_heads_scatter_seq,
3031
gather_seq_scatter_heads,
@@ -156,6 +157,16 @@ def apply_monkey_patch(
156157
assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (
157158
f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,kv heads are repeated to ensure correctness."
158159
)
160+
161+
if is_trl_available():
162+
from trl import AutoModelForCausalLMWithValueHead
163+
164+
def state_dict(self, *args, **kwargs):
165+
return torch.nn.Module.state_dict(self, *args, **kwargs)
166+
167+
AutoModelForCausalLMWithValueHead.state_dict = state_dict
168+
print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ")
169+
159170
# TODO: VLM models only, unify monkey patch to LLM models.
160171
if model.config.model_type == "qwen2_5_vl":
161172
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (

verl/utils/import_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ def is_sglang_available():
4848
return sglang_spec is not None
4949

5050

51+
@cache
52+
def is_trl_available():
53+
try:
54+
trl_spec = importlib.util.find_spec("trl")
55+
except ModuleNotFoundError:
56+
trl_spec = None
57+
return trl_spec is not None
58+
59+
5160
def import_external_libs(external_libs=None):
5261
if external_libs is None:
5362
return

verl/utils/model.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434

3535
from verl.models.registry import ModelRegistry
36+
from verl.utils.import_utils import is_trl_available
3637

3738

3839
class LambdaLayer(nn.Module):
@@ -469,3 +470,70 @@ def get_parallel_gptmodel_from_config(tfconfig, hf_config, pre_process=None, pos
469470

470471
parallel_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, output_size=1, config=tfconfig)
471472
return parallel_model
473+
474+
475+
def patch_valuehead_model(model) -> None:
476+
from types import MethodType
477+
478+
from transformers import PreTrainedModel
479+
480+
from trl import AutoModelForCausalLMWithValueHead
481+
482+
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
483+
if isinstance(self.pretrained_model, PreTrainedModel):
484+
self.pretrained_model.tie_weights()
485+
486+
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
487+
if isinstance(self.pretrained_model, PreTrainedModel):
488+
return self.pretrained_model.get_input_embeddings()
489+
490+
def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
491+
if isinstance(self.pretrained_model, PreTrainedModel):
492+
return self.pretrained_model.get_output_embeddings()
493+
494+
def can_generate(self):
495+
return False
496+
497+
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
498+
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
499+
setattr(model, "tie_weights", MethodType(tie_weights, model))
500+
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
501+
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
502+
setattr(model, "can_generate", MethodType(can_generate, model))
503+
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", []))
504+
505+
506+
def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code):
507+
from transformers import AutoModelForTokenClassification, AutoModelForCausalLM, AutoModelForVision2Seq
508+
509+
try:
510+
model = AutoModelForTokenClassification.from_pretrained(
511+
pretrained_model_name_or_path=local_path,
512+
torch_dtype=torch_dtype,
513+
config=model_config,
514+
attn_implementation="flash_attention_2",
515+
trust_remote_code=trust_remote_code,
516+
)
517+
return model
518+
except BaseException as e:
519+
if not is_trl_available():
520+
raise RuntimeError(f"model({local_path}) is not a value head model, please install trl to make it valid") from e
521+
522+
assert is_trl_available()
523+
524+
from trl import AutoModelForCausalLMWithValueHead
525+
526+
if type(model_config) in AutoModelForVision2Seq._model_mapping.keys():
527+
module_class = AutoModelForVision2Seq
528+
else:
529+
module_class = AutoModelForCausalLM
530+
ori_model = module_class.from_pretrained(
531+
pretrained_model_name_or_path=local_path,
532+
torch_dtype=torch_dtype,
533+
config=model_config,
534+
attn_implementation="flash_attention_2",
535+
trust_remote_code=trust_remote_code,
536+
)
537+
model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model)
538+
patch_valuehead_model(model)
539+
return model

verl/workers/critic/dp_critic.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ def _forward_micro_batch(self, micro_batch):
9999
**multi_modal_inputs,
100100
use_cache=False,
101101
) # prevent model thinks we are generating
102-
values_rmpad = output.logits
103-
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
102+
103+
if hasattr(self.critic_module, "v_head"):
104+
# For trl.AutoModelForCausalLMWithValueHead
105+
values_rmpad = output[2].squeeze(0).unsqueeze(-1)
106+
else:
107+
values_rmpad = output.logits
108+
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
104109

105110
# gather output if sp > 1
106111
if self.ulysses_sequence_parallel_size > 1:
@@ -117,7 +122,11 @@ def _forward_micro_batch(self, micro_batch):
117122
**multi_modal_inputs,
118123
use_cache=False,
119124
) # prevent model thinks we are generating
120-
values = output.logits
125+
if hasattr(self.critic_module, "v_head"):
126+
# For trl.AutoModelForCausalLMWithValueHead
127+
values = output[2]
128+
else:
129+
values = output.logits
121130
values = values[:, -response_length - 1 : -1].squeeze(-1)
122131
return values
123132

@@ -213,7 +222,7 @@ def update_critic(self, data: DataProto):
213222
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
214223
else:
215224
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
216-
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
225+
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
217226

218227
self.critic_optimizer.zero_grad()
219228

verl/workers/fsdp_workers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def _build_critic_model_optimizer(self, config):
823823
from torch import optim
824824
from torch.distributed.fsdp import MixedPrecision
825825

826-
from verl.utils.model import print_model_size
826+
from verl.utils.model import load_valuehead_model, print_model_size
827827
from verl.utils.torch_dtypes import PrecisionType
828828

829829
use_shm = config.model.get("use_shm", False)
@@ -864,11 +864,13 @@ def _build_critic_model_optimizer(self, config):
864864
warnings.simplefilter("ignore")
865865
critic_model_config.classifier_dropout = 0.0
866866
critic_model_config.hidden_dropout = "0"
867-
critic_module = AutoModelForTokenClassification.from_pretrained(
868-
pretrained_model_name_or_path=local_path,
869-
torch_dtype=torch_dtype,
870-
config=critic_model_config,
871-
trust_remote_code=config.model.get("trust_remote_code", False),
867+
critic_model_config.summary_dropout_prob = 0.0
868+
869+
critic_module = load_valuehead_model(
870+
local_path,
871+
torch_dtype,
872+
critic_model_config,
873+
config.model.get("trust_remote_code", False),
872874
)
873875

874876
use_remove_padding = config.model.get("use_remove_padding", False)

0 commit comments

Comments
 (0)