diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml index 7d918a58ec4..959368f1dbe 100644 --- a/examples/split_placement/config/ppo_trainer_split.yaml +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -3,6 +3,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + custom_chat_template: null max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 35e1d3cfe33..a727d18bab9 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -120,7 +120,8 @@ def main_task(config): # instantiate tokenizer from verl.utils import hf_tokenizer - tokenizer = hf_tokenizer(local_path) + custom_chat_template = config.data.get('custom_chat_template', None) + tokenizer = hf_tokenizer(local_path, custom_chat_template=custom_chat_template) # define worker classes if config.actor_rollout_ref.actor.strategy == "fsdp": diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 1b6668dce8a..bcd697a330b 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -4,6 +4,7 @@ data: val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt reward_fn_key: data_source + custom_chat_template: null max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index fab382a37bb..375bded48f8 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -98,8 +98,12 @@ def run(self, config): # instantiate tokenizer from verl.utils import hf_processor, hf_tokenizer - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + custom_chat_template = config.data.get('custom_chat_template', None) + trust_remote_code = config.data.get('trust_remote_code', False) + tokenizer = hf_tokenizer(local_path, + trust_remote_code=trust_remote_code, + custom_chat_template=custom_chat_template) processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index aef245cfacc..6815dc6b475 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -45,6 +45,7 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config): response_dict_keys = config.get("response_dict_keys", None) max_length = config.get("max_length", 1024) truncation = config.get("truncation", "error") + custom_chat_template = config.get('custom_chat_template', None) assert truncation in ["error", "left", "right"] self.truncation = truncation @@ -54,7 +55,7 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config): self.parquet_files = parquet_files if isinstance(tokenizer, str): - tokenizer = hf_tokenizer(tokenizer) + tokenizer = hf_tokenizer(tokenizer, custom_chat_template=custom_chat_template) self.tokenizer: PreTrainedTokenizer = tokenizer self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key] diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index 019683f02cc..a4ba0755f19 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -33,7 +33,7 @@ def set_pad_token_id(tokenizer): warnings.warn(f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}", stacklevel=1) -def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): +def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, custom_chat_template=None, **kwargs): """Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens. Args: @@ -41,7 +41,7 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw name (str): The name of the tokenizer. correct_pad_token (bool): Whether to correct the pad token id. correct_gemma2 (bool): Whether to correct the gemma2 tokenizer. - + custom_chat_template (str | None): The chat template that overrides the default template. Returns: transformers.PreTrainedTokenizer: The pretrained tokenizer. @@ -58,6 +58,11 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) if correct_pad_token: set_pad_token_id(tokenizer) + + # Chat template can be overridden, or set if the tokenizer does not have a chat template + if custom_chat_template is not None: + tokenizer.chat_template = custom_chat_template + return tokenizer