From 29506eac2e4937bcd9d2a3d284ebfa3e70583c16 Mon Sep 17 00:00:00 2001 From: Patrik Bartak Date: Fri, 25 Apr 2025 14:43:29 +0200 Subject: [PATCH 1/4] add custom chat template option --- examples/split_placement/config/ppo_trainer_split.yaml | 1 + examples/split_placement/main_ppo_split.py | 2 +- verl/trainer/config/ppo_trainer.yaml | 3 +++ verl/trainer/main_ppo.py | 8 ++++++-- verl/utils/dataset/rl_dataset.py | 4 ++++ verl/utils/dataset/sft_dataset.py | 3 ++- verl/utils/tokenizer.py | 9 +++++++-- 7 files changed, 24 insertions(+), 6 deletions(-) diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml index 7d918a58ec4..959368f1dbe 100644 --- a/examples/split_placement/config/ppo_trainer_split.yaml +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -3,6 +3,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + custom_chat_template: null max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 35e1d3cfe33..d7cd5822b56 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -120,7 +120,7 @@ def main_task(config): # instantiate tokenizer from verl.utils import hf_tokenizer - tokenizer = hf_tokenizer(local_path) + tokenizer = hf_tokenizer(local_path, custom_chat_template=config.data.custom_chat_template) # define worker classes if config.actor_rollout_ref.actor.strategy == "fsdp": diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 8b474df27cd..01c6f839e0b 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -4,6 +4,7 @@ data: val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt reward_fn_key: data_source + custom_chat_template: null max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 @@ -181,6 +182,8 @@ reward_model: forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} reward_manager: naive launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob + num_examine: 0 + val_num_examine: 5 custom_reward_function: path: null diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 1721c88af52..56eb843acdd 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -100,8 +100,12 @@ def run(self, config): # instantiate tokenizer from verl.utils import hf_processor, hf_tokenizer - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + custom_chat_template = config.data.get('custom_chat_template', None) + trust_remote_code = config.data.get('trust_remote_code', False) + tokenizer = hf_tokenizer(local_path, + trust_remote_code=trust_remote_code, + custom_chat_template=custom_chat_template) processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 76a49948300..b777eeb2ee6 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -75,6 +75,7 @@ def __init__( self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") self.max_prompt_length = config.get("max_prompt_length", 1024) + # self.chat_template_key = config.get("chat_template_key", "chat_template") self.return_raw_chat = config.get("return_raw_chat", False) self.truncation = config.get("truncation", "error") @@ -83,6 +84,9 @@ def __init__( self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) self.num_workers = min(self.num_workers, os.cpu_count()) + no_template = "{{ bos_token }}{{ messages }}" + self.tokenizer.chat_template = no_template + # whether to store the dataset in state_dict() # default not store self.serialize_dataset = False diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index 86b29d61392..c3d9d3793f2 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -45,6 +45,7 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config): response_dict_keys = config.get("response_dict_keys", None) max_length = config.get("max_length", 1024) truncation = config.get("truncation", "error") + custom_chat_template = config.get('custom_chat_template', None) assert truncation in ["error", "left", "right"] self.truncation = truncation @@ -54,7 +55,7 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config): self.parquet_files = parquet_files if isinstance(tokenizer, str): - tokenizer = hf_tokenizer(tokenizer) + tokenizer = hf_tokenizer(tokenizer, custom_chat_template=custom_chat_template) self.tokenizer: PreTrainedTokenizer = tokenizer self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key] diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index c82c57ae5f7..87decdb94e1 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -33,7 +33,7 @@ def set_pad_token_id(tokenizer): warnings.warn(f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}") -def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): +def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, custom_chat_template=None, **kwargs): """Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens. Args: @@ -41,7 +41,7 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw name (str): The name of the tokenizer. correct_pad_token (bool): Whether to correct the pad token id. correct_gemma2 (bool): Whether to correct the gemma2 tokenizer. - + custom_chat_template (str | None): The chat template that overrides the default template. Returns: transformers.PreTrainedTokenizer: The pretrained tokenizer. @@ -58,6 +58,11 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) if correct_pad_token: set_pad_token_id(tokenizer) + + # Chat template can be overridden, or set if the tokenizer does not have a chat template + if custom_chat_template is not None: + tokenizer.chat_template = custom_chat_template + return tokenizer From 9a680e4ef863b53a1a1adefaba10283b142db7b2 Mon Sep 17 00:00:00 2001 From: Patrik Bartak Date: Fri, 25 Apr 2025 14:53:48 +0200 Subject: [PATCH 2/4] update main_ppo_split --- examples/split_placement/main_ppo_split.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index d7cd5822b56..a727d18bab9 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -120,7 +120,8 @@ def main_task(config): # instantiate tokenizer from verl.utils import hf_tokenizer - tokenizer = hf_tokenizer(local_path, custom_chat_template=config.data.custom_chat_template) + custom_chat_template = config.data.get('custom_chat_template', None) + tokenizer = hf_tokenizer(local_path, custom_chat_template=custom_chat_template) # define worker classes if config.actor_rollout_ref.actor.strategy == "fsdp": From b82e5f19443942ac3fe0d5d13404fa156ce6f531 Mon Sep 17 00:00:00 2001 From: Patrik Bartak Date: Fri, 25 Apr 2025 14:54:20 +0200 Subject: [PATCH 3/4] remove num_examine --- verl/trainer/config/ppo_trainer.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 01c6f839e0b..630f3c04d64 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -182,8 +182,6 @@ reward_model: forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} reward_manager: naive launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob - num_examine: 0 - val_num_examine: 5 custom_reward_function: path: null From 42bf275df587a82a4ee1f47bb014b80a65d6c483 Mon Sep 17 00:00:00 2001 From: Patrik Bartak Date: Fri, 25 Apr 2025 14:55:09 +0200 Subject: [PATCH 4/4] fix --- verl/utils/dataset/rl_dataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index b777eeb2ee6..76a49948300 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -75,7 +75,6 @@ def __init__( self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") self.max_prompt_length = config.get("max_prompt_length", 1024) - # self.chat_template_key = config.get("chat_template_key", "chat_template") self.return_raw_chat = config.get("return_raw_chat", False) self.truncation = config.get("truncation", "error") @@ -84,9 +83,6 @@ def __init__( self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) self.num_workers = min(self.num_workers, os.cpu_count()) - no_template = "{{ bos_token }}{{ messages }}" - self.tokenizer.chat_template = no_template - # whether to store the dataset in state_dict() # default not store self.serialize_dataset = False