Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/split_placement/config/ppo_trainer_split.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data:
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
custom_chat_template: null
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
Expand Down
3 changes: 2 additions & 1 deletion examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def main_task(config):
# instantiate tokenizer
from verl.utils import hf_tokenizer

tokenizer = hf_tokenizer(local_path)
custom_chat_template = config.data.get('custom_chat_template', None)
tokenizer = hf_tokenizer(local_path, custom_chat_template=custom_chat_template)

# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ data:
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
custom_chat_template: null
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
Expand Down
8 changes: 6 additions & 2 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ def run(self, config):
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer

trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

custom_chat_template = config.data.get('custom_chat_template', None)
trust_remote_code = config.data.get('trust_remote_code', False)
tokenizer = hf_tokenizer(local_path,
trust_remote_code=trust_remote_code,
custom_chat_template=custom_chat_template)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none

# define worker classes
Expand Down
3 changes: 2 additions & 1 deletion verl/utils/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config):
response_dict_keys = config.get("response_dict_keys", None)
max_length = config.get("max_length", 1024)
truncation = config.get("truncation", "error")
custom_chat_template = config.get('custom_chat_template', None)

assert truncation in ["error", "left", "right"]
self.truncation = truncation
Expand All @@ -54,7 +55,7 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config):

self.parquet_files = parquet_files
if isinstance(tokenizer, str):
tokenizer = hf_tokenizer(tokenizer)
tokenizer = hf_tokenizer(tokenizer, custom_chat_template=custom_chat_template)
self.tokenizer: PreTrainedTokenizer = tokenizer

self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key]
Expand Down
9 changes: 7 additions & 2 deletions verl/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def set_pad_token_id(tokenizer):
warnings.warn(f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}", stacklevel=1)


def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs):
def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, custom_chat_template=None, **kwargs):
"""Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens.

Args:

name (str): The name of the tokenizer.
correct_pad_token (bool): Whether to correct the pad token id.
correct_gemma2 (bool): Whether to correct the gemma2 tokenizer.

custom_chat_template (str | None): The chat template that overrides the default template.
Returns:

transformers.PreTrainedTokenizer: The pretrained tokenizer.
Expand All @@ -58,6 +58,11 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
if correct_pad_token:
set_pad_token_id(tokenizer)

# Chat template can be overridden, or set if the tokenizer does not have a chat template
if custom_chat_template is not None:
tokenizer.chat_template = custom_chat_template

return tokenizer


Expand Down