diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index befb0d29216..f2474fa1461 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -13,23 +13,11 @@ # limitations under the License. from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Any, Optional from transformers import TrainingArguments -class FDivergenceType(Enum): - REVERSE_KL = "reverse_kl" - JS_DIVERGENCE = "js_divergence" - ALPHA_DIVERGENCE = "alpha_divergence" - - -class FDivergenceConstants: - ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef" - ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0 - - @dataclass class DPOConfig(TrainingArguments): r""" @@ -44,155 +32,44 @@ class DPOConfig(TrainingArguments): command line. Parameters: - > Parameters that control the model and reference model + > Parameters that control the model model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the - [`DPOTrainer`] is provided as a string. - ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the - [`DPOTrainer`] is provided as a string. - model_adapter_name (`str` or `None`, *optional*, defaults to `None`): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - force_use_ref_model (`bool`, *optional*, defaults to `False`): - If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set - this flag to `True`. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model and reference model. - use_logits_to_keep (`bool`, *optional*, defaults to `False`): - If `True`, only a specified number of logits are computed in the forward pass. This can be useful for - saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios - when working with very long prompts where labels are ignored (-100). + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`DPOTrainer`] is provided as a string. > Parameters that control the data preprocessing dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. - padding_value (`int` or `None`, *optional*, defaults to `None`): - Padding value to use. If `None`, the padding value of the tokenizer is used. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Padding value to use for labels. + pad_token (`int` or `None`, *optional*, defaults to `None`): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. + Maximum length of the prompt part of the sequence. If `None`, no truncation is applied. max_completion_length (`int` or `None`, *optional*, defaults to `None`): - Maximum length of the completion. + Maximum length of the completion part of the sequence. If `None`, no truncation is applied. max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the full sequence (prompt + completion). - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and - `"keep_start"`. + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. + If `None`, no truncation is applied. padding_free (`bool`, *optional*, defaults to `False`): Whether to perform forward passes without padding by flattening all sequences in the batch into a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only - supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened - batch structure. - precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): - Whether to precompute the log probabilities from the reference model. Setting this to `True` allows - training without needing the reference model during training, which can help reduce GPU memory usage. If - set to `False` (default), the reference model will be used during training to compute log probabilities - on-the-fly. - precompute_ref_batch_size (`int` or `None`, *optional*, defaults to `None`): - Batch size to use when precomputing reference model log probabilities. This can be set higher than the - training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for - training and `per_device_eval_batch_size` for evaluation. - tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): - List of tools (callable functions) that will be accessible to the model. If the template does not support - function calling, this argument will have no effect. + supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. + pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. > Parameters that control the training - loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`): - Type of loss to use. Possible values are: - - - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. - - `"hinge"`: hinge loss on the normalized likelihood from the - [SLiC](https://huggingface.co/papers/2305.10425) paper. - - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. - - `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper. - - `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper. - - `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust - DPO](https://huggingface.co/papers/2403.00409) paper. - - `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper. - - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) - paper. - - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. - - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) - paper. - - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the - [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. - - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss). - - Multiple loss types can be combined using comma separation (e.g., `["sigmoid", "bco_pair", "sft"]` for - [MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify - corresponding weights for each loss type. - - use_liger_loss (`bool`, *optional*, defaults to `False`): - Whether to use Liger loss. - base_model_attribute_name (`str`, *optional*, defaults to `"model"`): - Name of the attribute in the model that contains the base model. This is used to get the base model from - the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). - f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): - Type of f-divergence regularization function to compute divergence between policy and reference model. - f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): - α coefficient in the α-divergence u^-α regularization function for DPO loss. - reference_free (`bool`, *optional*, defaults to `False`): - Whether to ignore the provided reference model and implicitly use a reference model that assigns equal - probability to all responses. - label_smoothing (`float`, *optional*, defaults to `0.0`): - Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust - DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. - use_weighting (`bool`, *optional*, defaults to `False`): - Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827). - rpo_alpha (`float`, *optional*, defaults to `None`): - α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the - weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the - DPO loss. The paper recommends `rpo_alpha=1.0`. - ld_alpha (`float` or `None`, *optional*, defaults to `None`): - α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting - of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose - part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between - `0.0` and `1.0`. - discopop_tau (`float`, *optional*, defaults to `0.05`): - τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls - the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. - loss_weights (`list[float]` or `None`, *optional*, defaults to `None`): - List of loss weights for multi-loss combinations. Used when combining multiple loss types. Example: `[0.8, - 0.2, 1.0]` for [MPO](https://huggingface.co/papers/2411.10442). If not provided, defaults to equal weights - (`1.0`) for all loss types. - sync_ref_model (`bool`, *optional*, defaults to `False`): - Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using - the `ref_model_mixup_alpha` parameter. This synchronization originates from the - [TR-DPO](https://huggingface.co/papers/2404.09656) paper. - ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): - α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix - between the current policy and the previous reference policy during updates. The reference policy is - updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you - must set `sync_ref_model=True`. - ref_model_sync_steps (`int`, *optional*, defaults to `512`): - τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how - frequently the current policy is synchronized with the reference policy. To use this parameter, you must - set `sync_ref_model=True`. - - > Parameters that control the logging - - generate_during_eval (`bool`, *optional*, defaults to `False`): - Whether to generate and log completions from both the model and the reference model to W&B or Comet during - evaluation. + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"] + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( - default=1e-6, + default=2e-5, metadata={"help": "The initial learning rate for AdamW."}, ) logging_steps: float = field( @@ -216,8 +93,18 @@ class DPOConfig(TrainingArguments): "`fp16` is not set." }, ) + # Note: In transformers>=4.54.0, `average_tokens_across_devices` defaults to True. Overriding this setting is only + # needed for earlier versions. Once we require transformers>=4.54.0, this line can be safely removed. + # See https://github.com/huggingface/transformers/pull/39395 + average_tokens_across_devices: bool = field( + default=True, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize " + "num_tokens_in_batch for precise loss calculation. Reference: https://github.com/huggingface/transformers/issues/34242 " + }, + ) - # Parameters that control the model and reference model + # Parameters that control the model model_init_kwargs: Optional[dict[str, Any]] = field( default=None, metadata={ @@ -225,243 +112,58 @@ class DPOConfig(TrainingArguments): "the `DPOTrainer` is provided as a string." }, ) - ref_model_init_kwargs: Optional[dict[str, Any]] = field( - default=None, - metadata={ - "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument " - "of the `DPOTrainer` is provided as a string." - }, - ) - model_adapter_name: Optional[str] = field( - default=None, - metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, - ) - ref_adapter_name: Optional[str] = field( - default=None, - metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, - ) - force_use_ref_model: bool = field( - default=False, - metadata={ - "help": "If you provide a PEFT model as the active model and wish to use a different model for the " - "`ref_model`, set this flag to `True`." - }, - ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether to disable dropout in the model and reference model."}, - ) - use_logits_to_keep: bool = field( - default=False, - metadata={ - "help": "If `True`, only a specified number of logits are computed in the forward pass. This can be " - "useful for saving memory and speeding up training by not computing the logits for all tokens, especially " - "in scenarios when working with very long prompts where labels are ignored (-100)." - }, - ) # Parameters that control the data preprocessing dataset_num_proc: Optional[int] = field( default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) - padding_value: Optional[int] = field( + pad_token: Optional[str] = field( default=None, - metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, - ) - label_pad_token_id: int = field( - default=-100, - metadata={"help": "Padding value to use for labels."}, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, ) max_prompt_length: Optional[int] = field( default=512, - metadata={"help": "Maximum length of the prompt."}, + metadata={ + "help": "Maximum length of the prompt part of the sequence. If `None`, no truncation is applied. When packing is enabled, this value sets the prompt length." + }, ) max_completion_length: Optional[int] = field( default=None, - metadata={"help": "Maximum length of the completion."}, + metadata={ + "help": "Maximum length of the completion part of the sequence. If `None`, no truncation is applied. When packing is enabled, this value sets the completion length." + }, ) max_length: Optional[int] = field( default=1024, - metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, - ) - truncation_mode: str = field( - default="keep_end", metadata={ - "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " - "and `'keep_start'`.", - "choices": ["keep_end", "keep_start"], + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied." }, ) padding_free: bool = field( default=False, metadata={ "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " - "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, " - "this is only supported with the `flash_attention_2` attention implementation, which can efficiently " - "handle the flattened batch structure." + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this " + "is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch " + "structure." }, ) - precompute_ref_log_probs: bool = field( - default=False, - metadata={ - "help": "Whether to precompute the log probabilities from the reference model. Setting this to `True` " - "allows training without needing the reference model during training, which can help reduce GPU memory " - "usage. If set to `False` (default), the reference model will be used during training to compute log " - "probabilities on-the-fly." - }, - ) - precompute_ref_batch_size: Optional[int] = field( + pad_to_multiple_of: Optional[int] = field( default=None, - metadata={ - "help": "Batch size to use when precomputing reference model log probabilities. This can be set higher " - "than the training batch size to speed up preprocessing. If `None`, defaults to " - "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." - }, - ) - tools: Optional[list[Union[dict, Callable]]] = field( - default=None, - metadata={ - "help": "List of tools (callable functions) that will be accessible to the model. If the template does " - "not support function calling, this argument will have no effect." - }, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, ) # Parameters that control the training - loss_type: list[str] = field( - default_factory=lambda: ["sigmoid"], - metadata={ - "help": "Type of loss to use. Possible values are: `'sigmoid'`, `'hinge'`, `'ipo'`, `'exo_pair'`, " - "`'nca_pair'`, `'robust'`, `'bco_pair'`, `'sppo_hard'`, `'aot'`, `'aot_pair'`, `'discopop'`, " - "`'apo_zero'`, `'apo_down'` and `'sft'`. Multiple loss types can be combined using comma separation " - "(e.g., `['sigmoid', 'bco_pair', 'sft']` for MPO). The `loss_weights` parameter can be used to specify " - "corresponding weights for each loss type." - }, - ) - use_liger_loss: bool = field( + activation_offloading: bool = field( default=False, - metadata={"help": "Whether to use Liger loss."}, - ) - base_model_attribute_name: str = field( - default="model", - metadata={ - "help": "Name of the attribute in the model that contains the base model. This is used to get the base " - "model from the model when the model does not have a `get_decoder` method in the case when " - "`use_liger_loss` is `True`." - }, - ) - beta: float = field( - default=0.1, - metadata={ - "help": "Parameter controlling the deviation from the reference model. " - "Higher β means less deviation from the reference model." - }, - ) - f_divergence_type: FDivergenceType = field( - default=FDivergenceType.REVERSE_KL, - metadata={ - "help": "Type of f-divergence regularization function to compute divergence between policy and reference " - "model." - }, - ) - f_alpha_divergence_coef: float = field( - default=1.0, - metadata={"help": "α coefficient in the α-divergence u^-α regularization function for DPO loss."}, - ) - reference_free: bool = field( - default=False, - metadata={ - "help": "Whether to ignore the provided reference model and implicitly use a reference model that assigns " - "equal probability to all responses." - }, - ) - label_smoothing: float = field( - default=0.0, - metadata={ - "help": "Robust DPO label smoothing parameter from the cDPO report and Robust DPO paper that should " - "be between `0.0` and `0.5`." - }, - ) - use_weighting: bool = field( - default=False, - metadata={"help": "Whether to weight the loss as done in the WPO paper."}, - ) - rpo_alpha: Optional[float] = field( - default=None, - metadata={ - "help": "α parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. " - "If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends " - "`rpo_alpha=1.0`." - }, - ) - ld_alpha: Optional[float] = field( - default=None, - metadata={ - "help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token " - "log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is " - "equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.", - }, - ) - discopop_tau: float = field( - default=0.05, - metadata={ - "help": "τ/temperature parameter from the DiscoPOP paper, which controls the shape of log ratio modulated " - "loss. The paper recommends the default value `discopop_tau=0.05`." - }, - ) - loss_weights: Optional[list[float]] = field( - default=None, - metadata={ - "help": "List of loss weights for multi-loss combinations. Used when combining multiple loss types. " - "Example: `[0.8, 0.2, 1.0]` for MPO. If not provided, defaults to equal weights (`1.0`) for all loss " - "types." - }, - ) - sync_ref_model: bool = field( - default=False, - metadata={ - "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " - "steps, using the `ref_model_mixup_alpha` parameter." - }, - ) - ref_model_mixup_alpha: float = field( - default=0.6, - metadata={ - "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " - "previous reference policy during updates. The reference policy is updated according to the equation: " - "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." - }, - ) - ref_model_sync_steps: int = field( - default=512, - metadata={ - "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " - "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." - }, - ) - - # Parameters that control the logging - generate_during_eval: bool = field( - default=False, - metadata={ - "help": "Whether to generate and log completions from both the model and the reference model to W&B, MLFLow " - "or Comet during evaluation." - }, + metadata={"help": "Whether to offload the activations to the CPU."}, ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 - - # Normalize loss_type to string format for internal use - if hasattr(self.loss_type, "__len__") and len(self.loss_type) == 1: - self.loss_type = self.loss_type[0] - - # Validate loss_type - if self.loss_weights is not None: - loss_types = self.loss_type if isinstance(self.loss_type, list) else [self.loss_type] - if len(self.loss_weights) != len(loss_types): - raise ValueError( - f"Length of loss_weights list ({self.loss_weights}) must match number of loss types " - f"({loss_types})." - ) super().__post_init__() diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 5d4ae69461f..50a322056b2 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -12,29 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect +import contextlib import os -import random -import textwrap import warnings from collections import defaultdict -from contextlib import contextmanager, nullcontext +from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union - -import pandas as pd +from typing import Any, Callable, Optional, TypeVar, Union +import torch.nn.functional as F import torch import torch.nn as nn -import torch.nn.functional as F +import transformers from accelerate import PartialState -from accelerate.utils import tqdm from datasets import Dataset, IterableDataset -from torch import autocast -from torch.utils.data import DataLoader from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, + AutoConfig, + AutoProcessor, BaseImageProcessor, DataCollator, FeatureExtractionMixin, @@ -42,141 +36,320 @@ PreTrainedTokenizerBase, ProcessorMixin, Trainer, -) -from transformers.data.data_collator import DataCollatorMixin -from transformers.integrations import ( - is_comet_available, - is_mlflow_available, + TrainingArguments, is_wandb_available, ) -from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES +from transformers.data.data_collator import DataCollatorMixin from transformers.trainer_callback import TrainerCallback -from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_liger_kernel_available, is_peft_available - -from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt -from ..models import create_reference_model, prepare_deepspeed -from ..models.utils import prepare_fsdp -from .callbacks import SyncRefModelCallback -from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType -from .utils import ( - RunningMoments, - cap_exp, - disable_dropout_in_model, - empty_cache, - flush_left, - flush_right, - generate_model_card, - get_comet_experiment_url, - log_table_to_comet_experiment, - pad, - pad_to_length, - peft_module_casting_to_bf16, - selective_log_softmax, -) - +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available -if is_peft_available(): - from peft import ( - PeftConfig, - PeftModel, - get_peft_model, - prepare_model_for_kbit_training, - ) +from ..data_utils import extract_prompt, is_conversational, truncate_dataset +from ..models import get_act_offloading_ctx_manager, prepare_peft_model +from .dpo_config import DPOConfig +from .utils import generate_model_card, get_comet_experiment_url, pad -if is_liger_kernel_available(): - from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss +if is_peft_available(): + from peft import PeftConfig, PeftModel if is_wandb_available(): import wandb -if is_mlflow_available(): - import mlflow +TListOrMapping = TypeVar("TListOrMapping", list, Mapping) -def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor: - """Shift input ids one token to the right, and pad with pad_token_id""" - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() - shifted_input_ids[:, 0] = decoder_start_token_id + +def remove_none_values(example: TListOrMapping) -> TListOrMapping: + """ + Recursively removes entries with `None` values from a nested structure (list or dictionary). + + Args: + example (`list` or `Mapping`): + Input nested structure (list or dictionary) from which to remove `None`. + + Example: + ```python + >>> [ + ... { + ... "a": {"aa": None, "ab": 1}, + ... "b": "my_string", + ... } + ... ] + >>> remove_none_values(example) + [{'a': {'ab': 1}, 'b': 'my_string'}] + ``` + """ + if isinstance(example, list): + return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] + elif isinstance(example, Mapping): + return { + key: remove_none_values(value) if isinstance(value, (dict, list)) else value + for key, value in example.items() + if value is not None + } + else: + raise TypeError("Input must be a list or a dictionary.") @dataclass -class DataCollatorForPreference(DataCollatorMixin): +class DataCollatorForLanguageModeling(DataCollatorMixin): """ - Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they are - not all of the same length. + Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. + + This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key. + If the input contains a `"completion_mask"`, it is used to set the labels to `-100` for tokens that are not in the + completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not + in the assistant part of the sequence. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. + - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + - `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch. + - `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to + `True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are + not in the assistant part of the sequence are set to -100. Args: pad_token_id (`int`): Token ID to use for padding. + padding_free (`bool`, *optional*, defaults to `False`): + If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be + generated accordingly. The attention mask will be set to 1 for all tokens. + pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. return_tensors (`str`, *optional*, defaults to `"pt"`): Type of Tensor to return. Only `"pt"` is currently supported. Examples: ```python - >>> from trl import DataCollatorForPreference + >>> from trl import DataCollatorForLanguageModeling - >>> collator = DataCollatorForPreference(pad_token_id=0) + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0) + >>> examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[ 1, 2, 3], + [ 4, 5, -100]])} + + >>> # With completion mask >>> examples = [ - ... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, - ... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}, + ... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + ... {"input_ids": [4, 5], "completion_mask": [0, 1]}, ... ] >>> collator(examples) - {'prompt_input_ids': tensor([[1, 2, 3], - [0, 7, 8]]), - 'prompt_attention_mask': tensor([[1, 1, 1], - [0, 1, 1]]), - 'chosen_input_ids': tensor([[ 4, 5], - [ 9, 10]]), - 'chosen_attention_mask': tensor([[1, 1], - [1, 1]]), - 'rejected_input_ids': tensor([[ 6, 0, 0], - [11, 12, 13]]), - 'rejected_attention_mask': tensor([[1, 0, 0], - [1, 1, 1]]) - } + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[-100, 2, 3], + [-100, 5, -100]])} + + >>> # With padding_free + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3, 4, 5]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1]]), + 'position_ids': tensor([[0, 1, 2, 0, 1]]), + 'labels': tensor([[1, 2, 3, 4, 5]])} ``` """ pad_token_id: int + padding_free: bool = False + return_position_ids: bool = True + pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: # Convert to tensor - prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples] - prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids] - chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] - chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids] - rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] - rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids] - if "pixel_values" in examples[0]: - pixel_values = [torch.tensor(example["pixel_values"]) for example in examples] - if "pixel_attention_mask" in examples[0]: - pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples] - if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: - ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples]) - ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples]) + chosen_ids = [torch.tensor(example["chosen_ids"]) for example in examples] + rejected_ids = [torch.tensor(example["rejected_ids"]) for example in examples] + input_ids = chosen_ids + rejected_ids + + # Check if we have meaningful seq_lengths from packing (restarting sequences) + has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free + + # For packing with position_ids, we should NOT create attention_mask as it causes + # FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask + if not has_packed_position_ids: + attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] + + if self.return_position_ids: + if "seq_lengths" in examples[0]: + position_ids = self._convert_seq_lengths_to_position_ids( + [example["seq_lengths"] for example in examples] + ) + else: + position_ids = [torch.arange(len(ids)) for ids in input_ids] + chosen_mask = [torch.tensor(example["chosen_mask"]) for example in examples] + rejected_mask = [torch.tensor(example["rejected_mask"]) for example in examples] + completion_mask = chosen_mask + rejected_mask # Pad output = {} - output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left") - output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left") - output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id) - output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0) - output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id) - output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0) - if "pixel_values" in examples[0]: - output["pixel_values"] = pad(pixel_values, padding_value=0.0) - if "pixel_attention_mask" in examples[0]: - output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) - if "image_sizes" in examples[0]: - output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples]) - if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: - output["ref_chosen_logps"] = ref_chosen_logps - output["ref_rejected_logps"] = ref_rejected_logps + if self.padding_free: + output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0) + if not has_packed_position_ids: + output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0) + if self.return_position_ids: + output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0) + else: + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["attention_mask"] = pad( + attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.return_position_ids: + output["position_ids"] = pad( + position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["completion_mask"] = pad( + completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + + return output + + @staticmethod + def _convert_seq_lengths_to_position_ids(batch_seq_lengths: list[list[int]]) -> list[torch.Tensor]: + example_lengths = [sum(seq_lengths) for seq_lengths in batch_seq_lengths] + batch_seq_lengths = torch.tensor( + [seq_length for seq_lengths in batch_seq_lengths for seq_length in seq_lengths] + ) + position_ids = torch.ones(sum(example_lengths), dtype=batch_seq_lengths.dtype) + position_ids[0] = 0 + position_ids[batch_seq_lengths[:-1].cumsum(0)] = -(batch_seq_lengths[:-1] - 1) + position_ids = position_ids.cumsum(0) + return list(position_ids.split(example_lengths)) + + +@dataclass +class DataCollatorForVisionLanguageModeling(DataCollatorMixin): + """ + Data collator for vision-language modeling tasks. + + Unlike text-only datasets—where the collator typically receives pre-tokenized inputs ready for batching, + vision-language data processing involves converting images into pixel values. This conversion is disk-intensive, + making upfront preprocessing of the entire dataset impractical. Therefore, this collator performs tokenization and + image processing on-the-fly to efficiently prepare batches. + + Each input example should be a dictionary containing at least: + - An `"images"` key holding the image data. + - Either a `"messages"` key for conversational inputs or a `"text"` key for standard text inputs. + + The collator outputs a dictionary including: + - `"input_ids"`: Tensor of token IDs. + - `"attention_mask"`: Tensor indicating attention mask. + - `"pixel_values"`: Tensor representing image pixel values. + - `"labels"`: Tensor for training labels. + + Additional keys may be present depending on the processor, such as `"image_grid_thw"`. + Args: + processor (`ProcessorMixin`): + The processor used to tokenize text and process images. It must be a subclass of `ProcessorMixin` and + include a `tokenizer` with a defined `pad_token_id`. + max_length (`int` or `None`, optional, defaults to `None`): + Maximum sequence length for input tokens. If `None`, no truncation is applied. + pad_to_multiple_of (`int` or `None`, optional, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, optional, defaults to `"pt"`): + The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported. + + Example: + ```python + >>> from trl import DataCollatorForVisionLanguageModeling + >>> from transformers import AutoProcessor + + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> collator = DataCollatorForVisionLanguageModeling(processor) + >>> examples = [ + ... {"images": [Image.open("image_0.png")], "messages": [{"role": "user", "content": "What is this?"}]}, + ... {"images": [Image.open("image_1.png")], "messages": [{"role": "user", "content": "Describe this image."}]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), + 'pixel_values': tensor([[-0.9893, 0.1785, 1.5362, ..., -0.0582, 0.8661, -0.2431], + [-0.2302, 0.9522, -1.1061, ..., 0.0555, 1.3354, -0.6412], + [ 1.2150, 0.9084, 0.7041, ..., 0.2404, -0.8403, -0.5133], + ..., + [ 0.6895, 0.2807, 0.2515, ..., -0.2004, -1.2100, 0.0555], + [ 0.8209, -0.9748, 1.5654, ..., 1.6055, -0.4706, 0.5817], + [-1.0915, 0.4559, 0.9230, ..., 0.5106, 0.0982, -0.1720]]), + 'image_grid_thw': tensor([[1, 4, 4], + [1, 4, 4]]), + 'labels': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]])} + ``` + """ + + processor: ProcessorMixin + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + images = [example["images"] for example in examples] + + if "messages" in examples[0]: # conversational case + for example in examples: + image_included = False + for message in example["messages"]: + if message["role"] == "user": + if isinstance(message["content"], str) and not image_included: + message["content"] = [{"type": "image"}, {"type": "text", "text": message["content"]}] + image_included = True + elif isinstance(message["content"], str) and image_included: + message["content"] = [{"type": "text", "text": message["content"]}] + if message["role"] == "assistant": + if isinstance(message["content"], str): + message["content"] = [{"type": "text", "text": message["content"]}] + messages = [example["messages"] for example in examples] + texts = self.processor.apply_chat_template(messages, images=images) + elif "text" in examples[0]: # standard case + texts = [example["text"] for example in examples] + else: + raise KeyError( + "The input examples must contain either 'messages' for conversational data or 'text' for standard " + "data." + ) + + output = self.processor( + images=images, + text=texts, + padding=True, + pad_to_multiple_of=self.pad_to_multiple_of, + truncation=self.max_length is not None, + max_length=self.max_length, + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + labels = output["input_ids"].clone() + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in + # loss computation has to be done by the model, and masking them here would be infeasible in practice as vision + # token definitions vary across architectures. + output["labels"] = labels return output @@ -186,6 +359,18 @@ class DPOTrainer(Trainer): This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. + Example: + + ```python + from datasets import load_dataset + from trl import DPOTrainer + + dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + + trainer = DPOTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` + Args: model (`Union[str, PreTrainedModel]`): Model to be trained. Can be either: @@ -196,18 +381,14 @@ class DPOTrainer(Trainer): using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in `args.model_init_kwargs`. - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. - ref_model (`PreTrainedModelWrapper`): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation - and loss. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized. args ([`DPOConfig`], *optional*, defaults to `None`): Configuration for this trainer. If `None`, a default configuration is used. data_collator (`DataCollator`, *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. - Will default to [`DataCollatorForPreference`]. + Will default to a custom [`DataCollatorForLanguageModeling`]. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can - be either: + Dataset to use for training. DPO supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: - [Standard](dataset_formats#standard): Each sample contains plain text. - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role @@ -215,15 +396,10 @@ class DPOTrainer(Trainer): eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`): + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`] or `None`, *optional*, defaults to `None`): Processing class used to process the data. If `None`, the processing class is loaded from the model's name - with [`~transformers.AutoTokenizer.from_pretrained`]. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return - a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to - `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered - after the last eval batch to signal that the function needs to calculate and return the global summary - statistics rather than accumulating the batch-level statistics. + with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. + If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). @@ -238,6 +414,9 @@ class DPOTrainer(Trainer): `None`): A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before + initializing the Trainer. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): A function that preprocess the logits right before caching them at each evaluation step. Must take two @@ -254,15 +433,13 @@ class DPOTrainer(Trainer): def __init__( self, model: Union[str, nn.Module, PreTrainedModel], - ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: Optional[DPOConfig] = None, + args: Optional[Union[DPOConfig, TrainingArguments]] = None, data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, @@ -270,125 +447,84 @@ def __init__( peft_config: Optional["PeftConfig"] = None, ): # Args - model_id = model if isinstance(model, str) else model.config._name_or_path if args is None: - model_name = model_id.split("/")[-1] + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] args = DPOConfig(f"{model_name}-DPO") + elif isinstance(args, TrainingArguments) and not isinstance(args, DPOConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + dict_args.pop("push_to_hub_token") + args = DPOConfig(**dict_args) - # Handle the tokenizer - if processing_class is None: - processing_class = AutoTokenizer.from_pretrained(model_id) - - if args.padding_value is not None: - self.padding_value = args.padding_value - else: - if hasattr(processing_class, "pad_token_id") and processing_class.pad_token_id is not None: - self.padding_value = processing_class.pad_token_id - elif hasattr(processing_class, "tokenizer") and processing_class.tokenizer.pad_token_id is not None: - self.padding_value = processing_class.tokenizer.pad_token_id + # Model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str) and torch_dtype in ["bfloat16", "float16", "float32"]: + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype else: raise ValueError( - "`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in the " - "`processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " - "`tokenizer.pad_token` (e.g., `tokenizer.pad_token = tokenizer.eos_token`) before instantiating " - "the trainer." + "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **model_init_kwargs) + self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + warnings.warn( + "You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." ) - # Model - if not isinstance(model, str) and ref_model is model: - raise ValueError( - "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " - "same as `model`, you must mass a copy of it, or `None` if you use peft." - ) - - if args.model_init_kwargs is not None and not isinstance(model, str): - warnings.warn( - "You passed model_init_kwargs to the `DPOConfig`, but your model is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - if isinstance(model, str): - model = self._create_model_from_path(model, args) - - if args.ref_model_init_kwargs is not None and not isinstance(ref_model, str): - warnings.warn( - "You passed ref_model_init_kwargs to the `DPOConfig`, but your ref_model is already instantiated. " - "The `ref_model_init_kwargs` will be ignored." - ) - if isinstance(ref_model, str): - ref_model = self._create_model_from_path(ref_model, args, is_ref=True) - - # PEFT configuration and model wrapping - model = self._prepare_peft_model(model, ref_model, peft_config, args) + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(model_id) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if args.generate_during_eval and not (is_wandb_available() or is_comet_available() or is_mlflow_available()): + # Catch some wrong configurations related to VLMs + if self._is_vlm and args.padding_free: raise ValueError( - "`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." - " Please install `wandb`, `mlflow` or `comet-ml` to resolve." + "Padding-free training is yet not supported for vision-language models. Please set " + "`padding_free=False` in the `DPOConfig`." ) - self.is_encoder_decoder = model.config.is_encoder_decoder - self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys() - self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) - self.model_adapter_name = args.model_adapter_name - self.ref_adapter_name = args.ref_adapter_name - self.reference_free = args.reference_free - - if ref_model: - self.ref_model = ref_model - elif self.is_peft_model or args.precompute_ref_log_probs: - # The `model` with adapters turned off will be used as the reference model - self.ref_model = None - else: - self.ref_model = create_reference_model(model) - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - - # Liger kernel - if args.use_liger_loss: - if not is_liger_kernel_available(): - raise ImportError( - "You set `use_liger_loss=True` but the liger kernel is not available. " - "Please install liger-kernel first: `pip install liger-kernel`" - ) - if args.loss_type != "sigmoid": - raise ValueError( - "You set `use_liger_loss=True` but the loss type is not `sigmoid`. " - "Please set `loss_type='sigmoid'` to use the liger kernel." - ) - self.dpo_loss_fn = LigerFusedLinearDPOLoss( - ignore_index=args.label_pad_token_id, - beta=args.beta, - use_ref_model=not args.reference_free, - average_log_prob=False, - ) - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and - # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens - # of the input, floating-point operations will not be computed." To suppress this warning, we set the - # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate - # that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the + # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. + self.num_virtual_tokens = 0 + + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): + model = prepare_peft_model(model, peft_config, args) + if model.active_adapter in model.peft_config: + peft_model_config = model.peft_config[model.active_adapter] + self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) # Data collator - if data_collator is None: - data_collator = DataCollatorForPreference(pad_token_id=self.padding_value) - - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length - self.max_length = args.max_length - self.truncation_mode = args.truncation_mode - self.precompute_ref_log_probs = args.precompute_ref_log_probs - self.use_logits_to_keep = args.use_logits_to_keep - - if args.padding_free: - if model.config._attn_implementation != "flash_attention_2": + self.padding_free = args.padding_free + use_flash_attention = model.config._attn_implementation in [ + "flash_attention_2", + "kernels-community/vllm-flash-attn3", + ] + if self.padding_free: + if data_collator is not None: + raise ValueError("Passing a custom data collator is not supported when using padding-free.") + if not use_flash_attention: warnings.warn( "Padding-free training is enabled, but the attention implementation is not set to " "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " @@ -403,56 +539,60 @@ def __init__( "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " "to at least 2." ) - self.padding_free = args.padding_free - # Since ref_logs are precomputed on the first call to get_train/eval_dataloader - # keep track of first called to avoid computation of future calls - self._precomputed_train_ref_log_probs = False - self._precomputed_eval_ref_log_probs = False - - self.beta = args.beta - self.label_smoothing = args.label_smoothing - self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type] - self.loss_weights = args.loss_weights - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.use_weighting = args.use_weighting - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - warnings.warn( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - UserWarning, - ) - for loss_type in self.loss_type: - if ( - loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] - and args.label_smoothing > 0 - ): - warnings.warn( - f"You are using the {loss_type} loss type that does not support label smoothing. The " - "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.", - UserWarning, + # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format + # is prompt-completion, and False if the dataset format is language modeling. + dataset_sample = next(iter(train_dataset)) + + if data_collator is None and not self._is_vlm: + # Get the pad token: if not provided, use the one from the processing class or the eos token + # if the processing class does not have a pad token. + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." ) - if loss_type == "kto_pair": - raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.") - - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - self.f_divergence_type = args.f_divergence_type - self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} - self.dataset_num_proc = args.dataset_num_proc - - # Dataset preparation - train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") - if eval_dataset is not None: - if isinstance(eval_dataset, dict): - eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, key) - for key, dataset in eval_dataset.items() - } - else: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + data_collator = DataCollatorForLanguageModeling( + pad_token_id=pad_token_id, + padding_free=self.padding_free, + # Using position_ids without flash_attn hurts the training + return_position_ids=use_flash_attention, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + elif data_collator is None and self._is_vlm: + data_collator = DataCollatorForVisionLanguageModeling( + processor=processing_class, + max_length=args.max_length, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + + # Dataset + # Skip dataset preparation if it's a VLM, where preprocessing (e.g., image-to-pixel conversion) is too costly + # and done on the fly instead. + skip_prepare_dataset = self._is_vlm + if not skip_prepare_dataset: + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation super().__init__( model=model, @@ -461,6 +601,7 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, + compute_loss_func=compute_loss_func, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, @@ -468,153 +609,16 @@ def __init__( preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - # Deepspeed Zero-3 does not support precompute_ref_log_probs - if self.is_deepspeed_enabled: - if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." - ) - - if self.ref_model is None: - if not (self.is_peft_model or self.precompute_ref_log_probs): - raise ValueError( - "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" - ) - if args.sync_ref_model: - raise ValueError( - "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." - ) - else: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - elif self.is_fsdp_enabled: - self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - if args.sync_ref_model: - if self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." - ) - - self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) - - if "bco_pair" in self.loss_type: - self.running = RunningMoments(self.accelerator) - - def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool = False) -> PreTrainedModel: - """Creates a model from a path or model identifier.""" - if not is_ref: - model_init_kwargs = args.model_init_kwargs or {} - else: - model_init_kwargs = args.ref_model_init_kwargs or {} - - # Handle torch dtype - torch_dtype = model_init_kwargs.get("torch_dtype") - if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: - pass # torch_dtype is already a torch.dtype or "auto" or None - elif isinstance(torch_dtype, str): # it's a str, but not "auto" - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype - else: - raise ValueError( - "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." - ) - # Disable caching if gradient checkpointing is enabled (not supported) - # if args.gradient_checkpointing: - # model_init_kwargs["use_cache"] = False - - # Create model - model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) - return model - - def _prepare_peft_model( - self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig - ) -> PreTrainedModel: - """Prepares a model for PEFT training.""" - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if ref_model is not None and not args.force_use_ref_model: - raise ValueError( - "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference" - " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init." - " if you want to use a different ref_model." - ) - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - - else: - model = self._prepare_gradient_checkpointing(model, args) - - # get peft model with the given config - model = get_peft_model(model, peft_config) - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - else: - model = self._prepare_gradient_checkpointing(model, args) - - return model - - def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig): - """Prepare the gradienting checkpointing for the model.""" - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - if args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - return model - def _prepare_dataset( self, dataset: Union[Dataset, IterableDataset], @@ -622,1303 +626,213 @@ def _prepare_dataset( args: DPOConfig, dataset_name: str, ) -> Union[Dataset, IterableDataset]: + # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from + # sampled data. + if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` + dataset = dataset.with_transform(remove_none_values) + # Build the kwargs for the `map` function map_kwargs = {} - if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc map_kwargs["num_proc"] = args.dataset_num_proc - map_kwargs["writer_batch_size"] = 10 with PartialState().main_process_first(): - # Extract prompt if needed - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" - dataset = dataset.map(maybe_extract_prompt, **map_kwargs) + # Extract the prompt if needed + first_example = next(iter(dataset)) + if "prompt" not in first_example: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset" + dataset = dataset.map(extract_prompt, **map_kwargs) # Apply the chat template if needed - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs - ) + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if "text" in example and not example["text"].endswith(eos_token): # language modeling case + example["text"] = example["text"] + eos_token + elif "completion" in example and not example["completion"].endswith(eos_token): + example["completion"] = example["completion"] + eos_token + return example + + column_names = list(next(iter(dataset)).keys()) + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + remove_columns="messages" if "messages" in column_names else None, # renamed to "text" + **map_kwargs, + ) # Tokenize the dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - dataset = dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - remove_columns=["chosen", "rejected"], - fn_kwargs={ - "processing_class": processing_class, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) - "add_special_tokens": False, - }, - **map_kwargs, - ) - - return dataset - - @staticmethod - def tokenize_row( - features: dict[str, str], - processing_class: PreTrainedTokenizerBase, - max_prompt_length: Optional[int] = None, - max_completion_length: Optional[int] = None, - add_special_tokens: bool = True, - ) -> dict[str, list[int]]: - """ - Tokenize a row of the dataset. - - Args: - features (`dict[str, str]`): - Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. - processing_class (`PreTrainedTokenizerBase`): - Processing class used to process the data. - max_prompt_length (`int` or `None`): - Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. - max_completion_length (`int` or `None`): - Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. - add_special_tokens (`bool`): - Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, - the prompt sequence will have a bos token prepended and an eos token appended. In any case, the - completion sequences will have an eos token appended. - - Returns: - `dict[str, list[int]]`: - Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and - `"rejected_input_ids". - - Example: - ```python - >>> from transformers import GPT2Tokenizer - - >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} - >>> DPOTrainer.tokenize_row( - ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False - ... ) - {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} - ``` - """ - tokenizer = processing_class # the processing class is a tokenizer - prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] - chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] - rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] - - # Add special tokens (typically for encoder-decoder models) - if add_special_tokens: - if tokenizer.bos_token_id is not None: - prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids - if tokenizer.eos_token_id is not None: - prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] - chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] - rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] - - # Truncate prompt and completion sequences - if max_prompt_length is not None: - prompt_input_ids = prompt_input_ids[-max_prompt_length:] - if max_completion_length is not None: - chosen_input_ids = chosen_input_ids[:max_completion_length] - rejected_input_ids = rejected_input_ids[:max_completion_length] - - return { - "prompt_input_ids": prompt_input_ids, - "chosen_input_ids": chosen_input_ids, - "rejected_input_ids": rejected_input_ids, - } - - @staticmethod - def process_row( - features: dict[str, str], - processing_class: PreTrainedTokenizerBase, - max_prompt_length: Optional[int] = None, - max_completion_length: Optional[int] = None, - add_special_tokens: bool = True, - ) -> dict[str, list[int]]: - """ - Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. - """ - processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor - processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) - - prompt_input_ids = processed_features["input_ids"][0] - pixel_values = processed_features["pixel_values"][0] - chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] - rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] - - # Add special tokens (typically for encoder-decoder models) - if add_special_tokens: - if tokenizer.bos_token_id is not None: - prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids - if tokenizer.eos_token_id is not None: - prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] - chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] - rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] - - # Truncate prompt and completion sequences - if max_prompt_length is not None: - prompt_input_ids = prompt_input_ids[-max_prompt_length:] - if max_completion_length is not None: - chosen_input_ids = chosen_input_ids[:max_completion_length] - rejected_input_ids = rejected_input_ids[:max_completion_length] - - output = { - "prompt_input_ids": prompt_input_ids, - "pixel_values": pixel_values, - "chosen_input_ids": chosen_input_ids, - "rejected_input_ids": rejected_input_ids, - } - - if "pixel_attention_mask" in processed_features: - output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] - if "image_sizes" in processed_features: - output["image_sizes"] = processed_features["image_sizes"][0] - - return output - - def _set_signature_columns_if_needed(self): - # If `self.args.remove_unused_columns` is True, non-signature columns are removed. - # By default, this method sets `self._signature_columns` to the model's expected inputs. - # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. - # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override. - if self._signature_columns is None: - self._signature_columns = [ - "prompt_input_ids", - "chosen_input_ids", - "rejected_input_ids", - "image_sizes", - "ref_chosen_logps", - "ref_rejected_logps", - ] - - def get_train_dataloader(self) -> DataLoader: - """ - Returns the training [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. - """ - - if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: - batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size - dataloader_params = { - "batch_size": batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) - - ref_chosen_logps = [] - ref_rejected_logps = [] - for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) - ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( - (ref_chosen_logp, ref_rejected_logp) - ) - ref_chosen_logps.append(ref_chosen_logp.cpu()) - ref_rejected_logps.append(ref_rejected_logp.cpu()) - - # Unnecessary cache clearing to avoid OOM - empty_cache() - self.accelerator.free_memory() - - all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() - all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() - - self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) - self.train_dataset = self.train_dataset.add_column( - name="ref_rejected_logps", column=all_ref_rejected_logps - ) - - self._precomputed_train_ref_log_probs = True - - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - """ - Returns the evaluation [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. - - Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: - batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size - dataloader_params = { - "batch_size": batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - - ref_chosen_logps = [] - ref_rejected_logps = [] - for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) - ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( - (ref_chosen_logp, ref_rejected_logp) - ) - ref_chosen_logps.append(ref_chosen_logp.cpu()) - ref_rejected_logps.append(ref_rejected_logp.cpu()) - - all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() - all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() - - eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) - eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) - - # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs - if self.eval_dataset is not None: - self.eval_dataset = eval_dataset - self._precomputed_eval_ref_log_probs = True - - return super().get_eval_dataloader(eval_dataset=eval_dataset) - - @contextmanager - def null_ref_context(self): - """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with ( - self.accelerator.unwrap_model(self.model).disable_adapter() - if self.is_peft_model and not self.ref_adapter_name - else nullcontext() - ): - if self.ref_adapter_name: - self.model.set_adapter(self.ref_adapter_name) - yield - if self.ref_adapter_name: - self.model.set_adapter(self.model_adapter_name or "default") - - def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]: - """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" - compte_ref_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with torch.no_grad(), compte_ref_context_manager: - if self.ref_model is None: - with self.null_ref_context(): - ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) - else: - ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) - return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] - - @staticmethod - def concatenated_inputs( - batch: dict[str, Union[list, torch.LongTensor]], padding_value: int - ) -> dict[str, torch.LongTensor]: - """ - Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and - completion sequences. - - Args: - batch (`dict[str, Union[list, torch.LongTensor]]`): - A batch of input data. The batch must contain the following keys: - - - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input - IDs. - - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen - completion input IDs. - - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected - completion input IDs. - - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. - - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. - - padding_value (`int`): - The padding value to use for the concatenated completion sequences (`chosen_input_ids` and - `rejected_input_ids`). - - Returns: - `dict[str, torch.LongTensor]`: A dictionary containing: - - - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. - - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * - batch_size, max_completion_length)`. - - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, - prompt_length)`. - - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * - batch_size, max_completion_length)`. - - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. - - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if - `"prompt_pixel_attention_mask"` are present. - - Notes: - The completion input IDs and attention masks are padded to the maximum completion length of the chosen or - rejected sequences. - """ - output = {} - - # For the prompt, the input_ids are the same for both the chosen and rejected responses - output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) - output["prompt_attention_mask"] = torch.cat( - [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 - ) - if "pixel_values" in batch: - output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) - - if "pixel_attention_mask" in batch: - output["pixel_attention_mask"] = torch.cat( - [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 - ) - if "image_sizes" in batch: - output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) - - # Concatenate the chosen and rejected completions - max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - output["completion_input_ids"] = torch.cat( - ( - pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), - pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), - ), - ) - output["completion_attention_mask"] = torch.cat( - ( - pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), - pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), - ), - ) - - return output - - def dpo_loss( - self, - chosen_logps: torch.FloatTensor, - rejected_logps: torch.FloatTensor, - ref_chosen_logps: torch.FloatTensor, - ref_rejected_logps: torch.FloatTensor, - loss_type: str = "sigmoid", - model_output: dict[str, torch.FloatTensor] = None, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """ - Compute the DPO loss for a batch of policy and reference model log probabilities. - - Args: - chosen_logps (`torch.FloatTensor`): - Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. - rejected_logps (`torch.FloatTensor`): - Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. - ref_chosen_logps (`torch.FloatTensor`): - Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. - ref_rejected_logps (`torch.FloatTensor`): - Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. - - Returns: - A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO - loss for each example in the batch. The `chosen_rewards` and `rejected_rewards` tensors contain the rewards - for the chosen and rejected responses, respectively. - """ - device = self.accelerator.device - - # Get the log ratios for the chosen and rejected responses - chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) - rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) - - if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: - # The alpha-divergence formula: (1 - u^-alpha) / alpha - # The divergence difference between the chosen and rejected sample is: - # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha - # = (u[l]^-alpha - u[w]^-alpha) / alpha - # where u[w] and u[l] are the policy/reference probability ratios - # for the chosen and rejected samples, respectively. - alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT - if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: - alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) - logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef - else: - logratios = chosen_logps - rejected_logps - if self.reference_free: - ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) - else: - ref_logratios = ref_chosen_logps - ref_rejected_logps - - logratios = logratios.to(self.accelerator.device) - ref_logratios = ref_logratios.to(self.accelerator.device) - logits = logratios - ref_logratios - - if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: - # The js-divergence formula: log(2 * u / (1 + u)) - # The divergence difference between the chosen and rejected sample is: - # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) - # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) - # where u[w] and u[l] are the policy/reference probability ratios - # for the chosen and rejected samples, respectively. - logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) - - # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. - # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the - # labels and calculates a conservative DPO loss. - if loss_type == "sigmoid": - losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * logits) * self.label_smoothing - ) - - elif loss_type == "robust": - losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing - ) / (1 - 2 * self.label_smoothing) - - elif loss_type == "exo_pair": - # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856 - import math - - if self.label_smoothing == 0: - self.label_smoothing = 1e-3 - losses = (self.beta * logits).sigmoid() * ( - F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) - ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) - - elif loss_type == "hinge": - losses = torch.relu(1 - self.beta * logits) - - elif loss_type == "ipo": - # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. - losses = (logits - 1 / (2 * self.beta)) ** 2 - - elif loss_type == "bco_pair": - chosen_logratios = chosen_logps - ref_chosen_logps - rejected_logratios = rejected_logps - ref_rejected_logps - chosen_rewards = self.beta * chosen_logratios - rejected_rewards = self.beta * rejected_logratios - rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() - self.running.update(rewards) - delta = self.running.mean - losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( - -(self.beta * rejected_logratios - delta) - ) - - elif loss_type == "sppo_hard": - # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, - # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. - # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is - # set to 1 for the winner and 0 for the loser. - a = chosen_logps - ref_chosen_logps - b = rejected_logps - ref_rejected_logps - losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 - - elif loss_type == "nca_pair": - chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta - rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta - losses = ( - -F.logsigmoid(chosen_rewards) - - 0.5 * F.logsigmoid(-chosen_rewards) - - 0.5 * F.logsigmoid(-rejected_rewards) - ) - - elif loss_type == "aot_pair": - chosen_logratios = chosen_logps - ref_chosen_logps - rejected_logratios = rejected_logps - ref_rejected_logps - chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) - rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) - delta = chosen_logratios_sorted - rejected_logratios_sorted - losses = ( - -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * delta) * self.label_smoothing - ) - - elif loss_type == "aot": - logratios = chosen_logps - rejected_logps - ref_logratios = ref_chosen_logps - ref_rejected_logps - logratios_sorted, _ = torch.sort(logratios, dim=0) - ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) - delta = logratios_sorted - ref_logratios_sorted - losses = ( - -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * delta) * self.label_smoothing - ) - - elif loss_type == "apo_zero": - # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) - # Use this loss when you believe the chosen outputs are better than your model's default output - losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood - losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood - losses = losses_chosen + losses_rejected - - elif loss_type == "apo_down": - # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) - # Use this loss when you believe the chosen outputs are worse than your model's default output. - # Decrease chosen likelihood and decrease rejected likelihood more - losses_chosen = F.sigmoid(self.beta * chosen_logratios) - losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) - losses = losses_chosen + losses_rejected - - elif loss_type == "discopop": - # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) - # This loss was discovered with LLM discovery - logratios = chosen_logps - rejected_logps - ref_logratios = ref_chosen_logps - ref_rejected_logps - logits = logratios - ref_logratios - logits = logits * self.beta - # Modulate the mixing coefficient based on the log ratio magnitudes - log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) - logistic_component = -F.logsigmoid(logits) - exp_component = torch.exp(-logits) - # Blend between logistic and exponential component based on log ratio modulation - losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation - - elif loss_type == "sft": - # SFT loss is the negative log likelihood loss on chosen responses - # This acts as the generation loss component in MPO - sft_loss = model_output["nll_loss"] - # Create losses tensor with same shape as other losses (per-sample) - batch_size = chosen_logps.shape[0] - losses = sft_loss.expand(batch_size) - # For SFT, we don't have preference rewards, so use zeros - chosen_rewards = torch.zeros_like(chosen_logps) - rejected_rewards = torch.zeros_like(rejected_logps) - - else: - raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " - "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', " - "'apo_down', 'sft']" - ) - - chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() - rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() - - return losses, chosen_rewards, rejected_rewards - - def _compute_loss_liger( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] - ) -> dict[str, torch.Tensor]: - unwrapped_model = self.accelerator.unwrap_model(model) - concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) - - model_kwargs = {} - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - # Add the pixel values and attention masks for vision models - if "pixel_values" in concatenated_batch: - model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] - if "pixel_attention_mask" in concatenated_batch: - model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] - if "image_sizes" in concatenated_batch: - model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] - - prompt_attention_mask = concatenated_batch["prompt_attention_mask"] - completion_attention_mask = concatenated_batch["completion_attention_mask"] - - if self.is_encoder_decoder: - # 1. Get encoder outputs - encoder_outputs = unwrapped_model.get_encoder()( - concatenated_batch["prompt_input_ids"], - attention_mask=concatenated_batch["prompt_attention_mask"], - return_dict=True, - ) - # 2. Prepare decoder inputs - decoder_input_ids = shift_tokens_right( - concatenated_batch["completion_input_ids"], - unwrapped_model.config.decoder_start_token_id, - ) - # 3. Get decoder outputs - decoder_outputs = unwrapped_model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - hidden_states = decoder_outputs.last_hidden_state - - ref_hidden_states = None - if not self.reference_free and self.ref_model is not None: - unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) - ref_encoder_outputs = unwrapped_ref_model.get_encoder()( - concatenated_batch["prompt_input_ids"], - attention_mask=concatenated_batch["prompt_attention_mask"], - return_dict=True, - ) - ref_decoder_outputs = unwrapped_ref_model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - ref_hidden_states = ref_decoder_outputs.last_hidden_state - elif not self.reference_free: - with self.null_ref_context(): - ref_encoder_outputs = unwrapped_model.get_encoder()( - concatenated_batch["prompt_input_ids"], - attention_mask=concatenated_batch["prompt_attention_mask"], - return_dict=True, + def tokenize(example, processing_class): + output = {} + if is_conversational(example): + prompt_ids = processing_class.apply_chat_template( + example["prompt"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), ) - ref_decoder_outputs = unwrapped_model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, + prompt_chosen_ids = processing_class.apply_chat_template( + example["prompt"] + example["chosen"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), ) - ref_hidden_states = ref_decoder_outputs.last_hidden_state - - labels = concatenated_batch["completion_input_ids"] - loss_mask = completion_attention_mask.bool() - else: - # For decoder-only models - input_ids = torch.cat( - (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 - ) - attention_mask = torch.cat( - (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), - dim=1, - ) - # Mask the prompt but not the completion for the loss - loss_mask = torch.cat( - (torch.zeros_like(prompt_attention_mask), completion_attention_mask), - dim=1, - ) - - # Flush and truncate - if self.max_length is not None and self.max_length < attention_mask.size(1): - if self.truncation_mode == "keep_start": - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - attention_mask = attention_mask[:, : self.max_length] - input_ids = input_ids[:, : self.max_length] - loss_mask = loss_mask[:, : self.max_length] - elif self.truncation_mode == "keep_end": - # Flush right before truncating left, then flush left - # [[0, 0, x, x, x, x], -> [[0, 0, x, x], - # [0, x, x, x, 0, 0]] [0, x, x, x]] - attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) - input_ids = input_ids[:, -self.max_length :] - attention_mask = attention_mask[:, -self.max_length :] - loss_mask = loss_mask[:, -self.max_length :] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - else: - raise ValueError( - f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " - "'keep_start']." + prompt_rejected_ids = processing_class.apply_chat_template( + example["prompt"] + example["rejected"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), ) - else: - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - - # Add logits_to_keep optimization - if self.use_logits_to_keep: - first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 - model_kwargs["logits_to_keep"] = logits_to_keep - - model_kwargs["output_hidden_states"] = True - - # Add padding-free training support - if self.padding_free: - input_ids = input_ids[attention_mask.bool()].unsqueeze(0) - loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) - position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 - model_kwargs["position_ids"] = position_ids - else: - model_kwargs["attention_mask"] = attention_mask - - # Get the base model outputs (before LM head) - if hasattr(unwrapped_model, "get_decoder"): - base_model = unwrapped_model.get_decoder() - else: - base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) - - outputs = base_model( - input_ids, - use_cache=False, - **model_kwargs, - ) - hidden_states = outputs.last_hidden_state[:, :-1] - - # Get reference hidden states if needed - ref_hidden_states = None - if not self.reference_free and self.ref_model is not None: - unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) - if hasattr(unwrapped_ref_model, "get_decoder"): - ref_base_model = unwrapped_ref_model.get_decoder() else: - ref_base_model = getattr( - unwrapped_ref_model, self.args.base_model_attribute_name, unwrapped_ref_model + prompt_ids = processing_class(text=example["prompt"])["input_ids"] + prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"])[ + "input_ids" + ] + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_chosen_ids[: len(prompt_ids)] == prompt_ids: + warnings.warn( + "Mismatch between tokenized prompt and the start of tokenized prompt+chosen. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." ) - - ref_outputs = ref_base_model( - input_ids, - use_cache=False, - **model_kwargs, - ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - elif not self.reference_free: - if hasattr(unwrapped_model, "get_decoder"): - ref_base_model = unwrapped_model.get_decoder() - else: - ref_base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) - with self.null_ref_context(): - ref_outputs = ref_base_model( - input_ids, - use_cache=False, - **model_kwargs, + if not prompt_rejected_ids[: len(prompt_ids)] == prompt_ids: + warnings.warn( + "Mismatch between tokenized prompt and the start of tokenized prompt+rejected. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - - masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id) - labels = masked_input_ids[:, 1:] # Shift right for casual LM - # Get the LM head - lm_head = unwrapped_model.get_output_embeddings() + # Create a completion mask + chosen_mask = [0] * len(prompt_ids) + [1] * (len(prompt_chosen_ids) - len(prompt_ids)) + rejected_mask = [0] * len(prompt_ids) + [1] * (len(prompt_rejected_ids) - len(prompt_ids)) + output["chosen_ids"] = prompt_chosen_ids + output["rejected_ids"] = prompt_rejected_ids + output["chosen_mask"] = chosen_mask + output["rejected_mask"] = rejected_mask + + return output + + dataset = dataset.map(tokenize, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + + # Truncate + if args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating {dataset_name} dataset" + dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + # For Liger kernel, ensure only the essential columns + if args.use_liger_kernel: + dataset = dataset.select_columns( + {"input_ids", "seq_lengths", "completion_mask"}.intersection(dataset.column_names) + ) + return dataset - # Get reference model weights if needed - ref_weight = None - ref_bias = None - if not self.reference_free: - if self.ref_model is not None: - unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) - ref_lm_head = unwrapped_ref_model.get_output_embeddings() + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). + if self._signature_columns is None: + if self._is_vlm: + self._signature_columns = ["messages", "images"] else: - with self.null_ref_context(): - ref_lm_head = unwrapped_model.get_output_embeddings() - ref_weight = ref_lm_head.weight - ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None - - # Compute loss using Liger kernel - loss_output = self.dpo_loss_fn( - lm_head.weight, - hidden_states, - labels, - bias=lm_head.bias if hasattr(lm_head, "bias") else None, - ref_input=ref_hidden_states if not self.reference_free else None, - ref_weight=ref_weight if not self.reference_free else None, - ref_bias=ref_bias if not self.reference_free else None, - ) - ( - loss, - (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), - ) = loss_output - - output = { - "loss": loss, - "chosen_logps": chosen_logps, - "rejected_logps": rejected_logps, - "mean_chosen_logits": chosen_logits_mean, - "mean_rejected_logits": rejected_logits_mean, - "nll_loss": nll_loss, - "chosen_rewards": aux_outputs[0], - "rejected_rewards": aux_outputs[1], - } - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss - - return output + self._signature_columns = ["chosen_ids", "rejected_ids", "chosen_mask", "rejected_mask"] - def concatenated_forward( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False - ) -> dict[str, torch.Tensor]: + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ - Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - - Args: - model: - Model to run the forward pass on. - batch: - Batch of input data. - is_ref_model: - Whether this method is being called for the reference model. If `True`, length desensitization is not - applied. + Compute training loss and additionally compute token accuracies """ - num_examples = batch["prompt_input_ids"].shape[0] - - concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) - - model_kwargs = {"use_cache": False} - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - # Add the pixel values and attention masks for vision models - if "pixel_values" in concatenated_batch: - model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] - if "pixel_attention_mask" in concatenated_batch: - model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] - if "image_sizes" in concatenated_batch: - model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] - - prompt_input_ids = concatenated_batch["prompt_input_ids"] - prompt_attention_mask = concatenated_batch["prompt_attention_mask"] - completion_input_ids = concatenated_batch["completion_input_ids"] - completion_attention_mask = concatenated_batch["completion_attention_mask"] - if self.is_encoder_decoder: - labels = completion_input_ids - labels[completion_attention_mask == 0] = self.label_pad_token_id - outputs = model( - input_ids=prompt_input_ids, - attention_mask=prompt_attention_mask, - labels=labels, # we need the labels for the logits to be returned - **model_kwargs, - ) - logits = outputs.logits - loss_mask = completion_attention_mask.bool() - else: - # Concatenate the prompt and completion inputs - input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) - attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) - # Mask the prompt but not the completion for the loss - loss_mask = torch.cat( - (torch.zeros_like(prompt_attention_mask), completion_attention_mask), - dim=1, - ) - - # Flush and truncate - if self.max_length is not None and self.max_length < attention_mask.size(1): - if self.truncation_mode == "keep_start": - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - attention_mask = attention_mask[:, : self.max_length] - input_ids = input_ids[:, : self.max_length] - loss_mask = loss_mask[:, : self.max_length] - elif self.truncation_mode == "keep_end": - # Flush right before truncating left, then flush left - # [[0, 0, x, x, x, x], -> [[0, 0, x, x], - # [0, x, x, x, 0, 0]] [0, x, x, x]] - attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) - input_ids = input_ids[:, -self.max_length :] - attention_mask = attention_mask[:, -self.max_length :] - loss_mask = loss_mask[:, -self.max_length :] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - else: - raise ValueError( - f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " - "'keep_start']." - ) - else: - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - - if self.use_logits_to_keep: - # Compute logits_to_keep based on loss_mask pattern: - # [[0, 0, 0, x, x, x, x], - # [0, 0, 0, x, x, x, 0]] - # ^ start computing logits from here ([:, -(7-3+1):]) - first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label - model_kwargs["logits_to_keep"] = logits_to_keep - - model_kwargs["output_hidden_states"] = True - - if self.padding_free: - # Flatten the input_ids, position_ids, and loss_mask - # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] - # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] - input_ids = input_ids[attention_mask.bool()].unsqueeze(0) - loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) - position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 - model_kwargs["position_ids"] = position_ids + mode = "train" if self.model.training else "eval" + + + shifted_labels = inputs["input_ids"][:, 1:] # (B, L-1) + shifted_mask = inputs["completion_mask"][:, 1:] # (B, L-1) + + logits = model(**inputs, use_cache=False).logits + ref_logits = self.ref_model(**inputs, use_cache=False).logits + + shifted_logits = ref_logits[:, :-1, :] # (B, L-1, V) + ref_all_logps = F.log_softmax(shifted_logits, dim=-1) # (B, L-1, V) + ref_per_token_logprobs = ref_all_logps.gather(-1, shifted_labels.unsqueeze(-1)).squeeze(-1) # (B, L-1) + ref_per_token_logprobs = ref_per_token_logprobs * shifted_mask + ref_logprobs = ref_per_token_logprobs.sum(dim=-1) # (B,) + ref_chosen_logprobs = ref_logprobs[: ref_logprobs.size(0) // 2] # (B//2,) + ref_rejected_logprobs = ref_logprobs[ref_logprobs.size(0) // 2 :] # (B//2,) + + shifted_logits = logits[:, :-1, :] # (B, L-1, V) + all_logps = F.log_softmax(shifted_logits, dim=-1) # (B, L-1, V) + per_token_logprobs = all_logps.gather(-1, shifted_labels.unsqueeze(-1)).squeeze(-1) # (B, L-1) + per_token_logprobs = per_token_logprobs * shifted_mask + logprobs = per_token_logprobs.sum(dim=-1) # (B,) + chosen_logprobs = logprobs[: logprobs.size(0) // 2] # (B//2,) + rejected_logprobs = logprobs[logprobs.size(0) // 2 :] # (B//2,) + + + if mode == "train": + # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q, + # cu_seq_lens_k, and max_length_k, max_length_q and position_ids. + if "attention_mask" in inputs: + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + elif "position_ids" in inputs: + local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) + num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() else: - model_kwargs["attention_mask"] = attention_mask - - outputs = model(input_ids, **model_kwargs) - logits = outputs.logits - - # Offset the logits by one to align with the labels - labels = torch.roll(input_ids, shifts=-1, dims=1) - loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() - - if self.use_logits_to_keep: - # Align labels with logits - # logits: -, -, [x2, x3, x4, x5, x6] - # ^ --------- ^ after logits[:, :-1, :] - # labels: [y0, y1, y2, y3, y4, y5, y6] - # ^ --------- ^ with logits_to_keep=4, [:, -4:] - # loss_mask: [0, 0, 0, 1, 1, 1, 1] - labels = labels[:, -logits_to_keep:] - loss_mask = loss_mask[:, -logits_to_keep:] - - if logits.shape[:2] != labels.shape[:2]: - # for LLaVA, the returned logits include the image tokens (placed before the text tokens) - seq_len = labels.shape[1] - logits = logits[:, -seq_len:] - - # Compute the log probabilities of the labels - labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = selective_log_softmax(logits, labels) - per_token_logps[~loss_mask] = 0 - per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) - - if self.padding_free: - # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) - batch_size, seq_len = attention_mask.shape - per_token_logps_ = torch.zeros( - batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype - ) - per_token_logps_[attention_mask.bool()] = per_token_logps - per_token_logps = per_token_logps_ - - all_logps = per_token_logps[:, 1:].sum(-1) + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - output = {} - - if self.use_weighting: + # Compute token accuracy if we have labels and if the model is not using Liger (no logits) + if "labels" in inputs and not self.args.use_liger_kernel: with torch.no_grad(): - # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 - logprobs = F.log_softmax(logits, dim=-1) - weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space - per_token_logps_adjusted = per_token_logps - weights_adjustment_factor - all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_weights = all_weights[:num_examples] - rejected_weights = all_weights[num_examples:] - output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) - - if self.args.rpo_alpha is not None or "sft" in self.loss_type: - # Only use the chosen logits for the RPO loss or SFT loss - chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] - chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples] - - # Compute the log probabilities of the labels - output["nll_loss"] = F.cross_entropy( - torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 - ) + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = inputs["labels"][..., 1:].contiguous() - if "ipo" in self.loss_type: - all_logps = all_logps / loss_mask.sum(-1) + # When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do + # not correspond to actual input labels. + shift_logits = shift_logits[:, self.num_virtual_tokens :, :] - if self.args.ld_alpha is not None and not is_ref_model: - # Compute response lengths based on loss_mask - completion_lengths = loss_mask.sum(dim=1) + # Get predictions + predictions = shift_logits.argmax(dim=-1) - chosen_lengths = completion_lengths[:num_examples] - rejected_lengths = completion_lengths[num_examples:] - public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper - public_lengths = torch.cat([public_lengths, public_lengths], dim=0) + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 - seq_len = per_token_logps.size(1) - position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) - - ld_mask = position_ids < public_lengths.unsqueeze(1) - mask = position_ids < completion_lengths.unsqueeze(1) - - front_mask = (ld_mask & mask).float() - rear_mask = (~ld_mask & mask).float() - front_logps = (per_token_logps * front_mask).sum(dim=1) - rear_logps = (per_token_logps * rear_mask).sum(dim=1) - - all_logps = front_logps + self.args.ld_alpha * rear_logps - - output["chosen_logps"] = all_logps[:num_examples] - output["rejected_logps"] = all_logps[num_examples:] - - # Compute the mean logits - if self.padding_free: - # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). - # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, - # and the second half to the rejected tokens. - # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. - split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] - mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() - mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() - else: - mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() - mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() - - output["mean_chosen_logits"] = mean_chosen_logits - output["mean_rejected_logits"] = mean_rejected_logits - - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss - - return output + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() - def get_batch_loss_metrics( - self, - model: Union[PreTrainedModel, nn.Module], - batch: dict[str, Union[list, torch.LongTensor]], - train_eval: Literal["train", "eval"] = "train", - ) -> tuple[torch.Tensor, dict[str, float]]: - """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - - if self.args.use_liger_loss: - model_output = self._compute_loss_liger(model, batch) - losses = model_output["loss"] - chosen_rewards = model_output["chosen_rewards"] - rejected_rewards = model_output["rejected_rewards"] - else: - model_output = self.concatenated_forward(model, batch) - - # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model - if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: - ref_chosen_logps = batch["ref_chosen_logps"] - ref_rejected_logps = batch["ref_rejected_logps"] - else: - ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) - - # Initialize combined losses - losses = 0 - chosen_rewards = 0 - rejected_rewards = 0 - - # Compute losses for each loss type - for idx, loss_type in enumerate(self.loss_type): - # Compute individual loss using standard DPO loss function - _losses, _chosen_rewards, _rejected_rewards = self.dpo_loss( - model_output["chosen_logps"], - model_output["rejected_logps"], - ref_chosen_logps, - ref_rejected_logps, - loss_type, - model_output, - ) - - # Add weighted contributions - weight = self.loss_weights[idx] if self.loss_weights else 1.0 - losses = losses + _losses * weight - chosen_rewards = chosen_rewards + _chosen_rewards * weight - rejected_rewards = rejected_rewards + _rejected_rewards * weight - - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - if self.args.rpo_alpha is not None: - losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper - - if self.use_weighting: - losses = losses * model_output["policy_weights"] - - if self.aux_loss_enabled: - losses = losses + self.aux_loss_coef * model_output["aux_loss"] - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() - metrics[f"{prefix}rewards/margins"] = ( - self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() - ) - metrics[f"{prefix}logps/chosen"] = ( - self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() - ) - metrics[f"{prefix}logps/rejected"] = ( - self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() - ) - metrics[f"{prefix}logits/chosen"] = ( - self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() - ) - metrics[f"{prefix}logits/rejected"] = ( - self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() - ) - if self.args.rpo_alpha is not None or "sft" in self.loss_type: - metrics[f"{prefix}nll_loss"] = ( - self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() - ) - if self.aux_loss_enabled: - metrics[f"{prefix}aux_loss"] = ( - self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() - ) - - return losses.mean(), metrics - - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) - # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: - loss = loss.to(self.args.device) - # force log the metrics - self.store_metrics(metrics, train_eval="train") + # Compute the mean token accuracy and log it + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) - if return_outputs: - return loss, metrics + return (loss, outputs) if return_outputs else loss - return loss - - def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.padding_value, - ) - - # if ref_output in batch use that otherwise use the reference model - if "ref_output" in batch: - ref_output = batch["ref_output"] - else: - if self.ref_model is None: - with self.null_ref_context(): - ref_output = self.model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.padding_value, - ) - else: - ref_output = self.ref_model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.padding_value, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) - ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) - - return policy_output_decoded, ref_output_decoded - - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") - - # force log the metrics - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return loss.detach(), None, None - - # logits for the chosen and rejected samples from model - logits_dict = { - "eval_logits/chosen": metrics["eval_logits/chosen"], - "eval_logits/rejected": metrics["eval_logits/rejected"], - } - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip( - random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded - ) - ], - ) - if "wandb" in self.args.report_to and self.accelerator.is_main_process: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - if "mlflow" in self.args.report_to and self.accelerator.is_main_process: - mlflow.log_table(data=table, artifact_file="game_log.json") - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float` or `None`, *optional*, defaults to `None`): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs.update(metrics) + super().log(logs, start_time) + self._metrics[mode].clear() # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial): @@ -1967,31 +881,15 @@ def create_model_card( tags.update(self._tag_names) - # docstyle-ignore - citation = textwrap.dedent( - """\ - @inproceedings{rafailov2023direct, - title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, - author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, - year = 2023, - booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, - url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, - editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, - }""" - ) - model_card = generate_model_card( base_model=base_model, model_name=model_name, hub_model_id=self.hub_model_id, dataset_name=dataset_name, - tags=tags, + tags=list(tags), wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), trainer_name="DPO", - trainer_citation=citation, - paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model", - paper_id="2305.18290", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index ebfc60f64bf..a85096cbfa7 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -186,7 +186,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d # For packing with position_ids, we should NOT create attention_mask as it causes # FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask if not has_packed_position_ids: - attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] + attention_mask = [torch.ones_like(ids) for ids in input_ids] if self.return_position_ids: if "seq_lengths" in examples[0]: