diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 9a3aabc39e..a5b9f2c426 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -13,41 +13,99 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Optional from transformers import TrainingArguments @dataclass -class RewardConfig(TrainingArguments): +class SFTConfig(TrainingArguments): r""" - Configuration class for the [`RewardTrainer`]. + Configuration class for the [`SFTTrainer`]. - This class includes only the parameters that are specific to Reward training. For a full list of training - arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this - class may differ from those in [`~transformers.TrainingArguments`]. + This class includes only the parameters that are specific to SFT training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. Using [`~transformers.HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line. Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the - limit. This argument is required if you want to use the default data collator. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to + include the load balancing/auxilliary loss as a part of the final loss, remember to set + `output_router_logits=True` in this dictionary. + chat_template_path (`str`, *optional*): + If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory + or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must + ensure that any special tokens referenced in the template are added to the tokenizer and that the model's + embedding layer is resized accordingly. + + > Parameters that control the data preprocessing + + dataset_text_field (`str`, *optional*, defaults to `"text"`): + Name of the column that contains text data in the dataset. + dataset_kwargs (`dict[str, Any]`, *optional*): + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True` + regardless of the provided value, since preprocessing is done on the fly. dataset_num_proc (`int`, *optional*): Number of processes to use for processing the dataset. - center_rewards_coefficient (`float`, *optional*): - Coefficient to incentivize the reward model to output mean-zero rewards (proposed by - https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. - remove_unused_columns (`bool`, *optional*, defaults to `False`): - Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the - dataset is pretokenized. + eos_token (`str`, *optional*): + Token used to indicate the end of a turn or sequence. If `None`, it defaults to + `processing_class.eos_token`. + pad_token (`int`, *optional*): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. + If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. + packing (`bool`, *optional*, defaults to `False`): + Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce + padding. Uses `max_length` to define sequence length. + packing_strategy (`str`, *optional*, defaults to `"bfd"`): + Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`. + padding_free (`bool`, *optional*, defaults to `False`): + Whether to perform forward passes without padding by flattening all sequences in the batch into a single + continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only + supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When + packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this + parameter. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + eval_packing (`bool`, *optional*): + Whether to pack the eval dataset. If `None`, uses the same value as `packing`. + + > Parameters that control the training + + completion_only_loss (`bool`, *optional*): + Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed + only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If + `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: + loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full + sequence for [language modeling](#language-modeling) datasets. + assistant_only_loss (`bool`, *optional*, defaults to `False`): + Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only + on the assistant responses, which is supported only for [conversational](#conversational) datasets. If + `False`, loss is computed on the entire sequence. + loss_type (`str`, *optional*, defaults to `"nll"`): + Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic + Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)). + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. """ + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=2e-5, + metadata={"help": "The initial learning rate for AdamW."}, + ) logging_steps: float = field( default=10, metadata={ @@ -70,37 +128,135 @@ class may differ from those in [`~transformers.TrainingArguments`]. }, ) - max_length: Optional[int] = field( - default=1024, + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " - "exceed the limit. This argument is required if you want to use the default data collator." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `SFTTrainer` is provided as a string. If you're training a MoE architecture and want to include the " + "load balancing/auxilliary loss as a part of the final loss, remember to set `output_router_logits=True` " + "in this dictionary." }, ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether to disable dropout in the model and reference model."}, + chat_template_path: Optional[str] = field( + default=None, + metadata={ + "help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local " + "directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, " + "you must ensure that any special tokens referenced in the template are added to the tokenizer and " + "that the model's embedding layer is resized accordingly." + }, + ) + + # Parameters that control the data preprocessing + dataset_text_field: str = field( + default="text", + metadata={"help": "Name of the column that contains text data in the dataset."}, + ) + dataset_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`. If the model is a VLM, `skip_prepare_dataset` value is ignored. When the model " + "is a VLM, `skip_prepare_dataset` is automatically treated as `True` regardless of the provided value, " + "since preprocessing is done on the fly." + }, ) dataset_num_proc: Optional[int] = field( default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) - center_rewards_coefficient: Optional[float] = field( + eos_token: Optional[str] = field( default=None, metadata={ - "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by " - "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." + "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`." }, ) - remove_unused_columns: bool = field( + pad_token: Optional[str] = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_length: Optional[int] = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "sequence length." + }, + ) + packing: bool = field( + default=False, + metadata={ + "help": "Whether to group multiple sequences into fixed-length blocks to improve computational efficiency " + "and reduce padding. Uses `max_length` to define sequence length." + }, + ) + packing_strategy: str = field( + default="bfd", + metadata={ + "help": "Strategy for packing sequences. Can be either `'bfd'` (best-fit decreasing, default), or " + "`'wrapped'`." + }, + ) + padding_free: bool = field( + default=False, + metadata={ + "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this " + "is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch " + "structure. When packing is enabled with strategy `'bfd'`, padding-free is enabled, regardless of the " + "value of this parameter." + }, + ) + pad_to_multiple_of: Optional[int] = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + eval_packing: Optional[bool] = field( + default=None, + metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + ) + + # Parameters that control the training + completion_only_loss: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is " + "computed only on the completion, which is supported only for prompt-completion datasets. If `False`, " + "loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: " + "loss is computed on the completion for prompt-completion datasets, and on the full sequence for " + "language modeling datasets." + ) + }, + ) + assistant_only_loss: bool = field( default=False, metadata={ - "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only " - "if the dataset is pretokenized." + "help": ( + "Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is " + "computed only on the assistant responses, which is supported only for conversational datasets. If `False`, " + "loss is computed on the entire sequence." + ) + }, + ) + loss_type: str = field( + default="nll", + metadata={ + "help": ( + 'Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` ' + "(Dynamic Fine-Tuning, as described in https://huggingface.co/papers/2508.05629)." + ) }, ) + activation_offloading: bool = field( + default=False, + metadata={"help": "Whether to offload the activations to the CPU."}, + ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 - super().__post_init__() diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 804e4d8bf3..bdaeedfdd2 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -12,19 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os from collections import defaultdict -from dataclasses import FrozenInstanceError, replace +from collections.abc import Mapping +from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypeVar, Union -import pandas as pd import torch import torch.nn as nn +import transformers from accelerate import PartialState, logging -from accelerate.utils import gather_object -from datasets import Dataset +from datasets import Dataset, IterableDataset from transformers import ( + AutoConfig, + AutoProcessor, BaseImageProcessor, DataCollator, FeatureExtractionMixin, @@ -32,198 +35,837 @@ PreTrainedTokenizerBase, ProcessorMixin, Trainer, + TrainingArguments, is_wandb_available, ) +from transformers.data.data_collator import DataCollatorMixin from transformers.trainer_callback import TrainerCallback -from transformers.trainer_pt_utils import nested_detach from transformers.trainer_utils import EvalPrediction -from transformers.utils import is_peft_available, is_rich_available +from transformers.utils import is_peft_available -from ..data_utils import maybe_apply_chat_template -from ..models import prepare_peft_model -from .reward_config import RewardConfig +from ..data_utils import ( + apply_chat_template, + is_conversational, + is_conversational_from_value, + maybe_convert_to_chatml, + pack_dataset, + prepare_multimodal_messages, + truncate_dataset, +) +from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model +from .sft_config import SFTConfig from .utils import ( - RewardDataCollatorWithPadding, - compute_accuracy, - decode_and_strip_padding, - disable_dropout_in_model, + entropy_from_logits, + flush_left, generate_model_card, get_comet_experiment_url, - log_table_to_comet_experiment, - print_rich_table, + pad, + selective_log_softmax, ) if is_peft_available(): - from peft import PeftModel + from peft import PeftConfig, PeftModel if is_wandb_available(): import wandb - logger = logging.get_logger(__name__) +TListOrMapping = TypeVar("TListOrMapping", list, Mapping) + -def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]: - """Tokenize a batch from a reward modelling dataset.""" - new_examples = { - "input_ids_chosen": [], - "attention_mask_chosen": [], - "input_ids_rejected": [], - "attention_mask_rejected": [], - } - for chosen, rejected in zip(batch["chosen"], batch["rejected"]): - tokenized_chosen = tokenizer(chosen) - tokenized_rejected = tokenizer(rejected) - new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) - new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) - new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) - new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) +def remove_none_values(example: TListOrMapping) -> TListOrMapping: + """ + Recursively removes entries with `None` values from a nested structure (list or dictionary). + + Args: + example (`list` or `Mapping`): + Input nested structure (list or dictionary) from which to remove `None`. - return new_examples + Example: + ```python + >>> [ + ... { + ... "a": {"aa": None, "ab": 1}, + ... "b": "my_string", + ... } + ... ] + >>> remove_none_values(example) + [{'a': {'ab': 1}, 'b': 'my_string'}] + ``` + """ + if isinstance(example, list): + return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] + elif isinstance(example, Mapping): + return { + key: remove_none_values(value) if isinstance(value, (dict, list)) else value + for key, value in example.items() + if value is not None + } + else: + raise TypeError("Input must be a list or a dictionary.") -class RewardTrainer(Trainer): +@dataclass +class DataCollatorForLanguageModeling(DataCollatorMixin): """ - Trainer for custom reward. + Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. + + This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key. + If the input contains a `"completion_mask"`, it is used to set the labels to `-100` for tokens that are not in the + completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not + in the assistant part of the sequence. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. + - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + - `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch. + - `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to + `True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are + not in the assistant part of the sequence are set to -100. Args: - model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*): - Model to be trained, preferably an [`~transformers.AutoModelForSequenceClassification`]. - args ([`RewardConfig`], *optional*): - Training arguments. + pad_token_id (`int`): + Token ID to use for padding. + completion_only_loss (`bool`, *optional*, defaults to `True`): + When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens + that are no in the completion. + padding_free (`bool`, *optional*, defaults to `False`): + If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be + generated accordingly. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl.trainer.sft_trainer import DataCollatorForLanguageModeling + + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0) + >>> examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[ 1, 2, 3], + [ 4, 5, -100]])} + + >>> # With completion mask + >>> examples = [ + ... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + ... {"input_ids": [4, 5], "completion_mask": [0, 1]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[-100, 2, 3], + [-100, 5, -100]])} + + >>> # With padding_free + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3, 4, 5]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1]]), + 'position_ids': tensor([[0, 1, 2, 0, 1]]), + 'labels': tensor([[1, 2, 3, 4, 5]])} + ``` + """ + + pad_token_id: int + completion_only_loss: bool = True + padding_free: bool = False + return_position_ids: bool = True + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + # Convert to tensor + input_ids = [torch.tensor(example["input_ids"]) for example in examples] + + # Check if we have meaningful seq_lengths from packing (restarting sequences) + has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free + + # For packing with position_ids, we should NOT create attention_mask as it causes + # FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask + if not has_packed_position_ids: + attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] + + if self.return_position_ids: + if "seq_lengths" in examples[0]: + position_ids = self.get_position_ids_from_packed_seq_lengths( + [example["seq_lengths"] for example in examples] + ) + else: + position_ids = [torch.arange(len(ids)) for ids in input_ids] + if "labels" in examples[0]: + labels = [torch.tensor(example["labels"]) for example in examples] + else: + labels = [torch.tensor(example["input_ids"]) for example in examples] + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = [torch.tensor(example["completion_mask"]) for example in examples] + if "assistant_masks" in examples[0]: + assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples] + + # If padding_free, flatten everything into a single sequence + output = {} + if self.padding_free: + input_ids = [torch.cat(input_ids, dim=0)] + if not has_packed_position_ids: + attention_mask = [torch.cat(attention_mask, dim=0)] + if self.return_position_ids: + position_ids = [torch.cat(position_ids, dim=0)] + labels = [torch.cat(labels, dim=0)] + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = [torch.cat(completion_mask, dim=0)] + if "assistant_masks" in examples[0]: + assistant_masks = [torch.cat(assistant_masks, dim=0)] + + # Pad + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + if not has_packed_position_ids: + output["attention_mask"] = pad( + attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.return_position_ids: + output["position_ids"] = pad( + position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"] = pad( + labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = pad( + completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion + if "assistant_masks" in examples[0]: + assistant_masks = pad( + assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][assistant_masks == 0] = -100 + return output + + @staticmethod + def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[list[int]]) -> list[torch.Tensor]: + """ + Get position IDs for packed sequences. + + Args: + batch_seq_lengths (`list[list[int]]`): + A list of lists containing the lengths of each individual document in the packed batch. + + Return: + `list[torch.Tensor]`: + A list of tensors containing the position IDs for each packed sequence. + """ + # Get lengths per row + example_lengths = [sum(seq_lengths) for seq_lengths in batch_seq_lengths] + # Flat list of lengths + batch_seq_lengths = torch.tensor( + [seq_length for seq_lengths in batch_seq_lengths for seq_length in seq_lengths] + ) + position_ids = torch.ones(sum(example_lengths), dtype=batch_seq_lengths.dtype) + position_ids[0] = 0 + # Reset position ids to 0 at the start of each sequence + position_ids[batch_seq_lengths[:-1].cumsum(0)] = -(batch_seq_lengths[:-1] - 1) + position_ids = position_ids.cumsum(0) + # Split back into one tensor per example + return list(position_ids.split(example_lengths)) + + +@dataclass +class DataCollatorForVisionLanguageModeling(DataCollatorMixin): + """ + Data collator for vision-language modeling tasks. + + Unlike text-only datasets—where the collator typically receives pre-tokenized inputs ready for batching, + vision-language data processing involves converting images into pixel values. This conversion is disk-intensive, + making upfront preprocessing of the entire dataset impractical. Therefore, this collator performs tokenization and + image processing on-the-fly to efficiently prepare batches. + + Each input example should be a dictionary containing at least: + - An `"images"` key holding the image data. + - [language modeling](#language-modeling) type: either a `"messages"` key for conversational inputs or a `"text"` + key for standard text inputs. + - [prompt-completion](#prompt-completion) type: keys `"prompt"` and `"completion"` for the prompt and completion. + + The collator outputs a dictionary including: + - `"input_ids"`: Tensor of token IDs. + - `"attention_mask"`: Tensor indicating attention mask. + - `"pixel_values"`: Tensor representing image pixel values. + - `"labels"`: Tensor for training labels. + + Additional keys may be present depending on the processor, such as `"image_grid_thw"`. + + Args: + processor (`ProcessorMixin`): + The processor used to tokenize text and process images. It must be a subclass of `ProcessorMixin` and + include a `tokenizer` with a defined `pad_token_id`. + max_length (`int` or `None`, optional, defaults to `None`): + Maximum sequence length for input tokens. If `None`, no truncation is applied. + completion_only_loss (`bool`, *optional*, defaults to `False`): + Whether to compute loss only on the completion part of the sequence. When `True`, the labels for the prompt + part are set to -100. It requires the dataset type to be prompt-completion. + pad_to_multiple_of (`int` or `None`, optional, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. + dataset_text_field (`str`, optional, defaults to `"text"`): + Name of the column that contains text data in the dataset. This parameter is only relevant for [standard + datasets format](dataset_formats#standard). + return_tensors (`str`, optional, defaults to `"pt"`): + The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported. + + Example: + ```python + >>> from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling + >>> from transformers import AutoProcessor + + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> collator = DataCollatorForVisionLanguageModeling(processor) + >>> examples = [ + ... {"images": [Image.open("image_0.png")], "messages": [{"role": "user", "content": "What is this?"}]}, + ... {"images": [Image.open("image_1.png")], "messages": [{"role": "user", "content": "Describe this image."}]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), + 'pixel_values': tensor([[-0.9893, 0.1785, 1.5362, ..., -0.0582, 0.8661, -0.2431], + [-0.2302, 0.9522, -1.1061, ..., 0.0555, 1.3354, -0.6412], + [ 1.2150, 0.9084, 0.7041, ..., 0.2404, -0.8403, -0.5133], + ..., + [ 0.6895, 0.2807, 0.2515, ..., -0.2004, -1.2100, 0.0555], + [ 0.8209, -0.9748, 1.5654, ..., 1.6055, -0.4706, 0.5817], + [-1.0915, 0.4559, 0.9230, ..., 0.5106, 0.0982, -0.1720]]), + 'image_grid_thw': tensor([[1, 4, 4], + [1, 4, 4]]), + 'labels': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]])} + ``` + """ + + processor: ProcessorMixin + max_length: Optional[int] = None + completion_only_loss: bool = False # default not used in practice; SFTTrainer always passes the relevant value + pad_to_multiple_of: Optional[int] = None + dataset_text_field: str = "text" + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + if "messages" in examples[0] or self.dataset_text_field in examples[0]: + if self.completion_only_loss: + raise ValueError( + "The `completion_only_loss` argument is not supported for language modeling datasets." + ) + return self._collate_language_modeling(examples) + elif "prompt" in examples[0] and "completion" in examples[0]: + return self._collate_prompt_completion(examples) + else: + raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.") + + def _collate_language_modeling(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + images = [example["images"] for example in examples] + + if "messages" in examples[0]: # conversational case + for example in examples: + prepare_multimodal_messages(example["messages"], len(example["images"])) + messages = [example["messages"] for example in examples] + texts = self.processor.apply_chat_template(messages) + elif self.dataset_text_field in examples[0]: # standard case + texts = [example[self.dataset_text_field] for example in examples] + else: + raise KeyError( + "The input examples must contain either 'messages' for conversational data or 'text' for standard " + "data." + ) + + output = self.processor( + images=images, + text=texts, + padding=True, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + truncation=self.max_length is not None, + max_length=self.max_length, + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + labels = output["input_ids"].clone() + labels[output["attention_mask"] == 0] = -100 + # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in + # loss computation has to be done by the model, and masking them here would be infeasible in practice as vision + # token definitions vary across architectures. + output["labels"] = labels + return output + + def _collate_prompt_completion(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + if self.pad_to_multiple_of is not None: + raise NotImplementedError( + "Padding to a multiple of a value is not yet implemented for vision-language modeling and " + "prompt-completion data yet." + ) + images = [example["images"] for example in examples] + if is_conversational(examples[0]): # conversational case + for example in examples: + prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"])) + examples = [apply_chat_template(example, self.processor) for example in examples] + + prompts = [example["prompt"] for example in examples] + completions = [example["completion"] for example in examples] + + processed_prompts = self.processor( + images=images, + text=prompts, + padding=True, + padding_side="left", + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + processed_completions = self.processor( + text=completions, + padding=True, + padding_side="right", + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + + # Concatenate prompts and completions + prompt_ids, completion_ids = processed_prompts["input_ids"], processed_completions["input_ids"] + prompt_mask, completion_mask = processed_prompts["attention_mask"], processed_completions["attention_mask"] + input_ids = torch.cat((prompt_ids, completion_ids), dim=1) + attention_mask = torch.cat((prompt_mask, completion_mask), dim=1) + completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1) + + # Flush left to reduce padding + attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + + # Truncate if necessary + if self.max_length is not None: + input_ids = input_ids[:, : self.max_length] + attention_mask = attention_mask[:, : self.max_length] + completion_mask = completion_mask[:, : self.max_length] + + # Create labels and mask padding tokens + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + if self.completion_only_loss: + labels[completion_mask == 0] = -100 + + # Build the output dictionary + output = processed_prompts # we take processed_prompts because it contains the images + output["input_ids"] = input_ids + output["attention_mask"] = attention_mask + output["labels"] = labels + return output + + +def dft_loss(outputs, labels, num_items_in_batch): + """ + DFT loss function, as presented in [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward + Rectification](https://huggingface.co/papers/2508.05629) + """ + labels = nn.functional.pad(labels, (0, 1), value=-100) + shift_labels = labels[..., 1:].contiguous() + loss_mask = shift_labels != -100 + shift_labels[~loss_mask] = 0 + logprobs = selective_log_softmax(outputs.logits, shift_labels) + per_token_loss = -logprobs.exp().detach() * logprobs + loss = (per_token_loss * loss_mask).sum() / num_items_in_batch + return loss + + +class SFTTrainer(Trainer): + """ + Trainer for Supervised Fine-Tuning (SFT) method. + + This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from datasets import load_dataset + from trl import SFTTrainer + + dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") + + trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. + If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss + as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. + args ([`SFTConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): - The data collator to use for training. If None is specified, the default data collator - [`~trainer.utils.RewardDataCollatorWithPadding`] will be used which will pad the sequences to the maximum - length of the sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`], *optional*): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`], *optional*): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional*, defaults to [`~trainer.utils.compute_accuracy`]): - Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a - dictionary string to float. - callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): - Callbacks to use during training. - optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): - Tuple containing the optimizer and the learning rate scheduler to use for training. + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model + and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. + If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss + function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) + used by [`Trainer`]. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing + [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean + `compute_result` argument. This will be triggered after the last eval batch to signal that the function + needs to calculate and return the global summary statistics rather than accumulating the batch-level + statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in + `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before + initializing the Trainer. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): - Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and - return the logits to be used for metrics computation. - peft_config (`dict`, *optional*): - PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be - wrapped with the specified PEFT adapter. + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + formatting_func (`Callable`, *optional*): + Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly + converts the dataset into a [language modeling](#language-modeling) type. """ - _tag_names = ["trl", "reward-trainer"] + _tag_names = ["trl", "sft"] def __init__( self, - model: Optional[Union[PreTrainedModel, nn.Module]] = None, - args: Optional[RewardConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, + model: Union[str, nn.Module, PreTrainedModel], + args: Optional[Union[SFTConfig, TrainingArguments]] = None, + data_collator: Optional[DataCollator] = None, # type: ignore + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, + processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, + compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( - None, - None, - ), + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, + peft_config: Optional["PeftConfig"] = None, + formatting_func: Optional[Callable[[dict], str]] = None, ): - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): - model = prepare_peft_model(model, peft_config, args) + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = SFTConfig(f"{model_name}-SFT") + elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + dict_args.pop("push_to_hub_token") + args = SFTConfig(**dict_args) - # Disable dropout in the model - if args.disable_dropout: - disable_dropout_in_model(model) + # Model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype + else: + raise ValueError( + "Invalid `dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(model_id) - if compute_metrics is None: - compute_metrics = compute_accuracy + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if data_collator is None: - if processing_class is None: + if args.eos_token is not None: + eos_token = args.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) + if eos_token_id is None: raise ValueError( - "A processing_class must be specified when using the default RewardDataCollatorWithPadding" + f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " + "in the vocabulary before using it as an EOS token." ) + tokenizer.eos_token_id = eos_token_id - max_length = args.max_length + if args.chat_template_path is not None: + if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): + with open(args.chat_template_path, encoding="utf-8") as chat_template_file: + processing_class.chat_template = chat_template_file.read() + added_tokens = [] + else: + model, processing_class, added_tokens = clone_chat_template( + model, processing_class, args.chat_template_path + ) + else: + added_tokens = [] - data_collator = RewardDataCollatorWithPadding(processing_class) + # Catch some wrong configurations related to VLMs + if self._is_vlm and args.packing: + raise ValueError( + "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." + ) + if self._is_vlm and args.padding_free: + raise ValueError( + "Padding-free training is yet not supported for vision-language models. Please set " + "`padding_free=False` in the `SFTConfig`." + ) + if self._is_vlm and args.assistant_only_loss: + raise ValueError( + "Assistant-only loss is not yet supported for vision-language models. Please set " + "`assistant_only_loss=False` in the `SFTConfig`." + ) + + # PEFT configuration and model wrapping + if peft_config is not None: + if added_tokens: + # Ensure that the added tokens are trainable + if peft_config.trainable_token_indices is None: + peft_config.trainable_token_indices = {"embed_tokens": added_tokens} + elif "embed_tokens" not in peft_config.trainable_token_indices: + peft_config.trainable_token_indices["embed_tokens"] = added_tokens + else: + peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) + + # Ensure that the lm_head is trainable + if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: + logger.warning( + "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " + "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " + "tokens, leading to degraded generation quality. To fix this, add " + "`modules_to_save=['lm_head']` to your PEFT configuration." + ) + + if peft_config.modules_to_save is None: + peft_config.modules_to_save = ["lm_head"] + else: + peft_config.modules_to_save.append("lm_head") + + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the + # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. + self.num_virtual_tokens = 0 - if args.remove_unused_columns: - try: # for bc before https://github.com/huggingface/transformers/pull/25435 - args.remove_unused_columns = False - except FrozenInstanceError: - args = replace(args, remove_unused_columns=False) - # warn users + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): + model = prepare_peft_model(model, peft_config, args) + if model.active_adapter in model.peft_config: + peft_model_config = model.peft_config[model.active_adapter] + self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) + + # Data collator + # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing + # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask. + self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd") + use_flash_attention = model.config._attn_implementation in [ + "flash_attention_2", + "flash_attention_3", + "kernels-community/vllm-flash-attn3", + ] + if self.padding_free: + if data_collator is not None: + raise ValueError("Passing a custom data collator is not supported when using padding-free.") + if args.packing and args.packing_strategy == "wrapped": + logger.warning( + "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not " + "recommended. Please refer to the documentation to understand why this is not recommended." + ) + if not use_flash_attention: logger.warning( - "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" - " we have set it for you, but you should do it yourself in the future.", + "Padding-free training is enabled, but the attention implementation is not set to " + "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " + "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " + "other implementations may lead to unexpected behavior. To ensure compatibility, set " + "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " + "attention mechanism can handle flattened sequences." + ) + if args.per_device_train_batch_size == 1 and not args.packing: + logger.warning( + "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " + "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " + "to at least 2." ) - self.use_reward_data_collator = True + # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format + # is prompt-completion, and False if the dataset format is language modeling. + dataset_sample = next(iter(train_dataset)) + if args.completion_only_loss is None: + self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample else: - self.use_reward_data_collator = False - - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result, - # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point - # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's - # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been - # issued. - model.warnings_issued["estimate_tokens"] = True - - if "input_ids_chosen" not in train_dataset.column_names: - with PartialState().main_process_first(): - fn_kwargs = {"tokenizer": processing_class} - train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}) - train_dataset = train_dataset.map( - _tokenize, - batched=True, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, + self.completion_only_loss = args.completion_only_loss + + if data_collator is None and not self._is_vlm: + # Get the pad token: if not provided, use the one from the processing class or the eos token + # if the processing class does not have a pad token. + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." ) - # This filter is important because otherwise you get samples that exceed the model's context length and - # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the - # user might get surprised if N samples are missing from training. - train_dataset = train_dataset.filter( - lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length, - num_proc=args.dataset_num_proc, + data_collator = DataCollatorForLanguageModeling( + pad_token_id=pad_token_id, + completion_only_loss=self.completion_only_loss, + padding_free=self.padding_free, + # Using position_ids without flash_attn hurts the training + return_position_ids=use_flash_attention, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + elif data_collator is None and self._is_vlm: + data_collator = DataCollatorForVisionLanguageModeling( + processor=processing_class, + max_length=args.max_length, + completion_only_loss=self.completion_only_loss, + pad_to_multiple_of=args.pad_to_multiple_of, + dataset_text_field=args.dataset_text_field, + ) + + if args.packing and args.packing_strategy == "bfd" and not use_flash_attention: + logger.warning( + "You are using packing, but the attention implementation is not set to 'flash_attention_2' or " + "'kernels-community/vllm-flash-attn3'. Packing flattens batches into a single sequence, and Flash " + "Attention is the only known attention mechanisms that reliably support this. Using other " + "implementations may lead to cross-contamination between batches. To avoid this, either disable " + "packing by setting `packing=False`, or set `attn_implementation='flash_attention_2'` or " + "`attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration." + ) + if args.assistant_only_loss and not is_conversational(dataset_sample): + raise ValueError( + "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only " + "supported for conversational datasets." + ) + + # Dataset + # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where + # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead. + skip_prepare_dataset = ( + args.dataset_kwargs is not None and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vlm + ) + if not skip_prepare_dataset: + if self.completion_only_loss and formatting_func: + raise ValueError( + "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " + "Using a formatter converts the dataset to a language modeling type, conflicting with " + "completion-only loss. To resolve this, apply your formatting function before passing the " + "dataset, or disable `completion_only_loss` in `SFTConfig`." ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class} - ) - eval_dataset = eval_dataset.map( - _tokenize, - fn_kwargs=fn_kwargs, - batched=True, - num_proc=args.dataset_num_proc, - ) - # This filter is important because otherwise you get samples that exceed the model's context length and - # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the - # user might get surprised if N samples are missing from training. - eval_dataset = eval_dataset.filter( - lambda x: len(x["input_ids_chosen"]) <= max_length - and len(x["input_ids_rejected"]) <= max_length, - num_proc=args.dataset_num_proc, + train_dataset = self._prepare_dataset( + train_dataset, processing_class, args, args.packing, formatting_func, "train" + ) + if eval_dataset is not None: + packing = args.packing if args.eval_packing is None else args.eval_packing + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset( + eval_dataset, processing_class, args, packing, formatting_func, "eval" ) + # Loss function + if args.loss_type == "nll": + pass # use the default loss + elif args.loss_type == "dft": + if compute_loss_func is not None: + raise ValueError( + "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " + "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a " + "`compute_loss_func` is not allowed." + ) + compute_loss_func = dft_loss + else: + raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation + super().__init__( model=model, args=args, @@ -231,124 +873,322 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, + compute_loss_func=compute_loss_func, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - def compute_loss( + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + + def _prepare_dataset( self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - rewards_chosen = model( - input_ids=inputs["input_ids_chosen"], - attention_mask=inputs["attention_mask_chosen"], - return_dict=True, - )["logits"] - rewards_rejected = model( - input_ids=inputs["input_ids_rejected"], - attention_mask=inputs["attention_mask_rejected"], - return_dict=True, - )["logits"] - # calculate loss, optionally modulate with margin - if "margin" in inputs: - loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() - else: - loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: SFTConfig, + packing: bool, + formatting_func: Optional[Callable[[dict], str]], + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from + # sampled data. + if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` + dataset = dataset.with_transform(remove_none_values) - if self.args.center_rewards_coefficient is not None: - loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = list(next(iter(dataset)).keys()) + is_processed = "input_ids" in column_names - if return_outputs: - return loss, { - "rewards_chosen": rewards_chosen, - "rewards_rejected": rewards_rejected, - } - return loss + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - inputs = self._prepare_inputs(inputs) - if ignore_keys is None: - if hasattr(self.model, "config"): - ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] + with PartialState().main_process_first(): + # Apply the formatting function if any + if formatting_func is not None and is_processed: + logger.warning( + "You passed a dataset that is already processed (contains an `input_ids` field) together with a " + "formatting function. Therefore `formatting_func` will be ignored. Either remove the " + "`formatting_func` or pass a dataset that is not already processed.", + ) - with torch.no_grad(): - loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) + if formatting_func is not None and not is_processed: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" - if prediction_loss_only: - return (loss, None, None) + def _func(example): + return {"text": formatting_func(example)} - loss = loss.detach() - logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) - logits = nested_detach(logits) - # Stack accepted against rejected, mean over logits - # and softmax to get preferences between accepted and rejected to sum to 1 - logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T + dataset = dataset.map(_func, batched=False, **map_kwargs) - labels = torch.zeros(logits.shape[0]) - labels = self._prepare_inputs(labels) + if not is_processed: + # Convert the dataset to ChatML if needed + first_example = next(iter(dataset)) + if is_conversational_from_value(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + column_names = next(iter(dataset)).keys() + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in column_names else None, + **map_kwargs, + ) - return loss, logits, labels + # Apply the chat template if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" - def evaluate(self, *args, **kwargs): - num_print_samples = kwargs.pop("num_print_samples", 4) - self.visualize_samples(num_print_samples) - return super().evaluate(*args, **kwargs) + def add_eos(example, eos_token): + if "text" in example and not example["text"].endswith(eos_token): # language modeling case + example["text"] = example["text"] + eos_token + elif "completion" in example and not example["completion"].endswith(eos_token): + example["completion"] = example["completion"] + eos_token + return example - def visualize_samples(self, num_print_samples: int): - """ - Visualize the reward model logits prediction + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + remove_columns="messages" if "messages" in column_names else None, # renamed to "text" + **map_kwargs, + ) - Args: - num_print_samples (`int`, defaults to `4`): - The number of samples to print. Set to `-1` to print all samples. - """ - eval_dataloader = self.get_eval_dataloader() - table = defaultdict(list) - for _, inputs in enumerate(eval_dataloader): - _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) - chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class) - rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class) - table["chosen_text"].extend(gather_object(chosen_text)) - table["rejected_text"].extend(gather_object(rejected_text)) - table["logits"].extend( - gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) - ) - if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples: - break - df = pd.DataFrame(table) - if self.accelerator.process_index == 0: - if is_rich_available(): - print_rich_table(df[:num_print_samples]) - if "wandb" in self.args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="completions.csv", - table=df, + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize(example, processing_class, dataset_text_field, assistant_only_loss): + if "prompt" in example: # prompt-completion case + output = {} + if is_conversational(example): + prompt_ids = processing_class.apply_chat_template( + example["prompt"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + prompt_completion_processed = processing_class.apply_chat_template( + example["prompt"] + example["completion"], + return_dict=True, + return_assistant_tokens_mask=assistant_only_loss, + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + prompt_completion_ids = prompt_completion_processed["input_ids"] + if "assistant_masks" in prompt_completion_processed: + output["assistant_masks"] = prompt_completion_processed["assistant_masks"] + else: + prompt_ids = processing_class(text=example["prompt"])["input_ids"] + prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"])[ + "input_ids" + ] + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+completion. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + + # Create a completion mask + completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) + output["input_ids"] = prompt_completion_ids + output["completion_mask"] = completion_mask + + else: # language modeling case + if is_conversational(example): + processed = processing_class.apply_chat_template( + example["messages"], + return_dict=True, + return_assistant_tokens_mask=assistant_only_loss, + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + if "assistant_masks" in processed and 1 not in processed["assistant_masks"]: + raise RuntimeError( + "You're using `assistant_only_loss=True`, but at least one example has no " + "assistant tokens. This usually means the tokenizer's chat template doesn't " + "generate assistant masks — it may be missing the `{% generation %}` keyword. Please " + "check the template and ensure it's correctly configured to support assistant " + "masking." + ) + output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed} + else: + output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]} + return output + + dataset = dataset.map( + tokenize, + fn_kwargs={ + "processing_class": processing_class, + "dataset_text_field": args.dataset_text_field, + "assistant_only_loss": args.assistant_only_loss, + }, + **map_kwargs, ) + # Pack or truncate + if packing: + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Packing {dataset_name} dataset" + + columns = ["input_ids"] + if "completion_mask" in dataset.column_names: + columns.append("completion_mask") + if "assistant_masks" in dataset.column_names: + columns.append("assistant_masks") + + dataset = dataset.select_columns(columns) + + # Packing adds new column "seq_lengths" needed for document aware FlashAttention + dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) + elif args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating {dataset_name} dataset" + dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + # For Liger kernel, ensure only the essential columns + if args.use_liger_kernel: + collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"} + dataset = dataset.select_columns(collator_expected_keys.intersection(dataset.column_names)) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the + # dataset. So we need to override the default signature columns to include "completion_mask" as well. + if self._signature_columns is None: + if self._is_vlm: + self._signature_columns = ["messages", "prompt", "completion", "images"] + else: + self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss and additionally compute token accuracies + """ + mode = "train" if self.model.training else "eval" + + # Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used. + # This can be removed when this issue is fixed. + labels = inputs["labels"] + + # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing + inputs["use_cache"] = False + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + + # Compute entropy + if not self.args.use_liger_kernel: # liger doesn't return logits + with torch.no_grad(): + per_token_entropy = entropy_from_logits(outputs.logits) + if "attention_mask" in inputs: + attention_mask = inputs["attention_mask"] + # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1). + virtual_attention_mask = torch.ones( + attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device + ) + attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1) + entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() + elif "position_ids" in inputs: + entropy = torch.mean(per_token_entropy) + else: + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + self._metrics[mode]["entropy"].append(entropy) + + if mode == "train": + # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q, + # cu_seq_lens_k, and max_length_k, max_length_q and position_ids. + if "attention_mask" in inputs: + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + elif "position_ids" in inputs: + local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) + num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() + else: + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + # Compute token accuracy if we have labels and if the model is not using Liger (no logits) + if not self.args.use_liger_kernel: + with torch.no_grad(): + if "shift_labels" in inputs: + # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: + # - The first discarded token from inputs["labels"] actually belongs to process n-1 + # - The last logits require the label from process n+1 + shift_logits = outputs.logits.contiguous() + shift_labels = inputs["shift_labels"] + else: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do + # not correspond to actual input labels. + shift_logits = shift_logits[:, self.num_virtual_tokens :, :] + + # Get predictions + predictions = shift_logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) + + # Compute the mean token accuracy and log it + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + if self.aux_loss_enabled: + aux_loss = outputs.aux_loss + aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() + self._metrics[mode]["aux_loss"].append(aux_loss) + + return (loss, outputs) if return_outputs else loss + + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs.update(metrics) + super().log(logs, start_time) + self._metrics[mode].clear() + # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial): if self.args.hub_model_id is None: @@ -404,10 +1244,10 @@ def create_model_card( model_name=model_name, hub_model_id=self.hub_model_id, dataset_name=dataset_name, - tags=tags, + tags=list(tags), wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), - trainer_name="Reward", + trainer_name="SFT", ) model_card.save(os.path.join(self.args.output_dir, "README.md"))